In [1]:
from torch.utils.data import Dataset
import os
import h5py
import numpy as np
from collections import defaultdict
import torch
from tqdm import tqdm
import torch.nn as nn
import torch.nn.functional as F
import lightning as L
from vector_quantize_pytorch import VectorQuantize
from torchmetrics.functional import pearson_corrcoef
from pathlib import Path

from utils import preprocess_features, count_total_files, load_fmri, perform_pca, align_features_and_fmri_samples

In [None]:
# class AlgonautsDataset(Dataset):
#     def __init__(self, features_dir, fmri_dir, movies, subjects, excluded_samples_start=5, excluded_samples_end=5, hrf_delay=3, stimulus_window=5):
#         self.features_dir = features_dir
#         self.fmri_dir = fmri_dir
#         self.movies = movies
#         self.subjects = subjects
#         self.excluded_samples_start = excluded_samples_start
#         self.excluded_samples_end = excluded_samples_end
#         self.hrf_delay = hrf_delay
#         self.stimulus_window = stimulus_window
#         self.partition_indices = defaultdict(list)
        
#         audio_data, video_data, language_data, fmri_data = [], [], [], []
#         current_idx = 0
        
#         # Load fMRI data for current subject
#         for subject in self.subjects:
#             fmri_dict = load_fmri(self.fmri_dir, subject)
        
#         total_files = count_total_files(self.features_dir, self.movies)
#         pbar = tqdm(desc='Loading dataset', total=total_files)
        
#         for movie in self.movies:
#             start_idx = current_idx
#             test_arr = []
#             fmri_test = []
#             if 'friends' in movie:
#                 season = movie.split('-')[1]
#                 dir_list = sorted(os.listdir(self.features_dir + 'audio')) #List of all audio for each subset of dataset
#                 for episode in dir_list:
#                     if f"{season}e" in episode and '_features_' in episode:
#                         episode_base = episode.split('_features_')[0] # friends_s01e01 and so on....
                        
#                         features = {'audio': None, 'visual': None, 'language': None}
                        
#                         # Load all modalities
#                         for modality in ['audio', 'visual', 'language']:
#                             with h5py.File(os.path.join(self.features_dir, modality, f"{episode_base}_features_{modality}.h5"), 'r') as f:
#                                 try:
#                                     key = 'language_pooler_output' if modality == 'language' else modality
#                                     st_season_episode = episode_base.split('_')[1]
#                                     test_arr.append(st_season_episode)
#                                     features[modality] = f[st_season_episode][key][:]
#                                 except:
#                                     f.visit(lambda x: print(x))
                                    
#                         # Get fMRI data
#                         try:
#                             fmri_season_episode = episode_base.split("_")[1]
#                             fmri = fmri_dict[fmri_season_episode]
#                             assert fmri_season_episode == st_season_episode
#                         except:
#                             print(fmri_dict.keys())
#                             print("Current key:", episode_base.split("_")[1])
                            
#                         # Align features with fMRI using sliding window
#                         if all(v is not None for v in features.values()):
#                             valid_fmri = fmri[excluded_samples_start:-excluded_samples_end]
#                             for s in range(len(valid_fmri)):
#                                 aligned_features = []
                                
#                                 # Handle audio and video with sliding window
#                                 for mod in ['audio', 'visual']:
#                                     if s < (stimulus_window + hrf_delay):
#                                         idx_start = excluded_samples_start
#                                         idx_end = idx_start + stimulus_window
#                                     else:
#                                         idx_start = s + excluded_samples_start - hrf_delay - stimulus_window + 1
#                                         idx_end = idx_start + stimulus_window
                                        
#                                     if idx_end > len(features[mod]):
#                                         idx_end = len(features[mod])
#                                         idx_start = idx_end - stimulus_window
                                        
#                                     feat = features[mod][idx_start:idx_end]
#                                     aligned_features.append(feat)
                                
#                                 # Handle language features
#                                 idx = s + excluded_samples_start - hrf_delay
#                                 if idx >= len(features['language']):
#                                     lang_feat = features['language'][-1]
#                                 else:
#                                     lang_feat = features['language'][idx]
#                                 aligned_features.append(lang_feat)
                                
#                                 # Store aligned features
#                                 audio_data.append(aligned_features[0])
#                                 video_data.append(aligned_features[1])
#                                 language_data.append(aligned_features[2])
#                                 fmri_data.append(fmri[s + excluded_samples_start])
#                                 current_idx += 1
                                
#                         pbar.update(1)
                        
#             else:
#                 # Handle movies similarly...
#                 movie_name = movie.replace('movie10-', '')
#                 partitions = sorted([f for f in os.listdir(self.features_dir + 'audio') if movie_name in f and '_features_' in f])
                
#                 for partition in partitions:
#                     partition_base = partition.split('_features_')[0]
                    
#                     features = {'audio': None, 'visual': None, 'language': None}
                    
