In [2]:
!pip install torch_geometric

Collecting torch_geometric
  Downloading torch_geometric-2.6.1-py3-none-any.whl.metadata (63 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m63.1/63.1 kB[0m [31m2.5 MB/s[0m eta [36m0:00:00[0m
Downloading torch_geometric-2.6.1-py3-none-any.whl (1.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.1/1.1 MB[0m [31m27.8 MB/s[0m eta [36m0:00:00[0m00:01[0m
[?25hInstalling collected packages: torch_geometric
Successfully installed torch_geometric-2.6.1


In [4]:
import os
import re
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from torch.utils.data import Dataset, DataLoader
from torch_geometric.data import Data
from torch_geometric.nn import (
    global_mean_pool, global_max_pool, global_add_pool,
    GATConv, EdgeConv, SAGPooling
)
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.transforms import BaseTransform
import matplotlib.pyplot as plt
from tqdm import tqdm
from sklearn.model_selection import train_test_split
from sklearn.metrics import (accuracy_score, precision_score, recall_score,
                             f1_score, confusion_matrix, classification_report, roc_auc_score)
from torch.optim import SGD
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts

In [5]:
def extract_chunk_numbers(f):
    m = re.search(r'chunk_(\d+)_(\d+)', f)
    if m:
        return (int(m.group(1)), int(m.group(2)))
    else:
        return (float('inf'), float('inf'))

In [6]:
class RandomPhiShift(BaseTransform):
    def __init__(self, dphi=0.2):
        self.dphi = dphi
    
    def __call__(self, data):
        if data.pos.size(0) > 0:
            shift = (torch.rand(1) * 2 - 1) * self.dphi
            data.pos[:, 1] = data.pos[:, 1] + shift
            data.pos[:, 1] = torch.atan2(torch.sin(data.pos[:, 1]), torch.cos(data.pos[:, 1]))
        return data

In [7]:
class DynamicHybridGraph(nn.Module):
    def __init__(self, k=4, alpha=0.5, beta=0.5):
        super().__init__()
        self.k = k
        self.alpha = alpha
        self.beta = beta
        
    def forward(self, x, pos, batch):
        device = x.device
        batch_size = batch.max().item() + 1
        edge_index_list = []
        edge_attr_list = []
        
        for i in range(batch_size):
            mask = (batch == i)
            x_batch = x[mask]
            pos_batch = pos[mask]
            if x_batch.size(0) <= 1:
                continue
            
            # Compute distances
            dist_geom = torch.cdist(pos_batch, pos_batch)
            dist_feat = torch.cdist(x_batch, x_batch)
            dist_combined = self.alpha * dist_geom + self.beta * dist_feat
            
            k_eff = min(self.k + 1, x_batch.size(0))
            _, topk_indices = torch.topk(dist_combined, k=k_eff, largest=False, dim=1)
            topk_indices = topk_indices[:, 1:]  # remove self-loop
            
            rows = torch.arange(x_batch.size(0), device=device).view(-1, 1).repeat(1, k_eff-1)
            offset = mask.nonzero(as_tuple=True)[0].min().item()
            edge_index = torch.stack([rows.reshape(-1), topk_indices.reshape(-1)], dim=0) + offset
            
            source_nodes = edge_index[0] - offset
            target_nodes = edge_index[1] - offset
            delta_eta = pos_batch[target_nodes, 0] - pos_batch[source_nodes, 0]
            delta_phi = pos_batch[target_nodes, 1] - pos_batch[source_nodes, 1]
            delta_phi = torch.atan2(torch.sin(delta_phi), torch.cos(delta_phi))
            delta_r = torch.sqrt(delta_eta**2 + delta_phi**2)
            delta_embed = x_batch[target_nodes] - x_batch[source_nodes]
            delta_embed_norm = torch.norm(delta_embed, p=2, dim=1, keepdim=True)
            edge_attr = torch.cat([
                delta_eta.unsqueeze(1),
                delta_phi.unsqueeze(1),
                delta_r.unsqueeze(1),
                delta_embed_norm
            ], dim=1)
            
            edge_index_list.append(edge_index)
            edge_attr_list.append(edge_attr)
        
        if not edge_index_list:
            return (torch.zeros((2, 0), device=device, dtype=torch.long),
                    torch.zeros((0, 4), device=device))
        edge_index = torch.cat(edge_index_list, dim=1)
        edge_attr = torch.cat(edge_attr_list, dim=0)
        return edge_index, edge_attr

In [8]:
def create_multiscale_knn(pos, k1=4, k2=6):
    def knn_graph(pos, k):
        n = pos.size(0)
        if n <= 1:
            return torch.zeros((2, 0), dtype=torch.long), torch.zeros((0, 3), dtype=torch.float)
        dist = torch.cdist(pos, pos)
        _, nn_idx = torch.topk(dist, k=min(k+1, n), dim=1, largest=False)
        nn_idx = nn_idx[:, 1:]
        rows = torch.arange(n).view(-1, 1).repeat(1, min(k, n-1))
        edge_index = torch.stack([rows.reshape(-1), nn_idx.reshape(-1)], dim=0)
        source_nodes = edge_index[0]
        target_nodes = edge_index[1]
        delta_eta = pos[target_nodes, 0] - pos[source_nodes, 0]
        delta_phi = pos[target_nodes, 1] - pos[source_nodes, 1]
        delta_phi = torch.atan2(torch.sin(delta_phi), torch.cos(delta_phi))
        delta_r = torch.sqrt(delta_eta**2 + delta_phi**2)
        edge_attr = torch.stack([delta_eta, delta_phi, delta_r], dim=1)
        return edge_index, edge_attr
    eidx_s, eattr_s = knn_graph(pos, k1)
    eidx_l, eattr_l = knn_graph(pos, k2)
    return eidx_s, eattr_s, eidx_l, eattr_l

In [9]:
class JetGraphProcessor:
    def __init__(self, input_dir, transform=None, k=8, min_energy_threshold=1e-4):
        self.input_dir = input_dir
        self.transform = transform
        self.k = k
        self.min_energy_threshold = min_energy_threshold
    
    def _convert_to_graph(self, jet_image, label, m0, pt):
        """
        Convert a jet image (3D numpy array) into a PyG Data object.
        Assumes jet_image shape is (H, W, C).
        """
        # Sum across channels to form a 2D projection
        jet_image_2d = np.sum(jet_image, axis=2)
        padded = np.pad(jet_image_2d, pad_width=1, mode='constant', constant_values=0)
        non_zero_indices = np.where(jet_image_2d > self.min_energy_threshold)
        points = []
        features = []
        for i, j in zip(non_zero_indices[0], non_zero_indices[1]):
            pixel_eta = (i / 125.0 * 2 - 1) * 0.8
            pixel_phi = (j / 125.0 * 2 - 1) * 0.8
            energy_ecal = jet_image[i, j, 0]
            energy_hcal = jet_image[i, j, 1]
            energy_tracks = jet_image[i, j, 2]
            total_energy = energy_ecal + energy_hcal + energy_tracks
            pt_fraction = total_energy / (pt + 1e-9)
            charged_fraction = energy_tracks / total_energy if total_energy > 0 else 0.0
            local_sum = np.sum(padded[i:i+3, j:j+3])
            angle_center = np.arctan2(j - 62.5, i - 62.5)
            norm_dist_center = np.sqrt((i - 62.5)**2 + (j - 62.5)**2) / 62.5
            features.append([
                total_energy, energy_ecal, energy_hcal, energy_tracks,
                pt_fraction, charged_fraction, local_sum,
                pixel_eta, pixel_phi, np.log1p(total_energy),
                np.sqrt(total_energy), angle_center, norm_dist_center
            ])
            points.append([pixel_eta, pixel_phi])
        if len(points) == 0:
            points = [[0, 0]]
            features = [[0.0]*13]
        x = torch.tensor(features, dtype=torch.float)
        pos = torch.tensor(points, dtype=torch.float)
        eidx_s, eattr_s, eidx_l, eattr_l = create_multiscale_knn(pos, k1=8, k2=16)
        data = Data(
            x=x,
            pos=pos,
            edge_index=eidx_s,
            edge_attr=eattr_s,
            y=torch.tensor([label], dtype=torch.long)
        )
        data.edge_index_large = eidx_l
        data.edge_attr_large = eattr_l
        data.global_features = torch.tensor([m0, pt], dtype=torch.float).unsqueeze(0)
        return data
    
    def process_all_chunks(self):
        """
        Process all .npz files from input_dir and accumulate graphs in memory.
        """
        all_chunk_files = sorted(
            [f for f in os.listdir(self.input_dir) if f.endswith('.npz')],
            key=extract_chunk_numbers
        )
        all_graphs = []
        for chunk_file in all_chunk_files:
            print(f"Processing {chunk_file} ...")
            data_npz = np.load(os.path.join(self.input_dir, chunk_file))
            X_jets = data_npz['X_jets']
            if 'y' in data_npz:
                y = data_npz['y']
            else:
                raise KeyError("Label key 'y' not found in the dataset.")
            m0 = data_npz['m0']
            pt = data_npz['pt']
            for i in tqdm(range(X_jets.shape[0]), desc=f"Processing {chunk_file}"):
                graph = self._convert_to_graph(X_jets[i], int(y[i]), m0[i], pt[i])
                if self.transform:
                    graph = self.transform(graph)
                all_graphs.append(graph)
        print(f"Total graphs processed: {len(all_graphs)}")
        return all_graphs

    def process_chunk(self):
        """
        Process a single NPZ file (input_file) and return a list of graphs.
        """
        data_npz = np.load(self.input_dir)  # here, input_dir is actually the file path
        X_jets = data_npz['X_jets']
        y = data_npz['y']
        m0 = data_npz['m0']
        pt = data_npz['pt']
        all_graphs = []
        for i in tqdm(range(X_jets.shape[0]), desc="Processing single chunk"):
            graph = self._convert_to_graph(X_jets[i], int(y[i]), m0[i], pt[i])
            if self.transform:
                graph = self.transform(graph)
            all_graphs.append(graph)
        return all_graphs


In [10]:
class BatchNorm(nn.Module):
    def __init__(self, num_features):
        super().__init__()
        self.bn = nn.BatchNorm1d(num_features)
    
    def forward(self, x):
        return self.bn(x)

In [11]:
class InMemoryJetGraphDataset(Dataset):
    def __init__(self, graphs, transform=None):
        self.graphs = graphs
        self.transform = transform
        
    def __len__(self):
        return len(self.graphs)
    
    def __getitem__(self, idx):
        graph = self.graphs[idx]
        if self.transform:
            graph = self.transform(graph)
        return graph

In [12]:
class CoordinateEmbedding(nn.Module):
    def __init__(self, in_dim=2, out_dim=4):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(in_dim, out_dim),
            nn.ReLU(),
            nn.Linear(out_dim, out_dim)
        )
    
    def forward(self, coords):
        return self.mlp(coords)


In [13]:
class NonLocalBlock(nn.Module):
    """
    A non-local block for self-attention over the nodes.
    Adapted from 'Non-local Neural Networks', with a simplified design.
    """
    def __init__(self, in_channels, inter_channels=None, bn_layer=True):
        super().__init__()
        self.in_channels = in_channels
        if inter_channels is None:
            inter_channels = in_channels // 2
            if inter_channels == 0:
                inter_channels = 1
        self.inter_channels = inter_channels

        # Theta, Phi, and g are 1x1 convolutions (pointwise fully connected layers)
        self.theta = nn.Conv1d(in_channels, inter_channels, kernel_size=1, bias=False)
        self.phi   = nn.Conv1d(in_channels, inter_channels, kernel_size=1, bias=False)
        self.g     = nn.Conv1d(in_channels, inter_channels, kernel_size=1, bias=False)
        self.W     = nn.Conv1d(inter_channels, in_channels, kernel_size=1, bias=False)
        if bn_layer:
            self.bn = nn.BatchNorm1d(in_channels)
        else:
            self.bn = None

    def forward(self, x):
        # x: (N, C) where N is the number of nodes, C is the number of channels.
        N, C = x.shape
        # Reshape to (1, C, N) to apply 1D convolutions over the "spatial"/node dimension.
        x_reshaped = x.transpose(0, 1).unsqueeze(0)  # shape: (1, C, N)

        theta_x = self.theta(x_reshaped)  # (1, inter_channels, N)
        phi_x   = self.phi(x_reshaped)    # (1, inter_channels, N)
        g_x     = self.g(x_reshaped)      # (1, inter_channels, N)

        # Reshape theta_x and phi_x for the attention computation.
        theta_x = theta_x.squeeze(0).transpose(0, 1)  # shape: (N, inter_channels)
        phi_x = phi_x.squeeze(0)                      # shape: (inter_channels, N)
        f = torch.matmul(theta_x, phi_x)              # shape: (N, N)
        f_div_C = f / f.shape[-1]                     # normalization
        
        # g_x reshaped: (N, inter_channels)
        g_x = g_x.squeeze(0).transpose(0, 1)
        y = torch.matmul(f_div_C, g_x)                # shape: (N, inter_channels)
        # Return to shape (1, inter_channels, N), then apply W and (optionally) batch norm.
        y = y.transpose(0, 1).unsqueeze(0)
        W_y = self.W(y)
        if self.bn is not None:
            W_y = self.bn(W_y)
        # Residual connection: add the input x (reshaped back to (N, C))
        z = W_y.squeeze(0).transpose(0, 1) + x
        return z

In [14]:
class EnhancedJetGNN(nn.Module):
    def __init__(self, node_dim, global_dim, hidden_dim=64, out_channels=2, 
                 num_layers=3, dropout=0.3, use_dynamic_graph=True, heads=4, 
                 use_sagpool=True, sagpool_ratio=0.5, use_non_local=False, use_sageconv=True):
        super().__init__()
        self.node_dim = node_dim
        self.global_dim = global_dim
        self.hidden_dim = hidden_dim
        self.num_layers = num_layers
        self.dropout = dropout
        self.use_dynamic_graph = use_dynamic_graph
        self.heads = heads
        self.use_sagpool = use_sagpool
        self.sagpool_ratio = sagpool_ratio
        self.use_non_local = use_non_local
        self.use_sageconv = use_sageconv
        
        # Coordinate embedding for (η, φ) at indices 7 and 8
        self.coord_embed = CoordinateEmbedding(in_dim=2, out_dim=4)
        # Process the remaining node features
        self.feature_mlp = nn.Sequential(
            nn.Linear(node_dim - 2, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )
        # Combine coordinate embedding with processed features
        self.comb_mlp = nn.Sequential(
            nn.Linear(hidden_dim + 4, hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU()
        )
        # Global feature encoder
        self.global_encoder = nn.Sequential(
            nn.Linear(global_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU()
        )
        # Dynamic adjacency module
        if use_dynamic_graph:
            self.dynamic_graph = DynamicHybridGraph(k=8, alpha=0.5, beta=0.5)
        # GNN layers (alternating GAT and EdgeConv)
        self.gnn_layers = nn.ModuleList()
        self.batch_norms = nn.ModuleList()
        self.gnn_layers.append(
            GATConv(hidden_dim, hidden_dim // heads, heads=heads, edge_dim=4, dropout=dropout)
        )
        self.batch_norms.append(BatchNorm(hidden_dim))
        for i in range(1, num_layers):
            if i % 2 == 1:
                self.gnn_layers.append(
                    EdgeConv(
                        nn=nn.Sequential(
                            nn.Linear(hidden_dim*2, hidden_dim),
                            nn.ReLU(),
                            nn.Dropout(dropout),
                            nn.Linear(hidden_dim, hidden_dim)
                        ),
                        aggr='mean'
                    )
                )
            else:
                self.gnn_layers.append(
                    GATConv(hidden_dim, hidden_dim // heads, heads=heads, edge_dim=4, dropout=dropout)
                )
            self.batch_norms.append(BatchNorm(hidden_dim))
        
        # SAGEConv branch (if enabled)
        if self.use_sageconv:
            self.sage_layers = nn.ModuleList()
            for i in range(num_layers):
                self.sage_layers.append(SAGEConv(hidden_dim, hidden_dim, aggr='mean'))
        
        # Insert a non-local block if enabled.
        if self.use_non_local:
            self.non_local = NonLocalBlock(in_channels=hidden_dim)
        
        # Optional SAGPooling for hierarchical pooling
        if self.use_sagpool:
            self.sagpool = SAGPooling(hidden_dim, ratio=self.sagpool_ratio)
        
        # Final classifier: multi-scale pooling (mean, max, sum) combined with global features.
        # Note: After pooling, pooled features have dimension hidden_dim.
        # Global encoder output is hidden_dim, so total input = 3*hidden_dim + hidden_dim = 4*hidden_dim.
        self.classifier = nn.Sequential(
            nn.Linear(4 * hidden_dim, 2 * hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(2 * hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, out_channels)
        )
    
    def forward(self, data):
        # data.x shape: (num_nodes, node_dim)
        # data.batch: (num_nodes,)
        # data.edge_index: (2, num_edges)
        # data.edge_attr: (num_edges, 4), assume edge features of size 4
        x, batch = data.x, data.batch
        edge_index, edge_attr = data.edge_index, data.edge_attr
        pos = data.pos
        if hasattr(data, 'global_features'):
            global_features = data.global_features
        else:
            global_features = torch.zeros((batch.max().item()+1, self.global_dim), device=x.device)
        if batch is None:
            batch = torch.zeros(x.size(0), dtype=torch.long, device=x.device)
        
        # Process node features: separate coordinate features at indices 7 and 8.
        coords = x[:, [7, 8]]
        other_feats = torch.cat([x[:, :7], x[:, 9:]], dim=1)
        coord_emb = self.coord_embed(coords)
        feat_emb = self.feature_mlp(other_feats)
        x_initial = torch.cat([feat_emb, coord_emb], dim=1)
        x_initial = self.comb_mlp(x_initial)  # Now x_initial has shape (num_nodes, hidden_dim)
        global_x = self.global_encoder(global_features)  # (num_graphs, hidden_dim)
        
        # Apply dynamic adjacency if enabled.
        if self.use_dynamic_graph:
            edge_index, edge_attr = self.dynamic_graph(x_initial, pos, batch)
        
        # GNN branch
        x_gnn = x_initial
        for i, conv in enumerate(self.gnn_layers):
            if self.use_dynamic_graph:
                edge_index, edge_attr = self.dynamic_graph(x_gnn, pos, batch)
            if isinstance(conv, GATConv):
                x_gnn = conv(x_gnn, edge_index, edge_attr=edge_attr)
            else:
                x_gnn = conv(x_gnn, edge_index)
            x_gnn = self.batch_norms[i](x_gnn)
            x_gnn = F.relu(x_gnn)
            x_gnn = F.dropout(x_gnn, p=self.dropout, training=self.training)
        
        # SAGEConv branch (if enabled)
        if self.use_sageconv:
            x_sage = x_initial
            for sage in self.sage_layers:
                x_sage = sage(x_sage, edge_index)
                x_sage = F.relu(x_sage)
                x_sage = F.dropout(x_sage, p=self.dropout, training=self.training)
            # Fuse outputs (average)
            x = (x_gnn + x_sage) / 2.0
        else:
            x = x_gnn
        
        # Apply non-local block if enabled
        if self.use_non_local:
            x = self.non_local(x)
        
        # Apply SAGPooling if enabled
        if self.use_sagpool:
            x, edge_index, edge_attr, batch, _, _ = self.sagpool(x, edge_index, edge_attr, batch=batch)
        
        pooled_mean = global_mean_pool(x, batch)
        pooled_max = global_max_pool(x, batch)
        pooled_sum = global_add_pool(x, batch)
        
        combined = torch.cat([pooled_mean, pooled_max, pooled_sum, global_x], dim=1)
        out = self.classifier(combined)
        return F.log_softmax(out, dim=-1)

In [16]:
def visualize_graph(graph, ax=None):
    if ax is None:
        fig, ax = plt.subplots()
    pos = graph.pos.numpy()
    ax.scatter(pos[:, 0], pos[:, 1], c='blue', s=30, label='Nodes')
    # Plot edges
    if graph.edge_index.size(1) > 0:
        edge_index = graph.edge_index.numpy()
        for i in range(edge_index.shape[1]):
            src = pos[edge_index[0, i]]
            dst = pos[edge_index[1, i]]
            ax.plot([src[0], dst[0]], [src[1], dst[1]], c='gray', linewidth=0.5)
    ax.set_title(f"Graph (Label: {graph.y.item()})")
    ax.legend()
    return ax

In [15]:
from torch_geometric.data import Data, InMemoryDataset
from torch_geometric.loader import DataLoader

import torch
import torch.nn as nn
import numpy as np
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix, classification_report, roc_auc_score
from torch.optim import SGD
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts

class EnhancedJetGNNTrainer:
    def __init__(self, model, device, optimizer=None, scheduler=None, lr=1e-2, weight_decay=5e-4):
        self.model = model.to(device)
        self.device = device
        
        # Use SGD with momentum as the optimizer.
        if optimizer is None:
            self.optimizer = SGD(self.model.parameters(), lr=lr, momentum=0.9, weight_decay=weight_decay)
        else:
            self.optimizer = optimizer
        
        # Scheduler: CosineAnnealingWarmRestarts with T_0=5.
        self.scheduler = scheduler if scheduler is not None else CosineAnnealingWarmRestarts(self.optimizer, T_0=5, T_mult=1)
        
        # Use NLLLoss since the model outputs log_softmax.
        self.criterion = nn.NLLLoss().to(device)

    def train(self, train_loader, val_loader, num_epochs=50, patience=10, model_save_path='best_jet_gnn_model.pt'):
        best_val_loss = float('inf')
        wait = 0
        train_losses = []
        val_metrics = []

        for epoch in range(num_epochs):
            self.model.train()
            running_loss = 0.0
            for data in train_loader:
                data = data.to(self.device)
                self.optimizer.zero_grad()
                out = self.model(data)
                loss = self.criterion(out, data.y.view(-1))
                loss.backward()
                self.optimizer.step()
                running_loss += loss.item() * data.num_graphs

            epoch_train_loss = running_loss / len(train_loader.dataset)
            metrics = self.evaluate(val_loader)
            train_losses.append(epoch_train_loss)
            val_metrics.append(metrics)

            print(f"Epoch [{epoch+1}/{num_epochs}] Train Loss: {epoch_train_loss:.4f} | "
                  f"Val Loss: {metrics['loss']:.4f} | Accuracy: {metrics['accuracy']:.4f} | "
                  f"F1: {metrics['f1_score']:.4f} | AUC: {metrics.get('auc', 0):.4f}")

            if self.scheduler is not None:
                self.scheduler.step()

            if metrics['loss'] < best_val_loss:
                best_val_loss = metrics['loss']
                torch.save(self.model.state_dict(), model_save_path)
                wait = 0
            else:
                wait += 1
                if wait >= patience:
                    print("Early stopping triggered.")
                    break

        self.model.load_state_dict(torch.load(model_save_path))
        return train_losses, val_metrics

    def evaluate(self, loader):
        self.model.eval()
        y_true, y_pred, y_score = [], [], []
        loss_sum = 0.0

        with torch.no_grad():
            for data in loader:
                data = data.to(self.device)
                out = self.model(data)
                loss = self.criterion(out, data.y.view(-1))
                loss_sum += loss.item() * data.num_graphs

                preds = out.argmax(dim=1).cpu().numpy()
                y_true.extend(data.y.cpu().numpy())
                y_pred.extend(preds)
                # For binary classification, extract probability for class 1.
                y_score.extend(out.exp()[:, 1].cpu().numpy())

        y_true = np.array(y_true)
        y_pred = np.array(y_pred)
        y_score = np.array(y_score)

        avg_loss = loss_sum / len(loader.dataset)
        accuracy = accuracy_score(y_true, y_pred)
        precision = precision_score(y_true, y_pred, average='macro')
        recall = recall_score(y_true, y_pred, average='macro')
        f1 = f1_score(y_true, y_pred, average='macro')
        conf_matrix = confusion_matrix(y_true, y_pred)
        class_report = classification_report(y_true, y_pred)

        try:
            auc = roc_auc_score(y_true, y_score)
        except Exception:
            auc = 0.0

        return {
            "loss": avg_loss,
            "accuracy": accuracy,
            "precision": precision,
            "recall": recall,
            "f1_score": f1,
            "confusion_matrix": conf_matrix,
            "classification_report": class_report,
            "auc": auc
        }

In [14]:
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix, classification_report, roc_auc_score

In [None]:
def main():
    # Configuration: Toggle 'use_non_local' to switch between baseline and non-local GNN.
    config = {
        'chunk_files': [
            '/kaggle/input/chunk_0_10000.npz',
            '/kaggle/input/chunk_10000_20000.npz',
        ],
        'batch_size': 16,
        'hidden_channels': 128,
        'num_layers': 2,
        'dropout': 0.4,
        'learning_rate': 0.001,
        'weight_decay': 5e-4,
        'epochs': 10,
        'patience': 8,
        'test_size': 0.2,  # 20% for test/inference
        'device': torch.device('cuda' if torch.cuda.is_available() else 'cpu'),
        'use_dynamic_graph': True,
        'gat_heads': 4,
        'use_sagpool': True,
        'sagpool_ratio': 0.5,
        # Set to True to use the non-local GNN variant.
        'use_non_local': True
    }

    print(f"Using device: {config['device']}")

    # Process multiple NPZ files into graphs.
    all_graphs = []
    for file in config['chunk_files']:
        print(f"Processing file: {file}")
        processor = JetGraphProcessor(file)
        all_graphs.extend(processor.process_chunk())
    print(f"Total graphs processed: {len(all_graphs)}")

    # Create an in-memory dataset.
    dataset = InMemoryJetGraphDataset(all_graphs)
    print(f"InMemoryJetGraphDataset created with {len(dataset)} graphs.")

    # Split dataset indices for train (80%) and test (20%).
    indices = list(range(len(dataset)))
    train_idx, test_idx = train_test_split(indices, test_size=config['test_size'], random_state=42)

    train_dataset = torch.utils.data.Subset(dataset, train_idx)
    test_dataset = torch.utils.data.Subset(dataset, test_idx)

    print(f"Train size: {len(train_dataset)}")
    print(f"Test size:  {len(test_dataset)}")

    # Use PyTorch Geometric DataLoader for batching Data objects.
    train_loader = DataLoader(train_dataset, batch_size=config['batch_size'], shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=config['batch_size'], shuffle=False)

    # Determine feature dimensions from a sample graph.
    sample_data = dataset[0]
    node_dim = sample_data.x.size(1)
    global_dim = sample_data.global_features.size(1) if hasattr(sample_data, 'global_features') else 2

    # Initialize the model.
    model = EnhancedJetGNN(
        node_dim=node_dim,
        global_dim=global_dim,
        hidden_dim=config['hidden_channels'],
        out_channels=2,
        num_layers=config['num_layers'],
        dropout=config['dropout'],
        use_dynamic_graph=config['use_dynamic_graph'],
        heads=config['gat_heads'],
        use_sagpool=config['use_sagpool'],
        sagpool_ratio=config['sagpool_ratio'],
        use_non_local=config['use_non_local']
    ).to(config['device'])

    # Initialize the trainer.
    trainer = EnhancedJetGNNTrainer(
        model=model,
        device=config['device'],
        lr=config['learning_rate'],
        weight_decay=config['weight_decay'],
    )

    # Train the model using the training loader and validate on the test loader.
    train_losses, test_metrics_history = trainer.train(
        train_loader=train_loader,
        val_loader=test_loader,  # Using test set for validation in this example.
        num_epochs=config['epochs'],
        patience=config['patience'],
        model_save_path='best_jet_gnn_model.pt'
    )

    # Evaluate on the test set and print final metrics.
    final_test_metrics = trainer.evaluate(test_loader)
    print("\nTest Set Metrics:")
    print(f"Accuracy:  {final_test_metrics['accuracy']:.4f}")
    print(f"Precision: {final_test_metrics['precision']:.4f}")
    print(f"Recall:    {final_test_metrics['recall']:.4f}")
    print(f"F1 Score:  {final_test_metrics['f1_score']:.4f}")
    print(f"AUC:       {final_test_metrics['auc']:.4f}")
    print("Confusion Matrix:")
    print(final_test_metrics['confusion_matrix'])
    print("Classification Report:")
    print(final_test_metrics['classification_report'])

    # Plot training loss and validation metrics.
    plt.figure(figsize=(12, 5))
    plt.subplot(1, 2, 1)
    plt.plot(train_losses, label='Train Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Training Loss')
    plt.legend()

    if test_metrics_history:
        epochs = range(1, len(test_metrics_history) + 1)
        accuracies = [m['accuracy'] for m in test_metrics_history]
        f1_scores = [m['f1_score'] for m in test_metrics_history]
        plt.subplot(1, 2, 2)
        plt.plot(epochs, accuracies, label='Accuracy')
        plt.plot(epochs, f1_scores, label='F1 Score')
        plt.xlabel('Epoch')
        plt.ylabel('Score')
        plt.title('Validation Metrics')
        plt.legend()
    else:
        print("Test metrics history is not available for plotting.")

    plt.tight_layout()
    plt.show()

if __name__ == "__main__":
    main()

Using device: cuda
Processing file: /kaggle/input/chunk_0_10000.npz


Processing single chunk:   0%|          | 4/10000 [00:00<04:18, 38.69it/s]