Adapted the following example to work with my input files and to use GraphSAGE:

https://colab.research.google.com/drive/1I8a0DfQ3fI7Njc62__mVXUlcAleUclnb?usp=sharing

Training a GNN for graph classification usually follows a simple recipe:
1. Embed each node by performing multiple rounds of message passing
2. Aggregate node embeddings into a unified graph embedding (readout layer)
3. Train a final classifier on the graph embedding


In [None]:
from datetime import datetime
import math
import os.path as osp
import sys
import time
import torch
import torch.nn.functional as F
from torch.nn import Linear
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
from sklearn import datasets, metrics, model_selection, svm
from torch_geometric.loader import DataLoader
from torch_geometric.nn import global_mean_pool

import matplotlib.pyplot as plt
from matplotlib.cm import get_cmap
from matplotlib.colors import ListedColormap
import numpy as np
from scipy.signal import find_peaks
import time
import csv
from tqdm import tqdm
import torch_geometric
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
import torch.nn.functional as F
# from torch_geometric.nn import GCNConv
import torch_geometric.transforms as T
# from torch_geometric.nn import GCNConv, global_add_pool
from torch_geometric.nn import GraphSAGE
# import dgl
import torch.nn as nn
# import itertools
# import scipy.sparse as sp
from sklearn.metrics import roc_auc_score
from sklearn.metrics import confusion_matrix
import seaborn as sns  # For a nicer looking matrix


import pandas as pd

In [None]:
def normaliseX(xPos):
    xPos = (xPos +350)/100
    return xPos

In [None]:
def normaliseY(yPos):
    yPos = (yPos +650)/100
    return yPos

In [None]:
def normaliseZ(zPos):
    zPos = zPos/100
    return zPos

In [None]:
def normaliseAdc(adcCounts):
    adcCounts=adcCounts/10
    return adcCounts

### This is the function that loops over the input csv file lines.
It calls *process_event* which unpacks each line in the file (each line = 1 cluster = 1 graph) to return an "eventArray" with features, edge labels, edge source and edge destination node IDs. One line = 1 cluster. It passes it each row of the csv file via "data" and also passes empty lists for x, y, z, and adcCounts, which will be filled by it.

In [None]:
def process_file(input_file):
    n_tracks_in_sample=0
    n_showers_in_sample=0
    with open(input_file, 'r') as f:
        num_events = len(f.readlines())
        print("num_events = ",num_events)
    
    dataset=[]
    xPos=[]
    yPos=[]
    zPos=[]
    adcCounts=[]
    truePdg=[]
    trueMCIndex=[]
    with open(input_file, 'r') as f:
            reader = csv.reader(f)
            for i, row in enumerate(tqdm(reader, desc="Test", miniters=100, total=num_events)):
                data = row[1:]
                print("i = ",i," -----------------------------------------------------------------")
                
                eventArray=process_event(data, f"{i}", xPos, yPos, zPos, adcCounts, truePdg, trueMCIndex)

                if eventArray==[]:
                    continue

                features=eventArray[0]
                edgeStart=eventArray[1]
                edgeEnd=eventArray[2]
                label=eventArray[3]

                if len(edgeStart) == 0:
                    continue
                edge_index = torch.tensor([edgeStart, edgeEnd], dtype=torch.long)
                node_features = torch.tensor(features, dtype=torch.float)

                data = Data(x=node_features, edge_index=edge_index)
                
                data.y = torch.tensor([label])

                print("n_tracks_in_sample = ", n_tracks_in_sample, " n_showers_in_sample = ",n_showers_in_sample, " n_tracks_in_sample_max  = ",n_tracks_in_sample_max," n_showers_in_sample_max = ",n_showers_in_sample_max)
                if (n_tracks_in_sample >= n_tracks_in_sample_max) and (n_showers_in_sample >= n_showers_in_sample_max):
                    return dataset
                elif (data.y == 0) and (n_showers_in_sample >= n_showers_in_sample_max):
                    continue
                elif (data.y == 1) and (n_tracks_in_sample >= n_tracks_in_sample_max):
                    continue
                #
                if data.y == 0:
                    n_showers_in_sample=n_showers_in_sample+1
                elif data.y == 1:
                    n_tracks_in_sample=n_tracks_in_sample+1
                dataset.append(data)
                print("len(dataset)= ",len(dataset))
            
    
    return dataset

