In [3]:
from torch.utils.data import Dataset
import os
import h5py
import numpy as np
from collections import defaultdict
import torch
from tqdm import tqdm

In [163]:
def count_total_files(features_dir, movie_list):
    total_files = 0
    for movie in movie_list:
        if 'friends' in movie:
            season = movie.split('-')[1]
            pattern = f"{season}e"
        else:
            movie_name = movie.replace('movie10-', '')
            pattern = movie_name
            
        files = [f for f in os.listdir(features_dir + 'audio') 
                if '_features_' in f and pattern in f]
        total_files += len(files)
            
    return total_files

def get_fmri_files(subject, task_type, fmri_dir):
    subject_dir = os.path.join(fmri_dir, f'sub-{subject}/func/')
    files = [f for f in os.listdir(subject_dir) if f'sub-{subject}_task-{task_type}' in f]
    if len(files) != 1:
        raise ValueError(f"Expected 1 fMRI file for subject {subject} and task {task_type}, found {len(files)}")
    return os.path.join(subject_dir, files[0])

def load_fmri(root_data_dir, subject):
    fmri = {}

    fmri_file = f'sub-0{subject}_task-friends_space-MNI152NLin2009cAsym_atlas-Schaefer18_parcel-1000Par7Net_desc-s123456_bold.h5'
    fmri_dir = os.path.join(root_data_dir, f'sub-0{subject}', 'func', fmri_file)
    fmri_friends = h5py.File(fmri_dir, 'r')
    for key, val in fmri_friends.items():
        fmri[str(key[13:])] = val[:].astype(np.float32)
    del fmri_friends

    fmri_file = f'sub-0{subject}_task-movie10_space-MNI152NLin2009cAsym_atlas-Schaefer18_parcel-1000Par7Net_bold.h5'
    fmri_dir = os.path.join(root_data_dir, f'sub-0{subject}', 'func', fmri_file)
    fmri_movie10 = h5py.File(fmri_dir, 'r')
    for key, val in fmri_movie10.items():
        fmri[key[13:]] = val[:].astype(np.float32)
    del fmri_movie10

    keys_all = fmri.keys()
    figures_splits = 12
    for s in range(figures_splits):
        movie = 'figures' + format(s+1, '02')
        keys_movie = [rep for rep in keys_all if movie in rep]
        fmri[movie] = ((fmri[keys_movie[0]] + fmri[keys_movie[1]]) / 2).astype(np.float32)
        del fmri[keys_movie[0]]
        del fmri[keys_movie[1]]

    keys_all = fmri.keys()
    life_splits = 5
    for s in range(life_splits):
        movie = 'life' + format(s+1, '02')
        keys_movie = [rep for rep in keys_all if movie in rep]
        fmri[movie] = ((fmri[keys_movie[0]] + fmri[keys_movie[1]]) / 2).astype(np.float32)
        del fmri[keys_movie[0]]
        del fmri[keys_movie[1]]

    return fmri

