In [8]:
import os
import sys
import numpy as np
import random
import time
import torch
# set gpu device
device = torch.device( "cuda" if torch.cuda.is_available() else "cpu" )
torch.use_deterministic_algorithms(True)
import matplotlib.pyplot as plt
import torch.nn as nn
import torch.nn.functional as F

# Getting the batch x

In [22]:
#loading in data ------------------------------------------------------------
import My_Anom_extargs as args
sys.path.insert(1, '/remote/gpu05/rueschkamp/projects/torch_datasets/')
from semi_dataset import SemiV_Dataset
from torch.utils.data import DataLoader


#starting training loader --------------------------------------


training_set = SemiV_Dataset(
                                    data_path = args.data_path,
                                    signal_origin= "qcd",
                                    usage= "training",
                                    number_constit= args.n_constit,
                                    number_of_jets= 1e3,
                                    ratio=0.2
                                    )

dl_training = DataLoader(training_set,batch_size=128, shuffle=True)

for i,(data,labels) in enumerate(dl_training):

    x = data

    break

In [3]:
print(x)

tensor([[[ 6.0063e+01,  3.6949e+01,  1.6779e+01,  1.2583e+01,  1.0393e+01,
           3.0307e+00,  2.8600e+00,  2.8419e+00,  2.3515e+00,  2.0964e+00,
           1.7141e+00,  1.7024e+00,  1.3026e+00,  1.2009e+00,  1.1571e+00,
           4.4814e-01,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
           0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
           0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
           0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
           0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
           0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
           0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00],
         [ 3.1984e-03, -8.2359e-02,  1.6986e-02, -3.8284e-04, -4.6613e-03,
           2.4926e-03,  3.4589e-02,  1.4970e-01,  1.4167e-01,  2.8626e-01,
           2.8447e-01, -1.8889e-01,  4.1370e-01,  1.6515e-01,  8.1092e-02,
           3.5033e-01,  

In [None]:
#x = torch.randint(0,20,[128, 3, 50])
x = torch.ones([128, 3, 50])*20

pTs= torch.ones([128, 50])*10
etas = torch.ones([128, 50])*100
phis = torch.ones([128, 50])*1000
x = torch.cat((pTs.unsqueeze(1),etas.unsqueeze(1),phis.unsqueeze(1)),axis = 1)

x = x = torch.ones([1, 3, 50])

In [None]:
x = x.to(device)

# Anomaly Augmentations

### Basic fkt

In [9]:
def rescale_pts(batch):
    '''
    Input: batch of jets, shape (batchsize, 3, n_constit)
    dim 1 ordering: (pT, eta, phi)
    Output: batch of pT-rescaled jets, each constituent pT is rescaled by 600, same shape as input
    '''
    batch_rscl = batch.clone()
    batch_rscl[:,0,:] = torch.nan_to_num(batch_rscl[:,0,:]/600, nan=0.0, posinf=0.0, neginf=0.0)
    return batch_rscl

def recentre_jet(batch):
    batchc = batch.clone()
    pts = batch[:,0,:]
    etas = batch[:,1,:]
    phis = batch[:,2,:]
    if torch.sum( pts ) != 0:
        eta_shift = torch.sum(  pts*etas  ) / torch.sum( pts )
        phi_shift = torch.sum(  pts*phis ) / torch.sum( pts )
        etas = etas - eta_shift
        phis = phis - phi_shift

    pTs, indices = torch.sort(pts, dim=1, descending=True) # Ordering pTs
    etas = etas.gather(dim=1,index=indices)
    phis = phis.gather(dim=1,index=indices)

    batch_recentred = torch.cat((pTs.unsqueeze(1),etas.unsqueeze(1),phis.unsqueeze(1)),axis = 1)

    #print(batch_dropped.size())
    return batch_recentred



#### Checking recentre

In [None]:
def recentre_jet_np(batch):
    batchc = batch.copy()
    nj = batch.shape[0]
    for i in range( nj ):
        pts = batch[i,0,:]
        etas = batch[i,1,:]
        phis = batch[i,2,:]
        nc = len( pts )
        eta_shift = np.sum( [ pts[j]*etas[j] for j in range( nc ) ] ) / np.sum( pts )
        phi_shift = np.sum( [ pts[j]*phis[j] for j in range( nc ) ] ) / np.sum( pts )
        batchc[i,1,:] = batch[i,1,:] - eta_shift
        batchc[i,2,:] = batch[i,2,:] - phi_shift
    return batchc

t = recentre_jet(x)
n = recentre_jet_np(x.cpu().numpy())

print(t.cpu().numpy() - n)

#### Checking rescale

In [None]:


def rescale_pt_np(dataset):
    for i in range(0, dataset.shape[0]):
        dataset[i,0,:] = dataset[i,0,:]/600
    return dataset

t = rescale_pts(x).cpu()
n = rescale_pt_np(x.cpu().numpy())

print(t.numpy() - n)


## Drop Constituents

### Independent

In [None]:
torch.manual_seed(42)

In [None]:
def drop_constits_jet( batch, prob=0.5 ):
    '''
    Input: batch of jets, shape (batchsize, 3, n_constit)
    Dim 1 ordering: (pT, eta, phi)
    Output: batch of jets where each jet has some fraction of missing constituents
    Note: rescale pts so that the augmented jet pt matches the original
    '''
    batch_dropped = batch.clone()
    #n_nonzero = torch.sum(batch_dropped[:,0,:]>0, dim=1)
    nj = batch_dropped.shape[0]
    nc = batch_dropped.shape[2]
    torch.manual_seed(42)
    mask = torch.rand((nj, nc)) > prob
    print("first mask:",mask)
    mask = mask.int().to(device)
    print(mask)
    num_zeros_tensor = (mask == 0).sum().item()
    #print(num_zeros_tensor)
    batch_dropped = batch_dropped * mask.unsqueeze(1)

    #print(batch_dropped)
    pts = torch.sum( batch[:,0,:], axis=1 )
    pts_aug = torch.sum( batch_dropped[:,0,:], axis=1 )

    pts_aug[pts_aug == 0] = 1
    pt_rescale = pts/pts_aug

    pTs= batch_dropped[:,0,:]
    etas = batch_dropped[:,1,:]
    phis = batch_dropped[:,2,:]
    pTs *= pt_rescale.unsqueeze(1)
    
    batch_dropped = torch.cat((pTs.unsqueeze(1),etas.unsqueeze(1),phis.unsqueeze(1)),axis = 1)

    #print(batch_dropped.size())
    return recentre_jet( batch_dropped )


In [None]:
def drop_constits_jet_np( batch, prob=0.5 ):
    '''
    Input: batch of jets, shape (batchsize, 3, n_constit)
    Dim 1 ordering: (pT, eta, phi)
    Output: batch of jets where each jet has some fraction of missing constituents
    Note: rescale pts so that the augmented jet pt matches the original
    '''
    batchc = batch.copy()
    nj = batchc.shape[0]
    nc = batchc.shape[2]
    nzs = np.array( [ np.where( batchc[:,0,:]>0.0 )[0].shape[0] for i in range(len(batch)) ] )

    torch.manual_seed(42)
    mask = torch.rand((nj, nc)) > prob
    print("first mask:",mask)
    mask = mask.int()
    mask = mask.numpy()
    print(mask)
    num_zeros_array = ((np.array(mask)) == 0).sum()
    #print(num_zeros_array)

    for i in range( nj ):
        for j in range( nc ):
            if mask[i][j]==0:
                batchc[i,:,j] = np.array([0.0,0.0,0.0])
    pts = np.sum( batch[:,0,:], axis=1 )
    pts_aug = np.sum( batchc[:,0,:], axis=1 )
    pt_rescale = [ pts[i]/pts_aug[i] for i in range(nj) ]
    for i in range(nj):
        batchc[i,0,:] = batchc[i,0,:]*pt_rescale[i]
    return recentre_jet_np( batchc )

x=x.to(device)
t = drop_constits_jet(x,0.9)
n = drop_constits_jet_np(x.cpu().numpy(),0.9)


print(x.size())
print(t)
print(n)

In [36]:
def drop_constits_jet_ordered( batch, prob=0.5 ):
    '''
    Input: batch of jets, shape (batchsize, 3, n_constit)
    Dim 1 ordering: (pT, eta, phi)
    Output: batch of jets where each jet has some fraction of missing constituents
    Note: rescale pts so that the augmented jet pt matches the original
    '''
    batch_dropped = batch.clone()
    #n_nonzero = torch.sum(batch_dropped[:,0,:]>0, dim=1)
    nj = batch_dropped.shape[0]
    nc = batch_dropped.shape[2]
    mask = torch.rand((nj, nc)) > prob
    mask = mask.int().to(device)
    batch_dropped = batch_dropped * mask.unsqueeze(1)

    #print(batch_dropped)
    pts = torch.sum( batch[:,0,:], axis=1 )
    pts_aug = torch.sum( batch_dropped[:,0,:], axis=1 )

    pTs= batch_dropped[:,0,:]
        
    if torch.any(pts_aug != 0):
        pt_rescale = torch.where(pts_aug != 0, pts / pts_aug, torch.ones_like(pts))
        print(pt_rescale)
        pTs *= pt_rescale.unsqueeze(1)


    pTs, indices = torch.sort(pTs, dim=1, descending=True) # Ordering pTs
    #print("pts:", pTs)
    etas = batch_dropped[:,1,:]
    etas = etas.gather(dim=1,index=indices)
    phis = batch_dropped[:,2,:]
    phis = phis.gather(dim=1,index=indices)

    
    batch_dropped = torch.cat((pTs.unsqueeze(1),etas.unsqueeze(1),phis.unsqueeze(1)),axis = 1)
    non_zero_count = torch.sum(pTs != 0, dim=-1, keepdim=True)
    #print(batch_dropped)
    if torch.min(non_zero_count)==0:
        return drop_constits_jet_ordered(batch,prob)
    else:
        return recentre_jet( batch_dropped )

### Safe

In [26]:
pTs= x[:,0,:]
print(pTs.size())
# Count non-zero entries in the last dimension
non_zero_count = torch.sum(pTs != 0, dim=-1, keepdim=True)
dropping_numbers = torch.round( non_zero_count * 0.3)
#print(dropping_numbers)
total_length_mask = pTs.size(1)

mask_safe = []
# Create a mask with zeros distributed between ones
for i in range(len(dropping_numbers)):
    length_of_nonzero_pTs = int(non_zero_count[i])
    #print(dropping_numbers[i])
    n_drop = int(dropping_numbers[i])

    non_zero_mask = torch.cat((torch.zeros(n_drop), torch.ones(length_of_nonzero_pTs - n_drop))) #creating mask for non zero entries
    shuffled_non_zero_mask = non_zero_mask[torch.randperm(non_zero_mask.size(0))]  # Generate random permutation of non-zero mask

    #print(shuffled_non_zero_mask)

    jet_mask = torch.cat((shuffled_non_zero_mask,torch.zeros(total_length_mask-length_of_nonzero_pTs))) #Creating mask for whole jet

    mask_safe.append(jet_mask)

mask = torch.stack(mask_safe)

result = x.cpu() * mask.unsqueeze(1)
print(result)
    
# Shuffle the mask tensor

torch.Size([128, 50])
tensor([[[ 0.0000e+00,  0.0000e+00,  1.4455e+01,  ...,  0.0000e+00,
           0.0000e+00,  0.0000e+00],
         [-0.0000e+00, -0.0000e+00,  5.4634e-02,  ...,  0.0000e+00,
           0.0000e+00,  0.0000e+00],
         [ 0.0000e+00, -0.0000e+00,  1.1028e-01,  ..., -0.0000e+00,
          -0.0000e+00, -0.0000e+00]],

        [[ 0.0000e+00,  2.0434e+01,  1.5722e+01,  ...,  0.0000e+00,
           0.0000e+00,  0.0000e+00],
         [ 0.0000e+00, -2.6525e-02, -1.5643e-02,  ...,  0.0000e+00,
           0.0000e+00,  0.0000e+00],
         [-0.0000e+00, -5.1978e-02, -4.5160e-02,  ...,  0.0000e+00,
           0.0000e+00,  0.0000e+00]],

        [[ 0.0000e+00,  1.7668e+01,  1.2581e+01,  ...,  0.0000e+00,
           0.0000e+00,  0.0000e+00],
         [ 0.0000e+00,  5.1081e-02,  9.1307e-02,  ...,  0.0000e+00,
           0.0000e+00,  0.0000e+00],
         [ 0.0000e+00,  2.1180e-02,  3.9251e-02,  ...,  0.0000e+00,
           0.0000e+00,  0.0000e+00]],

        ...,

        [[ 2.

In [39]:
def drop_constits_jet_safe( batch, prob=0.5 ):
    '''
    Input: batch of jets, shape (batchsize, 3, n_constit)
    Dim 1 ordering: (pT, eta, phi)
    Output: batch of jets where each jet has some fraction of missing constituents
    Note: rescale pts so that the augmented jet pt matches the original
    '''
    batch_dropped = batch.clone()

    pTs= batch_dropped[:,0,:]
    non_zero_count = torch.sum(pTs != 0, dim=-1, keepdim=True)
    dropping_numbers = torch.round( non_zero_count * prob)
    #print(dropping_numbers)
    total_length_mask = pTs.size(1)

    mask_safe = []
    # Create a mask with zeros distributed between ones
    for i in range(len(dropping_numbers)):
        length_of_nonzero_pTs = int(non_zero_count[i])
        #print(dropping_numbers[i])
        n_drop = int(dropping_numbers[i])

        non_zero_mask = torch.cat((torch.zeros(n_drop), torch.ones(length_of_nonzero_pTs - n_drop))) #creating mask for non zero entries
        shuffled_non_zero_mask = non_zero_mask[torch.randperm(non_zero_mask.size(0))]  # Generate random permutation of non-zero mask

        #print(shuffled_non_zero_mask)

        jet_mask = torch.cat((shuffled_non_zero_mask,torch.zeros(total_length_mask-length_of_nonzero_pTs))) #Creating mask for whole jet

        mask_safe.append(jet_mask)

    mask = torch.stack(mask_safe).to(device)

    result = x * mask.unsqueeze(1)

    

    #From here on just rescaling and reordering --> Masking is done here. 

    #print(batch_dropped)
    pts = torch.sum( result[:,0,:], axis=1 )
    pts_aug = torch.sum( result[:,0,:], axis=1 )


    if torch.any(pts_aug != 0):
        pt_rescale = torch.where(pts_aug != 0, pts / pts_aug, torch.ones_like(pts))
        #print(pt_rescale)
        pTs *= pt_rescale.unsqueeze(1)

    pTs_dropped =result[:,0,:]

    pTs, indices = torch.sort(pTs_dropped, dim=1, descending=True) # Ordering pTs
    #print("pts:", pTs)
    etas = result[:,1,:]
    etas = etas.gather(dim=1,index=indices)
    phis = result[:,2,:]
    phis = phis.gather(dim=1,index=indices)

    
    batch_dropped = torch.cat((pTs.unsqueeze(1),etas.unsqueeze(1),phis.unsqueeze(1)),axis = 1)

    #print(batch_dropped)
    return recentre_jet( batch_dropped )

In [40]:
print(drop_constits_jet_safe(x)[0])
print(x[0])

tensor([[ 1.8307e+01,  1.3657e+01,  1.0552e+01,  7.4996e+00,  7.3442e+00,
          5.7860e+00,  5.6805e+00,  3.2277e+00,  2.5522e+00,  2.1958e+00,
          1.8695e+00,  1.7832e+00,  1.4848e+00,  1.4705e+00,  1.3568e+00,
          1.0411e+00,  1.0220e+00,  9.7170e-01,  9.4382e-01,  6.8864e-01,
          6.4694e-01,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00],
        [-6.8864e-02, -5.9721e-02, -2.5275e-02,  4.3302e-02, -3.1788e-02,
          1.0998e-01, -5.9714e-02, -1.4244e-01,  3.2493e-02,  1.2836e-01,
         -5.5845e-02,  1.2249e-01, -6.6407e-01,  3.6923e-01, -2.5517e-02,
          2.6570e-01, -1.2463e-01,  7

## pT reweight


In [None]:
def pt_reweight_jet( batch, beta=1.5 ):
    '''
    Input: batch of jets, shape (batchsize, 3, n_constit)
    Dim 1 ordering: (pT, eta, phi)
    Output: batch of jets where the pt of the constituents in each jet has has been re-weighted by some power
    Note: rescale pts so that the augmented jet pt matches the original
    '''
    batchc = batch.clone()

    etas = batchc[:,1,:]
    phis = batchc[:,2,:]
    batchc = batchc[:,0,:]**beta
    pts = torch.sum( batch[:,0,:], axis=1 )
    pts_aug = torch.sum( batchc, axis=1 )

    pts_aug[pts_aug == 0] = 1
    pt_rescale =  pts/pts_aug
    pTs = pt_rescale.unsqueeze(-1)* batchc
    #print(pTs)

    jet = torch.cat((pTs.unsqueeze(1),etas.unsqueeze(1),phis.unsqueeze(1)),axis = 1)
    return recentre_jet( jet.float() )

In [None]:
def pt_reweight_jet_np( batch, beta=1.5 ):
    '''
    Input: batch of jets, shape (batchsize, 3, n_constit)
    Dim 1 ordering: (pT, eta, phi)
    Output: batch of jets where the pt of the constituents in each jet has has been re-weighted by some power
    Note: rescale pts so that the augmented jet pt matches the original
    '''
    batchc = batch.copy()
    nj = batchc.shape[0]
    nc = batchc.shape[2]
    for i in range( nj ):
        for j in range( nc ):
            batchc[i,0,j] = batch[i,0,j]**beta
    pts = np.sum( batch[:,0,:], axis=1 )
    pts_aug = np.sum( batchc[:,0,:], axis=1 )
    pt_rescale = [ pts[i]/pts_aug[i] for i in range(nj) ]
    for i in range(nj):
        batchc[i,0,:] = batchc[i,0,:]*pt_rescale[i]
    #print(batchc[:,0,:])
    return recentre_jet_np( batchc )

t = pt_reweight_jet(x)
n = pt_reweight_jet_np(x.cpu().numpy())

print(t.cpu().numpy() - n)
print(t - torch.Tensor(n).to(device))
#print(t - x)

In [None]:
def pt_reweight_jet_ordered( batch, beta=1.5 ):
    '''
    Input: batch of jets, shape (batchsize, 3, n_constit)
    Dim 1 ordering: (pT, eta, phi)
    Output: batch of jets where the pt of the constituents in each jet has has been re-weighted by some power
    Note: rescale pts so that the augmented jet pt matches the original
    '''
    batchc = batch.clone()

    etas = batchc[:,1,:]
    phis = batchc[:,2,:]
    batchc = batchc[:,0,:]**beta
    pts = torch.sum( batch[:,0,:], axis=1 )
    pts_aug = torch.sum( batchc, axis=1 )

    pts_aug[pts_aug == 0] = 1
    pt_rescale =  pts/pts_aug
    pTs = pt_rescale.unsqueeze(-1)* batchc
    #print(pTs)
    pTs, indices = torch.sort(pTs, dim=1, descending=True) # Ordering pTs
    etas = batch_dropped[:,1,:]
    etas = etas.gather(dim=1,index=indices)
    phis = batch_dropped[:,2,:]
    phis = phis.gather(dim=1,index=indices)

    jet = torch.cat((pTs.unsqueeze(1),etas.unsqueeze(1),phis.unsqueeze(1)),axis = 1)
    return recentre_jet( jet.float() )

## adding stuff


In [None]:
def add_jets(batch):
    batch_filled = batch.clone()
    n_constit = batch_filled.shape[2]
    n_nonzero = torch.sum(batch_filled[:,0,:]>0, dim=1) #number of zero constituents
    print(n_nonzero)
    n_split = torch.min(torch.stack([n_nonzero, n_constit-n_nonzero], dim=1), dim=1).values
    print(n_split)

In [None]:
y = add_jets(x)

# Physical Augmentations

## Collinear fill jets 

In [None]:
def collinear_fill_jets_fast(batch):
    '''
    Fill as many of the zero-padded entries with collinear splittings
    of the constituents by splitting each constituent at most once.
    Parameters
    ----------
    batch : torch.Tensor
        batch of jets with zero-padding
    Returns
    -------
    batch_filled : torch.Tensor
        batch of jets with collinear splittings
    '''
    batch_filled = batch.clone()
    n_constit = batch_filled.shape[2]
    n_nonzero = torch.sum(batch_filled[:,0,:]>0, dim=1)
    
    n_split = torch.min(torch.stack([n_nonzero, n_constit-n_nonzero], dim=1), dim=1).values
    idx_flip = torch.where(n_nonzero != n_split)[0]
    mask_split = (batch_filled[:,0,:] != 0)
    
    mask_split[idx_flip] = torch.flip(mask_split[idx_flip].float(), dims=[1]).bool()

    #print(mask_split)
    mask_split[idx_flip] = ~mask_split[idx_flip]
    r_split = torch.rand(size=mask_split.shape, device=batch.device)
    
    a = r_split*mask_split*batch_filled[:,0,:]
    b = (1-r_split)*mask_split*batch_filled[:,0,:]
    c = ~mask_split*batch_filled[:,0,:]
    batch_filled[:,0,:] = a+c+torch.flip(b, dims=[1])
    batch_filled[:,1,:] += torch.flip(mask_split*batch_filled[:,1,:], dims=[1])
    batch_filled[:,2,:] += torch.flip(mask_split*batch_filled[:,2,:], dims=[1])
    return batch_filled

In [None]:
def collinear_fill_jets_np( batch ):
    '''
    Input: batch of jets, shape (batchsize, 3, n_constit)
    dim 1 ordering: (pT, eta, phi)
    Output: batch of jets with collinear splittings, the function attempts to fill as many of the zero-padded args.nconstit
    entries with collinear splittings of the constituents by splitting each constituent at most once, same shape as input
    '''
    batchb = batch.copy()
    nc = batch.shape[2]
    nzs = np.array( [ np.where( batch[:,0,:][i]>0.0)[0].shape[0] for i in range(len(batch)) ] )
    for k in range(len(batch)):
        nzs1 = np.max( [ nzs[k], int(nc/2) ] )
        zs1 = int(nc-nzs1)
        els = np.random.choice( np.linspace(0,nzs1-1,nzs1), size=zs1, replace=False )
        rs = np.random.uniform( size=zs1 )
        for j in range(zs1):
            batchb[k,0,int(els[j])] = rs[j]*batch[k,0,int(els[j])]
            batchb[k,0,int(nzs[k]+j)] = (1-rs[j])*batch[k,0,int(els[j])]
            batchb[k,1,int(nzs[k]+j)] = batch[k,1,int(els[j])]
            batchb[k,2,int(nzs[k]+j)] = batch[k,2,int(els[j])]
    return batchb

In [None]:
t = collinear_fill_jets_fast(x)
n = collinear_fill_jets_np(x.cpu().numpy())

print(t.cpu().numpy() - n)

## Rotations 

### Only phi

In [None]:
def rotate_jets(batch):

    rot_batch = batch # is this right when it should stay on gpu?
    batch_size = batch.size(0)
    constit = batch.size(2)

    rotate_tensor = torch.rand([batch_size,constit]) * 2 * np.pi #creating the array of random rotations

    rot_batch[:,2,:] =+ np.pi # shifting the phi tensor to make use of the % function
    rot_batch[:,2,:] += rotate_tensor
    rot_batch[:,2,:] %= 2 * np.pi # getting it in the same output range
    rot_batch[:,2,:] =- np.pi # shifting back

    return rot_batch

### Bary does it all

In [None]:
import torch


def rotate_jets(batch, ra ):
    '''
    Input: batch of jets, shape (batchsize, 3, n_constit)
    dim 1 ordering: (pT, eta, phi)
    Output: batch of jets rotated independently in eta-phi, same shape as input
    '''
    device = batch.device
    batch_size = batch.size(0)

    torch.manual_seed(42)
    rot_angle = ra
    #rot_angle = torch.rand(batch_size, device="cpu") * 2 * np.pi
    c = torch.cos(rot_angle)
    s = torch.sin(rot_angle)
    o = torch.ones_like(rot_angle)
    z = torch.zeros_like(rot_angle)

    #print(o.shape)

    rot_matrix = torch.stack([
    torch.stack([o, z, z], dim=0),
    torch.stack([z, c, s], dim=0),
    torch.stack([z, -s, c], dim=0)], dim=1) # (3, 3, batch_size]

    #print(rot_matrix[:,:,0])

    return torch.einsum('ijk,lji->ilk', batch, rot_matrix)


In [None]:
def rotate_jets_np( batch , ra ):
    '''
    Input: batch of jets, shape (batchsize, 3, n_constit)
    dim 1 ordering: (pT, eta, phi)
    Output: batch of jets rotated independently in eta-phi, same shape as input
    '''

    rot_angle = ra

    torch.manual_seed(42)
    #rot_angle = (torch.rand(128, device="cpu") * 2 * np.pi).cpu().numpy()
    c = np.cos(rot_angle)
    s = np.sin(rot_angle)
    o = np.ones_like(rot_angle)
    z = np.zeros_like(rot_angle)
    rot_matrix = np.array([[o, z, z], [z, c, -s], [z, s, c]]) # (3, 3, batchsize)
    return np.einsum('ijk,lji->ilk', batch, rot_matrix)

In [None]:
rot_angle = torch.rand(batch_size, device="cpu") * 2 * np.pi

t = rotate_jets(x.cpu(),rot_angle).cpu().numpy()
n = rotate_jets_np(x.cpu().numpy(),rot_angle.numpy())

print((t- n))

In [None]:
batch_size = 2
rot_angle = torch.ones(batch_size, device="cpu")

c = torch.Tensor([1])#torch.cos(rot_angle)
s = torch.Tensor([2])#torch.sin(rot_angle)
o = torch.Tensor([3])#torch.ones_like(rot_angle)
z = torch.Tensor([4])#torch.zeros_like(rot_angle)

rot_matrix = torch.stack([
torch.stack([o, z, z], dim=0),
torch.stack([z, c, s], dim=0),
torch.stack([z, -s, c], dim=0)], dim=1) # (3, 3, batch_size]

rot_matrix_np = np.array([[o.numpy(), z.numpy(), z.numpy()], [z.numpy(), c.numpy(), -s.numpy()], [z.numpy(), s.numpy(), c.numpy()]]) # (3, 3, batchsize)

In [None]:
print(rot_matrix.numpy())
print(rot_matrix_np)

## Distort Jets 

In [None]:
import torch
torch.manual_seed(42)

def distort_jets(batch, strength=0.1, pT_clip_min=0.1):
    '''
    Input: batch of jets, shape (batchsize, 3, n_constit)
    dim 1 ordering: (pT, eta, phi)
    Output: batch of jets with each constituents position shifted independently, shifts drawn from normal with mean 0, std strength/pT, same shape as input
    '''
    #strength = torch.Tensor(strength).to(device)
    pT = batch[:, 0]  # (batchsize, n_constit)
    shift_eta = torch.nan_to_num(
        #strength * torch.randn(batch.shape[0], batch.shape[2]).to(device) / pT.clip(min=pT_clip_min).to(device),
        strength * torch.ones(batch.shape[0], batch.shape[2]).to(device) / pT.clip(min=pT_clip_min).to(device),
        posinf=0.0,
        neginf=0.0,
    ).to(device)  # * mask
    shift_phi = torch.nan_to_num(
        #strength * torch.randn(batch.shape[0], batch.shape[2]).to(device) / pT.clip(min=pT_clip_min).to(device),
        strength * torch.ones(batch.shape[0], batch.shape[2]).to(device) / pT.clip(min=pT_clip_min).to(device),
        posinf=0.0,
        neginf=0.0,
    ).to(device)  # * mask
    shift = torch.stack([torch.zeros((batch.shape[0], batch.shape[2])).to(device), shift_eta, shift_phi], 1)
    return batch + shift


In [None]:
def distort_jets_np( batch, strength=0.1, pT_clip_min=0.1 ):
    '''
    Input: batch of jets, shape (batchsize, 3, n_constit)
    dim 1 ordering: (pT, eta, phi)
    Output: batch of jets with each constituents position shifted independently, shifts drawn from normal with mean 0, std strength/pT, same shape as input
    '''
    pT = batch[:,0]   # (batchsize, n_constit)
    #shift_eta = np.nan_to_num( strength * np.random.randn(batch.shape[0], batch.shape[2]) / pT.clip(min=pT_clip_min), posinf = 0.0, neginf = 0.0 )# * mask
    #shift_phi = np.nan_to_num( strength * np.random.randn(batch.shape[0], batch.shape[2]) / pT.clip(min=pT_clip_min), posinf = 0.0, neginf = 0.0 )# * mask

    shift_eta = np.nan_to_num(strength * np.ones((batch.shape[0], batch.shape[2]), dtype=np.float64) / pT.clip(min=pT_clip_min), posinf=0.0, neginf=0.0)
    shift_phi = np.nan_to_num(strength * np.ones((batch.shape[0], batch.shape[2]), dtype=np.float64) / pT.clip(min=pT_clip_min), posinf=0.0, neginf=0.0)

    shift = np.stack( [ np.zeros( (batch.shape[0], batch.shape[2]) ), shift_eta, shift_phi ], 1)
    return batch + shift

In [None]:
t = distort_jets(x).cpu()
n = distort_jets_np(x.cpu().numpy())

print(t.numpy() - n)

## Translate Jets 

In [None]:
import torch

def ptp(input, dim=None, keepdim=False):
    if dim is None:
        return input.max() - input.min()
    return input.max(dim, keepdim).values - input.min(dim, keepdim).values


def translate_jets(batch, width=1.0):
    '''
    Input: batch of jets, shape (batchsize, 3, n_constit)
    dim 1 ordering: (pT, eta, phi)
    Output: batch of eta-phi translated jets, same shape as input
    '''

    device = batch.device
    mask = (batch[:, 0] > 0)  # 1 for constituents with non-zero pT, 0 otherwise
    ptp_eta = ptp(batch[:, 1, :], dim=-1, keepdim=True)  # ptp = 'peak to peak' = max - min DOUBLE CHECKED
    ptp_phi = ptp(batch[:, 2, :], dim=-1, keepdim=True)  # ptp = 'peak to peak' = max - min
    low_eta = -width * ptp_eta
    high_eta = +width * ptp_eta
    low_phi = torch.max(-width * ptp_phi, -torch.tensor(np.pi, device=device) - torch.amin(batch[:, 2, :], dim=-1, keepdim=True))
    high_phi = torch.min(+width * ptp_phi, +torch.tensor(np.pi, device=device) - torch.amax(batch[:, 2, :], dim=-1, keepdim=True)) #DOUBLE CHECKED
    shift_eta = mask * torch.rand((batch.shape[0], 1), device=device) * (high_eta - low_eta) + low_eta
    shift_phi = mask * torch.rand((batch.shape[0], 1), device=device) * (high_phi - low_phi) + low_phi
    shift = torch.stack([torch.zeros((batch.shape[0], batch.shape[2]), device=device), shift_eta, shift_phi], dim=1)
    print(shift)
    shifted_batch = batch + shift
    return shifted_batch



In [None]:
def translate_jets_np( batch, width=1.0 ):
    '''
    Input: batch of jets, shape (batchsize, 3, n_constit)
    dim 1 ordering: (pT, eta, phi)
    Output: batch of eta-phi translated jets, same shape as input
    '''
    
    mask = (batch[:,0] > 0) # 1 for constituents with non-zero pT, 0 otherwise
    ptp_eta  = np.ptp(batch[:,1,:], axis=-1, keepdims=True) # ptp = 'peak to peak' = max - min
    ptp_phi  = np.ptp(batch[:,2,:], axis=-1, keepdims=True) # ptp = 'peak to peak' = max - min
    low_eta  = -width*ptp_eta
    high_eta = +width*ptp_eta
    low_phi  = np.maximum(-width*ptp_phi, -np.pi-np.amin(batch[:,2,:], axis=1).reshape(ptp_phi.shape))
    high_phi = np.minimum(+width*ptp_phi, +np.pi-np.amax(batch[:,2,:], axis=1).reshape(ptp_phi.shape))
    print(high_phi)
    #shift_eta = (mask * torch.rand((batch.shape[0], 1), device=device) * (high_eta - low_eta)).cpu().numpy()
    #shift_phi = (mask * torch.rand((batch.shape[0], 1), device=device) * (high_phi - low_phi)).cpu().numpy()

    shift_eta = mask*np.random.uniform(low=low_eta, high=high_eta, size=(batch.shape[0], 1))
    shift_phi = mask*np.random.uniform(low=low_phi, high=high_phi, size=(batch.shape[0], 1))
    shift = np.stack([np.zeros((batch.shape[0], batch.shape[2])), shift_eta, shift_phi], 1)
    print(shift)
    shifted_batch = batch+shift
    return shifted_batch


As the random function can not be compared i show here that they are doing the same using the mean and the variance.

In [None]:
batch = x
ptp_eta = ptp(batch[:, 1, :], dim=-1, keepdim=True) 
batch = batch.cpu()
ptp_eta_np  = np.ptp(batch[:,1,:], axis=-1, keepdims=True) # ptp = 'peak to peak' = max - min

print(ptp_eta)
print(ptp_eta_np)

import numpy as np
import torch

# Set the seed for reproducibility
np.random.seed(0)
torch.manual_seed(0)

# Parameters for uniform distribution
low = 0.0
high = 1.0
size = (200, 3000)  # Shape of the output array

# Generate random numbers using np.random.uniform()
np_uniform = np.random.uniform(low=low, high=high, size=size)
print("NumPy random numbers:\n", np_uniform)
print(np.mean(np_uniform),np.var(np_uniform))
# Generate random numbers using torch.rand()
torch_uniform = low + (high - low) * torch.rand(size)
print("PyTorch random numbers:\n", torch_uniform)
print(torch.mean(torch_uniform),torch.var(torch_uniform))

## normalize

In [None]:
def normalise_pts(batch):
    '''
    Input: batch of jets, shape (batchsize, 3, n_constit)
    dim 1 ordering: (pT, eta, phi)
    Output: batch of pT-normalised jets, pT in each jet sums to 1, same shape as input
    '''
    batch_norm = batch.clone()
    batch_norm[:, 0, :] = torch.nan_to_num(batch_norm[:, 0, :] / torch.sum(batch_norm[:, 0, :], dim=1)[:, None], posinf=0.0, neginf=0.0)
    return batch_norm

In [None]:
def normalise_pts_np( batch ):
    '''
    Input: batch of jets, shape (batchsize, 3, n_constit)
    dim 1 ordering: (pT, eta, phi)
    Output: batch of pT-normalised jets, pT in each jet sums to 1, same shape as input
    '''
    batch_norm = batch.copy()
    batch_norm[:,0,:] = np.nan_to_num(batch_norm[:,0,:]/np.sum(batch_norm[:,0,:], axis=1)[:, np.newaxis], posinf = 0.0, neginf = 0.0 )
    return batch_norm

In [None]:
t = normalise_pts(x).cpu()
n = normalise_pts_np(x.cpu().numpy())

print(t.numpy() - n)

## rescale

In [None]:
def rescale_pts(batch):
    '''
    Input: batch of jets, shape (batchsize, 3, n_constit)
    dim 1 ordering: (pT, eta, phi)
    Output: batch of pT-rescaled jets, each constituent pT is rescaled by 600, same shape as input
    '''
    batch_rscl = batch.clone()
    batch_rscl[:,0,:] = torch.nan_to_num(batch_rscl[:,0,:]/600, nan=0.0, posinf=0.0, neginf=0.0)
    return batch_rscl

In [None]:
def rescale_pts_np( batch ):
    '''
    Input: batch of jets, shape (batchsize, 3, n_constit)
    dim 1 ordering: (pT, eta, phi)
    Output: batch of pT-rescaled jets, each constituent pT is rescaled by 600, same shape as input
    '''
    batch_rscl = batch.copy()
    batch_rscl[:,0,:] = np.nan_to_num(batch_rscl[:,0,:]/600, posinf = 0.0, neginf = 0.0 )
    return batch_rscl


In [None]:
t = rescale_pts(x).cpu()
n = rescale_pts_np(x.cpu().numpy())


print(t)
print(t.numpy() - n)

## Crop it 

In [None]:
def crop_jets( batch, nc ):
    '''
    Input: batch of jets, shape (batchsize, 3, n_constit)
    dim 1 ordering: (pT, eta, phi)
    Output: batch of cropped jets, each jet is cropped to nc constituents, shape (batchsize, 3, nc)
    '''
    batch_crop = batch.clone()
    return batch_crop[:,:,0:nc]

In [None]:
def crop_jets_np( batch, nc ):
    '''
    Input: batch of jets, shape (batchsize, 3, n_constit)
    dim 1 ordering: (pT, eta, phi)
    Output: batch of cropped jets, each jet is cropped to nc constituents, shape (batchsize, 3, nc)
    '''
    batch_crop = batch.copy()
    return batch_crop[:,:,0:nc]

In [None]:
t = crop_jets(x,3).cpu()
n = crop_jets_np(x.cpu().numpy(),3)

print(t.numpy() - n)