In [None]:
!git clone https://github.com/Mamiglia/challenge.git

In [None]:
import os
import h5py
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, Subset
from collections import defaultdict
import matplotlib.pyplot as plt
from huggingface_hub import hf_hub_download
from transformers import CLIPModel, CLIPProcessor
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from pathlib import Path
from tqdm import tqdm
import torch.nn.functional as F
import random
import numpy as np
import math
import pickle

from transformers import CLIPModel,CLIPProcessor
import torch
import numpy as np
from tqdm import tqdm
import transformers.utils.hub as hub_utils 
from PIL import Image
import pandas as pd
from collections import defaultdict
from challenge.src.eval import evaluate_retrieval

In [None]:
# ============================================================================
# SECTION 1: DOWNLOAD MANAGER
# ============================================================================
def get_dataset_paths():
    """Download the 3 critical components from HuggingFace"""
    print("Downloading files from HuggingFace...")
    repo = "pscotti/mindeyev2"
    
    betas_path = hf_hub_download(
        repo_id=repo, 
        filename="betas_all_subj01_fp32_renorm.hdf5", 
        repo_type="dataset"
    )
    
    images_path = hf_hub_download(
        repo_id=repo, 
        filename="coco_images_224_float16.hdf5", 
        repo_type="dataset"
    )
    
    behav_path = hf_hub_download(
        repo_id=repo, 
        filename="COCO_73k_subj_indices.hdf5", 
        repo_type="dataset"
    )
    
    print("✓ Files downloaded successfully")
    return betas_path, images_path, behav_path

In [None]:
class MindEyeDataset(Dataset):
    def __init__(self, betas_file, images_file, behav_file, 
                 subject='subj01', transform=None):
        """
        Args:
            betas_file: Path to fMRI betas HDF5
            images_file: Path to images HDF5
            behav_file: Path to behavior/indices HDF5
            subject: Subject ID (default: 'subj01')
            transform: Optional image transforms (for CLIP preprocessing)
        """
        # Store paths instead of file handles (for multiprocessing)
        self.betas_path = betas_file
        self.images_path = images_file
        self.transform = transform
        
        # Load metadata (image indices for each trial)
        with h5py.File(behav_file, 'r') as f:
            self.behav_indices = f[subject][:]
        
        # File handles will be opened per-worker in __getitem__
        self._betas_file = None
        self._images_file = None
        
        print(f"Dataset initialized: {len(self.behav_indices)} trials")
    
    def _ensure_files_open(self):
        """Lazy open files (called in __getitem__)"""
        if self._betas_file is None:
            self._betas_file = h5py.File(self.betas_path, 'r')
            self._images_file = h5py.File(self.images_path, 'r')
    
    def __len__(self):
        return len(self.behav_indices)
    
    def __getitem__(self, idx):
        # Ensure files are open (lazy initialization per worker)
        self._ensure_files_open()
        
        # OPTIMIZED: Direct access to open files (very fast!)
        fmri_data = self._betas_file['betas'][idx]
        
        # Get image ID (direct mapping: behav_indices = cocoidx)
        image_id = self.behav_indices[idx]
        
        # Load image using cocoidx
        image_data = self._images_file['images'][image_id]
        
        # Convert to tensors
        fmri_tensor = torch.tensor(fmri_data, dtype=torch.float32)
        
        
        image_tensor = torch.tensor(image_data, dtype=torch.float32)
        
        # Apply transforms if provided (for CLIP)
        if self.transform:
            image_tensor = self.transform(image_tensor)
        
        return {
            'fmri': fmri_tensor,      # (15724,)
            'image': image_tensor,     # (3, 224, 224) in [0, 1]
            'image_id': image_id,      # COCO image ID (0-72999)
            'trial_idx': idx           # Trial index (0-29999 for subj01)
        }
    
    def __del__(self):
        """Close HDF5 files when dataset is destroyed"""
        if self._betas_file is not None:
            self._betas_file.close()
        if self._images_file is not None:
            self._images_file.close()

