# The Particle Transformer

7/1/23

Summary notebook for the Particle Transformer, including all necessary functions and classes for running.

Importing all necessary libaries for implementation and training...

In [None]:
%cd 'data location'/data/

In [None]:
import os
import numpy as np
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import sklearn.metrics as metrics
from sklearn.model_selection import train_test_split
from tensorflow import keras
import tensorflow as tf
import tensorflow_addons as tfa
from functools import partial
import h5py
import matplotlib.pyplot as plt

import vector
import collections
import math
import string
from kerasPMHA import KerasPMHA


Defining a function to take a training history input and return the training curves (accuracy and loss).

In [None]:
def histplot(history):
    hist = pd.DataFrame(history.history)
    epochs = hist.index.to_numpy() +1
    fig = make_subplots(rows=1, cols=2,subplot_titles=('Accuracy',  'Loss'))
    
    fig.add_trace(go.Scatter(mode="markers+lines", x=epochs, y=hist["Accuracy"], name = "Accuracy"), row=1, col=1)
    fig.add_trace(go.Scatter(mode="markers+lines", x=epochs, y=hist["val_Accuracy"], name = "Val accuracy"), row=1, col=1)
    
    fig.add_shape(type='line',
                x0=0,
                y0=np.max(hist["val_Accuracy"]),
                x1=np.max(epochs),
                y1=np.max(hist["val_Accuracy"]),
                line=dict(color='Green',dash="dot"),
                xref='x',
                yref='y',
                row=1,
                col=1,
    )
    
    fig.add_trace(go.Scatter(mode="markers+lines", x=epochs, y=hist["loss"], name = "Loss"), row=1, col=2)
    fig.add_trace(go.Scatter(mode="markers+lines", x=epochs, y=hist["val_loss"], name = "Val loss"), row=1, col=2)
    
    fig.add_shape(type='line',
                x0=0,
                y0=np.min(hist["val_loss"]),
                x1=np.max(epochs),
                y1=np.min(hist["val_loss"]),
                line=dict(color='Green',dash="dot"),
                xref='x',
                yref='y',
                row=1,
                col=2,
    )
    
    fig['layout']['xaxis']['title']='Epoch'
    fig['layout']['xaxis2']['title']='Epoch'
    fig['layout']['yaxis']['title']='Accuracy'
    fig['layout']['yaxis2']['title']='Loss'
    
    return fig

## Data importing and preprocessing

Defining the functions to calcuate and pairwise kinematics. A modulo function is also defined as training runs into errors using the default implementation.

In [None]:
def tensor_mod(a, n):
    a = tf.cast(a, dtype=tf.float32)
    n = tf.cast(n, dtype=tf.float32)
    
    return a - tf.multiply(n, tf.math.floor(a/n))

def delta_phi(a, b):
    x = tf.subtract(a, b)
    x += np.pi
    n = tf.constant(2*np.pi)
    x = tensor_mod(x, n)
    x -=np.pi
    return x


def delta_r2(eta1, phi1, eta2, phi2):
    return tf.square(eta1 - eta2) + tf.square(delta_phi(phi1, phi2))

def to_pt2(x, eps=1e-8):
    pt2 = tf.reduce_sum(tf.square(x[:, :2]), axis=1, keepdims=True)
    if eps is not None:
        pt2 = tf.clip_by_value(pt2, clip_value_min=eps, clip_value_max=10e32)
    return pt2

def atan2(y, x):
    sx = tf.math.sign(x)
    sy = tf.math.sign(y)
    pi_part = (sy + sx * (sy ** 2 - 1)) * (sx - 1) * (-math.pi / 2)
    atan_part = tf.math.atan(y / (x + (1 - sx ** 2))) * sx ** 2
    return atan_part + pi_part


def to_ptrapphim(x, return_mass=True, eps=1e-8, for_onnx=False):
    px, py, pz, energy = tf.split(x, (1, 1, 1, 1), axis=1)
    pt = tf.sqrt(to_pt2(x, eps=eps))
    rapidity = 0.5 * tf.math.log(1 + (2 * pz) / (energy - pz))
    sign = tf.math.sign(rapidity)
    rapidity =  tf.clip_by_value(tf.abs(rapidity), clip_value_min=1e-20, clip_value_max=10e32)
    rapidity = sign*rapidity
    
    phi = (atan2 if for_onnx else tf.math.atan2)(py, px)
    if not return_mass:
        return tf.concat((pt, rapidity, phi), axis=1)
    else:
        m = tf.sqrt(to_m2(x, eps=eps))
        return tf.concat((pt, rapidity, phi, m), axis=1)


