In [6]:
!pip install torch_geometric



In [19]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.data import Data, DataLoader
from torch_geometric.nn import GATConv, global_mean_pool, global_max_pool
from torch.optim import Adam
from torch.optim.lr_scheduler import StepLR
from sklearn.metrics import accuracy_score, f1_score, roc_auc_score
from sklearn.model_selection import train_test_split
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


In [31]:
class JetGraphDataset(torch.utils.data.Dataset):
    def __init__(self, graph_path, npz_path):
        """
        Loads graphs from a .pt file and corresponding physics features from an npz file.
        Each graph is augmented with:
            - y: label (quark/gluon)
            - pt: normalized transverse momentum
            - m0: normalized jet mass
            - explosion: ratio of number of edges to number of nodes (physics-inspired)
        """
        super(JetGraphDataset, self).__init__()
        self.graphs = torch.load(graph_path, map_location=device)
        
        data_npz = np.load(npz_path)
        self.labels = data_npz['y']      # Quark/Gluon labels (as integers)
        self.pt = data_npz['pt']         # Transverse momentum
        self.m0 = data_npz['m0']         # Jet mass
        
        # Normalize physics features
        self.pt = (self.pt - np.mean(self.pt)) / (np.std(self.pt) + 1e-8)
        self.m0 = (self.m0 - np.mean(self.m0)) / (np.std(self.m0) + 1e-8)
        
        for i, graph in enumerate(self.graphs):
            graph.y = torch.tensor([int(self.labels[i])], dtype=torch.long, device=device)
            graph.pt = torch.tensor([self.pt[i]], dtype=torch.float, device=device)
            graph.m0 = torch.tensor([self.m0[i]], dtype=torch.float, device=device)
            # If no node features, initialize with ones
            if not hasattr(graph, 'x') or graph.x is None:
                # If not provided, default feature dimension is set to 16.
                graph.x = torch.ones((graph.num_nodes, 16), device=device)
            # Compute explosion metric: number of edges / number of nodes
            num_edges = graph.edge_index.size(1) if hasattr(graph, 'edge_index') else 0
            explosion = num_edges / graph.num_nodes if graph.num_nodes > 0 else 0.0
            graph.explosion = torch.tensor([explosion], dtype=torch.float, device=device)
    
    def __len__(self):
        return len(self.graphs)
    
    def __getitem__(self, idx):
        return self.graphs[idx]


In [32]:
class EnhancedGATEncoder(nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, heads=4, dropout=0.2):
        super(EnhancedGATEncoder, self).__init__()
        self.dropout = dropout
        # GAT layers with LayerNorm
        self.gat1 = GATConv(in_channels, hidden_channels, heads=heads, concat=True)
        self.ln1 = nn.LayerNorm(hidden_channels * heads)
        self.gat2 = GATConv(hidden_channels * heads, hidden_channels, heads=heads, concat=True)
        self.ln2 = nn.LayerNorm(hidden_channels * heads)
        self.gat3 = GATConv(hidden_channels * heads, out_channels, heads=1, concat=False)
        self.ln3 = nn.LayerNorm(out_channels)
        # Skip connection: project input to out_channels with adaptive gating
        self.skip_proj = nn.Linear(in_channels, out_channels)
        self.gate = nn.Sequential(nn.Linear(out_channels * 2, out_channels), nn.Sigmoid())
        # Optional spectral branch: a simple linear transformation for multi-frequency filtering
        self.spectral_linear = nn.Linear(in_channels, out_channels)
    
    def forward(self, x, edge_index, batch):
        # GAT pathway with dropout and LayerNorm
        out1 = F.elu(self.ln1(self.gat1(x, edge_index)))
        out1 = F.dropout(out1, p=self.dropout, training=self.training)
        out2 = F.elu(self.ln2(self.gat2(out1, edge_index)))
        out2 = F.dropout(out2, p=self.dropout, training=self.training)
        out3 = self.ln3(self.gat3(out2, edge_index))
        # Skip connection
        skip = self.skip_proj(x)
        combined_main = out3 + skip
        gate_factor = self.gate(torch.cat([out3, skip], dim=-1))
        gated = gate_factor * combined_main
        
        # Multi-scale pooling: fuse mean and max pooling
        pooled_mean = global_mean_pool(gated, batch)
        pooled_max = global_max_pool(gated, batch)
        pooled_main = torch.cat([pooled_mean, pooled_max], dim=-1)
        # Project multi-scale pooled embedding back to out_channels (use a linear layer)
        proj_layer = nn.Linear(pooled_main.shape[-1], skip.shape[-1]).to(x.device)
        pooled_main = torch.tanh(proj_layer(pooled_main))
        
        # Spectral branch: compute a spectral transformation (linear approximation)
        spectral_out = self.spectral_linear(x)
        pooled_spectral = global_mean_pool(spectral_out, batch)
        
        # Fuse both branches
        fused = pooled_main + pooled_spectral
        return fused

