In [None]:
import torch
from typing import Tuple
from torch.utils.data import DataLoader
from torch.utils.data import Dataset

import numpy as np
from torch import Tensor

import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint

from torch import nn

In [None]:
class TTBarDataset(Dataset):
    def __init__(self,
                 index: [float, float]=[0.0, 1.0]):
        
        data=np.load("../sample_data/parton_level.npy")
        num = data.shape[0]
        source =[]
        for k in data:
            particles =k[0:16].reshape((4,4))
            source.append(particles)
        self.source = np.array(source)        
        
        index = (int(round(index[0] * num)), int(round(index[1] * num)))
        indices = np.arange(num)[index[0]:index[1]]
        self.num_samples = indices.shape[0]
        
        self.source = self.source[indices]
        self.indices = indices
        
        source = self.source
        source_inv = source[:,[0,3,2,1]]
        
        label = np.ones(self.num_samples)
        label_inv = np.zeros(self.num_samples)
        
        source = np.concatenate((source, source_inv))
        label = np.concatenate((label, label_inv))
        
        idx = np.arange(self.num_samples*2)
        np.random.shuffle(idx)
        source, label =  source[idx], label[idx]
        
        
        
        self.source = torch.from_numpy(source).float()
        self.targets = torch.from_numpy(label).float() 
        
    def __len__(self):
        return self.num_samples
    
    
    
    def __getitem__(self, idx):
        x = np.copy(self.source[idx])
        y = np.copy(self.targets[idx])
        
        return x, y

class TTBarDataset_detector(Dataset):
    def __init__(self,
                 index: [float, float]=[0.0, 1.0]):
        data=np.load("../sample_data/detector_level.npy")
        num = data.shape[0]
        source =[]
        for k in data:
            particles =k[0:16].reshape((4,4))
            source.append(particles)
        self.source = np.array(source)        
        
        index = (int(round(index[0] * num)), int(round(index[1] * num)))
        indices = np.arange(num)[index[0]:index[1]]
        self.num_samples = indices.shape[0]
        
        self.source = self.source[indices]
        self.indices = indices
        
        source = self.source
        source_inv = source[:,[0,3,2,1]]
        
        label = np.ones(self.num_samples)
        label_inv = np.zeros(self.num_samples)
        
        source = np.concatenate((source, source_inv))
        label = np.concatenate((label, label_inv))
        
        idx = np.arange(self.num_samples*2)
        np.random.shuffle(idx)
        source, label =  source[idx], label[idx]
        
        
        
        self.source = torch.from_numpy(source).float()
        self.targets = torch.from_numpy(label).float()
        
    def __len__(self):
        return self.num_samples
    
    def __getitem__(self, idx):
        x = np.copy(self.source[idx])
        y = np.copy(self.targets[idx])

        return x, y

In [None]:
def create_linear_stack(input_dim, output_dim):
    layers = [nn.Linear(input_dim, output_dim)]

    
    layers.append(nn.PReLU(output_dim))
    

    
    layers.append(nn.BatchNorm1d(output_dim))

    if dropout > 0.0:
        layers.append(nn.Dropout(dropout))

    return layers


def create_linear_layers(num_layers: int, hidden_dim: int):
    layers = []

    for _ in range(num_layers):
        layers.extend(create_linear_stack(hidden_dim, hidden_dim))

    return nn.Sequential(*layers)

In [None]:
class ParticleEncoder(nn.Module):
    def __init__(self, input_dim: int, transformer_options: Tuple[int, int, int, float, str]):
        super().__init__()

        self.embedding = self.create_embedding_layers(input_dim)

        self.encoder_layer = nn.TransformerEncoderLayer
        self.encoder = nn.TransformerEncoder(self.encoder_layer(*transformer_options), num_encoder_layers)
        

    def create_embedding_layers(self, input_dim):
        current_embedding_dim = initial_embedding_dim
        embedding_layers = create_linear_stack(input_dim, current_embedding_dim)

        for i in range(num_embedding_layers):
            next_embedding_dim = 2 * current_embedding_dim
            if next_embedding_dim >= hidden_dim:
                break

            embedding_layers.extend(create_linear_stack(current_embedding_dim, next_embedding_dim))
            current_embedding_dim = next_embedding_dim

        embedding_layers.extend(create_linear_stack(current_embedding_dim, hidden_dim))

        return nn.Sequential(*embedding_layers)

    def forward(self, x: torch.Tensor) -> Tuple[Tensor, Tensor]:
        batch_size, max_particles, input_dim = x.shape

        hidden = self.embedding(x.view(-1, input_dim))
        hidden = hidden.view(batch_size, max_particles, hidden_dim)

        hidden = hidden.transpose(0, 1)

        hidden = self.encoder(hidden) 

        return hidden

