In [1]:
!pip install memory_profiler
%load_ext memory_profiler



In [2]:
import sys
import os
import numpy as np
import matplotlib.pyplot as plt
import random
import time
import glob
import argparse
sys.path.append('../')

# load torch modules
import torch
import torch.nn as nn
import torch.nn.functional as F

# load custom modules required for jetCLR training
from scripts.modules.jet_augs import rotate_jets, distort_jets, rescale_pts, crop_jets, translate_jets, collinear_fill_jets
# from scripts.modules.transformer import Transformer
from scripts.modules.losses import contrastive_loss, align_loss, uniform_loss
from scripts.modules.perf_eval import get_perf_stats, linear_classifier_test 

In [3]:
# import standard python modules
import os
import sys
import numpy as np
from sklearn import metrics

# import torch modules
import torch
import torch.nn as nn
import torch.nn.functional as F

# import simple FCN network
from scripts.modules.fcn_linear import fully_connected_linear_network
from scripts.modules.fcn import fully_connected_network

# import preprocessing functions
from sklearn.preprocessing import StandardScaler, MaxAbsScaler, RobustScaler

In [4]:
import os
import sys
import numpy as np
import random
import time

import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F


# class for transformer network
class Transformer(nn.Module):
    # define and intialize the structure of the neural network
    def __init__(
        self,
        input_dim,
        model_dim,
        output_dim,
        n_heads,
        dim_feedforward,
        n_layers,
        learning_rate,
        n_head_layers=2,
        head_norm=False,
        dropout=0.1,
        opt="adam",
        log=False,
    ):
        super().__init__()
        # define hyperparameters
        self.input_dim = input_dim
        self.model_dim = model_dim
        self.output_dim = output_dim
        self.n_heads = n_heads
        self.dim_feedforward = dim_feedforward
        self.n_layers = n_layers
        self.learning_rate = learning_rate
        self.n_head_layers = n_head_layers
        self.head_norm = head_norm
        self.dropout = dropout
        self.log = log
        # define subnetworks
        self.embedding = nn.Linear(input_dim, model_dim)
        self.transformer = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(
                model_dim, n_heads, dim_feedforward=dim_feedforward, dropout=dropout
            ),
            n_layers,
        )
        # head_layers have output_dim
        if n_head_layers == 0:
            self.head_layers = []
        else:
            if head_norm:
                self.norm_layers = nn.ModuleList([nn.LayerNorm(model_dim)])
            self.head_layers = nn.ModuleList([nn.Linear(model_dim, output_dim)])
            for i in range(n_head_layers - 1):
                if head_norm:
                    self.norm_layers.append(nn.LayerNorm(output_dim))
                self.head_layers.append(nn.Linear(output_dim, output_dim))
        # option to use adam or sgd
        if opt == "adam":
            self.optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
        if opt == "sgdca" or opt == "sgdslr" or opt == "sgd":
            self.optimizer = torch.optim.SGD(
                self.parameters(), lr=self.learning_rate, momentum=0.9
            )

    def forward(
        self,
        inpt,
        mask=None,
        use_mask=False,
        use_continuous_mask=False,
        mult_reps=False,
    ):
        """
        input here is (batch_size, n_constit, 3 or 7)
        but transformer expects (n_constit, batch_size, 3 or 7) so we need to transpose
        if use_mask is True, will mask out all inputs with pT=0
        """
#         print(f"input shape: {inpt.shape}")
        assert not (use_mask and use_continuous_mask)
        pt_index = 2 if args.full_kinematics else 0
        # make a copy
        x = inpt + 0.0
        if use_mask:
            pT_zero = x[:, :, pt_index] == 0
        if use_continuous_mask:
            if self.log:
                log_pT = x[:, :, pt_index]
                # exponentiate to get actual pT
                pT = torch.where(log_pT != 0, torch.exp(log_pT), torch.zeros_like(log_pT))
            else:
                pT = x[:, :, pt_index]