In [None]:
# ============================================================================
# SECTION 3: TRAIN/VAL SPLIT (NO DATA LEAKAGE)
# ============================================================================
def create_splits(dataset, train_ratio=0.8, seed=42):
    """
    Split by IMAGE ID to avoid data leakage from repetitions
    
    Returns:
        train_indices, val_indices: Lists of trial indices
    """
    np.random.seed(seed)
    
    # Group trials by image_id
    image_to_trials = defaultdict(list)
    for idx in range(len(dataset)):
        image_id = dataset.behav_indices[idx]
        image_to_trials[image_id].append(idx)
    
    # Split at IMAGE level (not trial level!)
    unique_images = list(image_to_trials.keys())
    np.random.shuffle(unique_images)
    
    n_train_images = int(len(unique_images) * train_ratio)
    train_images = set(unique_images[:n_train_images])
    val_images = set(unique_images[n_train_images:])
    
    # Collect trial indices
    train_indices = []
    val_indices = []
    
    for image_id, trial_list in image_to_trials.items():
        if image_id in train_images:
            train_indices.extend(trial_list)
        else:
            val_indices.extend(trial_list)
    
    print(f"\n=== Dataset Split ===")
    print(f"Train: {len(train_images)} images, {len(train_indices)} trials")
    print(f"Val:   {len(val_images)} images, {len(val_indices)} trials")
    print(f"Repetitions per image: ~{len(train_indices)/len(train_images):.1f}")
    
    return train_indices, val_indices


In [None]:
# ============================================================================
# SECTION 4: CLIP EMBEDDINGS EXTRACTOR
# ============================================================================
class CLIPEmbeddingsExtractor:
    def __init__(self, model_name="openai/clip-vit-large-patch14", device='cuda'):
        """
        Load CLIP model for extracting image embeddings
        
        Args:
            model_name: CLIP model variant
            device: 'cuda' or 'cpu'
        """
        print(f"\nLoading CLIP model: {model_name}")
        self.device = torch.device(device if torch.cuda.is_available() else 'cpu')
        
        self.model = CLIPModel.from_pretrained(model_name).to(self.device)
        self.processor = CLIPProcessor.from_pretrained(model_name)
        self.model.eval()
        
        print(f"✓ CLIP loaded on {self.device}")
        print(f"  Embedding dimension: {self.model.config.projection_dim}")
    
    @torch.no_grad()
    def extract_image_embeddings(self, images):
        """
        Extract CLIP embeddings from images
        
        Args:
            images: Tensor of shape (B, 3, 224, 224) in range [0, 1]
        
        Returns:
            embeddings: Tensor of shape (B, 768) for ViT-Large
        """
        # CLIP expects images in [0, 1] range (already normalized in dataset)
        # But we need to apply CLIP's specific normalization
        
        # Move to correct device
        images = images.to(self.device)
        
        # Get image features (not normalized)
        image_features = self.model.get_image_features(pixel_values=images)
        
        # Normalize embeddings (CLIP does this internally for similarity)
        image_features = image_features / image_features.norm(dim=-1, keepdim=True)
        
        return image_features
    
    def preprocess_images(self, images):
        """
        Apply CLIP preprocessing to raw images
        
        Args:
            images: Tensor (B, 3, 224, 224) in [0, 1]
        
        Returns:
            Preprocessed tensor ready for CLIP
        """
        # Convert to PIL for processor
        # Note: For simplicity, we'll use the raw tensors
        # CLIP processor normalizes with ImageNet stats
        
        # Manual normalization (CLIP uses ImageNet stats)
        mean = torch.tensor([0.48145466, 0.4578275, 0.40821073]).view(1, 3, 1, 1)
        std = torch.tensor([0.26862954, 0.26130258, 0.27577711]).view(1, 3, 1, 1)
        
        images = (images - mean.to(images.device)) / std.to(images.device)
        
        return images

In [None]:
# ============================================================================
# SECTION 5: CLIP EMBEDDING CACHE BUILDER
# ============================================================================
def build_clip_cache(dataset, clip_extractor, cache_path='clip_embeddings.pt', 
                     batch_size=64):
    """
    Pre-compute all CLIP embeddings and save to disk
    This saves time during training!
    
    Args:
        dataset: MindEyeDataset instance
        clip_extractor: CLIPEmbeddingsExtractor instance
        cache_path: Where to save embeddings
        batch_size: Batch size for processing
    
    Returns:
        embeddings: Tensor of shape (n_trials, 768)
    """
    if os.path.exists(cache_path):
        print(f"Loading cached CLIP embeddings from {cache_path}")
        return torch.load(cache_path)
    
    print(f"\nBuilding CLIP embeddings cache...")
    print(f"This will take a few minutes but only needs to be done once!")
    
    dataloader = DataLoader(dataset, batch_size=batch_size, 
                           shuffle=False, num_workers=0)
    
    all_embeddings = []
    
    for batch in tqdm(dataloader, desc="Extracting CLIP embeddings"):
        images = batch['image']
        
        # Preprocess for CLIP
        images = clip_extractor.preprocess_images(images)
        
        # Extract embeddings
        embeddings = clip_extractor.extract_image_embeddings(images)
        all_embeddings.append(embeddings.cpu())
    
    # Concatenate all embeddings
    all_embeddings = torch.cat(all_embeddings, dim=0)
    
    # Save to disk
    torch.save(all_embeddings, cache_path)
    print(f"✓ Saved {all_embeddings.shape[0]} embeddings to {cache_path}")
    print(f"  Shape: {all_embeddings.shape}")
    
    return all_embeddings

