In [1]:
# Import packages and setup gpu configuration.
# This code block shouldnt need to be adjusted!
import os
import sys
import json
import yaml
import numpy as np
import pandas as pd
import copy
import math
from einops import rearrange
from einops.layers.torch import Rearrange
import time
import random
import h5py
import webdataset as wds
import gc
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
import utils
import models
import nibabel as nib
from nilearn import plotting

# tf32 data type is faster than standard float32
torch.backends.cuda.matmul.allow_tf32 = True

### Multi-GPU config ###
device_count = torch.cuda.device_count()
print(f"Number of available CUDA devices: {device_count}")

local_rank = os.getenv('LOCAL_RANK')
if local_rank is None: 
    local_rank = 0
else:
    local_rank = int(local_rank)
print(f"LOCAL RANK={local_rank}")

num_devices = os.getenv('NUM_GPUS')
if num_devices is None: 
    num_devices = 1
else:
    num_devices = int(num_devices)
print(f"NUM GPUS={num_devices}")
distributed = True if num_devices>1 else False
if distributed: assert device_count==num_devices

node = os.getenv('SLURM_NODEID')
if node is None:
    node = 0
else:
    node = int(node)
print(f"NODE={node}")

global_rank = os.getenv('RANK')
if global_rank is None:
    global_rank = 0
else:
    global_rank = int(global_rank)
print(f"GLOBAL RANK={global_rank}")

world_size = os.getenv('WORLD_SIZE')
if world_size is None: 
    world_size = 1
else:
    world_size = int(world_size)
print(f"WORLD_SIZE={world_size}")

if utils.is_interactive():
    # Following allows you to change functions in models.py or utils.py and 
    # have this notebook automatically update with your revisions
    %load_ext autoreload
    %autoreload 2
    from tqdm.notebook import tqdm
else:
    from tqdm import tqdm

# Load parameters from yaml config
config = yaml.load(open('config.yaml', 'r'), Loader=yaml.FullLoader)

# create global variables from the config
for attribute_name in config.keys():
    globals()[attribute_name] = config[f'{attribute_name}']

data_type = torch.float16 # change depending on your mixed_precision
# batch_size = global_batch_size // num_devices
global_batch_size = batch_size * world_size

# FSDP Setup
if distributed:
    import torch.distributed as dist
    import torch.multiprocessing as mp
    from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
    from torch.distributed.fsdp.api import BackwardPrefetch, CPUOffload, ShardingStrategy
    import functools
    from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy, transformer_auto_wrap_policy
    print("starting init_process_group...")
    dist.init_process_group("nccl", rank=global_rank, world_size=world_size)
    print(f"setting device to cuda:{local_rank}")
    try:
        torch.cuda.set_device(local_rank)
        device = torch.device('cuda',local_rank)
        print(f"\nSuccessfully set cuda:{local_rank} | global_rank{global_rank} | node{node}")
    except Exception as error:        
        print(f"\nFAILED TO SET DEVICE cuda:{local_rank} | global_rank{global_rank} | node{node}")
        print("An exception occurred:", error)
        
else:
    device = torch.device('cuda')

print("PID of this process =",os.getpid())
print("device =", device, "distributed =",distributed, "num_devices =", num_devices, "local rank =", local_rank, "world size =", world_size, "data_type =", data_type)


Number of available CUDA devices: 1
LOCAL RANK=0
NUM GPUS=1
NODE=0
GLOBAL RANK=0
WORLD_SIZE=1
PID of this process = 3914273
device = cuda distributed = False num_devices = 1 local rank = 0 world size = 1 data_type = torch.float16


# Configuration

In [2]:
print(config)

# seed all random functions
utils.seed_everything(seed)

outdir = os.path.abspath(f'../ckpts/{model_name}')
print("outdir", outdir)

num_patches = int(
    (img_size[0] / patch_size)
    * (img_size[1] / patch_size)
    * (img_size[2] / patch_size)
    * num_frames
)
num_patches_per_timepoint = num_patches // num_frames
num_encoder_patches = int(num_patches_per_timepoint * (1 - tube_start_masking_ratio) * num_frames)
# num_decoder_patches = int(num_patches_per_timepoint * (1 - decoder_mask_ratio) * num_frames)
print("num_patches", num_patches)
print("num_encoder_patches", num_encoder_patches)
# print("num_decoder_patches", num_decoder_patches)