#                     # Load features
#                     for modality in ['audio', 'visual', 'language']:
#                         with h5py.File(os.path.join(self.features_dir, modality, f"{partition_base}_features_{modality}.h5"), 'r') as f:
#                             try:
#                                 key = 'language_pooler_output' if modality == 'language' else modality
#                                 features[modality] = f[partition_base][key][:]
#                                 test_arr.append(partition_base)
#                             except:
#                                 f.visit(lambda x: print(x))
                                
#                     # Get fMRI data
#                     fmri = fmri_dict[partition_base]
#                     fmri_test.append(partition_base)
#                     # Align features with fMRI
#                     if all(v is not None for v in features.values()):
#                         valid_fmri = fmri[excluded_samples_start:-excluded_samples_end]
#                         for s in range(len(valid_fmri)):
#                             aligned_features = []
                            
#                             # Handle audio and video
#                             for mod in ['audio', 'visual']:
#                                 if s < (stimulus_window + hrf_delay):
#                                     idx_start = excluded_samples_start
#                                     idx_end = idx_start + stimulus_window
#                                 else:
#                                     idx_start = s + excluded_samples_start - hrf_delay - stimulus_window + 1
#                                     idx_end = idx_start + stimulus_window
                                    
#                                 if idx_end > len(features[mod]):
#                                     idx_end = len(features[mod])
#                                     idx_start = idx_end - stimulus_window
                                    
#                                 feat = features[mod][idx_start:idx_end]
#                                 aligned_features.append(feat)
                            
#                             # Handle language
#                             idx = s + excluded_samples_start - hrf_delay
#                             if idx >= len(features['language']):
#                                 lang_feat = features['language'][-1]
#                             else:
#                                 lang_feat = features['language'][idx]
#                             aligned_features.append(lang_feat)
                            
#                             # Store aligned features
#                             audio_data.append(aligned_features[0])
#                             video_data.append(aligned_features[1])
#                             language_data.append(aligned_features[2])
#                             fmri_data.append(fmri[s + excluded_samples_start])
#                             current_idx += 1
                            
#                     pbar.update(1)
            
#             self.partition_indices[movie] = (start_idx, current_idx)
#         pbar.close()
#         # Convert to tensors

#         video_preproc = preprocess_features(np.stack(video_data).reshape(-1, 8192))
#         language_preproc = preprocess_features(np.stack(language_data).reshape(-1, 768))
#         print("Video shape: ", video_preproc.shape)
#         print("Language shape: ", language_preproc.shape)

#         video_pca = perform_pca(video_preproc, n_components=250)
#         language_pca = perform_pca(language_preproc, n_components=250)
#         print("Video shape: ", video_pca.shape)
#         print("Language shape: ", language_pca.shape)
#         import sys; sys.exit()

#         self.audio = torch.from_numpy(np.stack(audio_data))
#         self.video = torch.from_numpy(video_pca)
#         self.language = torch.from_numpy(language_pca)
#         self.fmri = torch.from_numpy(np.stack(fmri_data))
        
#         # Verify loading order
#         for movie in self.movies:
#             start, end = self.partition_indices[movie]
#             assert start < end, f"Invalid index range for {movie}"
            
#     def __len__(self):
#         return len(self.audio)

#     def __getitem__(self, idx):
#         return {
#             'audio': self.audio[idx],
#             'video': self.video[idx],
#             'language': self.language[idx],
#             'fmri': self.fmri[idx]
#         }
        
#     def get_partition_indices(self):
#         return self.partition_indices

In [51]:
subject = '1'
fmri_dir = '/home/pranav/mihir/algonauts_challenge/algonauts_2025.competitors/fmri/'
fmri = load_fmri(fmri_dir, subject)


In [56]:
print(f"Subject {subject} fMRI movies splits name and shape:")
[print(f"{k} {v.shape}") or v for k,v in fmri.items()]; print(np.concatenate([v for k,v in fmri.items()], axis=0).shape)