In [None]:
#0 = shower, 1 = track, -999 = other
def find_true_particle_label(hpdg):
    
    label = -999
    most_frequent_pdg = -999
    unique_values, counts = np.unique(hpdg, return_counts=True)
    
    sorted_indices = np.argsort(counts)[::-1]
    sorted_unique_values = unique_values[sorted_indices]
    sorted_counts = counts[sorted_indices]
    
    # Most frequent PDG code and its count
    most_frequent_pdg = int(sorted_unique_values[0])
    most_frequent_count = sorted_counts[0]

    if most_frequent_pdg == 22 or most_frequent_pdg == 11:
        label = 0
    elif most_frequent_pdg == 13 or most_frequent_pdg == 2212 or most_frequent_pdg == 321 or most_frequent_pdg == 211:
        label = 1

    return int(label)

In [None]:
def visualise_input(x,y,z,adc,mc,pdg):

    fig, axes = plt.subplots(1, 3, figsize=(15, 5))

    # x-y plot
    axes[0].scatter(x, y, c=pdg, cmap='tab20')
    axes[0].set_title("X-Y Plot")
    axes[0].set_xlabel("X")
    axes[0].set_ylabel("Y")

    # x-z plot
    axes[1].scatter(x, z, c=pdg, cmap='tab20')
    axes[1].set_title("X-Z Plot")
    axes[1].set_xlabel("X")
    axes[1].set_ylabel("Z")

    # y-z plot
    axes[2].scatter(y, z, c=pdg, cmap='tab20')
    axes[2].set_title("Y-Z Plot")
    axes[2].set_xlabel("Y")
    axes[2].set_ylabel("Z")
            
    plt.tight_layout()
    plt.show()

In [None]:
def plot_loss_evolution(epochs, training_loss, test_loss, accuracy, ylim1, ylim2):
    fig, ax1 = plt.subplots()

    ax1.plot(epochs, training_loss, label='Training Loss', color='b')
    ax1.plot(epochs, test_loss, label='Validation Loss', color='g')

    ax1.set_xlabel('Epochs')
    ax1.set_ylabel('Loss', color='b')
    ax1.tick_params('y', colors='b')
    ax1.legend(loc='upper left')
    ax1.set_title('Training and Validation Loss Over Epochs')

    ax2 = ax1.twinx()
    ax2.plot(epochs, accuracy, label='Accuracy', color='r')
    ax2.set_ylabel('Accuracy Scale', color='r')
    ax2.tick_params('y', colors='r')
    ax2.legend(loc='upper right')

    if ylim1 > 0:
        ax1.set_ylim(0.0001, ylim1)
    if ylim2 > 0:
        ax2.set_ylim(0.0001, ylim2)
    ax1.set_yscale('log')

    plt.draw()
    plt.show()

In [None]:
def process_event(data, event, xPos, yPos, zPos, adcCounts, truePdg, trueMCIndex):
    eventArray=[]
    nh_coords = 6
    n_hits = int(data[0])
    h_start, h_finish = 1, nh_coords * n_hits+1
    length = len(data[h_start:-1]) 
    
    if length != (n_hits * nh_coords):
        print('Missing information in input file')
        print(n_hits, length)
        return
    
    hmc = np.array(data[h_start:h_finish:nh_coords], dtype=int)   #true mc index
    hpdg = np.array(data[h_start+1:h_finish:nh_coords], dtype=float)   #true mc index
    hx = np.array(data[h_start+2:h_finish:nh_coords], dtype=float)  #x coord
    hy = np.array(data[h_start+3:h_finish:nh_coords], dtype=float)  #y coord  
    hz = np.array(data[h_start+4:h_finish:nh_coords], dtype=float)  #z coord
    hadc = np.array(data[h_start+5:h_finish:nh_coords], dtype=float)#adc
    
    if normalise_positions==1:
        hx=normaliseX(hx)
        hy=normaliseY(hy)
        hz=normaliseZ(hz)
        
    if normalise_adcs==1:  
        hadc=normaliseAdc(hadc)

    print("n_hits = ",n_hits)
    
    if (n_hits > n_hits_in_graph_max) or (n_hits < n_hits_in_graph_min):
        print("n hits outside allowed interval")
        return eventArray
    
    features=np.empty((n_hits,6))

    dimension = n_hits*n_hits
    edgeStart = np.zeros((dimension))
    edgeEnd = np.zeros((dimension))
    index=0

    #Check that there is less than minPurity contamination
    unique_values, counts = np.unique(hmc, return_counts=True)
    if len(unique_values) >= 2:
        sorted_counts = np.sort(counts)  # Sort the counts
        #second_most_frequent_count = sorted_counts[-2]  # Get the second highest count
        total_count = np.sum(counts)
        if sorted_counts[-1] / total_count < minPurity:
            print("The most contributing MC true particle in this cluster contributes less than minPurity hits")
            return eventArray

    visualise_input(hx,hy,hz,hadc,hmc,hpdg)
    
    true_particle_label = find_true_particle_label(hpdg)
    print("true_particle_label = ",true_particle_label)
    if true_particle_label != 0 and true_particle_label !=1:
        print("true_particle_label of wrong type = ",true_particle_label)
        return eventArray



    for hit1 in range(0,n_hits):
        features[hit1,0]=hx[hit1]
        features[hit1,1]=hy[hit1]
        features[hit1,2]=hz[hit1]
        features[hit1,3]=hadc[hit1]
        features[hit1,4]=hmc[hit1]
        features[hit1,5]=hpdg[hit1]
        for hit2 in range(hit1+1,n_hits): 
            dx = hx[hit1] - hx[hit2]
            dy = hy[hit1] - hy[hit2]
            dz = hz[hit1] - hz[hit2]
            dist = np.sqrt(dx**2 + dy**2 + dz**2)
            if (dist>max_hit_distance_radius):
                continue            
            edgeStart[index]=int(hit1)
            edgeEnd[index]=int(hit2)
            index=index+1

    #Resize the arrays to contain entries for the hits within that range.
    edgeStart = edgeStart[:index]
    edgeEnd = edgeEnd[:index]
    features = features[:index,:]        
    eventArray=[features,edgeStart,edgeEnd,true_particle_label]
    return eventArray