In [None]:
# ============================================================================
# SECTION 6: DATASET WITH PRECOMPUTED CLIP EMBEDDINGS
# ============================================================================
class MindEyeWithCLIP(Dataset):
    """
    Dataset that returns (fMRI, CLIP_embedding) pairs
    Much faster than computing CLIP on-the-fly!
    """
    def __init__(self, base_dataset, clip_embeddings):
        self.base_dataset = base_dataset
        self.clip_embeddings = clip_embeddings
    
    def __len__(self):
        return len(self.base_dataset)
    
    def __getitem__(self, idx):
        fmri = self.base_dataset[idx]['fmri']
        clip_emb = self.clip_embeddings[idx]
        
        return (
            fmri,              # (15724,)
            clip_emb, # (768,)
            self.base_dataset[idx]['image_id']
        )

In [None]:
def symmetric_contrastive_loss(text_proj, image_emb, temperature=0.05):
    text_proj = F.normalize(text_proj, dim=-1)
    image_emb = F.normalize(image_emb, dim=-1)
    logits = torch.matmul(text_proj, image_emb.T) / temperature
    batch_size = logits.shape[0]
    labels = torch.arange(batch_size, device=logits.device)
    loss_t2i = F.cross_entropy(logits, labels)
    loss_i2t = F.cross_entropy(logits.T, labels)
    return (loss_t2i + loss_i2t) / 2

def cosine_regression_loss(text_proj, image_emb):
    text_proj = F.normalize(text_proj, dim=-1)
    image_emb = F.normalize(image_emb, dim=-1)
    # Maximize diagonal similarities
    cos_sim = (text_proj * image_emb).sum(dim=-1)
    return (1 - cos_sim).mean()
    
def triplet_loss(text_proj, image_emb, margin=0.4):
    text_proj = F.normalize(text_proj, dim=-1)
    image_emb = F.normalize(image_emb, dim=-1)
    batch_size = text_proj.shape[0]
    sims = torch.matmul(text_proj, image_emb.T)
    pos_sims = sims.diagonal()
    mask = 1.0 - torch.eye(batch_size, device=sims.device)
    neg_sims = sims * mask + torch.eye(batch_size, device=sims.device) * -1e9
    hard_neg_sims, _ = neg_sims.max(dim=1)
    loss = F.relu(margin - pos_sims + hard_neg_sims).mean()
    return loss


def combined_loss(text_proj, image_emb, loss_arg):
    alpha = loss_arg["ALPHA"]
    beta = loss_arg.get("BETA", 0.3)  # New weight
       
    contrastive = symmetric_contrastive_loss(text_proj, image_emb, 
                                                loss_arg["TEMPERATURE"])
    triplet = triplet_loss(text_proj, image_emb, 
                             margin=loss_arg["MARGIN"])
    regression = cosine_regression_loss(text_proj, image_emb)
       
    return alpha * contrastive + beta * triplet + (1-alpha-beta) * regression

In [None]:
def compute_mrr_at_k_batched(text_proj, image_emb, k=100, batch_size=128):
    text_proj = F.normalize(text_proj, dim=-1)
    image_emb = F.normalize(image_emb, dim=-1)
    
    N = text_proj.shape[0]
    reciprocal_ranks = []
    
    with torch.no_grad():
        for start in range(0, N, batch_size):
            end = min(start + batch_size, N)
            batch_t = text_proj[start:end]
            
            sims = torch.matmul(batch_t, image_emb.T)
            top_k_values, top_k_indices = torch.topk(sims, k=min(k, N), dim=1)
            
            for i in range(end - start):
                true_idx = start + i
                top_k_for_query = top_k_indices[i].cpu().numpy()
                position = (top_k_for_query == true_idx).nonzero()
                
                if len(position[0]) > 0:
                    rank = position[0][0] + 1
                    reciprocal_ranks.append(1.0 / rank)
                else:
                    reciprocal_ranks.append(0.0)
    
    return sum(reciprocal_ranks) / len(reciprocal_ranks)