class AlgonautsDataset(Dataset):
    def __init__(self, features_dir, fmri_dir, movies, subjects, excluded_samples_start=5, excluded_samples_end=5, hrf_delay=3, stimulus_window=5):
        self.features_dir = features_dir
        self.fmri_dir = fmri_dir
        self.movies = movies
        self.subjects = subjects
        self.excluded_samples_start = excluded_samples_start
        self.excluded_samples_end = excluded_samples_end
        self.hrf_delay = hrf_delay
        self.stimulus_window = stimulus_window
        self.partition_indices = defaultdict(list)
        
        audio_data, video_data, language_data, fmri_data = [], [], [], []
        current_idx = 0
        
        # Load fMRI data for current subject
        for subject in self.subjects:
            fmri_dict = load_fmri(self.fmri_dir, subject)
        
        total_files = count_total_files(self.features_dir, self.movies)
        pbar = tqdm(desc='Loading dataset', total=total_files)
        
        for movie in self.movies:
            start_idx = current_idx
            if 'friends' in movie:
                season = movie.split('-')[1]
                dir_list = sorted(os.listdir(self.features_dir + 'audio'))
                for episode in dir_list:
                    if f"{season}e" in episode and '_features_' in episode:
                        episode_base = episode.split('_features_')[0]
                        
                        features = {'audio': None, 'visual': None, 'language': None}
                        
                        # Load all modalities
                        for modality in ['audio', 'visual', 'language']:
                            with h5py.File(os.path.join(self.features_dir, modality, f"{episode_base}_features_{modality}.h5"), 'r') as f:
                                try:
                                    key = 'language_pooler_output' if modality == 'language' else modality
                                    features[modality] = f[episode_base.split('_')[1]][key][:]
                                except:
                                    f.visit(lambda x: print(x))
                                    
                        # Get fMRI data
                        try:
                            fmri = fmri_dict[episode_base.split("_")[1]]
                        except:
                            print(fmri_dict.keys())
                            print("Current key:", episode_base.split("_")[1])
                            
                        # Align features with fMRI using sliding window
                        if all(v is not None for v in features.values()):
                            valid_fmri = fmri[excluded_samples_start:-excluded_samples_end]
                            for s in range(len(valid_fmri)):
                                aligned_features = []
                                
                                # Handle audio and video with sliding window
                                for mod in ['audio', 'visual']:
                                    if s < (stimulus_window + hrf_delay):
                                        idx_start = excluded_samples_start
                                        idx_end = idx_start + stimulus_window
                                    else:
                                        idx_start = s + excluded_samples_start - hrf_delay - stimulus_window + 1
                                        idx_end = idx_start + stimulus_window
                                        
                                    if idx_end > len(features[mod]):
                                        idx_end = len(features[mod])
                                        idx_start = idx_end - stimulus_window
                                        
                                    feat = features[mod][idx_start:idx_end]
                                    aligned_features.append(feat)
                                
                                # Handle language features
                                idx = s + excluded_samples_start - hrf_delay
                                if idx >= len(features['language']):
                                    lang_feat = features['language'][-1]
                                else:
                                    lang_feat = features['language'][idx]
                                aligned_features.append(lang_feat)
                                
                                # Store aligned features
                                audio_data.append(aligned_features[0])
                                video_data.append(aligned_features[1])
                                language_data.append(aligned_features[2])
                                fmri_data.append(fmri[s + excluded_samples_start])
                                current_idx += 1
                                
                        pbar.update(1)
                        
            else:
                # Handle movies similarly...
                movie_name = movie.replace('movie10-', '')
                partitions = sorted([f for f in os.listdir(self.features_dir + 'audio') if movie_name in f and '_features_' in f])
                
                for partition in partitions:
                    partition_base = partition.split('_features_')[0]
                    
                    features = {'audio': None, 'visual': None, 'language': None}
                    
                    # Load features
                    for modality in ['audio', 'visual', 'language']:
                        with h5py.File(os.path.join(self.features_dir, modality, f"{partition_base}_features_{modality}.h5"), 'r') as f:
                            try:
                                key = 'language_pooler_output' if modality == 'language' else modality
                                features[modality] = f[partition_base][key][:]
                            except:
                                f.visit(lambda x: print(x))
                                
                    # Get fMRI data
                    fmri = fmri_dict[partition_base]
                    
                    # Align features with fMRI
                    if all(v is not None for v in features.values()):
                        valid_fmri = fmri[excluded_samples_start:-excluded_samples_end]
                        for s in range(len(valid_fmri)):
                            aligned_features = []
                            
                            # Handle audio and video
                            for mod in ['audio', 'visual']:
                                if s < (stimulus_window + hrf_delay):
                                    idx_start = excluded_samples_start
                                    idx_end = idx_start + stimulus_window
                                else:
                                    idx_start = s + excluded_samples_start - hrf_delay - stimulus_window + 1
                                    idx_end = idx_start + stimulus_window
                                    
                                if idx_end > len(features[mod]):
                                    idx_end = len(features[mod])
                                    idx_start = idx_end - stimulus_window
                                    
                                feat = features[mod][idx_start:idx_end]
                                aligned_features.append(feat)
                            
                            # Handle language
                            idx = s + excluded_samples_start - hrf_delay
                            if idx >= len(features['language']):
                                lang_feat = features['language'][-1]
                            else:
                                lang_feat = features['language'][idx]
                            aligned_features.append(lang_feat)
                            
                            # Store aligned features
                            audio_data.append(aligned_features[0])
                            video_data.append(aligned_features[1])
                            language_data.append(aligned_features[2])
                            fmri_data.append(fmri[s + excluded_samples_start])
                            current_idx += 1
                            
                    pbar.update(1)
            
            self.partition_indices[movie] = (start_idx, current_idx)
        
        pbar.close()
        
        # Convert to tensors
        self.audio = torch.from_numpy(np.stack(audio_data))
        self.video = torch.from_numpy(np.stack(video_data))
        self.language = torch.from_numpy(np.stack(language_data))
        self.fmri = torch.from_numpy(np.stack(fmri_data))
        
        # Verify loading order
        for movie in self.movies:
            start, end = self.partition_indices[movie]
            assert start < end, f"Invalid index range for {movie}"
            
    def __len__(self):
        return len(self.audio)

    def __getitem__(self, idx):
        return {
            'audio': self.audio[idx],
            'video': self.video[idx],
            'language': self.language[idx],
            'fmri': self.fmri[idx]
        }
        
    def get_partition_indices(self):
        return self.partition_indices