{'model_name': 'apr17_encoder32_decoder32_randomtubemask_fixdecodemask_learnableposemb_lr4e-6', 'use_cls_token': False, 'use_contrastive_loss': False, 'constrastive_loss_weight': 1.0, 'batch_size': 8, 'num_workers': 10, 'num_epochs': 1000, 'seed': 42, 'max_lr': 4e-06, 'num_samples_per_epoch': 1024, 'cache_dir': 'cache', 'ckpt_saving': True, 'ckpt_interval': 50, 'resume_from_ckpt': True, 'wandb_log': True, 'wandb_group_name': 'mamba', 'tube_start_masking_ratio': 0.75, 'tube_end_masking_ratio': 0.75, 'decoder_mask_ratio': 0.75, 'encoder_depth': 32, 'decoder_depth': 32, 'patch_size': 8, 'frame_patch_size': 1, 'encoder_outdim': 6, 'decoder_outdim': 512, 'use_rope_emb': False, 'masking_strategy': 'MNI', 'img_size': [88, 104, 72], 'is_random_tube_mask_ratio': True, 'num_frames': 4, 'is_s3': False, 'train_urls': ['/scratch/gpfs/KNORMAN/nsdfoundation/wds/{000005..000099}.tar'], 'test_urls': ['/scratch/gpfs/KNORMAN/nsdfoundation/wds/{000000..000004}.tar'], 'test_num_iterations_per_epoch': 10, '

# Prep models

In [3]:

# vit_size = {
#     "encoder": encoder_model,
#     "decoder": decoder_model
# }
    
# vit_model = models.get_vit(
#     size=vit_size,
#     image_size=img_size,  # depth, height, width
#     image_patch_size=(patch_size,patch_size,patch_size),  # depth, height, width patch size
#     frames=num_frames,
#     frame_patch_size=frame_patch_size,
#     channels=1,
#     use_rope_emb=use_rope_emb,
#     use_cls_token=use_cls_token,
# )
model = models.get_mamba("middle",
                        channels=1,
                        img_size=img_size,  # depth, height, width
                        patch_size=(patch_size,patch_size,patch_size),
                        num_frames=num_frames,
                        frame_patch_size=frame_patch_size,
                        device=device,
                        embed_dim=512, 
                        encoder_outdim=encoder_outdim,
                        decoder_outdim=decoder_outdim,
                        encoder_depth=encoder_depth, 
                        decoder_depth=decoder_depth,
                        )

utils.count_params(model)

# test that the model works without error
model = model.to(device)
# encoder_mask = torch.zeros(num_patches).to(device).to(torch.bool)
# encoder_mask[:num_encoder_patches] = True
# decoder_mask = torch.zeros(num_patches).to(device).to(torch.bool)
# decoder_mask[-num_decoder_patches:] = True
# with torch.no_grad():
#     print("\nencoder")
#     encoder_out = model(
#                 torch.randn(6, 1, 4, 64, 64, 48).to(device),
#                 mask=encoder_mask)
#     print("\ndecoder")
#     decoder_out = model(
#                 encoder_out, 
#                 encoder_mask=encoder_mask, 
#                 decoder_mask=decoder_mask)
#     if use_cls_token:
#         enc_cls_token = encoder_out[:, :1, :]
#         encoder_patches = encoder_out[:, 1:, :]
#         dec_cls_token = decoder_out[:, :1, :]
#         decoder_patches = decoder_out[:, 1:, :]
#         print("\nenc_cls_token", enc_cls_token.shape)
#         print("encoder_patches", encoder_patches.shape)
#         print("dec_cls_token", dec_cls_token.shape)
#         print("decoder_patches", decoder_patches.shape)

aug_transform = utils.DataPrepper(
    masking_strategy="conservative",
    patch_depth=patch_size,
    patch_height=patch_size,
    patch_width=patch_size,
    frame_patch_size=frame_patch_size,
)

Use checkpoint: False
Checkpoint number: 0
512 rms_norm True residual_in_fp32 True fused_add_norm True bimamba True ssm_cfg None
param counts:
117,486,598 total
117,486,598 trainable


# Create dataset and data loaders

In [4]:
def log_and_continue(exn):
    """Call in an exception handler to ignore any exception, issue a warning, and continue."""
    print(f'Handling webdataset error ({repr(exn)}). Ignoring.')
    return True

def filter_corrupted_images(sample):
    """If all the required files are not present don't use them."""
    correct_data = ("func.npy" in sample)
    return correct_data

### ================      Train Dataset and DataLoader    ====================
from braceexpand import braceexpand
print(train_urls)

if is_s3:
    expanded_urls = [f"pipe:aws s3 cp {url} -" for pattern in train_urls for url in braceexpand(pattern)]
else:
    expanded_urls = [str(url) for pattern in train_urls for url in braceexpand(pattern)]
train_data = (
    wds.WebDataset(expanded_urls, resampled=True, nodesplitter=wds.split_by_node, handler=log_and_continue)
    .shuffle(100, initial=100, rng=random.Random(seed))
    .select(filter_corrupted_images)
    .decode("torch")
    .rename(key="__key__", func="func.npy")
    .to_tuple(*("key","func"))
)
train_dl = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=False, drop_last=True, pin_memory=True)
#    .map_dict(func=utils.numpy_decoder)
# train_dl = wds.WebLoader(
#     train_data.batched(batch_size), 
#     pin_memory=True,
#     shuffle=False,
#     batch_size=None,
#     num_workers=num_workers, 
#     persistent_workers=num_workers>0,
# ).with_epoch(num_samples_per_epoch//batch_size)

if is_s3:
    expanded_urls = [f"pipe:aws s3 cp {url} -" for pattern in test_urls for url in braceexpand(pattern)]
else:
    expanded_urls = [str(url) for pattern in test_urls for url in braceexpand(pattern)]