In [None]:
def train_model(augmentation_par,train_par, loss_par,model, train_loader, val_loader, device, model_path,augmenter=None):
    LR =train_par["LR"]
    WEIGHT_DEC = train_par["WEIGHT_DEC"]
    WARMUP = train_par["WARMUP"]
    USE_AUGMENTATION = augmentation_par["USE_AUGMENTATION"]
    epochs = train_par["EPOCHS"]
    
    optimizer = optim.AdamW(model.parameters(), lr=LR, weight_decay=WEIGHT_DEC)
    warmup_epochs = WARMUP
    scheduler_warmup = optim.lr_scheduler.LinearLR(optimizer, start_factor=0.1, total_iters=warmup_epochs)
    scheduler_cosine = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs-warmup_epochs, eta_min=1e-7)
    
    best_mrr = 0.0
    patience_counter = 0
    patience = 5
    loss_fn = loss_par["FUNC"]
    for epoch in range(epochs):
        model.train()
        train_loss = 0
        for X_batch, y_batch, imaged_id in tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}"):
            
            X_batch, y_batch = X_batch.to(DEVICE), y_batch.to(DEVICE)
            optimizer.zero_grad()
            outputs = model(X_batch)
            loss = loss_fn(outputs, y_batch,loss_par["ARG"])
            loss.backward()
            optimizer.step()
            train_loss += loss.item()
    
        train_loss /= len(train_loader)
    
        model.eval()
        val_loss = 0
        with torch.no_grad():
            for X_batch, y_batch, imaged_id in val_loader:
                X_batch, y_batch = X_batch.to(DEVICE), y_batch.to(DEVICE)
                outputs = model(X_batch)
                loss =  loss_fn(outputs, y_batch,loss_par["ARG"])
                val_loss += loss.item()
    
        val_loss /= len(val_loader)
    
        if (epoch + 1) % 2 == 0 or epoch == epochs - 1:
            all_preds = []
            all_targets = []
            model.eval()
            with torch.no_grad():
                for X_batch, y_batch, imaged_id in val_loader:
                    X_batch = X_batch.to(DEVICE)
                    y_batch = y_batch.to(DEVICE)
                    pred_batch = model(X_batch)
                    all_preds.append(pred_batch.cpu())
                    all_targets.append(y_batch.cpu())
    
            all_preds = torch.cat(all_preds, dim=0)
            all_targets = torch.cat(all_targets, dim=0)
    
            mrr_100 = compute_mrr_at_k_batched(all_preds.to(DEVICE), all_targets.to(DEVICE), k=100, batch_size=128)
            model.eval()
            print(all_preds.shape,all_targets.shape)
            print(f"Epoch {epoch+1}: Train={train_loss:.4f}, Val={val_loss:.4f}, MRR@100={mrr_100:.4f}", evaluate_retrieval(all_preds.cpu(), all_targets.cpu(), np.arange(len(all_preds))))
            
            if train_loss > val_loss:
                best_mrr = mrr_100
                patience_counter = 0
                Path(model_path).parent.mkdir(parents=True, exist_ok=True)
                torch.save(model.state_dict(), model_path)
                print(f"  New best: {mrr_100:.4f}")
            else:
                patience_counter += 1
    
            del all_preds, all_targets
            torch.cuda.empty_cache()
        else:
            print(f"Epoch {epoch+1}: Train={train_loss:.4f}, Val={val_loss:.4f}")
    
        if epoch < warmup_epochs:
            scheduler_warmup.step()
        else:
            scheduler_cosine.step()
    
        if patience_counter >= patience:
            print(f"Early stopping at epoch {epoch+1}")
            break

In [None]:
def create_model(model_par):
    print("\n2. Building model...")

    model = model_par["CREATE_MODEL"](model_par)
    
    print(f"Parameters: {sum(p.numel() for p in model.parameters()):,}")
    return model

def create_Enc(model_par):
    return EncoderMLP(
        input_dim=model_par["INPUT_DIM"],
        output_dim=model_par["OUTPUT_DIM"],
        hidden_dims = model_par["HIDDEN_DIMS"],
        dropout = model_par["DROPOUT"],
    ).to(DEVICE)

