In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
from vector_quantize_pytorch import VectorQuantize

In [3]:
import torch
import torch.nn as nn

class ResidualBlock(nn.Module):
   def __init__(self, in_dim, out_dim):
       super().__init__()
       self.downsample = in_dim != out_dim
       self.net = nn.Sequential(
           nn.Linear(in_dim, out_dim),
           nn.LayerNorm(out_dim),
           nn.GELU(),
           nn.Linear(out_dim, out_dim),
           nn.LayerNorm(out_dim),
           nn.GELU()
       )
       if self.downsample:
           self.proj = nn.Linear(in_dim, out_dim)
   
   def forward(self, x):
       if self.downsample:
           return self.proj(x) + self.net(x)
       return x + self.net(x)

class Encoder(nn.Module):
   def __init__(self, input_dim=1000, hidden_dims=[512, 384, 256], num_tokens=32, codebook_dim=64):
       super().__init__()
       
       # Initial projection with one residual block
       self.input_proj = ResidualBlock(input_dim, hidden_dims[0])
       
       # Main network with one residual block per layer
       layers = []
       for i in range(len(hidden_dims)-1):
           layers.append(ResidualBlock(hidden_dims[i], hidden_dims[i+1]))
       self.layers = nn.Sequential(*layers)
       
       # Project to token space with one residual block
       self.token_proj = ResidualBlock(hidden_dims[-1], num_tokens * codebook_dim)
       
       self.num_tokens = num_tokens
       self.codebook_dim = codebook_dim
       
   def forward(self, x):
       x = self.input_proj(x)
       x = self.layers(x)
       x = self.token_proj(x)
       return x.view(x.shape[0], self.num_tokens, self.codebook_dim)

class Decoder(nn.Module):
   def __init__(self, output_dim=1000, hidden_dims=[256, 384, 512], num_tokens=32, codebook_dim=64):
       super().__init__()
       
       # Process tokens with one residual block
       self.token_proj = ResidualBlock(num_tokens * codebook_dim, hidden_dims[0])
       
       # Main network with one residual block per layer
       layers = []
       for i in range(len(hidden_dims)-1):
           layers.append(ResidualBlock(hidden_dims[i], hidden_dims[i+1]))
       self.layers = nn.Sequential(*layers)
       
       # Final projection with one residual block
       self.output_proj = ResidualBlock(hidden_dims[-1], output_dim)
       
   def forward(self, x):
       # x shape: [batch_size, num_tokens, codebook_dim] 
       x = x.reshape(x.shape[0], -1)  # Flatten tokens
       x = self.token_proj(x)
       x = self.layers(x)
       return self.output_proj(x)

In [8]:
import lightning as L
from torchmetrics.functional import pearson_corrcoef

