https://arxiv.org/abs/2310.16121

https://github.com/abogatskiy/PELICAN-nano

In [None]:
import pandas as pd
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt

In [None]:
def load():
    df = pd.read_hdf("/home/nikolai/data/TopTagingML/train.h5", key="table", stop=100000)
    # list of arrays representation (without the padding)
    arrays = [x[~(x==0).all(axis=-1)] for x in df.iloc[:, :200*4].to_numpy().reshape(-1, 200, 4)]
    y = df.is_signal_new.to_numpy()
    return arrays, y

In [None]:
arrays, y = load()

In [None]:
len(arrays)

In [None]:
p4 = arrays[0]
p4.shape

In [None]:
class Dataset(torch.utils.data.Dataset):
    def __init__(self, arrays, y):
        self.arrays = arrays
        self.y = y

    def __len__(self):
        return len(self.arrays)

    def __getitem__(self, i):
        return self.arrays[i], self.y[i]

In [None]:
ds = Dataset(arrays, y)

In [None]:
def collate_fn(batch):
    maxlen = max(len(x) for x, y in batch)
    out = torch.zeros(len(batch), maxlen, 4, dtype=torch.float32)
    mask = torch.zeros(len(batch), maxlen, dtype=bool)
    for i, (x, y) in enumerate(batch):
        out[i, :len(x)] = torch.from_numpy(x)
        mask[i, len(x):] = True
    return {"x": out, "y": torch.tensor([y for x, y in batch], dtype=torch.float32), "mask": mask}

In [None]:
dl = DataLoader(ds, batch_size=32, shuffle=True, collate_fn=collate_fn)

In [None]:
batch = next(iter(dl))

In [None]:
batch["x"].shape, batch["y"].shape, batch["mask"].shape

In [None]:
bx = batch["x"]

In [None]:
prod = bx[:, :, np.newaxis] * bx[:, np.newaxis, :]
prod.shape

In [None]:
dot2 = prod[..., 0] - prod[..., 1:].sum(axis=-1)
dot2.shape

In [None]:
len(dot2.ravel())

In [None]:
logdot2 = torch.log(1 + dot2)

In [None]:
plt.hist(logdot2.ravel().numpy(), bins=100);
#plt.yscale("log")

![image.png](attachment:88b26230-e1b0-4dbe-9648-c8b56634f4ea.png)