class ProjectionHead(nn.Module):
    def __init__(self, in_dim, proj_dim):
        super(ProjectionHead, self).__init__()
        self.fc1 = nn.Linear(in_dim, proj_dim)
        self.fc2 = nn.Linear(proj_dim, proj_dim)
    
    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# Modified AuxiliaryReconstructionHead to output target dimension equal to original node feature dimension.
class AuxiliaryReconstructionHead(nn.Module):
    def __init__(self, in_dim, recon_dim, target_dim):
        """
        Reconstruct aggregated node features as an auxiliary task.
        Args:
            in_dim: dimension of the encoder output.
            recon_dim: hidden dimension for reconstruction.
            target_dim: target dimension to reconstruct (should equal original node feature dimension).
        """
        super(AuxiliaryReconstructionHead, self).__init__()
        self.fc1 = nn.Linear(in_dim, recon_dim)
        self.fc2 = nn.Linear(recon_dim, target_dim)
    
    def forward(self, z):
        z_rec = F.relu(self.fc1(z))
        z_rec = self.fc2(z_rec)
        return z_rec

class ClassifierHead(nn.Module):
    def __init__(self, in_dim, num_classes):
        super(ClassifierHead, self).__init__()
        self.fc1 = nn.Linear(in_dim, in_dim // 2)
        self.fc2 = nn.Linear(in_dim // 2, num_classes)
    
    def forward(self, x):
        x = F.relu(self.fc1(x))
        logits = self.fc2(x)
        return logits

In [33]:
class EnhancedGraphModel(nn.Module):
    def __init__(self, in_channels, hidden_channels, encoder_out, proj_dim, num_classes, use_aux_recon=True):
        super(EnhancedGraphModel, self).__init__()
        self.encoder = EnhancedGATEncoder(in_channels, hidden_channels, encoder_out)
        self.projection_head = ProjectionHead(encoder_out, proj_dim)
        self.use_aux_recon = use_aux_recon
        if self.use_aux_recon:
            # Reconstruction head: output dimension equals original node feature dimension (in_channels)
            self.reconstruction_head = AuxiliaryReconstructionHead(encoder_out, encoder_out // 2, target_dim=in_channels)
        # Classifier: concatenates graph embedding with physics features (pt, m0, explosion)
        self.classifier = ClassifierHead(encoder_out + 3, num_classes)
        self.classifier_dropout = nn.Dropout(0.2)
    
    def forward(self, data, mode='contrastive'):
        # data: a Batch from PyG containing x, edge_index, batch, pt, m0, explosion
        x = data.x
        edge_index = data.edge_index
        batch = data.batch
        embedding = self.encoder(x, edge_index, batch)
        
        if mode == 'contrastive':
            proj = self.projection_head(embedding)
            if self.use_aux_recon:
                rec = self.reconstruction_head(embedding)
                return proj, rec, embedding
            return proj, None, embedding
        elif mode == 'classification':
            if hasattr(data, 'pt') and hasattr(data, 'm0') and hasattr(data, 'explosion'):
                physics_features = torch.cat([
                    data.pt.view(-1, 1),
                    data.m0.view(-1, 1),
                    data.explosion.view(-1, 1)
                ], dim=1)
                embedding = torch.cat([embedding, physics_features], dim=1)
            embedding = self.classifier_dropout(embedding)
            logits = self.classifier(embedding)
            return logits
        else:
            raise ValueError("Mode must be 'contrastive' or 'classification'.")

In [36]:
def improved_nt_xent_loss(z1, z2, temperature=0.5, margin=0.5, lambda_reg=0.1):
    """
    NT-Xent loss with hard negative margin regularization.
    """
    batch_size = z1.shape[0]
    z = torch.cat([z1, z2], dim=0)
    z = F.normalize(z, dim=1)
    
    similarity_matrix = torch.matmul(z, z.T)
    mask = torch.eye(2 * batch_size, dtype=torch.bool, device=z.device)
    similarity_matrix = similarity_matrix.masked_fill(mask, -9e15)
    similarity_matrix = similarity_matrix / temperature
    
    labels = torch.arange(batch_size, device=z.device)
    labels = torch.cat([labels, labels], dim=0)
    nt_xent = F.cross_entropy(similarity_matrix, labels)
    
    negatives = similarity_matrix.clone()
    negatives[mask] = -9e15
    max_negatives, _ = negatives.max(dim=1)
    margin_loss = F.relu(max_negatives - margin).mean()
    
    total_loss = nt_xent + lambda_reg * margin_loss
    return total_loss


In [37]:
class EnhancedTrainer:
    def __init__(self, model, train_loader, test_loader, device, lr=1e-3, lambda_aux=0.5):
        self.model = model.to(device)
        self.train_loader = train_loader
        self.test_loader = test_loader
        self.device = device
        self.optimizer = Adam(self.model.parameters(), lr=lr)
        self.scheduler = StepLR(self.optimizer, step_size=5, gamma=0.5)
        self.criterion_cls = nn.CrossEntropyLoss()
        self.criterion_rec = nn.MSELoss()
        self.lambda_aux = lambda_aux  # Weight for auxiliary reconstruction loss
    
    def pretrain(self, epochs, drop_prob=0.2, edge_perturb_prob=0.1, temperature=0.5, margin=0.5, lambda_reg=0.1):
        self.model.train()
        for epoch in range(epochs):
            total_loss = 0.0
            for data in self.train_loader:
                data = data.to(self.device)
                if not hasattr(data, 'x') or data.x is None:
                    data.x = torch.ones((data.num_nodes, 16), device=self.device)
                
                # Generate two augmented views:
                # View 1: Node dropout + edge perturbation
                data1 = self._graph_augmentation(data, drop_prob, edge_perturb_prob)
                # View 2: Feature masking augmentation (zero out a fraction of node features)
                data2 = self._feature_mask_augmentation(data, mask_prob=0.2)
                data1 = data1.to(self.device)
                data2 = data2.to(self.device)
                
                self.optimizer.zero_grad()
                proj1, rec1, _ = self.model(data1, mode='contrastive')
                proj2, rec2, _ = self.model(data2, mode='contrastive')
                
                # Contrastive loss on projections
                loss_contrast = improved_nt_xent_loss(proj1, proj2, temperature, margin, lambda_reg)
                
                # Auxiliary reconstruction loss:
                # Use global mean pooling of the original node features as target.
                pooled1 = global_mean_pool(data1.x, data1.batch)
                pooled2 = global_mean_pool(data2.x, data2.batch)
                loss_rec1 = self.criterion_rec(rec1, pooled1)
                loss_rec2 = self.criterion_rec(rec2, pooled2)
                loss_rec = (loss_rec1 + loss_rec2) / 2.0
                
                total_loss_batch = loss_contrast + self.lambda_aux * loss_rec
                
                # Optional: Add adversarial loss for robustness (currently commented)
                # total_loss_batch += 0.01 * self._adversarial_loss(data)
                
                total_loss_batch.backward()
                self.optimizer.step()
                total_loss += total_loss_batch.item()
            self.scheduler.step()
            avg_loss = total_loss / len(self.train_loader)
            print(f"[Pretrain] Epoch {epoch+1}/{epochs} - Loss: {avg_loss:.4f}, LR: {self.optimizer.param_groups[0]['lr']:.6f}")
    
    def finetune(self, epochs, freeze_encoder=True, early_stop_patience=5):
        if freeze_encoder:
            for param in self.model.encoder.parameters():
                param.requires_grad = False
        
        optimizer_cls = Adam(self.model.classifier.parameters(), lr=1e-3)
        scheduler_cls = StepLR(optimizer_cls, step_size=5, gamma=0.5)
        
        best_loss = float('inf')
        patience_counter = 0
        
        self.model.train()
        for epoch in range(epochs):
            total_loss = 0.0
            for data in self.train_loader:
                data = data.to(self.device)
                if not hasattr(data, 'x') or data.x is None:
                    data.x = torch.ones((data.num_nodes, 16), device=self.device)
                optimizer_cls.zero_grad()
                logits = self.model(data, mode='classification')
                loss = self.criterion_cls(logits, data.y)
                loss.backward()
                optimizer_cls.step()
                total_loss += loss.item()
            scheduler_cls.step()
            avg_loss = total_loss / len(self.train_loader)
            metrics = self.evaluate()
            print(f"[Finetune] Epoch {epoch+1}/{epochs} - Loss: {avg_loss:.4f}, Accuracy: {metrics['accuracy']*100:.2f}%, "
                  f"F1: {metrics['f1']:.4f}, ROC-AUC: {metrics['roc_auc']:.4f}, LR: {optimizer_cls.param_groups[0]['lr']:.6f}")
            if avg_loss < best_loss:
                best_loss = avg_loss
                patience_counter = 0
                # Optionally save checkpoint here
            else:
                patience_counter += 1
                if patience_counter >= early_stop_patience:
                    print(f"Early stopping triggered at epoch {epoch+1}")
                    break
    
    def evaluate(self):
        self.model.eval()
        all_preds, all_labels, all_probs = [], [], []
        with torch.no_grad():
            for data in self.test_loader:
                data = data.to(self.device)
                if not hasattr(data, 'x') or data.x is None:
                    data.x = torch.ones((data.num_nodes, 16), device=self.device)
                logits = self.model(data, mode='classification')
                probs = F.softmax(logits, dim=1)[:, 1]
                preds = logits.argmax(dim=1).cpu().numpy()
                all_preds.extend(preds)
                all_probs.extend(probs.cpu().numpy())
                all_labels.extend(data.y.cpu().numpy())
        acc = accuracy_score(all_labels, all_preds)
        f1 = f1_score(all_labels, all_preds, average="weighted")
        try:
            roc_auc = roc_auc_score(all_labels, all_probs)
        except Exception:
            roc_auc = 0.0
        self.model.train()
        return {"accuracy": acc, "f1": f1, "roc_auc": roc_auc}
    
    def _graph_augmentation(self, data, drop_prob, edge_perturb_prob):
        """
        Apply node dropout and stochastic edge perturbation.
        """
        node_mask = torch.rand(data.num_nodes, device=data.x.device) > drop_prob
        if node_mask.sum() == 0:
            node_mask[torch.randint(0, data.num_nodes, (1,))] = True
        new_idx = torch.zeros(data.num_nodes, dtype=torch.long, device=data.x.device)
        new_idx[node_mask] = torch.arange(node_mask.sum(), device=data.x.device)
        x = data.x[node_mask]
        edge_index = data.edge_index
        mask = node_mask[edge_index[0]] & node_mask[edge_index[1]]
        edge_index = edge_index[:, mask]
        edge_index = new_idx[edge_index]
        if edge_index.size(1) > 0:
            edge_drop_mask = torch.rand(edge_index.size(1), device=edge_index.device) > edge_perturb_prob
            edge_index = edge_index[:, edge_drop_mask]
        new_data = Data(x=x, edge_index=edge_index)
        if hasattr(data, 'batch'):
            new_data.batch = data.batch[node_mask]
        if hasattr(data, 'pt'):
            new_data.pt = data.pt
        if hasattr(data, 'm0'):
            new_data.m0 = data.m0
        if hasattr(data, 'explosion'):
            new_data.explosion = data.explosion
        return new_data
    
    def _feature_mask_augmentation(self, data, mask_prob=0.2):
        """
        Augment graphs by masking a fraction of node features.
        """
        new_data = Data()
        new_data.x = data.x.clone()
        mask = torch.rand(data.x.size(0), device=data.x.device) < mask_prob
        new_data.x[mask] = 0.0  # Zero out selected features
        new_data.edge_index = data.edge_index
        new_data.batch = data.batch
        if hasattr(data, 'pt'): new_data.pt = data.pt
        if hasattr(data, 'm0'): new_data.m0 = data.m0
        if hasattr(data, 'explosion'): new_data.explosion = data.explosion
        return new_data
    
    def _adversarial_loss(self, data, epsilon=0.01):
        """
        Compute a simple adversarial perturbation loss.
        Adds Gaussian noise to node features to simulate worst-case perturbation.
        """
        perturbed_x = data.x + epsilon * torch.randn_like(data.x)
        adv_data = Data(x=perturbed_x, edge_index=data.edge_index, batch=data.batch)
        adv_proj, _, _ = self.model(adv_data, mode='contrastive')
        orig_proj, _, _ = self.model(data, mode='contrastive')
        return F.mse_loss(adv_proj, orig_proj)

In [38]:
graph_path = '/kaggle/input/part-4-task-2-output/processed_jet_graphs/processed_chunk_90000_100000.pt'
npz_path = '/kaggle/input/genie-extracted-dataset/chunk_90000_100000.npz'

# Create dataset
dataset = JetGraphDataset(graph_path, npz_path)

# Split dataset indices using stratification on labels


  self.graphs = torch.load(graph_path, map_location=device)


In [43]:
indices = np.arange(len(dataset))
labels = np.array([dataset[i].y.item() for i in range(len(dataset))])
train_idx, test_idx = train_test_split(indices, test_size=0.1, random_state=42, stratify=labels)

train_graphs = [dataset[i] for i in train_idx]
test_graphs = [dataset[i] for i in test_idx]

print(f"Total graphs: {len(dataset)}; Training graphs: {len(train_graphs)}; Testing graphs: {len(test_graphs)}")

batch_size = 256
train_loader = DataLoader(train_graphs, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_graphs, batch_size=batch_size, shuffle=False)

in_channels = train_graphs[0].x.shape[1]
hidden_channels = 64
encoder_out = 64
proj_dim = 32
num_classes = 2

# Instantiate the EnhancedGraphModel
model = EnhancedGraphModel(in_channels, hidden_channels, encoder_out, proj_dim, num_classes, use_aux_recon=True)
trainer = EnhancedTrainer(model, train_loader, test_loader, device, lr=1e-3, lambda_aux=0.5)



Total graphs: 10000; Training graphs: 9000; Testing graphs: 1000




In [44]:
pretrain_epochs = 10  # Adjust as necessary
print("Starting contrastive pre-training...")
trainer.pretrain(epochs=pretrain_epochs, drop_prob=0.1, temperature=0.7, margin=0.3, lambda_reg=0.4
                )

# %% [markdown]
# ### 6.2 Classification Fine-tuning
# 
# Fine-tune the classifier for anomaly detection/classification. Metrics are computed after each epoch.

# %%
finetune_epochs = 20  # Adjust as necessary
print("\nStarting classification fine-tuning...")
trainer.finetune(epochs=finetune_epochs, freeze_encoder=True)

Starting contrastive pre-training...




OutOfMemoryError: CUDA out of memory. Tried to allocate 728.00 MiB. GPU 0 has a total capacity of 15.89 GiB of which 589.12 MiB is free. Process 4668 has 15.31 GiB memory in use. Of the allocated memory 14.49 GiB is allocated by PyTorch, and 540.63 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

In [42]:
final_metrics = trainer.evaluate()
print("Final Evaluation Metrics:")
print(f"Accuracy: {final_metrics['accuracy']*100:.2f}%")
print(f"F1 Score: {final_metrics['f1']:.4f}")
print(f"ROC-AUC: {final_metrics['roc_auc']:.4f}")

Final Evaluation Metrics:
Accuracy: 66.60%
F1 Score: 0.6658
ROC-AUC: 0.7196


In [None]:
finetune_epochs = 10  # Adjust as necessary
print("\nStarting classification fine-tuning...")
trainer.finetune(epochs=finetune_epochs, freeze_encoder=True)