In [None]:
#hyperparameters
batch_size = 512
num_dataloader_workers = 16
hidden_dim = 64
num_attention_heads = 4
dropout = 0.1
transformer_activation = 'relu'
initial_embedding_dim = 8
num_encoder_layers = 3
learning_rate = 0.001
l2_penalty = 9e-05
num_embedding_layers = 8

In [None]:
training_dataset = TTBarDataset(index=[0.0,0.8])
validation_dataset = TTBarDataset(index=[0.8,0.9])
testing_dataset = TTBarDataset(index=[0.9,1.0])
# training_dataset = TTBarDataset_detector(index=[0.0,0.8])
# validation_dataset = TTBarDataset_detector(index=[0.8,0.9])
# testing_dataset = TTBarDataset_detector(index=[0.9,1.0])

In [None]:
def train_dataloader() -> DataLoader:
    return DataLoader(training_dataset,
                        batch_size=batch_size,
                        shuffle=True,
                        drop_last=True,
                        num_workers=num_dataloader_workers,
                        pin_memory=True)
def val_dataloader() -> DataLoader:
    return DataLoader(validation_dataset,
                          batch_size=batch_size,
                          shuffle=False,
                          drop_last=True,
                          num_workers=num_dataloader_workers,
                          pin_memory=True)

def testing_dataloader() -> DataLoader:
    return DataLoader(testing_dataset,
                          batch_size=batch_size,
                          shuffle=False,
                          drop_last=True,
                          num_workers=num_dataloader_workers,
                          pin_memory=True)

In [None]:
class ttbarNetwork(pl.LightningModule):
    def __init__(self):
        super().__init__()

        self.hidden_dim = hidden_dim
        self.num_particles = 4

        transformer_options = (self.hidden_dim,
                               num_attention_heads,
                               self.hidden_dim,
                               dropout,
                               transformer_activation)

        self.encoder = ParticleEncoder( 4, transformer_options)

        
        self.loss = nn.BCELoss()
        self.relu = nn.ReLU()
        
        self.activation = nn.Sigmoid()
        self.embedding2 = nn.Sequential(nn.Linear(hidden_dim*4,hidden_dim))
        self.embedding3 = nn.Sequential(nn.Linear(hidden_dim,1))
        self.hid_dim = hidden_dim

    def forward(self, x: torch.Tensor) -> Tuple[Tensor, Tensor]:
        # Extract features from data using transformer
        input_dim = self.hid_dim 
        q = self.encoder(x)
        a,batch_size, b = q.shape
        x = q.transpose(0, 1)
        x = x.reshape(-1, a*b)
        x = self.embedding2(x)
        x = self.relu(x)

        x = self.embedding3(x)
        

        output = self.activation(x)

        return output.view(-1)


    def training_step(self, batch, batch_nb):
        x, targets = batch

        predictions = self.forward(x)
        

        comb_loss = self.loss(predictions, targets)
        
        loss = torch.mean(comb_loss)

        if torch.isnan(loss):
            raise ValueError("Training loss has diverged.")

        self.log("train_loss", loss)

        return loss

    @staticmethod
    def accuracy(predictions: Tensor, targets: Tensor) -> Tensor:
        """ Compute single top and eventy accuracy for a batch. """
        l_predictions = predictions.clone()
        
        l_predictions = l_predictions.round()
        

        l_targets = targets.clone()

        accuracy = l_targets == l_predictions


        return accuracy

    def configure_optimizers(self) -> torch.optim.Optimizer:
        optimizer = torch.optim.Adam

        return optimizer(self.parameters(), lr=learning_rate, weight_decay=l2_penalty)
    
    def validation_step(self, batch, batch_idx):
        x, targets = batch
        predictions = self.forward(x)
        all_loss = self.loss(predictions, targets)
        
        val_loss = torch.mean(all_loss)
        self.log("val_loss", val_loss)

        accuracy = self.accuracy(predictions, targets)


        accuracy = accuracy.float().mean()

        self.log("accuracy", accuracy)

        return {"accuracy": accuracy}

    def validation_epoch_end(self, outputs):
        average_accuracy = torch.mean(torch.stack([x['accuracy'] for x in outputs]))

        print(average_accuracy)

In [None]:
model = ttbarNetwork()
trainer = pl.Trainer(max_epochs=300, gpus=1, precision= 32,callbacks=[checkpoint_callback])
trainer.fit(model, train_dataloader(), val_dataloader())

In [None]:
model = ttbarNetwork.load_from_checkpoint(checkpoint_callback.best_model_path)

In [None]:
def check_accuracy(loader, model):
    num_correct = 0
    num_samples = 0
    model.to(device='cuda:0')
    model.eval()
    
    with torch.no_grad():
        for x, y in loader:
            x = x.to(device='cuda:0')
            y = y.to(device='cuda:0')
            
            scores = model(x)
            predictions = scores.round()
            num_correct += (predictions == y).sum()
            num_samples += predictions.size(0)
        
    return float(num_correct)/float(num_samples)

In [None]:
print(check_accuracy(testing_dataloader(), model))