def boost(x, boostp4, eps=1e-8):
    # boost x to the rest frame of boostp4
    p3 = tf.clip_by_value(-boostp4[:, :3] / boostp4[:, 3:], clip_value_min=eps, clip_value_max=10e32)
    b2 = tf.reduce_sum(tf.square(p3), axis = 1, keepdims=True)
    gamma = tf.sqrt(tf.clip_by_value(1 - b2, clip_value_min=eps, clip_value_max=10e32))
    gamma2 = (gamma - 1) / b2
    gamma2 = tf.where(b2==0, 0, gamma2)

    bp = tf.reduce_sum(x[:, :3] * p3, axis = 1, keepdims=True)
    v = x[:, :3] + gamma2 * bp * p3 + x[:, 3:] * gamma * p3
    return v


def p3_norm(p, eps=1e-8):
    return tf.clip_by_value(p[:, :3] / tf.norm(p[:, :3], axis=1, keepdims=True), clip_value_min=eps, clip_value_max=10e32)

def to_m2(x, eps=1e-8):
    m2 = tf.square(x[:, :, 3:4]) - tf.reduce_sum(tf.square(x[:, :, :3]), axis=-1, keepdims=True)
    if eps is not None:
        m2 = tf.clip_by_value(m2, clip_value_min=eps, clip_value_max=10e32)
    return m2

Defining the import funtion for top data, and functions to calculate single particle transverse momentum, rapidity, and phi.

In [None]:
def to_pt2_pre(x, eps=1e-8):
    pt2 = tf.reduce_sum(tf.square(x[:, :, :2]), axis=2, keepdims=True)
    if eps is not None:
        pt2 = tf.clip_by_value(pt2, clip_value_min=eps, clip_value_max=10e32)
    return pt2

def to_ptrapphim_pre(x, return_mass=True, eps=1e-8, for_onnx=False):
    px, py, pz, energy = tf.split(x, (1, 1, 1, 1), axis=2)
    pt = tf.sqrt(to_pt2_pre(x, eps=eps))
    
    rapidity = 0.5 * tf.math.log(1 + (2 * pz) / (energy - pz))
    sign = tf.math.sign(rapidity)
    rapidity =  tf.clip_by_value(tf.abs(rapidity), clip_value_min=1e-20, clip_value_max=10e32)
    rapidity = sign*rapidity
    
    phi = (atan2 if for_onnx else tf.math.atan2)(py, px)
    if not return_mass:
        return pt, rapidity, phi
    else:
        m = tf.sqrt(to_m2(x, eps=eps))
        return tf.concat((pt, rapidity, phi, m), axis=1)