#             print(f"pT: {pT}")
        if use_mask:
            mask = self.make_mask(pT_zero).to(x.device)
        elif use_continuous_mask:
            mask = self.make_continuous_mask(pT).to(x.device)
        else:
            mask = None
#         print(f"mask : {mask}")
        x = torch.transpose(x, 0, 1)
        # (n_constit, batch_size, model_dim)
        x = self.embedding(x)
#         print(f"after embedding: {x.shape}")
        x = self.transformer(x, mask=mask)
#         print(f"after transformer: {x.shape}")
        if use_mask:
            # set masked constituents to zero
            # otherwise the sum will change if the constituents with 0 pT change
            x[torch.transpose(pT_zero, 0, 1)] = 0
        elif use_continuous_mask:
            # scale x by pT, so that function is IR safe
            # transpose first to get correct shape
#             x *= torch.transpose(pT, 0, 1)[:, :, None]
            pass
        # sum over sequence dim
        # (batch_size, model_dim)
        x = x.sum(0)
#         print(f"after summing: {x.shape}")
        return self.head(x, mult_reps)

    def head(self, x, mult_reps):
        """
        calculates output of the head if it exists, i.e. if n_head_layer>0
        returns multiple representation layers if asked for by mult_reps = True
        input:  x shape=(batchsize, model_dim)
                mult_reps boolean
        output: reps shape=(batchsize, output_dim)                  for mult_reps=False
                reps shape=(batchsize, number_of_reps, output_dim)  for mult_reps=True
        """
        relu = nn.ReLU()
        if mult_reps == True:
            if self.n_head_layers > 0:
                reps = torch.empty(x.shape[0], self.n_head_layers + 1, self.output_dim)
                # Transform x to output_dim size before assignment
                x_transformed = (
                    self.head_layers[0](relu(x)) if self.n_head_layers > 0 else x
                )
                reps[:, 0] = x_transformed
                for i, layer in enumerate(self.head_layers):
                    if self.head_norm:
                        x = self.norm_layers[i](x)
                    x = relu(x)
                    x = layer(x)
                    reps[:, i + 1] = x
                return reps
            else:
                reps = x[:, None, :]
                return reps
        else:
            for i, layer in enumerate(self.head_layers):
                if self.head_norm:
                    x = self.norm_layers[i](x)
                x = relu(x)
                x = layer(x)
            return x

    def forward_batchwise(
        self, x, batch_size, use_mask=False, use_continuous_mask=False
    ):
        device = next(self.parameters()).device
        with torch.no_grad():
            if self.n_head_layers == 0:
                rep_dim = self.model_dim
                number_of_reps = 1
            elif self.n_head_layers > 0:
                rep_dim = self.output_dim
                number_of_reps = self.n_head_layers + 1
            out = torch.empty(x.size(0), number_of_reps, rep_dim)
            idx_list = torch.split(torch.arange(x.size(0)), batch_size)
            for idx in idx_list:
                output = (
                    self(
                        x[idx].to(device),
                        use_mask=use_mask,
                        use_continuous_mask=use_continuous_mask,
                        mult_reps=True,
                    )
                    .detach()
                    .cpu()
                )
                out[idx] = output
        return out

    def make_mask(self, pT_zero):
        """
        Input: batch of bools of whether pT=0, shape (batchsize, n_constit)
        Output: mask for transformer model which masks out constituents with pT=0, shape (batchsize*n_transformer_heads, n_constit, n_constit)
        mask is added to attention output before softmax: 0 means value is unchanged, -inf means it will be masked
        """
        n_constit = pT_zero.size(1)
        pT_zero = torch.repeat_interleave(pT_zero, self.n_heads, axis=0)
        pT_zero = torch.repeat_interleave(pT_zero[:, None], n_constit, axis=1)
        mask = torch.zeros(pT_zero.size(0), n_constit, n_constit)
        mask[pT_zero] = -np.inf