class EncoderMLP(nn.Module):
    """MLP for projecting embeddings to shared space"""
    
    def __init__(self, input_dim, output_dim, hidden_dims, dropout):
        super().__init__()
        print("Creating ProjectionMLP")
        layers = []
        prev_dim = input_dim
        
        for i, hidden_dim in enumerate(hidden_dims):
            layers.extend([
                nn.Linear(prev_dim, hidden_dim),
                nn.BatchNorm1d(hidden_dim),
                nn.GELU(),
                nn.Dropout(dropout if i < len(hidden_dims) - 1 else dropout * 0.5)
            ])
            prev_dim = hidden_dim
        
        layers.extend([
            nn.Linear(prev_dim, output_dim),
            nn.BatchNorm1d(output_dim)
        ])
        
        self.network = nn.Sequential(*layers)
        self.skip = nn.Linear(input_dim, output_dim)
        self.skip_weight = nn.Parameter(torch.tensor(0.1))
        
    def forward(self, x):
        out = self.network(x)
        out = F.normalize(out, dim=-1)
        skip = F.normalize(self.skip(x), dim=-1)
        return out + self.skip_weight * skip


In [None]:
MLP_ENC_PAR = {"NAME":"ModelEncoder", "HIDDEN_DIMS":[12000, 10000,8192, 4096, 2048],"INPUT_DIM":15724, "OUTPUT_DIM":768, "DROPOUT": 0.2, "CREATE_MODEL":create_Enc}

#Decoder fron space 780
ENC_LOSS_PAR = {"NAME":"COMB_TRIPLET_AND_CONTR", "FUNC": combined_loss, "ARG":{"TEMPERATURE":0.7, "ALPHA":0.4, "MARGIN":0.5, "BETA":0.3}}

#TRAINING
ENC_TRAINING_PAR = {"LR":0.0005, "WARMUP":5, "EPOCHS": 50, "WEIGHT_DEC":0.01, "FLOW":False}


#DATASET
ENC_DATASET_PAR = {"BATCH_SIZE":4096*2, "TRAIN_SIZE":0.9}

#FILES
DATASET_PATH = "/kaggle/input/d/niccolosici/aml-dataset/train/train.npz"
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
AUGMENTATION_PAR = {"USE_AUGMENTATION":False}

In [None]:
# ============================================================================
# SECTION 7: MAIN EXECUTION
# ============================================================================
if __name__ == "__main__":
    
    # ===== STEP 1: Download Data =====
    betas_path, images_path, behav_path = get_dataset_paths()
    
    # ===== STEP 2: Create Dataset =====
    dataset = MindEyeDataset(
        betas_file=betas_path,
        images_file=images_path,
        behav_file=behav_path
    )
    
    # ===== STEP 3: Create Train/Val Split =====
    train_indices, val_indices = create_splits(dataset, train_ratio=0.8)
    
    train_dataset = Subset(dataset, train_indices)
    val_dataset = Subset(dataset, val_indices)
    
    # ===== STEP 4: Load CLIP =====
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    clip_extractor = CLIPEmbeddingsExtractor(device=device)
    
    # ===== STEP 5: Build CLIP Embeddings Cache =====
    print("\n" + "="*60)
    print("BUILDING CLIP EMBEDDINGS CACHE")
    print("="*60)
    
    clip_embeddings = build_clip_cache(
        dataset=dataset,
        clip_extractor=clip_extractor,
        cache_path='clip_embeddings_subj01.pt',
        batch_size=64
    )
    
    # ===== STEP 6: Create Final Datasets =====
    train_dataset_clip = MindEyeWithCLIP(train_dataset, clip_embeddings[train_indices])
    val_dataset_clip = MindEyeWithCLIP(val_dataset, clip_embeddings[val_indices])
    
    # ===== STEP 7: Test DataLoaders =====
    print("\n" + "="*60)
    print("TESTING DATALOADERS")
    print("="*60)
    
    train_loader = DataLoader(train_dataset_clip, batch_size=32, 
                              shuffle=True, num_workers=0)
    val_loader = DataLoader(val_dataset_clip, batch_size=32, 
                            shuffle=False, num_workers=0)

    
    encoder_model = create_model(MLP_ENC_PAR)
    encoder_model = train_model(AUGMENTATION_PAR,ENC_TRAINING_PAR, ENC_LOSS_PAR, encoder_model,train_loader, val_loader, DEVICE,"./encoderF2C.pth")
    del encoder_model