In [177]:
features_dir = '/home/pranav/mihir/algonauts_challenge/AlgonautsDS-features/developer_kit/stimulus_features/raw/'
fmri_dir = '/home/pranav/mihir/algonauts_challenge/algonauts_2025.competitors/fmri/'
movies_train = ["friends-s01", "friends-s02", "friends-s03", "friends-s04", "friends-s05", "movie10-bourne", "movie10-figures", "movie10-life", "movie10-wolf"]
movies_val = ["friends-s06"]
modality = "all"  #@param ["visual", "audio", "language", "all"]

excluded_samples_start = 5  #@param {type:"slider", min:0, max:20, step:1}
excluded_samples_end = 5  #@param {type:"slider", min:0, max:20, step:1}
hrf_delay = 3  #@param {type:"slider", min:0, max:10, step:1}
stimulus_window = 1

train_subjects = ['1'] #@param ["1", "2", "3", "5"] {type:"raw", allow-input: true}
test_subjects = ['2']

train_ds = AlgonautsDataset(features_dir, fmri_dir, movies=movies_train, subjects=train_subjects, excluded_samples_start=excluded_samples_start, excluded_samples_end=excluded_samples_end, hrf_delay=hrf_delay, stimulus_window=stimulus_window)
val_ds = AlgonautsDataset(features_dir, fmri_dir, movies=movies_val, subjects=test_subjects, excluded_samples_start=excluded_samples_start, excluded_samples_end=excluded_samples_end, hrf_delay=hrf_delay, stimulus_window=stimulus_window)

Loading dataset: 100%|██████████| 286/286 [00:00<00:00, 323.70it/s]
Loading dataset: 100%|██████████| 50/50 [00:00<00:00, 321.29it/s]


In [179]:
import torch.nn as nn
import torch.nn.functional as F
import lightning as L
from vector_quantize_pytorch import VectorQuantize
from torchmetrics.functional import pearson_corrcoef

class ResidualBlock(nn.Module):
   def __init__(self, in_dim, out_dim):
       super().__init__()
       self.downsample = in_dim != out_dim
       self.net = nn.Sequential(
           nn.Linear(in_dim, out_dim),
           nn.LayerNorm(out_dim),
           nn.GELU(),
           nn.Linear(out_dim, out_dim),
           nn.LayerNorm(out_dim),
           nn.GELU()
       )
       if self.downsample:
           self.proj = nn.Linear(in_dim, out_dim)
   
   def forward(self, x):
       if self.downsample:
           return self.proj(x) + self.net(x)
       return x + self.net(x)

class Encoder(nn.Module):
   def __init__(self, input_dim=1000, hidden_dims=[512, 384, 256], num_tokens=32, codebook_dim=64):
       super().__init__()
       
       # Initial projection with one residual block
       self.input_proj = ResidualBlock(input_dim, hidden_dims[0])
       
       # Main network with one residual block per layer
       layers = []
       for i in range(len(hidden_dims)-1):
           layers.append(ResidualBlock(hidden_dims[i], hidden_dims[i+1]))
       self.layers = nn.Sequential(*layers)
       
       # Project to token space with one residual block
       self.token_proj = ResidualBlock(hidden_dims[-1], num_tokens * codebook_dim)
       
       self.num_tokens = num_tokens
       self.codebook_dim = codebook_dim
       
   def forward(self, x):
       x = self.input_proj(x)
       x = self.layers(x)
       x = self.token_proj(x)
       return x.view(x.shape[0], self.num_tokens, self.codebook_dim)

