In [74]:
!pip install torch_geometric



In [75]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.data import Data, DataLoader
from torch_geometric.nn import GATConv, global_mean_pool
from torch.optim import Adam
from torch.optim.lr_scheduler import StepLR
from sklearn.metrics import accuracy_score, f1_score, roc_auc_score
from sklearn.model_selection import train_test_split

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


In [76]:
class JetGraphDataset(torch.utils.data.Dataset):
    def __init__(self, graph_path, npz_path):
        """
        Args:
            graph_path (str): Path to the processed graphs (.pt file).
            npz_path (str): Path to the corresponding npz file containing physics features.
        """
        super(JetGraphDataset, self).__init__()
        self.graphs = torch.load(graph_path, map_location=device)
        
        data_npz = np.load(npz_path)
        self.labels = data_npz['y']      # Quark/Gluon labels (converted to int)
        self.pt = data_npz['pt']         # Transverse momentum
        self.m0 = data_npz['m0']         # Jet mass
        
        for i, graph in enumerate(self.graphs):
            # Assign basic physics features
            graph.y = torch.tensor([int(self.labels[i])], dtype=torch.long, device=device)
            graph.pt = torch.tensor([self.pt[i]], dtype=torch.float, device=device)
            graph.m0 = torch.tensor([self.m0[i]], dtype=torch.float, device=device)
            # If no node features exist, initialize them
            if not hasattr(graph, 'x') or graph.x is None:
                graph.x = torch.ones((graph.num_nodes, 16), device=device)
            
            # Compute explosion metric: ratio of number of edges to number of nodes.
            num_edges = graph.edge_index.size(1) if hasattr(graph, 'edge_index') else 0
            explosion = num_edges / graph.num_nodes if graph.num_nodes > 0 else 0.0
            graph.explosion = torch.tensor([explosion], dtype=torch.float, device=device)
    
    def __len__(self):
        return len(self.graphs)
    
    def __getitem__(self, idx):
        return self.graphs[idx]

In [77]:
class AdvancedGATEncoder(nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, heads=4):
        """
        Advanced encoder using GATConv layers with dense connections.
        Args:
            in_channels: Input node feature dimension.
            hidden_channels: Hidden layer dimension.
            out_channels: Output embedding dimension.
            heads: Number of attention heads.
        """
        super(AdvancedGATEncoder, self).__init__()
        self.gat1 = GATConv(in_channels, hidden_channels, heads=heads, concat=True)
        self.gat2 = GATConv(hidden_channels * heads, hidden_channels, heads=heads, concat=True)
        self.gat3 = GATConv(hidden_channels * heads, out_channels, heads=1, concat=False)
        # Dense connection projection from input to out_channels
        self.skip_proj = nn.Linear(in_channels, out_channels)
    
    def forward(self, x, edge_index, batch):
        # First GAT layer
        out1 = F.elu(self.gat1(x, edge_index))
        # Second GAT layer with dense connection (concatenating previous layer's output)
        out2 = F.elu(self.gat2(out1, edge_index))
        # Third GAT layer
        out3 = self.gat3(out2, edge_index)
        # Add skip connection from input (projected) to preserve low-level features
        skip = self.skip_proj(x)
        # Global pooling for each level
        pooled_out3 = global_mean_pool(out3 + skip, batch)
        # Optionally, one could also combine pooled representations from multiple layers.
        return pooled_out3

In [78]:
class ProjectionHead(nn.Module):
    def __init__(self, in_dim, proj_dim):
        """
        Two-layer MLP for latent projection.
        """
        super(ProjectionHead, self).__init__()
        self.fc1 = nn.Linear(in_dim, proj_dim)
        self.fc2 = nn.Linear(proj_dim, proj_dim)
    
    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