test_data = (
    wds.WebDataset(expanded_urls, resampled=True, nodesplitter=wds.split_by_node, handler=log_and_continue)
    .shuffle(100, initial=100, rng=random.Random(seed))
    .select(filter_corrupted_images)
    .decode("torch")
    .rename(key="__key__", func="func.npy")
    .to_tuple(*("key","func"))
)
test_dl = torch.utils.data.DataLoader(test_data, batch_size=batch_size, shuffle=False, drop_last=True, pin_memory=True)

['/scratch/gpfs/KNORMAN/nsdfoundation/wds/{000005..000099}.tar']


### Check data loaders work and calculate number of iterations per epoch

In [5]:
if not distributed:
    start_time = time.time() 
    num_it = 2
    print(f"Yielding {num_it} batches")
    
    for i, batch in enumerate(train_dl):
        print("iter",i)
        key, input_func = batch
        if i >= (num_it-1):
            break
    
    print("Done!")
    print("input_func", input_func.shape)

    end_time = time.time()  
    execution_time = end_time - start_time  
    print(f"Execution time: {execution_time} seconds")

Yielding 2 batches
iter 0
iter 1
Done!
input_func torch.Size([8, 32, 88, 104, 72])
Execution time: 4.834169149398804 seconds


# Playing with the data, visualization of patching + masking

In [6]:
if masking_strategy=="MNI":
    # MNI_brain = nib.load(f"/weka/home-alexnguyen/mamba_fmri/fMRI-MAE/cache/tpl-MNI152NLin2009cAsym_res-02_T1w_brain.nii.gz").get_fdata()
    MNI_brain = nib.load(f"/scratch/gpfs/qanguyen/mamba_fmri/fMRI-MAE/cache/tpl-MNI152NLin2009cAsym_res-02_T1w_brain.nii.gz").get_fdata()
    brain_pos_voxels = MNI_brain[6:94,8:112,10:82]
    brain_pos_pats = model.patchify(torch.Tensor(brain_pos_voxels)[None,None,None])
    brain_pos_pats_vit = rearrange(brain_pos_pats, "b ... d -> b (...) d").mean(-1)[0]

In [7]:
# if utils.is_interactive():
#     # extract func volumes and their reference mean and standard deviation volumes
#     if masking_strategy=="MNI":
#         func, _ = aug_transform(input_func)
#     else:
#         func, brain_pos_voxels = aug_transform(input_func)
#         brain_pos_pats = model.patchify(torch.Tensor(brain_pos_voxels)[None,None,None])
#         brain_pos_pats_vit = rearrange(brain_pos_pats, "b ... d -> b (...) d").mean(-1)[0]
#     func = func.reshape(-1, num_frames, func.shape[-3], func.shape[-2], func.shape[-1])
#     func = func.unsqueeze(1)  # add empty first dimension to serve as 1d channel dimension

#     # patchify func samples
#     print("func", func.shape)
#     patches = model.patchify(func)
#     print("patches", patches.shape)
#     patches_vit = rearrange(patches, "b ... d -> b (...) d")
#     print("patches_vit", patches_vit.shape)
#     print("num patches in one timepoint", patches_vit.shape[1] // num_frames)

#     # start by masking everything (aka include nothing)
#     tube_mask = torch.zeros(num_patches // num_frames).to(torch.bool)
#     # approximate brain positive patches for the whole batch
#     batch_positive_approx = (brain_pos_pats_vit > 0)
#     mask_idx_candidates = torch.where(batch_positive_approx)[0]
#     mask_idx_candidates = mask_idx_candidates[torch.randperm(len(mask_idx_candidates))]
#     print("Percentage of brain positive patches", len(mask_idx_candidates) / len(batch_positive_approx))
#     tube_idx = mask_idx_candidates[: int(num_patches / num_frames * (1 - tube_start_masking_ratio))]
#     print("num tube patches =", len(tube_idx))
#     tube_mask[tube_idx] = True  # Trues mean to include the patch, False means to remove the patch
#     tube_mask = tube_mask.tile(num_frames)  # repeat masking for the other timepoints
#     print("tube mask percent", tube_mask.sum().item() / len(tube_mask))

#     # create decoder mask similar to tube mask, but ensure no overlap
#     decoder_mask = torch.zeros(num_patches // num_frames).to(torch.bool)  # start by masking everything (aka include nothing)
#     remaining_mask_idx = mask_idx_candidates[int(num_patches / num_frames * (1 - tube_start_masking_ratio)) :]  # brain positive tokens not selected for the encoder tokens
#     decoder_mask_idx = remaining_mask_idx[:int(num_patches / num_frames * (1 - decoder_mask_ratio))]
#     print("num decoder patches =", len(decoder_mask_idx))
#     decoder_mask[decoder_mask_idx] = True
#     decoder_mask = decoder_mask.tile(num_frames)  # repeat masking for the other timepoints
#     print("decoder_mask percent", decoder_mask.sum().item() / len(decoder_mask))