class Decoder(nn.Module):
   def __init__(self, output_dim=1000, hidden_dims=[256, 384, 512], num_tokens=32, codebook_dim=64):
       super().__init__()
       
       # Process tokens with one residual block
       self.token_proj = ResidualBlock(num_tokens * codebook_dim, hidden_dims[0])
       
       # Main network with one residual block per layer
       layers = []
       for i in range(len(hidden_dims)-1):
           layers.append(ResidualBlock(hidden_dims[i], hidden_dims[i+1]))
       self.layers = nn.Sequential(*layers)
       
       # Final projection with one residual block
       self.output_proj = ResidualBlock(hidden_dims[-1], output_dim)
       
   def forward(self, x):
       # x shape: [batch_size, num_tokens, codebook_dim] 
       x = x.reshape(x.shape[0], -1)  # Flatten tokens
       x = self.token_proj(x)
       x = self.layers(x)
       return self.output_proj(x)

class VQVAE(L.LightningModule):
    def __init__(
            self, 
            input_dim=1000, 
            hidden_dims=[512, 384, 256], 
            num_tokens=32, 
            codebook_size=1024, 
            codebook_dim=8,
            commitment_weight=0.25,
            quantizer_decay=0.99,
            learning_rate=3e-4,
            weight_decay=0.01
            ):
        super().__init__()
        
        self.encoder = Encoder(input_dim, hidden_dims, num_tokens, codebook_dim)
        self.decoder = Decoder(input_dim, hidden_dims[::-1], num_tokens, codebook_dim)
        self.quantizer = VectorQuantize(
                dim=codebook_dim,
                codebook_size=codebook_size,
                decay=quantizer_decay,
                commitment_weight=commitment_weight
                )
        self.learning_rate = learning_rate
        self.weight_decay = weight_decay
        self.save_hyperparameters()
        
    def forward(self, x):
        z = self.encoder(x)
        z_q, indices, commitment_loss = self.quantizer(z)
        x_recon = self.decoder(z_q)
        return x_recon, commitment_loss, indices
    
    def encode(self, x):
        z = self.encoder(x)
        _, indices, _ = self.quantizer(z)
        return indices
        
    def decode(self, indices):
        z_q = self.quantizer.get_codes_from_indices(indices)
        return self.decoder(z_q)
    
