This file contains GNN based solution for specific task of project [Graph Neural Networks for End-to-End Particle Identification with the CMS Experiment](https://docs.google.com/document/d/1lWTSASnVICm_4Zof7wr6_LkS24P_Z8TR1px_tctemQI/edit).

GNN layers been used:

- [Graph Convolution Layer](https://arxiv.org/pdf/1609.02907.pdf)

- [PointNet Convolution Layer](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.conv.PointNetConv.html#torch_geometric.nn.conv.PointNetConv) 

In both, model architecture is composed of two layers.
Latent embedding dimension is set to 300.

Node features:
- Channel values
- Global Positional encoding (3D coordinates of the nodes) - optional

Edge features:
- Euclidean distance between nodes.

Both models are trained for 75 epochs.

In [None]:
import os
from tqdm import tqdm
import torch
from torch_geometric.nn import MessagePassing,GPSConv, GINEConv
from torch_geometric.nn import global_add_pool, global_mean_pool, global_max_pool, GlobalAttention, Set2Set
import torch.nn.functional as F
import torch_geometric.transforms as T
from torch_geometric.utils import degree
from torch.utils.data import random_split
from torch_geometric.loader import DataLoader
import torch.optim as optim
from torchmetrics.classification import MulticlassAUROC, MulticlassAccuracy

from dataset import JetsGraphsDataset
from torch_geometric.nn.conv import GATConv,PointNetConv

In [None]:
### GIN convolution along the graph structure
class GINConv(MessagePassing):
    def __init__(self, emb_dim,input_node_dim,input_edge_dim):

        super(GINConv, self).__init__(aggr = "add")

        self.mlp = torch.nn.Sequential(torch.nn.Linear(emb_dim, 2*emb_dim), torch.nn.BatchNorm1d(2*emb_dim), 
                                       torch.nn.ReLU(), torch.nn.Linear(2*emb_dim, emb_dim))
        self.eps = torch.nn.Parameter(torch.Tensor([0]))
        self.linear = torch.nn.Linear(input_node_dim, emb_dim)
        self.edge_encoder = torch.nn.Linear(input_edge_dim, emb_dim)

    def forward(self, x, edge_index, edge_attr):
        x = self.linear(x)
        edge_embedding = self.edge_encoder(edge_attr)
        out = self.mlp((1 + self.eps) *x + self.propagate(edge_index, x=x, edge_attr=edge_embedding))

        return out

    def message(self, x_j, edge_attr):
        return F.relu(x_j + edge_attr)

    def update(self, aggr_out):
        return aggr_out

### GCN convolution along the graph structure
class GCNConv(MessagePassing):
    def __init__(self, emb_dim,input_node_dim,input_edge_dim):
        super(GCNConv, self).__init__(aggr='add')
        
        self.linear = torch.nn.Linear(input_node_dim, emb_dim)
        self.root_emb = torch.nn.Embedding(1, emb_dim)
        self.edge_encoder = torch.nn.Linear(input_edge_dim, emb_dim)

    def forward(self, x, edge_index, edge_attr):
        x = self.linear(x)
        edge_embedding = self.edge_encoder(edge_attr)

        row, col = edge_index

        #edge_weight = torch.ones((edge_index.size(1), ), device=edge_index.device)
        deg = degree(row, x.size(0), dtype = x.dtype) + 1
        deg_inv_sqrt = deg.pow(-0.5)
        deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0

        norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]

        return self.propagate(edge_index, x=x, edge_attr = edge_embedding, norm=norm) + F.relu(x + self.root_emb.weight) * 1./deg.view(-1,1)

    def message(self, x_j, edge_attr, norm):
        return norm.view(-1, 1) * F.relu(x_j + edge_attr)

    def update(self, aggr_out):
        return aggr_out

In [None]:
class mlp(torch.nn.Module):
    def __init__(self,input_node_dim,emb_dim):
        super(mlp, self).__init__()
        self.mlp = torch.nn.Sequential(torch.nn.Linear(input_node_dim, 2*emb_dim), 
                                       torch.nn.BatchNorm1d(2*emb_dim), 
                                       torch.nn.ReLU(), 
                                       torch.nn.Linear(2*emb_dim, emb_dim))
    def forward(self,x):
        return self.mlp(x)

In [None]:
class MessagePasssing_Module(torch.nn.Module):
    """
    MessagePasssing_Module contains 2 or more GNN layers stacked.
    Output:
        node representations
    """
    def __init__(self, num_layer, input_node_dim, input_edge_dim, emb_dim,extraPE_dim=None,
                 extraPE_method='sum', drop_ratio = 0.5, JK = "last", residual = False, gnn_type = 'gin'):
        '''
            emb_dim (int): node embedding dimensionality
            num_layer (int): number of GNN message passing layers
        '''
        super(MessagePasssing_Module, self).__init__()
        
        self.gnn_type = gnn_type
        self.num_layer = num_layer
        self.drop_ratio = drop_ratio
        self.JK = JK
        self.input_node_dim = input_node_dim
        self.input_edge_dim = input_edge_dim
        ### add residual connection or not
        self.residual = residual
        self.extraPE_dim = extraPE_dim
        self.extraPE_method = extraPE_method

        if self.num_layer < 2:
            raise ValueError("Number of GNN layers must be greater than 1.")

        ### List of GNNs
        self.convs = torch.nn.ModuleList()
        self.batch_norms = torch.nn.ModuleList()
        
        if extraPE_dim:
            # this is to transform random walk encoding
            self.extraPE_Encoder = torch.nn.Linear(self.extraPE_dim,self.input_node_dim)
        
        if extraPE_dim and extraPE_method=='cat':
            self.input_node_dim *= 2
            
        for layer in range(num_layer):
            if layer == 0:
                if gnn_type == 'gin':
                    self.convs.append(GINConv(emb_dim,input_node_dim=self.input_node_dim,input_edge_dim=self.input_edge_dim))
                elif gnn_type == 'gcn':
                    self.convs.append(GCNConv(emb_dim,input_node_dim=self.input_node_dim,input_edge_dim=self.input_edge_dim))
                elif gnn_type == 'gat':
                    self.convs.append(GATConv(in_channels=self.input_node_dim,out_channels=emb_dim,edge_dim=self.input_edge_dim))
                elif gnn_type == "pointnet":
                    local_mlp = mlp(self.input_node_dim+3, emb_dim)
                    global_mlp = None
                    self.convs.append(PointNetConv(local_mlp,global_mlp))
                    
                elif gnn_type == "gps":
                    # we need to explicitly declare this mlp
                    nn = torch.nn.Sequential(torch.nn.Linear(self.input_node_dim, 2*emb_dim),
                                           torch.nn.BatchNorm1d(2*emb_dim), 
                                           torch.nn.ReLU(), 
                                           torch.nn.Linear(2*emb_dim, emb_dim))
                    self.convs.append(GPSConv(self.input_node_dim, GINEConv(nn,edge_dim =self.input_edge_dim), 
                                              heads=5, attn_dropout=0.3))
                else:
                    raise ValueError('Undefined GNN type called {}'.format(gnn_type))
                
            else:
                if gnn_type == 'gin':
                    self.convs.append(GINConv(emb_dim,input_node_dim=emb_dim,input_edge_dim=self.input_edge_dim))
                elif gnn_type == 'gcn':
                    self.convs.append(GCNConv(emb_dim,input_node_dim=emb_dim,input_edge_dim=self.input_edge_dim))
                elif gnn_type == 'gat':
                    self.convs.append(GATConv(in_channels=emb_dim,out_channels=emb_dim,edge_dim=self.input_edge_dim))
                elif gnn_type == "pointnet":
                    local_mlp = mlp(emb_dim+3, emb_dim)
                    global_mlp = None
                    self.convs.append(PointNetConv(local_mlp,global_mlp))
                    
                elif gnn_type == "gps":
                    # we need to explicitly declare this mlp
                    nn = torch.nn.Sequential(torch.nn.Linear(emb_dim, 2*emb_dim), 
                                       torch.nn.BatchNorm1d(2*emb_dim), 
                                       torch.nn.ReLU(), 
                                       torch.nn.Linear(2*emb_dim, emb_dim))
                    self.convs.append(GPSConv(emb_dim, GINEConv(nn,edge_dim =self.input_edge_dim),
                                              heads=5, attn_dropout=0.3))
                else:
                    raise ValueError('Undefined GNN type called {}'.format(gnn_type))

            self.batch_norms.append(torch.nn.BatchNorm1d(emb_dim))

    def forward(self, batched_data):
        
        x = batched_data.x
        edge_index = batched_data.edge_index
        edge_attr = batched_data.edge_attr
        pos = batched_data.pos
        batch = batched_data.batch
        
        if self.extraPE_dim:
            extraPE = batched_data.extraPE
            extraPE_emb = self.extraPE_Encoder(extraPE)
            if self.extraPE_method == 'sum':
                h_list = [x+extraPE_emb]
            elif self.extraPE_method == 'cat':
                h_list = [torch.cat((x,extraPE_emb),1)]      
        else:    
            h_list = [x]  
            
        for layer in range(self.num_layer):
            if self.gnn_type == 'pointnet':
                h = self.convs[layer](h_list[layer], pos, edge_index)
            elif self.gnn_type == 'gps':
                h = self.convs[layer](h_list[layer], edge_index,batch=batch,edge_attr=edge_attr)
            else:    
                h = self.convs[layer](h_list[layer], edge_index, edge_attr)
                
            h = self.batch_norms[layer](h)

            if layer == self.num_layer - 1:
                #remove relu for the last layer
                h = F.dropout(h, self.drop_ratio, training = self.training)
            else:
                h = F.dropout(F.relu(h), self.drop_ratio, training = self.training)

            if self.residual:
                h += h_list[layer]

            h_list.append(h)

        ### Different implementations of Jk-concat
        if self.JK == "last":
            node_representation = h_list[-1]
        elif self.JK == "sum":
            node_representation = 0
            for layer in range(self.num_layer + 1):
                node_representation += h_list[layer]

        return node_representation


In [None]:
class GNN(torch.nn.Module):

    def __init__(self, num_classes=2, num_layer = 5,num_pre_fnn_layers =0,num_post_fnn_layers =1,hasPos =True,num_coords=3, 
                 input_spec_fts_dim=3,input_edge_dim = 1, emb_dim = 300,extraPE_dim=None,extraPE_method = 'sum', gnn_type = 'gcn', 
                 residual = False, drop_ratio = 0.5, JK = "last", graph_pooling = "mean"):
        '''
            hasPos (bool) : whether input node features should contain global positioning embeded
                            ps: global positioning is the coordinate of the pixel on 2D grid.
            input_spec_fts_dim (int) : denotes number of specific features (features apart from postional embedding)
            num_coords : number of coordinates required for the positional embedding 
            extraPE_dim: Denotes dimension of random walk  or Laplacian eigenvector positional encoding
            extraPE_method: Denotes how random walk embeddings or Laplacian eigenvector positional encoding should be embedded 
                         cat - concatenation, sum - summation.
        '''
        
        super(GNN, self).__init__()
        
        self.gnn_type = gnn_type
        self.num_layer = num_layer
        self.drop_ratio = drop_ratio
        self.JK = JK
        self.emb_dim = emb_dim
        self.hasPos = hasPos
        self.num_coords =num_coords
        self.num_classes = num_classes
        self.num_pre_fnn_layers = num_pre_fnn_layers
        self.num_post_fnn_layers = num_post_fnn_layers
        self.graph_pooling = graph_pooling
        self.input_spec_fts_dim = input_spec_fts_dim
        self.input_edge_dim = input_edge_dim
        self.extraPE_method = extraPE_method
        
        if self.gnn_type=="pointnet":
            self.hasPos = False
        
        self.pos_kwd = "hasPos"
        if not self.hasPos:
            self.pos_kwd = "noPos"
            
        if not self.hasPos:
            self.input_node_dim = self.input_spec_fts_dim
        else:
            self.input_node_dim = self.input_spec_fts_dim+num_coords

        if self.num_layer < 2:
            raise ValueError("Number of GNN layers must be greater than 1.")
            
        if self.num_post_fnn_layers < 1:
            raise ValueError("Number of GNN layers must be greater than or equal to 1.")
        
        self.graph_pred_pre_linear_list = torch.nn.ModuleList()
        
        # dimention of node fts which are fed into message passing layers
        self.input_node_dim_mp = self.input_node_dim
        
        if self.num_pre_fnn_layers >0:
            self.graph_pred_pre_linear_list.append(torch.nn.Linear(self.input_node_dim, emb_dim))
            for i in range(1,num_pre_fnn_layers):
                self.graph_pred_pre_linear_list.append(torch.nn.Linear(emb_dim, emb_dim))
            self.input_node_dim_mp = emb_dim
                 
        ### GNN to generate node embeddings
        self.gnn_node = MessagePasssing_Module(num_layer,input_node_dim=self.input_node_dim_mp,
                                               input_edge_dim = self.input_edge_dim, emb_dim=emb_dim,
                                               extraPE_dim = extraPE_dim, extraPE_method = self.extraPE_method,
                                               JK = JK, drop_ratio = drop_ratio, residual = residual, 
                                               gnn_type = gnn_type)

        ### Pooling function to generate entire-graph embeddings
        if self.graph_pooling == "sum":
            self.pool = global_add_pool
        elif self.graph_pooling == "mean":
            self.pool = global_mean_pool
        elif self.graph_pooling == "max":
            self.pool = global_max_pool
        elif self.graph_pooling == "attention":
            self.pool = GlobalAttention(gate_nn = torch.nn.Sequential(torch.nn.Linear(emb_dim, 2*emb_dim), 
                                                                      torch.nn.BatchNorm1d(2*emb_dim), torch.nn.ReLU(),
                                                                      torch.nn.Linear(2*emb_dim, 1)))
        else:
            raise ValueError("Invalid graph pooling type.")

        self.graph_pred_post_linear_list = torch.nn.ModuleList()

        for i in range(num_post_fnn_layers-1):
            self.graph_pred_post_linear_list.append(torch.nn.Linear(emb_dim, emb_dim))
        self.graph_pred_post_linear_list.append(torch.nn.Linear(emb_dim, self.num_classes))
                


    def forward(self, batched_data):

        input_x = batched_data.x # here we can split the x 
        
        batched_data.pos = input_x[:,self.input_spec_fts_dim:] # this will keep pos embeddings
        input_x = input_x[:,:self.input_node_dim]
        prep_x = input_x
        
        #preprocessing node features (only). 
        for fnn_inx in range(self.num_pre_fnn_layers):
            prep_x = self.graph_pred_pre_linear_list[fnn_inx](prep_x)
        
        batched_data.x = prep_x
        
        h_node = self.gnn_node(batched_data)

        h_graph = self.pool(h_node, batched_data.batch)

        output = h_graph # initial input is set to the output of the GNN 
        
        #postprocessing graph embeddings (only). 
        for fnn_inx in range(self.num_post_fnn_layers):
            output = self.graph_pred_post_linear_list[fnn_inx](output)
            

        return F.softmax(output,dim=1)
    
    def __str__(self):
        return self.gnn_type+f"-model-{self.pos_kwd}"

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else torch.device("cpu"))

In [None]:
multicls_criterion = torch.nn.CrossEntropyLoss()
epochs = 75

In [None]:
def import_dataset(name,transform=None, pre_transform=None,pre_filter=None):
    return JetsGraphsDataset('../dataset/',name=name,transform=transform,
                             pre_transform=pre_transform,pre_filter=pre_filter)

In [None]:
def create_loaders(dataset):
    # random splitting dataset
    train_inx, valid_inx, test_inx = random_split(range(len(dataset)),[0.7,0.2,0.1],generator=torch.Generator()
                                                .manual_seed(42))

    train_dataloader = DataLoader(dataset[list(train_inx)], batch_size=4, shuffle=True)
    valid_dataloader = DataLoader(dataset[list(valid_inx)], batch_size=4, shuffle=False)
    test_dataloader = DataLoader(dataset[list(test_inx)], batch_size=4, shuffle=False)
    
    return train_dataloader,valid_dataloader,test_dataloader

In [None]:
def train(model, device, loader, optimizer):
    model.train()

    loss_accum = 0
    for step, batch in enumerate(tqdm(loader, desc="Iteration")):
        batch=batch.to(device)
        if batch.x.shape[0] == 1:
            pass
        else: 
            output = model(batch)
            optimizer.zero_grad()
            loss = multicls_criterion(output, batch.y.view(-1).to(torch.int64))
            loss.backward()
            optimizer.step()

        loss_accum += loss.item()

    print('Average training loss: {}'.format(loss_accum / (step + 1))) 

In [None]:
def evaluate(model, device, loader,evaluator= "roauc"):
    model.eval()
    
    preds_list = []
    target_list = []
    for step, batch in enumerate(loader):
        batch = batch.to(device)
        with torch.no_grad():
            output = model(batch)
            preds_list.extend(output.tolist())
        target_list += batch.y.view(-1).tolist()

    if evaluator == "roauc":   
        metric = MulticlassAUROC(num_classes=2, average="macro", thresholds=None)
    if evaluator == "acc":
        metric = MulticlassAccuracy(num_classes=2, average="macro")
    # print("AUC-ROC metric score : ",metric(torch.Tensor(preds_list),torch.Tensor(target_list)).item())
    return metric(torch.Tensor(preds_list),torch.Tensor(target_list).to(torch.int64)).item()

In [None]:
def train_model(model,optimizer,dataset):
    checkpoints_path = "../models"
    checkpoints = os.listdir(checkpoints_path)
    checkpoint_path = list(filter(lambda i : str(model) in i, checkpoints))
    
    train_curves = []
    valid_curves = []
    starting_epoch = 1 
    
    # create loaders 
    train_dataloader,valid_dataloader,test_dataloader = create_loaders(dataset)
    
    if len(checkpoint_path)>0:
        checkpoint = torch.load(f"{checkpoints_path}/{checkpoint_path[0]}")
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        starting_epoch = checkpoint['epoch']+1

    for epoch in range(starting_epoch, epochs + 1):
        print("=====Epoch {}".format(epoch))
        print('Training...')
        train(model, device, train_dataloader, optimizer)
        
        # save checkpoint of current epoch
        torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                }, f"{checkpoints_path}/{str(model)}-{epoch}.pt")

        # delete checkpoint of previous epoch
        if epoch>1:
            os.remove(f"{checkpoints_path}/{str(model)}-{epoch-1}.pt")

        print("Evaluating...")
        train_perf_roauc = evaluate(model,device,train_dataloader)
        valid_perf_roauc = evaluate(model,device,valid_dataloader)
        print('ROAUC scores: ',{'Train': train_perf_roauc, 'Validation': valid_perf_roauc})
        
    print('\nFinished training!')
    print('\nROAUC Test score: {}'.format(evaluate(model,device,test_dataloader)))

