In [1]:
# Perform the Linear Classifier Test (LCT) on the representations learned by VICReg.

# load standard python modules
import argparse
from datetime import datetime
import copy
import sys
import os
import glob
import numpy as np
import matplotlib.pyplot as plt
import random
import time
import tqdm
from pathlib import Path

# load torch modules
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

In [2]:
# load custom modules required for jetCLR training
from src.models.jet_augs import (
    rotate_jets,
    distort_jets,
    rescale_pts,
    crop_jets,
    translate_jets,
    collinear_fill_jets,
)
from src.models.transformer import Transformer
from src.features.perf_eval import get_perf_stats, linear_classifier_test
from src.models.pretrain_vicreg import VICReg

In [3]:
!pwd

/ssl-jet-vol-v2/JetCLR_VICReg/notebooks


In [4]:
project_dir = "/ssl-jet-vol-v2/JetCLR_VICReg"

# load the data files and the label files from the specified directory
def load_data(dataset_path, flag, n_files=-1):
    data_files = glob.glob(f"{dataset_path}/{flag}/processed/3_features/*")

    data = []
    for i, file in enumerate(data_files):
        data += torch.load(f"{dataset_path}/{flag}/processed/3_features/data_{i}.pt")
        print(f"--- loaded data file {i} from `{flag}` directory")
        if n_files != -1 and i == n_files - 1:
            break

    return data

def load_labels(dataset_path, flag, n_files=-1):
    data_files = glob.glob(f"{dataset_path}/{flag}/processed/3_features/*")

    data = []
    for i, file in enumerate(data_files):
        data += torch.load(f"{dataset_path}/{flag}/processed/3_features/labels_{i}.pt")
        print(f"--- loaded label file {i} from `{flag}` directory")
        if n_files != -1 and i == n_files - 1:
            break

    return data

In [5]:
def get_backbones(args):
    x_backbone = Transformer(input_dim=args.x_inputs)
    y_backbone = x_backbone if args.shared else copy.deepcopy(x_backbone)
    return x_backbone, y_backbone


def augmentation(args, x, device):
    """
    Applies all the augmentations specified in the args
    """
    # crop all jets to a fixed number of constituents (default=50)
    x = crop_jets(x, args.nconstit)
    x = rotate_jets(x, device)
    y = x.clone()
    if args.do_rotation:
        y = rotate_jets(y, device)
    if args.do_cf:
        y = collinear_fill_jets(np.array(y.cpu()), device)
        y = collinear_fill_jets(np.array(y.cpu()), device)
    if args.do_ptd:
        y = distort_jets(y, device, strength=args.ptst, pT_clip_min=args.ptcm)
    if args.do_translation:
        y = translate_jets(y, device, width=args.trsw)
        x = translate_jets(x, device, width=args.trsw)
    x = rescale_pts(x)  # [batch_size, 3, n_constit]
    y = rescale_pts(y)  # [batch_size, 3, n_constit]
    x = x.transpose(1, 2)  # [batch_size, 3, n_constit] -> [batch_size, n_constit, 3]
    y = y.transpose(1, 2)  # [batch_size, 3, n_constit] -> [batch_size, n_constit, 3]
    return x, y

In [6]:
parser = argparse.ArgumentParser()
args = parser.parse_args(args=[])

args.mask = False
args.cmask = True
args.epoch = 10
args.batch_size = 256
args.outdir = f"{project_dir}/models/"
args.label = "1-0.1-1e-4-100"
args.dataset_path = "/ssl-jet-vol-v2/toptagging"
args.eval_path = f"{project_dir}/models/model_performances/"
args.load_vicreg_path = f"{project_dir}/models/trained_models/"
args.num_train_files = 1
args.num_test_files = 1
args.shared = False
args.mlp = "256-256-256"
args.transform_inputs = 32
args.Do = 1000
args.hidden = 128
args.sim_coeff = 1.0
args.std_coeff = 0.1
args.cov_coeff = 1e-4
args.return_embedding = False
args.return_representation = True
args.do_translation = True
args.do_rotation = True
args.do_cf = True
args.do_ptd = True
args.nconstit = 50
args.ptst = 0.1
args.ptcm = 0.1
args.trsw = 1.0
args.return_all_losses = False

In [14]:
# define the global base device
world_size = torch.cuda.device_count()
if world_size:
    device = torch.device("cuda:0")
    for i in range(world_size):
        print(f"Device {i}: {torch.cuda.get_device_name(i)}")
else:
    device = "cpu"
    print("Device: CPU")
args.device = device

args.augmentation = augmentation

args.x_inputs = 3
args.y_inputs = 3

args.x_backbone, args.y_backbone = get_backbones(args)
args.return_representation = True
args.return_embedding = False
# load the desired trained VICReg model
model = VICReg(args).to(args.device)
model.load_state_dict(torch.load(f"{args.load_vicreg_path}/vicreg_{args.label}_best.pth"))

# load the training and testing dataset
data_train = load_data(args.dataset_path, "train", n_files=args.num_train_files)
data_test = load_data(args.dataset_path, "test", n_files=args.num_test_files)
labels_train = load_labels(args.dataset_path, "train", n_files=args.num_train_files)
labels_test = load_labels(args.dataset_path, "test", n_files=args.num_test_files)