#         print(f"mask: {mask}")
        return mask

    def make_continuous_mask(self, pT):
        """
        Input: batch of pT values, shape (batchsize, n_constit)
        Output: mask for transformer model: -1/pT, shape (batchsize*n_transformer_heads, n_constit, n_constit)
        mask is added to attention output before softmax: 0 means value is unchanged, -inf means it will be masked
        intermediate values mean it is partly masked
        This function implements IR safety in the transformer
        """
#         print(f"pT : {pT}")
        n_constit = pT.size(1)
        pT_reshape = torch.repeat_interleave(pT, self.n_heads, axis=0)
        pT_reshape = torch.repeat_interleave(pT_reshape[:, None], n_constit, axis=1)
        # mask = -1/pT_reshape
        mask = 0.5 * torch.log(pT_reshape)
#         print(f"mask: {mask}")
        return mask


In [5]:
def load_data(dataset_path, flag, n_files=-1):
    if args.full_kinematics:
        data_files = glob.glob(f"{dataset_path}/{flag}/processed/7_features_raw/data/*")
    else:
        data_files = glob.glob(f"{dataset_path}/{flag}/processed/3_features_raw/data/*")

    data = []
    for i, file in enumerate(data_files):
        if args.full_kinematics:
            data.append(np.load(f"{dataset_path}/{flag}/processed/7_features_raw/data/data_{i}.npy")) 
        else:
            data.append(np.load(f"{dataset_path}/{flag}/processed/3_features_raw/data/data_{i}.pt")) 
        print(f"--- loaded file {i} from `{flag}` directory")
        if n_files != -1 and i == n_files - 1:
            break

    return data


def load_labels(dataset_path, flag, n_files=-1):
    data_files = glob.glob(f"{dataset_path}/{flag}/processed/7_features_raw/labels/*")

    data = []
    for i, file in enumerate(data_files):
        data.append(np.load(f"{dataset_path}/{flag}/processed/7_features_raw/labels/labels_{i}.npy"))
        print(f"--- loaded label file {i} from `{flag}` directory")
        if n_files != -1 and i == n_files - 1:
            break

    return data

In [6]:
parser = argparse.ArgumentParser()
args = parser.parse_args(args=[])

In [7]:
inpput_dim = 7
args.sbratio = 1
args.output_dim = 1000
args.model_dim = 1000 
args.n_heads = 4
args.dim_feedforward= 1000
args.n_layers= 4 
args.learning_rate = 0.00005 
args.n_head_layers = 2 
args.opt = "adam"
args.label = "test-aug-7-300"
args.load_path = f"/ssl-jet-vol-v2/JetCLR/models/experiments/{args.label}/final_model.pt"
args.trs = True
args.mask = False
args.cmask = True
args.batch_size = 128
args.trsw = 0.1
args.full_kinematics = True
args.num_files = 1

In [8]:
print( "loading data")
data = load_data("/ssl-jet-vol-v2/toptagging", "train", args.num_files)
labels = load_labels("/ssl-jet-vol-v2/toptagging", "train", args.num_files)
tr_dat_in = np.concatenate(data, axis=0)  # Concatenate along the first axis
tr_lab_in = np.concatenate(labels, axis=0)
tr_dat_in = tr_dat_in[:10000]
tr_lab_in = tr_lab_in[:10000]
print(tr_lab_in.shape)

loading data
--- loaded file 0 from `train` directory
--- loaded label file 0 from `train` directory
(10000,)


In [9]:
# input dim to the transformer -> (pt,eta,phi)
input_dim = tr_dat_in.shape[1]
print("input_dim: ", input_dim)