## Training PointNet Conv based GNN model

In [None]:
pointnet_model = GNN(num_classes = 2, num_layer = 2,num_post_fnn_layers=2,input_edge_dim = 1,num_coords=3, 
                 input_spec_fts_dim=3,gnn_type = 'pointnet', emb_dim = 300, drop_ratio = 0.3).to(device)
optimizer = optim.Adam(pointnet_model.parameters(), lr=1e-3)

train_model(pointnet_model,optimizer)

## Training of GCN based model

### Training with GPE (all x,y,z coords)

In [None]:
gcn_model = GNN(num_classes = 2, num_layer = 2,num_post_fnn_layers=2,hasPos=True,input_edge_dim = 1,num_coords=3, 
                input_spec_fts_dim=3, gnn_type = 'gcn', emb_dim = 300, drop_ratio = 0.3).to(device)
optimizer = optim.Adam(gcn_model.parameters(), lr=1e-3)

train_model(gcn_model,optimizer)

### Training with GPE (only x,y coords)

In [None]:
gcn_model = GNN(num_classes = 2, num_layer = 2,num_post_fnn_layers=2,hasPos=True,input_edge_dim = 1,num_coords=2, 
                input_spec_fts_dim=3, gnn_type = 'gcn', emb_dim = 300, drop_ratio = 0.3).to(device)