# TODO: delete when pasting back to LCT.py
num_jets = 10000
data_train = data_train[:num_jets]
data_test = data_test[:num_jets]
labels_train = labels_train[:num_jets]
labels_test = labels_test[:num_jets]

# concatenate the training and testing datasets
data_train = torch.stack(data_train)
data_test = torch.stack(data_test)
labels_train = torch.tensor([t.item() for t in labels_train])
labels_test = torch.tensor([t.item() for t in labels_test])

n_train = data_train.shape[0]
n_test = data_test.shape[0]

batch_size = args.batch_size
train_its = int(n_train / batch_size)
test_its = int(n_test / batch_size)

# obtain the representations from the trained VICReg model
with torch.no_grad():
    model.eval()
    train_loader = DataLoader(data_train, args.batch_size)
    test_loader = DataLoader(data_test, args.batch_size)
    tr_reps = []
    pbar = tqdm.tqdm(train_loader, total=train_its)
    for i, batch in enumerate(pbar):
        batch = batch.to(args.device)
        tr_reps.append(model(batch)[0].detach().cpu().numpy())
        pbar.set_description(f"{i}")
    tr_reps = np.concatenate(tr_reps)
    te_reps = []
    pbar = tqdm.tqdm(test_loader, total=test_its)
    for i, batch in enumerate(pbar):
        batch = batch.to(args.device)
        te_reps.append(model(batch)[0].detach().cpu().numpy())
        pbar.set_description(f"{i}")
    te_reps = np.concatenate(te_reps)

# perform the linear classifier test (LCT) on the representations
i = 0
linear_input_size = tr_reps.shape[1]
linear_n_epochs = 750
linear_learning_rate = 0.001
linear_batch_size = 1024
out_dat_f, out_lbs_f, losses_f = linear_classifier_test( linear_input_size, linear_batch_size, linear_n_epochs, linear_learning_rate, tr_reps, labels_train, te_reps, labels_test )
auc, imtafe = get_perf_stats( out_lbs_f, out_dat_f )
ep=0
step_size = 25
for lss in losses_f[::step_size]:
    print( f"(rep layer {i}) epoch: " + str( ep ) + ", loss: " + str( lss ), flush=True)
    ep+=step_size
print( f"(rep layer {i}) auc: "+str( round(auc, 4) ), flush=True)
print( f"(rep layer {i}) imtafe: "+str( round(imtafe, 1) ), flush=True)
np.save( args.eval_path+f"{args.label}/linear_losses_{i}.npy", losses_f )
np.save( args.eval_path+f"{args.label}/test_linear_cl_{i}.npy", out_dat_f )
np.save( args.eval_path+f"{args.label}/test_linear_cl_labels_{i}.npy", out_lbs_f )

Device 0: NVIDIA A100 80GB PCIe MIG 1g.10gb
--- loaded data file 0 from `train` directory
--- loaded data file 0 from `test` directory
--- loaded label file 0 from `train` directory
--- loaded label file 0 from `test` directory


39: : 40it [00:28,  1.42it/s]                                                                                                          
39: : 40it [00:27,  1.44it/s]                                                                                                          


(rep layer 0) epoch: 0, loss: 46.482967
(rep layer 0) epoch: 25, loss: 50.817123
(rep layer 0) epoch: 50, loss: 50.811146
(rep layer 0) epoch: 75, loss: 50.8231
(rep layer 0) epoch: 100, loss: 50.811146
(rep layer 0) epoch: 125, loss: 50.763313
(rep layer 0) epoch: 150, loss: 50.867943
(rep layer 0) epoch: 175, loss: 50.787228
(rep layer 0) epoch: 200, loss: 50.77826
(rep layer 0) epoch: 225, loss: 50.882893
(rep layer 0) epoch: 250, loss: 50.775272
(rep layer 0) epoch: 275, loss: 50.802177
(rep layer 0) epoch: 300, loss: 50.80517
(rep layer 0) epoch: 325, loss: 50.78125
(rep layer 0) epoch: 350, loss: 50.748363
(rep layer 0) epoch: 375, loss: 50.817123
(rep layer 0) epoch: 400, loss: 50.882893
(rep layer 0) epoch: 425, loss: 50.77826
(rep layer 0) epoch: 450, loss: 50.79022
(rep layer 0) epoch: 475, loss: 50.77826
(rep layer 0) epoch: 500, loss: 50.799187
(rep layer 0) epoch: 525, loss: 50.84403
(rep layer 0) epoch: 550, loss: 50.817123
(rep layer 0) epoch: 575, loss: 50.799187
(rep l

In [8]:
data_test.shape

torch.Size([1000, 3, 50])

In [9]:
data_train = load_data(args.dataset_path, "train", n_files=args.num_train_files)

--- loaded data file 0 from `train` directory


In [10]:
data_train[0].shape

torch.Size([3, 50])

In [11]:
data_train_tensor = torch.stack(data_train)

In [12]:
data_train_tensor.shape

torch.Size([100001, 3, 50])

In [13]:
torch.tensor([t.item() for t in labels_train]).shape

torch.Size([1000])

(10000, 1000)