def import_top_data(filename, dataset_size, alpha = 0., eps=1e-8, padding=200):
    print("Importing data...\n")
    path = "TopTagging/{}.h5".format(filename)
    df = pd.read_hdf(path, 'table')
    
    part_px = []
    part_py = []
    part_pz = []
    part_E = []

    if padding > 200: padding=200
    
    for i in range(padding):
        part_px.append(df["PX_{}".format(i)][:dataset_size].to_list())
        part_py.append(df["PY_{}".format(i)][:dataset_size].to_list())
        part_pz.append(df["PZ_{}".format(i)][:dataset_size].to_list())
        part_E.append(df["E_{}".format(i)][:dataset_size].to_list())


    rot_mat = [[1.,0.,0.,0.],
           [0.,1.,0.,0.],
           [0.,0.,np.cos(alpha), -np.sin(alpha)],
           [0.,0.,np.sin(alpha), np.cos(alpha)]]
    rot_tensor = tf.constant(rot_mat)
    
    
    part_px = np.stack(part_px, axis=-1)
    part_py = np.stack(part_py, axis=-1)
    part_pz = np.stack(part_pz, axis=-1)
    part_E = np.stack(part_E, axis=-1)   
     
    if alpha != 0:
        part_py_p = np.cos(alpha)*part_py - np.sin(alpha)*part_pz
        part_pz_p = np.cos(alpha)*part_pz + np.sin(alpha)*part_py
        part_py = part_py_p
        part_pz = part_pz_p
    
    jet_px = np.sum(np.asarray(part_px),axis=1)
    jet_py= np.sum(np.asarray(part_py),axis=1)
    jet_pz= np.sum(np.asarray(part_pz),axis=1)
    jet_E= np.sum(np.asarray(part_E),axis=1)        
        
    batch_part = tf.stack([part_px, part_py, part_pz, part_E], axis=-1)
    batch_jet = tf.stack([jet_px, jet_py, jet_pz, jet_E], axis=-1)
    
    batch_jet = tf.reshape(batch_jet, (tf.shape(batch_jet)[0],1,tf.shape(batch_jet)[1]))
    batch_jet = tf.tile(batch_jet, (1,padding, 1))

    part_pt, part_rapidity, part_phi = to_ptrapphim_pre(batch_part, return_mass=False, eps=1e-8, for_onnx=False)
    jet_pt, jet_rapidity, jet_phi= to_ptrapphim_pre(batch_jet, return_mass=False, eps=1e-8, for_onnx=False)
    
    jet_E = tf.reshape(jet_E, (tf.shape(jet_E)[0],1))
    jet_E = tf.tile(jet_E, (1,padding))

    part_eta = part_rapidity - jet_rapidity
    part_phi = delta_phi(part_phi, jet_phi)
    part_logpt = tf.clip_by_value(tf.math.log(part_pt), clip_value_min=eps, clip_value_max=10e32)
    part_logE = tf.math.sign(tf.math.log(part_E))*(tf.clip_by_value(tf.abs(tf.math.log(part_E)), clip_value_min=-eps, clip_value_max=10e32))
    part_ptptjet =  part_pt/jet_pt 
    part_EEjet =  part_E/jet_E 
    part_dR = tf.sqrt(tf.square(tf.cast(part_eta,dtype=tf.float32)) + tf.square(tf.cast(part_phi,dtype=tf.float32)))
    y = df['is_signal_new'][:dataset_size].to_list()
    
    batch = np.stack([part_px, part_py, part_pz, part_E, tf.squeeze(part_eta), tf.squeeze(part_phi), tf.squeeze(part_logpt),
                           tf.squeeze(part_logE), tf.squeeze(part_ptptjet), tf.squeeze(part_EEjet), tf.squeeze(part_dR)], axis=-1)
    
    i=0

    tf_batch = tf.reshape(tf.convert_to_tensor(()), (0, padding, 11))
    tf_batch = tf.cast(tf_batch,dtype=tf.float32)
    while (i+1)*100000 < dataset_size:
        sub_batch = tf.convert_to_tensor(batch[i*100000:(i+1)*100000])
        padding_bool = tf.reduce_all(sub_batch[:,:,:4]!=[0.,0.,0.,0.],axis=-1) 
        padding_arr = tf.cast(padding_bool, dtype=tf.float64)
        sub_batch = tf.einsum('...i,...ij->...ij', padding_arr, sub_batch)
        sub_batch = tf.where(tf.math.is_nan(sub_batch), tf.zeros_like(sub_batch), sub_batch)
        sub_batch = tf.cast(sub_batch,dtype=tf.float32)
        tf_batch = tf.concat([tf_batch, sub_batch],0)
        i+=1
    sub_batch = tf.convert_to_tensor(batch[i*100000:])
    padding_bool = tf.reduce_all(sub_batch[:,:,:4]!=[0.,0.,0.,0.],axis=-1) 
    padding_arr = tf.cast(padding_bool, dtype=tf.float64)
    sub_batch = tf.einsum('...i,...ij->...ij', padding_arr, sub_batch)
    sub_batch = tf.where(tf.math.is_nan(sub_batch), tf.zeros_like(sub_batch), sub_batch)
    sub_batch = tf.cast(sub_batch,dtype=tf.float32)
    tf_batch = tf.concat([tf_batch, sub_batch],0)
    print("Data imported.\n")
    
    return tf_batch, tf.convert_to_tensor(y,dtype=tf.float32)

Importing 100,000 training events, defining the maximum number of particles per event to be 128. 

In [None]:
num_partons = 128
training_data = import_top_data('train', dataset_size=250000,padding=num_partons)

Function collating the pairwise kinematic functions, taking an arguement to define which features to use.

In [None]:
def to_m2_pair(x, eps=1e-8):
    m2 = tf.square(x[:, :, :, 3:4]) - tf.reduce_sum(tf.square(x[:, :, :, :3]), axis=-1, keepdims=True)
    #m2 = x[:, 3:4].square() - x[:, :3].square().sum(dim=1, keepdim=True)
    if eps is not None:
        m2 = tf.clip_by_value(m2, clip_value_min=eps, clip_value_max=10e32)
    return m2