class MultiModalMLP(L.LightningModule):
    def __init__(self, video_dim, audio_dim, text_dim, hidden_dim, num_tokens, fmri_tok_dir, learning_rate, weight_decay):
        super().__init__()
        
        # Dimensions for each modality
        self.video_dim = video_dim 
        self.audio_dim = audio_dim
        self.text_dim = text_dim
        self.fmri_tokenizer = VQVAE.load_from_checkpoint(fmri_tok_dir)
        
        for param in self.fmri_tokenizer.parameters():
            param.requires_grad = False
        self.fmri_tokenizer.eval()
        
        # Trainable token for missing text
        self.missing_text_token = nn.Parameter(torch.randn(1, text_dim))
        
        # MLP layers
        total_dim = video_dim + audio_dim + text_dim
        self.mlp = nn.Sequential(
            nn.Linear(total_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(hidden_dim, num_tokens)
        )
        self.learning_rate = learning_rate
        self.weight_decay = weight_decay
        self.save_hyperparameters()

    def forward(self, video, audio, text):
        """
        Args:
            video: [batch_size, video_dim]
            audio: [batch_size, audio_dim] 
            text: [batch_size, text_dim]
        """
        batch_size = video.shape[0]
        # Generate text mask based on NaN values
        text_mask = ~torch.isnan(text).any(dim=1, keepdim=True)
        
        # Replace missing text with learned token
        text = torch.where(
            text_mask,
            text,
            self.missing_text_token.expand(batch_size, -1)
        )
        
        # Concatenate all features
        x = torch.cat([video, audio, text], dim=1)
        
        # Predict tokens
        return self.mlp(x)

    def calculate_metrics(self, x, x_recon):
        # Flatten the tensors for correlation calculation
        x_flat = x.reshape(x.shape[0], -1)
        x_recon_flat = x_recon.reshape(x_recon.shape[0], -1)
        
        # Calculate Pearson R for each sample in batch
        correlations = torch.stack([
            pearson_corrcoef(x_flat[i], x_recon_flat[i])
            for i in range(x_flat.shape[0])
        ])
        avg_pearson_r = correlations.mean()
        
        # Calculate variance explained
        total_variance = torch.var(x_flat, dim=1).sum()
        residual_variance = torch.var(x_flat - x_recon_flat, dim=1).sum()
        variance_explained = 1 - (residual_variance / total_variance)
        
        # Calculate MSE and MAE
        mse = F.mse_loss(x_flat, x_recon_flat)
        mae = F.l1_loss(x_flat, x_recon_flat)
        
        return avg_pearson_r, variance_explained, mse, mae
        

    def training_step(self, batch):
        video, audio, text, fmri = batch['video'], batch['audio'], batch['language'], batch['fmri']
        logits = self(video, audio, text)
        fmri_tokens = self.fmri_tokenizer.encode(fmri)

        loss = F.cross_entropy(logits, fmri_tokens)
        recon_fmri = self.fmri_tokenizer.decode(logits)

        avg_pearson_r, variance_explained, mse, mae = self.calculate_metrics(fmri, recon_fmri)
        self.log('train_loss', loss)
        self.log('train_pearson_r', avg_pearson_r)
        self.log('train_variance_explained', variance_explained)
        self.log('train_mse', mse)
        self.log('train_mae', mae)
        return loss
    
    def validation_step(self, batch):
        video, audio, text, fmri = batch['video'], batch['audio'], batch['language'], batch['fmri']
        logits = self(video, audio, text)
        fmri_tokens = self.fmri_tokenizer.encode(fmri)
        
        val_loss = F.cross_entropy(logits, fmri_tokens)
        val_recon_fmri = self.fmri_tokenizer.decode(logits)
        val_avg_pearson_r, val_variance_explained, val_mse, val_mae = self.calculate_metrics(fmri, val_recon_fmri)
        self.log('val_loss', val_loss)
        self.log('val_pearson_r', val_avg_pearson_r)
        self.log('val_variance_explained', val_variance_explained)
        self.log('val_mse', val_mse)
        self.log('val_mae', val_mae)

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(
            self.parameters(), 
            lr=self.learning_rate, 
            betas=(0.9, 0.999), 
            eps=1e-8, 
            weight_decay=self.weight_decay)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=100)
        return {
            "optimizer": optimizer,
            "lr_scheduler": {
                "scheduler": scheduler,
                "monitor": "val_loss"
            }
        }

In [172]:
# Example usage:

batch_size = 4
video_dim = 8192
audio_dim = 128
text_dim = 768
hidden_dim = 1024
num_tokens = 32
learning_rate = 1e-4
weight_decay = 0.01 
fmri_tok_dir = "/home/pranav/mihir/algonauts_challenge/algonauts2025/checkpoints/fmri_vqvae_res_gelu_8qdim_1024/epoch=44_val_loss=0.137.ckpt"

train_dl = torch.utils.data.DataLoader(train_ds, batch_size=batch_size, shuffle=True)
val_dl = torch.utils.data.DataLoader(val_ds, batch_size=batch_size, shuffle=False)

Output shape: torch.Size([32, 32])


In [180]:
print("Training samples: ", len(train_ds))
print("Validation samples: ", len(val_ds))

Training samples:  129516
Validation samples:  22924


In [None]:
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.loggers import WandbLogger

run_name = "MM_MLP_1024dim"
wandb_logger = WandbLogger(
    project="algonauts2025",
    name=run_name,
    save_dir="wandb_logs/"
)
trainer = L.Trainer(
    max_epochs=50,
    precision=32,
    logger=wandb_logger,
    callbacks=[
        ModelCheckpoint(
            dirpath=f'checkpoints/{run_name}',
            filename='{epoch:02d}_{val_loss:.3f}',
            monitor='val_loss',
            mode='min',
            save_top_k=1,
            save_last=True
        )
    ]
)

In [None]:
from lightning.pytorch.utilities.model_summary import ModelSummary
torch.set_float32_matmul_precision('high')
mm_mlp = MultiModalMLP(
    video_dim=video_dim,
    audio_dim=audio_dim,
    text_dim=text_dim,
    hidden_dim=hidden_dim,
    num_tokens=num_tokens,
    fmri_tok_dir=fmri_tok_dir,
    learning_rate=learning_rate,
    weight_decay=weight_decay
)
summary = ModelSummary(mm_mlp, max_depth=2)
print(summary)

In [None]:
trainer.fit(mm_mlp, train_dl, val_dl)