In [1]:
!pip install memory_profiler
%load_ext memory_profiler



In [2]:
import sys
import os
import numpy as np
import matplotlib.pyplot as plt
import random
import time
import glob
import argparse
sys.path.append('../')

# load torch modules
import torch
import torch.nn as nn
import torch.nn.functional as F

# load custom modules required for jetCLR training
from scripts.modules.jet_augs import rotate_jets, distort_jets, rescale_pts, crop_jets, translate_jets, collinear_fill_jets
# from scripts.modules.transformer import Transformer
from scripts.modules.losses import contrastive_loss, align_loss, uniform_loss
from scripts.modules.perf_eval import get_perf_stats, linear_classifier_test 

In [3]:
def load_data(dataset_path, flag, n_files=-1):
    if args.full_kinematics:
        data_files = glob.glob(f"{dataset_path}/{flag}/processed/7_features_raw/data/*")
    else:
        data_files = glob.glob(f"{dataset_path}/{flag}/processed/3_features_raw/data/*")

    data = []
    for i, file in enumerate(data_files):
        if args.full_kinematics:
            data.append(np.load(f"{dataset_path}/{flag}/processed/7_features_raw/data/data_{i}.npy")) 
        else:
            data.append(np.load(f"{dataset_path}/{flag}/processed/3_features_raw/data/data_{i}.pt")) 
        print(f"--- loaded file {i} from `{flag}` directory")
        if n_files != -1 and i == n_files - 1:
            break

    return data


def load_labels(dataset_path, flag, n_files=-1):
    data_files = glob.glob(f"{dataset_path}/{flag}/processed/7_features_raw/labels/*")

    data = []
    for i, file in enumerate(data_files):
        data.append(np.load(f"{dataset_path}/{flag}/processed/7_features_raw/labels/labels_{i}.npy"))
        print(f"--- loaded label file {i} from `{flag}` directory")
        if n_files != -1 and i == n_files - 1:
            break

    return data

In [4]:
parser = argparse.ArgumentParser()
args = parser.parse_args(args=[])

In [5]:
inpput_dim = 7
args.sbratio = 1
args.output_dim = 1000
args.model_dim = 1000 
args.n_heads = 4
args.dim_feedforward= 1000
args.n_layers= 4 
args.learning_rate = 0.00005 
args.n_head_layers = 2 
args.opt = "adam"
args.label = "test-notebook"
# args.load_path = f"/ssl-jet-vol-v2/JetCLR/models/experiments/{args.label}/model_ep20.pt"
args.trs = True
args.mask = False
args.cmask = True
args.batch_size = 128
args.trsw = 0.1
args.full_kinematics = True
args.num_files = 1

In [6]:
print( "loading data")
data = load_data("/ssl-jet-vol-v2/toptagging", "train", args.num_files)
labels = load_labels("/ssl-jet-vol-v2/toptagging", "train", args.num_files)
tr_dat_in = np.concatenate(data, axis=0)  # Concatenate along the first axis
tr_lab_in = np.concatenate(labels, axis=0)
tr_dat_in = tr_dat_in[:10]
tr_lab_in = tr_lab_in[:10]
print(tr_lab_in.shape)

loading data
--- loaded file 0 from `train` directory
--- loaded label file 0 from `train` directory
(10,)


In [7]:
# input dim to the transformer -> (pt,eta,phi)
input_dim = tr_dat_in.shape[1]
print("input_dim: ", input_dim)

# creating the training dataset
print( "shuffling data and doing the S/B split", flush=True )
tr_bkg_dat = tr_dat_in[ tr_lab_in==0 ].copy()
tr_sig_dat = tr_dat_in[ tr_lab_in==1 ].copy()
nbkg_tr = int( tr_bkg_dat.shape[0] )
nsig_tr = int( args.sbratio * nbkg_tr )
list_tr_dat = list( tr_bkg_dat[ 0:nbkg_tr ] ) + list( tr_sig_dat[ 0:nsig_tr ] )
list_tr_lab = [ 0 for i in range( nbkg_tr ) ] + [ 1 for i in range( nsig_tr ) ]
ldz_tr = list( zip( list_tr_dat, list_tr_lab ) )
random.shuffle( ldz_tr )
tr_dat, tr_lab = zip( *ldz_tr )
# reducing the training data
tr_dat = np.array( tr_dat )
tr_lab = np.array( tr_lab )