optimizer = optim.Adam(gcn_model.parameters(), lr=1e-3)

train_model(gcn_model,optimizer)

### Training without GPE

In [None]:
gcn_model = GNN(num_classes = 2, num_layer = 2,num_post_fnn_layers=2,hasPos=False,input_edge_dim = 1, 
            input_spec_fts_dim=3, gnn_type = 'gcn', emb_dim = 300, drop_ratio = 0.3).to(device)
optimizer = optim.Adam(gcn_model.parameters(), lr=1e-3)

train_model(gcn_model,optimizer)

### Training with deeper model (with 10 GCN layers)

In [None]:
gcn_model = GNN(num_classes = 2, num_layer = 5,num_post_fnn_layers=2,hasPos=True,input_edge_dim = 1,num_coords=2, 
                input_spec_fts_dim=3, gnn_type = 'gcn', emb_dim = 100, drop_ratio = 0.3, JK='sum').to(device)
optimizer = optim.Adam(gcn_model.parameters(), lr=1e-3)

train_model(gcn_model,optimizer)

### Training GPS model

In [None]:
# importing dataset for GPS 
transform = T.AddRandomWalkPE(walk_length=20, attr_name='extraPE') # adding random walk positional encoding
jets_dataset = import_dataset(name="QCDToGGQQ_IMGjet_RH1all_jet0_run0_n36272",
                             transform=transform)

GPS_model = GNN(num_classes = 2, num_layer = 5,num_pre_fnn_layers=1,num_post_fnn_layers=2,hasPos=True,input_edge_dim = 1,num_coords=2, 
                input_spec_fts_dim=3, gnn_type = 'gps', emb_dim = 300, extraPE_dim=20, drop_ratio = 0.3).to(device)
optimizer = optim.Adam(GPS_model.parameters(), lr=1e-3)

train_model(GPS_model,optimizer,jets_dataset)