Subject 1 fMRI movies splits name and shape:
s01e02a (482, 1000)
s01e02b (482, 1000)
s01e03a (472, 1000)
s01e03b (472, 1000)
s01e04a (503, 1000)
s01e04b (503, 1000)
s01e05a (468, 1000)
s01e05b (468, 1000)
s01e01a (592, 1000)
s01e01b (592, 1000)
s01e06a (488, 1000)
s01e06b (488, 1000)
s01e07a (493, 1000)
s01e07b (493, 1000)
s01e08a (476, 1000)
s01e08b (476, 1000)
s01e09a (468, 1000)
s01e09b (468, 1000)
s01e10a (496, 1000)
s01e10b (496, 1000)
s01e11a (477, 1000)
s01e11b (477, 1000)
s01e12a (472, 1000)
s01e12b (471, 1000)
s01e13a (468, 1000)
s01e13b (468, 1000)
s01e14a (491, 1000)
s01e14b (491, 1000)
s01e15a (477, 1000)
s01e15b (477, 1000)
s01e16a (483, 1000)
s01e16b (483, 1000)
s01e17a (483, 1000)
s01e17b (483, 1000)
s01e18a (497, 1000)
s01e18b (497, 1000)
s01e19a (460, 1000)
s01e19b (460, 1000)
s01e20a (460, 1000)
s01e20b (460, 1000)
s01e21a (477, 1000)
s01e21b (477, 1000)
s01e22a (474, 1000)
s01e22b (474, 1000)
s01e23a (462, 1000)
s01e23b (462, 1000)
s01e24a (515, 1000)
s01e24b (515, 1

In [15]:
class AlgonautsDataset(Dataset):
    def __init__(self, features_dir, fmri_dir, movies, subject, excluded_samples_start=5, excluded_samples_end=5, hrf_delay=3, stimulus_window=5):
        self.features_dir = features_dir
        self.fmri_dir = fmri_dir
        self.movies = movies
        self.subject = subject
        self.excluded_samples_start = excluded_samples_start
        self.excluded_samples_end = excluded_samples_end
        self.hrf_delay = hrf_delay
        self.stimulus_window = stimulus_window
        self.partition_indices = defaultdict(list)
        
        # First load all raw features
        stimuli_features = {"visual": {}, "audio": {}, "language": {}}
        raw_audio, raw_video = [], []
        raw_language = []
        
        
        total_files = count_total_files(self.features_dir, self.movies)
        pbar = tqdm(desc='Loading dataset', total=total_files)
        
        # Load audio and video features first
        for movie in self.movies:
            if 'friends' in movie:
                season = movie.split('-')[1]
                dir_list = sorted(os.listdir(self.features_dir + 'audio')) #List of all audio for each subset of dataset
                for episode in dir_list:
                    if f"{season}e" in episode and '_features_' in episode:
                        episode_base = episode.split('_features_')[0] # friends_s01e01 and so on....
                        
                        # Load audio and video features
                        features = {'audio': None, 'visual': None}
                        
                        for modality in ['audio', 'visual']:
                            with h5py.File(os.path.join(self.features_dir, modality, f"{episode_base}_features_{modality}.h5"), 'r') as f:
                                try:
                                    stimuli_features[modality][episode_base.split('_')[1]] = f[episode_base.split('_')[1]][modality][:]
                                except:
                                    f.visit(lambda x: print(x))
                                    
                        # if all(v is not None for v in features.values()):
                        #     raw_audio.extend(features['audio'])
                        #     raw_video.extend(features['visual'])
                        # else:
                        #     print(f"Could not load features for {episode_base}")


                lang_dir_list = sorted(os.listdir(self.features_dir + 'language'))
                for episode in lang_dir_list:
                    if f"{season}e" in episode and '_features_' in episode:
                        episode_base = episode.split('_features_')[0]
                        
                        with h5py.File(os.path.join(self.features_dir, 'language', f"{episode_base}_features_language.h5"), 'r') as f:
                            try:
                                st_season_episode = episode_base.split('_')[1]
                                stimuli_features['language'][st_season_episode] = f[st_season_episode]['language_pooler_output'][:]
                            except:
                                f.visit(lambda x: print(x))
            else:
                movie_name = movie.replace('movie10-', '')
                partitions = sorted([f for f in os.listdir(self.features_dir + 'audio') if movie_name in f and '_features_' in f])
                
                for partition in partitions:
                    partition_base = partition.split('_features_')[0]
                    
                    features = {'audio': None, 'visual': None}
                    
                    for modality in ['audio', 'visual']:
                        with h5py.File(os.path.join(self.features_dir, modality, f"{partition_base}_features_{modality}.h5"), 'r') as f:
                            try:
                                stimuli_features[modality][partition_base] = f[partition_base][modality][:]
                            except:
                                f.visit(lambda x: print(x))

                    # if all(v is not None for v in features.values()):
                    #     raw_audio.extend(features['audio'])
                    #     raw_video.extend(features['visual'])
                    # else:
                    #     print(f"Could not load features for {partition_base}")  
                                
                lang_partitions = sorted([f for f in os.listdir(self.features_dir + 'language') if movie_name in f and '_features_' in f])
                
                for partition in lang_partitions:
                    partition_base = partition.split('_features_')[0]
                    
                    with h5py.File(os.path.join(self.features_dir, 'language', f"{partition_base}_features_language.h5"), 'r') as f:
                        try:
                            stimuli_features['language'][partition_base] = f[partition_base]['language_pooler_output'][:]
                        except:
                            f.visit(lambda x: print(x))





        fmri_data = load_fmri(self.fmri_dir, self.subject)

        feature_train, fmri_train = align_features_and_fmri_samples(
            stimuli_features, 
            fmri_data, 
            self.excluded_samples_start, 
            self.excluded_samples_end, 
            self.hrf_delay, 
            self.stimulus_window, 
            self.movies
        )

        print("Training fMRI responses shape:")
        print(fmri_train.shape)
        print('(Train samples × Parcels)')
        print("\nTraining stimulus features shape:")
        print(feature_train.shape)
        print('(Train samples × Features)')
        # print(stimuli_features.keys())
        # print(stimuli_features['visual'].keys())
        # print(stimuli_features['audio'].keys())
        # print(stimuli_features['language'].keys())
        # print(f"Subject {subject} fMRI movies splits name and shape:")
        # for key, value in fmri_data.items():
        #     print(key + " " + str(value.shape))
        
        # stimuli_features['visual'] = raw_video
        # stimuli_features['audio'] = raw_audio
        # stimuli_features['language'] = raw_language
        # for key_modality, value_modality in stimuli_features.items():
        #     print(f"\n{key_modality} features movie splits name and shape:")
        #     for key_movie, value_movie in value_modality.items():
        #         print(key_movie + " " + str(value_movie.shape))
        import sys; sys.exit()

        pbar.close()

        print("Raw video shape: ", np.stack(stimuli_features['visual']).shape)
        print("Raw audio shape: ", np.stack(stimuli_features['audio']).shape)
        print("Raw language shape: ", np.stack(stimuli_features['language']).shape)
        print("Raw fMRI shape: ", np.stack(fmri_data).shape)
        
        video_preproc = preprocess_features(np.stack(raw_video))
        language_preproc = preprocess_features(np.stack(raw_language))
        print("Video shape: ", video_preproc.shape)
        print("Language shape: ", language_preproc.shape)

        video_pca = perform_pca(video_preproc, n_components=250)
        language_pca = perform_pca(language_preproc, n_components=250)
        print("Video shape after PCA: ", video_pca.shape)   
        print("Language shape after PCA: ", language_pca.shape)

        # Now align features with fMRI using sliding window and apply exclusion
        aligned_audio, aligned_video, aligned_language = [], [], []
        valid_fmri = []
        
        current_pos = 0
        for movie in self.movies:
            start_idx = current_pos
            movie_fmri = fmri_data[start_idx:start_idx+len(raw_audio)]
            
            # Apply exclusion to fMRI
            valid_movie_fmri = movie_fmri[excluded_samples_start:-excluded_samples_end]
            valid_fmri.extend(valid_movie_fmri)
            
            for s in range(len(valid_movie_fmri)):
                # Handle audio and video with sliding window
                if s < (stimulus_window + hrf_delay):
                    idx_start = current_pos
                    idx_end = idx_start + stimulus_window
                else:
                    idx_start = current_pos - hrf_delay - stimulus_window + 1
                    idx_end = idx_start + stimulus_window
                    
                # Align audio features
                if idx_end > len(raw_audio):
                    idx_end = len(raw_audio)
                    idx_start = idx_end - stimulus_window
                aligned_audio.append(raw_audio[idx_start:idx_end])
                
                # Align video features
                if idx_end > len(video_pca):
                    idx_end = len(video_pca)
                    idx_start = idx_end - stimulus_window
                aligned_video.append(video_pca[idx_start:idx_end])
                
                # Handle language features
                idx = current_pos - hrf_delay
                if idx >= len(language_pca):
                    aligned_language.append(language_pca[-1])
                else:
                    aligned_language.append(language_pca[idx])
                    
                current_pos += 1
            
            # Update partition indices after alignment
            self.partition_indices[movie] = (start_idx, current_pos)

        # Convert to tensors
        self.audio = torch.from_numpy(np.stack(aligned_audio))
        self.video = torch.from_numpy(np.stack(aligned_video))
        self.language = torch.from_numpy(np.stack(aligned_language))
        self.fmri = torch.from_numpy(np.stack(valid_fmri))
        
        # Verify loading order
        for movie in self.movies:
            start, end = self.partition_indices[movie]
            assert start < end, f"Invalid index range for {movie}"
            
    def __len__(self):
        return len(self.audio)

    def __getitem__(self, idx):
        return {
            'audio': self.audio[idx],
            'video': self.video[idx],
            'language': self.language[idx],
            'fmri': self.fmri[idx]
        }
        
    def get_partition_indices(self):
        return self.partition_indices


In [16]:
features_dir = '/home/pranav/mihir/algonauts_challenge/AlgonautsDS-features/developer_kit/stimulus_features/raw/'
fmri_dir = '/home/pranav/mihir/algonauts_challenge/algonauts_2025.competitors/fmri/'
movies_train = ["friends-s01"]
# movies_train = ["friends-s01", "friends-s02", "friends-s03", "friends-s04", "friends-s05", "movie10-bourne", "movie10-figures", "movie10-life", "movie10-wolf"]
movies_val = ["friends-s06"]
modality = "all"  #@param ["visual", "audio", "language", "all"]

excluded_samples_start = 5  #@param {type:"slider", min:0, max:20, step:1}
excluded_samples_end = 5  #@param {type:"slider", min:0, max:20, step:1}
hrf_delay = 3  #@param {type:"slider", min:0, max:10, step:1}
stimulus_window = 5

subject = 1 #@param ["1", "2", "3", "5"] {type:"raw", allow-input: true}

train_ds = AlgonautsDataset(features_dir, fmri_dir, movies=movies_train, subject=subject, excluded_samples_start=excluded_samples_start, excluded_samples_end=excluded_samples_end, hrf_delay=hrf_delay, stimulus_window=stimulus_window)
val_ds = AlgonautsDataset(features_dir, fmri_dir, movies=movies_val, subject=subject, excluded_samples_start=excluded_samples_start, excluded_samples_end=excluded_samples_end, hrf_delay=hrf_delay, stimulus_window=stimulus_window)



Training fMRI responses shape:
(22787, 1000)
(Train samples × Parcels)

Training stimulus features shape:
(22787, 42368)
(Train samples × Features)


SystemExit: 

In [35]:
import os
import h5py
import numpy as np

language_dir = '/home/pranav/mihir/algonauts_challenge/AlgonautsDS-features/developer_kit/stimulus_features/raw/audio'

# Load all .h5 files for friends s01
language_features = []
for file in sorted(os.listdir(language_dir)):
    if 'friends' in file and 's01' in file and file.endswith('_features_audio.h5'):
        
        with h5py.File(os.path.join(language_dir, file), 'r') as f:
            feature = f[file.split('_')[1]]['audio'][:]
            print(f"Loading file: {file}, shape: {feature.shape}")
            language_features.append(feature)
            # Note: This loop will run for each .h5 file in friends s01 that matches the pattern
            # Based on the dataset structure, there should be 48 files (one per episode)
            # You may want to add a check:
if len(language_features) != 48:
    print(f"Warning: Expected 48 files but found {len(language_features)}")
else:
    print(f"Loaded {len(language_features)} files for friends s01")

language_features = np.concatenate(language_features, axis=0)
print(f"\nNumber of samples in friends s01 language features: {len(language_features)}")
print(f"Shape of language features: {language_features.shape}")




Loading file: friends_s01e01a_features_audio.h5, shape: (592, 128)
Loading file: friends_s01e01b_features_audio.h5, shape: (591, 128)
Loading file: friends_s01e02a_features_audio.h5, shape: (483, 128)
Loading file: friends_s01e02b_features_audio.h5, shape: (483, 128)
Loading file: friends_s01e03a_features_audio.h5, shape: (473, 128)
Loading file: friends_s01e03b_features_audio.h5, shape: (473, 128)
Loading file: friends_s01e04a_features_audio.h5, shape: (503, 128)
Loading file: friends_s01e04b_features_audio.h5, shape: (503, 128)
Loading file: friends_s01e05a_features_audio.h5, shape: (469, 128)
Loading file: friends_s01e05b_features_audio.h5, shape: (469, 128)
Loading file: friends_s01e06a_features_audio.h5, shape: (489, 128)
Loading file: friends_s01e06b_features_audio.h5, shape: (489, 128)
Loading file: friends_s01e07a_features_audio.h5, shape: (493, 128)
Loading file: friends_s01e07b_features_audio.h5, shape: (493, 128)
Loading file: friends_s01e08a_features_audio.h5, shape: (476, 

In [42]:
file_path = "/home/pranav/mihir/algonauts_challenge/algonauts_2025.competitors/fmri/sub-01/func/sub-01_task-friends_space-MNI152NLin2009cAsym_atlas-Schaefer18_parcel-1000Par7Net_desc-s123456_bold.h5"
fmri_features = []
with h5py.File(file_path, 'r') as f:
    key_arr = sorted(list(f.keys()))
    for i in key_arr:
        if i.split('_')[1].split('-')[1][:3] == 's01':
            output = f[i][:]
            fmri_features.append(output)

if len(fmri_features) != 48:
    print(f"Warning: Expected 48 files but found {len(fmri_features)}")
else:
    print(f"Loaded {len(fmri_features)} files for friends s01")

fmri_features = np.concatenate(fmri_features, axis=0)
print(f"\nNumber of samples in friends s01 language features: {len(fmri_features)}")
print(f"Shape of language features: {fmri_features.shape}")


Loaded 48 files for friends s01

Number of samples in friends s01 language features: 23267
Shape of language features: (23267, 1000)


In [13]:
print(len(train_ds))
print(len(val_ds))

22787
22924


In [15]:
audio, video, language, fmri = train_ds[0]['audio'], train_ds[0]['video'], train_ds[0]['language'], train_ds[0]['fmri']
print(audio.shape)
print(video.shape)
print(language.shape)
print(fmri.shape)

torch.Size([1, 128])
torch.Size([250])
torch.Size([250])
torch.Size([1000])


In [None]:
batch_size = 4
num_workers = 4

train_dl = torch.utils.data.DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=num_workers)
val_dl = torch.utils.data.DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=num_workers)