class VQVAE(L.LightningModule):
    def __init__(
            self, 
            input_dim=1000, 
            hidden_dims=[512, 384, 256], 
            num_tokens=32, 
            codebook_size=1024, 
            codebook_dim=8,
            commitment_weight=0.25,
            quantizer_decay=0.99,
            learning_rate=3e-4,
            weight_decay=0.01
            ):
        super().__init__()
        
        self.encoder = Encoder(input_dim, hidden_dims, num_tokens, codebook_dim)
        self.decoder = Decoder(input_dim, hidden_dims[::-1], num_tokens, codebook_dim)
        self.quantizer = VectorQuantize(
                dim=codebook_dim,
                codebook_size=codebook_size,
                decay=quantizer_decay,
                commitment_weight=commitment_weight
                )
        self.learning_rate = learning_rate
        self.weight_decay = weight_decay
        self.save_hyperparameters()
        
    def forward(self, x):
        z = self.encoder(x)
        z_q, indices, commitment_loss = self.quantizer(z)
        x_recon = self.decoder(z_q)
        return x_recon, commitment_loss, indices
    
    def calculate_metrics(self, x, x_recon):
        # Flatten the tensors for correlation calculation
        x_flat = x.reshape(x.shape[0], -1)
        x_recon_flat = x_recon.reshape(x_recon.shape[0], -1)
        
        # Calculate Pearson R for each sample in batch
        correlations = torch.stack([
            pearson_corrcoef(x_flat[i], x_recon_flat[i])
            for i in range(x_flat.shape[0])
        ])
        avg_pearson_r = correlations.mean()
        
        # Calculate variance explained
        total_variance = torch.var(x_flat, dim=1).sum()
        residual_variance = torch.var(x_flat - x_recon_flat, dim=1).sum()
        variance_explained = 1 - (residual_variance / total_variance)
        
        return avg_pearson_r, variance_explained


    def training_step(self, batch: torch.Tensor, batch_idx: int) -> torch.Tensor:
        x = batch[0]
        if len(x.shape) == 1:
            x = x.unsqueeze(0)
        x_recon, commitment_loss, _ = self(x)
        
        recon_loss = F.mse_loss(x_recon, x)
        total_loss = recon_loss + commitment_loss

        pearson_r, var = self.calculate_metrics(x, x_recon)
        
        # Log metrics
        self.log('train_loss', total_loss)
        self.log('train_recon_loss', recon_loss)
        self.log('train_commitment_loss', commitment_loss)
        self.log('train_pearson_r', pearson_r)
        self.log('train_variance', var)
        
        return total_loss
    
    def validation_step(self, batch: torch.Tensor, batch_idx: int) -> torch.Tensor:
        x = batch[0]
        if len(x.shape) == 1:
            x = x.unsqueeze(0)
        x_recon, commitment_loss, _ = self(x)
        
        recon_loss = F.mse_loss(x_recon, x)
        total_loss = recon_loss + commitment_loss

        val_pearson_r, val_var = self.calculate_metrics(x, x_recon)
        
        # Log metrics
        self.log('val_loss', total_loss)
        self.log('val_recon_loss', recon_loss)
        self.log('val_commitment_loss', commitment_loss)
        self.log('val_pearson_r', val_pearson_r)
        self.log('val_variance', val_var)
    
    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(
            self.parameters(),
            lr=self.learning_rate,  # Learning rate
            betas=(0.9, 0.999),  # Beta parameters
            eps=1e-8,  # Epsilon to prevent division by zero
            weight_decay=self.weight_decay  # Weight decay for regularization
        )
        return optimizer


In [9]:
import h5py
import glob
from torch.utils.data import Dataset, DataLoader

class FMRIDataset(Dataset):
    def __init__(self, h5_file):
        data_list = []
        try:
            with h5py.File(h5_file, 'r') as f:
                for key in f.keys():
                    data = torch.from_numpy(f[key][:]).float()
                    data_list.append(data)
        except Exception as e:
            print(f"Error loading {h5_file}: {str(e)}")
        self.data = torch.cat(data_list, dim=0)
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        return self.data[idx]

# Create separate datasets for train and val
train_file = '/home/pranav/mihir/algonauts_challenge/algonauts_2025.competitors/fmri/sub-01/func/sub-01_task-friends_space-MNI152NLin2009cAsym_atlas-Schaefer18_parcel-1000Par7Net_desc-s123456_bold.h5'
val_file = '/home/pranav/mihir/algonauts_challenge/algonauts_2025.competitors/fmri/sub-01/func/sub-01_task-movie10_space-MNI152NLin2009cAsym_atlas-Schaefer18_parcel-1000Par7Net_bold.h5'

train_dataset = FMRIDataset(train_file)
val_dataset = FMRIDataset(val_file)

# Create dataloaders
batch_size = 32
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4)

In [10]:
print(f"Training samples: {len(train_dataset)}")
print(f"Validation samples: {len(val_dataset)}")


Training samples: 137913
Validation samples: 24758


In [11]:
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.loggers import WandbLogger

run_name = "fmri_vqvae_res_gelu_8qdim_1024"
wandb_logger = WandbLogger(
    project="algonauts2025",
    name=run_name,
    save_dir="wandb_logs/"
)
trainer = L.Trainer(
    max_epochs=50,
    precision=32,
    logger=wandb_logger,
    callbacks=[
        ModelCheckpoint(
            dirpath=f'checkpoints/{run_name}',
            filename='{epoch:02d}_{val_loss:.3f}',
            monitor='val_loss',
            mode='min',
            save_top_k=1,
            save_last=True
        )
    ]
)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