#     # apply masks to patches_vit
#     tube_patches_vit = copy.deepcopy(patches_vit.detach())
#     decoder_patches_vit = copy.deepcopy(patches_vit.detach())
#     # tube_patches_vit[:, tube_mask] = 1
#     # decoder_patches_vit[:, decoder_mask] = 1
#     tube_patches_vit[:, ~tube_mask] = 0
#     decoder_patches_vit[:, ~decoder_mask] = 0

#     # undo patchification so we can visualize
#     tube_unpatches = rearrange(
#         tube_patches_vit,
#         "b (f d h w) c -> b f d h w c",
#         d=img_size[0]//patch_size,
#         h=img_size[1]//patch_size,
#         w=img_size[2]//patch_size,
#     )
#     decoder_unpatches = rearrange(
#         decoder_patches_vit,
#         "b (f d h w) c -> b f d h w c",
#         d=img_size[0]//patch_size,
#         h=img_size[1]//patch_size,
#         w=img_size[2]//patch_size,
#     )
#     print("tube_unpatches", tube_unpatches.shape)
#     print("decoder_unpatches", decoder_unpatches.shape)
    
#     encoder_func = rearrange(
#         tube_unpatches,
#         "b f d h w (pd ph pw pf c) -> b c (f pf) (d pd) (h ph) (w pw)",
#         b=len(func),
#         f=num_frames,
#         d=img_size[0] // patch_size,
#         h=img_size[1] // patch_size,
#         w=img_size[2] // patch_size,
#         pd=patch_size,
#         ph=patch_size,
#         pw=patch_size,
#         pf=frame_patch_size,
#     )
#     decoder_func = rearrange(
#         decoder_unpatches,
#         "b f d h w (pd ph pw pf c) -> b c (f pf) (d pd) (h ph) (w pw)",
#         b=len(func),
#         f=num_frames,
#         d=img_size[0] // patch_size,
#         h=img_size[1] // patch_size,
#         w=img_size[2] // patch_size,
#         pd=patch_size,
#         ph=patch_size,
#         pw=patch_size,
#         pf=frame_patch_size,
#     )
#     print("encoder_func", encoder_func.shape)
#     print("decoder_func", decoder_func.shape)
    
#     brain_pos_vit = copy.deepcopy(patches_vit.detach())
#     brain_pos_vit[:,batch_positive_approx.repeat(num_frames)] = 1
#     brain_pos_vit[:,~batch_positive_approx.repeat(num_frames)] = 0
#     brain_pos_unpatches = rearrange(
#         brain_pos_vit,
#         "b (f d h w) c -> b f d h w c",
#         d=img_size[0]//patch_size,
#         h=img_size[1]//patch_size,
#         w=img_size[2]//patch_size,
#     )
#     brain_pos_func = rearrange(
#         brain_pos_unpatches,
#         "b f d h w (pd ph pw pf c) -> b c (f pf) (d pd) (h ph) (w pw)",
#         b=len(func),
#         f=num_frames,
#         d=img_size[0] // patch_size,
#         h=img_size[1] // patch_size,
#         w=img_size[2] // patch_size,
#         pd=patch_size,
#         ph=patch_size,
#         pw=patch_size,
#         pf=frame_patch_size,
#     )

#     # Visualize
#     idx = 0
#     print("original func")
#     display(transforms.ToPILImage()(utils.reshape_to_2d(func[idx].clamp(0,1))))
    
#     print("\nbrain-positive patches")
#     display(transforms.ToPILImage()(utils.reshape_to_2d(brain_pos_func[idx].clamp(0,1))))

#     print("\nencoder func")
#     display(transforms.ToPILImage()(utils.reshape_to_2d(encoder_func[idx].clamp(0,1))))

#     print("\ndecoder func")
#     display(transforms.ToPILImage()(utils.reshape_to_2d(decoder_func[idx].clamp(0,1))))

# FSDP / optimizer / saving functions


In [8]:
if distributed:    
    # my_auto_wrap_policy = functools.partial(
    #     size_based_auto_wrap_policy, min_num_params=200000
    # )
    from mamba_ssm.modules.mamba_simple import Block
    my_auto_wrap_policy = functools.partial(
        transformer_auto_wrap_policy, 
        transformer_layer_cls={
            Block, # <--- Your Transformer layer class
        },
    )
    print(f"\nPrepping FSDP on {global_rank} {node}...\n")
    model = FSDP(
        model,
        sharding_strategy=ShardingStrategy.HYBRID_SHARD,
        auto_wrap_policy=my_auto_wrap_policy,
        use_orig_params=False,
        cpu_offload=None, #CPUOffload(offload_params=True)
        sync_module_states=True,
        limit_all_gathers=True, # See https://github.com/pytorch/pytorch/issues/91165
        device_id=device,
    )
    print(f"\nSuccessfully loaded FSDP model to device on global_rank {global_rank}\n")
    dist.barrier()
    print(f"\nSuccessfully loaded FSDP model to device on global_rank {global_rank}\n")

