# Setup

In [1]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

import torch

RANDOM_SEED = 42
torch.manual_seed(RANDOM_SEED)

<torch._C.Generator at 0x149a4d950>

In [2]:

device = torch.accelerator.current_accelerator().type if torch.accelerator.is_available() else "cpu"
print(f"Using {device} device")

Using mps device


# Dataset

In [3]:
from jetnet.datasets import JetNet
from jetnet.datasets.normalisations import FeaturewiseLinearBounded, FeaturewiseLinear

Issue: coffea.nanoevents.methods.vector will be removed and replaced with scikit-hep vector. Nanoevents schemas internal to coffea will be migrated. Otherwise please consider using that package!.
  from coffea.nanoevents.methods import vector


In [4]:
MASK = True
NUM_PARTICLES = 30
TRAIN_SPLIT = 0.7

In [5]:

feature_maxes = JetNet.fpnd_norm.feature_maxes
if MASK:
    feature_maxes = feature_maxes + [1]
# particle_normalizer = FeaturewiseLinearBounded(
#     feature_norms=1.0,
#     feature_shifts=[0.0, 0.0, -0.5, -0.5] if MASK else [0.0, 0.0, -0.5],
#     feature_maxes=feature_maxes, # Max pre-scaling values of each feature
# )

FeaturewiseLinear(feature_scales=(1/NUM_PARTICLES))

data_args = {
    # Gluons, light quarks, and top quarks
    "jet_type": ["g", "q", "t"],
    "data_dir": "datasets/jetnet",
    "num_particles": NUM_PARTICLES,
    "particle_features": (
        JetNet.ALL_PARTICLE_FEATURES if MASK else JetNet.ALL_PARTICLE_FEATURES[:-1]
    ),
    # The order of the list is preserved in the retrieved data
    "jet_features": ["eta", "pt", "mass", "num_particles", "type"],
    # "particle_normalisation": particle_normalizer,
    "split_fraction": [TRAIN_SPLIT, 1 - TRAIN_SPLIT, 0],
    "download": True
}

In [20]:
from torch.utils.data import DataLoader
X_train = JetNet(**data_args, split="train")
X_test = JetNet(**data_args, split="valid")

`X_train` consists of 368113 jets; each of which is represented as a tuple with 2 elements:
1. A shape `30 x 4` tensor representing each of the particles in the jet, where the features are in the order of ['etarel', 'phirel', 'ptrel', 'mask']
2. A length `5` tensor consisting of jet features ["eta", "pt", "mass", "num_particles", "type"],

### Cleaning

We convert the particles from relative polar coordinates to absolute Cartesian coordinates. To do this, we use the JetNet `relEtaPhiPt_to_cartesian` utility function, which takes in two parameters:
1. Particle features, where the last axis is $\eta^\text{rel}, phi^\text{rel}, p_\text{T}^\text{rel}$
2. Jet features, where the last axis is $\eta, \phi, p_\text{T}, E/c$

$E/c$ is equivalent to jet mass. Values for the azimuthal angle $\phi$ are not provided for jets in the dataset due to the azimuthal symmetry of the collider system. We therefore provide random $\phi$ values for the jets.

In [7]:
from jetnet.utils import EtaPhiPtE_to_cartesian

def transform_rel_particle_coordinates_to_cartesian(X):
    """
    Transforms relative particle coordinates to absolute Cartesian coordinates using the JetNet relEtaPhiPt_to_cartesian utility function

    Requires X to be a list of length N_jets where each item is a tuple (particle_features, jet_features)
    where particle_features is of shape (n_particles, n_particle_features)
    and jet_features is of length n_jet_features

    Particle features need to start as etarel, phirel, ptrel
    Jet features need to start as eta, pt, mass

    The function generates random phi-values for jets taking into account the azimuthal symmetry of the collider
    """

    particle_polarrel_features = X[:][0][:, :, :3]
    
    # Phi has to be the second column for the JetNet utility function
    jet_eta = (X[:][1][:, 0]).unsqueeze(1)
    jet_phi_vals = (2 * torch.pi) * torch.rand(len(X)).unsqueeze(1)
    jet_pt_ec = X[:][1][:, 1:3]
    jet_features = torch.concat([jet_eta, jet_phi_vals, jet_pt_ec], dim=-1)

    # Because of issues with the JetNet utility implementation, we do the conversion ourselves
    eta_rel, phi_rel, pt_rel = torch.unbind(particle_polarrel_features, axis=-1)
    Eta, Phi, Pt, _ = torch.unbind(jet_features, axis=-1)

    pt = pt_rel * Pt.unsqueeze(1)
    eta = eta_rel + Eta.unsqueeze(1)
    phi = phi_rel + Phi.unsqueeze(1)
    p0 = pt * torch.cosh(eta)

    stacked = torch.stack([eta, phi, pt, p0], axis=-1)
    return EtaPhiPtE_to_cartesian(stacked)


