In [1]:
!pip install torch_geometric

Collecting torch_geometric
  Downloading torch_geometric-2.6.1-py3-none-any.whl.metadata (63 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m63.1/63.1 kB[0m [31m2.0 MB/s[0m eta [36m0:00:00[0m
Downloading torch_geometric-2.6.1-py3-none-any.whl (1.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.1/1.1 MB[0m [31m19.7 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
[?25hInstalling collected packages: torch_geometric
Successfully installed torch_geometric-2.6.1


In [2]:
import torch
print(torch.__version__)

2.5.1+cu121


In [3]:
import os
import re
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from torch.utils.data import Dataset, DataLoader
from torch_geometric.data import Data
from torch_geometric.nn import (
    global_mean_pool, global_max_pool, global_add_pool,
    GATConv, EdgeConv, SAGPooling,GATConv, EdgeConv, SAGEConv, SAGPooling, global_mean_pool, global_max_pool, global_add_pool, BatchNorm)
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.transforms import BaseTransform

import matplotlib.pyplot as plt
from tqdm import tqdm
from sklearn.model_selection import train_test_split

In [4]:
def extract_chunk_numbers(f):
    m = re.search(r'chunk_(\d+)_(\d+)', f)
    if m:
        return (int(m.group(1)), int(m.group(2)))
    else:
        return (float('inf'), float('inf'))

In [5]:
class RandomPhiShift(BaseTransform):
    def __init__(self, dphi=0.2):
        self.dphi = dphi
    
    def __call__(self, data):
        if data.pos.size(0) > 0:
            shift = (torch.rand(1) * 2 - 1) * self.dphi
            data.pos[:, 1] = data.pos[:, 1] + shift
            data.pos[:, 1] = torch.atan2(torch.sin(data.pos[:, 1]), torch.cos(data.pos[:, 1]))
        return data

In [6]:
class DynamicHybridGraph(nn.Module):
    def __init__(self, k=5, alpha=0.5, beta=0.5):
        super().__init__()
        self.k = k
        self.alpha = alpha
        self.beta = beta
        
    def forward(self, x, pos, batch):
        device = x.device
        batch_size = batch.max().item() + 1
        edge_index_list = []
        edge_attr_list = []
        
        for i in range(batch_size):
            mask = (batch == i)
            x_batch = x[mask]
            pos_batch = pos[mask]
            if x_batch.size(0) <= 1:
                continue
            
            # Compute distances
            dist_geom = torch.cdist(pos_batch, pos_batch)
            dist_feat = torch.cdist(x_batch, x_batch)
            dist_combined = self.alpha * dist_geom + self.beta * dist_feat
            
            k_eff = min(self.k + 1, x_batch.size(0))
            _, topk_indices = torch.topk(dist_combined, k=k_eff, largest=False, dim=1)
            topk_indices = topk_indices[:, 1:]  # remove self-loop
            
            rows = torch.arange(x_batch.size(0), device=device).view(-1, 1).repeat(1, k_eff-1)
            offset = mask.nonzero(as_tuple=True)[0].min().item()
            edge_index = torch.stack([rows.reshape(-1), topk_indices.reshape(-1)], dim=0) + offset
            
            source_nodes = edge_index[0] - offset
            target_nodes = edge_index[1] - offset
            delta_eta = pos_batch[target_nodes, 0] - pos_batch[source_nodes, 0]
            delta_phi = pos_batch[target_nodes, 1] - pos_batch[source_nodes, 1]
            delta_phi = torch.atan2(torch.sin(delta_phi), torch.cos(delta_phi))
            delta_r = torch.sqrt(delta_eta**2 + delta_phi**2)
            delta_embed = x_batch[target_nodes] - x_batch[source_nodes]
            delta_embed_norm = torch.norm(delta_embed, p=2, dim=1, keepdim=True)
            edge_attr = torch.cat([
                delta_eta.unsqueeze(1),
                delta_phi.unsqueeze(1),
                delta_r.unsqueeze(1),
                delta_embed_norm
            ], dim=1)
            
            edge_index_list.append(edge_index)
            edge_attr_list.append(edge_attr)
        
        if not edge_index_list:
            return (torch.zeros((2, 0), device=device, dtype=torch.long),
                    torch.zeros((0, 4), device=device))
        edge_index = torch.cat(edge_index_list, dim=1)
        edge_attr = torch.cat(edge_attr_list, dim=0)
        return edge_index, edge_attr

In [7]:
def create_multiscale_knn(pos, k1=3, k2=5):
    def knn_graph(pos, k):
        n = pos.size(0)
        if n <= 1:
            return torch.zeros((2, 0), dtype=torch.long), torch.zeros((0, 3), dtype=torch.float)
        dist = torch.cdist(pos, pos)
        _, nn_idx = torch.topk(dist, k=min(k+1, n), dim=1, largest=False)
        nn_idx = nn_idx[:, 1:]
        rows = torch.arange(n).view(-1, 1).repeat(1, min(k, n-1))
        edge_index = torch.stack([rows.reshape(-1), nn_idx.reshape(-1)], dim=0)
        source_nodes = edge_index[0]
        target_nodes = edge_index[1]
        delta_eta = pos[target_nodes, 0] - pos[source_nodes, 0]
        delta_phi = pos[target_nodes, 1] - pos[source_nodes, 1]
        delta_phi = torch.atan2(torch.sin(delta_phi), torch.cos(delta_phi))
        delta_r = torch.sqrt(delta_eta**2 + delta_phi**2)
        edge_attr = torch.stack([delta_eta, delta_phi, delta_r], dim=1)
        return edge_index, edge_attr
    eidx_s, eattr_s = knn_graph(pos, k1)
    eidx_l, eattr_l = knn_graph(pos, k2)
    return eidx_s, eattr_s, eidx_l, eattr_l

In [8]:
class JetGraphProcessor:
    def __init__(self, input_dir, transform=None, k=5, min_energy_threshold=1e-5):
        self.input_dir = input_dir
        self.transform = transform
        self.k = k
        self.min_energy_threshold = min_energy_threshold
    
    def _convert_to_graph(self, jet_image, label, m0, pt):
        """
        Convert a jet image (3D numpy array) into a PyG Data object.
        Assumes jet_image shape is (H, W, C).
        """
        # Sum across channels to form a 2D projection
        jet_image_2d = np.sum(jet_image, axis=2)
        padded = np.pad(jet_image_2d, pad_width=1, mode='constant', constant_values=0)
        non_zero_indices = np.where(jet_image_2d > self.min_energy_threshold)
        points = []
        features = []
        for i, j in zip(non_zero_indices[0], non_zero_indices[1]):
            pixel_eta = (i / 125.0 * 2 - 1) * 0.8
            pixel_phi = (j / 125.0 * 2 - 1) * 0.8
            energy_ecal = jet_image[i, j, 0]
            energy_hcal = jet_image[i, j, 1]
            energy_tracks = jet_image[i, j, 2]
            total_energy = energy_ecal + energy_hcal + energy_tracks
            pt_fraction = total_energy / (pt + 1e-9)
            charged_fraction = energy_tracks / total_energy if total_energy > 0 else 0.0
            local_sum = np.sum(padded[i:i+3, j:j+3])
            angle_center = np.arctan2(j - 62.5, i - 62.5)
            norm_dist_center = np.sqrt((i - 62.5)**2 + (j - 62.5)**2) / 62.5
            features.append([
                total_energy, energy_ecal, energy_hcal, energy_tracks,
                pt_fraction, charged_fraction, local_sum,
                pixel_eta, pixel_phi, np.log1p(total_energy),
                np.sqrt(total_energy), angle_center, norm_dist_center
            ])
            points.append([pixel_eta, pixel_phi])
        if len(points) == 0:
            points = [[0, 0]]
            features = [[0.0]*13]
        x = torch.tensor(features, dtype=torch.float)
        pos = torch.tensor(points, dtype=torch.float)
        eidx_s, eattr_s, eidx_l, eattr_l = create_multiscale_knn(pos, k1=4, k2=8)
        data = Data(
            x=x,
            pos=pos,
            edge_index=eidx_s,
            edge_attr=eattr_s,
            y=torch.tensor([label], dtype=torch.long)
        )
        data.edge_index_large = eidx_l
        data.edge_attr_large = eattr_l
        data.global_features = torch.tensor([m0, pt], dtype=torch.float).unsqueeze(0)
        return data
    
    def process_all_chunks(self):
        """
        Process all .npz files from input_dir and accumulate graphs in memory.
        """
        all_chunk_files = sorted(
            [f for f in os.listdir(self.input_dir) if f.endswith('.npz')],
            key=extract_chunk_numbers
        )
        all_graphs = []
        for chunk_file in all_chunk_files:
            print(f"Processing {chunk_file} ...")
            data_npz = np.load(os.path.join(self.input_dir, chunk_file))
            X_jets = data_npz['X_jets']
            if 'y' in data_npz:
                y = data_npz['y']
            else:
                raise KeyError("Label key 'y' not found in the dataset.")
            m0 = data_npz['m0']
            pt = data_npz['pt']
            for i in tqdm(range(X_jets.shape[0]), desc=f"Processing {chunk_file}"):
                graph = self._convert_to_graph(X_jets[i], int(y[i]), m0[i], pt[i])
                if self.transform:
                    graph = self.transform(graph)
                all_graphs.append(graph)
        print(f"Total graphs processed: {len(all_graphs)}")
        return all_graphs

    def process_chunk(self):
        """
        Process a single NPZ file (input_file) and return a list of graphs.
        """
        data_npz = np.load(self.input_dir)  # here, input_dir is actually the file path
        X_jets = data_npz['X_jets']
        y = data_npz['y']
        m0 = data_npz['m0']
        pt = data_npz['pt']
        all_graphs = []
        for i in tqdm(range(X_jets.shape[0]), desc="Processing single chunk"):
            graph = self._convert_to_graph(X_jets[i], int(y[i]), m0[i], pt[i])
            if self.transform:
                graph = self.transform(graph)
            all_graphs.append(graph)
        return all_graphs


In [9]:
class BatchNorm(nn.Module):
    def __init__(self, num_features):
        super().__init__()
        self.bn = nn.BatchNorm1d(num_features)
    
    def forward(self, x):
        return self.bn(x)

In [10]:
class InMemoryJetGraphDataset(Dataset):
    def __init__(self, graphs, transform=None):
        self.graphs = graphs
        self.transform = transform
        
    def __len__(self):
        return len(self.graphs)
    
    def __getitem__(self, idx):
        graph = self.graphs[idx]
        if self.transform:
            graph = self.transform(graph)
        return graph

In [11]:
class CoordinateEmbedding(nn.Module):
    def __init__(self, in_dim=2, out_dim=4):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(in_dim, out_dim),
            nn.ReLU(),
            nn.Linear(out_dim, out_dim)
        )
    
    def forward(self, coords):
        return self.mlp(coords)


In [12]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GATConv, EdgeConv, SAGEConv, SAGPooling, global_mean_pool, global_max_pool, global_add_pool, BatchNorm

class EnhancedJetGNN(nn.Module):
    def __init__(self, node_dim, global_dim, hidden_dim=64, out_channels=2, 
                 num_layers=3, dropout=0.3, use_dynamic_graph=True, heads=4, 
                 use_sagpool=True, sagpool_ratio=0.5, use_sageconv=True):
        super().__init__()
        self.node_dim = node_dim
        self.global_dim = global_dim
        self.hidden_dim = hidden_dim
        self.num_layers = num_layers
        self.dropout = dropout
        self.use_dynamic_graph = use_dynamic_graph
        self.heads = heads
        self.use_sagpool = use_sagpool
        self.sagpool_ratio = sagpool_ratio
        self.use_sageconv = use_sageconv

        # Growth rate for dense connectivity.
        self.growth_rate = hidden_dim // 2

        # Coordinate embedding for (η, φ) at indices 7 and 8.
        self.coord_embed = CoordinateEmbedding(in_dim=2, out_dim=4)
        # Process the remaining node features.
        self.feature_mlp = nn.Sequential(
            nn.Linear(node_dim - 2, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )
        # Dedicated physics branch to extract physics-inspired features.
        self.physics_mlp = nn.Sequential(
            nn.Linear(node_dim, hidden_dim // 2),
            nn.ReLU(),
            nn.Linear(hidden_dim // 2, hidden_dim),
            nn.ReLU()
        )
        # Combine coordinate embedding with processed features.
        self.comb_mlp = nn.Sequential(
            nn.Linear(hidden_dim + 4, hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU()
        )
        # Global feature encoder.
        self.global_encoder = nn.Sequential(
            nn.Linear(global_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU()
        )
        # Dynamic adjacency module.
        if use_dynamic_graph:
            self.dynamic_graph = DynamicHybridGraph(k=5, alpha=0.5, beta=0.5)

        # Build GNN layers with dense (concatenative) connections.
        self.gnn_layers = nn.ModuleList()
        self.batch_norms = nn.ModuleList()
        self.transition_layers = nn.ModuleList()

        # For the first layer, the input dimension is hidden_dim.
        in_dim = hidden_dim

        for i in range(num_layers):
            if i % 2 == 0:
                per_head_dim = self.growth_rate // heads
                self.gnn_layers.append(
                    GATConv(in_dim, per_head_dim, heads=heads, edge_dim=4, dropout=dropout)
                )
            else:
                mlp = nn.Sequential(
                    nn.Linear(in_dim * 2, hidden_dim),
                    nn.ReLU(),
                    nn.Dropout(dropout),
                    nn.Linear(hidden_dim, self.growth_rate)
                )
                self.gnn_layers.append(
                    EdgeConv(nn=mlp, aggr='mean')
                )
            self.batch_norms.append(BatchNorm(self.growth_rate))
            self.transition_layers.append(
                nn.Sequential(
                    nn.Linear(in_dim + self.growth_rate, hidden_dim),
                    nn.ReLU()
                )
            )
            in_dim = hidden_dim
        
        # SAGEConv branch for parallel feature extraction.
        if self.use_sageconv:
            self.sage_layers = nn.ModuleList()
            sage_in_dim = hidden_dim  # starting from same input dimension as after comb_mlp.
            for i in range(num_layers):
                self.sage_layers.append(SAGEConv(sage_in_dim, self.growth_rate, aggr='mean'))
                sage_in_dim = hidden_dim  # We'll assume a transition layer similar to GNN branch not needed here.
        
        # Optional SAGPooling for hierarchical pooling.
        if self.use_sagpool:
            self.sagpool = SAGPooling(hidden_dim, ratio=self.sagpool_ratio)
        
        # Final classifier: multi-scale pooling (mean, max, sum) combined with global features.
        self.classifier = nn.Sequential(
            nn.Linear(hidden_dim * 3 + hidden_dim, hidden_dim * 2),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, out_channels)
        )
    
    def forward(self, data):
        x, batch = data.x, data.batch
        edge_index, edge_attr = data.edge_index, data.edge_attr
        pos = data.pos
        if hasattr(data, 'global_features'):
            global_features = data.global_features
        else:
            global_features = torch.zeros((batch.max().item()+1, self.global_dim), device=x.device)
        if batch is None:
            batch = torch.zeros(x.size(0), dtype=torch.long, device=x.device)
        
        # Save original node features for physics branch.
        x_original = x.clone()
        # Process coordinate features (assumed at indices 7 and 8).
        coords = x[:, [7, 8]]
        other_feats = torch.cat([x[:, :7], x[:, 9:]], dim=1)
        coord_emb = self.coord_embed(coords)
        feat_emb = self.feature_mlp(other_feats)
        x = torch.cat([feat_emb, coord_emb], dim=1)
        # Obtain additional physics-inspired features.
        physics_emb = self.physics_mlp(x_original)
        x = x + physics_emb
        x = self.comb_mlp(x)
        global_x = self.global_encoder(global_features)
        
        # Process through GNN branch with dense connectivity.
        x_gnn = x
        for i, conv in enumerate(self.gnn_layers):
            if self.use_dynamic_graph:
                edge_index, edge_attr = self.dynamic_graph(x_gnn, pos, batch)
            if isinstance(conv, GATConv):
                new_features = conv(x_gnn, edge_index, edge_attr=edge_attr)
            else:
                new_features = conv(x_gnn, edge_index)
            new_features = self.batch_norms[i](new_features)
            new_features = F.relu(new_features)
            new_features = F.dropout(new_features, p=self.dropout, training=self.training)
            concatenated = torch.cat([x_gnn, new_features], dim=1)
            x_gnn = self.transition_layers[i](concatenated)
        
        # Process through SAGEConv branch if enabled.
        if self.use_sageconv:
            x_sage = x
            for sage_conv in self.sage_layers:
                x_sage = sage_conv(x_sage, edge_index)
                x_sage = F.relu(x_sage)
                x_sage = F.dropout(x_sage, p=self.dropout, training=self.training)
            # Fuse the outputs by averaging.
            x_fused = (x_gnn + x_sage) / 2.0
        else:
            x_fused = x_gnn
        
        # Optionally apply non-local block if implemented externally (not included here).
        # if self.use_non_local:
        #     x_fused = self.non_local(x_fused)
        
        # Apply SAGPooling if enabled.
        if self.use_sagpool:
            x_fused, edge_index, edge_attr, batch, _, _ = self.sagpool(x_fused, edge_index, edge_attr, batch=batch)
        
        pooled_mean = global_mean_pool(x_fused, batch)
        pooled_max = global_max_pool(x_fused, batch)
        pooled_sum = global_add_pool(x_fused, batch)
        combined = torch.cat([pooled_mean, pooled_max, pooled_sum, global_x], dim=1)
        out = self.classifier(combined)
        return F.log_softmax(out, dim=-1)


In [13]:
def visualize_graph(graph, ax=None):
    if ax is None:
        fig, ax = plt.subplots()
    pos = graph.pos.numpy()
    ax.scatter(pos[:, 0], pos[:, 1], c='blue', s=30, label='Nodes')
    # Plot edges
    if graph.edge_index.size(1) > 0:
        edge_index = graph.edge_index.numpy()
        for i in range(edge_index.shape[1]):
            src = pos[edge_index[0, i]]
            dst = pos[edge_index[1, i]]
            ax.plot([src[0], dst[0]], [src[1], dst[1]], c='gray', linewidth=0.5)
    ax.set_title(f"Graph (Label: {graph.y.item()})")
    ax.legend()
    return ax

In [14]:
import os
import torch
import torch.nn as nn
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch_geometric.data import Data, InMemoryDataset
from torch_geometric.loader import DataLoader
import numpy as np
from sklearn.metrics import (accuracy_score, precision_score, recall_score,
                             f1_score, confusion_matrix, classification_report, roc_auc_score)

# Helper function to load processed graphs if already cached, or process and cache them.
def load_or_process_graphs(raw_data_path, cache_path='cached_graphs.pt'):
    if os.path.exists(cache_path):
        print("Loading graphs from cache...")
        return torch.load(cache_path)
    else:
        print("Cache not found. Processing raw data...")
        # Replace this with your actual raw data processing logic:
        # For example, use JetGraphProcessor to process raw files.
        all_graphs = []
        # Suppose raw_data_path is a list of NPZ files:
        for file in raw_data_path:
            print(f"Processing file: {file}")
            processor = JetGraphProcessor(file)  # Your processing class
            all_graphs.extend(processor.process_chunk())
        # Cache the processed graphs.
        torch.save(all_graphs, cache_path)
        return all_graphs

class EnhancedJetGNNTrainer:
    def __init__(self, model, device, optimizer=None, scheduler=None, lr=1e-3, weight_decay=5e-4, grad_clip=1.0, monitor_metric='accuracy_score'):
        self.model = model.to(device)
        self.device = device
        
        # Use AdamW optimizer.
        if optimizer is None:
            self.optimizer = AdamW(self.model.parameters(), lr=lr, weight_decay=weight_decay)
        else:
            self.optimizer = optimizer
        
        # Scheduler will be set later if not provided.
        self.scheduler = scheduler
        
        # Use NLLLoss since the model is expected to output log_softmax values.
        self.criterion = nn.NLLLoss().to(device)
        
        # Gradient clipping threshold.
        self.grad_clip = grad_clip
        
        # Monitor metric for saving the best model (e.g., 'f1_score' or 'accuracy').
        self.monitor_metric = monitor_metric

    def train(self, train_loader, val_loader, num_epochs=50, patience=10, model_save_path='best_jet_gnn_model.pt'):
        # Initialize CosineAnnealingLR scheduler if not provided.
        if self.scheduler is None:
            # T_max is the maximum number of iterations, here we set it to the number of epochs.
            self.scheduler = CosineAnnealingLR(self.optimizer, T_max=num_epochs)
        
        best_metric = -float('inf')
        wait = 0
        train_losses = []
        val_metrics = []
        
        for epoch in range(num_epochs):
            self.model.train()
            running_loss = 0.0
            
            for data in train_loader:
                data = data.to(self.device)
                self.optimizer.zero_grad()
                out = self.model(data)
                loss = self.criterion(out, data.y.view(-1))
                loss.backward()
                
                # Apply gradient clipping.
                torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.grad_clip)
                self.optimizer.step()
                
                running_loss += loss.item() * data.num_graphs

            # Step the cosine annealing scheduler at the end of the epoch.
            self.scheduler.step()
            
            epoch_train_loss = running_loss / len(train_loader.dataset)
            metrics = self.evaluate(val_loader)
            train_losses.append(epoch_train_loss)
            val_metrics.append(metrics)

            current_lr = self.optimizer.param_groups[0]['lr']
            print(f"Epoch [{epoch+1}/{num_epochs}] Train Loss: {epoch_train_loss:.4f} | "
                  f"Val Loss: {metrics['loss']:.4f} | Accuracy: {metrics['accuracy']:.4f} | "
                  f"F1: {metrics['f1_score']:.4f} | AUC: {metrics.get('auc', 0):.4f} | LR: {current_lr:.6f}")

            # Use the specified monitor metric to track the best model.
            current_metric = metrics.get(self.monitor_metric, None)
            if current_metric is None:
                current_metric = -metrics['loss']  # fallback if metric not available

            if current_metric > best_metric:
                best_metric = current_metric
                torch.save(self.model.state_dict(), model_save_path)
                wait = 0
            else:
                wait += 1
                if wait >= patience:
                    print("Early stopping triggered.")
                    break

        self.model.load_state_dict(torch.load(model_save_path))
        return train_losses, val_metrics

    def evaluate(self, loader):
        self.model.eval()
        y_true, y_pred, y_score = [], [], []
        loss_sum = 0.0

        with torch.no_grad():
            for data in loader:
                data = data.to(self.device)
                out = self.model(data)
                loss = self.criterion(out, data.y.view(-1))
                loss_sum += loss.item() * data.num_graphs

                preds = out.argmax(dim=1).cpu().numpy()
                y_true.extend(data.y.cpu().numpy())
                y_pred.extend(preds)
                # For binary classification, extract probability for class 1.
                y_score.extend(out.exp()[:, 1].cpu().numpy())

        y_true = np.array(y_true)
        y_pred = np.array(y_pred)
        y_score = np.array(y_score)

        avg_loss = loss_sum / len(loader.dataset)
        accuracy = accuracy_score(y_true, y_pred)
        precision = precision_score(y_true, y_pred, average='macro')
        recall = recall_score(y_true, y_pred, average='macro')
        f1 = f1_score(y_true, y_pred, average='macro')
        conf_matrix = confusion_matrix(y_true, y_pred)
        class_report = classification_report(y_true, y_pred)

        try:
            auc = roc_auc_score(y_true, y_score)
        except Exception:
            auc = 0.0

        return {
            "loss": avg_loss,
            "accuracy": accuracy,
            "precision": precision,
            "recall": recall,
            "f1_score": f1,
            "confusion_matrix": conf_matrix,
            "classification_report": class_report,
            "auc": auc
        }


In [15]:
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix, classification_report, roc_auc_score

In [None]:
import os
import torch
import numpy as np
from sklearn.model_selection import train_test_split
from torch_geometric.data import DataLoader
import gc

# Configuration for training
config = {
    'chunk_dir': '/kaggle/input/genie-extracted-dataset',  # Directory containing all NPZ chunk files.
    'chunk_pattern': 'chunk_*.npz',  # Pattern to match chunk files.
    'batch_size': 16,
    'hidden_channels': 128,
    'num_layers': 2,
    'dropout': 0.4,
    'learning_rate': 0.001,
    'weight_decay': 5e-4,
    'epochs': 10,
    'patience': 8,
    'test_size': 0.2,  # 20% of all graphs will be used for testing.
    'device': torch.device('cuda' if torch.cuda.is_available() else 'cpu'),
    'use_dynamic_graph': True,
    'gat_heads': 4,
    'use_sagpool': True,
    'sagpool_ratio': 0.5,
    'use_non_local': True
}

print(f"Using device: {config['device']}")

# Helper function to load all chunk files in the specified directory.
def get_chunk_files(chunk_dir, pattern):
    import glob
    return sorted(glob.glob(os.path.join(chunk_dir, pattern)))

# Updated main() function to train chunk by chunk.
def main():
    # Get all NPZ chunk files.
    chunk_files = get_chunk_files(config['chunk_dir'], config['chunk_pattern'])
    print(f"Found {len(chunk_files)} chunk files.")

    # Process all chunks once to form the entire dataset for test split.
    # You may optionally do this offline to obtain indices for test split.
    all_graphs = []
    for file in chunk_files:
        print(f"Loading chunk: {file}")
        processor = JetGraphProcessor(file)  # Your processor to load graphs from NPZ
        graphs = processor.process_chunk()
        all_graphs.extend(graphs)
        # Free memory if necessary:
        del graphs
        gc.collect()
    print(f"Total graphs processed from all chunks: {len(all_graphs)}")

    # Create an in-memory dataset from all graphs.
    dataset = InMemoryJetGraphDataset(all_graphs)
    print(f"InMemoryJetGraphDataset created with {len(dataset)} graphs.")

    # Generate training and test indices.
    indices = list(range(len(dataset)))
    labels = torch.tensor([dataset[i].y.item() for i in range(len(dataset))])
    train_idx, test_idx = train_test_split(indices, test_size=config['test_size'], random_state=42, stratify=labels)

    # Create test dataset and DataLoader (will be used for evaluation every epoch).
    test_dataset = torch.utils.data.Subset(dataset, test_idx)
    test_loader = DataLoader(test_dataset, batch_size=config['batch_size'], shuffle=False)
    print(f"Test dataset size: {len(test_dataset)}")

    # We'll train the model epoch-wise by iterating over chunk files.
    # Initialize the model using a sample graph from the full dataset.
    sample_data = dataset[0]
    node_dim = sample_data.x.size(1)
    global_dim = sample_data.global_features.size(1) if hasattr(sample_data, 'global_features') else 2

    model = EnhancedJetGNN(
        node_dim=node_dim,
        global_dim=global_dim,
        hidden_dim=config['hidden_channels'],
        out_channels=2,
        num_layers=config['num_layers'],
        dropout=config['dropout'],
        use_dynamic_graph=config['use_dynamic_graph'],
        heads=config['gat_heads'],
        use_sagpool=config['use_sagpool'],
        sagpool_ratio=config['sagpool_ratio'],
        use_non_local=config['use_non_local']
    ).to(config['device'])

    # Initialize the trainer.
    trainer = EnhancedJetGNNTrainer(
        model=model,
        device=config['device'],
        lr=config['learning_rate'],
        weight_decay=config['weight_decay'],
    )

    # For early stopping, you can track best performance.
    best_val_metric = float('inf')
    epochs_since_improvement = 0

    # Training loop: For each epoch, iterate over each chunk file sequentially.
    for epoch in range(config['epochs']):
        model.train()
        epoch_loss = 0.0
        num_chunks = 0
        print(f"\nEpoch {epoch+1}/{config['epochs']}")
        
        for chunk_file in chunk_files:
            print(f"Processing training chunk: {chunk_file}")
            processor = JetGraphProcessor(chunk_file)
            chunk_graphs = processor.process_chunk()
            
            # Create a DataLoader for the current chunk.
            chunk_dataset = InMemoryJetGraphDataset(chunk_graphs)
            # Use training indices that are within this chunk.
            # Note: If each chunk is independent, you can train over the entire chunk.
            train_loader = DataLoader(chunk_dataset, batch_size=config['batch_size'], shuffle=True)
            
            for data in train_loader:
                data = data.to(config['device'])
                loss = trainer.train_step(data)  # Assuming trainer has a method train_step(data)
                epoch_loss += loss.item()
            
            # Clear chunk from memory.
            del chunk_graphs, chunk_dataset, train_loader
            gc.collect()
            num_chunks += 1
        
        avg_epoch_loss = epoch_loss / num_chunks if num_chunks > 0 else 0
        print(f"Epoch {epoch+1} average loss: {avg_epoch_loss:.4f}")
        # Evaluate the model on the test set.
        val_metrics = trainer.evaluate(test_loader)
        print(f"Validation metrics: Accuracy: {val_metrics['accuracy']:.4f}, F1: {val_metrics['f1_score']:.4f}")
        
        # Early stopping: If improvement, save model.
        current_val_metric = val_metrics['f1_score']
        if current_val_metric < best_val_metric:
            best_val_metric = current_val_metric
            epochs_since_improvement = 0
            torch.save(model.state_dict(), 'best_jet_gnn_model.pt')
            print("Model improved, saving current model.")
        else:
            epochs_since_improvement += 1
            if epochs_since_improvement >= config['patience']:
                print("No improvement for several epochs, stopping training.")
                break
    
    # Final evaluation on test set.
    final_metrics = trainer.evaluate(test_loader)
    print("\nFinal Test Set Metrics:")
    print(f"Accuracy:  {final_metrics['accuracy']:.4f}")
    print(f"Precision: {final_metrics['precision']:.4f}")
    print(f"Recall:    {final_metrics['recall']:.4f}")
    print(f"F1 Score:  {final_metrics['f1_score']:.4f}")
    print(f"AUC:       {final_metrics['auc']:.4f}")
    print("Confusion Matrix:")
    print(final_metrics['confusion_matrix'])
    print("Classification Report:")
    print(final_metrics['classification_report'])
    
    # Plotting training loss and validation metrics.
    import matplotlib.pyplot as plt
    # Assume trainer has logged train_losses and val metrics history.
    plt.figure(figsize=(12, 5))
    plt.subplot(1, 2, 1)
    plt.plot(trainer.train_losses, label='Train Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Training Loss')
    plt.legend()
    
    if trainer.val_metrics_history:
        epochs_plot = range(1, len(trainer.val_metrics_history) + 1)
        accuracies = [m['accuracy'] for m in trainer.val_metrics_history]
        f1_scores = [m['f1_score'] for m in trainer.val_metrics_history]
        plt.subplot(1, 2, 2)
        plt.plot(epochs_plot, accuracies, label='Accuracy')
        plt.plot(epochs_plot, f1_scores, label='F1 Score')
        plt.xlabel('Epoch')
        plt.ylabel('Score')
        plt.title('Validation Metrics')
        plt.legend()
    else:
        print("Validation metrics history not available for plotting.")
    
    plt.tight_layout()
    plt.show()

if __name__ == '__main__':
    main()


Using device: cuda
Processing file: /kaggle/input/chunk_0_10000.npz


Processing single chunk:  73%|███████▎  | 7311/10000 [03:07<01:05, 40.81it/s]