In [6]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import models
import h5py
from tqdm import tqdm
import numpy as np

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.manual_seed(42)
np.random.seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(42)


In [7]:

class JetDatasetLabeled(Dataset):
    def __init__(self, file_path, jet_key="jet", y_key="Y", pt_key="pT", m_key="m"):
        self.file_path = file_path
        self.jet_key = jet_key
        self.y_key = y_key
        self.pt_key = pt_key
        self.m_key = m_key
        
        with h5py.File(file_path, 'r') as f:
            self.length = f[jet_key].shape[0]
            
    def __len__(self):
        return self.length

    def __getitem__(self, idx):
        with h5py.File(self.file_path, 'r') as f:
            jet = f[self.jet_key][idx]
            y = f[self.y_key][idx]
            pt = f[self.pt_key][idx]
            m = f[self.m_key][idx]
            
        # Convert shape to [channels, eta, phi]
        jet = torch.tensor(jet, dtype=torch.float32).permute(2, 0, 1)
        y = torch.tensor(y, dtype=torch.float32)
        pt = torch.tensor(pt, dtype=torch.float32)
        m = torch.tensor(m, dtype=torch.float32)
        
        return jet, y, pt, m

In [8]:

class ResNetSSL(nn.Module):
    def __init__(self, latent_dim=256, projection_dim=128, resnet_depth=18):
        super().__init__()

        if resnet_depth == 18:
            self.encoder_backbone = models.resnet18(weights=None)
            num_bottleneck_features = 512
        elif resnet_depth == 34:
            self.encoder_backbone = models.resnet34(weights=None)
            num_bottleneck_features = 512
        else:
            raise ValueError("Unsupported ResNet depth")

        original_conv1 = self.encoder_backbone.conv1
        self.encoder_backbone.conv1 = nn.Conv2d(
            8,
            original_conv1.out_channels,
            kernel_size=original_conv1.kernel_size,
            stride=original_conv1.stride,
            padding=original_conv1.padding,
            bias=False
        )

        self.encoder_backbone.fc = nn.Identity()

        self.final_layer = nn.Linear(num_bottleneck_features, latent_dim)

        self.projector = nn.Sequential(
            nn.Linear(latent_dim, latent_dim * 2),
            nn.ReLU(),
            nn.Linear(latent_dim * 2, projection_dim)
        )

    def forward(self, x):
        features_bottleneck = self.encoder_backbone(x)
        features_latent = self.final_layer(features_bottleneck)
        projections = self.projector(features_latent)
        return features_latent, projections

In [9]:

def extract_features(model, dataloader):
    model.eval()
    features_list = []
    y_list = []
    pt_list = []
    m_list = []
    
    with torch.no_grad():
        for jets, y, pt, m in tqdm(dataloader, desc="Extracting features"):
            jets = jets.to(device)
            features, _ = model(jets)
            features_list.append(features.cpu())
            y_list.append(y)
            pt_list.append(pt)
            m_list.append(m)

    features = torch.cat(features_list, dim=0)
    y = torch.cat(y_list, dim=0)
    pt = torch.cat(pt_list, dim=0)
    m = torch.cat(m_list, dim=0)
    
    return features, y, pt, m



## Loading Pre-trained ResNetSSL Model Weights and then extracting latent vectors


In [10]:

def load_model(checkpoint_name, model, device):
    checkpoint_path = os.path.join("/kaggle/input/greattt", checkpoint_name)
    checkpoint = torch.load(checkpoint_path, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    print(f"Loaded model from {checkpoint_name} (Epoch {checkpoint['epoch']} with Loss {checkpoint['loss']:.4f})")
    return model

ssl_model = ResNetSSL(latent_dim=256, projection_dim=128, resnet_depth=18).to(device)
ssl_model = load_model("best_model (20).pt", ssl_model, device)
ssl_model.eval()
labeled_file_path = "/kaggle/input/dataset-specific-labelled-full-only-for-2i/Dataset_Specific_labelled_full_only_for_2i.h5"
batch_size = 64

labeled_dataset = JetDatasetLabeled(labeled_file_path)
labeled_loader = DataLoader(
        labeled_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=4,
        pin_memory=True
)

# Extract features
features, y_labels, _, _ = extract_features(ssl_model, labeled_loader)

  checkpoint = torch.load(checkpoint_path, map_location=device)


Loaded model from best_model (20).pt (Epoch 22 with Loss 0.0037)


Extracting features: 100%|██████████| 157/157 [00:41<00:00,  3.77it/s]