In [12]:
from lightning.pytorch.utilities.model_summary import ModelSummary
torch.set_float32_matmul_precision('high')
model = VQVAE()
summary = ModelSummary(model, max_depth=2)
print(summary)

   | Name                  | Type              | Params | Mode 
---------------------------------------------------------------------
0  | encoder               | Encoder           | 2.2 M  | train
1  | encoder.input_proj    | ResidualBlock     | 1.3 M  | train
2  | encoder.layers        | Sequential        | 807 K  | train
3  | encoder.token_proj    | ResidualBlock     | 132 K  | train
4  | decoder               | Decoder           | 3.2 M  | train
5  | decoder.token_proj    | ResidualBlock     | 132 K  | train
6  | decoder.layers        | Sequential        | 1.0 M  | train
7  | decoder.output_proj   | ResidualBlock     | 2.0 M  | train
8  | quantizer             | VectorQuantize    | 0      | train
9  | quantizer.project_in  | Identity          | 0      | train
10 | quantizer.project_out | Identity          | 0      | train
11 | quantizer._codebook   | EuclideanCodebook | 0      | train
---------------------------------------------------------------------
5.4 M     Trainable params
0

In [13]:
model

VQVAE(
  (encoder): Encoder(
    (input_proj): ResidualBlock(
      (net): Sequential(
        (0): Linear(in_features=1000, out_features=512, bias=True)
        (1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (2): GELU(approximate='none')
        (3): Linear(in_features=512, out_features=512, bias=True)
        (4): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (5): GELU(approximate='none')
      )
      (proj): Linear(in_features=1000, out_features=512, bias=True)
    )
    (layers): Sequential(
      (0): ResidualBlock(
        (net): Sequential(
          (0): Linear(in_features=512, out_features=384, bias=True)
          (1): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
          (2): GELU(approximate='none')
          (3): Linear(in_features=384, out_features=384, bias=True)
          (4): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
          (5): GELU(approximate='none')
        )
        (proj): Linear(in_features=512, out_f

In [14]:
trainer.fit(model, train_loader, val_loader)

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mmihir-neal[0m ([33mmihirneal[0m). Use [1m`wandb login --relogin`[0m to force relogin


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name      | Type           | Params | Mode 
-----------------------------------------------------
0 | encoder   | Encoder        | 2.2 M  | train
1 | decoder   | Decoder        | 3.2 M  | train
2 | quantizer | VectorQuantize | 0      | train
-----------------------------------------------------
5.4 M     Trainable params
0         Non-trainable params
5.4 M     Total params
21.596    Total estimated model params size (MB)
78        Modules in train mode
0         Modules in eval mode


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=50` reached.


In [15]:
# model = VQVAE()
model = VQVAE.load_from_checkpoint("/home/pranav/mihir/algonauts_challenge/algonauts2025/checkpoints/fmri_vqvae_res_gelu_8qdim_1024/epoch=44_val_loss=0.137.ckpt")
model.eval()
val_data = []
val_recon = []
with torch.no_grad():
    for batch in val_loader:
        x = batch[0].to(model.device)
        if len(x.shape) == 1:
            x = x.unsqueeze(0)
        x_recon, _, _ = model(x)
        val_data.append(x)
        val_recon.append(x_recon)

val_data = torch.cat(val_data)
val_recon = torch.cat(val_recon)

#Calculating dead codes and perplexity
from tqdm import tqdm

# Initialize counters
codebook_size = 1024
counts = torch.zeros(codebook_size, device="cuda")  # Use "cpu" if needed

# Run inference on the test set
model.eval()
with torch.no_grad():
    for batch in tqdm(val_loader):
        x = batch.to(model.device)
        z_e = model.encoder(x)  # Get encoder output
        z_q, indices, _ = model.quantizer(z_e)  # Assuming `vq_layer` returns indices
        indices_flat = indices.view(-1)  # Flatten all batch/spatial dimensions
        counts += torch.bincount(indices_flat, minlength=codebook_size)

# Compute metrics
num_dead_codes = (counts == 0).sum().item()
prob = counts / (counts.sum() + 1e-10)
perplexity = torch.exp(-torch.sum(prob * torch.log(prob + 1e-10)))

print(f"Dead codes: {num_dead_codes} / {codebook_size}")
print(f"Perplexity: {perplexity.item():.1f} (max: {codebook_size})")


100%|██████████| 774/774 [00:01<00:00, 524.64it/s]

Dead codes: 0 / 1024
Perplexity: 870.3 (max: 1024)





In [None]:
import matplotlib.pyplot as plt
import seaborn as sns

# Plot code frequencies
plt.figure(figsize=(10, 4))
sns.histplot(counts.cpu().numpy(), bins=50, log_scale=(False, True))
plt.xlabel("Code Frequency (log scale)")
plt.title("Code Usage Distribution")
plt.show()

In [None]:
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity

codebook_vectors = model.quantizer.codebook.cpu().numpy()  # Shape [1024, 8]
similarity = cosine_similarity(codebook_vectors)

plt.figure(figsize=(8, 8))
plt.imshow(similarity, cmap="viridis", vmin=-1, vmax=1)
plt.colorbar(label="Cosine Similarity")
plt.title("Codebook Vector Similarity")
plt.show()

In [None]:
from sklearn.manifold import TSNE

# Collect codebook vectors and their usage counts
codebook = model.quantizer.codebook.cpu().numpy()
counts_np = counts.cpu().numpy()

# Run t-SNE
tsne = TSNE(n_components=2, random_state=42, perplexity=30)
codebook_2d = tsne.fit_transform(codebook)

# Plot with color indicating usage frequency
plt.figure(figsize=(10, 8))
plt.scatter(codebook_2d[:, 0], codebook_2d[:, 1], c=np.log(counts_np + 1), cmap="viridis")
plt.colorbar(label="Log(Usage Count + 1)")
plt.title("Codebook t-SNE Projection")
plt.show()

In [17]:
var_explained = 1 - torch.var(val_data - val_recon) / torch.var(val_data)
pearson_r = torch.corrcoef(torch.stack([val_data.flatten(), val_recon.flatten()]))[0,1]
mse = torch.mean((val_data - val_recon) ** 2)
mae = torch.mean(torch.abs(val_data - val_recon))

print(f"Variance explained: {var_explained:.3f}")
print(f"Pearson correlation: {pearson_r:.3f}")
print(f"MSE: {mse:.3f}")
print(f"MAE: {mae:.3f}")

Variance explained: 0.628
Pearson correlation: 0.792
MSE: 0.137
MAE: 0.290


In [18]:
def analyze_codebook(model, val_loader, device='cuda'):
    model.eval()
    usage_count = torch.zeros(model.quantizer.codebook_size, device=device)
    total_tokens = 0
    distances = []
    
    with torch.no_grad():
        for batch in val_loader:
            x = batch[0].to(device)
            if len(x.shape) == 1:
                x = x.unsqueeze(0)
            # Get encoder output
            z = model.encoder(x)
            # Get quantizer outputs before quantization
            z_flat = z.reshape(-1, z.shape[-1])
            _, indices, _ = model.quantizer(z_flat)
            
            # Update usage count
            unique, counts = torch.unique(indices, return_counts=True)
            usage_count[unique] += counts
            total_tokens += indices.numel()
            
            # Calculate distances to assigned codebook vectors
            dists = torch.cdist(z_flat, model.quantizer.codebook)
            distances.append(dists.min(dim=1)[0])
    
    # Calculate metrics
    usage_prob = usage_count / total_tokens
    active_codes = (usage_count > 0).sum().item()
    entropy = -(usage_prob * torch.log2(usage_prob + 1e-10)).sum().item()
    distances = torch.cat(distances)
    
    print(f"Active codebook vectors: {active_codes}/{model.quantizer.codebook_size}")
    print(f"Codebook entropy: {entropy:.2f} bits")
    print(f"Mean distance to codebook: {distances.mean():.4f}")
    print(f"Distance std: {distances.std():.4f}")
    
    return {
        'usage_count': usage_count.cpu(),
        'usage_prob': usage_prob.cpu(),
        'active_codes': active_codes,
        'entropy': entropy,
        'distances': distances.cpu(),
    }

In [None]:
analyze_codebook(model, val_loader)

In [None]:
model.quantizer.weight