In [None]:
class FinalClassifier(torch.nn.Module):
    def __init__(self, hidden_channels, num_classes):
        super(FinalClassifier, self).__init__()
        self.lin = Linear(hidden_channels, num_classes)

    def forward(self, x):
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.lin(x)
        return x
        

In [None]:
#Readout layer
class ReadoutLayer(torch.nn.Module):
    def __init__(self):
        super(ReadoutLayer, self).__init__()

    def forward(self, x, batch):
        x = global_mean_pool(x, batch)  # [batch_size, hidden_channels]
        return x

In [None]:
def train_gnn(dataset):

    test_graphs_size=int(len(dataset)*test_train_ratio)
    train_graphs_size=int(len(dataset)-test_graphs_size)
    print("train_graphs_size = ",train_graphs_size," test_graphs_size = ",test_graphs_size," len(dataset)= ",len(dataset))

    #Split dataset into train and test samples of graphs randomly:
    indices = np.random.choice(len(dataset), test_graphs_size, replace=False)
    test_dataset = [dataset[i] for i in indices]
    mask = np.ones(len(dataset), dtype=bool)
    mask[indices] = False
    train_dataset = [dataset[i] for i in range(len(dataset)) if mask[i]]


    print(f'Number of training graphs: {len(train_dataset)}')
    print(f'Number of test graphs: {len(test_dataset)}')

    train_loader = DataLoader(train_dataset, batch_size=180, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=180, shuffle=False)

    for step, data in enumerate(train_loader):
        print(f'Step {step + 1}:')
        print('=======')
        print(f'Number of graphs in the current batch: {data.num_graphs}')
        print(data)
        print()
    
    device = torch.device("cpu")

    dataset_num_classes=2 #Tracks and showers
    
    # This is using the pyg own module ##################
    GraphSAGE_model = GraphSAGE(
        dataset[0].num_node_features,
        hidden_channels=model_hidden_channels,
        num_layers=model_num_layers,
    ).to(device)

    ReadoutLayer_model = ReadoutLayer(
    ).to(device)
    
    FinalClassifier_model = FinalClassifier(
        hidden_channels=model_hidden_channels,
        num_classes=dataset_num_classes
        
    ).to(device)
    
    # # ----------- 3. set up loss and optimizer -------------- #
    model_parameters = list(GraphSAGE_model.parameters()) + list(ReadoutLayer_model.parameters()) + list(FinalClassifier_model.parameters())
    optimizer = torch.optim.Adam(model_parameters, lr=learning_rate)
    criterion = torch.nn.CrossEntropyLoss()

    training_loss=0
    test_loss=0
    training_loss_array=[]
    test_loss_array=[]
    accuracy_array=[]
    epochs_array=[]
    test_true_labels=[]
    test_predicted_labels=[]
    for e in epochs:
        test_correct=0
        train_correct=0
        train_acc=0
        test_acc=0
        print("epoch = ", e)

        #Train
        GraphSAGE_model.train()
        ReadoutLayer_model.train()
        FinalClassifier_model.train()
        optimizer.zero_grad()  # Clear gradients.
        for step, data in enumerate(train_loader):
            h = GraphSAGE_model(data.x, data.edge_index, data.batch)
            h1 = ReadoutLayer_model(h, data.batch)
            h2 = FinalClassifier_model(h1)
            training_loss = criterion(h2, data.y)
            training_loss.backward()  # Derive gradients.
            optimizer.step()  # Update parameters based on gradients.
            optimizer.zero_grad()  # Clear gradients.
            pred = h2.argmax(dim=1)  # Use the class with highest probability.
            train_correct += int((pred == data.y).sum())  # Check against ground-truth labels.
        train_acc=train_correct / len(train_loader.dataset)
            

        #Test
        with torch.no_grad():
            GraphSAGE_model.eval()
            ReadoutLayer_model.eval()
            FinalClassifier_model.eval()
        for step, data in enumerate(test_loader):
        #for data in test_loader:  
            h = GraphSAGE_model(data.x, data.edge_index, data.batch)
            h1 = ReadoutLayer_model(h, data.batch)
            h2 = FinalClassifier_model(h1)
            test_loss = criterion(h2, data.y)
            pred = h2.argmax(dim=1)  # Use the class with highest probability.
            test_predicted_labels=pred
            test_true_labels=data.y
            test_correct += int((pred == data.y).sum())  # Check against ground-truth labels.

        test_acc=test_correct / len(test_loader.dataset)
        training_loss_array.append(training_loss.detach().numpy())
        test_loss_array.append(test_loss.detach().numpy())
        accuracy_array.append(test_acc)
        epochs_array.append(e)
        
        print(f'Epoch: {e:03d}, Train Acc: {train_acc:.4f}, Test Acc: {test_acc:.4f}')
    plot_loss_evolution(epochs_array, training_loss_array, test_loss_array, accuracy_array, -999,-999)

