In [1]:
import os
import torch
import torch.nn as nn
from einops import rearrange
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import models
import h5py
from tqdm import tqdm
import numpy as np

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.manual_seed(42)
np.random.seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(42)


In [2]:

class JetDatasetLabeled(Dataset):
    def __init__(self, file_path, jet_key="jet", y_key="Y", pt_key="pT", m_key="m"):
        self.file_path = file_path
        self.jet_key = jet_key
        self.y_key = y_key
        self.pt_key = pt_key
        self.m_key = m_key
        
        with h5py.File(file_path, 'r') as f:
            self.length = f[jet_key].shape[0]
            
    def __len__(self):
        return self.length

    def __getitem__(self, idx):
        with h5py.File(self.file_path, 'r') as f:
            jet = f[self.jet_key][idx]
            y = f[self.y_key][idx]
            pt = f[self.pt_key][idx]
            m = f[self.m_key][idx]
            
        # Convert shape to [channels, eta, phi]
        jet = torch.tensor(jet, dtype=torch.float32).permute(2, 0, 1)
        y = torch.tensor(y, dtype=torch.float32)
        pt = torch.tensor(pt, dtype=torch.float32)
        m = torch.tensor(m, dtype=torch.float32)
        
        return jet, y, pt, m

In [3]:
class ParticleTransformer(nn.Module):
    def __init__(self, in_channels=8, latent_dim=256):
        super().__init__()
        
        self.cnn = nn.Sequential(
            nn.Conv2d(8, 64, kernel_size=3, padding=1),
            nn.GELU(),
            nn.BatchNorm2d(64),
            nn.MaxPool2d(2),

            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.GELU(),
            nn.BatchNorm2d(128),
            nn.MaxPool2d(2),

            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.GELU(),
            nn.BatchNorm2d(256),
            nn.AdaptiveAvgPool2d((4, 4))
        )
        
        self.pos_embedding = nn.Parameter(torch.randn(1, 16, 256) * 0.02)
        
        self.transformer = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(
                d_model=256,
                nhead=8,
                dim_feedforward=1024,
                dropout=0.1,
                activation='gelu',
                batch_first=True,
                norm_first=True
            ),
            num_layers=2
        )
        
        self.head = nn.Sequential(
            nn.Linear(256, 512),
            nn.GELU(),
            nn.LayerNorm(512),
            nn.Dropout(0.1),
            nn.Linear(512, latent_dim)
        )
    
    def forward(self, x):
        x = self.cnn(x)
        x = rearrange(x, 'b c h w -> b (h w) c')
        x = x + self.pos_embedding
        x = self.transformer(x)
        x = x.mean(dim=1)
        return self.head(x)


In [4]:
class SimCLRModel(nn.Module):
    def __init__(self, latent_dim=256, projection_dim=128):
        super().__init__()
        self.encoder = ParticleTransformer(latent_dim=latent_dim)
        self.projector = nn.Sequential(
            nn.Linear(latent_dim, latent_dim),
            nn.GELU(),
            nn.LayerNorm(latent_dim),
            nn.Linear(latent_dim, projection_dim)
        )

    def forward(self, x):
        features = self.encoder(x)
        projections = self.projector(features)
        return features, projections    

In [5]:

def extract_features(model, dataloader):
    model.eval()
    features_list = []
    y_list = []
    pt_list = []
    m_list = []
    
    with torch.no_grad():
        for jets, y, pt, m in tqdm(dataloader, desc="Extracting features"):
            jets = jets.to(device)
            features, _ = model(jets)
            features_list.append(features.cpu())
            y_list.append(y)
            pt_list.append(pt)
            m_list.append(m)

    features = torch.cat(features_list, dim=0)
    y = torch.cat(y_list, dim=0)
    pt = torch.cat(pt_list, dim=0)
    m = torch.cat(m_list, dim=0)
    
    return features, y, pt, m



## Loading Pre-trained ParticleTransformer Model Weights and then extracting latent vectors


In [6]:

def load_model(checkpoint_name, model, device):
    checkpoint_path = os.path.join("/kaggle/input/greattt", checkpoint_name)
    checkpoint = torch.load(checkpoint_path, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    print(f"Loaded model from {checkpoint_name} (Epoch {checkpoint['epoch']} with Loss {checkpoint['loss']:.4f})")
    return model

ssl_model = SimCLRModel(latent_dim=256).to(device)
ssl_model = load_model("best_model (4).pt", ssl_model, device)
ssl_model.eval()
labeled_file_path = "/kaggle/input/dataset-specific-labelled-full-only-for-2i/Dataset_Specific_labelled_full_only_for_2i.h5"
batch_size = 64

labeled_dataset = JetDatasetLabeled(labeled_file_path)
labeled_loader = DataLoader(
        labeled_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=4,
        pin_memory=True
)

# Extract features
features, y_labels, _, _ = extract_features(ssl_model, labeled_loader)

  checkpoint = torch.load(checkpoint_path, map_location=device)


Loaded model from best_model (4).pt (Epoch 29 with Loss 0.0254)


Extracting features: 100%|██████████| 157/157 [00:37<00:00,  4.22it/s]


In [12]:
features_to_save = features.cpu()
labels_to_save = y_labels.cpu()
data_to_save = {
    'features': features_to_save,
    'labels': labels_to_save
}

torch.save(data_to_save, "/kaggle/working/ParticleTransformer_latent_vectors")
print("Data saved successfully using PyTorch.")

Data saved successfully using PyTorch.