# create two validation sets: 
# one for training the linear classifier test (LCT)
# and one for testing on it
# we will do this just with tr_dat_in, but shuffled and split 50/50
# this should be fine because the jetCLR training doesn't use labels
# we want the LCT to use S/B=1 all the time
list_vl_dat = list( tr_dat_in.copy() )
list_vl_lab = list( tr_lab_in.copy() )
ldz_vl = list( zip( list_vl_dat, list_vl_lab ) )
random.shuffle( ldz_vl )
vl_dat, vl_lab = zip( *ldz_vl )
vl_dat = np.array( vl_dat )
vl_lab = np.array( vl_lab )
vl_len = vl_dat.shape[0]
vl_split_len = int( vl_len/2 )
vl_dat_1 = vl_dat[ 0:vl_split_len ]
vl_lab_1 = vl_lab[ 0:vl_split_len ]
vl_dat_2 = vl_dat[ -vl_split_len: ]
vl_lab_2 = vl_lab[ -vl_split_len: ]

input_dim:  7
shuffling data and doing the S/B split


# Transformer

In [8]:
import os
import sys
import numpy as np
import random
import time

import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F


# class for transformer network
class Transformer(nn.Module):
    # define and intialize the structure of the neural network
    def __init__(
        self,
        input_dim,
        model_dim,
        output_dim,
        n_heads,
        dim_feedforward,
        n_layers,
        learning_rate,
        n_head_layers=2,
        head_norm=False,
        dropout=0.1,
        opt="adam",
        log=False,
    ):
        super().__init__()
        # define hyperparameters
        self.input_dim = input_dim
        self.model_dim = model_dim
        self.output_dim = output_dim
        self.n_heads = n_heads
        self.dim_feedforward = dim_feedforward
        self.n_layers = n_layers
        self.learning_rate = learning_rate
        self.n_head_layers = n_head_layers
        self.head_norm = head_norm
        self.dropout = dropout
        self.log = log
        # define subnetworks
        self.embedding = nn.Linear(input_dim, model_dim)
        self.transformer = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(
                model_dim, n_heads, dim_feedforward=dim_feedforward, dropout=dropout
            ),
            n_layers,
        )
        # head_layers have output_dim
        if n_head_layers == 0:
            self.head_layers = []
        else:
            if head_norm:
                self.norm_layers = nn.ModuleList([nn.LayerNorm(model_dim)])
            self.head_layers = nn.ModuleList([nn.Linear(model_dim, output_dim)])
            for i in range(n_head_layers - 1):
                if head_norm:
                    self.norm_layers.append(nn.LayerNorm(output_dim))
                self.head_layers.append(nn.Linear(output_dim, output_dim))
        # option to use adam or sgd
        if opt == "adam":
            self.optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
        if opt == "sgdca" or opt == "sgdslr" or opt == "sgd":
            self.optimizer = torch.optim.SGD(
                self.parameters(), lr=self.learning_rate, momentum=0.9
            )

    def forward(
        self,
        inpt,
        mask=None,
        use_mask=False,
        use_continuous_mask=False,
        mult_reps=False,
    ):
        """
        input here is (batch_size, n_constit, 3 or 7)
        but transformer expects (n_constit, batch_size, 3 or 7) so we need to transpose
        if use_mask is True, will mask out all inputs with pT=0
        """
        print(f"input shape: {inpt.shape}")
        assert not (use_mask and use_continuous_mask)
        pt_index = 2 if args.full_kinematics else 0
        # make a copy
        x = inpt + 0.0
        if use_mask:
            pT_zero = x[:, :, pt_index] == 0
        if use_continuous_mask:
            if self.log:
                log_pT = x[:, :, pt_index]
                # exponentiate to get actual pT
                pT = torch.where(log_pT != 0, torch.exp(log_pT), torch.zeros_like(log_pT))
            else:
                pT = x[:, :, pt_index]
            print(f"pT: {pT}")
        if use_mask:
            mask = self.make_mask(pT_zero).to(x.device)
        elif use_continuous_mask:
            mask = self.make_continuous_mask(pT).to(x.device)
        else:
            mask = None
        print(f"mask : {mask}")
        x = torch.transpose(x, 0, 1)
        # (n_constit, batch_size, model_dim)
        x = self.embedding(x)
        print(f"after embedding: {x.shape}")
        x = self.transformer(x, mask=mask)
        print(f"after transformer: {x.shape}")
        if use_mask:
            # set masked constituents to zero
            # otherwise the sum will change if the constituents with 0 pT change
            x[torch.transpose(pT_zero, 0, 1)] = 0
        elif use_continuous_mask:
            # scale x by pT, so that function is IR safe
            # transpose first to get correct shape