# creating the training dataset
print( "shuffling data and doing the S/B split", flush=True )
tr_bkg_dat = tr_dat_in[ tr_lab_in==0 ].copy()
tr_sig_dat = tr_dat_in[ tr_lab_in==1 ].copy()
nbkg_tr = int( tr_bkg_dat.shape[0] )
nsig_tr = int( args.sbratio * nbkg_tr )
list_tr_dat = list( tr_bkg_dat[ 0:nbkg_tr ] ) + list( tr_sig_dat[ 0:nsig_tr ] )
list_tr_lab = [ 0 for i in range( nbkg_tr ) ] + [ 1 for i in range( nsig_tr ) ]
ldz_tr = list( zip( list_tr_dat, list_tr_lab ) )
random.shuffle( ldz_tr )
tr_dat, tr_lab = zip( *ldz_tr )
# reducing the training data
tr_dat = np.array( tr_dat )
tr_lab = np.array( tr_lab )

# create two validation sets: 
# one for training the linear classifier test (LCT)
# and one for testing on it
# we will do this just with tr_dat_in, but shuffled and split 50/50
# this should be fine because the jetCLR training doesn't use labels
# we want the LCT to use S/B=1 all the time
list_vl_dat = list( tr_dat_in.copy() )
list_vl_lab = list( tr_lab_in.copy() )
ldz_vl = list( zip( list_vl_dat, list_vl_lab ) )
random.shuffle( ldz_vl )
vl_dat, vl_lab = zip( *ldz_vl )
vl_dat = np.array( vl_dat )
vl_lab = np.array( vl_lab )
vl_len = vl_dat.shape[0]
vl_split_len = int( vl_len/2 )
vl_dat_1 = vl_dat[ 0:vl_split_len ]
vl_lab_1 = vl_lab[ 0:vl_split_len ]
vl_dat_2 = vl_dat[ -vl_split_len: ]
vl_lab_2 = vl_lab[ -vl_split_len: ]

input_dim:  7
shuffling data and doing the S/B split


In [10]:
# set-up parameters for the LCT
linear_input_size = args.output_dim
linear_n_epochs = 750
linear_learning_rate = 0.001
linear_batch_size = 128

# initialise the network
print( "initialising the network", flush=True )
%memit net = Transformer( input_dim, args.model_dim, args.output_dim, args.n_heads, args.dim_feedforward, args.n_layers, args.learning_rate, args.n_head_layers, dropout=0.1, opt=args.opt, log=args.full_kinematics )
# send network to device
device = torch.device( "cuda" if torch.cuda.is_available() else "cpu" )
net.to( device )
# print(net)
net.load_state_dict(torch.load(f"{args.load_path}"))

initialising the network
peak memory: 1126.80 MiB, increment: 107.85 MiB


<All keys matched successfully>

In [11]:
# import sys

# def bytes_to_gb(size_in_bytes):
#     """Convert bytes to gigabytes."""
#     return size_in_bytes / (1024**3)

# global_items = list(globals().items())  # Create a list of global items
# threshold_mb = 10  # Threshold in MB
# threshold_bytes = threshold_mb * (1024**2)  # Convert threshold to bytes

# print(f"{'Variable Name':<20} {'Size (GB)':>10}")
# print('-' * 32)

# for var_name, value in global_items:
#     size_in_bytes = sys.getsizeof(value)
#     if size_in_bytes > threshold_bytes:
#         size_in_gb = bytes_to_gb(size_in_bytes)
#         print(f"{var_name:<20} {size_in_gb:>10.6f}")

In [12]:
import gc
print( "starting the final LCT run", flush=True )
print("obtaining representations")
# evaluate the network on the testing data, applying some augmentations first if it's required
# if args.trs:
#     vl_dat_1 = translate_jets( vl_dat_1, width=args.trsw )
#     vl_dat_2 = translate_jets( vl_dat_2, width=args.trsw )
with torch.no_grad():
    net.eval()
    #vl_reps_1 = F.normalize( net.forward_batchwise( torch.Tensor( vl_dat_1 ).transpose(1,2), args.batch_size, use_mask=args.mask, use_continuous_mask=args.cmask ).detach().cpu(), dim=-1 ).numpy()
    #vl_reps_2 = F.normalize( net.forward_batchwise( torch.Tensor( vl_dat_2 ).transpose(1,2), args.batch_size, use_mask=args.mask, use_continuous_mask=args.cmask ).detach().cpu(), dim=-1 ).numpy()
    vl_reps_1 = net.forward_batchwise( torch.Tensor( vl_dat_1 ).transpose(1,2), args.batch_size, use_mask=args.mask, use_continuous_mask=args.cmask ).detach().cpu().numpy()
    vl_reps_2 = net.forward_batchwise( torch.Tensor( vl_dat_2 ).transpose(1,2), args.batch_size, use_mask=args.mask, use_continuous_mask=args.cmask ).detach().cpu().numpy()
    net.train()
