In [2]:
# !pip install torch_geometric

Collecting torch_geometric
  Downloading torch_geometric-2.3.1.tar.gz (661 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m661.6/661.6 kB[0m [31m12.8 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
[?25h  Installing build dependencies ... [?25ldone
[?25h  Getting requirements to build wheel ... [?25ldone
[?25h  Preparing metadata (pyproject.toml) ... [?25ldone
Building wheels for collected packages: torch_geometric
  Building wheel for torch_geometric (pyproject.toml) ... [?25ldone
[?25h  Created wheel for torch_geometric: filename=torch_geometric-2.3.1-py3-none-any.whl size=910454 sha256=fffb452be7a1221eda9109d82e4ba9885e2b67ba684edf0dfcfe60d8161caa21
  Stored in directory: /home/jovyan/.cache/pip/wheels/aa/16/a8/fd7737d723cc1eb8df023c016c262ff4520091e1b022f8c164
Successfully built torch_geometric
Installing collected packages: torch_geometric
Successfully installed torch_geometric-2.3.1


In [23]:
import argparse
import copy
import glob
import os
from pathlib import Path

import numpy as np
import torch
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
import tqdm
import yaml
from torch import nn

from src.models.transformer import Transformer
from src.models.jet_augs import *
from src.data.convert_data import convert_x

In [2]:
project_dir = "/ssl-jet-vol-v2/JetCLR_VICReg"

In [3]:
def contains_nan(tensor):
    has_nan = torch.isnan(tensor)
    return torch.any(has_nan).item()

In [4]:
class VICReg(nn.Module):
    def __init__(self, args):
        super().__init__()
        self.args = args
        self.num_features = int(
            args.mlp.split("-")[-1]
        )  # size of the last layer of the MLP projector
        self.x_transform = nn.Sequential(
            nn.BatchNorm1d(args.x_inputs),
            nn.Linear(args.x_inputs, args.transform_inputs),
            nn.BatchNorm1d(args.transform_inputs),
            nn.ReLU(),
        )
        self.y_transform = nn.Sequential(
            nn.BatchNorm1d(args.y_inputs),
            nn.Linear(args.y_inputs, args.transform_inputs),
            nn.BatchNorm1d(args.transform_inputs),
            nn.ReLU(),
        )
        self.augmentation = args.augmentation
        self.x_backbone = args.x_backbone
        self.y_backbone = args.y_backbone
        self.N_x = self.x_backbone.input_dim
        self.N_y = self.y_backbone.input_dim
        self.embedding = args.Do
        self.return_embedding = args.return_embedding
        self.return_representation = args.return_representation
        self.x_projector = Projector(args.mlp, self.embedding)
        self.y_projector = (
            self.x_projector if args.shared else copy.deepcopy(self.x_projector)
        )

    def forward(self, x):
        """
        x -> x_aug -> (x_xform) -> x_rep -> x_emb
        y -> y_aug -> (y_xform) -> y_rep -> y_emb
        _aug: augmented
        _xform: transformed by linear layer (skipped because it destroys the zero padding)
        _rep: backbone representation
        _emb: projected embedding
        """
        x_aug, y_aug = self.augmentation(
            self.args, x, self.args.device
        )  # [batch_size, n_constit, 3]
#         print(f"x_aug contains nan: {contains_nan(x_aug)}")
#         print(f"y_aug contains nan: {contains_nan(y_aug)}")

        # x_xform = self.x_transform.to(torch.double)(
        #     x_aug.x.double()
        # )  # [batch_size, n_constit, transform_inputs]?
        # y_xform = self.y_transform.to(torch.double)(
        #     y_aug.x.double()
        # )  # [batch_size, n_constit, transform_inputs]?

        x_rep = self.x_backbone(
            x_aug, use_mask=self.args.mask, use_continuous_mask=self.args.cmask
        )  # [batch_size, output_dim]
        y_rep = self.y_backbone(
            y_aug, use_mask=self.args.mask, use_continuous_mask=self.args.cmask
        )  # [batch_size, output_dim]
#         print(f"x_rep contains nan: {contains_nan(x_rep)}")
#         print(f"y_rep contains nan: {contains_nan(y_rep)}")
        if self.return_representation:
            return x_rep, y_rep

        x_emb = self.x_projector(x_rep)  # [batch_size, embedding_size]
        y_emb = self.y_projector(y_rep)  # [batch_size, embedding_size]
#         print(f"x_emb contains nan: {contains_nan(x_emb)}")
#         print(f"y_emb contains nan: {contains_nan(y_emb)}")
        if self.return_embedding:
            return x_emb, y_emb
        x = x_emb
        y = y_emb
        repr_loss = F.mse_loss(x, y)

        x = x - x.mean(dim=0)
        y = y - y.mean(dim=0)

        std_x = torch.sqrt(x.var(dim=0) + 0.0001)
        std_y = torch.sqrt(y.var(dim=0) + 0.0001)
        std_loss = torch.mean(F.relu(1 - std_x)) / 2 + torch.mean(F.relu(1 - std_y)) / 2

        cov_x = (x.T @ x) / (self.args.batch_size - 1)
        cov_y = (y.T @ y) / (self.args.batch_size - 1)
        cov_loss = off_diagonal(cov_x).pow_(2).sum().div(
            self.num_features
        ) + off_diagonal(cov_y).pow_(2).sum().div(self.num_features)

        loss = (
            self.args.sim_coeff * repr_loss
            + self.args.std_coeff * std_loss
            + self.args.cov_coeff * cov_loss
        )
        if args.return_all_losses:
            return loss, repr_loss, std_loss, cov_loss
        else:
            return loss

In [14]:
def Projector(mlp, embedding):
    mlp_spec = f"{embedding}-{mlp}"
    layers = []
    f = list(map(int, mlp_spec.split("-")))
    for i in range(len(f) - 2):
        layers.append(nn.Linear(f[i], f[i + 1]))
        layers.append(nn.BatchNorm1d(f[i + 1]))
        layers.append(nn.ReLU())
    layers.append(nn.Linear(f[-2], f[-1], bias=False))
    return nn.Sequential(*layers)


def off_diagonal(x):
    n, m = x.shape
    assert n == m
    return x.flatten()[:-1].view(n - 1, n + 1)[:, 1:].flatten()


def get_backbones(args):
    x_backbone = Transformer(input_dim=args.x_inputs)
    y_backbone = x_backbone if args.shared else copy.deepcopy(x_backbone)
    return x_backbone, y_backbone


def augmentation(args, x, device):
    """
    Applies all the augmentations specified in the args
    """
    # cropping to 50 particles is already done in data preprocessing
    # crop all jets to a fixed number of constituents (default=50)
#     x = crop_jets(x, args.nconstit)
    x = rotate_jets(x, device)
    y = x.clone()
    if args.do_rotation:
        y = rotate_jets(y, device)
    if args.do_cf:
        y = collinear_fill_jets(np.array(y.cpu()), device)
        y = collinear_fill_jets(np.array(y.cpu()), device)
    if args.do_ptd:
        y = distort_jets(y, device, strength=args.ptst, pT_clip_min=args.ptcm)
    if args.do_translation:
        y = translate_jets(y, device, width=args.trsw)
        x = translate_jets(x, device, width=args.trsw)
    x = rescale_pts(x)  # [batch_size, 3, n_constit]
    y = rescale_pts(y)  # [batch_size, 3, n_constit]
    x = x.transpose(1, 2)  # [batch_size, 3, n_constit] -> [batch_size, n_constit, 3]
    y = y.transpose(1, 2)  # [batch_size, 3, n_constit] -> [batch_size, n_constit, 3]
    return x, y


# load the datafiles
def load_data(dataset_path, flag, n_files=-1):
    data_files = glob.glob(f"{dataset_path}/{flag}/processed/3_features/*")

    data = []
    for i, file in enumerate(data_files):
        data += torch.load(f"{dataset_path}/{flag}/processed/3_features/data_{i}.pt")
        print(f"--- loaded file {i} from `{flag}` directory")
        if n_files != -1 and i == n_files - 1:
            break

    return data

In [6]:
parser = argparse.ArgumentParser()
args = parser.parse_args(args=[])

In [29]:
args.mask = False
args.cmask = True
args.epoch = 10
args.batch_size = 256
args.outdir = f"{project_dir}/models/"
args.label = "notebook_test"
args.dataset_path = "/ssl-jet-vol-v2/toptagging"
args.num_train_files = 1
args.num_val_files = 1
args.shared = False
args.mlp = "256-256-256"
args.transform_inputs = 32
args.Do = 1000
args.hidden = 128
args.sim_coeff = 25.0
args.std_coeff = 25.0
args.cov_coeff = 1.0
args.return_embedding = False
args.return_representation = False
args.do_translation = True
args.do_rotation = True
args.do_cf = True
args.do_ptd = True
args.nconstit = 50
args.ptst = 0.1
args.ptcm = 0.1
args.trsw = 1.0
args.return_all_losses = True

In [8]:
# define the global base device
world_size = torch.cuda.device_count()
if world_size:
    device = torch.device("cuda:0")
    for i in range(world_size):
        print(f"Device {i}: {torch.cuda.get_device_name(i)}")
else:
    device = "cpu"
    print("Device: CPU")
args.device = device

Device 0: NVIDIA A100 80GB PCIe MIG 1g.10gb


In [15]:
n_epochs = args.epoch
batch_size = args.batch_size
outdir = args.outdir
label = args.label


model_loc = f"{outdir}/trained_models/"
model_perf_loc = f"{outdir}/model_performances/{label}"
model_dict_loc = f"{outdir}/model_dicts/"
os.system(
    f"mkdir -p {model_loc} {model_perf_loc} {model_dict_loc}"
)  # -p: create parent dirs if needed, exist_ok

# prepare data
data_train = load_data(args.dataset_path, "train", n_files=args.num_train_files)
data_valid = load_data(args.dataset_path, "val", n_files=args.num_val_files)

n_train = len(data_train)
n_val = len(data_valid)

--- loaded file 0 from `train` directory
--- loaded file 0 from `val` directory


In [22]:
len(data_train)

100001

In [20]:
train_file = torch.load(f"{args.dataset_path}/train/processed/3_features/data_0.pt")

In [21]:
train_file.shape

torch.Size([100001, 3, 50])

In [24]:
train_loader = DataLoader(data_train, batch_size)

In [25]:
for _, batch in enumerate(train_loader):
    break

In [27]:
batch.shape

torch.Size([256, 3, 50])

In [30]:
args.augmentation = augmentation

args.x_inputs = 3
args.y_inputs = 3

args.x_backbone, args.y_backbone = get_backbones(args)
model = VICReg(args).to(args.device)

train_its = int(n_train / batch_size)
val_its = int(n_val / batch_size)

optimizer = optim.Adam(model.parameters(), lr=0.0001)
loss_val_epochs = []  # loss recorded for each epoch
repr_loss_val_epochs, std_loss_val_epochs, cov_loss_val_epochs = [], [], []
# invariance, variance, covariance loss recorded for each epoch
loss_val_batches = []  # loss recorded for each batch
loss_train_epochs = []  # loss recorded for each epoch
repr_loss_train_epochs, std_loss_train_epochs, cov_loss_train_epochs = [], [], []
# invariance, variance, covariance loss recorded for each epoch
loss_train_batches = []  # loss recorded for each batch
l_val_best = 999999
for m in range(n_epochs):
    print(f"Epoch {m}\n")
    loss_train_epoch = []  # loss recorded for each batch in this epoch
    repr_loss_train_epoch, std_loss_train_epoch, cov_loss_train_epoch = [], [], []
    # invariance, variance, covariance loss recorded for each batch in this epoch
    loss_val_epoch = []  # loss recorded for each batch in this epoch
    repr_loss_val_epoch, std_loss_val_epoch, cov_loss_val_epoch = [], [], []
    # invariance, variance, covariance loss recorded for each batch in this epoch

    train_loader = DataLoader(data_train, batch_size)
    model.train()
    pbar = tqdm.tqdm(train_loader, total=train_its)
#     for _, batch in tqdm.tqdm(enumerate(train_loader)):
    for _, batch in enumerate(pbar):
        batch = batch.to(args.device)
        optimizer.zero_grad()
        if args.return_all_losses:
            loss, repr_loss, std_loss, cov_loss = model.forward(batch)
#             print(loss, repr_loss, std_loss, cov_loss)
            repr_loss_train_epoch.append(repr_loss.detach().cpu().item())
            std_loss_train_epoch.append(std_loss.detach().cpu().item())
            cov_loss_train_epoch.append(cov_loss.detach().cpu().item())
        else:
            loss = model.forward(batch)
        loss.backward()
        optimizer.step()
        loss = loss.detach().cpu().item()
        loss_train_batches.append(loss)
        loss_train_epoch.append(loss)
        pbar.set_description(f"Training loss: {loss:.4f}")
        print(f"Training loss: {loss:.4f}")
    model.eval()
    valid_loader = DataLoader(data_valid, batch_size)
    pbar = tqdm.tqdm(valid_loader, total=val_its)
#     for _, batch in tqdm.tqdm(enumerate(valid_loader)):
    for _, batch in enumerate(pbar):
        batch = batch.to(args.device)
        batch = convert_x(batch, args.device)  # [batch_size, 3, n_constit]
        if args.return_all_losses:
            loss, repr_loss, std_loss, cov_loss = model.forward(batch)
            repr_loss_val_epoch.append(repr_loss.detach().cpu().item())
            std_loss_val_epoch.append(std_loss.detach().cpu().item())
            cov_loss_val_epoch.append(cov_loss.detach().cpu().item())
            loss = loss.detach().cpu().item()
        else:
            loss = model.forward(batch).cpu().item()
        loss_val_batches.append(loss)
        loss_val_epoch.append(loss)
        pbar.set_description(f"Validation loss: {loss:.4f}")
        print(f"Validation loss: {loss:.4f}")
    l_val = np.mean(np.array(loss_val_epoch))
    l_train = np.mean(np.array(loss_train_epoch))
    loss_val_epochs.append(l_val)
    loss_train_epochs.append(l_train)

    if args.return_all_losses:
        repr_l_val = np.mean(np.array(repr_loss_val_epoch))
        repr_l_train = np.mean(np.array(repr_loss_train_epoch))
        std_l_val = np.mean(np.array(std_loss_val_epoch))
        std_l_train = np.mean(np.array(std_loss_train_epoch))
        cov_l_val = np.mean(np.array(cov_loss_val_epoch))
        cov_l_train = np.mean(np.array(cov_loss_train_epoch))

        repr_loss_val_epochs.append(repr_l_val)
        std_loss_val_epochs.append(std_l_val)
        cov_loss_val_epochs.append(cov_l_val)

        repr_loss_train_epochs.append(repr_l_train)
        std_loss_train_epochs.append(std_l_train)
        cov_loss_train_epochs.append(cov_l_train)
    # save the model
    if l_val < l_val_best:
        print("New best model")
        l_val_best = l_val
        torch.save(model.state_dict(), f"{model_loc}/vicreg_{label}_best.pth")
    torch.save(model.state_dict(), f"{model_loc}/vicreg_{label}_last.pth")
# After training

np.save(
    f"{model_perf_loc}/vicreg_{label}_loss_train_epochs.npy",
    np.array(loss_train_epochs),
)
np.save(
    f"{model_perf_loc}/vicreg_{label}_loss_train_batches.npy",
    np.array(loss_train_batches),
)
np.save(
    f"{model_perf_loc}/vicreg_{label}_loss_val_epochs.npy",
    np.array(loss_val_epochs),
)
np.save(
    f"{model_perf_loc}/vicreg_{label}_loss_val_batches.npy",
    np.array(loss_val_batches),
)
if args.return_all_losses:
    np.save(
        f"{model_perf_loc}/vicreg_{label}_repr_loss_train_epochs.npy",
        np.array(repr_loss_train_epochs),
    )
    np.save(
        f"{model_perf_loc}/vicreg_{label}_std_loss_train_epochs.npy",
        np.array(std_loss_train_epochs),
    )
    np.save(
        f"{model_perf_loc}/vicreg_{label}_cov_loss_train_epochs.npy",
        np.array(cov_loss_train_epochs),
    )
    np.save(
        f"{model_perf_loc}/vicreg_{label}_repr_loss_val_epochs.npy",
        np.array(repr_loss_val_epochs),
    )
    np.save(
        f"{model_perf_loc}/vicreg_{label}_std_loss_val_epochs.npy",
        np.array(std_loss_val_epochs),
    )
    np.save(
        f"{model_perf_loc}/vicreg_{label}_cov_loss_val_epochs.npy",
        np.array(cov_loss_val_epochs),
    )

Epoch 0



Training loss: 22.6040:   0%|▏                                                                         | 1/390 [00:02<13:15,  2.05s/it]

Training loss: 22.6040


Training loss: 22.8887:   1%|▍                                                                         | 2/390 [00:04<13:02,  2.02s/it]

Training loss: 22.8887


Training loss: 22.6074:   1%|▌                                                                         | 3/390 [00:06<12:57,  2.01s/it]

Training loss: 22.6074


Training loss: 22.2898:   1%|▊                                                                         | 4/390 [00:08<12:47,  1.99s/it]

Training loss: 22.2898


Training loss: 22.1764:   1%|▉                                                                         | 5/390 [00:09<12:47,  1.99s/it]

Training loss: 22.1764


Training loss: 22.0609:   2%|█▏                                                                        | 6/390 [00:12<13:17,  2.08s/it]

Training loss: 22.0609


Training loss: 22.0753:   2%|█▎                                                                        | 7/390 [00:14<13:06,  2.05s/it]

Training loss: 22.0753


Training loss: 22.0753:   2%|█▎                                                                        | 7/390 [00:15<14:04,  2.20s/it]


KeyboardInterrupt: 

## Inspect augmented jet

In [None]:
for _, batch in tqdm.tqdm(enumerate(train_loader)):
    batch = batch.to(args.device)
    batch = convert_x(batch, args.device)
    break

In [None]:
batch.shape

In [None]:
x_aug, y_aug = augmentation(args, batch, args.device)

In [None]:
x_aug.shape

In [None]:
has_nan = torch.isnan(y_aug)
contains_nan = torch.any(has_nan)
print(contains_nan.item())

In [None]:
contains_nan(y_aug)

In [None]:
model.forward(batch)