def pairwise_lv_fts(input_data, output_vars=['kt', 'z', 'delta', 'm2'], eps=1e-8, for_onnx=False):

    
    x = input_data[:,:,:4]
    batch_size = tf.squeeze(tf.shape(x))[0]
    length = tf.squeeze(tf.shape(x))[1]
    
    num_outputs = len(output_vars)
    
    padding_bool = tf.reduce_all(x!=[0,0,0,0], axis=-1) 
    padding_arr = tf.cast(padding_bool, dtype=tf.float32)
    padding_mat = tf.einsum('...i,...j->...ij', padding_arr, padding_arr)
    outputs = []
    
    all_out = []

    xi = tf.expand_dims(x,1)
    xj = tf.expand_dims(x,2)
    pt_rap  = to_ptrapphim_pre(x, False, eps=None, for_onnx=for_onnx)
    pt, rap, phi = tf.split(to_ptrapphim_pre(x, False, eps=None, for_onnx=for_onnx), (1, 1, 1), axis=0)
    
    pt = tf.squeeze(pt)
    rap = tf.squeeze(rap)
    phi = tf.squeeze(phi)
    

    if batch_size == 1:
        pt = tf.expand_dims(pt,0)
        rap = tf.expand_dims(rap,0)
        phi = tf.expand_dims(phi,0)
        
        
        
    pti = tf.expand_dims(pt,1)
    rapi = tf.expand_dims(rap,1)
    phii = tf.expand_dims(phi,1)
    pti = tf.expand_dims(pti,3)
    rapi = tf.expand_dims(rapi,3)
    phii = tf.expand_dims(phii,3)
    
    ptj = tf.expand_dims(pt,2)
    rapj = tf.expand_dims(rap,2)
    phij = tf.expand_dims(phi,2)
    ptj = tf.expand_dims(ptj,3)
    rapj = tf.expand_dims(rapj,3)
    phij = tf.expand_dims(phij,3)
    
    
    delta = tf.sqrt(delta_r2(rapi, phii, rapj, phij))
    lndelta = tf.math.log(tf.clip_by_value(delta, clip_value_min=eps, clip_value_max=10e32))

    ptmin = ((pti <= ptj) * pti + (pti > ptj) * ptj) if for_onnx else tf.math.minimum(pti, ptj)
    lnkt = tf.clip_by_value(tf.math.log(ptmin * delta), clip_value_min=eps, clip_value_max=10e32)

    ptmin_sum = tf.clip_by_value(ptmin / (pti + ptj), clip_value_min=eps, clip_value_max=10e32)
    lnz = tf.math.sign(tf.math.log(ptmin_sum))*tf.clip_by_value(tf.abs(tf.math.log(ptmin_sum)), clip_value_min=eps, clip_value_max=10e32)        
    ### Old lnz method lnz = tf.clip_by_value(tf.math.log(ptmin_sum), clip_value_min=eps, clip_value_max=10e32)        
    
    xij = xi + xj
    lnm2 = tf.math.log(to_m2_pair(xij, eps=eps))

    kinematic_dict =  dict([('kt', lnkt) ,('z', lnz) ,('delta', lndelta) , ('m2', lnm2)])

    for i in range(num_outputs):
        outputs.append(kinematic_dict[output_vars[i]])

    y = tf.stack(outputs)
    y = tf.squeeze(y)
    
    if num_outputs == 1:
        y = tf.expand_dims(y, 0)

    if batch_size == 1:
        y = tf.expand_dims(y, 0)
    
    y = tf.transpose(y, perm=(1,2,3,0))
    out = tf.einsum('...ij,...ijk->...ijk', padding_mat, y)
    out = tf.where(tf.math.is_nan(out), tf.zeros_like(out), out)
    return out

## Embedding functions

Defining single and particle pair embedding functions.

In [None]:
class Embedding(tf.keras.layers.Layer):
    def __init__(self, d_model, input_dim, activation = 'gelu'):
        super().__init__()
        self.d_model = d_model
        self.input_dim = input_dim
        self.masking = tf.keras.layers.Masking()
        self.batchnorm = tf.keras.layers.BatchNormalization()
        layer_arr = []
        for dim in d_model:
            temp = [tf.keras.layers.LayerNormalization(), tf.keras.layers.Dense(dim,activation=activation)]
            layer_arr.append(temp)
        
        self.layers = layer_arr
        
    def call(self, x):
        #x = self.dense(x)
        x = self.masking(x)
        x = self.batchnorm(x)
        for layer in self.layers:
            x = layer[0](x)
            x = layer[1](x)
        return x
    
class PairEmbed(tf.keras.layers.Layer):
    def __init__(self, max_len, output_vars, d_model, activation = 'gelu', eps=1e-8, for_onnx=False):
        super().__init__()
        self.for_onnx = for_onnx
        self.pairwise_lv_fts = partial(pairwise_lv_fts, output_vars = output_vars, eps=eps, for_onnx = for_onnx)
        self.masking = tf.keras.layers.Masking()
        self.num_outputs = len(output_vars)
        self.batchnorm = tf.keras.layers.BatchNormalization()
        
        layers_arr = []
        for dim in d_model:
            temp = [tf.keras.layers.Conv1D(filters=dim,kernel_size=1),tf.keras.layers.BatchNormalization(), tf.keras.layers.Activation('gelu')]
            layers_arr.append(temp)
            
        self.layers = layers_arr
        self.output_dim = d_model[-1]

    def call(self, x):
        # x: (batch, v_dim, seq_len)
        batch_size = tf.shape(x)[0]
        n_particles = tf.shape(x)[1]
        seq_len = tf.shape(x)[2]
        x = self.masking(x)
        x = self.pairwise_lv_fts(x)
        x = tf.reshape(x, (batch_size, n_particles * n_particles, 1, self.num_outputs))
        x = self.batchnorm(x)
        
        for layer in self.layers:
            x = layer[0](x)
            x = layer[1](x)
            x = layer[2](x)
        
        x = tf.reshape(x, (batch_size, n_particles, n_particles, self.output_dim))
        x = self.masking(x)
        return x