In [79]:
class ClassifierHead(nn.Module):
    def __init__(self, in_dim, num_classes):
        """
        Simple MLP classifier.
        """
        super(ClassifierHead, self).__init__()
        self.fc1 = nn.Linear(in_dim, in_dim // 2)
        self.fc2 = nn.Linear(in_dim // 2, num_classes)
    
    def forward(self, x):
        x = F.relu(self.fc1(x))
        logits = self.fc2(x)
        return logits

In [80]:
class AdvancedGraphModel(nn.Module):
    def __init__(self, in_channels, hidden_channels, encoder_out, proj_dim, num_classes):
        """
        Combines the advanced GAT encoder, projection head, and classifier head.
        """
        super(AdvancedGraphModel, self).__init__()
        self.encoder = AdvancedGATEncoder(in_channels, hidden_channels, encoder_out)
        self.projection_head = ProjectionHead(encoder_out, proj_dim)
        # For classification, concatenate global physics features (pt and m0)
        self.classifier = ClassifierHead(encoder_out + 2, num_classes)
    
    def forward(self, data, mode='contrastive'):
        x = data.x
        edge_index = data.edge_index
        batch = data.batch
        embedding = self.encoder(x, edge_index, batch)
        
        if mode == 'contrastive':
            proj = self.projection_head(embedding)
            return proj
        elif mode == 'classification':
            if hasattr(data, 'pt') and hasattr(data, 'm0'):
                physics_features = torch.cat([data.pt.view(-1, 1), data.m0.view(-1, 1)], dim=1)
                embedding = torch.cat([embedding, physics_features], dim=1)
            logits = self.classifier(embedding)
            return logits
        else:
            raise ValueError("Mode must be 'contrastive' or 'classification'.")

In [81]:
def improved_nt_xent_loss(z1, z2, temperature=0.5, margin=0.5, lambda_reg=0.1):
    batch_size = z1.shape[0]
    z = torch.cat([z1, z2], dim=0)
    z = F.normalize(z, dim=1)
    
    similarity_matrix = torch.matmul(z, z.T)
    mask = torch.eye(2 * batch_size, dtype=torch.bool, device=z.device)
    similarity_matrix = similarity_matrix.masked_fill(mask, -9e15)
    similarity_matrix = similarity_matrix / temperature
    
    labels = torch.arange(batch_size, device=z.device)
    labels = torch.cat([labels, labels], dim=0)
    
    nt_xent = F.cross_entropy(similarity_matrix, labels)
    
    negatives = similarity_matrix.clone()
    negatives[mask] = -9e15
    max_negatives, _ = negatives.max(dim=1)
    margin_loss = F.relu(max_negatives - margin).mean()
    
    total_loss = nt_xent + lambda_reg * margin_loss
    return total_loss

In [82]:
class AdvancedTrainer:
    def __init__(self, model, train_loader, test_loader, device, lr=1e-3):
        self.model = model.to(device)
        self.train_loader = train_loader
        self.test_loader = test_loader
        self.device = device
        self.optimizer = Adam(self.model.parameters(), lr=lr)
        self.scheduler = StepLR(self.optimizer, step_size=5, gamma=0.5)
        self.criterion_cls = nn.CrossEntropyLoss()
    
    def pretrain(self, epochs, drop_prob=0.2, temperature=0.5, margin=0.5, lambda_reg=0.1):
        self.model.train()
        for epoch in range(epochs):
            total_loss = 0.0
            for data in self.train_loader:
                data = data.to(self.device)
                if not hasattr(data, 'x') or data.x is None:
                    data.x = torch.ones((data.num_nodes, 16), device=self.device)
                
                data1 = self._graph_augmentation(data, drop_prob)
                data2 = self._graph_augmentation(data, drop_prob)
                data1 = data1.to(self.device)
                data2 = data2.to(self.device)
                
                self.optimizer.zero_grad()
                z1 = self.model(data1, mode='contrastive')
                z2 = self.model(data2, mode='contrastive')
                loss = improved_nt_xent_loss(z1, z2, temperature, margin, lambda_reg)
                loss.backward()
                self.optimizer.step()
                total_loss += loss.item()
            self.scheduler.step()
            avg_loss = total_loss / len(self.train_loader)
            print(f"[Pretrain] Epoch {epoch+1}/{epochs}, Loss: {avg_loss:.4f}, LR: {self.optimizer.param_groups[0]['lr']:.6f}")
    
    def finetune(self, epochs, freeze_encoder=True):
        if freeze_encoder:
            for param in self.model.encoder.parameters():
                param.requires_grad = False
        
        optimizer_cls = Adam(self.model.classifier.parameters(), lr=1e-3)
        scheduler_cls = StepLR(optimizer_cls, step_size=5, gamma=0.5)
        
        self.model.train()
        for epoch in range(epochs):
            total_loss = 0.0
            for data in self.train_loader:
                data = data.to(self.device)
                if not hasattr(data, 'x') or data.x is None:
                    data.x = torch.ones((data.num_nodes, 16), device=self.device)
                optimizer_cls.zero_grad()
                logits = self.model(data, mode='classification')
                loss = self.criterion_cls(logits, data.y)
                loss.backward()
                optimizer_cls.step()
                total_loss += loss.item()
            scheduler_cls.step()
            avg_loss = total_loss / len(self.train_loader)
            metrics = self.evaluate()
            print(f"[Finetune] Epoch {epoch+1}/{epochs}, Loss: {avg_loss:.4f}, Accuracy: {metrics['accuracy']*100:.2f}%, "
                  f"F1: {metrics['f1']:.4f}, ROC-AUC: {metrics['roc_auc']:.4f}, LR: {optimizer_cls.param_groups[0]['lr']:.6f}")
    
    def evaluate(self):
        self.model.eval()
        all_preds = []
        all_labels = []
        all_probs = []
        with torch.no_grad():
            for data in self.test_loader:
                data = data.to(self.device)
                if not hasattr(data, 'x') or data.x is None:
                    data.x = torch.ones((data.num_nodes, 16), device=self.device)
                logits = self.model(data, mode='classification')
                probs = F.softmax(logits, dim=1)[:, 1]
                preds = logits.argmax(dim=1).cpu().numpy()
                all_preds.extend(preds)
                all_probs.extend(probs.cpu().numpy())
                all_labels.extend(data.y.cpu().numpy())
        acc = accuracy_score(all_labels, all_preds)
        f1 = f1_score(all_labels, all_preds, average="weighted")
        try:
            roc_auc = roc_auc_score(all_labels, all_probs)
        except Exception:
            roc_auc = 0.0
        self.model.train()
        return {"accuracy": acc, "f1": f1, "roc_auc": roc_auc}
    
    def _graph_augmentation(self, data, drop_prob):
        node_mask = torch.rand(data.num_nodes, device=data.x.device) > drop_prob
        if node_mask.sum() == 0:
            node_mask[torch.randint(0, data.num_nodes, (1,))] = True
        new_idx = torch.zeros(data.num_nodes, dtype=torch.long, device=data.x.device)
        new_idx[node_mask] = torch.arange(node_mask.sum(), device=data.x.device)
        
        x = data.x[node_mask]
        edge_index = data.edge_index
        mask = node_mask[edge_index[0]] & node_mask[edge_index[1]]
        edge_index = edge_index[:, mask]
        edge_index = new_idx[edge_index]
        
        new_data = Data(x=x, edge_index=edge_index)
        if hasattr(data, 'batch'):
            new_data.batch = data.batch[node_mask]
        if hasattr(data, 'pt'):
            new_data.pt = data.pt
        if hasattr(data, 'm0'):
            new_data.m0 = data.m0
        return new_data


In [83]:
ggraph_path = '/kaggle/input/part-4-task-2-output/processed_jet_graphs/processed_chunk_90000_100000.pt'
npz_path = '/kaggle/input/genie-extracted-dataset/chunk_90000_100000.npz'

# Create dataset
dataset = JetGraphDataset(graph_path, npz_path)

# Split dataset indices using stratification on labels


  self.graphs = torch.load(graph_path, map_location=device)


In [84]:
indices = np.arange(len(dataset))
labels = np.array([dataset[i].y.item() for i in range(len(dataset))])
train_idx, test_idx = train_test_split(indices, test_size=0.2, random_state=42, stratify=labels)

train_graphs = [dataset[i] for i in train_idx]
test_graphs = [dataset[i] for i in test_idx]

print(f"Total graphs: {len(dataset)}; Training: {len(train_graphs)}; Testing: {len(test_graphs)}")

batch_size = 32
train_loader = DataLoader(train_graphs, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_graphs, batch_size=batch_size, shuffle=False)

in_channels = train_graphs[0].x.shape[1]
hidden_channels = 64
encoder_out = 64
proj_dim = 32
num_classes = 2

model = AdvancedGraphModel(in_channels, hidden_channels, encoder_out, proj_dim, num_classes)
trainer = AdvancedTrainer(model, train_loader, test_loader, device, lr=1e-3)


Total graphs: 10000; Training: 8000; Testing: 2000




In [85]:
pretrain_epochs = 10  # Adjust as necessary
print("Starting contrastive pre-training...")
trainer.pretrain(epochs=pretrain_epochs, drop_prob=0.1, temperature=0.7, margin=0.3, lambda_reg=0.4
                )

# %% [markdown]
# ### 6.2 Classification Fine-tuning
# 
# Fine-tune the classifier for anomaly detection/classification. Metrics are computed after each epoch.

# %%
finetune_epochs = 10  # Adjust as necessary
print("\nStarting classification fine-tuning...")
trainer.finetune(epochs=finetune_epochs, freeze_encoder=True)

Starting contrastive pre-training...
[Pretrain] Epoch 1/10, Loss: 6428570146570240.0000, LR: 0.001000
[Pretrain] Epoch 2/10, Loss: 6428570146570240.0000, LR: 0.001000
[Pretrain] Epoch 3/10, Loss: 6428570146570240.0000, LR: 0.001000
[Pretrain] Epoch 4/10, Loss: 6428570146570240.0000, LR: 0.001000
[Pretrain] Epoch 5/10, Loss: 6428570146570240.0000, LR: 0.000500
[Pretrain] Epoch 6/10, Loss: 6428570146570240.0000, LR: 0.000500
[Pretrain] Epoch 7/10, Loss: 6428570146570240.0000, LR: 0.000500
[Pretrain] Epoch 8/10, Loss: 6428570146570240.0000, LR: 0.000500
[Pretrain] Epoch 9/10, Loss: 6428570146570240.0000, LR: 0.000500
[Pretrain] Epoch 10/10, Loss: 6428570146570240.0000, LR: 0.000250

Starting classification fine-tuning...
[Finetune] Epoch 1/10, Loss: 0.6848, Accuracy: 61.85%, F1: 0.5846, ROC-AUC: 0.7160, LR: 0.001000
[Finetune] Epoch 2/10, Loss: 0.6284, Accuracy: 65.80%, F1: 0.6573, ROC-AUC: 0.7050, LR: 0.001000
[Finetune] Epoch 3/10, Loss: 0.6238, Accuracy: 66.75%, F1: 0.6662, ROC-AUC: 0.

In [86]:
final_metrics = trainer.evaluate()
print("Final Evaluation Metrics:")
print(f"Accuracy: {final_metrics['accuracy']*100:.2f}%")
print(f"F1 Score: {final_metrics['f1']:.4f}")
print(f"ROC-AUC: {final_metrics['roc_auc']:.4f}")

Final Evaluation Metrics:
Accuracy: 63.95%
F1 Score: 0.6343
ROC-AUC: 0.7076
