In [None]:
import os
import random
import warnings
import h5py
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, TensorDataset
from torchvision import transforms
import torch.optim as optim
from einops import rearrange
from tqdm import tqdm
from torchvision.transforms import GaussianBlur
import matplotlib.pyplot as plt
from sklearn.metrics import roc_curve, auc, mean_squared_error, f1_score, confusion_matrix, r2_score
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler, KBinsDiscretizer
from scipy.stats import gaussian_kde
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.amp import autocast, GradScaler

# Device configuration and seeding
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.manual_seed(42)
np.random.seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(42)

In [None]:

class Augmentation:
    def __init__(self, size=112, crop_scale=(0.8, 1.0), brightness=(0.8, 1.2)):
        self.size = size
        self.crop_scale = crop_scale
        self.brightness = brightness

    def __call__(self, x):
        # x is a torch tensor image of shape (C, H, W)
        _, h, w = x.shape
        
        # Random Resized Crop
        scale = random.uniform(self.crop_scale[0], self.crop_scale[1])
        new_h = int(h * scale)
        new_w = int(w * scale)
        if h > new_h and w > new_w:
            top = random.randint(0, h - new_h)
            left = random.randint(0, w - new_w)
            x = x[:, top:top+new_h, left:left+new_w]
        # Resize to the target size
        x = F.interpolate(x.unsqueeze(0), size=(self.size, self.size),
                          mode='bilinear', align_corners=False).squeeze(0)
        
        # Random Horizontal Flip
        if random.random() < 0.5:
            x = torch.flip(x, dims=[2])
            
        # Random brightness jitter
        if random.random() < 0.5:
            factor = random.uniform(self.brightness[0], self.brightness[1])
            x = x * factor
            x = torch.clamp(x, 0, 1)
        return x


In [None]:

class JetDatasetSSL(Dataset):
    def __init__(self, file_path, key, transform=None):
        self.file_path = file_path
        self.key = key
        with h5py.File(file_path, 'r') as f:
            self.length = f[key].shape[0]
        self.transform = transforms.Compose([Augmentation()]) if transform is None else transform

    def __len__(self):
        return self.length

    def __getitem__(self, idx):
        with h5py.File(self.file_path, 'r') as f:
            jet = f[self.key][idx]  # Expected shape: [125, 125, 8]
        # Convert shape to [channels, eta, phi]
        jet = torch.tensor(jet, dtype=torch.float32).permute(2, 0, 1)
        # Create two augmented views for contrastive learning
        view1 = self.transform(jet)
        view2 = self.transform(jet)
        return view1, view2

In [5]:
# New class for labeled dataset
class JetDatasetLabeled(Dataset):
    def __init__(self, file_path, jet_key="jet", y_key="Y", pt_key="pT", m_key="m"):
        self.file_path = file_path
        self.jet_key = jet_key
        self.y_key = y_key
        self.pt_key = pt_key
        self.m_key = m_key
        
        with h5py.File(file_path, 'r') as f:
            self.length = f[jet_key].shape[0]
            
    def __len__(self):
        return self.length

    def __getitem__(self, idx):
        with h5py.File(self.file_path, 'r') as f:
            jet = f[self.jet_key][idx]
            y = f[self.y_key][idx]
            pt = f[self.pt_key][idx]
            m = f[self.m_key][idx]
            
        # Convert shape to [channels, eta, phi]
        jet = torch.tensor(jet, dtype=torch.float32).permute(2, 0, 1)
        y = torch.tensor(y, dtype=torch.long)
        pt = torch.tensor(pt, dtype=torch.float32)
        m = torch.tensor(m, dtype=torch.float32)
        
        return jet, y, pt, m