## Run the GNN! (Here are the commands to prepare samples and run the training)

### Parameters (modify based on your needs!):
**prepare_samples** : if the sample is already loaded in memory in the notebook, I don't want to accidentally re-prepare it (as it's time consuming), so I check this condition later in the notebook when I call prepare_samples_for_training

**n_hits_in_graph_min, n_hits_in_graph_max** : minimum and maximum number of hits for cluster to be in input sample for training

**minPurity** : fraction of hits in the cluster that are contributed by the most contributing MC particle. Only select tracks and showers with a purity larger than minPurity. 

**test_train_ratio**, **epochs**, **learning rate** should self-describe.

***model_hidden_channels*** : the dimensionality of the hidden feature space in the GraphSAGE model

***model_num_layers*** : number of message-passing layers in the GraphSAGE model

***normalise_adcs***,***normalise_positions*** : bringing inputs down to a "small number" range. not normalised between 0 and 1, but numbers are broadly in the 1-10 range.

***plot_input_distributions*** : if true, plot input node feature distributions (i.e. x,y,z,adc) plus some other 

***max_hit_distance_radius*** : in this edge-prediction problem, only create an edge (regardless of if positive or negative) for hits that are less than this amount apart. Units change based on normalisation.


In [None]:
#Parameters
prepare_samples_again=True

n_hits_in_graph_min=200
n_hits_in_graph_max=12000
minPurity=0.95

n_tracks_in_sample_max=200
n_showers_in_sample_max=200

test_train_ratio=0.1
epochs = range(1, 100)  
learning_rate=0.01
model_hidden_channels=16
model_num_layers=3

#input_coord_normalisation=1
normalise_adcs=1
normalise_positions=1

if normalise_positions:
    max_hit_distance_radius=0.1 #this  is now in m so 10 cm 
else:
    max_hit_distance_radius=10

In [None]:
if prepare_samples_again:
    #file_path = f"training_out_test_feb_nue_43.csv"
    #file_path = f"training_out_test_feb_nue_showers.csv"
    #file_path = f"training_out_tsId.csv"
    #file_path = f"InputForTrackShowerID_GNN_nu_and_nue.csv" #huge file (100 files of nu and 100 files of nue)
    file_path = f"InputForTrackShowerID_GNN_nu_and_nue_20files.csv" #large file (10 files of nu and 10 files of nue)
    output_directory = f"outputDir"
    dataset=process_file(file_path)

In [None]:
#Train the GNN
print("len(dataset) = ",len(dataset))
train_gnn(dataset)