from [ML4Jets2024 slides](https://indico.cern.ch/event/1253794/contributions/5588625/attachments/2748386/4783069/ML4Jets23_PELICAN.pdf)

In [None]:
logdot2.shape

In [None]:
# sf
np.mean([len(x) for x in arrays])

In [None]:
sf = 1 / 50

In [None]:
rowsum = logdot2.sum(axis=1, keepdim=True) * sf
rowsum.shape

In [None]:
totsum = logdot2.sum(axis=(1, 2), keepdim=True) * sf**2
totsum.shape

In [None]:
N = batch["x"].shape[1]
N

In [None]:
rowsum

In [None]:
totsum.expand(-1, N, -1).squeeze(-1)

In [None]:
agg0 = logdot2
agg1 = rowsum.expand(-1, N, -1)
agg2 = agg1.transpose(-1, -2)
agg3 = torch.diag_embed(rowsum.squeeze(1))
agg4 = totsum.expand(-1, N, N)
agg5 = torch.diag_embed(totsum.expand(-1, N, -1).squeeze(-1))

In [None]:
agg3.shape

In [None]:
def six_aggs(x, sf=1/50):
    rowsum = x.sum(axis=1, keepdim=True) * sf
    totsum = x.sum(axis=(1, 2), keepdim=True) * sf**2
    N = x.shape[1]
    aggs = []
    aggs.append(x)
    aggs.append(rowsum.expand(-1, N, -1))
    aggs.append(aggs[0].transpose(-1, -2))
    aggs.append(torch.diag_embed(rowsum.squeeze(1)))
    aggs.append(totsum.expand(-1, N, N))
    aggs.append(torch.diag_embed(totsum.expand(-1, N, -1).squeeze(-1)))
    return torch.stack(aggs, -1)

In [None]:
aggs = six_aggs(logdot2)
aggs.shape

In [None]:
sq_mask = batch["mask"][:, :, np.newaxis] | batch["mask"][:, np.newaxis, :]
sq_mask[0]

In [None]:
sq_mask.shape

![image.png](attachment:fde72581-6064-4115-a82c-ac8a3409d60f.png)

In [None]:
c = nn.Linear(6, 2)(aggs).relu().masked_fill(sq_mask[..., np.newaxis], 0)
c.shape

In [None]:
ft = torch.cat([(c.sum((1, 2)) * sf**2), c.diagonal(dim1=1, dim2=2).sum(-1)], -1)
ft.shape

In [None]:
nn.Linear(4, 1)(ft).sigmoid()

In [None]:
class NanoPelican(nn.Module):
    def __init__(self, sf=1/50):
        super().__init__()
        self.linear1 = nn.Linear(6, 2)
        self.linear2 = nn.Linear(4, 1)
        self.sf = sf

    def forward(self, x, mask):
        mask = mask[:, :, np.newaxis] | mask[:, np.newaxis, :] # squared mask
        mask = mask[..., np.newaxis]
        x = x[:, :, np.newaxis] * x[:, np.newaxis, :] # pairwise prod
        x = x[..., 0] - x[..., 1:].sum(axis=-1) # minkowski dot squared
        x = torch.log(1 + x)
        x = six_aggs(x, sf=self.sf)
        x = self.linear1(x).relu().masked_fill(mask, 0)
        x = torch.cat([x.sum((1, 2)) * sf**2, x.diagonal(dim1=1, dim2=2).sum(-1) * sf], -1)
        x = self.linear2(x).sigmoid()
        return x

In [None]:
model = NanoPelican()

In [None]:
model(batch["x"], batch["mask"])

In [None]:
sum(par.numel() for par in model.parameters())

In [None]:
ds_train, ds_val = torch.utils.data.random_split(ds, [0.9, 0.1])
kwargs = dict(batch_size=32, collate_fn=collate_fn)
dl_train = DataLoader(ds_train, shuffle=True, **kwargs)
dl_val = DataLoader(ds_val, **kwargs)

In [None]:
optimizer = torch.optim.Adam(model.parameters())

In [None]:
from tqdm.auto import tqdm

In [None]:
history = []

In [None]:
def fit(model, optimizer, dl_train, dl_val, history, epochs=1):
    loss_fn = F.binary_cross_entropy_with_logits
    def forward(model, batch):
        return model(batch["x"], batch["mask"]).squeeze()

    def train_step(batch):
        model.train()
        optimizer.zero_grad()
        logits = forward(model, batch)
        loss = loss_fn(logits, batch["y"])
        loss.backward()
        optimizer.step()
        return loss.cpu().detach().item()

    def val_step(batch):
        model.eval()
        with torch.no_grad():
            logits = forward(model, batch)
            return loss_fn(logits, batch["y"]).cpu().item()

    for epoch in range(epochs):
        losses = []
        for batch in tqdm(dl_train):
            losses.append(train_step(batch))
        val_losses = []
        for batch in dl_val:
            val_losses.append(val_step(batch))
        history.append({"loss": np.mean(losses), "val_loss": np.mean(val_losses)})
        print(history[-1])

    return losses

In [None]:
fit(model, optimizer, dl_train, dl_val, history, epochs=5)

In [None]:
pd.DataFrame(history).plot()

In [None]:
def evaluate(model, dl):
    y_pred = []
    model.eval()
    for batch in tqdm(dl):
        with torch.no_grad():
            y_pred.append(model(batch["x"], mask=batch["mask"]).sigmoid().squeeze(1))
    return torch.cat(y_pred)

In [None]:
y_pred = evaluate(model, DataLoader(ds, batch_size=32, collate_fn=collate_fn))

In [None]:
from sklearn.metrics import roc_curve, auc

In [None]:
fpr, tpr, thr = roc_curve(y, y_pred.numpy())
auc(fpr, tpr)

In [None]:
plt.plot(tpr, 1 / fpr)
plt.yscale("log")
(1 / fpr)[tpr>0.3].max()