In [4]:
print(train_ds.get_partition_indices())

defaultdict(<class 'list'>, {'friends-s01': (0, 22787)})


In [47]:
file = "/home/pranav/mihir/algonauts_challenge/algonauts_2025.competitors/fmri/sub-01/func/sub-01_task-friends_space-MNI152NLin2009cAsym_atlas-Schaefer18_parcel-1000Par7Net_desc-s123456_bold.h5"
with h5py.File(file, 'r') as f:
    key_arr = list(f.keys())
    data = []
    data_keys = []
    for key in sorted(key_arr):
        if 's01' in key:
            data_keys.append(key)
            data.append(f[key][:])
    data = np.concatenate(data, axis=0)
    # print(f['sub-01_task-friends_space-MNI152NLin2009cAsym_atlas-Schaefer18_parcel-1000Par7Net_desc-s123456_bold'].shape)
print(sorted(key_arr))

['ses-001_task-s01e02a', 'ses-001_task-s01e02b', 'ses-001_task-s01e03a', 'ses-001_task-s01e03b', 'ses-002_task-s01e04a', 'ses-002_task-s01e04b', 'ses-002_task-s01e05a', 'ses-002_task-s01e05b', 'ses-003_task-s01e01a', 'ses-003_task-s01e01b', 'ses-003_task-s01e06a', 'ses-003_task-s01e06b', 'ses-004_task-s01e07a', 'ses-004_task-s01e07b', 'ses-004_task-s01e08a', 'ses-004_task-s01e08b', 'ses-004_task-s01e09a', 'ses-004_task-s01e09b', 'ses-005_task-s01e10a', 'ses-005_task-s01e10b', 'ses-005_task-s01e11a', 'ses-005_task-s01e11b', 'ses-006_task-s01e12a', 'ses-006_task-s01e12b', 'ses-006_task-s01e13a', 'ses-006_task-s01e13b', 'ses-006_task-s01e14a', 'ses-006_task-s01e14b', 'ses-007_task-s01e15a', 'ses-007_task-s01e15b', 'ses-007_task-s01e16a', 'ses-007_task-s01e16b', 'ses-007_task-s01e17a', 'ses-007_task-s01e17b', 'ses-008_task-s01e18a', 'ses-008_task-s01e18b', 'ses-008_task-s01e19a', 'ses-008_task-s01e19b', 'ses-009_task-s01e20a', 'ses-009_task-s01e20b', 'ses-009_task-s01e21a', 'ses-009_task-s

In [None]:
data.shape

In [None]:
print(sorted(data_keys))

In [None]:
num = 0
for i in train_ds:
    # print(i['fmri'].shape)
    # print(i['audio'].shape)
    # print(i['video'].shape)
    # print(i['language'].shape)
    num += 1
print(num)
# print(i['fmri'].shape)

In [5]:


class ResidualBlock(nn.Module):
   def __init__(self, in_dim, out_dim):
       super().__init__()
       self.downsample = in_dim != out_dim
       self.net = nn.Sequential(
           nn.Linear(in_dim, out_dim),
           nn.LayerNorm(out_dim),
           nn.GELU(),
           nn.Linear(out_dim, out_dim),
           nn.LayerNorm(out_dim),
           nn.GELU()
       )
       if self.downsample:
           self.proj = nn.Linear(in_dim, out_dim)
   
   def forward(self, x):
       if self.downsample:
           return self.proj(x) + self.net(x)
       return x + self.net(x)

class Encoder(nn.Module):
   def __init__(self, input_dim=1000, hidden_dims=[512, 384, 256], num_tokens=32, codebook_dim=64):
       super().__init__()
       
       # Initial projection with one residual block
       self.input_proj = ResidualBlock(input_dim, hidden_dims[0])
       
       # Main network with one residual block per layer
       layers = []
       for i in range(len(hidden_dims)-1):
           layers.append(ResidualBlock(hidden_dims[i], hidden_dims[i+1]))
       self.layers = nn.Sequential(*layers)
       
       # Project to token space with one residual block
       self.token_proj = ResidualBlock(hidden_dims[-1], num_tokens * codebook_dim)
       
       self.num_tokens = num_tokens
       self.codebook_dim = codebook_dim
       
   def forward(self, x):
       x = self.input_proj(x)
       x = self.layers(x)
       x = self.token_proj(x)
       return x.view(x.shape[0], self.num_tokens, self.codebook_dim)

class Decoder(nn.Module):
   def __init__(self, output_dim=1000, hidden_dims=[256, 384, 512], num_tokens=32, codebook_dim=64):
       super().__init__()
       
       # Process tokens with one residual block
       self.token_proj = ResidualBlock(num_tokens * codebook_dim, hidden_dims[0])
       
       # Main network with one residual block per layer
       layers = []
       for i in range(len(hidden_dims)-1):
           layers.append(ResidualBlock(hidden_dims[i], hidden_dims[i+1]))
       self.layers = nn.Sequential(*layers)
       
       # Final projection with one residual block
       self.output_proj = ResidualBlock(hidden_dims[-1], output_dim)
       
   def forward(self, x):
       # x shape: [batch_size, num_tokens, codebook_dim] 
       x = x.reshape(x.shape[0], -1)  # Flatten tokens
       x = self.token_proj(x)
       x = self.layers(x)
       return self.output_proj(x)

class VQVAE(L.LightningModule):
    def __init__(
            self, 
            input_dim=1000, 
            hidden_dims=[512, 384, 256], 
            num_tokens=32, 
            codebook_size=1024, 
            codebook_dim=8,
            commitment_weight=0.25,
            quantizer_decay=0.99,
            learning_rate=3e-4,
            weight_decay=0.01
            ):
        super().__init__()
        
        self.encoder = Encoder(input_dim, hidden_dims, num_tokens, codebook_dim)
        self.decoder = Decoder(input_dim, hidden_dims[::-1], num_tokens, codebook_dim)
        self.quantizer = VectorQuantize(
                dim=codebook_dim,
                codebook_size=codebook_size,
                decay=quantizer_decay,
                commitment_weight=commitment_weight
                )
        self.learning_rate = learning_rate
        self.weight_decay = weight_decay
        self.save_hyperparameters()
        
    def forward(self, x):
        z = self.encoder(x)
        z_q, indices, commitment_loss = self.quantizer(z)
        x_recon = self.decoder(z_q)
        return x_recon, commitment_loss, indices
    
    def encode(self, x):
        z = self.encoder(x)
        _, indices, _ = self.quantizer(z)
        return indices
        
    def decode(self, indices):
        z_q = self.quantizer.get_codes_from_indices(indices)
        return self.decoder(z_q)
    
class MultiModalMLP(L.LightningModule):
    def __init__(self, video_dim, audio_dim, text_dim, hidden_dim, num_tokens, fmri_tok_dir, learning_rate, weight_decay):
        super().__init__()
        
        # Dimensions for each modality
        self.video_dim = video_dim 
        self.audio_dim = audio_dim
        self.text_dim = text_dim
        self.fmri_tokenizer = VQVAE.load_from_checkpoint(fmri_tok_dir)
        
        for param in self.fmri_tokenizer.parameters():
            param.requires_grad = False
        self.fmri_tokenizer.eval()
        
        # Trainable token for missing text
        self.missing_text_token = nn.Parameter(torch.randn(1, text_dim))
        
        # MLP layers
        total_dim = video_dim + audio_dim + text_dim
        self.mlp = nn.Sequential(
            nn.Linear(total_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(hidden_dim, num_tokens)
        )
        self.learning_rate = learning_rate
        self.weight_decay = weight_decay
        self.save_hyperparameters()

    def forward(self, video, audio, text):
        """
        Args:
            video: [batch_size, video_dim]
            audio: [batch_size, audio_dim] 
            text: [batch_size, text_dim]
        """
        batch_size = video.shape[0]
        # Generate text mask based on NaN values
        text_mask = ~torch.isnan(text).any(dim=1, keepdim=True)
        
        # Replace missing text with learned token
        text = torch.where(
            text_mask,
            text,
            self.missing_text_token.expand(batch_size, -1)
        )
        
        # Concatenate all features
        text = text.unsqueeze(1)
        x = torch.cat([video, audio, text], dim=-1).squeeze(1)
        # Predict tokens
        return self.mlp(x)

    def calculate_metrics(self, x, x_recon):
        # Flatten the tensors for correlation calculation
        x_flat = x.reshape(x.shape[0], -1)
        x_recon_flat = x_recon.reshape(x_recon.shape[0], -1)
        
        # Calculate Pearson R for each sample in batch
        correlations = torch.stack([
            pearson_corrcoef(x_flat[i], x_recon_flat[i])
            for i in range(x_flat.shape[0])
        ])
        avg_pearson_r = correlations.mean()
        
        # Calculate variance explained
        total_variance = torch.var(x_flat, dim=1).sum()
        residual_variance = torch.var(x_flat - x_recon_flat, dim=1).sum()
        variance_explained = 1 - (residual_variance / total_variance)
        
        # Calculate MSE and MAE
        mse = F.mse_loss(x_flat, x_recon_flat)
        mae = F.l1_loss(x_flat, x_recon_flat)
        
        return avg_pearson_r, variance_explained, mse, mae
        

    def training_step(self, batch):
        video, audio, text, fmri = batch['video'], batch['audio'], batch['language'], batch['fmri']
        logits = self(video, audio, text)
        fmri_tokens = self.fmri_tokenizer.encode(fmri).to(dtype=logits.dtype)

        criterion = nn.MSELoss()
        loss = criterion(logits, fmri_tokens)
        recon_fmri = self.fmri_tokenizer.decode(logits.round().long())

        avg_pearson_r, variance_explained, mse, mae = self.calculate_metrics(fmri, recon_fmri)
        self.log('train_loss', loss)
        self.log('train_pearson_r', avg_pearson_r)
        self.log('train_variance_explained', variance_explained)
        self.log('train_mse', mse)
        self.log('train_mae', mae)
        return loss
    
    def validation_step(self, batch):
        video, audio, text, fmri = batch['video'], batch['audio'], batch['language'], batch['fmri']
        logits = self(video, audio, text)
        fmri_tokens = self.fmri_tokenizer.encode(fmri).to(dtype=logits.dtype)
        
        criterion = nn.MSELoss()
        val_loss = criterion(logits, fmri_tokens)
        val_recon_fmri = self.fmri_tokenizer.decode(logits.round().long())
        val_avg_pearson_r, val_variance_explained, val_mse, val_mae = self.calculate_metrics(fmri, val_recon_fmri)
        
        self.log('val_loss', val_loss)
        self.log('val_pearson_r', val_avg_pearson_r)
        self.log('val_variance_explained', val_variance_explained)
        self.log('val_mse', val_mse)
        self.log('val_mae', val_mae)

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(
            self.parameters(), 
            lr=self.learning_rate, 
            betas=(0.9, 0.999), 
            eps=1e-8, 
            weight_decay=self.weight_decay)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=100)
        return {
            "optimizer": optimizer,
            "lr_scheduler": {
                "scheduler": scheduler,
                "monitor": "val_loss"
            }
        }

In [6]:
# Example usage:

batch_size = 4
video_dim = 8192
audio_dim = 128
text_dim = 768
hidden_dim = 1024
num_tokens = 32
learning_rate = 1e-4
weight_decay = 0.01 
num_workers = 4
fmri_tok_dir = "/home/pranav/mihir/algonauts_challenge/algonauts2025/checkpoints/fmri_vqvae_res_gelu_8qdim_1024/epoch=44_val_loss=0.137.ckpt"

train_dl = torch.utils.data.DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=num_workers)
val_dl = torch.utils.data.DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=num_workers)