## Defining attention layer classes

The particle and class attention blocks are defined, inheriting from a parent BaseAttention class.

In [None]:
class ScalingLayer(tf.keras.layers.Layer):
    def __init__(self,):
        super().__init__()
        self.initial_value = 1.
        self.learned = tf.Variable(self.initial_value, name = 'learned_scalar')
    def call(self, x):
        return self.learned*x
    
class BaseAttention(tf.keras.layers.Layer):
    # Base attention class with member keras layers to be inherited

    def __init__(self, embed_dim, key_dim, num_heads, dropout, **kwargs):
        super().__init__()
        self.embed_dim = embed_dim
        self.add = tf.keras.layers.Add()
        self.gelu = tf.keras.layers.Activation(tf.keras.activations.gelu)
        self.layernorm = tf.keras.layers.LayerNormalization()
        self.linear = tf.keras.layers.Dense(embed_dim, activation=None)
        self.mha = tf.keras.layers.MultiHeadAttention(num_heads=num_heads, key_dim = key_dim, dropout=dropout)
        self.p_mha = KerasPMHA(num_heads=num_heads, key_dim = key_dim, dropout=dropout)
        self.dropout = tf.keras.layers.Dropout(dropout)
        
class PAttentionBlock(BaseAttention):
    # Particle attention block, using the PMHA layer defined using the keras implementation 
    # of multi-head attention. 
    
    def call(self, x, context):
        # x: single particle embedding, (Batch size, Num. Partons, Embedding Dimension)
        # context: pairwise particle embedding, (Batch size, Num. Partons, Num. Partons, Embedding Dimension)
        
        x_context = x
        x = self.layernorm(x)
        x = self.p_mha(x,x,x,bias_mask=context)
        x = self.layernorm(x)
        x = self.dropout(x)
        x = self.add([x, x_context])
        
        x_context = x
        x = self.layernorm(x)
        x = self.linear(x)
        x = self.gelu(x)
        x = self.dropout(x)
        x = self.layernorm(x)
        x = self.linear(x)
        x = self.dropout(x)
        x = self.add([x, x_context])
        return x
    
class PlainAttentionBlock(BaseAttention):
    
    # Plain attention block, identical to Particle block except for a standard
    # MHA layer replacing the PMHA. 
    
    def call(self, x):
        # x: single particle embedding, (Batch size, Num. Partons, Embedding Dimension)
        x_context = x
        x = self.layernorm(x)
        x = self.mha(x,x,x)
        x = self.layernorm(x)
        x = self.dropout(x)
        x = self.add([x, x_context])
        
        x_context = x
        x = self.layernorm(x)
        x = self.linear(x)
        x = self.gelu(x)
        x = self.dropout(x)
        x = self.layernorm(x)
        x = self.linear(x)
        x = self.dropout(x)
        x = self.add([x, x_context])
        return x

class CAttentionBlock(BaseAttention):
    # Class attention block, following the design outlined in the Particle
    # Transformer paper
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        
    def call(self,x, x_cls):
        # x: output of final Particle/Plain attention layer, (Batch size, Num. Partons, Embedding Dimension)
        # x_cls: class token to extract features, (Batch size, 1, Embedding Dimension)
        x = tf.concat((x_cls, x), axis=1)
        x = self.layernorm(x)
        x = self.mha(x_cls, x, x)
        x = self.layernorm(x)
        x = self.add([x, x_cls])
        
        x_context = x
        x = self.layernorm(x)
        x = self.linear(x)
        x = self.gelu(x)
        x = self.layernorm(x)
        x = self.linear(x)
        x = self.add([x, x_context])
        return x

## Defining the Particle Transformer

Now the component attention layers are defined, the complete transformer model can be built. This is defined, along with a multi-layer perceptron layer for the output.

In [None]:
class MLP(tf.keras.layers.Layer):
    def __init__(self, MLP_params):
        super().__init__()
        layer_arr = [] 
        for out_dim, drop_rate in MLP_params:
            temp = [tf.keras.layers.Dense(out_dim, activation=None),tf.keras.layers.Dropout(rate = drop_rate)]
            layer_arr.append(temp)
            
        self.MLP_layers = layer_arr
    def call(self, x):
        # x: class token or single layer attention, (Batch size, ..., embedding dimension)
        for layers in self.MLP_layers:
            x = layers[0](x)
            x = layers[1](x)
        return x