#             x *= torch.transpose(pT, 0, 1)[:, :, None]
            pass
        # sum over sequence dim
        # (batch_size, model_dim)
        x = x.sum(0)
        print(f"after summing: {x.shape}")
        return self.head(x, mult_reps)

    def head(self, x, mult_reps):
        """
        calculates output of the head if it exists, i.e. if n_head_layer>0
        returns multiple representation layers if asked for by mult_reps = True
        input:  x shape=(batchsize, model_dim)
                mult_reps boolean
        output: reps shape=(batchsize, output_dim)                  for mult_reps=False
                reps shape=(batchsize, number_of_reps, output_dim)  for mult_reps=True
        """
        relu = nn.ReLU()
        if mult_reps == True:
            if self.n_head_layers > 0:
                reps = torch.empty(x.shape[0], self.n_head_layers + 1, self.output_dim)
                # Transform x to output_dim size before assignment
                x_transformed = (
                    self.head_layers[0](relu(x)) if self.n_head_layers > 0 else x
                )
                reps[:, 0] = x_transformed
                for i, layer in enumerate(self.head_layers):
                    if self.head_norm:
                        x = self.norm_layers[i](x)
                    x = relu(x)
                    x = layer(x)
                    reps[:, i + 1] = x
                return reps
            else:
                reps = x[:, None, :]
                return reps
        else:
            for i, layer in enumerate(self.head_layers):
                if self.head_norm:
                    x = self.norm_layers[i](x)
                x = relu(x)
                x = layer(x)
            return x

    def forward_batchwise(
        self, x, batch_size, use_mask=False, use_continuous_mask=False
    ):
        device = next(self.parameters()).device
        with torch.no_grad():
            if self.n_head_layers == 0:
                rep_dim = self.model_dim
                number_of_reps = 1
            elif self.n_head_layers > 0:
                rep_dim = self.output_dim
                number_of_reps = self.n_head_layers + 1
            out = torch.empty(x.size(0), number_of_reps, rep_dim)
            idx_list = torch.split(torch.arange(x.size(0)), batch_size)
            for idx in idx_list:
                output = (
                    self(
                        x[idx].to(device),
                        use_mask=use_mask,
                        use_continuous_mask=use_continuous_mask,
                        mult_reps=True,
                    )
                    .detach()
                    .cpu()
                )
                out[idx] = output
        return out

    def make_mask(self, pT_zero):
        """
        Input: batch of bools of whether pT=0, shape (batchsize, n_constit)
        Output: mask for transformer model which masks out constituents with pT=0, shape (batchsize*n_transformer_heads, n_constit, n_constit)
        mask is added to attention output before softmax: 0 means value is unchanged, -inf means it will be masked
        """
        n_constit = pT_zero.size(1)
        pT_zero = torch.repeat_interleave(pT_zero, self.n_heads, axis=0)
        pT_zero = torch.repeat_interleave(pT_zero[:, None], n_constit, axis=1)
        mask = torch.zeros(pT_zero.size(0), n_constit, n_constit)
        mask[pT_zero] = -np.inf
