In [1]:
!pip install torch_geometric

Collecting torch_geometric
  Downloading torch_geometric-2.6.1-py3-none-any.whl.metadata (63 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m63.1/63.1 kB[0m [31m2.1 MB/s[0m eta [36m0:00:00[0m
Downloading torch_geometric-2.6.1-py3-none-any.whl (1.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.1/1.1 MB[0m [31m19.9 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
[?25hInstalling collected packages: torch_geometric
Successfully installed torch_geometric-2.6.1


In [2]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from torch.utils.data import Dataset, DataLoader
from torch_geometric.data import Data
from torch_geometric.nn import (
    global_mean_pool, global_max_pool, global_add_pool,
    GATConv, EdgeConv, SAGPooling
)
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.transforms import BaseTransform

import matplotlib.pyplot as plt
from tqdm import tqdm
from sklearn.model_selection import train_test_split

In [3]:
class RandomPhiShift(BaseTransform):
    def __init__(self, dphi=0.2):
        self.dphi = dphi
    
    def __call__(self, data):
        if data.pos.size(0) > 0:
            shift = (torch.rand(1) * 2 - 1) * self.dphi
            data.pos[:, 1] = data.pos[:, 1] + shift
            data.pos[:, 1] = torch.atan2(torch.sin(data.pos[:, 1]), torch.cos(data.pos[:, 1]))
        return data

In [4]:
class DynamicHybridGraph(nn.Module):
    def __init__(self, k=4, alpha=0.5, beta=0.5):
        super().__init__()
        self.k = k
        self.alpha = alpha
        self.beta = beta
        
    def forward(self, x, pos, batch):
        device = x.device
        batch_size = batch.max().item() + 1
        edge_index_list = []
        edge_attr_list = []
        
        for i in range(batch_size):
            mask = (batch == i)
            x_batch = x[mask]
            pos_batch = pos[mask]
            if x_batch.size(0) <= 1:
                continue
            
            # Compute distances
            dist_geom = torch.cdist(pos_batch, pos_batch)
            dist_feat = torch.cdist(x_batch, x_batch)
            dist_combined = self.alpha * dist_geom + self.beta * dist_feat
            
            k_eff = min(self.k + 1, x_batch.size(0))
            _, topk_indices = torch.topk(dist_combined, k=k_eff, largest=False, dim=1)
            topk_indices = topk_indices[:, 1:]  # remove self-loop
            
            rows = torch.arange(x_batch.size(0), device=device).view(-1, 1).repeat(1, k_eff-1)
            offset = mask.nonzero(as_tuple=True)[0].min().item()
            edge_index = torch.stack([rows.reshape(-1), topk_indices.reshape(-1)], dim=0) + offset
            
            source_nodes = edge_index[0] - offset
            target_nodes = edge_index[1] - offset
            delta_eta = pos_batch[target_nodes, 0] - pos_batch[source_nodes, 0]
            delta_phi = pos_batch[target_nodes, 1] - pos_batch[source_nodes, 1]
            delta_phi = torch.atan2(torch.sin(delta_phi), torch.cos(delta_phi))
            delta_r = torch.sqrt(delta_eta**2 + delta_phi**2)
            delta_embed = x_batch[target_nodes] - x_batch[source_nodes]
            delta_embed_norm = torch.norm(delta_embed, p=2, dim=1, keepdim=True)
            edge_attr = torch.cat([
                delta_eta.unsqueeze(1),
                delta_phi.unsqueeze(1),
                delta_r.unsqueeze(1),
                delta_embed_norm
            ], dim=1)
            
            edge_index_list.append(edge_index)
            edge_attr_list.append(edge_attr)
        
        if not edge_index_list:
            return (torch.zeros((2, 0), device=device, dtype=torch.long),
                    torch.zeros((0, 4), device=device))
        edge_index = torch.cat(edge_index_list, dim=1)
        edge_attr = torch.cat(edge_attr_list, dim=0)
        return edge_index, edge_attr

In [5]:
def create_multiscale_knn(pos, k1=4, k2=6):
    def knn_graph(pos, k):
        n = pos.size(0)
        if n <= 1:
            return torch.zeros((2, 0), dtype=torch.long), torch.zeros((0, 3), dtype=torch.float)
        dist = torch.cdist(pos, pos)
        _, nn_idx = torch.topk(dist, k=min(k+1, n), dim=1, largest=False)
        nn_idx = nn_idx[:, 1:]
        rows = torch.arange(n).view(-1, 1).repeat(1, min(k, n-1))
        edge_index = torch.stack([rows.reshape(-1), nn_idx.reshape(-1)], dim=0)
        source_nodes = edge_index[0]
        target_nodes = edge_index[1]
        delta_eta = pos[target_nodes, 0] - pos[source_nodes, 0]
        delta_phi = pos[target_nodes, 1] - pos[source_nodes, 1]
        delta_phi = torch.atan2(torch.sin(delta_phi), torch.cos(delta_phi))
        delta_r = torch.sqrt(delta_eta**2 + delta_phi**2)
        edge_attr = torch.stack([delta_eta, delta_phi, delta_r], dim=1)
        return edge_index, edge_attr
    eidx_s, eattr_s = knn_graph(pos, k1)
    eidx_l, eattr_l = knn_graph(pos, k2)
    return eidx_s, eattr_s, eidx_l, eattr_l

In [6]:
class JetGraphProcessor:
    def __init__(self, input_dir, output_dir, transform=None, k=8, min_energy_threshold=1e-4):
        self.input_dir = input_dir
        self.output_dir = output_dir
        self.transform = transform
        self.k = k
        self.min_energy_threshold = min_energy_threshold
        os.makedirs(output_dir, exist_ok=True)
    
    def _convert_to_graph(self, jet_image, label, m0, pt):
        jet_image_2d = np.sum(jet_image, axis=2)
        padded = np.pad(jet_image_2d, pad_width=1, mode='constant', constant_values=0)
        non_zero_indices = np.where(jet_image_2d > self.min_energy_threshold)
        points = []
        features = []
        for i, j in zip(non_zero_indices[0], non_zero_indices[1]):
            pixel_eta = (i / 125.0 * 2 - 1) * 0.8
            pixel_phi = (j / 125.0 * 2 - 1) * 0.8
            energy_ecal = jet_image[i, j, 0]
            energy_hcal = jet_image[i, j, 1]
            energy_tracks = jet_image[i, j, 2]
            total_energy = energy_ecal + energy_hcal + energy_tracks
            pt_fraction = total_energy / (pt + 1e-9)
            charged_fraction = energy_tracks / total_energy if total_energy > 0 else 0.0
            local_sum = np.sum(padded[i:i+3, j:j+3])
            angle_center = np.arctan2(j - 62.5, i - 62.5)
            norm_dist_center = np.sqrt((i - 62.5)**2 + (j - 62.5)**2) / 62.5
            features.append([
                total_energy, energy_ecal, energy_hcal, energy_tracks,
                pt_fraction, charged_fraction, local_sum,
                pixel_eta, pixel_phi, np.log1p(total_energy),
                np.sqrt(total_energy), angle_center, norm_dist_center
            ])
            points.append([pixel_eta, pixel_phi])
        if len(points) == 0:
            points = [[0, 0]]
            features = [[0.0]*13]
        x = torch.tensor(features, dtype=torch.float)
        pos = torch.tensor(points, dtype=torch.float)
        eidx_s, eattr_s, eidx_l, eattr_l = create_multiscale_knn(pos, k1=8, k2=16)
        data = Data(
            x=x,
            pos=pos,
            edge_index=eidx_s,
            edge_attr=eattr_s,
            y=torch.tensor([label], dtype=torch.long)
        )
        data.edge_index_large = eidx_l
        data.edge_attr_large = eattr_l
        data.global_features = torch.tensor([m0, pt], dtype=torch.float).unsqueeze(0)
        return data
    
    def process_chunks(self, chunk_files):
        total_graphs = 0
        for chunk_file in chunk_files:
            print(f"Processing {chunk_file} ...")
            parts = chunk_file.replace('chunk_', '').replace('.npz', '').split('_')
            if len(parts) >= 2:
                start_idx, end_idx = int(parts[0]), int(parts[1])
            else:
                start_idx, end_idx = 0, 0
            data_npz = np.load(os.path.join(self.input_dir, chunk_file))
            X_jets = data_npz['X_jets']
            y = data_npz['y']
            m0 = data_npz['m0']
            pt = data_npz['pt']
            graph_list = []
            for i in tqdm(range(len(X_jets)), desc=f"Processing {chunk_file}"):
                graph = self._convert_to_graph(X_jets[i], int(y[i]), m0[i], pt[i])
                if self.transform:
                    graph = self.transform(graph)
                graph_list.append(graph)
            output_file = os.path.join(self.output_dir, f"processed_{chunk_file.replace('.npz', '.pt')}")
            torch.save(graph_list, output_file)
            total_graphs += len(graph_list)
        print(f"Processed {total_graphs} graphs from {len(chunk_files)} chunks.")
        return total_graphs


In [7]:
class ChunkedJetGraphDataset(Dataset):
    def __init__(self, processed_dir, transform=None):
        super().__init__()
        self.processed_dir = processed_dir
        self.transform = transform
        self.processed_files = sorted([f for f in os.listdir(processed_dir) if f.startswith('processed_chunk_')])
        self._create_index_mapping()
        
    def _create_index_mapping(self):
        self.chunk_sizes = []
        self.chunk_indices = []
        start_idx = 0
        for file in self.processed_files:
            file_path = os.path.join(self.processed_dir, file)
            if os.path.exists(file_path):
                try:
                    graphs = torch.load(file_path)
                    chunk_size = len(graphs)
                    self.chunk_sizes.append(chunk_size)
                    end_idx = start_idx + chunk_size
                    self.chunk_indices.append((start_idx, end_idx, file))
                    start_idx = end_idx
                    del graphs
                except Exception as e:
                    print(f"Error loading {file}: {e}")
        self.total_size = sum(self.chunk_sizes)
        
    def __len__(self):
        return self.total_size
    
    def __getitem__(self, idx):
        for start_idx, end_idx, file in self.chunk_indices:
            if start_idx <= idx < end_idx:
                graphs = torch.load(os.path.join(self.processed_dir, file))
                local_idx = idx - start_idx
                graph = graphs[local_idx]
                if self.transform:
                    graph = self.transform(graph)
                return graph
        raise IndexError(f"Index {idx} out of range")

In [8]:
class CoordinateEmbedding(nn.Module):
   
    def __init__(self, in_dim=2, out_dim=4):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(in_dim, out_dim),
            nn.ReLU(),
            nn.Linear(out_dim, out_dim)
        )
    
    def forward(self, coords):
        return self.mlp(coords)

In [None]:
class EnhancedJetGNN(nn.Module):
    def __init__(self, node_dim, global_dim, hidden_dim=64, out_channels=2, 
                 num_layers=3, dropout=0.3, use_dynamic_graph=True, heads=4, 
                 use_sagpool=True, sagpool_ratio=0.5):
        super().__init__()
        self.node_dim = node_dim
        self.global_dim = global_dim
        self.hidden_dim = hidden_dim
        self.num_layers = num_layers
        self.dropout = dropout
        self.use_dynamic_graph = use_dynamic_graph
        self.heads = heads
        self.use_sagpool = use_sagpool
        self.sagpool_ratio = sagpool_ratio
        
        # Coordinate embedding for (η, φ) at indices 7 and 8
        self.coord_embed = CoordinateEmbedding(in_dim=2, out_dim=4)
        # Process the remaining node features
        self.feature_mlp = nn.Sequential(
            nn.Linear(node_dim - 2, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )
        # Combine coordinate embedding with processed features
        self.comb_mlp = nn.Sequential(
            nn.Linear(hidden_dim + 4, hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU()
        )
        # Global feature encoder
        self.global_encoder = nn.Sequential(
            nn.Linear(global_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU()
        )
        # Dynamic adjacency module
        if use_dynamic_graph:
            self.dynamic_graph = DynamicHybridGraph(k=8, alpha=0.5, beta=0.5)
        # GNN layers (alternating GAT and EdgeConv)
        self.gnn_layers = nn.ModuleList()
        self.batch_norms = nn.ModuleList()
        self.gnn_layers.append(
            GATConv(hidden_dim, hidden_dim // heads, heads=heads, edge_dim=4, dropout=dropout)
        )
        self.batch_norms.append(BatchNorm(hidden_dim))
        for i in range(1, num_layers):
            if i % 2 == 1:
                self.gnn_layers.append(
                    EdgeConv(
                        nn=nn.Sequential(
                            nn.Linear(hidden_dim*2, hidden_dim),
                            nn.ReLU(),
                            nn.Dropout(dropout),
                            nn.Linear(hidden_dim, hidden_dim)
                        ),
                        aggr='mean'
                    )
                )
            else:
                self.gnn_layers.append(
                    GATConv(hidden_dim, hidden_dim // heads, heads=heads, edge_dim=4, dropout=dropout)
                )
            self.batch_norms.append(BatchNorm(hidden_dim))
        # Optional SAGPooling for hierarchical pooling
        if self.use_sagpool:
            self.sagpool = SAGPooling(hidden_dim, ratio=self.sagpool_ratio)
        # Final classifier: multi-scale pooling (mean, max, sum) combined with global features
        self.classifier = nn.Sequential(
            nn.Linear(hidden_dim*3 + hidden_dim, hidden_dim*2),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim*2, hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, out_channels)
        )
    
    def forward(self, data):
        x, batch = data.x, data.batch
        edge_index, edge_attr = data.edge_index, data.edge_attr
        pos = data.pos
        if hasattr(data, 'global_features'):
            global_features = data.global_features
        else:
            global_features = torch.zeros((batch.max().item()+1, self.global_dim), device=x.device)
        if batch is None:
            batch = torch.zeros(x.size(0), dtype=torch.long, device=x.device)
        
        # Separate coordinate features (assumed at indices 7 and 8)
        coords = x[:, [7, 8]]
        other_feats = torch.cat([x[:, :7], x[:, 9:]], dim=1)
        coord_emb = self.coord_embed(coords)
        feat_emb = self.feature_mlp(other_feats)
        x = torch.cat([feat_emb, coord_emb], dim=1)
        x = self.comb_mlp(x)
        global_x = self.global_encoder(global_features)
        
        for i, conv in enumerate(self.gnn_layers):
            if self.use_dynamic_graph:
                edge_index, edge_attr = self.dynamic_graph(x, pos, batch)
            if isinstance(conv, GATConv):
                x = conv(x, edge_index, edge_attr=edge_attr)
            else:
                x = conv(x, edge_index)
            x = self.batch_norms[i](x)
            x = F.relu(x)
            x = F.dropout(x, p=self.dropout, training=self.training)
        
        if self.use_sagpool:
            x, edge_index, edge_attr, batch, _, _ = self.sagpool(x, edge_index, edge_attr, batch=batch)
        pooled_mean = global_mean_pool(x, batch)
        pooled_max = global_max_pool(x, batch)
        pooled_sum = global_add_pool(x, batch)
        combined = torch.cat([pooled_mean, pooled_max, pooled_sum, global_x], dim=1)
        out = self.classifier(combined)
        return F.log_softmax(out, dim=-1)

In [10]:
class LabelSmoothingLoss(nn.Module):
    def __init__(self, smoothing=0.1, reduction='mean'):
        super().__init__()
        self.smoothing = smoothing
        self.reduction = reduction
    
    def forward(self, pred, target):
        num_classes = pred.size(1)
        with torch.no_grad():
            true_dist = pred.new_ones(pred.size()) * (self.smoothing / (num_classes - 1))
            true_dist.scatter_(1, target.unsqueeze(1), 1.0 - self.smoothing)
        loss = -true_dist * pred
        loss = loss.sum(dim=1)
        if self.reduction == 'mean':
            return loss.mean()
        elif self.reduction == 'sum':
            return loss.sum()
        else:
            return loss

In [11]:
class EnhancedJetGNNTrainer:
    def __init__(self, model, device, lr=1e-3, weight_decay=5e-4, smoothing=0.1):
        self.model = model.to(device)
        self.device = device
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=lr, weight_decay=weight_decay)
        self.criterion = LabelSmoothingLoss(smoothing=smoothing).to(device)
    
    def train(self, train_loader, val_loader, num_epochs=50, patience=10, model_save_path='best_jet_gnn_model.pt'):
        best_val_loss = float('inf')
        wait = 0
        train_losses = []
        val_metrics = []
        for epoch in range(num_epochs):
            self.model.train()
            running_loss = 0.0
            for data in train_loader:
                data = data.to(self.device)
                self.optimizer.zero_grad()
                out = self.model(data)
                loss = self.criterion(out, data.y.view(-1))
                loss.backward()
                self.optimizer.step()
                running_loss += loss.item() * data.num_graphs
            epoch_train_loss = running_loss / len(train_loader.dataset)
            val_loss, val_acc, val_f1, val_auc = self.evaluate(val_loader)
            train_losses.append(epoch_train_loss)
            val_metrics.append({'loss': val_loss, 'accuracy': val_acc, 'f1': val_f1, 'auc': val_auc})
            print(f"Epoch [{epoch+1}/{num_epochs}] Train Loss: {epoch_train_loss:.4f} | Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.4f} | Val F1: {val_f1:.4f} | Val AUC: {val_auc:.4f}")
            if val_loss < best_val_loss:
                best_val_loss = val_loss
                torch.save(self.model.state_dict(), model_save_path)
                wait = 0
            else:
                wait += 1
                if wait >= patience:
                    print("Early stopping triggered.")
                    break
        self.model.load_state_dict(torch.load(model_save_path))
        return train_losses, val_metrics
    
    def evaluate(self, loader, return_predictions=False):
        self.model.eval()
        correct = 0
        total = 0
        y_true, y_pred, y_score = [], [], []
        loss_sum = 0.0
        with torch.no_grad():
            for data in loader:
                data = data.to(self.device)
                out = self.model(data)
                loss = self.criterion(out, data.y.view(-1))
                loss_sum += loss.item() * data.num_graphs
                pred = out.argmax(dim=-1)
                correct += pred.eq(data.y.view(-1)).sum().item()
                total += data.num_graphs
                y_true.append(data.y.view(-1).cpu().numpy())
                y_pred.append(pred.cpu().numpy())
                prob_class1 = out.exp()[:, 1].cpu().numpy()
                y_score.append(prob_class1)
        y_true = np.concatenate(y_true, axis=0)
        y_pred = np.concatenate(y_pred, axis=0)
        y_score = np.concatenate(y_score, axis=0)
        val_loss = loss_sum / len(loader.dataset)
        accuracy = (correct / total) if total > 0 else 0
        from sklearn.metrics import f1_score, roc_auc_score
        f1 = f1_score(y_true, y_pred, average='binary')
        try:
            auc = roc_auc_score(y_true, y_score)
        except:
            auc = 0.0
        if return_predictions:
            return {'loss': val_loss, 'accuracy': accuracy, 'f1': f1, 'auc': auc}, (y_true, y_pred, y_score)
        else:
            return val_loss, accuracy, f1, auc

In [12]:
def main():
    config = {
        'input_dir': '/kaggle/input/genie-extracted-dataset',
        'processed_dir': '/kaggle/working/processed_jet_graphs',
        'batch_size': 64,
        'hidden_channels': 64,
        'num_layers': 3,
        'dropout': 0.3,
        'learning_rate': 0.001,
        'weight_decay': 5e-4,
        'epochs': 50,
        'patience': 10,
        'test_size': 0.1,
        'val_size': 0.1,
        'use_dynamic_graph': True,
        'gat_heads': 4,
        'device': torch.device('cuda' if torch.cuda.is_available() else 'cpu'),
        'use_sagpool': True,
        'sagpool_ratio': 0.5
    }
    
    print(f"Using device: {config['device']}")
    
    # Define the chunks to skip.
    skipped_chunks = {
        "chunk_0_10000.npz",
        "chunk_100000_110000.npz",
        "chunk_10000_20000.npz",
        "chunk_110000_120000.npz"
    }
    
    # List all .npz chunk files from the input directory, skipping the ones we don't want.
    all_chunk_files = sorted([f for f in os.listdir(config['input_dir']) 
                               if f.endswith('.npz') and f not in skipped_chunks])
    
    if len(all_chunk_files) < 2:
        raise ValueError("Need at least 2 chunk files for training and inference separation.")
    
    os.makedirs(config['processed_dir'], exist_ok=True)
    # Process and save all chunks except the last one.
    if not os.path.exists(config['processed_dir']) or len(os.listdir(config['processed_dir'])) == 0:
        print("Processing data into graph representations for training...")
        processor = JetGraphProcessor(
            config['input_dir'],
            config['processed_dir'],
            transform=RandomPhiShift(dphi=0.2),
            k=8,
            min_energy_threshold=1e-4
        )
        # Process all chunks except the last one
        processor.process_chunks(all_chunk_files[:-1])
    else:
        print(f"Using existing processed graphs from {config['processed_dir']}")
    
    # Create dataset from the processed (saved) chunks.
    dataset = ChunkedJetGraphDataset(config['processed_dir'])
    print(f"Total dataset size: {len(dataset)}")
    
    indices = list(range(len(dataset)))
    train_idx, test_idx = train_test_split(indices, test_size=config['test_size'], random_state=42)
    train_idx, val_idx = train_test_split(train_idx, test_size=config['val_size']/(1 - config['test_size']), random_state=42)
    
    train_dataset = torch.utils.data.Subset(dataset, train_idx)
    val_dataset   = torch.utils.data.Subset(dataset, val_idx)
    test_dataset  = torch.utils.data.Subset(dataset, test_idx)
    
    print(f"Train size: {len(train_dataset)}")
    print(f"Validation size: {len(val_dataset)}")
    print(f"Test size: {len(test_dataset)}")
    
    train_loader = DataLoader(train_dataset, batch_size=config['batch_size'], shuffle=True)
    val_loader   = DataLoader(val_dataset, batch_size=config['batch_size'], shuffle=False)
    test_loader  = DataLoader(test_dataset, batch_size=config['batch_size'], shuffle=False)
    
    sample_data = dataset[0]
    node_dim = sample_data.x.size(1)
    global_dim = sample_data.global_features.size(1) if hasattr(sample_data, 'global_features') else 2
    
    model = EnhancedJetGNN(
        node_dim=node_dim,
        global_dim=global_dim,
        hidden_dim=config['hidden_channels'],
        out_channels=2,
        num_layers=config['num_layers'],
        dropout=config['dropout'],
        use_dynamic_graph=config['use_dynamic_graph'],
        heads=config['gat_heads'],
        use_sagpool=config['use_sagpool'],
        sagpool_ratio=config['sagpool_ratio']
    ).to(config['device'])
    
    trainer = EnhancedJetGNNTrainer(
        model=model,
        device=config['device'],
        lr=config['learning_rate'],
        weight_decay=config['weight_decay'],
        smoothing=0.1
    )
    
    # Train and save best model weights after all epochs.
    train_losses, val_metrics = trainer.train(
        train_loader=train_loader,
        val_loader=val_loader,
        num_epochs=config['epochs'],
        patience=config['patience'],
        model_save_path='best_jet_gnn_model.pt'
    )
    
    test_metrics, (y_true, y_pred, y_score) = trainer.evaluate(test_loader, return_predictions=True)
    print("\nTest Set Metrics:")
    print(f"Accuracy:  {test_metrics['accuracy']:.4f}")
    print(f"F1 Score:  {test_metrics['f1']:.4f}")
    print(f"AUC:       {test_metrics['auc']:.4f}")
    
    # Inference on the last chunk (which was not saved)
    last_chunk_file = all_chunk_files[-1]
    print(f"Running inference on the last chunk: {last_chunk_file}")
    data_npz = np.load(os.path.join(config['input_dir'], last_chunk_file))
    X_jets = data_npz['X_jets']
    y_last = data_npz['y']
    m0_last = data_npz['m0']
    pt_last = data_npz['pt']
    last_chunk_graphs = []
    processor = JetGraphProcessor(
        config['input_dir'],
        config['processed_dir'],
        transform=RandomPhiShift(dphi=0.2),
        k=8,
        min_energy_threshold=1e-4
    )
    for i in tqdm(range(len(X_jets)), desc=f"Processing last chunk {last_chunk_file} for inference"):
        graph = processor._convert_to_graph(X_jets[i], int(y_last[i]), m0_last[i], pt_last[i])
        if processor.transform:
            graph = processor.transform(graph)
        last_chunk_graphs.append(graph)
    last_chunk_loader = DataLoader(last_chunk_graphs, batch_size=config['batch_size'], shuffle=False)
    last_chunk_metrics, (lc_y_true, lc_y_pred, lc_y_score) = trainer.evaluate(last_chunk_loader, return_predictions=True)
    print("\nLast Chunk Inference Metrics:")
    print(f"Accuracy:  {last_chunk_metrics['accuracy']:.4f}")
    print(f"F1 Score:  {last_chunk_metrics['f1']:.4f}")
    print(f"AUC:       {last_chunk_metrics['auc']:.4f}")
    
    # Plot training history
    plt.figure(figsize=(12, 4))
    plt.subplot(1, 2, 1)
    plt.plot(train_losses, label='Train Loss')
    plt.plot([m['loss'] for m in val_metrics], label='Validation Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.subplot(1, 2, 2)
    plt.plot([m['accuracy'] for m in val_metrics], label='Accuracy')
    plt.plot([m['f1'] for m in val_metrics], label='F1')
    plt.plot([m['auc'] for m in val_metrics], label='AUC')
    plt.xlabel('Epoch')
    plt.ylabel('Score')
    plt.legend()
    plt.tight_layout()
    plt.savefig('training_history.png')
    plt.show()

if __name__ == "__main__":
    main()


Using device: cuda
Processing data into graph representations for training...
Processing chunk_0_10000.npz ...


Processing chunk_0_10000.npz: 100%|██████████| 10000/10000 [03:54<00:00, 42.69it/s]


Processing chunk_100000_110000.npz ...


Processing chunk_100000_110000.npz: 100%|██████████| 10000/10000 [04:06<00:00, 40.52it/s]


Processing chunk_10000_20000.npz ...


Processing chunk_10000_20000.npz: 100%|██████████| 10000/10000 [03:53<00:00, 42.89it/s]


Processing chunk_110000_120000.npz ...


Processing chunk_110000_120000.npz: 100%|██████████| 10000/10000 [04:09<00:00, 40.11it/s]


Processing chunk_120000_130000.npz ...


Processing chunk_120000_130000.npz:  39%|███▉      | 3875/10000 [01:37<02:33, 39.90it/s]


KeyboardInterrupt: 