In [None]:
class ParticleTransformer(nn.Module):
    def __init__(self, in_channels=8, latent_dim=256):
        super().__init__()
        
        self.cnn = nn.Sequential(
            nn.Conv2d(8, 64, kernel_size=3, padding=1),
            nn.GELU(),
            nn.BatchNorm2d(64),
            nn.MaxPool2d(2),

            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.GELU(),
            nn.BatchNorm2d(128),
            nn.MaxPool2d(2),

            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.GELU(),
            nn.BatchNorm2d(256),
            nn.AdaptiveAvgPool2d((4, 4))
        )
        
        self.pos_embedding = nn.Parameter(torch.randn(1, 16, 256) * 0.02)
        
        self.transformer = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(
                d_model=256,
                nhead=8,
                dim_feedforward=1024,
                dropout=0.1,
                activation='gelu',
                batch_first=True,
                norm_first=True
            ),
            num_layers=2
        )
        
        self.head = nn.Sequential(
            nn.Linear(256, 512),
            nn.GELU(),
            nn.LayerNorm(512),
            nn.Dropout(0.1),
            nn.Linear(512, latent_dim)
        )
    
    def forward(self, x):
        x = self.cnn(x)
        x = rearrange(x, 'b c h w -> b (h w) c')
        x = x + self.pos_embedding
        x = self.transformer(x)
        x = x.mean(dim=1)
        return self.head(x)

In [7]:
class SimCLRModel(nn.Module):
    def __init__(self, latent_dim=256, projection_dim=128):
        super().__init__()
        self.encoder = ParticleTransformer(latent_dim=latent_dim)
        self.projector = nn.Sequential(
            nn.Linear(latent_dim, latent_dim),
            nn.GELU(),
            nn.LayerNorm(latent_dim),
            nn.Linear(latent_dim, projection_dim)
        )

    def forward(self, x):
        features = self.encoder(x)
        projections = self.projector(features)
        return features, projections       

In [8]:
class LARS(optim.Optimizer):
    def __init__(self, optimizer, eps=1e-9, trust_coef=0.002):
        """
        LARS wrapper for an optimizer with updated hyperparameters.

        Args:
            optimizer: The inner optimizer (e.g. AdamW).
            eps: Small value for numerical stability.
            trust_coef: Coefficient for computing the local learning rate.
        """
        self.optimizer = optimizer
        self.eps = eps
        self.trust_coef = trust_coef

    def step(self, closure=None):
        for group in self.optimizer.param_groups:
            weight_decay = group.get("weight_decay", 0)
            for p in group["params"]:
                if p.grad is None:
                    continue
                d_p = p.grad.data
                if weight_decay != 0:
                    d_p = d_p.add(p.data, alpha=weight_decay)
                w_norm = p.data.norm()
                g_norm = d_p.norm()
                if w_norm > 0 and g_norm > 0:
                    local_lr = self.trust_coef * w_norm / (g_norm + self.eps)
                    d_p.mul_(local_lr)
        self.optimizer.step(closure)

    def zero_grad(self):
        self.optimizer.zero_grad()

In [9]:
class NTXentLoss(nn.Module):
    def __init__(self, temperature=0.05, eps=1e-6):
        super(NTXentLoss, self).__init__()
        self.temperature = temperature
        self.eps = eps

    def forward(self, z_i, z_j):
        batch_size = z_i.size(0)

        # Normalize the input embeddings
        z_i = F.normalize(z_i, dim=1, eps=self.eps)
        z_j = F.normalize(z_j, dim=1, eps=self.eps)

        # Concatenate the embeddings
        representations = torch.cat([z_i, z_j], dim=0)

        # Compute similarity matrix (use float32 for stability)
        similarity_matrix = torch.mm(representations, representations.t()) / self.temperature
        similarity_matrix = similarity_matrix.float()

        # Mask out self-similarity using a safer mask value
        mask = torch.eye(2 * batch_size, device=z_i.device).bool()
        similarity_matrix = similarity_matrix.masked_fill(mask, -1e4)

        # Compute positive and negative samples
        positives = torch.cat([
            torch.diag(similarity_matrix, batch_size),
            torch.diag(similarity_matrix, -batch_size)
        ])

        negatives = similarity_matrix[~mask].view(2 * batch_size, -1)

        # Calculate the loss using logsumexp for numerical stability
        logsumexp_negatives = torch.logsumexp(negatives, dim=1)
        loss = -torch.log(torch.exp(positives) / (torch.exp(logsumexp_negatives) + self.eps)).mean()

        return loss