#         print(f"mask: {mask}")
        return mask

    def make_continuous_mask(self, pT):
        """
        Input: batch of pT values, shape (batchsize, n_constit)
        Output: mask for transformer model: -1/pT, shape (batchsize*n_transformer_heads, n_constit, n_constit)
        mask is added to attention output before softmax: 0 means value is unchanged, -inf means it will be masked
        intermediate values mean it is partly masked
        This function implements IR safety in the transformer
        """
#         print(f"pT : {pT}")
        n_constit = pT.size(1)
        pT_reshape = torch.repeat_interleave(pT, self.n_heads, axis=0)
        pT_reshape = torch.repeat_interleave(pT_reshape[:, None], n_constit, axis=1)
        # mask = -1/pT_reshape
        mask = 0.5 * torch.log(pT_reshape)
#         print(f"mask: {mask}")
        return mask


In [9]:
# set-up parameters for the LCT
linear_input_size = args.output_dim
linear_n_epochs = 750
linear_learning_rate = 0.001
linear_batch_size = 128

# initialise the network
print( "initialising the network", flush=True )
%memit net = Transformer( input_dim, args.model_dim, args.output_dim, args.n_heads, args.dim_feedforward, args.n_layers, args.learning_rate, args.n_head_layers, dropout=0.1, opt=args.opt, log=args.full_kinematics )
# send network to device
device = torch.device( "cuda" if torch.cuda.is_available() else "cpu" )
net.to( device )
# print(net)
# net.load_state_dict(torch.load(f"{args.load_path}"))

initialising the network
peak memory: 1013.61 MiB, increment: 104.00 MiB