In [8]:
X_train_particle_transformed = transform_rel_particle_coordinates_to_cartesian(X_train)
X_test_particle_transformed = transform_rel_particle_coordinates_to_cartesian(X_test)

print(X_train_particle_transformed.shape)
print(X_test_particle_transformed.shape)

torch.Size([368113, 30, 4])
torch.Size([157763, 30, 4])


In [9]:
# Sourced from LorentzNet

def normsq4(p):
    r''' Minkowski square norm
         `\|p\|^2 = p[0]^2-p[1]^2-p[2]^2-p[3]^2`
    ''' 
    psq = torch.pow(p, 2)

    # 2t^2 - (t^2 + x^2 + y^2 + z^2)
    return 2 * psq[..., 0] - psq.sum(dim=-1)
    
def dotsq4(p,q):
    r''' Minkowski inner product
         `<p,q> = p[0]q[0]-p[1]q[1]-p[2]q[2]-p[3]q[3]`
    '''
    psq = p*q
    return 2 * psq[..., 0] - psq.sum(dim=-1)
    
def psi(p):
    ''' `\psi(p) = Sgn(p) \cdot \log(|p| + 1)`
    '''
    return torch.sign(p) * torch.log(torch.abs(p) + 1)

  ''' `\psi(p) = Sgn(p) \cdot \log(|p| + 1)`


## Models

### Architecture

For testing, we design a simple message-passing Lorentz equivariant network based on LorentzNet

The inputs to each layer are the particle features $x_i$ for $x = 1 \dots N$, where $N$ is the number of particles (30). The message $m_{ij}^l$ between particles $i$ and $j$ in the $l$-th layuer is
$$
m^l_{ij} = \phi_e(\psi(||x_i^l - x_j^l||)), \psi(<x_i^l, x_j^l>))\\
m^l_{ij} = \phi_m(m^l_{ij}) m^l_{ij}
$$

where $\phi_e$ is a neural network, $\psi(\cdot) = \text{sgn} \log (|\cdot| + 1)$, $|| \cdot || $ is the Minkowski norm, and $<\cdot, \cdot>$ is the Minkowski inner product

The updated velocity after the $l$-th step is
$$
x_i^{l+1} = x_i^l + c \sum\limits_{j=1}^N \phi_x(m_{ij}, t') \cdot |x_j^l - x_i^l|
$$

where $\phi_x(m_{ij})$ is a scalar, preserving lorentz equivariance; and $t'$ is the time embedding

In [None]:
# Minkowski norm and inner product
N_EDGE_FEATURES = 2
from torch import nn

def minkowski_features(x):
    x_i = x.unsqueeze(-2) # second-last dimension - N
    x_j = x.unsqueeze(-3) # third-last dimension - B
    x_diffs = x_i - x_j # (batch_size, n_particles, n_particles, n_features)

    norms = normsq4(x_diffs)
    dots = dotsq4(x_i, x_j)
    norms, dots = psi(norms), psi(dots)
    return norms, dots, x_diffs

class FMLorentzLayer(nn.Module):
    def __init__(self,n_hidden, 
                 dropout = 0., c_weight=1.0, last_layer=False):
        super(FMLorentzLayer, self).__init__()

        self.c_weight = c_weight

        self.time_embed = nn.Sequential(
            nn.Linear(1, n_hidden),
            nn.SiLU(),
            nn.Linear(n_hidden, n_hidden),
        )

        self.phi_e = nn.Sequential(
            nn.Linear(N_EDGE_FEATURES, n_hidden, bias=False),
            nn.LayerNorm(n_hidden),
            nn.ReLU(),
            nn.Linear(n_hidden, n_hidden),
            nn.ReLU()
        )

        layer = nn.Linear(n_hidden, 1, bias=False)
        torch.nn.init.xavier_uniform_(layer.weight, gain=0.001)
        self.phi_x = nn.Sequential(
            #  Message + time -> Embedding
            nn.Linear(n_hidden * 2, n_hidden),
            nn.ReLU(),
            layer)

        self.phi_m = nn.Sequential(
            nn.Linear(n_hidden, 1),
            nn.Sigmoid())
    
    
    def message_passing(self, norms, dots, diffs):
        inp = torch.stack([norms, dots], dim=-1)  # Concatenate along feature dimension
        # print(f"{inp.shape=}")
        out = self.phi_e(inp)
        # print(f"phi_e(norms, dots) = {out.shape}")
        out = out * self.phi_m(out)
        return out


    def forward(self, x_t: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
        time_embed = self.time_embed(t.unsqueeze(-1))

        norms, dots, diffs = minkowski_features(x_t)
        messages = self.message_passing(norms, dots, diffs)

        batch_size, n_particles, _, n_hidden = messages.shape
        t_broadcast = time_embed.view(batch_size, 1, 1, -1).expand(-1, n_particles, n_particles, -1)

        # Concatenate messages with time
        messages_with_time = torch.cat([messages, t_broadcast], dim=-1)
        velocity_magnitude = self.phi_x(messages_with_time)
        velocity = velocity_magnitude * diffs
        velocity = torch.mean(velocity, dim=-2)
        
        return velocity

In [None]:
class LorentzFMNet(nn.Module):
    def __init__(self, n_hidden, n_layers, dropout=0., c_weight=1.0):
        super(LorentzFMNet, self).__init__()
        self.layers = nn.ModuleList([
            FMLorentzLayer(n_hidden, dropout=dropout, c_weight=c_weight)
            for _ in range(n_layers)
        ])

    def forward(self, x_t: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
        for layer in self.layers:
            vel = layer(x_t, t)
        return vel
    
    def step(self, x_t: torch.Tensor, t_start: torch.Tensor, t_end: torch.Tensor) -> torch.Tensor:
        """
        Calculate the probability density at a particular time step
        """
        # Reshape t_start to be a column vector and expand to match the batch size of x_t
        t_start = t_start.view(1, 1).expand(x_t.shape[0], 1)

        # Translate x_t by the expected midpoint velocity between t_start and t_end
        start_vel = self.forward(x_t=x_t, t=t_start)
        midpoint_x = x_t + (start_vel * (t_end - t_start) / 2)
        midpoint_vel = self.forward(x_t=midpoint_x, t=t_start + (t_end - t_start) / 2)

        return x_t + (t_end - t_start) * midpoint_vel

In [12]:
model = LorentzFMNet(n_hidden=64, n_layers=12, dropout=0.1, c_weight=1.0).to(device)
print(model)

LorentzFMNet(
  (layers): ModuleList(
    (0-11): 12 x FMLorentzLayer(
      (time_embed): Sequential(
        (0): Linear(in_features=1, out_features=64, bias=True)
        (1): SiLU()
        (2): Linear(in_features=64, out_features=64, bias=True)
      )
      (phi_e): Sequential(
        (0): Linear(in_features=2, out_features=64, bias=False)
        (1): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
        (2): ReLU()
        (3): Linear(in_features=64, out_features=64, bias=True)
        (4): ReLU()
      )
      (phi_x): Sequential(
        (0): Linear(in_features=128, out_features=64, bias=True)
        (1): ReLU()
        (2): Linear(in_features=64, out_features=1, bias=False)
      )
      (phi_m): Sequential(
        (0): Linear(in_features=64, out_features=1, bias=True)
        (1): Sigmoid()
      )
    )
  )
)


### Training

In [13]:
time = torch.tensor([0.1, 0.5], device=device)  # Example time input
x = X_train_particle_transformed[:2].to(device)  # Move to the same device as the model
output = model(x, time)
output.shape

torch.Size([2, 30, 4])

In [49]:

BATCH_SIZE = 128
X_train_loaded = DataLoader(X_train_particle_transformed, shuffle=True, batch_size=BATCH_SIZE, pin_memory=True)
X_test_loaded = DataLoader(X_test_particle_transformed, shuffle=False, batch_size=BATCH_SIZE, pin_memory=True)

In [50]:
epochs = 100
losses = []
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

for epoch in range(epochs):

    epoch_loss = []

    for i, data in enumerate(X_train_loaded):
        x_0 = torch.randn_like(data).to(device)
        x_1 = data.to(device)

        t = torch.rand(x_0.shape[0], device=device).view(-1, 1, 1)  # Reshape t to match the expected input shape
        x_t = (1 - t) * x_0 + t * x_1  # Linear interpolation
        dx_t = x_1 - x_0
        optimizer.zero_grad()

        loss = nn.MSELoss()(model(x_t, t), dx_t)
        loss.backward()
        optimizer.step()

        epoch_loss.append(loss.item())

    losses.append(np.mean(epoch_loss))
    if epoch % 10 == 0:
        print(f"Epoch [{epoch+1}/{epochs}], Loss: {losses[-1]:.4f}")

KeyboardInterrupt: 