class ParticleTransformer(tf.keras.Model):
    def __init__(self, d_model, d_pair_model, max_len, input_dim, output_vars, num_heads, N_p, N_c, output_dims, isPlain = False, **kwargs):
        super().__init__(**kwargs)
        d_pair_model.append(num_heads)
        self.embed_dim = d_model[-1]
        self.embed = Embedding(d_model, input_dim)
        self.pair_embed = PairEmbed(max_len, output_vars, d_pair_model)
        self.isPlain = isPlain
        
        pattention = []
        if isPlain:
            for i in range(N_p):
                pattention.append(PlainAttentionBlock(embed_dim = self.embed_dim,key_dim=self.embed_dim, num_heads=num_heads, dropout=0.1))
            self.particle_attention_arr = pattention
        
        else:
            for i in range(N_p):
                pattention.append(PAttentionBlock(embed_dim = self.embed_dim,key_dim=self.embed_dim, num_heads=num_heads, dropout=0.1))
            self.particle_attention_arr = pattention
        
        cattention = []
        for i in range(N_c):
            cattention.append(CAttentionBlock(embed_dim = self.embed_dim,key_dim=self.embed_dim, num_heads=num_heads, dropout=0.0))
        self.class_attention_arr = cattention

        # Standard CS class token
        self.cls_tkn = tf.Variable(name="class token", trainable=True, initial_value = tf.random.normal(shape=(1,1,self.embed_dim),stddev=0.2))
        # Alternative, feature-based token
        self.token_layer = Embedding(d_model, 1)
        
        self.mlp = MLP(output_dims)
        self.softmax = tf.keras.layers.Softmax()
        self.sigmoid = tf.keras.layers.Activation(tf.keras.activations.sigmoid)
        self.n_p = N_p
        self.n_c = N_c
        self.output_dims = output_dims
        
    def call(self, inputs):
        # Inputs: Tensor of single particle kinematic features for each jet event,
        # (Batch size, Num. partons, Num. kinematic features)
        # px, py, pz, energy
        
        x = self.embed(inputs)
        if self.isPlain:
            for block in self.particle_attention_arr:
                x = block(x)
        
        else:
            U = self.pair_embed(inputs)
            
            for block in self.particle_attention_arr:
                x = block(x, U)

        broadcast_shape = (tf.shape(x)[0], 1, 1)
        class_tkn = self.cls_tkn
        class_tkn = tf.tile(class_tkn, broadcast_shape)
        
        # Feature-based token
#         p = inputs[:,:,:4]
#         E_tot = tf.reduce_sum(p[:,:,3], axis=1)
#         pz_tot = tf.reduce_sum(p[:,:,2], axis=1)
#         E_T2 = tf.square(E_tot) - tf.square(pz_tot)
#         E_T2 = tf.reshape(E_T2,broadcast_shape)
#         class_tkn = self.token_layer(E_T2)

        x_context = x        
        x = class_tkn

        for block in self.class_attention_arr:
            x = block(x_context, x)

        x = tf.reshape(x, (tf.shape(x)[0], tf.shape(x)[2]))  
        x = self.mlp(x)
        if self.output_dims[-1][0] == 1:
            x = self.sigmoid(x)
        else:
            x = self.softmax(x)
        return x

## Defining a wrapper

A complete wrapper is defined, with methods to test, train, and to save/load weights. 

Parameters of the Particle Transformer are defined upon initialization, including chosen pairwise features and the option to use the plain MHA.