Transformer(
  (embedding): Linear(in_features=7, out_features=1000, bias=True)
  (transformer): TransformerEncoder(
    (layers): ModuleList(
      (0): TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=1000, out_features=1000, bias=True)
        )
        (linear1): Linear(in_features=1000, out_features=1000, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
        (linear2): Linear(in_features=1000, out_features=1000, bias=True)
        (norm1): LayerNorm((1000,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((1000,), eps=1e-05, elementwise_affine=True)
        (dropout1): Dropout(p=0.1, inplace=False)
        (dropout2): Dropout(p=0.1, inplace=False)
      )
      (1): TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=1000, out_features=1000, bias=True)
        )
        (linear1): Linear(in_f

In [10]:
import gc
print( "starting the final LCT run", flush=True )
print("obtaining representations")
# evaluate the network on the testing data, applying some augmentations first if it's required
# if args.trs:
#     vl_dat_1 = translate_jets( vl_dat_1, width=args.trsw )
#     vl_dat_2 = translate_jets( vl_dat_2, width=args.trsw )
with torch.no_grad():
    net.eval()
    #vl_reps_1 = F.normalize( net.forward_batchwise( torch.Tensor( vl_dat_1 ).transpose(1,2), args.batch_size, use_mask=args.mask, use_continuous_mask=args.cmask ).detach().cpu(), dim=-1 ).numpy()
    #vl_reps_2 = F.normalize( net.forward_batchwise( torch.Tensor( vl_dat_2 ).transpose(1,2), args.batch_size, use_mask=args.mask, use_continuous_mask=args.cmask ).detach().cpu(), dim=-1 ).numpy()
    vl_reps_1 = net.forward_batchwise( torch.Tensor( vl_dat_1 ).transpose(1,2), args.batch_size, use_mask=args.mask, use_continuous_mask=args.cmask ).detach().cpu().numpy()
#     vl_reps_2 = net.forward_batchwise( torch.Tensor( vl_dat_2 ).transpose(1,2), args.batch_size, use_mask=args.mask, use_continuous_mask=args.cmask ).detach().cpu().numpy()
    net.train()
#     del vl_dat_1, vl_dat_2
#     del net
#     gc.collect()
print("finished obtaining representations, starting LCT")

starting the final LCT run
obtaining representations
input shape: torch.Size([5, 50, 7])
pT: tensor([[1.4794e+02, 5.5542e+01, 4.3797e+01, 3.4844e+01, 2.9229e+01, 2.6833e+01,
         2.2414e+01, 2.1758e+01, 1.7662e+01, 1.5363e+01, 1.4716e+01, 1.3230e+01,
         1.2635e+01, 8.8253e+00, 8.7712e+00, 8.7003e+00, 7.9881e+00, 7.3455e+00,
         6.1866e+00, 5.4705e+00, 5.0854e+00, 4.7240e+00, 3.6427e+00, 3.5194e+00,
         3.1081e+00, 3.1059e+00, 3.0073e+00, 2.9894e+00, 2.1623e+00, 2.1467e+00,
         2.0780e+00, 1.5348e+00, 1.4796e+00, 1.1885e+00, 1.1550e+00, 9.1264e-01,
         7.1755e-01, 6.6400e-01, 5.3050e-01, 4.5812e-01, 3.7212e-01, 0.0000e+00,
         0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         0.0000e+00, 0.0000e+00],
        [7.8762e+01, 7.4373e+01, 5.7039e+01, 5.4135e+01, 4.6028e+01, 3.6450e+01,
         2.1482e+01, 1.9150e+01, 1.8814e+01, 1.4018e+01, 1.2491e+01, 1.2001e+01,
         1.0184e+01, 9.9931e+00, 8.4996e+00, 7.6117e+00, 7.1444

In [11]:
vl_reps_1

array([[[-13.654608  ,  23.118135  , -17.19412   , ...,  35.22066   ,
          20.608694  , -44.329098  ],
        [-13.654608  ,  23.118135  , -17.19412   , ...,  35.22066   ,
          20.608694  , -44.329098  ],
        [  4.5870886 ,   1.2627772 ,  -4.2939386 , ...,   3.5898564 ,
          -2.4007485 , -15.663402  ]],

       [[  2.9299662 ,  19.328493  , -11.448227  , ...,  22.024143  ,
          12.216901  , -61.507187  ],
        [  2.9299662 ,  19.328493  , -11.448227  , ...,  22.024143  ,
          12.216901  , -61.507187  ],
        [ -3.06046   ,  -3.3615606 ,  -7.2038755 , ...,   2.6103375 ,
          -5.6897707 , -17.941868  ]],

       [[  1.8149749 ,   8.06336   , -13.503469  , ...,  26.003977  ,
           2.6337273 , -50.747673  ],
        [  1.8149749 ,   8.06336   , -13.503469  , ...,  26.003977  ,
           2.6337273 , -50.747673  ],
        [  1.25453   ,  -1.1534108 ,  -4.2633305 , ...,   3.3173125 ,
          -3.9180858 , -16.847599  ]],

       [[ -6.336482  ,

In [12]:
import gc
print( "starting the final LCT run", flush=True )
print("obtaining representations")
# evaluate the network on the testing data, applying some augmentations first if it's required
# if args.trs:
#     vl_dat_1 = translate_jets( vl_dat_1, width=args.trsw )
#     vl_dat_2 = translate_jets( vl_dat_2, width=args.trsw )
with torch.no_grad():
    net.eval()
    args.mask = True
    args.cmask = False
    #vl_reps_1 = F.normalize( net.forward_batchwise( torch.Tensor( vl_dat_1 ).transpose(1,2), args.batch_size, use_mask=args.mask, use_continuous_mask=args.cmask ).detach().cpu(), dim=-1 ).numpy()
    #vl_reps_2 = F.normalize( net.forward_batchwise( torch.Tensor( vl_dat_2 ).transpose(1,2), args.batch_size, use_mask=args.mask, use_continuous_mask=args.cmask ).detach().cpu(), dim=-1 ).numpy()
    vl_reps_1 = net.forward_batchwise( torch.Tensor( vl_dat_1 ).transpose(1,2), args.batch_size, use_mask=args.mask, use_continuous_mask=args.cmask ).detach().cpu().numpy()
#     vl_reps_2 = net.forward_batchwise( torch.Tensor( vl_dat_2 ).transpose(1,2), args.batch_size, use_mask=args.mask, use_continuous_mask=args.cmask ).detach().cpu().numpy()
    net.train()
#     del vl_dat_1, vl_dat_2
#     del net
#     gc.collect()
print("finished obtaining representations, starting LCT")

starting the final LCT run
obtaining representations
input shape: torch.Size([5, 50, 7])
mask : tensor([[[0., 0., 0.,  ..., -inf, -inf, -inf],
         [0., 0., 0.,  ..., -inf, -inf, -inf],
         [0., 0., 0.,  ..., -inf, -inf, -inf],
         ...,
         [0., 0., 0.,  ..., -inf, -inf, -inf],
         [0., 0., 0.,  ..., -inf, -inf, -inf],
         [0., 0., 0.,  ..., -inf, -inf, -inf]],

        [[0., 0., 0.,  ..., -inf, -inf, -inf],
         [0., 0., 0.,  ..., -inf, -inf, -inf],
         [0., 0., 0.,  ..., -inf, -inf, -inf],
         ...,
         [0., 0., 0.,  ..., -inf, -inf, -inf],
         [0., 0., 0.,  ..., -inf, -inf, -inf],
         [0., 0., 0.,  ..., -inf, -inf, -inf]],

        [[0., 0., 0.,  ..., -inf, -inf, -inf],
         [0., 0., 0.,  ..., -inf, -inf, -inf],
         [0., 0., 0.,  ..., -inf, -inf, -inf],
         ...,
         [0., 0., 0.,  ..., -inf, -inf, -inf],
         [0., 0., 0.,  ..., -inf, -inf, -inf],
         [0., 0., 0.,  ..., -inf, -inf, -inf]],

        ..

In [13]:
vl_reps_1[:, :, :]

array([[[-10.8137245 ,  20.04731   , -12.585735  , ...,  30.353998  ,
          21.448772  , -43.442318  ],
        [-10.8137245 ,  20.04731   , -12.585735  , ...,  30.353998  ,
          21.448772  , -43.442318  ],
        [  2.7370608 ,   2.056719  ,  -3.056351  , ...,   1.1202877 ,
          -2.482028  , -13.235533  ]],

       [[  1.5579587 ,  19.09314   , -11.285239  , ...,  23.047926  ,
          18.242231  , -67.74737   ],
        [  1.5579587 ,  19.09314   , -11.285239  , ...,  23.047926  ,
          18.242231  , -67.74737   ],
        [ -2.3121598 ,  -0.52468693,  -6.8692713 , ...,   1.7296107 ,
          -7.405306  , -16.74133   ]],

       [[  2.3459933 ,   9.238085  ,  -5.801456  , ...,  11.23956   ,
           7.154127  , -31.271805  ],
        [  2.3459933 ,   9.238085  ,  -5.801456  , ...,  11.23956   ,
           7.154127  , -31.271805  ],
        [ -1.2445474 ,  -0.57823384,  -3.5734456 , ...,   1.0001347 ,
          -4.175991  ,  -8.271915  ]],

       [[ -2.3444216 ,

In [14]:
# del tr_dat_in, tr_bkg_dat, tr_sig_dat, tr_dat, vl_dat
# gc.collect()

In [15]:
# # final LCT for each rep layer
# for run in range(3):
#     for i in range(vl_reps_1.shape[1]):
#         if i == 1:
#             out_dat_f, out_lbs_f, losses_f, val_losses_f = linear_classifier_test( linear_input_size, linear_batch_size, linear_n_epochs, "adam", linear_learning_rate, vl_reps_1[:,i,:], np.expand_dims(vl_lab_1, axis=1), vl_reps_2[:,i,:], np.expand_dims(vl_lab_2, axis=1) )
#             auc, imtafe = get_perf_stats( out_lbs_f, out_dat_f )
#             ep=0
#             step_size = 25
#             for (lss, val_lss) in zip(losses_f[::step_size], val_losses_f):
#                 print( f"(rep layer {i}) epoch: " + str( ep ) + ", loss: " + str( lss ) + ", val loss: " + str( val_lss ), flush=True)
#                 ep+=step_size
#             print( f"(rep layer {i}) auc: "+str( round(auc, 4) ), flush=True )
#             print( f"(rep layer {i}) imtafe: "+str( round(imtafe, 1) ), flush=True)