In [9]:
no_decay = ["bias", "LayerNorm.bias", "LayerNorm.weight"]
opt_grouped_parameters = [
    {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': 1e-2},
    {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0},
]

if distributed:
    max_lr = max_lr * global_batch_size
    print(f"multiply lr {max_lr} by global batch size: max_lr={max_lr}")
optimizer = torch.optim.AdamW(opt_grouped_parameters, lr=max_lr)
num_iterations_per_epoch = num_samples_per_epoch // global_batch_size

total_steps = num_epochs * num_iterations_per_epoch * num_devices
print("total_steps", total_steps)
lr_scheduler = torch.optim.lr_scheduler.OneCycleLR(
    optimizer,
    max_lr=max_lr,
    total_steps=total_steps,
)

print("\nDone with model preparations!")
num_params = utils.count_params(model)

total_steps 128000

Done with model preparations!
param counts:
117,486,598 total
117,486,598 trainable


In [10]:
default_ckpt_path = outdir+f'/last.pth'

def save_ckpt(model,tag="last"):
    if distributed: dist.barrier()
    model_states = model.state_dict()
    if global_rank == 0:
        os.makedirs(outdir,exist_ok=True)
        ckpt_path = outdir+f'/{tag}.pth'
        torch.save({
            'epoch': epoch,
            'model_state_dict': model_states,
            'optimizer_state_dict': optimizer.state_dict(),
            'lr_scheduler_state_dict': lr_scheduler.state_dict(),
        }, ckpt_path)
        print(f"\n---saved {ckpt_path}!---\n")

def resume_ckpt(model, optimizer, lr_scheduler, device, ckpt_path=default_ckpt_path):
    if global_rank == 0:
        checkpoint = torch.load(ckpt_path, map_location=device)
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        lr_scheduler.load_state_dict(checkpoint['lr_scheduler_state_dict'])
        epoch = checkpoint['epoch']
    else:
        epoch = 0
    if distributed: dist.barrier()
    torch.cuda.empty_cache()
    return model, optimizer, lr_scheduler, epoch

# Start wandb (if enabled)

In [11]:
if utils.is_interactive():
    ckpt_saving = False
    wandb_log = False
if global_rank==0 and wandb_log: # only use main process for wandb logging
    import wandb
    wandb_project = 'found' 
    print(f"wandb {wandb_project} run {model_name}")
    # need to configure wandb beforehand in terminal with "wandb init"!
    wandb_config = {
      "model_name": model_name,
      "global_batch_size": global_batch_size,
      "batch_size": batch_size,
      "num_epochs": num_epochs,
      "num_samples_per_epoch": num_samples_per_epoch,
    #   "encoder_model": encoder_model,
    #   "decoder_model": decoder_model,
      "tube_start_masking_ratio": tube_start_masking_ratio,
      "tube_end_masking_ratio": tube_end_masking_ratio,
      "decoder_mask_ratio": decoder_mask_ratio,
      "num_frames": num_frames,
      "patch_size": patch_size,
      "frame_patch_size": frame_patch_size,
      "use_contrastive_loss": use_contrastive_loss,
      "use_cls_token": use_cls_token,
      "constrastive_loss_weight": constrastive_loss_weight,
      "num_params": num_params,
      "max_lr": max_lr,
      "ckpt_interval": ckpt_interval,
      "ckpt_saving": ckpt_saving,
      "seed": seed,
      "distributed": distributed,
      "num_devices": num_devices,
      "world_size": world_size,
      "train_urls": train_urls,
      "test_urls": test_urls,
    }
    print("wandb_config:\n",wandb_config)
    print("wandb_id:",model_name)
    wandb.init(
        # id=model_name,
        project=wandb_project,
        name=model_name,
        config=wandb_config,
        resume="allow",
        group=wandb_group_name
    )
else:
    wandb_log = False

# Start training

In [12]:
epoch = 0
lrs, train_losses, recon_losses, contrastive_losses, test_losses = [], [], [], [], []
best_test_loss = 1e9
torch.cuda.empty_cache()

In [13]:
debug=True

if (resume_from_ckpt==True):# and (debug!=True):
    if os.path.exists(default_ckpt_path):
        print(f"Resuming from {default_ckpt_path}...")
        model, optimizer, lr_scheduler, epoch = resume_ckpt(model, optimizer, lr_scheduler, device)

In [17]:
if distributed: dist.barrier()
mse = nn.MSELoss()
if use_contrastive_loss:
    contrastive_temps = utils.cosine_anneal(0.004, 0.0075, num_epochs)
model.train()
progress_bar = tqdm(range(epoch, num_epochs), disable=global_rank!=0, desc="Overall")
for epoch in progress_bar:
    # get the masking ratio for the current epoch
    tube_mask_ratio = utils.get_masking_ratio(
        current_epoch=epoch, 
        total_epochs=num_epochs, 
        start_masking_ratio=tube_start_masking_ratio, 
        end_masking_ratio=tube_end_masking_ratio
    )
    num_decoder_patches = int(num_patches * tube_mask_ratio)
    with torch.cuda.amp.autocast(dtype=data_type):
        model.train()
        for train_i, batch in enumerate(train_dl):
            optimizer.zero_grad()

            key, input_func = batch 
            if masking_strategy=="MNI":
                func, _ = aug_transform(input_func)
            else:
                func, brain_pos_voxels = aug_transform(input_func)
                brain_pos_pats = model.patchify(torch.Tensor(brain_pos_voxels)[None,None,None])
                brain_pos_pats_vit = rearrange(brain_pos_pats, "b ... d -> b (...) d").mean(-1)[0]
            if is_random_tube_mask_ratio==True: tube_mask_ratio = np.random.uniform(low=0.001,high=0.75) # random tube mask ratio to train with different #s of patches
            func = func.reshape(-1, num_frames, func.shape[-3], func.shape[-2], func.shape[-1]).float()
            func = func.unsqueeze(1).clamp(0,1).to(device)

            # create tube mask (i.e., a mask that is the same for all frames/timepoints)
            tube_mask = torch.zeros(num_patches // num_frames).to(torch.bool)
            batch_positive_approx = (brain_pos_pats_vit > 0)
            mask_idx_candidates = torch.where(batch_positive_approx)[0]
            mask_idx_candidates = mask_idx_candidates[torch.randperm(len(mask_idx_candidates))]
            num_masked_voxels = min(int(len(mask_idx_candidates) * (1 - tube_mask_ratio)), len(mask_idx_candidates) - 1)
            tube_idx = mask_idx_candidates[:num_masked_voxels]
            tube_mask[tube_idx] = True
            tube_mask = tube_mask.tile(num_frames)#.to(device)
             

            # create decoder mask
            decoder_mask = torch.zeros(num_patches // num_frames).to(torch.bool)
            remaining_mask_idx = mask_idx_candidates[num_masked_voxels:]
            decoder_mask_idx = remaining_mask_idx#[:int(num_patches / num_frames * (1 - decoder_mask_ratio))]
            decoder_mask[decoder_mask_idx] = True
            decoder_mask = decoder_mask.tile(num_frames)#.to(device)
 
            # decoder_mask = ~tube_mask
            
            # encode the tube patches
            encoder_out = model(func, encoder_mask=tube_mask)
            if use_cls_token:
                enc_cls_token = encoder_out[:,:1,:]
            # print ("tube_mask", tube_mask.shape)
            print ("encoder_out", encoder_out.shape)
            
            # decode both the encoder_out patches and masked decoder patches
            decoder_out = model(encoder_out, encoder_mask=tube_mask, decoder_mask=decoder_mask)
            # decoder_out = model(func, encoder_mask=tube_mask, decoder_mask=decoder_mask)
            # subset only the reconstructed decoder patches
            # print ("decoder_out", decoder_out.shape, "num_decoder_patches", num_decoder_patches)
            # the decoder mask gets concatenated to the end
            output = decoder_out[:, -decoder_mask.sum():].clamp(0,1)
            

            # compare to ground truth and calculate loss 
            target_patches = model.patchify(func)
            target_patches_vit = rearrange(target_patches, "b ... d -> b (...) d")
            target = target_patches_vit[:, decoder_mask]
            #print("encoder_out",encoder_out.shape,"output",output.shape,"target",target.shape)
            # print("output")
            # #print(output[0,:10,:10])

            # #print(target[0,:10,:10])
            # plt.imshow(torch.corrcoef(output[0]).detach().cpu(),vmin=0,vmax=1 )
            # plt.show()
            # plt.hist(torch.corrcoef(output[0]).detach().cpu().flatten())
            # plt.show()
            # print("target")
            # plt.imshow(torch.corrcoef(target[0]).detach().cpu(),vmin=0,vmax=1 )
            # plt.show()
            # plt.hist(torch.corrcoef(target[0]).detach().cpu().flatten())
            # plt.show()
            
            # print("torch.corrcoef(output[0])",torch.corrcoef(output[0]))
            # ## visualize
            # decode_vis = torch.zeros_like(target_patches_vit)
            # print("output",output.shape, "-decoder_mask.sum()",-decoder_mask.sum())
            # decode_vis[:, decoder_mask] = output.to(decode_vis.device).to(decode_vis.dtype)
            # decoder_unpatches = rearrange(
            #     decode_vis,
            #     "b (f d h w) c -> b f d h w c",
            #     d=img_size[0]//patch_size,
            #     h=img_size[1]//patch_size,
            #     w=img_size[2]//patch_size,
            # )
            # decoder_func = rearrange(
            #     decoder_unpatches,
            #     "b f d h w (pd ph pw pf c) -> b c (f pf) (d pd) (h ph) (w pw)",
            #     b=batch_size,
            #     f=num_frames,
            #     d=img_size[0]//patch_size,
            #     h=img_size[1]//patch_size,
            #     w=img_size[2]//patch_size,
            #     pd=patch_size,
            #     ph=patch_size,
            #     pw=patch_size,
            #     pf=frame_patch_size,
            # )
            # if train_i%2 == 1:
            #     orig_image = torch.zeros_like(utils.reshape_to_2d(func[idx]))
            # else:
            #     orig_image =  (utils.reshape_to_2d(func[idx]))
            # recon_image = utils.reshape_to_2d(decoder_func[idx])
            # print("orig_image",orig_image.shape, orig_image.min(),orig_image.max(), "recon_image",recon_image.shape,recon_image.min(),recon_image.max())
            # combined_image = orig_image.clone()
            # combined_image[recon_image!=0] = recon_image[recon_image!=0]

            # random_start = np.random.randint(recon_image.shape[1]-400)
            # orig_image = transforms.ToPILImage()(orig_image[:,random_start:random_start+100])
            # recon_image = transforms.ToPILImage()(recon_image[:,random_start:random_start+100])
            # combined_image = transforms.ToPILImage()(combined_image[:,random_start:random_start+100])


            # plt.imshow( (utils.reshape_to_2d(func[idx])).detach().cpu() [:,random_start:random_start+100] ,vmin=0,vmax=1)
            # plt.show()
            # plt.imshow(recon_image.detach().cpu()[:,random_start:random_start+100],vmin=0,vmax=1)
            # plt.show()
            # plt.imshow(combined_image.detach().cpu() [:,random_start:random_start+100] ,vmin=0,vmax=1)
            # plt.show() 
            # print("func",func.shape,"output",output.shape,"target",target.shape)
            loss = mse(output, target)
            # print ( "decoder_mask.sum()", decoder_mask.sum(), "loss", loss.item())
            recon_losses.append(loss.item())
 
            # contrastive loss
            if use_contrastive_loss:
                enc_norm = nn.functional.normalize(encoder_out.flatten(1), dim=-1)
                cosine_similarities = enc_norm @ enc_norm.T
                
                softmax_scores = nn.functional.softmax(cosine_similarities / contrastive_temps[epoch], dim=-1)
                contrastive_loss = nn.functional.cross_entropy(softmax_scores, torch.arange(len(cosine_similarities)).to(device))
                
                loss += constrastive_loss_weight * contrastive_loss
                contrastive_losses.append(contrastive_loss.item())

            
            loss.backward()
            optimizer.step()
            lr_scheduler.step()
            train_losses.append(loss.item())
            # print("loss",loss.item(), "tube_mask_ratio", tube_mask_ratio)
            lrs.append(optimizer.param_groups[0]["lr"])
            if (train_i >= (num_iterations_per_epoch-1)) or debug:
                print("train_i", train_i, "local_rank", local_rank, "global_rank", global_rank)
                break
            
        model.eval()
        with torch.no_grad():
            for test_i, batch in enumerate(test_dl):
                key, input_func = batch 

                if masking_strategy=="MNI":
                    func, _ = aug_transform(input_func)
                else:
                    func, brain_pos_voxels = aug_transform(input_func)
                    brain_pos_pats = model.patchify(torch.Tensor(brain_pos_voxels)[None,None,None])
                    brain_pos_pats_vit = rearrange(brain_pos_pats, "b ... d -> b (...) d").mean(-1)[0]
                    
                if is_random_tube_mask_ratio==True: tube_mask_ratio = np.random.uniform(low=0.001,high=0.75) # random tube mask ratio to train with different #s of patches
                func = func.reshape(-1, num_frames, func.shape[-3], func.shape[-2], func.shape[-1]).float()
                func = func.unsqueeze(1).clamp(0,1).to(device)
    # create tube mask (i.e., a mask that is the same for all frames/timepoints)
                tube_mask = torch.zeros(num_patches // num_frames).to(torch.bool)
                batch_positive_approx = (brain_pos_pats_vit > 0)
                mask_idx_candidates = torch.where(batch_positive_approx)[0]
                mask_idx_candidates = mask_idx_candidates[torch.randperm(len(mask_idx_candidates))]
                num_masked_voxels = min(int(len(mask_idx_candidates) * (1 - tube_mask_ratio)), len(mask_idx_candidates) - 1)
                tube_idx = mask_idx_candidates[:num_masked_voxels]
                tube_mask[tube_idx] = True
                tube_mask = tube_mask.tile(num_frames)#.to(device)
                

                # create decoder mask 
                decoder_mask = torch.zeros(num_patches // num_frames).to(torch.bool)
                remaining_mask_idx = mask_idx_candidates[num_masked_voxels:]
                decoder_mask_idx = remaining_mask_idx#[:int(num_patches / num_frames * (1 - decoder_mask_ratio))]
                decoder_mask[decoder_mask_idx] = True
                decoder_mask = decoder_mask.tile(num_frames)#.to(device)

                # decoder_mask = ~tube_mask
                
                # encode the tube patches
                encoder_out = model(func, encoder_mask=tube_mask)
                if use_cls_token:
                    enc_cls_token = encoder_out[:,:1,:]
                # print ("test encoder_out", encoder_out.shape, "test decoder_mask.sum()", decoder_mask.sum())
                # print ("tube_mask", tube_mask.shape)
                # print ("encoder_out", encoder_out.shape)
                # decode both the encoder_out patches and masked decoder patches
                decoder_out = model(encoder_out, encoder_mask=tube_mask, decoder_mask=decoder_mask)
                # decoder_out = model(func, encoder_mask=tube_mask, decoder_mask=decoder_mask)
                # subset only the reconstructed decoder patches
                # print ("decoder_out", decoder_out.shape, "num_decoder_patches", num_decoder_patches)
                # the decoder mask gets concatenated to the end
                output = decoder_out[:, -decoder_mask.sum():].clamp(0,1)
    

                # compare to ground truth and calculate loss
                target_patches = model.patchify(func)
                target_patches_vit = rearrange(target_patches, "b ... d -> b (...) d")
                target = target_patches_vit[:, decoder_mask]
                loss = mse(output, target)
                test_losses.append(loss.item())

                if test_i >= (test_num_iterations_per_epoch-1):
                    break
 
        logs = {
            "train/loss": np.mean(train_losses[-(train_i + 1) :]),
            "train/recon_losses": np.mean(recon_losses[-(train_i + 1) :]),
            "train/contrastive_losses": np.mean(contrastive_losses[-(train_i + 1) :]),
            "train/num_steps": len(recon_losses),
            "test/loss": np.mean(test_losses[-(test_i + 1) :]),
            "test/num_steps": len(test_losses),
            "lr": np.mean(lrs[-(train_i + 1) :]),
            "epoch": epoch,
            "tube_mask_ratio": tube_mask_ratio,
            "decoder_mask_ratio": decoder_mask_ratio,
        }
        progress_bar.set_postfix(**logs)
        if distributed: print(logs)

#         if global_rank==0:
           # Plot progress (first sample in batch)
#             with torch.no_grad():
#                 if utils.is_interactive():
#                     idx = 0
#                     if epoch % 5 == 0:
#                         decode_vis = torch.zeros_like(target_patches_vit)
#                         decode_vis[:, decoder_mask] = output.to(decode_vis.device).to(decode_vis.dtype)
#                         decoder_unpatches = rearrange(
#                             decode_vis,
#                             "b (f d h w) c -> b f d h w c",
#                             d=img_size[0]//patch_size,
#                             h=img_size[1]//patch_size,
#                             w=img_size[2]//patch_size,
#                         )
#                         decoder_func = rearrange(
#                             decoder_unpatches,
#                             "b f d h w (pd ph pw pf c) -> b c (f pf) (d pd) (h ph) (w pw)",
#                             b=batch_size*2,
#                             f=num_frames,
#                             d=img_size[0]//patch_size,
#                             h=img_size[1]//patch_size,
#                             w=img_size[2]//patch_size,
#                             pd=patch_size,
#                             ph=patch_size,
#                             pw=patch_size,
#                             pf=frame_patch_size,
#                         )
#                         orig_image = utils.reshape_to_2d(func[idx])
#                         recon_image = utils.reshape_to_2d(decoder_func[idx])
    
#                         combined_image = orig_image.clone()
#                         combined_image[recon_image!=0] = recon_image[recon_image!=0]
                        
#                         random_start = np.random.randint(recon_image.shape[1]-400)
#                         orig_image = transforms.ToPILImage()(orig_image[:,random_start:random_start+400])
#                         recon_image = transforms.ToPILImage()(recon_image[:,random_start:random_start+400])
#                         combined_image = transforms.ToPILImage()(combined_image[:,random_start:random_start+400])
    
#                         if wandb_log:
#                             logs[f"train/orig"] = wandb.Image(orig_image, caption=f"epoch{epoch:03d}")
#                             logs[f"train/recon"] = wandb.Image(recon_image, caption=f"epoch{epoch:03d}")
#                             logs[f"train/combined"] = wandb.Image(combined_image, caption=f"epoch{epoch:03d}")
#                         else:
#                             # display(orig_image)
#                             # display(recon_image)
#                             display(combined_image)
#             if wandb_log: wandb.log(logs)
        
        # wait for other GPUs to catch up if needed
#         if distributed: dist.barrier()

        # Save model checkpoint
        if (ckpt_saving) and ((epoch % ckpt_interval == 0) or (epoch==num_epochs-1)) and (debug==False):
            save_ckpt(model,"last")
            
        # wait for other GPUs to catch up if needed
        torch.cuda.empty_cache()
        
if distributed:
    dist.destroy_process_group()

Overall:   0%|          | 0/1000 [00:00<?, ?it/s]

patched_emb torch.Size([1, 1287, 512]) tensor([[-0.0054,  0.0035,  0.0132,  ...,  0.0053,  0.0061, -0.0405],
        [ 0.0187, -0.0197, -0.0276,  ...,  0.0096, -0.0321,  0.0140],
        [-0.0166,  0.0359, -0.0093,  ...,  0.0124,  0.0023,  0.0261],
        ...,
        [-0.0110,  0.0134,  0.0288,  ...,  0.0074,  0.0040,  0.0210],
        [ 0.0089,  0.0277, -0.0107,  ..., -0.0220,  0.0190, -0.0158],
        [ 0.0232, -0.0165, -0.0057,  ..., -0.0431,  0.0079, -0.0061]],
       device='cuda:0', grad_fn=<SelectBackward0>)
encoder_out torch.Size([8, 2556, 6])


AttributeError: 'VisionMamba' object has no attribute 'encoder_to_decoder'

In [None]:
plt.figure(figsize=(8, 3))
plt.plot(recon_losses)
plt.title("Training re-construction losses")
plt.show()
if use_contrastive_loss:
    plt.figure(figsize=(8, 3))
    plt.plot(contrastive_losses)
    plt.title("Training contrastive losses")
    plt.show()