In [None]:
class ParticleTransformerWrapper():
    def __init__(self, d_model, d_pair_model, max_number_partons, input_dim, pairwise_outputs, num_heads, N_p, N_c, output_dims, isPlain = False, train_filename = 'train', model_name = "ParticleTransformer",  dataset_size = 10000, **kwargs):
        self.ParT = ParticleTransformer(d_model, d_pair_model, max_number_partons, input_dim, pairwise_outputs,  num_heads, N_p, N_c, output_dims, isPlain=isPlain, **kwargs)
        self.input_dim = input_dim
        self.padding = max_number_partons
        self.model_name = model_name
        self.weight_filename = model_name
        self.weight_log = []
        self.val_acc_loss_log = []
        self.is_built = False
        self.out_dim = output_dims[-1][0]
        
    def build(self):
        input_shape = (None, self.padding, self.input_dim)
        RAdam = tfa.optimizers.RectifiedAdam(beta_1=0.95, beta_2=0.999, epsilon=0.00001, weight_decay=0.01)
        Lookahead = tfa.optimizers.Lookahead(RAdam)
        
        if self.out_dim == 1:
            loss_func = 'binary_crossentropy'
        else: loss_func = 'categorical_crossentropy'
        print(loss_func)
        
        self.ParT.compile(
             loss=loss_func,
             optimizer=Lookahead,
             metrics=['Accuracy'],
            )
        self.ParT.build(input_shape)
        self.is_built = True
    
    def summary(self):
        if self.is_built:
            return self.ParT.summary()
        else:
            print("This model has not yet been built. Build the model first by calling `build()` or by calling the model on a batch of data.")
            return None
    
    def train(self, dataset, dynamic_lr = False, initial_lr=0.001, max_lr = 0.01, min_max_ratio = 0.1, rampup_epochs = 0,sustain_ratio = 0.7, early_stopping = True, print_summary = False, train_val_split = 0.3, metrics = ['Accuracy'], epochs = 20, batch_size = 8, learning_rate = 0.001):
        # Training function, with options for a dynamic learning rate and early stopping
        # Weight checkpoints are saved each epoch, if validation accuracy higher 
        # than previous epoch
        print("Loading data...\n") 
        X, y = dataset
        X_train, X_test, y_train, y_test = train_test_split(X.numpy(), y.numpy(), shuffle=True, test_size=train_val_split)
        X_train = tf.convert_to_tensor(X_train)
        X_test = tf.convert_to_tensor(X_test)
        y_train = tf.convert_to_tensor(y_train)
        y_test = tf.convert_to_tensor(y_test)
        print("Loaded data.\nCompiling model...\n")
        
        RAdam = tfa.optimizers.RectifiedAdam(beta_1=0.95, beta_2=0.999, epsilon=0.00001)#, weight_decay=0.01)
        Lookahead = tfa.optimizers.Lookahead(RAdam)
        
        
        if self.out_dim == 1:
            loss_func = 'binary_crossentropy'
        else: loss_func = 'categorical_crossentropy'
        
        self.ParT.compile(
             loss=loss_func,
             optimizer=Lookahead,
             metrics=metrics,
            )

        print("Model compiled.\nBeginning training... \n")
        
        path = "Checkpoints/{}".format(self.model_name)
        path = path+"/cp-{epoch:04d}"

              
        min_lr = min_max_ratio*max_lr

        sustain_epochs = int(sustain_ratio*epochs)
        exp_decay = np.exp(-1)
        if rampup_epochs != 0:
            grad = (max_lr - initial_lr)/rampup_epochs
        else: grad=0
        
        def lrfn(epoch):
            
            if epoch < rampup_epochs:
                return initial_lr + grad*epoch
            elif epoch < sustain_epochs+rampup_epochs:
                return max_lr
            else:
                return (max_lr - min_lr) * exp_decay**(epoch-rampup_epochs-sustain_epochs) + min_lr

        lr_callback = tf.keras.callbacks.LearningRateScheduler(lambda epoch: lrfn(epoch), verbose=True) 

        self.ParT.save_weights(path.format(epoch=0))
        cp_callback = tf.keras.callbacks.ModelCheckpoint(
        filepath=path,
        verbose=1,
        save_weights_only=True,
        monitor='val_Accuracy',
        mode='max',
        save_best_only=True)
        
        es_loss = keras.callbacks.EarlyStopping(monitor='val_loss', mode='min', verbose=1,patience=5)
        # log_dir = "logs/fit/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
        # tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1)
    
        if early_stopping and dynamic_lr:
            
            history = self.ParT.fit(X_train, y_train, batch_size=batch_size, epochs=epochs, verbose=1, callbacks=[es_loss, lr_callback, cp_callback], validation_data=(X_test, y_test), shuffle=True)

        if early_stopping:
            history = self.ParT.fit(X_train, y_train, batch_size=batch_size, epochs=epochs, verbose=1, callbacks=[es_loss, cp_callback], validation_data=(X_test, y_test), shuffle=True)

        if dynamic_lr:
            history = self.ParT.fit(X_train, y_train, batch_size=batch_size, epochs=epochs, verbose=1, callbacks=[lr_callback, cp_callback], validation_data=(X_test, y_test), shuffle=True)
 
        else:
            history = self.ParT.fit(X_train, y_train, batch_size=batch_size, epochs=epochs, verbose=1, callbacks=[cp_callback], validation_data=(X_test, y_test), shuffle=True)



        self.save_weights()
        self.is_built = True
        
        if print_summary == True:
            print(self.ParT.summary())
        
        figure = histplot(history)
        figure.show()
        filename = self.model_name
        isFile = os.path.isfile("Figures/{}.png".format(filename))
        i=1
        while isFile:
            filename = "{}{}".format(self.model_name, i)
            i+=1
            isFile = os.path.isfile("Figures/{}.png".format(filename))

        figure.write_image("Figures/{}.png".format(filename), width=800, height=500)
        
        return
    
    def test(self, dataset, get_score=True, percent_of_total_shown=0.1, batch_size=8):
        X, y = dataset
        y_pred = self.ParT.predict(X, batch_size=batch_size)
        if get_score:
            try:
                score = self.ParT.evaluate(X, y, batch_size=batch_size, verbose=1)
                print('Test loss:', score[0])
                print('Test accuracy:', score[1])
            except(RuntimeError):
                print("Cannot evaluate model without first compiling the model for training/testing")

        fig = make_subplots(rows=1, cols=1,subplot_titles=(['ROC']))
        try:
            fpr, tpr, thresholds = metrics.roc_curve(y, y_pred)
            AUC = metrics.auc(fpr, tpr)
            reduced_fpr = []
            reduced_tpr = []
            reduced_threshold = []
            for i in range(len(fpr)):
                if np.random.rand()<percent_of_total_shown and i<len(thresholds):
                    reduced_fpr.append(fpr[i])
                    reduced_tpr.append(tpr[i])
                    reduced_threshold.append(thresholds[i])


            fig.add_trace(go.Scatter(x=reduced_fpr, y=reduced_tpr,
                                    mode='lines',
                                    marker_line_color="midnightblue", 
                                    text = reduced_threshold,
                                    hovertemplate = 'Thresh: %{text:.3f}<extra></extra>',
                                    ), row=1, col=1)
            
        except(ValueError):
            print("Could not compute ROC due NaN values in y_pred")
            AUC = np.NaN
        fig['layout']['xaxis1']['title']='FPR'
        fig['layout']['yaxis1']['title']='TPR'
        
        return fig, AUC
    
    def save_weights(self):
        i=1
        filename = self.model_name
        isFile = os.path.isfile("Weights/{}.index".format(filename))
        while isFile:
            filename = "{}{}".format(self.model_name, i)
            isFile = os.path.isfile("Weights/{}.index".format(filename))
            i+=1
        
        self.ParT.save_weights("Weights/{}".format(filename))
        self.weight_filename = filename
    
    def load_weights(self, filepath):
        isFile = os.path.isfile(filepath+".index")
        if isFile:
            try:
                self.ParT.load_weights(filepath)
            except(...):
                print("Error loading weights. Check 'filename'.data-00000-of-00001 exists in the weights folder.")
        else: print("{} does not appear to exist.".format(filepath)) 

