In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
from sklearn.metrics import mean_absolute_error, root_mean_squared_error, r2_score, roc_auc_score

In [2]:
class CustomDataset(Dataset):
    def __init__(self):
        self.lepton = 0
        self.nu = 0
        self.probe_jet = 0
        self.probe_jet_constituents = 0
        self.balance_jets = 0
        self.labels = 0
        self.track_labels = 0
    def __getitem__(self, idx):
        return self.lepton[idx], self.nu[idx], self.probe_jet[idx], self.probe_jet_constituents[idx], self.balance_jets[idx], self.labels[idx], self.track_labels[idx]
    def __len__(self):
        return len(self.lepton)

In [3]:
tag="U_1M"

In [4]:
dset = torch.load("dataset"+tag+".pt", weights_only=False)
train_dataset, test_dataset = torch.utils.data.random_split(dset, [0.8, 0.2])
train_loader = DataLoader(train_dataset, batch_size=256)
test_loader = DataLoader(test_dataset, batch_size=256)

In [5]:
class Encoder(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super(Encoder, self).__init__()
        self.pre_norm_Q = nn.LayerNorm(embed_dim)
        self.pre_norm_K = nn.LayerNorm(embed_dim)
        self.pre_norm_V = nn.LayerNorm(embed_dim)
        self.attention = nn.MultiheadAttention(embed_dim,num_heads=num_heads,batch_first=True, dropout=0.25)
        self.post_norm = nn.LayerNorm(embed_dim)
        self.out = nn.Linear(embed_dim,embed_dim)
    def forward(self, Query, Key, Value):
        Query = self.pre_norm_Q(Query)
        Key = self.pre_norm_K(Key)
        Value = self.pre_norm_V(Value)
        context, weights = self.attention(Query, Key, Value)
        context = self.post_norm(context)
        latent = Query + context
        tmp = F.gelu(self.out(latent))
        latent = latent + tmp
        return_weights=False
        if return_weights:
            return latent,weights
        return latent

In [6]:
class Model(nn.Module):  
    def __init__(self, embed_dim, num_heads):
        super(Model, self).__init__()   
        
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        
        # Initiliazer
        self.lepton_initializer = nn.Linear(4, self.embed_dim)
        self.MET_initializer = nn.Linear(2, self.embed_dim)
        self.probe_jet_initializer = nn.Linear(3, self.embed_dim)
        self.probe_jet_constituent_initializer = nn.Linear(4, self.embed_dim)
        self.small_jet_initializer = nn.Linear(3, self.embed_dim)
           
        # Transformer Stack
        self.event_encoder1 = Encoder(self.embed_dim, self.num_heads)
        self.event_encoder2 = Encoder(self.embed_dim, self.num_heads)
        self.event_encoder3 = Encoder(self.embed_dim, self.num_heads)
        self.event_encoder4 = Encoder(self.embed_dim, self.num_heads)
        self.event_encoder5 = Encoder(self.embed_dim, self.num_heads)
        
        # Kinematics Regression
        self.top_regression_input = nn.Linear(self.embed_dim, self.embed_dim)
        self.top_regression = nn.Linear(self.embed_dim, 4)
        self.down_regression_input = nn.Linear(self.embed_dim, self.embed_dim)
        self.down_regression = nn.Linear(self.embed_dim, 4)

        # Direct Regression Task
        self.direct_input = nn.Linear(8, self.embed_dim)
        self.direct_regression = nn.Linear(self.embed_dim, 1)
        
        # Track Classification
        self.track_classification = nn.Linear(self.embed_dim, 3)

    def forward(self, lepton, MET, probe_jet, probe_jet_constituent, small_jet):
        
        # Feature initialization layers
        lepton_embedding = torch.unsqueeze(F.gelu(self.lepton_initializer(lepton)), dim=1)
        MET_embedding = torch.unsqueeze(F.gelu(self.MET_initializer(MET)), dim=1)
        probe_jet_embedding = torch.unsqueeze(F.gelu(self.probe_jet_initializer(probe_jet)), dim=1)
        probe_jet_constituent_embedding = F.gelu(self.probe_jet_constituent_initializer(probe_jet_constituent))
        small_jet_embedding = F.gelu(self.small_jet_initializer(small_jet))
        
        num_leptons = lepton_embedding.shape[1]
        num_MET = MET_embedding.shape[1]
        num_probe_jet = probe_jet_embedding.shape[1]
        num_constituents = probe_jet_constituent_embedding.shape[1]
        num_small_jets = small_jet_embedding.shape[1]
        
        # Combine objects into single event tensor
        event_embedding = torch.cat([probe_jet_embedding, probe_jet_constituent_embedding, lepton_embedding, MET_embedding, small_jet_embedding], axis=1)
        
        # Event Level Attention
        event_embedding = self.event_encoder1(event_embedding,event_embedding,event_embedding)
        event_embedding = self.event_encoder2(event_embedding,event_embedding,event_embedding)
        event_embedding = self.event_encoder3(event_embedding,event_embedding,event_embedding)
        #event_embedding = self.event_encoder4(event_embedding,event_embedding,event_embedding)
        #event_embedding = self.event_encoder5(event_embedding,event_embedding,event_embedding)
        
        # Track Classificiation
        start_idx=num_probe_jet
        end_idx=start_idx+num_constituents
        track_tensor = event_embedding[:,start_idx:end_idx]
        track_output = self.track_classification(track_tensor)
        
        # Probe jet classification
        start_idx=0
        end_idx=num_probe_jet
        probe_jet_tensor = torch.squeeze(event_embedding[:,start_idx:end_idx], dim=1)
        
        # Get Top output
        top_kinematics = F.gelu(self.top_regression_input(probe_jet_tensor))
        top_kinematics = self.top_regression(top_kinematics)
        
        # Get Down output
        down_kinematics = F.gelu(self.down_regression_input(probe_jet_tensor))
        down_kinematics = self.down_regression(down_kinematics)
        down_kinematics = torch.cat([F.tanh(down_kinematics[:,0:3]), down_kinematics[:,3:]], axis=1)
        
        # Direct Regression output
        direct_embedding = torch.cat([top_kinematics, down_kinematics], axis=1)
        direct_embedding = F.gelu(self.direct_input(direct_embedding))
        costheta = self.direct_regression(direct_embedding)
        
        # Construct output
        output = torch.cat([top_kinematics, down_kinematics, costheta], axis=1)

        return output, track_output

In [7]:
print("GPU Available: ", torch.cuda.is_available())
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

GPU Available:  True
cuda:0


In [8]:
model = Model(32,4).to(device)

optimizer = optim.AdamW(model.parameters(), lr=0.0001)

MSE_loss_fn = nn.MSELoss()
CCE_loss_fn = nn.CrossEntropyLoss()

In [9]:
print("Trainable Parameters :", sum(p.numel() for p in model.parameters() if p.requires_grad))

Trainable Parameters : 31148


In [10]:
for lepton, MET, probe_jet, constituents, small_jet, labels, track_labels in train_loader:
    batch_size = len(lepton)
    print(model(lepton.to(device), MET.to(device), probe_jet.to(device), constituents.to(device), small_jet.to(device)))
    break

(tensor([[-0.1239, -0.1371, -0.8665,  ..., -0.1440,  0.0027, -0.0483],
        [-0.2162,  0.0592, -0.7414,  ...,  0.1410,  0.1185, -0.0281],
        [-0.3144,  0.0642, -0.7334,  ...,  0.0894,  0.1028, -0.0337],
        ...,
        [-0.1431, -0.0591, -0.7414,  ..., -0.1682,  0.0924, -0.0438],
        [-0.0967, -0.1011, -0.7617,  ..., -0.2614, -0.0172, -0.0228],
        [-0.0734, -0.1476, -0.9021,  ..., -0.2058, -0.0434, -0.0404]],
       device='cuda:0', grad_fn=<CatBackward0>), tensor([[[-2.1215,  0.7314, -0.4901],
         [-2.2045,  0.6428, -0.5950],
         [-2.2467,  0.8070, -0.5939],
         ...,
         [-2.1221,  0.5342, -0.3118],
         [-2.1414,  0.4034, -0.2959],
         [-2.1710,  0.4695, -0.3263]],

        [[-1.7370,  0.8895,  0.2570],
         [-1.8859,  0.9233,  0.2062],
         [-1.8379,  0.8849,  0.2634],
         ...,
         [-1.5917,  0.6052,  0.4240],
         [-1.6936,  0.5915,  0.3825],
         [-1.6645,  0.6231,  0.4238]],

        [[-2.0263,  0.7359, 

In [11]:
def train(model, optimizer, train_loader, val_loader, epochs=40):
    
    combined_history = []
    
    for e in range(epochs):
        model.train()
        cumulative_loss_train = 0
        cumulative_loss_val = 0
        
        num_train = len(train_loader)*batch_size
        num_val = len(val_loader)*batch_size

        for lepton, MET, probe_jet, constituents, small_jet, labels, track_labels in train_loader:
            optimizer.zero_grad()
            
            output, trk_output = model(lepton.to(device), MET.to(device), probe_jet.to(device), constituents.to(device), small_jet.to(device))
            
            top_loss      = MSE_loss_fn(output[:,0:4], labels[:,0:4].to(device))
            down_loss     = MSE_loss_fn(output[:,4:8], labels[:,4:8].to(device))
            costheta_loss = MSE_loss_fn(output[:,-1], labels[:,-1].to(device))
            track_loss    = CCE_loss_fn(trk_output, track_labels.to(device))
            
            alpha = 1
            beta  = 1
            gamma = 1
            delta = 1
            loss  = alpha*top_loss + beta*down_loss + gamma*costheta_loss + delta*track_loss
            
            loss.backward()
            optimizer.step()
            
            cumulative_loss_train+=loss.detach().cpu().numpy().mean()
            
        cumulative_loss_train = cumulative_loss_train / num_train
        
        model.eval()
        for lepton, MET, probe_jet, constituents, small_jet, labels, track_labels in val_loader:
            output, trk_output = model(lepton.to(device), MET.to(device), probe_jet.to(device), constituents.to(device), small_jet.to(device))
            
            top_loss      = MSE_loss_fn(output[:,0:4], labels[:,0:4].to(device))
            down_loss     = MSE_loss_fn(output[:,4:8], labels[:,4:8].to(device))
            costheta_loss = MSE_loss_fn(output[:,-1], labels[:,-1].to(device))
            track_loss    = CCE_loss_fn(trk_output, track_labels.to(device))
            
            loss  = alpha*top_loss + beta*down_loss + gamma*costheta_loss + delta*track_loss

            cumulative_loss_val+=loss.detach().cpu().numpy().mean()
        
        cumulative_loss_val = cumulative_loss_val / num_val
        
        combined_history.append([cumulative_loss_train, cumulative_loss_val])

        if e%1==0:
            print('Epoch:',e+1,'\tTrain Loss:',round(cumulative_loss_train,6),'\tVal Loss:',round(cumulative_loss_val,6))
            
    return np.array(combined_history)

In [None]:
history = train(model, optimizer, train_loader, test_loader, epochs=100)

Epoch: 1 	Train Loss: 301.75235 	Val Loss: 274.63794


In [None]:
plt.figure()
plt.plot(history[20:,0], label="Train")
plt.plot(history[20:,1], label="Val")
plt.title('Loss')
plt.legend()
plt.yscale('log')
plt.show()

In [None]:
num_feats=9
pred_labels = np.array([]).reshape(0,num_feats)
true_labels = np.array([]).reshape(0,num_feats)
for lepton, MET, probe_jet, constituents, small_jet, labels, track_labels in test_loader:
    output, trk_output = model(lepton.to(device), MET.to(device), probe_jet.to(device), constituents.to(device), small_jet.to(device))
    pred_labels = np.vstack((pred_labels,output.detach().cpu().numpy()))
    true_labels = np.vstack((true_labels,labels.detach().cpu().numpy()))

In [None]:
feats = ['top_px','top_py','top_pz','top_E','down_px','down_py','down_pz','down_E', 'costheta']
ranges = [(-1000,1000),(-1000,1000),(-1000,1000),(0,1500),(-1.5,1.5),(-1.5,1.5),(-1.5,1.5),(0,150),(-1.5,1.5)]

In [None]:
print("Plotting predictions...")
for i in range(num_feats):
    plt.figure()
    plt.hist(np.ravel(true_labels[:,i]),histtype='step',color='r',label='True Distribution',bins=50,range=ranges[i])
    plt.hist(np.ravel(pred_labels[:,i]),histtype='step',color='b',label='Predicted Distribution',bins=50,range=ranges[i])
    plt.title("Predicted Ouput Distribution using Attention Model")
    plt.legend()
    plt.yscale('log')
    plt.xlabel(feats[i],loc='right')
    #plt.savefig(out_dir+"/pred_1d_"+feats[i]+".png")
    plt.show()
    plt.close()

    plt.figure()
    plt.title("Ouput Distribution using Attention Model")
    plt.hist2d(np.ravel(pred_labels[:,i]),np.ravel(true_labels[:,i]), bins=100,norm=mcolors.LogNorm(),range=(ranges[i],ranges[i]))
    plt.xlabel('Predicted '+feats[i],loc='right')
    plt.ylabel('True '+feats[i],loc='top')
    diff = ranges[i][1] - ranges[i][0]
    plt.text(ranges[i][1]-0.3*diff,ranges[i][0]+0.2*diff,"$R^2$ value: "+str(round(r2_score(np.ravel(true_labels[:,i]),np.ravel(pred_labels[:,i])),3)),backgroundcolor='r',color='k')
    #print("R^2 value: ", round(r2_score(true_labels[:,i],predicted_labels[:,i]),3))
    #plt.savefig(out_dir+"/pred_2d_"+feats[i]+".png")
    plt.show()
    plt.close()