In [10]:
def train_ssl(model, train_loader, num_epochs=100, save_dir="./models", validation_loader=None, patience=5, use_lars=True):
    """
    Train the SSL model with improved monitoring, validation, early stopping, and optional LARS.
    
    Args:
        model: The SimCLR model to train.
        train_loader: DataLoader for training data.
        num_epochs: Number of training epochs.
        save_dir: Directory to save model checkpoints.
        validation_loader: Optional loader for validation.
        patience: Number of epochs with no improvement after which training will be stopped.
        use_lars: If True, wraps the underlying optimizer with LARS.
        
    Returns:
        Trained model and training history.
    """
    os.makedirs(save_dir, exist_ok=True)
    
    # Setup base optimizer using AdamW
    base_optimizer = optim.AdamW(model.parameters(), lr=3e-4, weight_decay=4e-5)
    optimizer = LARS(base_optimizer, eps=1e-9, trust_coef=0.002) if use_lars else base_optimizer
    
    # Use the inner optimizer for the scheduler if LARS wrapper is used
    scheduler = CosineAnnealingLR(optimizer.optimizer if use_lars else optimizer, T_max=num_epochs)
    
    # Contrastive loss with temperature scaling
    criterion = NTXentLoss(temperature=0.05)
    
    best_loss = float('inf')
    history = {'train_loss': [], 'val_loss': []}
    early_stop_counter = 0
    scaler = GradScaler(device=device) if device.type == "cuda" else GradScaler()
    
    for epoch in range(num_epochs):
        model.train()
        total_loss = 0.0
        pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}")

        for view1, view2 in pbar:
            x1, x2 = view1.to(device), view2.to(device)
            optimizer.zero_grad()
            
            with autocast(device_type=device.type):
                _, proj1 = model(x1)
                _, proj2 = model(x2)
                loss = criterion(proj1, proj2)
            scaler.scale(loss).backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            scaler.step(optimizer.optimizer if use_lars else optimizer)
            scaler.update()
            
            total_loss += loss.item()
            current_lr = (optimizer.optimizer.param_groups[0]['lr'] if use_lars 
                          else optimizer.param_groups[0]['lr'])
            pbar.set_postfix({'loss': loss.item(), 'lr': current_lr})
        
        avg_loss = total_loss / len(train_loader)
        history['train_loss'].append(avg_loss)
        
        current_loss = avg_loss
        if validation_loader is not None:
            model.eval()
            val_total_loss = 0.0
            with torch.no_grad():
                for view1, view2 in validation_loader:
                    x1, x2 = view1.to(device), view2.to(device)
                    _, proj1 = model(x1)
                    _, proj2 = model(x2)
                    loss = criterion(proj1, proj2)
                    val_total_loss += loss.item()
            val_loss = val_total_loss / len(validation_loader)
            history['val_loss'].append(val_loss)
            print(f"Epoch {epoch+1} - Train Loss: {avg_loss:.4f}, Val Loss: {val_loss:.4f}")
            current_loss = val_loss
        else:
            print(f"Epoch {epoch+1} - Train Loss: {avg_loss:.4f}")
        
        if val_loss < best_loss:
            best_loss = val_loss
            early_stop_counter = 0
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': (optimizer.optimizer.state_dict() if use_lars 
                                         else optimizer.state_dict()),
                'loss': best_loss,
                'history': history,
            }, os.path.join(save_dir, "best_model.pt"))
            print(f"New best model saved! Loss: {best_loss:.4f}")
        else:
            early_stop_counter += 1
            if early_stop_counter >= patience:
                print(f"Early stopping triggered at epoch {epoch+1} with best loss {best_loss:.4f}")
                break
        
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': (optimizer.optimizer.state_dict() if use_lars 
                                     else optimizer.state_dict()),
            'scheduler_state_dict': scheduler.state_dict(),
            'loss': current_loss,
            'history': history,
        }, os.path.join(save_dir, f"checkpoint_epoch_{epoch+1}.pt"))
            
        scheduler.step()
    
    # Plot training curve
    plt.figure(figsize=(12, 4))
    plt.subplot(1, 2, 1)
    plt.plot(history['train_loss'], label='Training Loss')
    if history['val_loss']:
        plt.plot(history['val_loss'], label='Validation Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.title('Training and Validation Loss')
    
    plt.subplot(1, 2, 2)
    plt.semilogy(history['train_loss'], label='Training Loss')
    if history['val_loss']:
        plt.semilogy(history['val_loss'], label='Validation Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss (log scale)')
    plt.legend()
    plt.title('Loss on Log Scale')
    
    plt.tight_layout()
    plt.savefig(os.path.join(save_dir, "training_curve.png"))
    plt.close()
    
    return model, history

In [11]:

# Extract features using the pre-trained encoder
def extract_features(model, dataloader):
    model.eval()
    features_list = []
    y_list = []
    pt_list = []
    m_list = []
    
    with torch.no_grad():
        for jets, y, pt, m in tqdm(dataloader, desc="Extracting features"):
            jets = jets.to(device)
            features, _ = model(jets)
            features_list.append(features.cpu())
            y_list.append(y)
            pt_list.append(pt)
            m_list.append(m)
    
    # Concatenate all batches
    features = torch.cat(features_list, dim=0)
    y = torch.cat(y_list, dim=0)
    pt = torch.cat(pt_list, dim=0)
    m = torch.cat(m_list, dim=0)
    
    return features, y, pt, m



In [12]:

def main():
    # Configuration
    ssl_file_path = "/kaggle/input/dataset-specific-unlabelled/Dataset_Specific_Unlabelled.h5"
    labeled_file_path = "/kaggle/input/dataset-specific-labelled-full-only-for-2i/Dataset_Specific_labelled_full_only_for_2i.h5"
    batch_size = 128
    num_epochs_ssl = 30  # Increased for better convergence
    latent_dim = 256
    save_dir = "./models"
    os.makedirs(save_dir, exist_ok=True)
    
    # Set random seeds for reproducibility
    torch.manual_seed(42)
    np.random.seed(42)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(42)
    
    # Step 1: Self-Supervised Learning Pretraining
    print("Loading SSL dataset...")
    try:
        ssl_dataset = JetDatasetSSL(ssl_file_path, "jet")
        print(f"Dataset loaded with {len(ssl_dataset)} samples")
        
        # Create train/validation split for SSL
        train_size = int(0.9 * len(ssl_dataset))
        val_size = len(ssl_dataset) - train_size
        ssl_train_dataset, ssl_val_dataset = torch.utils.data.random_split(
            ssl_dataset, [train_size, val_size], 
            generator=torch.Generator().manual_seed(42)
        )
        
        ssl_train_loader = DataLoader(
            ssl_train_dataset, 
            batch_size=batch_size,
            shuffle=True,
            num_workers=4,
            pin_memory=True,
            drop_last=True
        )
        
        ssl_val_loader = DataLoader(
            ssl_val_dataset, 
            batch_size=batch_size,
            shuffle=False,
            num_workers=4,
            pin_memory=True,
            drop_last=True
        )
        
        # Initialize model
        print(f"Initializing SimCLR model on {device}")
        ssl_model = SimCLRModel(latent_dim=latent_dim).to(device)
        
        # Print model summary
        total_params = sum(p.numel() for p in ssl_model.parameters())
        print(f"Model created with {total_params:,} parameters")
        
        # Train model
        print("Starting SSL training...")
        ssl_model, history = train_ssl(
            ssl_model, 
            ssl_train_loader, 
            num_epochs=num_epochs_ssl, 
            save_dir=save_dir,
            validation_loader=ssl_val_loader ,use_lars=True)  
        
        print("SSL training completed!")
        
    except Exception as e:
        print(f"Error during SSL training: {e}")
        raise

if __name__ == "__main__":
    main()

Loading SSL dataset...
Dataset loaded with 60000 samples
Initializing SimCLR model on cuda




Model created with 2,321,344 parameters
Starting SSL training...


Epoch 1/30: 100%|██████████| 421/421 [04:32<00:00,  1.55it/s, loss=4.05, lr=0.0003]


Epoch 1 - Train Loss: 4.7785, Val Loss: 5.6166
New best model saved! Loss: 5.6166


Epoch 2/30: 100%|██████████| 421/421 [04:09<00:00,  1.69it/s, loss=3.02, lr=0.000299]


Epoch 2 - Train Loss: 3.4739, Val Loss: 4.0090
New best model saved! Loss: 4.0090


Epoch 3/30: 100%|██████████| 421/421 [03:30<00:00,  2.00it/s, loss=2.44, lr=0.000297]


Epoch 3 - Train Loss: 2.9188, Val Loss: 11.1161


Epoch 4/30: 100%|██████████| 421/421 [03:29<00:00,  2.01it/s, loss=1.8, lr=0.000293]


Epoch 4 - Train Loss: 2.2420, Val Loss: 3.0492
New best model saved! Loss: 3.0492


Epoch 5/30: 100%|██████████| 421/421 [03:25<00:00,  2.05it/s, loss=1.22, lr=0.000287]


Epoch 5 - Train Loss: 1.6156, Val Loss: 2.0525
New best model saved! Loss: 2.0525


Epoch 6/30: 100%|██████████| 421/421 [03:32<00:00,  1.98it/s, loss=0.725, lr=0.00028]


Epoch 6 - Train Loss: 1.0695, Val Loss: 1.1865
New best model saved! Loss: 1.1865


Epoch 7/30: 100%|██████████| 421/421 [03:33<00:00,  1.97it/s, loss=0.744, lr=0.000271]


Epoch 7 - Train Loss: 0.9892, Val Loss: 3.6994


Epoch 8/30: 100%|██████████| 421/421 [03:32<00:00,  1.98it/s, loss=0.655, lr=0.000261]


Epoch 8 - Train Loss: 0.7230, Val Loss: 5.7851


Epoch 9/30: 100%|██████████| 421/421 [03:23<00:00,  2.07it/s, loss=0.671, lr=0.00025]


Epoch 9 - Train Loss: 0.6023, Val Loss: 2.9267


Epoch 10/30: 100%|██████████| 421/421 [03:25<00:00,  2.05it/s, loss=0.296, lr=0.000238]


Epoch 10 - Train Loss: 0.4459, Val Loss: 0.9301
New best model saved! Loss: 0.9301


Epoch 11/30: 100%|██████████| 421/421 [03:24<00:00,  2.06it/s, loss=0.271, lr=0.000225]


Epoch 11 - Train Loss: 0.3265, Val Loss: 0.3724
New best model saved! Loss: 0.3724


Epoch 12/30: 100%|██████████| 421/421 [03:28<00:00,  2.02it/s, loss=0.254, lr=0.000211]


Epoch 12 - Train Loss: 0.3233, Val Loss: 0.1862
New best model saved! Loss: 0.1862


Epoch 13/30: 100%|██████████| 421/421 [03:27<00:00,  2.03it/s, loss=0.205, lr=0.000196]


Epoch 13 - Train Loss: 0.2373, Val Loss: 0.2413


Epoch 14/30: 100%|██████████| 421/421 [03:24<00:00,  2.06it/s, loss=0.154, lr=0.000181]


Epoch 14 - Train Loss: 0.1712, Val Loss: 0.1550
New best model saved! Loss: 0.1550


Epoch 15/30: 100%|██████████| 421/421 [03:29<00:00,  2.01it/s, loss=0.17, lr=0.000166]


Epoch 15 - Train Loss: 0.1433, Val Loss: 0.3746


Epoch 16/30: 100%|██████████| 421/421 [03:24<00:00,  2.06it/s, loss=0.147, lr=0.00015]


Epoch 16 - Train Loss: 0.1144, Val Loss: 0.1190
New best model saved! Loss: 0.1190


Epoch 17/30: 100%|██████████| 421/421 [03:28<00:00,  2.02it/s, loss=0.102, lr=0.000134]


Epoch 17 - Train Loss: 0.0836, Val Loss: 0.0629
New best model saved! Loss: 0.0629


Epoch 18/30: 100%|██████████| 421/421 [03:30<00:00,  2.00it/s, loss=0.0728, lr=0.000119]


Epoch 18 - Train Loss: 0.0752, Val Loss: 0.0967


Epoch 19/30: 100%|██████████| 421/421 [03:27<00:00,  2.03it/s, loss=0.0958, lr=0.000104]


Epoch 19 - Train Loss: 0.0692, Val Loss: 0.0485
New best model saved! Loss: 0.0485


Epoch 20/30: 100%|██████████| 421/421 [03:28<00:00,  2.02it/s, loss=0.0514, lr=8.9e-5]


Epoch 20 - Train Loss: 0.0643, Val Loss: 0.0406
New best model saved! Loss: 0.0406


Epoch 21/30: 100%|██████████| 421/421 [03:26<00:00,  2.04it/s, loss=0.0758, lr=7.5e-5]


Epoch 21 - Train Loss: 0.0578, Val Loss: 0.0356
New best model saved! Loss: 0.0356


Epoch 22/30: 100%|██████████| 421/421 [03:27<00:00,  2.03it/s, loss=0.0597, lr=6.18e-5]


Epoch 22 - Train Loss: 0.0523, Val Loss: 0.0387


Epoch 23/30: 100%|██████████| 421/421 [03:31<00:00,  1.99it/s, loss=0.0278, lr=4.96e-5]


Epoch 23 - Train Loss: 0.0521, Val Loss: 0.0378


Epoch 24/30: 100%|██████████| 421/421 [03:28<00:00,  2.02it/s, loss=0.0326, lr=3.85e-5]


Epoch 24 - Train Loss: 0.0477, Val Loss: 0.0306
New best model saved! Loss: 0.0306


Epoch 25/30: 100%|██████████| 421/421 [03:34<00:00,  1.97it/s, loss=0.0592, lr=2.86e-5]


Epoch 25 - Train Loss: 0.0437, Val Loss: 0.0295
New best model saved! Loss: 0.0295


Epoch 26/30: 100%|██████████| 421/421 [03:38<00:00,  1.93it/s, loss=0.0465, lr=2.01e-5]


Epoch 26 - Train Loss: 0.0423, Val Loss: 0.0278
New best model saved! Loss: 0.0278


Epoch 27/30: 100%|██████████| 421/421 [03:35<00:00,  1.96it/s, loss=0.0194, lr=1.3e-5]


Epoch 27 - Train Loss: 0.0405, Val Loss: 0.0281


Epoch 28/30: 100%|██████████| 421/421 [03:32<00:00,  1.98it/s, loss=0.0306, lr=7.34e-6]


Epoch 28 - Train Loss: 0.0380, Val Loss: 0.0309


Epoch 29/30: 100%|██████████| 421/421 [03:32<00:00,  1.98it/s, loss=0.0226, lr=3.28e-6]


Epoch 29 - Train Loss: 0.0394, Val Loss: 0.0271
New best model saved! Loss: 0.0271


Epoch 30/30: 100%|██████████| 421/421 [03:37<00:00,  1.94it/s, loss=0.0273, lr=8.22e-7]


Epoch 30 - Train Loss: 0.0390, Val Loss: 0.0254
New best model saved! Loss: 0.0254
SSL training completed!
