In [1]:
! pip install gdown
! gdown --id 1WO2K-SfU2dntGU4Bb3IYBp9Rh7rtTYEr -O filename
! pip install h5p
! pip install torch_geometric
! pip install torch_sparse torch_scatter torch_cluster torch_spline_conv -f https://data.pyg.org/whl/torch-$(python -c "import torch; print(torch.__version__.split('+')[0])")+cpu.html

Downloading...
From (original): https://drive.google.com/uc?id=1WO2K-SfU2dntGU4Bb3IYBp9Rh7rtTYEr
From (redirected): https://drive.google.com/uc?id=1WO2K-SfU2dntGU4Bb3IYBp9Rh7rtTYEr&confirm=t&uuid=5a0d3c53-1b08-44e8-ba59-cde7deb121da
To: /kaggle/working/filename
100%|█████████████████████████████████████████| 701M/701M [00:06<00:00, 107MB/s]
[31mERROR: Could not find a version that satisfies the requirement h5p (from versions: none)[0m[31m
[0m[31mERROR: No matching distribution found for h5p[0m[31m
[0mCollecting torch_geometric
  Downloading torch_geometric-2.6.1-py3-none-any.whl.metadata (63 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m63.1/63.1 kB[0m [31m2.6 MB/s[0m eta [36m0:00:00[0m
Downloading torch_geometric-2.6.1-py3-none-any.whl (1.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.1/1.1 MB[0m [31m29.9 MB/s[0m eta [36m0:00:00[0m00:01[0m
[?25hInstalling collected packages: torch_geometric
Successfully installed torch_g

In [5]:
import h5py
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.data import Data, DataLoader
from torch_geometric.nn import MessagePassing, global_mean_pool
from sklearn.neighbors import radius_neighbors_graph
from sklearn.model_selection import train_test_split
from tqdm.auto import tqdm
from torch.optim.lr_scheduler import ReduceLROnPlateau
from sklearn.metrics import roc_auc_score
import gc

# ======================
# 1. Graph Construction
# ======================
def multi_channel_image_to_graph(ecal, hcal, track, threshold=0.01):
    """Convert 3-channel jet image to graph"""
    nodes = []
    height, width = ecal.shape
    
    for i in range(height):
        for j in range(width):
            total_energy = ecal[i,j] + hcal[i,j] + track[i,j]
            if total_energy > threshold:
                nodes.append([
                    i/float(height),   # Normalized x
                    j/float(width),    # Normalized y
                    ecal[i,j],         # ECAL
                    hcal[i,j],         # HCAL
                    track[i,j]         # Track
                ])
    
    if len(nodes) == 0:  # Fallback for empty graphs
        combined = ecal + hcal + track
        max_idx = np.unravel_index(np.argmax(combined), combined.shape)
        nodes.append([
            max_idx[0]/float(height), max_idx[1]/float(width),
            ecal[max_idx], hcal[max_idx], track[max_idx]
        ])
    
    nodes = np.array(nodes, dtype=np.float32)
    pos = nodes[:, :2]
    
    if len(nodes) > 1:
        edges = radius_neighbors_graph(pos, radius=0.15, mode='connectivity')
        edge_index = torch.tensor(edges.nonzero(), dtype=torch.long)
    else:
        edge_index = torch.tensor([[0], [0]], dtype=torch.long)
    
    return Data(x=torch.tensor(nodes, dtype=torch.float),
                edge_index=edge_index)

# ======================
# 2. Data Loading
# ======================
def load_data(filename, num_jets=30000, threshold=0.01):
    """Load jet data from HDF5 and convert to graphs"""
    graphs = []
    with h5py.File(filename, 'r') as f:
        X_jets = f['X_jets'][:num_jets]  # Load first `num_jets` jets
        m0 = f['m0'][:num_jets]
        pt = f['pt'][:num_jets]
        y = f['y'][:num_jets]
        
        for i in tqdm(range(num_jets), desc="Creating graphs"):
            ecal = X_jets[i, 0, :, :]  # ECAL channel
            hcal = X_jets[i, 1, :, :]  # HCAL channel
            track = X_jets[i, 2, :, :]  # Track channel
            
            data = multi_channel_image_to_graph(ecal, hcal, track, threshold)
            data.m0 = torch.tensor([m0[i]], dtype=torch.float)
            data.pt = torch.tensor([pt[i]], dtype=torch.float)
            data.y = torch.tensor([int(y[i])], dtype=torch.long)
            graphs.append(data)
    
    return graphs

# ======================
# 3. Non-Local Block
# ======================
class NonLocalBlock(nn.Module):
    def __init__(self, in_channels, reduction=2):
        super().__init__()
        self.inter_channels = in_channels // reduction
        
        self.theta = nn.Linear(in_channels, self.inter_channels)
        self.phi = nn.Linear(in_channels, self.inter_channels)
        self.g = nn.Linear(in_channels, self.inter_channels)
        self.out = nn.Linear(self.inter_channels, in_channels)
    
    def forward(self, x, batch):
        N = x.size(0)
        
        theta = self.theta(x)  # [N, inter_channels]
        phi = self.phi(x)      # [N, inter_channels]
        g = self.g(x)          # [N, inter_channels]
        
        # Attention scores [N, N]
        attention = torch.matmul(theta, phi.T) / np.sqrt(self.inter_channels)
        attention = F.softmax(attention, dim=-1)
        
        # Aggregation
        out = torch.matmul(attention, g)  # [N, inter_channels]
        out = self.out(out)  # [N, in_channels]
        
        return out + x  # Residual connection

# ======================
# 4. GNN Models
# ======================
class EdgeConv(MessagePassing):
    def __init__(self, in_channels, out_channels):
        super().__init__(aggr='mean')
        self.mlp = nn.Sequential(
            nn.Linear(2*in_channels, out_channels),
            nn.ReLU(),
            nn.BatchNorm1d(out_channels),
            nn.Linear(out_channels, out_channels)
        )
    
    def forward(self, x, edge_index):
        return self.propagate(edge_index, x=x)
    
    def message(self, x_i, x_j):
        return self.mlp(torch.cat([x_i, x_j-x_i], dim=1))

class JetGNN(nn.Module):
    def __init__(self, node_features=5, hidden_dim=64):
        super().__init__()
        self.conv1 = EdgeConv(node_features, hidden_dim)
        self.conv2 = EdgeConv(hidden_dim, hidden_dim)
        self.conv3 = EdgeConv(hidden_dim, hidden_dim)
        
        self.global_mlp = nn.Sequential(
            nn.Linear(2, 12),
            nn.ReLU(),
            nn.BatchNorm1d(12))
        
        self.classifier = nn.Sequential(
            nn.Linear(12 + hidden_dim, hidden_dim // 2),
            nn.ReLU(),
            nn.Linear(hidden_dim // 2, 2))
    
    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        
        x = F.leaky_relu(self.conv1(x, edge_index))
        x = F.leaky_relu(self.conv2(x, edge_index))
        x = F.leaky_relu(self.conv3(x, edge_index))
        
        graph_feat = global_mean_pool(x, batch)
        global_feat = self.global_mlp(torch.stack([data.m0, data.pt], dim=1))
        
        return self.classifier(torch.cat([graph_feat, global_feat], dim=1))

class NonLocalJetGNN(nn.Module):
    def __init__(self, node_features=5, hidden_dim=64):
        super().__init__()
        self.conv1 = EdgeConv(node_features, hidden_dim)
        self.conv2 = EdgeConv(hidden_dim, hidden_dim)
        self.conv3 = EdgeConv(hidden_dim, hidden_dim)
        
        self.non_local = NonLocalBlock(hidden_dim)
        
        self.global_mlp = nn.Sequential(
            nn.Linear(2, 12),
            nn.ReLU(),
            nn.BatchNorm1d(12))
        
        self.classifier = nn.Sequential(
            nn.Linear(12 + hidden_dim, hidden_dim // 2),
            nn.ReLU(),
            nn.Linear(hidden_dim // 2, 2))
    
    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        
        x = F.leaky_relu(self.conv1(x, edge_index))
        x = F.leaky_relu(self.conv2(x, edge_index))
        x = F.leaky_relu(self.conv3(x, edge_index))
        
        x = self.non_local(x, batch)  # Non-local aggregation
        
        graph_feat = global_mean_pool(x, batch)
        global_feat = self.global_mlp(torch.stack([data.m0, data.pt], dim=1))
        
        return self.classifier(torch.cat([graph_feat, global_feat], dim=1))

# ======================
# 5. Training Utilities
# ======================
class EarlyStopping:
    def __init__(self, patience=5, delta=0, path='best_model.pt'):
        self.patience = patience
        self.delta = delta
        self.path = path
        self.counter = 0
        self.best_score = None
        self.early_stop = False

    def __call__(self, val_loss, model):
        if self.best_score is None:
            self.best_score = val_loss
            self.save_checkpoint(model)
        elif val_loss > self.best_score + self.delta:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = val_loss
            self.save_checkpoint(model)
            self.counter = 0

    def save_checkpoint(self, model):
        torch.save(model.state_dict(), self.path)

def train_epoch(model, loader, optimizer, criterion, device):
    model.train()
    total_loss = 0
    pbar = tqdm(loader, leave=False, desc="Training")
    for data in pbar:
        data = data.to(device)
        optimizer.zero_grad()
        out = model(data)
        loss = criterion(out, data.y.squeeze())
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        pbar.set_postfix({'loss': f"{loss.item():.4f}"})
    return total_loss / len(loader)

def validate(model, loader, criterion, device):
    model.eval()
    total_loss = 0
    all_preds, all_targets = [], []
    pbar = tqdm(loader, leave=False, desc="Validation")
    with torch.no_grad():
        for data in pbar:
            data = data.to(device)
            out = model(data)
            loss = criterion(out, data.y.squeeze())
            total_loss += loss.item()
            
            probs = F.softmax(out, dim=1)
            all_preds.append(probs[:, 1].cpu().numpy())
            all_targets.append(data.y.squeeze().cpu().numpy())
    
    auc = roc_auc_score(np.concatenate(all_targets), np.concatenate(all_preds))
    return total_loss / len(loader), auc

# ======================
# 6. Main Training Loop
# ======================
def main():
    # Config
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    threshold = 0.01
    batch_size = 32
    hidden_dim = 64
    lr = 0.001
    epochs = 100
    num_jets = 30000  # Reduce if memory is limited
    
    # Load and preprocess data
    print("Loading data...")
    graphs = load_data('/kaggle/working/filename', num_jets=num_jets, threshold=threshold)
    
    # Split data
    train_graphs, val_graphs = train_test_split(graphs, test_size=0.2, random_state=42)
    train_loader = DataLoader(train_graphs, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_graphs, batch_size=batch_size)
    
    # Initialize models
    models = {
        "Baseline GNN": JetGNN(hidden_dim=hidden_dim).to(device),
        "Non-Local GNN": NonLocalJetGNN(hidden_dim=hidden_dim).to(device)
    }
    
    results = {}
    criterion = nn.CrossEntropyLoss()
    
    for name, model in models.items():
        print(f"\n===== Training {name} =====")
        optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=1e-4)
        scheduler = ReduceLROnPlateau(optimizer, 'min', patience=2, factor=0.5)
        early_stopping = EarlyStopping(patience=5, path=f'best_{name.lower().replace(" ", "_")}.pt')
        
        best_auc = 0
        for epoch in range(epochs):
            train_loss = train_epoch(model, train_loader, optimizer, criterion, device)
            val_loss, val_auc = validate(model, val_loader, criterion, device)
            
            scheduler.step(val_loss)
            if val_auc > best_auc:
                best_auc = val_auc
                torch.save(model.state_dict(), f'best_{name.lower().replace(" ", "_")}_auc.pt')
            
            print(f"Epoch {epoch+1:03d} | Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f} | Val AUC: {val_auc:.4f}")
            
            early_stopping(val_loss, model)
            if early_stopping.early_stop:
                print(f"Early stopping triggered for {name}!")
                break
        
        results[name] = best_auc
    
    print("\n===== Final Results =====")
    for name, auc in results.items():
        print(f"{name}: ROC-AUC = {auc:.4f}")

if __name__ == "__main__":
    main()

Loading data...


Creating graphs:   0%|          | 0/30000 [00:00<?, ?it/s]


===== Training Baseline GNN =====




Training:   0%|          | 0/750 [00:00<?, ?it/s]

Validation:   0%|          | 0/188 [00:00<?, ?it/s]

Epoch 001 | Train Loss: 0.6286 | Val Loss: 0.6140 | Val AUC: 0.7259


Training:   0%|          | 0/750 [00:00<?, ?it/s]

Validation:   0%|          | 0/188 [00:00<?, ?it/s]

Epoch 002 | Train Loss: 0.6196 | Val Loss: 0.6125 | Val AUC: 0.7283


Training:   0%|          | 0/750 [00:00<?, ?it/s]

Validation:   0%|          | 0/188 [00:00<?, ?it/s]

Epoch 003 | Train Loss: 0.6200 | Val Loss: 0.6122 | Val AUC: 0.7287


Training:   0%|          | 0/750 [00:00<?, ?it/s]

Validation:   0%|          | 0/188 [00:00<?, ?it/s]

Epoch 004 | Train Loss: 0.6187 | Val Loss: 0.6110 | Val AUC: 0.7288


Training:   0%|          | 0/750 [00:00<?, ?it/s]

Validation:   0%|          | 0/188 [00:00<?, ?it/s]

Epoch 005 | Train Loss: 0.6185 | Val Loss: 0.6110 | Val AUC: 0.7283


Training:   0%|          | 0/750 [00:00<?, ?it/s]

Validation:   0%|          | 0/188 [00:00<?, ?it/s]

Epoch 006 | Train Loss: 0.6192 | Val Loss: 0.6132 | Val AUC: 0.7275


Training:   0%|          | 0/750 [00:00<?, ?it/s]

Validation:   0%|          | 0/188 [00:00<?, ?it/s]

Epoch 007 | Train Loss: 0.6179 | Val Loss: 0.6115 | Val AUC: 0.7290


Training:   0%|          | 0/750 [00:00<?, ?it/s]

Validation:   0%|          | 0/188 [00:00<?, ?it/s]

Epoch 008 | Train Loss: 0.6173 | Val Loss: 0.6109 | Val AUC: 0.7291


Training:   0%|          | 0/750 [00:00<?, ?it/s]

Validation:   0%|          | 0/188 [00:00<?, ?it/s]

Epoch 009 | Train Loss: 0.6172 | Val Loss: 0.6110 | Val AUC: 0.7291


Training:   0%|          | 0/750 [00:00<?, ?it/s]

Validation:   0%|          | 0/188 [00:00<?, ?it/s]

Epoch 010 | Train Loss: 0.6169 | Val Loss: 0.6112 | Val AUC: 0.7292


Training:   0%|          | 0/750 [00:00<?, ?it/s]

Validation:   0%|          | 0/188 [00:00<?, ?it/s]

Epoch 011 | Train Loss: 0.6169 | Val Loss: 0.6105 | Val AUC: 0.7299


Training:   0%|          | 0/750 [00:00<?, ?it/s]

Validation:   0%|          | 0/188 [00:00<?, ?it/s]

Epoch 012 | Train Loss: 0.6165 | Val Loss: 0.6113 | Val AUC: 0.7283


Training:   0%|          | 0/750 [00:00<?, ?it/s]

Validation:   0%|          | 0/188 [00:00<?, ?it/s]

Epoch 013 | Train Loss: 0.6176 | Val Loss: 0.6107 | Val AUC: 0.7290


Training:   0%|          | 0/750 [00:00<?, ?it/s]

Validation:   0%|          | 0/188 [00:00<?, ?it/s]

Epoch 014 | Train Loss: 0.6166 | Val Loss: 0.6104 | Val AUC: 0.7296


Training:   0%|          | 0/750 [00:00<?, ?it/s]

Validation:   0%|          | 0/188 [00:00<?, ?it/s]

Epoch 015 | Train Loss: 0.6173 | Val Loss: 0.6105 | Val AUC: 0.7294


Training:   0%|          | 0/750 [00:00<?, ?it/s]

Validation:   0%|          | 0/188 [00:00<?, ?it/s]

Epoch 016 | Train Loss: 0.6166 | Val Loss: 0.6108 | Val AUC: 0.7290


Training:   0%|          | 0/750 [00:00<?, ?it/s]

Validation:   0%|          | 0/188 [00:00<?, ?it/s]

Epoch 017 | Train Loss: 0.6160 | Val Loss: 0.6104 | Val AUC: 0.7292


Training:   0%|          | 0/750 [00:00<?, ?it/s]

Validation:   0%|          | 0/188 [00:00<?, ?it/s]

Epoch 018 | Train Loss: 0.6159 | Val Loss: 0.6100 | Val AUC: 0.7294


Training:   0%|          | 0/750 [00:00<?, ?it/s]

Validation:   0%|          | 0/188 [00:00<?, ?it/s]

Epoch 019 | Train Loss: 0.6158 | Val Loss: 0.6102 | Val AUC: 0.7291


Training:   0%|          | 0/750 [00:00<?, ?it/s]

Validation:   0%|          | 0/188 [00:00<?, ?it/s]

Epoch 020 | Train Loss: 0.6167 | Val Loss: 0.6096 | Val AUC: 0.7296


Training:   0%|          | 0/750 [00:00<?, ?it/s]

Validation:   0%|          | 0/188 [00:00<?, ?it/s]

Epoch 021 | Train Loss: 0.6160 | Val Loss: 0.6114 | Val AUC: 0.7287


Training:   0%|          | 0/750 [00:00<?, ?it/s]

Validation:   0%|          | 0/188 [00:00<?, ?it/s]

Epoch 022 | Train Loss: 0.6163 | Val Loss: 0.6111 | Val AUC: 0.7293


Training:   0%|          | 0/750 [00:00<?, ?it/s]

Validation:   0%|          | 0/188 [00:00<?, ?it/s]

Epoch 023 | Train Loss: 0.6162 | Val Loss: 0.6105 | Val AUC: 0.7293


Training:   0%|          | 0/750 [00:00<?, ?it/s]

Validation:   0%|          | 0/188 [00:00<?, ?it/s]

Epoch 024 | Train Loss: 0.6160 | Val Loss: 0.6104 | Val AUC: 0.7295


Training:   0%|          | 0/750 [00:00<?, ?it/s]

Validation:   0%|          | 0/188 [00:00<?, ?it/s]

Epoch 025 | Train Loss: 0.6168 | Val Loss: 0.6098 | Val AUC: 0.7296
Early stopping triggered for Baseline GNN!

===== Training Non-Local GNN =====


Training:   0%|          | 0/750 [00:00<?, ?it/s]

Validation:   0%|          | 0/188 [00:00<?, ?it/s]

Epoch 001 | Train Loss: 0.6269 | Val Loss: 0.6172 | Val AUC: 0.7282


Training:   0%|          | 0/750 [00:00<?, ?it/s]

Validation:   0%|          | 0/188 [00:00<?, ?it/s]

Epoch 002 | Train Loss: 0.6192 | Val Loss: 0.6167 | Val AUC: 0.7282


Training:   0%|          | 0/750 [00:00<?, ?it/s]

Validation:   0%|          | 0/188 [00:00<?, ?it/s]

Epoch 003 | Train Loss: 0.6209 | Val Loss: 0.6123 | Val AUC: 0.7282


Training:   0%|          | 0/750 [00:00<?, ?it/s]

Validation:   0%|          | 0/188 [00:00<?, ?it/s]

Epoch 004 | Train Loss: 0.6191 | Val Loss: 0.6145 | Val AUC: 0.7287


Training:   0%|          | 0/750 [00:00<?, ?it/s]

Validation:   0%|          | 0/188 [00:00<?, ?it/s]

Epoch 005 | Train Loss: 0.6190 | Val Loss: 0.6117 | Val AUC: 0.7290


Training:   0%|          | 0/750 [00:00<?, ?it/s]

Validation:   0%|          | 0/188 [00:00<?, ?it/s]

Epoch 006 | Train Loss: 0.6195 | Val Loss: 0.6103 | Val AUC: 0.7288


Training:   0%|          | 0/750 [00:00<?, ?it/s]

Validation:   0%|          | 0/188 [00:00<?, ?it/s]

Epoch 007 | Train Loss: 0.6182 | Val Loss: 0.6106 | Val AUC: 0.7288


Training:   0%|          | 0/750 [00:00<?, ?it/s]

Validation:   0%|          | 0/188 [00:00<?, ?it/s]

Epoch 008 | Train Loss: 0.6180 | Val Loss: 0.6129 | Val AUC: 0.7279


Training:   0%|          | 0/750 [00:00<?, ?it/s]

Validation:   0%|          | 0/188 [00:00<?, ?it/s]

Epoch 009 | Train Loss: 0.6183 | Val Loss: 0.6121 | Val AUC: 0.7281


Training:   0%|          | 0/750 [00:00<?, ?it/s]

Validation:   0%|          | 0/188 [00:00<?, ?it/s]

Epoch 010 | Train Loss: 0.6176 | Val Loss: 0.6109 | Val AUC: 0.7293


Training:   0%|          | 0/750 [00:00<?, ?it/s]

Validation:   0%|          | 0/188 [00:00<?, ?it/s]

Epoch 011 | Train Loss: 0.6174 | Val Loss: 0.6105 | Val AUC: 0.7293
Early stopping triggered for Non-Local GNN!

===== Final Results =====
Baseline GNN: ROC-AUC = 0.7299
Non-Local GNN: ROC-AUC = 0.7293


In [None]:
# import h5py
# import numpy as np
# import torch
# import torch.nn as nn
# import torch.nn.functional as F
# from torch_geometric.data import Data, DataLoader
# from torch_geometric.nn import MessagePassing, global_mean_pool
# from sklearn.neighbors import radius_neighbors_graph
# from sklearn.model_selection import train_test_split
# from tqdm.auto import tqdm
# from torch.optim.lr_scheduler import ReduceLROnPlateau
# from sklearn.metrics import roc_auc_score
# import gc

# # ======================
# # 1. Memory-Efficient Graph Creation
# # ======================
# def process_in_chunks(h5_path, chunk_size=30000, threshold=0.01):
#     """Process HDF5 file in chunks to save memory"""
#     all_graphs = []
#     with h5py.File(h5_path, 'r') as f:
#         total_jets = f['X_jets'].shape[0]
        
#         for start_idx in tqdm(range(0, total_jets, chunk_size), 
#                           desc="Processing chunks"):
#             end_idx = min(start_idx + chunk_size, total_jets)
            
#             # Load chunk
#             X_chunk = f['X_jets'][start_idx:end_idx]
#             m0_chunk = f['m0'][start_idx:end_idx]
#             pt_chunk = f['pt'][start_idx:end_idx]
#             y_chunk = f['y'][start_idx:end_idx]
            
#             # Process chunk
#             chunk_graphs = []
#             for i in range(X_chunk.shape[0]):
#                 data = multi_channel_image_to_graph(
#                     X_chunk[i,0], X_chunk[i,1], X_chunk[i,2], threshold)
#                 data.m0 = torch.tensor([m0_chunk[i]], dtype=torch.float)
#                 data.pt = torch.tensor([pt_chunk[i]], dtype=torch.float)
#                 data.y = torch.tensor([int(y_chunk[i])], dtype=torch.long)
#                 chunk_graphs.append(data)
            
#             all_graphs.extend(chunk_graphs)
            
#             # Clean up memory
#             del X_chunk, m0_chunk, pt_chunk, y_chunk, chunk_graphs
#             gc.collect()
    
#     return all_graphs

# # ======================
# # 2. Early Stopping Class
# # ======================
# class EarlyStopping:
#     def __init__(self, patience=5, delta=0, path='best_model.pt'):
#         self.patience = patience
#         self.delta = delta
#         self.path = path
#         self.counter = 0
#         self.best_score = None
#         self.early_stop = False

#     def __call__(self, val_loss, model):
#         if self.best_score is None:
#             self.best_score = val_loss
#             self.save_checkpoint(model)
#         elif val_loss > self.best_score + self.delta:
#             self.counter += 1
#             if self.counter >= self.patience:
#                 self.early_stop = True
#         else:
#             self.best_score = val_loss
#             self.save_checkpoint(model)
#             self.counter = 0

#     def save_checkpoint(self, model):
#         torch.save(model.state_dict(), self.path)

# # ======================
# # 3. Graph Construction 
# # ======================
# def multi_channel_image_to_graph(ecal, hcal, track, threshold=0.01):
#     """Convert 3-channel jet image to graph"""
#     nodes = []
#     height, width = ecal.shape
    
#     for i in range(height):
#         for j in range(width):
#             total_energy = ecal[i,j] + hcal[i,j] + track[i,j]
#             if total_energy > threshold:
#                 nodes.append([
#                     i/float(height),   # norm x
#                     j/float(width),   # norm y
#                     ecal[i,j],        # ECAL
#                     hcal[i,j],        # HCAL
#                     track[i,j]        # Track
#                 ])
    
#     if len(nodes) == 0:  # Fallback
#         combined = ecal + hcal + track
#         max_idx = np.unravel_index(np.argmax(combined), combined.shape)
#         nodes.append([
#             max_idx[0]/float(height), max_idx[1]/float(width),
#             ecal[max_idx], hcal[max_idx], track[max_idx]
#         ])
    
#     nodes = np.array(nodes, dtype=np.float32)
#     pos = nodes[:, :2]
    
#     if len(nodes) > 1:
#         edges = radius_neighbors_graph(pos, radius=0.15, mode='connectivity')
#         edge_index = torch.tensor(edges.nonzero(), dtype=torch.long)
#     else:
#         edge_index = torch.tensor([[0], [0]], dtype=torch.long)
    
#     return Data(x=torch.tensor(nodes, dtype=torch.float),
#                 edge_index=edge_index)

# # ======================
# # 4. GNN Model
# # ======================
# class EdgeConv(MessagePassing):
#     def __init__(self, in_channels, out_channels):
#         super().__init__(aggr='mean')
#         self.mlp = nn.Sequential(
#             nn.Linear(2*in_channels, out_channels),
#             nn.ReLU(),
#             nn.BatchNorm1d(out_channels),
#             nn.Linear(out_channels, out_channels)
#         )
    
#     def forward(self, x, edge_index):
#         return self.propagate(edge_index, x=x)
    
#     def message(self, x_i, x_j):
#         return self.mlp(torch.cat([x_i, x_j-x_i], dim=1))

# class JetGNN(nn.Module):
#     def __init__(self, node_features=5, hidden_dim=64):
#         super().__init__()
#         self.conv1 = EdgeConv(node_features, hidden_dim)
#         self.conv2 = EdgeConv(hidden_dim, hidden_dim)
#         self.conv3 = EdgeConv(hidden_dim, hidden_dim)
#         self.global_mlp = nn.Sequential(
#             nn.Linear(2, 12),
#             nn.ReLU(),
#             nn.BatchNorm1d(12))
#         self.classifier = nn.Sequential(
#             nn.Linear(12 + hidden_dim, hidden_dim // 2),
#             nn.ReLU(),
#             nn.Linear(hidden_dim // 2, 2))
    
#     def forward(self, data):
#         x, edge_index, batch = data.x, data.edge_index, data.batch
#         x = F.leaky_relu(self.conv1(x, edge_index))
#         x = F.leaky_relu(self.conv2(x, edge_index))
#         x = F.leaky_relu(self.conv3(x, edge_index))
#         graph_feat = global_mean_pool(x, batch)
#         global_feat = self.global_mlp(torch.stack([data.m0, data.pt], dim=1))
#         return self.classifier(torch.cat([graph_feat, global_feat], dim=1))

# # ======================
# # 5. Data Loading (First 30,000 jets only)
# # ======================
# def load_data(filename, num_jets=30000):
#     with h5py.File(filename, 'r') as f:
#         X_jets = f['X_jets'][:num_jets]  # Only load first 30,000 jets
#         m0 = f['m0'][:num_jets]
#         pt = f['pt'][:num_jets]
#         y = f['y'][:num_jets]
#     return X_jets, m0, pt, y

# def create_graph_dataset(X_jets, m0, pt, y, threshold=0.01):
#     graphs = []
#     num_jets = X_jets.shape[0]
    
#     for i in tqdm(range(num_jets), desc="Creating graphs"):
#         ecal = X_jets[i, 0, :, :]  # ECAL channel
#         hcal = X_jets[i, 1, :, :]  # HCAL channel
#         track = X_jets[i, 2, :, :]  # Track channel
        
#         data = multi_channel_image_to_graph(ecal, hcal, track, threshold)
#         data.m0 = torch.tensor([m0[i]], dtype=torch.float)
#         data.pt = torch.tensor([pt[i]], dtype=torch.float)
#         data.y = torch.tensor([int(y[i])], dtype=torch.long)
#         graphs.append(data)
    
#     return graphs

# # ======================
# # 6. Training Loop with tqdm
# # ======================
# def train_epoch(model, loader, optimizer, criterion, device):
#     model.train()
#     total_loss, correct = 0, 0
#     pbar = tqdm(loader, leave=False, desc="Training")
#     for data in pbar:
#         data = data.to(device)
#         optimizer.zero_grad()
#         out = model(data)
#         loss = criterion(out, data.y.squeeze())
#         loss.backward()
#         optimizer.step()
        
#         total_loss += loss.item()
#         correct += (out.argmax(dim=1) == data.y.squeeze()).sum().item()
#         pbar.set_postfix({
#             'loss': f"{loss.item():.4f}",
#             'acc': f"{(out.argmax(dim=1) == data.y.squeeze()).float().mean().item():.4f}"
#         })
    
#     return total_loss/len(loader), correct/len(loader.dataset)

# def validate(model, loader, criterion, device):
#     model.eval()
#     total_loss, correct = 0, 0
#     all_preds = []
#     all_targets = []
#     pbar = tqdm(loader, leave=False, desc="Validation")
#     with torch.no_grad():
#         for data in pbar:
#             data = data.to(device)
#             out = model(data)
#             loss = criterion(out, data.y.squeeze())
#             total_loss += loss.item()
#             correct += (out.argmax(dim=1) == data.y.squeeze()).sum().item()
            
#             # Store predictions and targets for AUC calculation
#             probs = F.softmax(out, dim=1)
#             all_preds.append(probs[:, 1].cpu().numpy())  # Probability of class 1
#             all_targets.append(data.y.squeeze().cpu().numpy())
            
#             pbar.set_postfix({
#                 'val_loss': f"{loss.item():.4f}",
#                 'val_acc': f"{(out.argmax(dim=1) == data.y.squeeze()).float().mean().item():.4f}"
#             })
    
#     # Calculate AUC
#     all_preds = np.concatenate(all_preds)
#     all_targets = np.concatenate(all_targets)
#     auc = roc_auc_score(all_targets, all_preds)
    
#     return total_loss/len(loader), correct/len(loader.dataset), auc

# # ======================
# # 7. Main Execution
# # ======================
# def main():
#     # Config
#     device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
#     threshold = 0.01
#     patience = 5
#     lr = 0.001
#     batch_size = 32
#     hidden_dim = 128
#     chunk_size = 30000  # Process 30,000 jets at a time
    
#     # Process data in chunks
#     print("Processing entire dataset in chunks...")
#     graphs = process_in_chunks('/kaggle/working/filename', chunk_size, threshold)
    
#     # Split data
#     train_graphs, val_graphs = train_test_split(graphs, test_size=0.2, random_state=42)
#     del graphs  # Free memory
#     gc.collect()
    
#     # Create dataloaders
#     train_loader = DataLoader(train_graphs, batch_size=batch_size, shuffle=True)
#     val_loader = DataLoader(val_graphs, batch_size=batch_size)
    
#     # Initialize model
#     model = JetGNN(hidden_dim=hidden_dim).to(device)
#     optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=1e-4)
#     scheduler = ReduceLROnPlateau(optimizer, 'min', patience=2, factor=0.5)
#     criterion = nn.CrossEntropyLoss()
#     early_stopping = EarlyStopping(patience=patience, path='best_jetgnn.pt')
    
#     # Training loop
#     best_auc = 0.0
#     for epoch in range(100):
#         # Training
#         train_loss, train_acc = train_epoch(model, train_loader, optimizer, criterion, device)
        
#         # Validation
#         val_loss, val_acc, val_auc = validate(model, val_loader, criterion, device)
        
#         # Update scheduler
#         scheduler.step(val_loss)
        
#         # Track best AUC
#         if val_auc > best_auc:
#             best_auc = val_auc
#             torch.save(model.state_dict(), 'best_jetgnn_auc.pt')
        
#         # Print epoch stats
#         print(f"\nEpoch {epoch+1:03d}")
#         print(f"Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.4f}")
#         print(f"Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.4f} | Val AUC: {val_auc:.4f}")
#         print(f"Best AUC: {best_auc:.4f}")
#         print(f"LR: {optimizer.param_groups[0]['lr']:.2e}")
        
#         # Early stopping check
#         early_stopping(val_loss, model)
#         if early_stopping.early_stop:
#             print("\nEarly stopping triggered")
#             break
    
#     # Load best model (based on validation loss)
#     model.load_state_dict(torch.load('best_jetgnn.pt'))
#     print("\nTraining complete. Best model saved to 'best_jetgnn.pt'")
#     print(f"Best AUC during training: {best_auc:.4f}")

# if __name__ == "__main__":
#     main()