#     del vl_dat_1, vl_dat_2
#     del net
#     gc.collect()
print("finished obtaining representations, starting LCT")

starting the final LCT run
obtaining representations
finished obtaining representations, starting LCT


In [13]:
del tr_dat_in, tr_bkg_dat, tr_sig_dat, tr_dat, vl_dat
gc.collect()

9

In [14]:
# final LCT for each rep layer
auc_lst, imt_lst = [], []
for run in range(3):
    for i in range(vl_reps_1.shape[1]):
#         if i == 1:
        out_dat_f, out_lbs_f, losses_f, val_losses_f = linear_classifier_test( linear_input_size, linear_batch_size, linear_n_epochs, "adam", linear_learning_rate, vl_reps_1[:,i,:], np.expand_dims(vl_lab_1, axis=1), vl_reps_2[:,i,:], np.expand_dims(vl_lab_2, axis=1) )
        auc, imtafe = get_perf_stats( out_lbs_f, out_dat_f )
        ep=0
        step_size = 25
        for (lss, val_lss) in zip(losses_f[::step_size], val_losses_f):
            print( f"(rep layer {i}) epoch: " + str( ep ) + ", loss: " + str( lss ) + ", val loss: " + str( val_lss ), flush=True)
            ep+=step_size
        print( f"run {run} (rep layer {i}) auc: "+str( round(auc, 4) ), flush=True )
        print( f"run {run} (rep layer {i}) imtafe: "+str( round(imtafe, 1) ), flush=True)
        auc_lst.append(round(auc, 4))
        imt_lst.append(round(imtafe, 1))

Device 0: NVIDIA A10
49.86111
re-initialized LCT
8.639498
(rep layer 0) epoch: 0, loss: 29.526554, val loss: 15.651224
(rep layer 0) epoch: 25, loss: 8.260195, val loss: 14.25308
(rep layer 0) epoch: 50, loss: 8.639498, val loss: 10.515268
(rep layer 0) epoch: 75, loss: 11.006305, val loss: 9.1581745
(rep layer 0) epoch: 100, loss: 4.239189, val loss: 7.8309126
(rep layer 0) epoch: 125, loss: 4.095844, val loss: 7.1283646
(rep layer 0) epoch: 150, loss: 2.9051363, val loss: 5.998016
(rep layer 0) epoch: 175, loss: 3.8697696, val loss: 6.7341166
(rep layer 0) epoch: 200, loss: 5.6250305, val loss: 6.497891
(rep layer 0) epoch: 225, loss: 2.9808502, val loss: 5.307939
(rep layer 0) epoch: 250, loss: 2.3099823, val loss: 6.3966956
(rep layer 0) epoch: 275, loss: 2.1831696, val loss: 8.391468
(rep layer 0) epoch: 300, loss: 1.563565, val loss: 7.41227
(rep layer 0) epoch: 325, loss: 1.8401449, val loss: 6.4278436
(rep layer 0) epoch: 350, loss: 1.7700554, val loss: 6.4312167
(rep layer 0) 