## Defining the transformers

The 6 transformer versions of interest are defined, with 8 particle blocks and 2 class blocks, with 8 heads for each attention layer.

In [None]:
ParT_Full = ParticleTransformerWrapper(d_model=[128, 512, 128], d_pair_model=[64,64,64], 
                                        max_number_partons=num_partons, input_dim=11, 
                                        pairwise_outputs=['kt', 'z', 'delta', 'm2'], 
                                        num_heads=8, N_p=8, N_c=2, output_dims=[[1,0]], 
                                        model_name='ParT_Full_run1')


num_partons = 128
training_data = import_top_data('train', dataset_size=250000,padding=num_partons)

batch_size = 256
sustain_ratio = 0.7
epochs = 8

ParT_Full.train(dataset=training_data, epochs=epochs, train_val_split=0.3, 
                batch_size=batch_size, dynamic_lr=True, max_lr=0.001, 
                sustain_ratio=sustain_ratio, early_stopping=False)



In [None]:
ParT_Full = ParticleTransformerWrapper(d_model=[128, 512, 128], d_pair_model=[64,64,64], 
                                        max_number_partons=num_partons, input_dim=11, 
                                        pairwise_outputs=['kt', 'z', 'delta', 'm2'], 
                                       num_heads=8, N_p=8, N_c=2, output_dims=[[1,0]], 
                                       model_name='ParT_Full')
ParT_Plain = ParticleTransformerWrapper(d_model=[128, 512, 128], d_pair_model=[64,64,64], 
                                        max_number_partons=num_partons, input_dim=11, 
                                        pairwise_outputs=['kt', 'z', 'delta', 'm2'], 
                                        num_heads=8, N_p=8, N_c=2, output_dims=[[1,0]], 
                                        isPlain=True, model_name='ParT_Plain')

ParT_m2 = ParticleTransformerWrapper(d_model=[128, 512, 128], d_pair_model=[64,64,64], 
                                        max_number_partons=num_partons, input_dim=11, 
                                        pairwise_outputs=['m2'], num_heads=8, N_p=8, N_c=2, 
                                        output_dims=[[1,0]], model_name='ParT_m2')

ParT_delta = ParticleTransformerWrapper(d_model=[128, 512, 128], d_pair_model=[64,64,64], 
                                        max_number_partons=num_partons, input_dim=11, 
                                        pairwise_outputs=['delta'], num_heads=8, N_p=8, N_c=2, 
                                        output_dims=[[1,0]], model_name='ParT_delta')

ParT_kt = ParticleTransformerWrapper(d_model=[128, 512, 128], d_pair_model=[64,64,64], 
                                        max_number_partons=num_partons, input_dim=11, 
                                        pairwise_outputs=['kt'], num_heads=8, N_p=8, N_c=2, 
                                        output_dims=[[1,0]], model_name='ParT_kt')

ParT_z = ParticleTransformerWrapper(d_model=[128, 512, 128], d_pair_model=[64,64,64], 
                                        max_number_partons=num_partons, input_dim=11, 
                                        pairwise_outputs=['z'], num_heads=8, N_p=8, N_c=2, 
                                        output_dims=[[1,0]], model_name='ParT_z')

Training each version for 8 epochs, with a learning rate of 0.001 for 6 epochs and exponentially decaying towards 0.0001 for 2 epochs.