## The purpose of this notebook is to write and test out equivalent implementations of https://github.com/bmdillon/JetCLR/blob/main/scripts/modules/jet_augs.py with PyTorch instead of Numpy

In [5]:
! pip install torch_geometric

Collecting torch_geometric
  Downloading torch_geometric-2.3.1.tar.gz (661 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m661.6/661.6 kB[0m [31m15.7 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
[?25h  Installing build dependencies ... [?25ldone
[?25h  Getting requirements to build wheel ... [?25ldone
[?25h  Preparing metadata (pyproject.toml) ... [?25ldone
Building wheels for collected packages: torch_geometric
  Building wheel for torch_geometric (pyproject.toml) ... [?25ldone
[?25h  Created wheel for torch_geometric: filename=torch_geometric-2.3.1-py3-none-any.whl size=910460 sha256=953da0708c666701552b001cfb94b4db20f01006578b47d6960d145bc8b66062
  Stored in directory: /home/jovyan/.cache/pip/wheels/aa/16/a8/fd7737d723cc1eb8df023c016c262ff4520091e1b022f8c164
Successfully built torch_geometric
Installing collected packages: torch_geometric
Successfully installed torch_geometric-2.3.1


In [1]:
import glob
import os
import sys
import numpy as np
import random
import time

import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F

import torch_geometric
from torch_geometric.data import Batch, Data
from torch_geometric.loader import DataListLoader, DataLoader

In [2]:
# define the global base device
world_size = torch.cuda.device_count()
multi_gpu = world_size >= 2
if world_size:
    device = torch.device("cuda:0")
    for i in range(world_size):
        print(f"Device {i}: {torch.cuda.get_device_name(i)}")
else:
    device = "cpu"
    print("Device: CPU")

Device 0: NVIDIA GeForce RTX 3090


## Load Data

In [3]:
# load the datafiles
def load_data(dataset_path, flag, n_files=-1):
    data_files = glob.glob(f"{dataset_path}/{flag}/processed/*")
        
    data = []
    for i, file in enumerate(data_files):
        data += torch.load(f"{dataset_path}/{flag}/processed/data_{i}.pt")
        print(f"--- loaded file {i} from `{flag}` directory")
        if n_files!=-1 and i==n_files-1:
            break
                
    return data

dataset_path = "/../ssl-jet-vol-v2/toptagging/"

data_train = load_data(dataset_path, "train", n_files=1)
# data_valid = load_data(dataset_path, "val", n_files=4)
# data_test = load_data(dataset_path, "val", n_files=1)

--- loaded file 0 from `train` directory


In [4]:
batch_size = 128

train_loader = DataLoader(data_train, batch_size)

In [50]:
for i, bb in enumerate(train_loader):
    break

In [51]:
bb

DataBatch(x=[5980, 7], y=[128], batch=[5980], ptr=[129])

In [27]:
bb.x = bb.x[:, :3]

In [29]:
bb.x

tensor([[ 4.8051e-03, -1.2196e-02,  5.8162e+00],
        [-3.0414e-02,  1.1893e-01,  4.3164e+00],
        [ 3.2171e-03, -7.5693e-03,  4.3124e+00],
        ...,
        [-6.6076e-01,  1.8006e-01, -7.0576e-01],
        [ 1.6592e-01,  6.9361e-01, -1.0265e+00],
        [ 2.4350e-01,  1.7230e-01, -1.1979e+00]])

In [15]:
torch.bincount(bb.batch)

tensor([23, 43, 41, 25, 73, 55, 53, 70, 27, 64, 89, 28, 71, 29, 40, 22, 41, 43,
        13, 25, 57, 32, 76, 57, 27, 16, 23, 50, 40, 44, 42, 42, 25, 65, 21, 44,
        61, 23, 32, 26, 53, 47, 55, 67, 54, 36, 81, 62, 62, 66, 13, 45, 68, 23,
        69, 30, 30, 34, 49, 34, 48, 65, 41, 78, 41, 32, 81, 50, 52, 60, 57, 39,
        54, 57, 73, 53, 77, 90, 48, 59, 53, 67, 49, 43, 49, 52, 39, 55, 39, 37,
        21, 33, 21, 44, 25, 48, 56, 21, 44, 36, 53, 33, 54, 45, 59, 69, 44, 34,
        69, 64, 80, 17, 39, 60, 44, 31, 62, 18, 49, 35, 52, 42, 59, 37, 62, 59,
        28, 39])

## Reshape a batch

Original shape: (n_constit_total, 7)  
Desired shape: (batch_size, 3, n_constit)

In [5]:
def convert_x_py(bb):
    bb.x = bb.x[:, :3]
    bb.x = bb.x[:, [2, 0, 1]]
    # Compute number of constituents for each item in the batch
    n_constits = torch.bincount(bb.batch)

    # Compute maximum number of constituents
    n_constit = n_constits.max().item()

    # Allocate a tensor of the desired shape, filled with a padding value (e.g. zero)
    x_padded = torch.zeros(bb.batch.max().item() + 1, 3, n_constit)

    # Fill the padded tensor with the values from Batch.x
    for i, (start, length) in enumerate(zip(bb.ptr[:-1], n_constits)):
        x_padded[i, :, :length] = bb.x[start:start+length].t()

    return x_padded.to(device)

In [6]:
def convert_x(batch):
    """
    
    
    """
    batch.x = batch.x[:, :3]  # dim 1 ordering: eta, phi, pT
    bb.x = bb.x[:, [2, 0, 1]] # dim 1 ordering: pT, eta, phi
    batch_size = batch.num_graphs
    list_of_graphs = batch.to_data_list()  # convert the batch to a list of Data objects
    max_nodes = max(data.x.size(0) for data in list_of_graphs)  # get maximum number of nodes
#     print(max_nodes)
    padded_x = []

    for data in list_of_graphs:
        padding = torch.zeros((max_nodes - data.x.size(0), 3))  # create a padding tensor
#         print(max_nodes - data.x.size(0))
        padded_data = torch.cat([data.x, padding], dim=0)  # append padding to the graph tensor
        padded_x.append(padded_data)

    padded_x = torch.stack(padded_x, dim=0)  # stack into a single tensor
    padded_x = padded_x.transpose(1, 2)  # transpose to get [batch_size, 3, max_nodes]
    return padded_x.to(device)  # output tensor of shape [batch_size, max_nodes, 3]

In [7]:
for _, bb in enumerate(train_loader):
    convert_x_py(bb)

In [8]:
for _, bb in enumerate(train_loader):
    convert_x(bb)

In [9]:
for i, bb in enumerate(train_loader):
    break

In [9]:
convert_x_py(bb)

tensor([[[-1.0421e-02, -5.1420e-03, -1.0631e-02,  ...,  0.0000e+00,
           0.0000e+00,  0.0000e+00],
         [ 4.6977e+00,  4.6504e+00,  4.0143e+00,  ...,  0.0000e+00,
           0.0000e+00,  0.0000e+00],
         [-2.1989e-02, -3.5567e-02, -2.8496e-02,  ...,  0.0000e+00,
           0.0000e+00,  0.0000e+00]],

        [[ 6.7657e-02, -4.6331e-02, -5.9609e-02,  ...,  0.0000e+00,
           0.0000e+00,  0.0000e+00],
         [ 4.9020e+00,  4.5779e+00,  4.0775e+00,  ...,  0.0000e+00,
           0.0000e+00,  0.0000e+00],
         [-4.8572e-02,  3.7484e-02, -4.7075e-02,  ...,  0.0000e+00,
           0.0000e+00,  0.0000e+00]],

        [[-1.6790e-01, -1.9235e-01, -1.5715e-01,  ...,  0.0000e+00,
           0.0000e+00,  0.0000e+00],
         [ 4.5775e+00,  4.0804e+00,  4.0495e+00,  ...,  0.0000e+00,
           0.0000e+00,  0.0000e+00],
         [ 2.1439e-02,  4.5556e-02,  1.8307e-01,  ...,  0.0000e+00,
           0.0000e+00,  0.0000e+00]],

        ...,

        [[-1.4034e-01, -1.5707e-01,

In [10]:
convert_x(bb)

tensor([[[-2.1989e-02, -3.5567e-02, -2.8496e-02,  ...,  0.0000e+00,
           0.0000e+00,  0.0000e+00],
         [-1.0421e-02, -5.1420e-03, -1.0631e-02,  ...,  0.0000e+00,
           0.0000e+00,  0.0000e+00],
         [ 4.6977e+00,  4.6504e+00,  4.0143e+00,  ...,  0.0000e+00,
           0.0000e+00,  0.0000e+00]],

        [[-4.8572e-02,  3.7484e-02, -4.7075e-02,  ...,  0.0000e+00,
           0.0000e+00,  0.0000e+00],
         [ 6.7657e-02, -4.6331e-02, -5.9609e-02,  ...,  0.0000e+00,
           0.0000e+00,  0.0000e+00],
         [ 4.9020e+00,  4.5779e+00,  4.0775e+00,  ...,  0.0000e+00,
           0.0000e+00,  0.0000e+00]],

        [[ 2.1439e-02,  4.5556e-02,  1.8307e-01,  ...,  0.0000e+00,
           0.0000e+00,  0.0000e+00],
         [-1.6790e-01, -1.9235e-01, -1.5715e-01,  ...,  0.0000e+00,
           0.0000e+00,  0.0000e+00],
         [ 4.5775e+00,  4.0804e+00,  4.0495e+00,  ...,  0.0000e+00,
           0.0000e+00,  0.0000e+00]],

        ...,

        [[ 3.3760e-01, -1.9428e-01,

## Translation

In [29]:
for i, bb in enumerate(train_loader):
    break

In [30]:
converted_batch = convert_x(bb)

In [31]:
converted_batch.device

device(type='cuda', index=0)

In [23]:
def translate_jets(batch, width=1.0):
    mask = (batch[:, 0] > 0).float() # 1 for constituents with non-zero pT, 0 otherwise

    # Calculating ptp (max - min) for eta and phi
    ptp_eta = (batch[:, 1, :].max(dim=-1, keepdim=True).values - batch[:, 1, :].min(dim=-1, keepdim=True).values)
    ptp_phi = (batch[:, 2, :].max(dim=-1, keepdim=True).values - batch[:, 2, :].min(dim=-1, keepdim=True).values)

    low_eta = -width * ptp_eta
    high_eta = +width * ptp_eta
    low_phi = torch.maximum(-width * ptp_phi, -torch.tensor(np.pi) - batch[:, 2, :].min(dim=1).values.reshape(ptp_phi.shape))
    high_phi = torch.minimum(+width * ptp_phi, +torch.tensor(np.pi) - batch[:, 2, :].max(dim=1).values.reshape(ptp_phi.shape))

    shift_eta = mask * (torch.rand_like(low_eta) * (high_eta - low_eta) + low_eta)
    shift_phi = mask * (torch.rand_like(low_phi) * (high_phi - low_phi) + low_phi)
    shift = torch.stack([torch.zeros_like(shift_eta), shift_eta, shift_phi], dim=1).squeeze()

    shifted_batch = batch + shift
    return shifted_batch

In [32]:
new_batch = translate_jets(converted_batch)

In [67]:
converted_batch[0,1:,:]

tensor([[ 0.0048, -0.0304,  0.0032, -0.0078,  0.0176, -0.0078, -0.0123, -0.0224,
         -0.0214,  0.0220, -0.0722,  0.0343,  0.0810,  0.0298, -0.0354,  0.0295,
          0.0528,  0.1544,  0.0066, -0.2777, -0.2948,  0.4583, -0.0605,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [-0.0122,  0.1189, -0.0076, -0.0310, -0.0217, -0.0244, -0.0202, -0.0276,


In [56]:
new_batch = translate_jets(converted_batch)

In [61]:
new_batch[0, 2,:] - converted_batch[0,2,:]

tensor([0.3241, 0.3241, 0.3241, 0.3241, 0.3241, 0.3241, 0.3241, 0.3241, 0.3241,
        0.3241, 0.3241, 0.3241, 0.3241, 0.3241, 0.3241, 0.3241, 0.3241, 0.3241,
        0.3241, 0.3241, 0.3241, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000])

## Rotation

In [64]:
def rotate_jets(batch):
    '''
    Input: batch of jets, shape (batchsize, 3, n_constit)
    dim 1 ordering: (pT, eta, phi)
    Output: batch of jets rotated independently in eta-phi, same shape as input
    '''
    rot_angle = torch.rand(batch.shape[0]) * 2 * torch.tensor(np.pi)
    c = torch.cos(rot_angle)
    s = torch.sin(rot_angle)
    o = torch.ones_like(rot_angle)
    z = torch.zeros_like(rot_angle)
    rot_matrix = torch.stack([o, z, z, z, c, -s, z, s, c], dim=1).reshape(-1, 3, 3).transpose(0, 2) # (batchsize, 3, 3)
    return torch.einsum('ijk,lji->ilk', batch, rot_matrix)

In [65]:
new_batch = rotate_jets(converted_batch)

In [73]:
converted_batch[0,1:,:]

tensor([[ 0.0048, -0.0304,  0.0032, -0.0078,  0.0176, -0.0078, -0.0123, -0.0224,
         -0.0214,  0.0220, -0.0722,  0.0343,  0.0810,  0.0298, -0.0354,  0.0295,
          0.0528,  0.1544,  0.0066, -0.2777, -0.2948,  0.4583, -0.0605,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [-0.0122,  0.1189, -0.0076, -0.0310, -0.0217, -0.0244, -0.0202, -0.0276,


In [70]:
new_batch[0, 1:,:]

tensor([[ 2.0638e-04, -1.7214e-02, -8.8983e-05,  1.9022e-02, -8.0113e-03,
          1.6462e-02,  1.9039e-02,  3.1191e-02,  2.0445e-02, -1.4775e-02,
          1.4339e-01, -2.0112e-02, -6.8857e-02, -3.1422e-02,  5.5396e-02,
         -2.5259e-03, -3.6966e-02, -1.2229e-01, -2.8403e-02,  3.0041e-01,
          2.9621e-01, -6.5082e-01,  2.0790e-01,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.

## Normalization

In [74]:
def normalise_pts(batch):
    '''
    Input: batch of jets, shape (batchsize, 3, n_constit)
    dim 1 ordering: (pT, eta, phi)
    Output: batch of pT-normalised jets, pT in each jet sums to 1, same shape as input
    '''
    batch_norm = batch.clone()
    pt_sum = batch_norm[:, 0, :].sum(dim=1, keepdim=True)
    normalized_pt = batch_norm[:, 0, :] / pt_sum
    normalized_pt = torch.nan_to_num(normalized_pt, posinf=0.0, neginf=0.0)
    batch_norm[:, 0, :] = normalized_pt
    return batch_norm

In [75]:
new_batch = normalise_pts(converted_batch)

In [77]:
new_batch[:, 0, :].sum(dim=1, keepdim=True)

tensor([[1.0000],
        [1.0000],
        [1.0000],
        [1.0000],
        [1.0000],
        [1.0000],
        [1.0000],
        [1.0000],
        [1.0000],
        [1.0000],
        [1.0000],
        [1.0000],
        [1.0000],
        [1.0000],
        [1.0000],
        [1.0000],
        [1.0000],
        [1.0000],
        [1.0000],
        [1.0000],
        [1.0000],
        [1.0000],
        [1.0000],
        [1.0000],
        [1.0000],
        [1.0000],
        [1.0000],
        [1.0000],
        [1.0000],
        [1.0000],
        [1.0000],
        [1.0000],
        [1.0000],
        [1.0000],
        [1.0000],
        [1.0000],
        [1.0000],
        [1.0000],
        [1.0000],
        [1.0000],
        [1.0000],
        [1.0000],
        [1.0000],
        [1.0000],
        [1.0000],
        [1.0000],
        [1.0000],
        [1.0000],
        [1.0000],
        [1.0000],
        [1.0000],
        [1.0000],
        [1.0000],
        [1.0000],
        [1.0000],
        [1

## Rescaling

In [78]:
def rescale_pts(batch):
    '''
    Input: batch of jets, shape (batchsize, 3, n_constit)
    dim 1 ordering: (pT, eta, phi)
    Output: batch of pT-rescaled jets, each constituent pT is rescaled by 600, same shape as input
    '''
    batch_rscl = batch.clone()
    rescaled_pt = batch_rscl[:, 0, :] / 600
    rescaled_pt = torch.nan_to_num(rescaled_pt, posinf=0.0, neginf=0.0)
    batch_rscl[:, 0, :] = rescaled_pt
    return batch_rscl

In [79]:
new_batch = rescale_pts(converted_batch)

In [82]:
converted_batch[:, 0, :].sum(dim=1, keepdim=True) / new_batch[:, 0, :].sum(dim=1, keepdim=True)

tensor([[599.9999],
        [600.0000],
        [600.0001],
        [599.9999],
        [600.0001],
        [600.0001],
        [599.9999],
        [600.0000],
        [600.0000],
        [600.0000],
        [599.9999],
        [600.0001],
        [600.0000],
        [600.0001],
        [600.0000],
        [600.0000],
        [600.0001],
        [599.9999],
        [600.0000],
        [600.0001],
        [600.0000],
        [600.0001],
        [599.9999],
        [600.0000],
        [600.0000],
        [600.0000],
        [600.0000],
        [600.0001],
        [600.0000],
        [600.0000],
        [599.9999],
        [599.9999],
        [599.9999],
        [599.9998],
        [599.9999],
        [599.9999],
        [600.0000],
        [600.0000],
        [600.0001],
        [600.0000],
        [600.0000],
        [599.9999],
        [600.0000],
        [600.0000],
        [600.0001],
        [600.0001],
        [599.9999],
        [600.0001],
        [600.0000],
        [599.9999],


In [80]:
new_batch[:, 0, :].sum(dim=1, keepdim=True)

tensor([[0.0680],
        [0.0970],
        [0.1118],
        [0.0763],
        [0.1033],
        [0.0805],
        [0.1160],
        [0.1597],
        [0.0755],
        [0.1042],
        [0.1156],
        [0.0717],
        [0.1223],
        [0.0849],
        [0.0677],
        [0.0375],
        [0.0727],
        [0.0970],
        [0.0544],
        [0.0806],
        [0.1030],
        [0.0754],
        [0.1688],
        [0.1244],
        [0.0335],
        [0.0499],
        [0.0941],
        [0.1180],
        [0.0766],
        [0.1024],
        [0.0964],
        [0.1082],
        [0.0713],
        [0.1247],
        [0.0491],
        [0.1013],
        [0.1620],
        [0.0781],
        [0.0794],
        [0.0637],
        [0.1218],
        [0.1279],
        [0.1579],
        [0.1243],
        [0.1416],
        [0.0782],
        [0.1440],
        [0.1353],
        [0.1100],
        [0.1818],
        [0.0536],
        [0.1345],
        [0.0931],
        [0.0282],
        [0.1512],
        [0

## Cropping

In [83]:
def crop_jets( batch, nc=50):
    '''
    Input: batch of jets, shape (batchsize, 3, n_constit)
    dim 1 ordering: (pT, eta, phi)
    Output: batch of cropped jets, each jet is cropped to nc constituents, shape (batchsize, 3, nc)
    '''
    batch_crop = batch.clone()
    return batch_crop[:,:,0:nc]

In [84]:
new_batch = crop_jets(converted_batch)

In [85]:
new_batch.shape

torch.Size([128, 3, 50])

## Soft splitting

In [92]:
def distort_jets(batch, strength=0.1, pT_clip_min=0.1):
    '''
    Input: batch of jets, shape (batchsize, 3, n_constit)
    dim 1 ordering: (pT, eta, phi)
    Output: batch of jets with each constituents position shifted independently, 
            shifts drawn from normal with mean 0, std strength/pT, same shape as input
    '''
    mask = (batch[:, 0] > 0).float() # 1 for constituents with non-zero pT, 0 otherwise
    pT = batch[:, 0]  # (batchsize, n_constit)
    clipped_pT = pT.clamp(min=pT_clip_min)
    
    shift_eta = mask * torch.randn_like(pT) * strength / clipped_pT
    shift_eta = torch.nan_to_num(shift_eta, posinf=0.0, neginf=0.0)
    
    shift_phi = mask * torch.randn_like(pT) * strength / clipped_pT
    shift_phi = torch.nan_to_num(shift_phi, posinf=0.0, neginf=0.0)
    
    zeros_tensor = torch.zeros_like(shift_eta)
    shift = torch.stack([zeros_tensor, shift_eta, shift_phi], dim=1)
    return batch + shift

In [93]:
new_batch = distort_jets(converted_batch)

In [94]:
converted_batch[0,1,:]

tensor([ 0.0048, -0.0304,  0.0032, -0.0078,  0.0176, -0.0078, -0.0123, -0.0224,
        -0.0214,  0.0220, -0.0722,  0.0343,  0.0810,  0.0298, -0.0354,  0.0295,
         0.0528,  0.1544,  0.0066, -0.2777, -0.2948,  0.4583, -0.0605,  0.0000,
         0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
         0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
         0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
         0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
         0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
         0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
         0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
         0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
         0.0000,  0.0000])

In [95]:
new_batch[0,1,:]

tensor([ 3.8488e-03, -1.4790e-02,  1.0660e-02, -1.6686e-02,  1.4937e-02,
        -3.1109e-02,  4.1678e-02,  2.6737e-04,  6.9504e-03,  2.5316e-02,
        -8.1712e-02,  5.1337e-02,  1.2423e-01,  1.3699e-01,  1.7869e-02,
        -8.2534e-02,  1.9962e-01, -7.8314e-02, -2.3362e-03, -2.6009e-01,
        -7.7566e-01,  4.5826e-01, -6.0504e-02,  0.0000e+00,  0.0000e+00,
         0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
         0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
         0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
         0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
         0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
         0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
         0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
         0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
         0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0

## Collinear splitting

In [None]:
for i, bb in enumerate(train_loader):
    break
converted_batch = convert_x(bb).to(device)

In [124]:
def are_all_nonzero_elements_unique(tensor):
    non_zero_elements = tensor[tensor != 0]
    unique_elements = torch.unique(non_zero_elements)
    return unique_elements.numel() == non_zero_elements.numel()

In [125]:
are_all_nonzero_elements_unique(converted_batch[0,1,:])

True

In [306]:
converted_batch.device

device(type='cuda', index=0)

In [307]:
def collinear_fill_jets_2(batch):
    '''
    Input: batch of jets, shape (batchsize, 3, n_constit)
    dim 1 ordering: (pT, eta, phi)
    Output: batch of jets with collinear splittings, the function attempts to fill as many of the zero-padded args.nconstit
    entries with collinear splittings of the constituents by splitting each constituent at most once, same shape as input
    '''
    cloned_batch = batch.clone()
    num_constituents = batch.shape[2]
    num_non_zeros = (batch[:, 0, :] != 0.0).sum(dim=-1)
#     num_non_zeros = np.array( [ np.where( batch[:,0,:][i]>0.0)[0].shape[0] for i in range(len(batch)) ] )
    
    for batch_idx in range(len(batch)):
        
#         max_non_zeros = max(num_non_zeros[batch_idx], int(num_constituents / 2))
        num_zeros_to_fill = min([num_non_zeros[batch_idx], num_constituents - num_non_zeros[batch_idx]])
        elements_to_split = torch.randperm(num_non_zeros[batch_idx])[:num_zeros_to_fill]
#         elements_to_split = np.random.choice( np.linspace(0,num_zeros_to_fill-1,num_zeros_to_fill), size=num_zeros_to_fill, replace=False )
        random_scaling_factors = torch.rand(num_zeros_to_fill)

        for split_idx in range(num_zeros_to_fill):
            element_idx = int(elements_to_split[split_idx])
            scaling_factor = random_scaling_factors[split_idx]
            original_pt = batch[batch_idx, 0, element_idx]
            cloned_batch[batch_idx, 0, element_idx] = scaling_factor * original_pt
#             cloned_batch[batch_idx, 0, int(num_non_zeros[batch_idx] + split_idx)] = (1 - scaling_factor) * batch[batch_idx, 0, element_idx]
            
            cloned_batch[batch_idx, 0, int(num_non_zeros[batch_idx] + split_idx)] = original_pt - cloned_batch[batch_idx, 0, element_idx]
#             total = cloned_batch[batch_idx, 0, int(num_non_zeros[batch_idx] + split_idx)] + cloned_batch[batch_idx, 0, element_idx]
#             print(abs(total.item() - original_pt.item() < 0.001))
            cloned_batch[batch_idx, 1, int(num_non_zeros[batch_idx] + split_idx)] = batch[batch_idx, 1, element_idx]
            cloned_batch[batch_idx, 2, int(num_non_zeros[batch_idx] + split_idx)] = batch[batch_idx, 2, element_idx]

    return cloned_batch

In [308]:
collinear_fill_jets_2(converted_batch)

tensor([[[ 2.2190e-01,  3.0742e+00,  1.6362e+00,  ...,  0.0000e+00,
           0.0000e+00,  0.0000e+00],
         [ 4.8051e-03, -3.0414e-02,  3.2171e-03,  ...,  0.0000e+00,
           0.0000e+00,  0.0000e+00],
         [-1.2196e-02,  1.1893e-01, -7.5693e-03,  ...,  0.0000e+00,
           0.0000e+00,  0.0000e+00]],

        [[ 2.2420e+00,  4.0446e+00,  2.0844e+00,  ...,  0.0000e+00,
           0.0000e+00,  0.0000e+00],
         [-2.7744e-02,  1.4459e-02, -1.4909e-03,  ...,  0.0000e+00,
           0.0000e+00,  0.0000e+00],
         [-9.5150e-03,  1.1490e-02,  5.4397e-02,  ...,  0.0000e+00,
           0.0000e+00,  0.0000e+00]],

        [[ 8.3519e-01,  4.0517e-02,  2.9810e+00,  ...,  0.0000e+00,
           0.0000e+00,  0.0000e+00],
         [-6.2029e-02,  5.5684e-02, -5.2199e-03,  ...,  0.0000e+00,
           0.0000e+00,  0.0000e+00],
         [-2.9171e-02,  9.4480e-02,  3.1358e-02,  ...,  0.0000e+00,
           0.0000e+00,  0.0000e+00]],

        ...,

        [[ 1.3188e+00,  1.7132e-01,

In [312]:
def collinear_fill_jets_np( batch ):
    '''
    Input: batch of jets, shape (batchsize, 3, n_constit)
    dim 1 ordering: (pT, eta, phi)
    Output: batch of jets with collinear splittings, the function attempts to fill as many of the zero-padded args.nconstit
    entries with collinear splittings of the constituents by splitting each constituent at most once, same shape as input
    '''
    batchb = batch.copy()
    nc = batch.shape[2]  # number of constituents
    nzs = np.array( [ np.where( batch[:,0,:][i]!=0.0)[0].shape[0] for i in range(len(batch)) ] ) # number of non-zero elements
    for k in range(len(batch)):
        zs1 = np.min([nzs[k], nc-nzs[k]])  # number of zero padded entries to fill
        els = np.random.choice( np.linspace(0,nzs[k]-1,nzs[k]), size=zs1, replace=False )
        rs = np.random.uniform( size=zs1 ) # scaling factor
        for j in range(zs1):
            # split pT
            batchb[k,0,int(els[j])] = rs[j]*batch[k,0,int(els[j])]
            batchb[k,0,int(nzs[k]+j)] = (1-rs[j])*batch[k,0,int(els[j])]
            # keep eta and phi
            batchb[k,1,int(nzs[k]+j)] = batch[k,1,int(els[j])]
            batchb[k,2,int(nzs[k]+j)] = batch[k,2,int(els[j])]
    return torch.tensor(batchb).to(device)

In [313]:
collinear_fill_jets_np(np.array(converted_batch.cpu()))

tensor([[[ 5.0245e+00,  1.6125e+00,  1.2359e+00,  ...,  0.0000e+00,
           0.0000e+00,  0.0000e+00],
         [ 4.8051e-03, -3.0414e-02,  3.2171e-03,  ...,  0.0000e+00,
           0.0000e+00,  0.0000e+00],
         [-1.2196e-02,  1.1893e-01, -7.5693e-03,  ...,  0.0000e+00,
           0.0000e+00,  0.0000e+00]],

        [[ 3.8368e-02,  2.9158e+00,  1.7454e+00,  ...,  0.0000e+00,
           0.0000e+00,  0.0000e+00],
         [-2.7744e-02,  1.4459e-02, -1.4909e-03,  ...,  0.0000e+00,
           0.0000e+00,  0.0000e+00],
         [-9.5150e-03,  1.1490e-02,  5.4397e-02,  ...,  0.0000e+00,
           0.0000e+00,  0.0000e+00]],

        [[ 2.5912e-01,  3.0305e+00,  3.6987e+00,  ...,  0.0000e+00,
           0.0000e+00,  0.0000e+00],
         [-6.2029e-02,  5.5684e-02, -5.2199e-03,  ...,  0.0000e+00,
           0.0000e+00,  0.0000e+00],
         [-2.9171e-02,  9.4480e-02,  3.1358e-02,  ...,  0.0000e+00,
           0.0000e+00,  0.0000e+00]],

        ...,

        [[ 2.6528e+00,  3.7950e+00,

In [292]:
new_batch = collinear_fill_jets_np(np.array(converted_batch.cpu()))

In [293]:
new_batch[0,1:]

array([[ 0.00480515, -0.03041375,  0.0032171 , -0.00780481,  0.01759285,
        -0.00775272, -0.01226979, -0.02236104, -0.02136147,  0.02201229,
        -0.07218915,  0.03431445,  0.08099574,  0.02978092, -0.03541929,
         0.02953249,  0.05279046,  0.1543557 ,  0.00662047, -0.27768683,
        -0.2948079 ,  0.45826262, -0.0605036 , -0.02236104, -0.02136147,
        -0.03041375,  0.00662047,  0.03431445, -0.00780481,  0.08099574,
         0.45826262,  0.1543557 ,  0.0032171 , -0.07218915,  0.02201229,
         0.01759285, -0.0605036 , -0.03541929, -0.00775272, -0.2948079 ,
         0.02953249,  0.00480515,  0.05279046, -0.01226979, -0.27768683,
         0.02978092,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0. 

In [207]:
converted_batch[0,1,:]

tensor([ 0.0048, -0.0304,  0.0032, -0.0078,  0.0176, -0.0078, -0.0123, -0.0224,
        -0.0214,  0.0220, -0.0722,  0.0343,  0.0810,  0.0298, -0.0354,  0.0295,
         0.0528,  0.1544,  0.0066, -0.2777, -0.2948,  0.4583, -0.0605,  0.0000,
         0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
         0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
         0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
         0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
         0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
         0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
         0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
         0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
         0.0000,  0.0000])

In [304]:
split_dict = {}
k=0
for i in range(new_batch.shape[2]):
    split_dict[i] = []
    for j in range(new_batch.shape[2]):
        if new_batch[k,1,j] == new_batch[k,1,i] and new_batch[k,2,j] == new_batch[k,2,i] and new_batch[k,1,i] != 0:
            split_dict[i].append(j)

In [305]:
for i in split_dict.keys():
    if len(split_dict[i]) == 2:
        print(i, new_batch[k,0,split_dict[i][0]] + new_batch[k,0,split_dict[i][1]] - converted_batch[k,0,min(split_dict[i])].item())

0 tensor(0., device='cuda:0')
1 tensor(0., device='cuda:0')
2 tensor(0., device='cuda:0')
3 tensor(0., device='cuda:0')
4 tensor(0., device='cuda:0')
5 tensor(0., device='cuda:0')
6 tensor(0., device='cuda:0')
7 tensor(0., device='cuda:0')
8 tensor(0., device='cuda:0')
9 tensor(0., device='cuda:0')
10 tensor(0., device='cuda:0')
11 tensor(0., device='cuda:0')
12 tensor(0., device='cuda:0')
13 tensor(0., device='cuda:0')
14 tensor(0., device='cuda:0')
15 tensor(0., device='cuda:0')
16 tensor(0., device='cuda:0')
17 tensor(0., device='cuda:0')
18 tensor(0., device='cuda:0')
19 tensor(0., device='cuda:0')
20 tensor(0., device='cuda:0')
21 tensor(0., device='cuda:0')
22 tensor(0., device='cuda:0')
23 tensor(0., device='cuda:0')
24 tensor(0., device='cuda:0')
25 tensor(0., device='cuda:0')
26 tensor(0., device='cuda:0')
27 tensor(0., device='cuda:0')
28 tensor(0., device='cuda:0')
29 tensor(0., device='cuda:0')
30 tensor(0., device='cuda:0')
31 tensor(0., device='cuda:0')
32 tensor(0., devi

In [279]:
for i in range(128):
    print(f"{i}: {are_all_nonzero_elements_unique(converted_batch[i,1,:])}")
    

0: True
1: True
2: False
3: True
4: False
5: False
6: False
7: False
8: True
9: False
10: False
11: True
12: False
13: False
14: True
15: True
16: False
17: True
18: True
19: True
20: False
21: False
22: False
23: True
24: True
25: True
26: True
27: False
28: True
29: False
30: True
31: False
32: True
33: False
34: True
35: False
36: False
37: False
38: False
39: True
40: True
41: False
42: True
43: False
44: False
45: False
46: False
47: True
48: False
49: False
50: True
51: False
52: False
53: True
54: False
55: False
56: True
57: True
58: False
59: False
60: False
61: True
62: False
63: True
64: False
65: False
66: False
67: False
68: False
69: False
70: False
71: False
72: True
73: False
74: False
75: False
76: False
77: False
78: False
79: False
80: False
81: False
82: False
83: False
84: True
85: False
86: True
87: False
88: False
89: True
90: True
91: False
92: True
93: True
94: True
95: False
96: True
97: True
98: False
99: True
100: False
101: False
102: False
103: True
104: F

In [271]:
split_dict

{0: [0],
 1: [1],
 2: [2],
 3: [3],
 4: [4],
 5: [5],
 6: [6],
 7: [7],
 8: [8],
 9: [9],
 10: [10],
 11: [11],
 12: [12],
 13: [13],
 14: [14],
 15: [15],
 16: [16],
 17: [17],
 18: [18],
 19: [19],
 20: [20],
 21: [21],
 22: [22],
 23: [23],
 24: [24],
 25: [25],
 26: [26],
 27: [27],
 28: [28],
 29: [29],
 30: [30],
 31: [31],
 32: [32],
 33: [33],
 34: [34],
 35: [35, 79],
 36: [36],
 37: [37],
 38: [38],
 39: [39, 87],
 40: [40],
 41: [41],
 42: [42],
 43: [43],
 44: [44],
 45: [45],
 46: [46],
 47: [47],
 48: [48, 74],
 49: [49],
 50: [50],
 51: [51],
 52: [52],
 53: [53],
 54: [54],
 55: [55],
 56: [56],
 57: [57],
 58: [58],
 59: [59],
 60: [60],
 61: [61],
 62: [62],
 63: [63],
 64: [64],
 65: [65],
 66: [66],
 67: [67],
 68: [68],
 69: [69],
 70: [70, 89],
 71: [71],
 72: [72],
 73: [73],
 74: [48, 74],
 75: [75],
 76: [76],
 77: [77],
 78: [78],
 79: [35, 79],
 80: [80],
 81: [81],
 82: [82],
 83: [83],
 84: [84],
 85: [85],
 86: [86],
 87: [39, 87],
 88: [88],
 89: [70, 89]

In [32]:
def collinear_fill_jets( batch ):
    '''
    Input: batch of jets, shape (batchsize, 3, n_constit)
    dim 1 ordering: (pT, eta, phi)
    Output: batch of jets with collinear splittings, the function attempts to fill as many of the zero-padded args.nconstit
    entries with collinear splittings of the constituents by splitting each constituent at most once, same shape as input
    '''
    batchb = batch.copy()
    nc = batch.shape[2]
    nzs = np.array( [ np.where( batch[:,0,:][i]!=0.0)[0].shape[0] for i in range(len(batch)) ] )
    print(nzs)
    for k in range(len(batch)):
        nzs1 = np.max( [ nzs[k], int(nc/2) ] )
        zs1 = int(nc-nzs1)
        els = np.random.choice( np.linspace(0,nzs1-1,nzs1), size=zs1, replace=False )
        rs = np.random.uniform( size=zs1 )
        if k == 0:
            print(f"nzs[k]:{nzs[k]}")
            print(f"els: {els.shape}")
            print(23 in els)
        for j in range(zs1):
            batchb[k,0,int(els[j])] = rs[j]*batch[k,0,int(els[j])]
            batchb[k,0,int(nzs[k]+j)] = (1-rs[j])*batch[k,0,int(els[j])]
            batchb[k,1,int(nzs[k]+j)] = batch[k,1,int(els[j])]
            batchb[k,2,int(nzs[k]+j)] = batch[k,2,int(els[j])]
    return batchb

In [12]:
converted_batch = convert_x(bb)

In [33]:
new_batch = collinear_fill_jets(np.array(converted_batch.cpu()))

[23 43 41 25 73 55 53 70 27 64 89 28 71 29 40 22 41 43 13 25 57 32 76 57
 27 16 23 50 40 44 42 42 25 65 21 44 61 23 32 26 53 47 55 67 54 36 81 62
 62 66 13 45 68 23 69 30 30 34 49 34 48 65 41 78 41 32 81 50 52 60 57 39
 54 57 73 53 77 90 48 59 53 67 49 43 49 52 39 55 39 37 21 33 21 44 25 48
 56 21 44 36 53 33 54 45 59 69 44 34 69 64 80 17 39 60 44 31 62 18 49 35
 52 42 59 37 62 59 28 39]
nzs[k]:23
els: (45,)
True


In [25]:
for i in range(128):
    print(bb[i].x.shape)

torch.Size([23, 3])
torch.Size([43, 3])
torch.Size([41, 3])
torch.Size([25, 3])
torch.Size([73, 3])
torch.Size([55, 3])
torch.Size([53, 3])
torch.Size([70, 3])
torch.Size([27, 3])
torch.Size([64, 3])
torch.Size([89, 3])
torch.Size([28, 3])
torch.Size([71, 3])
torch.Size([29, 3])
torch.Size([40, 3])
torch.Size([22, 3])
torch.Size([41, 3])
torch.Size([43, 3])
torch.Size([13, 3])
torch.Size([25, 3])
torch.Size([57, 3])
torch.Size([32, 3])
torch.Size([76, 3])
torch.Size([57, 3])
torch.Size([27, 3])
torch.Size([16, 3])
torch.Size([23, 3])
torch.Size([50, 3])
torch.Size([40, 3])
torch.Size([44, 3])
torch.Size([42, 3])
torch.Size([42, 3])
torch.Size([25, 3])
torch.Size([65, 3])
torch.Size([21, 3])
torch.Size([44, 3])
torch.Size([61, 3])
torch.Size([23, 3])
torch.Size([32, 3])
torch.Size([26, 3])
torch.Size([53, 3])
torch.Size([47, 3])
torch.Size([55, 3])
torch.Size([67, 3])
torch.Size([54, 3])
torch.Size([36, 3])
torch.Size([81, 3])
torch.Size([62, 3])
torch.Size([62, 3])
torch.Size([66, 3])


In [28]:
new_batch[0,0,:]

array([ 3.5951667 ,  1.3918711 ,  2.7377632 ,  3.1586854 ,  1.0422342 ,
        2.639468  ,  0.21917054,  1.0453333 ,  0.8468509 ,  0.5749914 ,
        1.4674066 ,  0.4539976 ,  1.0031737 ,  0.5150111 ,  0.86192113,
        0.4999498 ,  0.0963839 ,  0.12090765,  0.19472317,  0.11526545,
        0.02369792, -0.015419  , -1.1311806 ,  0.        ,  0.03653057,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        1.3832604 ,  1.5746067 ,  1.7138896 ,  2.2210484 ,  2.3937337 ,
        0.        ,  0.44604877,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.8752433 ,
        0.        ,  0.        ,  1.413772  ,  0.19832505,  0.        ,
       -0.6592403 ,  0.        ,  0.        ,  0.        ,  0.  

In [29]:
np.array(converted_batch.cpu())[0,0,:]

array([ 5.816215  ,  4.316429  ,  4.31237   ,  3.3570106 ,  2.7561238 ,
        2.6759984 ,  2.6129043 ,  2.2736187 ,  2.091402  ,  1.9887633 ,
        1.9134554 ,  1.837258  ,  1.5402927 ,  1.3902544 ,  1.0075308 ,
        0.7134997 ,  0.6245758 ,  0.54390323,  0.48727167,  0.24372263,
        0.13339488, -0.6746593 , -1.1519126 ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.  