(rep layer 1) epoch: 50, loss: 0.9999418, val loss: 5.523946
(rep layer 1) epoch: 75, loss: 8.401289, val loss: 3.6073213
(rep layer 1) epoch: 100, loss: 0.92329884, val loss: 3.9422877
(rep layer 1) epoch: 125, loss: 2.1512856, val loss: 1.9486511
(rep layer 1) epoch: 150, loss: 1.1059268, val loss: 5.268175
(rep layer 1) epoch: 175, loss: 1.003024, val loss: 3.9997218
(rep layer 1) epoch: 200, loss: 0.6518978, val loss: 2.9528856
(rep layer 1) epoch: 225, loss: 1.2670442, val loss: 3.4426644
(rep layer 1) epoch: 250, loss: 2.1813362, val loss: 1.5459422
(rep layer 1) epoch: 275, loss: 1.0555383, val loss: 5.886836
(rep layer 1) epoch: 300, loss: 1.0326312, val loss: 1.1928298
(rep layer 1) epoch: 325, loss: 0.8550622, val loss: 1.6839293
(rep layer 1) epoch: 350, loss: 0.80870634, val loss: 3.482765
(rep layer 1) epoch: 375, loss: 0.5692836, val loss: 16.092722
(rep layer 1) epoch: 400, loss: 0.93458337, val loss: 1.3406739
(rep layer 1) epoch: 425, loss: 0.6674718, val loss: 1.01278

(rep layer 2) epoch: 125, loss: 0.2655799, val loss: 0.41207266
(rep layer 2) epoch: 150, loss: 0.2828299, val loss: 0.39262837
(rep layer 2) epoch: 175, loss: 0.26476574, val loss: 0.4065502
(rep layer 2) epoch: 200, loss: 0.26954925, val loss: 0.41225624
(rep layer 2) epoch: 225, loss: 0.26688206, val loss: 0.4406068
(rep layer 2) epoch: 250, loss: 0.26234972, val loss: 0.44570258
(rep layer 2) epoch: 275, loss: 0.25896657, val loss: 0.3983139
(rep layer 2) epoch: 300, loss: 0.27258623, val loss: 0.41636336
(rep layer 2) epoch: 325, loss: 0.2601922, val loss: 0.40453732
(rep layer 2) epoch: 350, loss: 0.26122904, val loss: 0.40331036
(rep layer 2) epoch: 375, loss: 0.26159173, val loss: 0.4393117
(rep layer 2) epoch: 400, loss: 0.26568615, val loss: 0.41451687
(rep layer 2) epoch: 425, loss: 0.25714728, val loss: 0.3981221
(rep layer 2) epoch: 450, loss: 0.2519011, val loss: 0.40180925
(rep layer 2) epoch: 475, loss: 0.25791037, val loss: 0.42515492
(rep layer 2) epoch: 500, loss: 0.

In [15]:
vl_reps_1[:10, 1, :]

array([[-152.35652  ,  -31.516916 ,  -30.420609 , ...,  -35.251858 ,
           2.250964 ,   -4.41827  ],
       [-103.33725  ,  -59.772614 ,  -26.388754 , ...,  -87.4075   ,
        -142.72273  , -135.32193  ],
       [ -70.02197  ,  -28.411247 ,  -15.907927 , ...,  -47.56984  ,
        -115.01345  , -115.87604  ],
       ...,
       [-106.91914  ,  -37.978436 ,  -23.446894 , ..., -110.92433  ,
        -109.37837  , -101.36799  ],
       [ -37.637566 ,  -66.314896 ,  -17.828943 , ..., -167.56693  ,
         -94.2389   ,  -95.51067  ],
       [ -80.782295 ,  -64.91017  ,  -51.30025  , ..., -157.12285  ,
         -66.60137  ,   -7.9691277]], dtype=float32)

In [19]:
auc_lst.sort()

In [20]:
imt_lst.sort()

In [21]:
auc_lst

[0.9001, 0.9098, 0.9159, 0.9179, 0.92, 0.9278, 0.9322, 0.9323, 0.9329]

In [22]:
imt_lst

[11.8, 15.4, 15.7, 16.4, 17.0, 18.4, 19.0, 19.1, 21.9]

In [26]:
print("\n--- hyper-parameters ---")


--- hyper-parameters ---
