In [1]:
# Import packages and setup gpu configuration.
# This code block shouldnt need to be adjusted!
import os
import sys
import json
import yaml
import numpy as np
import pandas as pd
import copy
import math
from einops import rearrange
from einops.layers.torch import Rearrange
import time
import random
import h5py
import webdataset as wds
import gc
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
import utils
from models import get_vit
import nibabel as nib
from nilearn import plotting

import lightning as pl
from typing import List
from lightning import Trainer
from lightning.pytorch.loggers import WandbLogger
from lightning.pytorch.strategies import DDPStrategy
from lightning.pytorch.callbacks import LearningRateMonitor, ModelCheckpoint, RichModelSummary, RichProgressBar

# tf32 data type is faster than standard float32
torch.backends.cuda.matmul.allow_tf32 = True

  _torch_pytree._register_pytree_node(


In [2]:
# Load parameters from yaml config
config = yaml.load(open('config.yaml', 'r'), Loader=yaml.FullLoader)

# create global variables from the config
for attribute_name in config.keys():
    globals()[attribute_name] = config[f'{attribute_name}']

# Configuration

In [3]:
print(config)

# seed all random functions
utils.seed_everything(seed)

outdir = os.path.abspath(f'../ckpts/{model_name}')
print("outdir", outdir)

if use_contrastive_loss:
    global_batch_size = global_batch_size // 2 # contrastive loss doubles the batch size with the same samples and different masks
print("global_batch_size", global_batch_size)

use_cls_token = True if use_contrastive_loss else use_cls_token
print("use_cls_token", use_cls_token)

num_patches = int(
    (img_size[0] / patch_size)
    * (img_size[1] / patch_size)
    * (img_size[2] / patch_size)
    * num_frames
)
num_patches_per_timepoint = num_patches // num_frames
num_encoder_patches = int(num_patches_per_timepoint * (1 - tube_start_masking_ratio) * num_frames)
num_decoder_patches = int(num_patches_per_timepoint * (1 - decoder_mask_ratio) * num_frames)
print("num_patches", num_patches)
print("num_encoder_patches", num_encoder_patches)
print("num_decoder_patches", num_decoder_patches)

{'model_name': 'test_ddp_lightning6', 'use_cls_token': False, 'use_contrastive_loss': False, 'constrastive_loss_weight': 1.0, 'global_batch_size': 4, 'num_workers': 8, 'max_steps': 10000, 'eval_per_n_steps': 2000, 'limit_val_batches': 800, 'seed': 42, 'max_lr': 3e-05, 'ckpt_saving': True, 'ckpt_interval': 50, 'resume_from_ckpt': False, 'wandb_log': True, 'tube_start_masking_ratio': 0.75, 'tube_end_masking_ratio': 0.75, 'decoder_mask_ratio': 0.75, 'encoder_model': 'vit_base', 'decoder_model': 'vit_small', 'patch_size': 8, 'frame_patch_size': 1, 'use_rope_emb': False, 'masking_strategy': 'MNI', 'img_size': [88, 104, 72], 'num_frames': 4, 'is_s3': False, 'test_urls': '/weka/proj-fmri/souvik/foundational_model/000000.tar', 'train_urls': '/weka/proj-fmri/souvik/foundational_model/000000.tar'}
outdir /weka/proj-fmri/souvik/foundational_model/ckpts/test_ddp_lightning6
global_batch_size 4
use_cls_token False
num_patches 5148
num_encoder_patches 1287
num_decoder_patches 1287


# DataModule

In [4]:
from dataloader import fMRIDataModule
datamodule = fMRIDataModule(train_urls=train_urls, test_urls=test_urls, batch_size=global_batch_size,num_workers=num_workers)

# Training Module

In [5]:
class FMRITrainer(pl.LightningModule):
    def __init__(
        self, 
        encoder_model: str, 
        decoder_model: str,
        img_size: List[int],
        patch_size: int,
        num_frames: int,
        frame_patch_size: int,
        use_rope_emb: bool,
        use_cls_token: bool,
        masking_strategy: str,
        max_steps: int,
        tube_start_masking_ratio: float,
        tube_end_masking_ratio: float,
        use_contrastive_loss: bool,
        decoder_mask_ratio: float,
        constrastive_loss_weight: float,
        max_lr: float,
        batch_size: int,
    ):
        super().__init__()
        assert len(img_size) == 3 # 3D volumes
        self.max_steps = max_steps
        self.tube_start_masking_ratio = tube_start_masking_ratio
        self.tube_end_masking_ratio = tube_end_masking_ratio
        self.masking_strategy = masking_strategy
        self.num_frames = num_frames
        self.decoder_mask_ratio = decoder_mask_ratio
        self.constrastive_loss_weight = constrastive_loss_weight
        self.max_lr = max_lr
        self.batch_size = batch_size
        self.use_contrastive_loss = use_contrastive_loss
        self.use_cls_token = use_cls_token
    
        self.model = get_vit(
            size={"encoder": encoder_model, "decoder": decoder_model},
            image_size=img_size,  # depth, height, width
            image_patch_size=(patch_size,patch_size,patch_size),  # depth, height, width patch size
            frames=num_frames,
            frame_patch_size=frame_patch_size,
            channels=1,
            use_rope_emb=use_rope_emb,
            use_cls_token=use_cls_token,
        )
        self.aug_transform = utils.DataPrepper(
            num_frames=num_frames,
            masking_strategy=masking_strategy,
            patch_depth=patch_size,
            patch_height=patch_size,
            patch_width=patch_size,
            frame_patch_size=frame_patch_size,
        )
        self.num_patches = int(
            (img_size[0] / patch_size)
            * (img_size[1] / patch_size)
            * (img_size[2] / patch_size)
            * (num_frames/frame_patch_size)
        )
        
        if self.masking_strategy=="MNI":
            MNI_brain = nib.load("/weka/proj-fmri/paulscotti/old_fMRI-foundation-model/dataset_creation/afni_conversion/tpl-MNI152NLin2009cAsym_res-02_T1w_brain.nii.gz").get_fdata()
            brain_pos_voxels = MNI_brain[6:94,8:112,10:82]
            self.brain_pos_pats = self.model.patchify(torch.Tensor(brain_pos_voxels)[None,None,None])
            self.brain_pos_pats_vit = rearrange(self.brain_pos_pats, "b ... d -> b (...) d").mean(-1)[0]
            
        ## LOSS
        self.mse = nn.MSELoss()
        if use_contrastive_loss:
            self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))  # learned logit scale


    def shared_step(self, batch, batch_idx, phase: str):
        tube_mask_ratio = utils.get_masking_ratio(
            current_epoch=self.global_step, 
            total_epochs=self.max_steps, 
            start_masking_ratio=self.tube_start_masking_ratio, 
            end_masking_ratio=self.tube_end_masking_ratio
        )
        
        input_func = batch['func.npy'].to(self.device)
        if self.masking_strategy=="MNI":
            func, _ = self.aug_transform(input_func)
            brain_pos_pats = self.brain_pos_pats
            brain_pos_pats_vit = self.brain_pos_pats_vit
        else:
            func, brain_pos_voxels = self.aug_transform(input_func)
            brain_pos_pats = model.patchify(torch.Tensor(brain_pos_voxels)[None,None,None])
            brain_pos_pats_vit = rearrange(brain_pos_pats, "b ... d -> b (...) d").mean(-1)[0]
        
        if self.use_contrastive_loss and phase=="train":  # create positive pairs by duplicating the batch
            func = torch.cat([func, func], dim=0)
            brain_pos_pats = torch.cat([brain_pos_pats, brain_pos_pats], dim=0)
            brain_pos_pats_vit = rearrange(brain_pos_pats, "b ... d -> b (...) d").mean(-1)[0]
        
        func = func.unsqueeze(1)
        batch_size=func.shape[0]
        # create tube mask (i.e., a mask that is the same for all frames/timepoints)
        tube_mask = torch.zeros(self.num_patches // self.num_frames).to(torch.bool).to(self.device)
        batch_positive_approx = (brain_pos_pats_vit > 0)
        mask_idx_candidates = torch.where(batch_positive_approx)[0]
        # check if there's not enough brain left for code to continue
        if len(mask_idx_candidates) < (
            int(self.num_patches/self.num_frames*(1-tube_mask_ratio))+int(self.num_patches/self.num_frames*(1-self.decoder_mask_ratio))):
            print("Brain volume skipped due to not enough brain-positive patches remaining...")
            return
        mask_idx_candidates = mask_idx_candidates[torch.randperm(len(mask_idx_candidates))]
        tube_idx = mask_idx_candidates[:int(self.num_patches / self.num_frames * (1 - tube_mask_ratio))]
        tube_mask[tube_idx] = True
        tube_mask = tube_mask.tile(self.num_frames)
        
        # create decoder mask
        decoder_mask = torch.zeros(self.num_patches // self.num_frames).to(torch.bool).to(self.device)
        remaining_mask_idx = mask_idx_candidates[int(self.num_patches / self.num_frames * (1 - tube_mask_ratio)):]
        decoder_mask_idx = remaining_mask_idx[:int(self.num_patches / self.num_frames * (1 - self.decoder_mask_ratio))]
        decoder_mask[decoder_mask_idx] = True
        decoder_mask = decoder_mask.tile(self.num_frames)
        
        # encode the tube patches
        encoder_out = self.model(func.to(self.device), encoder_mask=tube_mask.to(self.device))
        if self.use_cls_token:
            enc_cls_token = encoder_out[:,:1,:]

        # decode both the encoder_out patches and masked decoder patches
        decoder_out = self.model(encoder_out, encoder_mask=tube_mask, decoder_mask=decoder_mask)
        # subset only the reconstructed decoder patches
        output = decoder_out[:, -decoder_mask.sum():]
        
        # compare to ground truth and calculate loss
        target_patches = self.model.patchify(func)
        target_patches_vit = rearrange(target_patches, "b ... d -> b (...) d")
        target = target_patches_vit[:, decoder_mask]
        loss = self.mse(output, target)
        self.log(f"{phase}/recon_loss", loss, on_step=True, on_epoch=False, prog_bar=True, logger=True, batch_size=batch_size)
        # contrastive loss
        if self.use_contrastive_loss and phase=="train":
            n_b = len(func) // 2
            cls_token1 = enc_cls_token[:n_b, 0, :]  # first half of batch, cls_token shape B, 1, d_model
            cls_token2 = enc_cls_token[n_b:, 0, :]
            contrastive_loss = utils.contrastive_loss(cls_token1, cls_token2, temperature=self.logit_scale)
            cnt_loss = self.constrastive_loss_weight * contrastive_loss
            self.log(f"{phase}/contrastive_loss", cnt_loss, on_step=True, on_epoch=False, prog_bar=True, logger=True, batch_size=batch_size)
            loss += cnt_loss
        self.log(f"{phase}/loss", loss, on_step=True, on_epoch=False, prog_bar=True, logger=True, batch_size=batch_size)
        return loss
        
    def training_step(self, batch, batch_idx):
        return self.shared_step(batch, batch_idx, "train")
    
    def validation_step(self, batch, batch_idx):
        return self.shared_step(batch, batch_idx, "val")
    
    def test_step(self, batch, batch_idx):
        return self.shared_step(batch, batch_idx, "test")
    
    def configure_optimizers(self):
        no_decay = ["bias", "LayerNorm.bias", "LayerNorm.weight"]
        opt_grouped_parameters = [
            {'params': [p for n, p in self.model.named_parameters() if not any(nd in n for nd in no_decay)]},
            {'params': [p for n, p in self.model.named_parameters() if any(nd in n for nd in no_decay)]},
        ]
        self.optimizer = torch.optim.AdamW(opt_grouped_parameters, lr=self.max_lr)
        self.lr_scheduler = torch.optim.lr_scheduler.OneCycleLR(
            self.optimizer,
            max_lr=self.max_lr,
            total_steps=self.max_steps
        )
        # return self.optimizer
        return [self.optimizer], [self.lr_scheduler]

# Callbacks and trainers
> Currently using
> - Save Checkpoint
> - LR monitor to log lr
> - Model summary

In [6]:
checkpoint_callback = ModelCheckpoint(
    dirpath=os.path.abspath(f'../ckpts/{model_name}'),
    filename='{epoch}-{val/loss:.5f}',
    save_top_k=1,
    verbose=True,
    monitor='val/loss',
    mode='min',
)
callbacks = [
    LearningRateMonitor(),
    checkpoint_callback,
    RichModelSummary(max_depth=2),
    RichProgressBar()
]
trainer = Trainer(
    devices="auto",
    num_nodes=int(os.getenv('NUM_NODES', 1)),
    precision="16-mixed",
    logger=WandbLogger(project='found', id=model_name, name=model_name),
    callbacks=callbacks,
    max_steps=max_steps,
    ddp = DDPStrategy(process_group_backend="gloo"),
    # val_check_interval=eval_per_n_steps,
    # limit_train_batches=100, 
    # max_epochs=20,
    log_every_n_steps=10,
    # limit_val_batches=100
    # overfit_batches=5
)

/admin/home-kaladin/fmri/lib/python3.11/site-packages/lightning/fabric/plugins/environments/slurm.py:204: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python3.11 /admin/home-kaladin/fmri/lib/python3.11/site-pac ...
Using 16bit Automatic Mixed Precision (AMP)
Trainer already configured with model summary callbacks: [<class 'lightning.pytorch.callbacks.rich_model_summary.RichModelSummary'>]. Skipping setting a default `ModelSummary` callback.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [7]:
module = FMRITrainer(
    encoder_model=encoder_model, 
    decoder_model=decoder_model, 
    img_size=img_size, 
    patch_size=patch_size, 
    num_frames=num_frames, 
    masking_strategy=masking_strategy,
    tube_start_masking_ratio=tube_start_masking_ratio, 
    tube_end_masking_ratio=tube_end_masking_ratio,
    decoder_mask_ratio=decoder_mask_ratio, 
    frame_patch_size=frame_patch_size, 
    use_rope_emb=use_rope_emb, 
    use_cls_token=use_cls_token,
    max_lr=max_lr, 
    max_steps=max_steps,
    batch_size=global_batch_size,
    use_contrastive_loss=use_contrastive_loss,
    constrastive_loss_weight=constrastive_loss_weight,
)

In [8]:
trainer.fit(module, datamodule=datamodule)

[34m[1mwandb[0m: Currently logged in as: [33mmandalsouvik76[0m ([33msouvikmandal[0m). Use [1m`wandb login --relogin`[0m to force relogin


Adding decoder ImageHandler to decoders.
Adding decoder ImageHandler to decoders.


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [1,2,5,6]


Output()