In [7]:
print("Training samples: ", len(train_ds))
print("Validation samples: ", len(val_ds))

Training samples:  22787
Validation samples:  22924


In [8]:
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.loggers import WandbLogger

run_name = "MM_MLP_1024dim_only_sub01"
wandb_logger = WandbLogger(
    project="algonauts2025",
    name=run_name,
    save_dir="wandb_logs/"
)
trainer = L.Trainer(
    max_epochs=50,
    precision=32,
    logger=wandb_logger,
    callbacks=[
        ModelCheckpoint(
            dirpath=f'checkpoints/{run_name}',
            filename='{epoch:02d}_{val_loss:.3f}',
            monitor='val_loss',
            mode='min',
            save_top_k=1,
            save_last=True
        )
    ]
)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


In [9]:
from lightning.pytorch.utilities.model_summary import ModelSummary
torch.set_float32_matmul_precision('high')
mm_mlp = MultiModalMLP(
    video_dim=video_dim,
    audio_dim=audio_dim,
    text_dim=text_dim,
    hidden_dim=hidden_dim,
    num_tokens=num_tokens,
    fmri_tok_dir=fmri_tok_dir,
    learning_rate=learning_rate,
    weight_decay=weight_decay
)
summary = ModelSummary(mm_mlp, max_depth=1)
print(summary)

  | Name           | Type       | Params | Mode 
------------------------------------------------------
0 | fmri_tokenizer | VQVAE      | 5.4 M  | eval 
1 | mlp            | Sequential | 10.4 M | train
  | other params   | n/a        | 768    | n/a  
------------------------------------------------------
10.4 M    Trainable params
5.4 M     Non-trainable params
15.8 M    Total params
63.157    Total estimated model params size (MB)
8         Modules in train mode
79        Modules in eval mode


In [10]:
trainer.fit(mm_mlp, train_dl, val_dl)

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mmihir-neal[0m ([33mmihirneal[0m). Use [1m`wandb login --relogin`[0m to force relogin


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name           | Type       | Params | Mode 
------------------------------------------------------
0 | fmri_tokenizer | VQVAE      | 5.4 M  | eval 
1 | mlp            | Sequential | 10.4 M | train
  | other params   | n/a        | 768    | n/a  
------------------------------------------------------
10.4 M    Trainable params
5.4 M     Non-trainable params
15.8 M    Total params
63.157    Total estimated model params size (MB)
8         Modules in train mode
79        Modules in eval mode


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=50` reached.
