In [1]:
# Import packages and setup gpu configuration.
# This code block shouldnt need to be adjusted!
import os
import sys
import json
import yaml
import numpy as np
import pandas as pd
import copy
import math
from einops import rearrange
from einops.layers.torch import Rearrange
import time
import random
import h5py
import webdataset as wds
import gc
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
import utils
from models import *
import nibabel as nib
from nilearn import plotting

import schedulefree

# tf32 data type is faster than standard float32
torch.backends.cuda.matmul.allow_tf32 = True

### Multi-GPU config ###
device_count = torch.cuda.device_count()
print(f"Number of available CUDA devices: {device_count}")

local_rank = os.getenv('LOCAL_RANK')
if local_rank is None: 
    local_rank = 0
else:
    local_rank = int(local_rank)
print(f"LOCAL RANK={local_rank}")

num_devices = os.getenv('NUM_GPUS')
if num_devices is None: 
    num_devices = 1
else:
    num_devices = int(num_devices)
print(f"NUM GPUS={num_devices}")
distributed = True if num_devices>1 else False
if distributed: assert device_count==num_devices

node = os.getenv('SLURM_NODEID')
if node is None:
    node = 0
else:
    node = int(node)
print(f"NODE={node}")

global_rank = os.getenv('RANK')
if global_rank is None:
    global_rank = 0
else:
    global_rank = int(global_rank)
print(f"GLOBAL RANK={global_rank}")

world_size = os.getenv('WORLD_SIZE')
if world_size is None: 
    world_size = 1
else:
    world_size = int(world_size)
print(f"WORLD_SIZE={world_size}")

if utils.is_interactive():
    # Following allows you to change functions in models.py or utils.py and 
    # have this notebook automatically update with your revisions
    %load_ext autoreload
    %autoreload 2
    from tqdm.notebook import tqdm
else:
    from tqdm import tqdm

# Load parameters from yaml config
config = yaml.load(open('config.yaml', 'r'), Loader=yaml.FullLoader)

# create global variables from the config
for attribute_name in config.keys():
    globals()[attribute_name] = config[f'{attribute_name}']

data_type = torch.float16 # change depending on your mixed_precision
# batch_size = global_batch_size // num_devices
global_batch_size = batch_size * world_size

# FSDP Setup
if distributed:
    import torch.distributed as dist
    import torch.multiprocessing as mp
    from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
    from torch.distributed.fsdp.api import BackwardPrefetch, CPUOffload, ShardingStrategy
    import functools
    from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy, transformer_auto_wrap_policy
    print("starting init_process_group...")
    dist.init_process_group("nccl", rank=global_rank, world_size=world_size)
    print(f"setting device to cuda:{local_rank}")
    try:
        torch.cuda.set_device(local_rank)
        device = torch.device('cuda',local_rank)
        print(f"\nSuccessfully set cuda:{local_rank} | global_rank{global_rank} | node{node}")
    except Exception as error:        
        print(f"\nFAILED TO SET DEVICE cuda:{local_rank} | global_rank{global_rank} | node{node}")
        print("An exception occurred:", error)
        
else:
    device = torch.device('cuda')

print("PID of this process =",os.getpid())
print("device =", device, "distributed =",distributed, "num_devices =", num_devices, "local rank =", local_rank, "world size =", world_size, "data_type =", data_type)

Number of available CUDA devices: 1
LOCAL RANK=0
NUM GPUS=1
NODE=0
GLOBAL RANK=0
WORLD_SIZE=1
PID of this process = 1339581
device = cuda distributed = False num_devices = 1 local rank = 0 world size = 1 data_type = torch.float16


# Configuration

In [2]:
print(config)

# seed all random functions
utils.seed_everything(seed)

outdir = os.path.abspath(f'../ckpts/{model_name}')
os.makedirs(outdir,exist_ok=True)
print("outdir", outdir)
print("global_batch_size", global_batch_size)
print("use_cls_token", use_cls_token)

if type(patch_size) == int:
    patch_size = [patch_size,patch_size,patch_size]
patch_depth = patch_size[0]
patch_height = patch_size[1]
patch_width = patch_size[2]

num_patches = int(
    (img_size[0] / patch_depth)
    * (img_size[1] / patch_height)
    * (img_size[2] / patch_width)
    * num_frames
)
num_patches_per_timepoint = num_patches // frame_patch_size
num_encoder_patches = int(np.floor((num_patches_per_timepoint * num_frames // frame_patch_size) * (1 - tube_start_masking_ratio)))
num_decoder_patches = int(np.floor((num_patches_per_timepoint * num_frames  // frame_patch_size) * (1 - decoder_mask_ratio)))
print("num_patches", num_patches)
print("num_patches_per_timepoint", num_patches_per_timepoint)
print("num_encoder_patches", num_encoder_patches)
print("num_decoder_patches", num_decoder_patches)

{'model_name': 'mini_nomask_logitsCLS_downstream_40ep_l', 'use_cls_token': True, 'use_contrastive_loss': True, 'contrastive_loss_weight': 0.1, 'batch_size': 256, 'num_workers': 10, 'num_epochs': 20, 'seed': 42, 'max_lr': 3e-06, 'num_samples_per_epoch': 1024, 'test_num_samples_per_epoch': 384, 'ckpt_saving': True, 'ckpt_interval': 50, 'resume_from_ckpt': True, 'wandb_log': True, 'tube_start_masking_ratio': 0.75, 'tube_end_masking_ratio': 0.75, 'decoder_mask_ratio': 0.75, 'patch_size': [8, 8, 8], 'frame_patch_size': 4, 'use_rope_emb': False, 'masking_strategy': 'None', 'encoder_model': 'vit_mini', 'decoder_model': 'vit_mini', 'img_size': [88, 104, 72], 'num_frames': 4, 'is_s3': False, 'train_urls': ['/weka/proj-fmri/shared/NSD_MNI_wds/{000000..000699}.tar'], 'test_urls': ['/weka/proj-fmri/shared/NSD_MNI_wds/{000700..000738}.tar']}
outdir /weka/proj-fmri/paulscotti/fMRI-foundation-model/ckpts/mini_nomask_logitsCLS_downstream_40ep_l
global_batch_size 256
use_cls_token True
num_patches 5148

# Prep models

In [3]:
vit_size = {
    "encoder": encoder_model,
    "decoder": decoder_model
}
    
model = get_vit(
    size=vit_size,
    image_size=img_size,  # depth, height, width
    image_patch_size=(patch_depth,patch_height,patch_width),  # depth, height, width patch size
    frames=num_frames,
    frame_patch_size=frame_patch_size,
    channels=1,
    use_rope_emb=use_rope_emb,
    use_cls_token=use_cls_token,
)
utils.count_params(model)

# function to select random num_frames from sample and obtain brain-positive patches
aug_transform = utils.DataPrepper(
    num_frames=num_frames*2,
    masking_strategy=masking_strategy,
    patch_depth=patch_depth,
    patch_height=patch_height,
    patch_width=patch_width,
    frame_patch_size=frame_patch_size,
)

# test that the model works without error
model = model.to(device)
encoder_mask = torch.zeros(num_patches_per_timepoint).to(torch.bool)
encoder_mask[:num_encoder_patches] = True
decoder_mask = torch.zeros(num_patches_per_timepoint).to(torch.bool)
decoder_mask[-num_decoder_patches:] = True
decoder_mask[encoder_mask] = False
with torch.no_grad():
    print("\nencoder")
    encoder_out = model(
                torch.randn(batch_size, 1, num_frames, img_size[0], img_size[1], img_size[2]).to(device),
                encoder_mask=encoder_mask,
                verbose=True)
    print("\ndecoder")
    decoder_out = model(
                encoder_out, 
                encoder_mask=encoder_mask, 
                decoder_mask=decoder_mask, 
                verbose=True)
    if use_cls_token:
        enc_cls_token = encoder_out[:, :1, :]
        encoder_patches = encoder_out[:, 1:, :]
        dec_cls_token = decoder_out[:, :1, :]
        decoder_patches = decoder_out[:, 1:, :]
        print("\nenc_cls_token", enc_cls_token.shape)
        print("encoder_patches", encoder_patches.shape)
        print("dec_cls_token", dec_cls_token.shape)
        print("decoder_patches", decoder_patches.shape)

param counts:
1,693,008 total
1,693,008 trainable

encoder
torch.Size([256, 1, 4, 88, 104, 72])
patched torch.Size([256, 1, 11, 13, 9, 2048])
reshaped torch.Size([256, 1287, 2048])
masked torch.Size([256, 321, 2048])
patched_emb torch.Size([256, 321, 48])
pe torch.Size([1287, 48])
masked torch.Size([256, 322, 48])
torch.Size([256, 322, 48])

decoder
torch.Size([256, 322, 48])
pe torch.Size([1287, 48])
pos_emd_encoder torch.Size([321, 48])
pos_emd_decoder torch.Size([321, 48])
x_concat torch.Size([256, 643, 48])
torch.Size([256, 643, 48])
proj torch.Size([256, 643, 2048])

enc_cls_token torch.Size([256, 1, 48])
encoder_patches torch.Size([256, 321, 48])
dec_cls_token torch.Size([256, 1, 2048])
decoder_patches torch.Size([256, 642, 2048])


### Add "linear" probe

In [4]:
class LinearProbe(nn.Module):
    def __init__(self, input_dim, h=256, num_classes=8):
        super(LinearProbe, self).__init__()
        # self.classifier = nn.Linear(input_dim, num_classes)
        self.classifier = nn.Sequential(
            nn.LayerNorm(input_dim),
            nn.GELU(),
            nn.Linear(input_dim, h),
            nn.LayerNorm(h),
            nn.GELU(),
            nn.Linear(h, h),
            nn.LayerNorm(h),
            nn.GELU(),
            nn.Linear(h, num_classes)
        )
    def forward(self, x):
        x = self.classifier(x)
        return x

In [5]:
# if use_cls_token:
#     model.cont = LinearProbe((num_encoder_patches+1)*model.encoder_embed_dim,h=768,num_classes=768)
# else:
# model.cont = LinearProbe(model.encoder_embed_dim,h=256,num_classes=256)
# model = model.to(device)

## Create dataset and data loaders

In [6]:
# from dataloader import create_dataset, create_loader
# train_urls = train_urls[0]
# print(train_urls)

# train_dp = create_dataset(train_urls, 
#                           is_s3=train_urls[:2]=="s3", 
#                           sample_shuffle=100, shard_shuffle=100)
# train_dl = create_loader(train_dp, batch_size=batch_size, num_workers=num_workers)

In [7]:
def log_and_continue(exn):
    """Call in an exception handler to ignore any exception, issue a warning, and continue."""
    print(f'Handling webdataset error ({repr(exn)}). Ignoring.')
    return True

def filter_corrupted_images(sample):
    """If all the required files are not present don't use them."""
    correct_data = ("func.npy" in sample)
    return correct_data

### ================      Train Dataset and DataLoader    ====================
from braceexpand import braceexpand
print(train_urls)
if is_s3:
    expanded_urls = [f"pipe:aws s3 cp {url} -" for pattern in train_urls for url in braceexpand(pattern)]
else:
    expanded_urls = [str(url) for pattern in train_urls for url in braceexpand(pattern)]

train_data = (
    wds.WebDataset(expanded_urls, resampled=True, nodesplitter=wds.split_by_node, handler=log_and_continue)
    .shuffle(100, initial=100, rng=random.Random(seed))
    .select(filter_corrupted_images)
    .decode("torch")
)
train_dl = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=False, drop_last=True, pin_memory=True)

### ================      Test Dataset and DataLoader    ====================
print(test_urls)
if is_s3:
    expanded_urls = [f"pipe:aws s3 cp {url} -" for pattern in test_urls for url in braceexpand(pattern)]
else:
    expanded_urls = [str(url) for pattern in train_urls for url in braceexpand(pattern)]

test_data = (
    wds.WebDataset(expanded_urls, resampled=True, nodesplitter=wds.split_by_node, handler=log_and_continue)
    .shuffle(100, initial=100, rng=random.Random(seed))
    .select(filter_corrupted_images)
    .decode("torch")
)
test_dl = torch.utils.data.DataLoader(test_data, batch_size=batch_size, shuffle=False, drop_last=True, pin_memory=True)

['/weka/proj-fmri/shared/NSD_MNI_wds/{000000..000699}.tar']
['/weka/proj-fmri/shared/NSD_MNI_wds/{000700..000738}.tar']


### Check data loaders work

In [8]:
# if utils.is_interactive():
#     start_time = time.time() 
#     num_it = 2
#     print(f"Yielding {num_it} batches")
    
#     for i, batch in enumerate(test_dl):
#         print("iter",i)
#         input_func = batch['func.npy']
#         subject_id = batch['subject_id.txt']
#         subject_id = [int(subject[-2:]) for subject in subject_id]
#         # session_id = batch['session_id.txt']
#         # session_id = [int(session[-2:]) for session in session_id]
#         func, brain_pos_pats = aug_transform(input_func)
#         if i >= (num_it-1):
#             break
    
#     print("Done!")
#     print("input_func", input_func.shape)
#     print("func", func.shape)
#     print("subject_id", subject_id)

#     end_time = time.time()  
#     execution_time = end_time - start_time  
#     print(f"Execution time: {execution_time} seconds")

### Playing with the data, visualization of patching + masking

In [9]:
# if utils.is_interactive():
#     func, brain_pos_pats = aug_transform(input_func)
#     print(func.shape)
#     display(utils.view_brain(func,cut_coords=(44,44,44)))
# # plt.hist(func[0,0].flatten().clamp(.25,3),bins=100)

# Set up optimizer and saving functions

In [10]:
if distributed:    
    # my_auto_wrap_policy = functools.partial(
    #     size_based_auto_wrap_policy, min_num_params=200000
    # )
    my_auto_wrap_policy = functools.partial(
        transformer_auto_wrap_policy, 
        transformer_layer_cls={
            Attention, # <--- Your Transformer layer class
        },
    )
    print(f"\nPrepping FSDP on {global_rank} {node}...\n")
    model = model.to(device)
    model = FSDP(
        model,
        sharding_strategy=ShardingStrategy.HYBRID_SHARD,
        auto_wrap_policy=my_auto_wrap_policy,
        use_orig_params=False,
        cpu_offload=None, #CPUOffload(offload_params=True)
        sync_module_states=True,
        limit_all_gathers=True, # See https://github.com/pytorch/pytorch/issues/91165
        device_id=device,
    )
    print(f"\nSuccessfully loaded FSDP model to device on global_rank {global_rank}\n")
    dist.barrier()
    print(f"\nSuccessfully loaded FSDP model to device on global_rank {global_rank}\n")

In [11]:
no_decay = ["bias", "LayerNorm.bias", "LayerNorm.weight"]
opt_grouped_parameters = [
    {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': 1e-2},
    {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0},
]

if distributed:
    max_lr = max_lr * global_batch_size
    print(f"multiply lr {max_lr} by global batch size: max_lr={max_lr}")

# optimizer = torch.optim.AdamW(opt_grouped_parameters, lr=max_lr)
optimizer = schedulefree.AdamWScheduleFree(opt_grouped_parameters, lr=max_lr)

num_iterations_per_epoch = num_samples_per_epoch // global_batch_size
print("num_iterations_per_epoch", num_iterations_per_epoch)

probe_num_iterations_per_epoch = test_num_samples_per_epoch // global_batch_size
print("probe_num_iterations_per_epoch", probe_num_iterations_per_epoch)

total_steps = num_epochs * num_iterations_per_epoch * num_devices
print("total_steps", total_steps)

# lr_scheduler = torch.optim.lr_scheduler.OneCycleLR(
#     optimizer,
#     max_lr=max_lr,
#     total_steps=total_steps,
# )

print("\nDone with model preparations!")
num_params = utils.count_params(model)

num_iterations_per_epoch 4
probe_num_iterations_per_epoch 1
total_steps 80

Done with model preparations!
param counts:
1,838,256 total
1,838,256 trainable


In [12]:
def save_ckpt(model,tag="last"):
    if distributed: dist.barrier()
    model_states = model.state_dict()
    if global_rank == 0:
        os.makedirs(outdir,exist_ok=True)
        ckpt_path = outdir+f'/{tag}.pth'
        torch.save({
            'epoch': epoch,
            'model_state_dict': model_states,
            'optimizer_state_dict': optimizer.state_dict(),
        }, ckpt_path)
        print(f"\n---saved {ckpt_path}!---\n")

# Start wandb (if enabled)

In [13]:
if utils.is_interactive():
#     wandb_log = False
    ckpt_saving = False
if local_rank==0 and wandb_log: # only use main process for wandb logging
    import wandb
    wandb_project = 'found'
    print(f"wandb {wandb_project} run {model_name}")
    # need to configure wandb beforehand in terminal with "wandb init"!
    wandb_config = {
      "model_name": model_name,
      "global_batch_size": global_batch_size,
      "batch_size": batch_size,
      "num_epochs": num_epochs,
      "num_samples_per_epoch": num_samples_per_epoch,
      "test_num_samples_per_epoch": test_num_samples_per_epoch,
      "num_iterations_per_epoch": num_iterations_per_epoch,
      "encoder_model": encoder_model,
      "decoder_model": decoder_model,
      "tube_start_masking_ratio": tube_start_masking_ratio,
      "tube_end_masking_ratio": tube_end_masking_ratio,
      "decoder_mask_ratio": decoder_mask_ratio,
      "num_frames": num_frames,
      "patch_size": patch_size,
      "frame_patch_size": frame_patch_size,
      "use_contrastive_loss": use_contrastive_loss,
      "use_cls_token": use_cls_token,
      "contrastive_loss_weight": contrastive_loss_weight,
      "num_params": num_params,
      "max_lr": max_lr,
      "ckpt_interval": ckpt_interval,
      "ckpt_saving": ckpt_saving,
      "seed": seed,
      "distributed": distributed,
      "num_devices": num_devices,
      "world_size": world_size,
      "train_urls": train_urls,
    }
    print("wandb_config:\n",wandb_config)
    print("wandb_id:",model_name)
    wandb.init(
        id=model_name,
        project=wandb_project,
        name=model_name,
        config=wandb_config,
        resume="allow",
    )
else:
    wandb_log = False

wandb found run mini_nomask_logitsCLS_downstream_40ep_l
wandb_config:
 {'model_name': 'mini_nomask_logitsCLS_downstream_40ep_l', 'global_batch_size': 256, 'batch_size': 256, 'num_epochs': 20, 'num_samples_per_epoch': 1024, 'test_num_samples_per_epoch': 384, 'num_iterations_per_epoch': 4, 'encoder_model': 'vit_mini', 'decoder_model': 'vit_mini', 'tube_start_masking_ratio': 0.75, 'tube_end_masking_ratio': 0.75, 'decoder_mask_ratio': 0.75, 'num_frames': 4, 'patch_size': [8, 8, 8], 'frame_patch_size': 4, 'use_contrastive_loss': True, 'use_cls_token': True, 'contrastive_loss_weight': 0.1, 'num_params': 1838256, 'max_lr': 3e-06, 'ckpt_interval': 50, 'ckpt_saving': False, 'seed': 42, 'distributed': False, 'num_devices': 1, 'world_size': 1, 'train_urls': ['/weka/proj-fmri/shared/NSD_MNI_wds/{000000..000699}.tar']}
wandb_id: mini_nomask_logitsCLS_downstream_40ep_l


[34m[1mwandb[0m: Currently logged in as: [33mpaul-scotti[0m. Use [1m`wandb login --relogin`[0m to force relogin


# Start training

In [14]:
epoch = 0
lrs, train_losses, recon_losses, contrastive_losses = [], [], [], []
cos_sim_encoder_output, cos_sim_decoder_output, cos_sim_encoder_output_patchwise = [], [], []
probe_losses, probe_accs, test_losses, test_accs = [], [], [], []

In [15]:
# # resume from ckpt (e.g., if you are resuming from a run that got pre-empted)
# load_progress = False
# if wandb_log:
#     if wandb.run.resumed:
#         load_checkpoint_in_model(model, outdir+"/last")
#         load_progress = True
# elif resume_from_ckpt: # if resuming without using wandb
#     load_checkpoint_in_model(model, outdir+"/last")
#     load_progress = True

In [16]:
if masking_strategy=="MNI":
    from einops.layers.torch import Rearrange
    MNI_brain = nib.load("/weka/proj-fmri/paulscotti/fMRI-foundation-model/dataset_creation/afni_conversion/tpl-MNI152NLin2009cAsym_res-02_T1w_brain.nii.gz").get_fdata()
    brain_pos_voxels = MNI_brain[6:94,8:112,10:82]
    brain_pos_pats = Rearrange(
            "b c (f pf) (d pd) (h ph) (w pw) -> b f d h w (pd ph pw pf c)",
            pd=patch_depth,
            ph=patch_height,
            pw=patch_width,
            pf=1,
        )(torch.Tensor(brain_pos_voxels)[None,None,None])
    brain_pos_pats_vit = rearrange(brain_pos_pats, "b ... d -> b (...) d").mean(-1)[0]

In [18]:
mse = nn.MSELoss()
l1 = nn.L1Loss()
crossentropy = nn.CrossEntropyLoss()
if use_contrastive_loss:
    contrastive_temps = utils.cosine_anneal(0.004, 0.0075, num_epochs)
progress_bar = tqdm(range(epoch, num_epochs), disable=local_rank!=0, desc="Overall")
for epoch in progress_bar:
    # get the masking ratio for the current epoch
    tube_mask_ratio = utils.get_masking_ratio(
        current_epoch=epoch, 
        total_epochs=num_epochs, 
        start_masking_ratio=tube_start_masking_ratio, 
        end_masking_ratio=tube_end_masking_ratio
    )
    with torch.cuda.amp.autocast(dtype=data_type):
        model.train()
        optimizer.train()
        for train_i, batch in enumerate(train_dl):
            optimizer.zero_grad()

            input_func = batch['func.npy']

            subject_id = batch['subject_id.txt']
            subject_id = torch.Tensor([int(subject[-2:]) for subject in subject_id]).long()
            subject_id = torch.repeat_interleave(subject_id.long(), 2).to(device)
            # session_id = batch['session_id.txt']
            # session_id = torch.Tensor([int(session[-2:]) for session in session_id]).long().repeat(2).to(device)
            # session_id = torch.repeat_interleave(session_id.long(), 2)

            if masking_strategy=="None":
                func, _ = aug_transform(input_func)
                brain_pos_pats_vit = torch.ones(num_patches_per_timepoint)
            elif masking_strategy=="MNI":
                func, _ = aug_transform(input_func)
            else:
                func, brain_pos_voxels = aug_transform(input_func)
                brain_pos_pats = model.patchify(torch.Tensor(brain_pos_voxels)[None,None,None])
                brain_pos_pats_vit = rearrange(brain_pos_pats, "b ... d -> b (...) d").mean(-1)[0]

            func = func.reshape(-1, num_frames, 
                                func.shape[-3], func.shape[-2], func.shape[-1])
            func = func.unsqueeze(1).clamp(0,1)
            
            # create encoder and decoder masks
            rand_patches = torch.randperm(num_patches_per_timepoint)
            
            encoder_mask = torch.zeros(num_patches_per_timepoint).to(torch.bool)
            encoder_mask[rand_patches[:num_encoder_patches]] = True
            encoder_mask = encoder_mask.tile(num_frames//frame_patch_size)
            
            decoder_mask = torch.zeros(num_patches_per_timepoint).to(torch.bool)
            decoder_mask[rand_patches[num_encoder_patches:num_encoder_patches+num_decoder_patches]] = True
            decoder_mask = decoder_mask.tile(num_frames//frame_patch_size)

            # encode the tube patches
            encoder_out = model(func, encoder_mask=encoder_mask, device=device)
            if use_cls_token:
                enc_cls_token = encoder_out[:,:1,:]

            # decode both the encoder_out patches and masked decoder patches
            decoder_out = model(encoder_out, encoder_mask=encoder_mask, decoder_mask=decoder_mask, device=device)
            # subset only the reconstructed decoder patches
            output = decoder_out[:, -decoder_mask.sum():]

            # compare to ground truth and calculate loss
            target_patches = model.patchify(func)
            target_patches_vit = rearrange(target_patches, "b ... d -> b (...) d")
            target = target_patches_vit.to(device)[:, decoder_mask]

            target_mean = target.mean(0)
            target_std = target.std(0)
            target_normed = (target - target_mean) / (target_std + 1e-6)

            recon_loss = mse(output, target_normed)
            recon_losses.append(recon_loss.item())
            loss = recon_loss

            # contrastive loss
            if use_contrastive_loss:
                # encode the decoder patches
                encoder_out2 = model(func, encoder_mask=decoder_mask, device=device)
                enc_cls_token2 = encoder_out2[:,:1,:]
                
                temp = contrastive_temps[epoch]
                
                logits = (nn.functional.normalize(enc_cls_token.flatten(1),dim=-1) @
                            nn.functional.normalize(enc_cls_token2.flatten(1),dim=-1).T) / temp
                
                # logits = (nn.functional.normalize(model.cont(encoder_out.flatten(1)),dim=-1) @
                #             nn.functional.normalize(model.cont(encoder_out2.flatten(1)),dim=-1).T) / temp
                
                labels = torch.arange(len(logits)).long().to(device)
                loss1 = crossentropy(logits, labels)
                # loss1 = -(logits.log_softmax(-1) * labels.softmax(-1)).sum(-1).mean()
                loss2 = crossentropy(logits.T, labels)
                contr_loss = (loss1 + loss2)/2
                
                contrastive_losses.append(contr_loss.item())
                loss += (contr_loss * contrastive_loss_weight)

            cos_sim_encoder_output_patchwise.append(utils.patchwise_cosine_similarity(encoder_out).mean().item())
            cos_sim_encoder_output.append(utils.batchwise_cosine_similarity(encoder_out.flatten(1)/1e3,encoder_out.flatten(1)/1e3)[~torch.eye(len(encoder_out),dtype=torch.bool)].mean().item())
            cos_sim_decoder_output.append(utils.batchwise_cosine_similarity(output,output)[~torch.eye(len(output),dtype=torch.bool)].mean().item())

            loss.backward()
            optimizer.step()
            lrs.append(optimizer.param_groups[0]["lr"])
            train_losses.append(loss.item())

            if train_i >= (num_iterations_per_epoch-1):
                break

        # reset linear_probe
        # if use_cls_token:
        #     linear_probe = LinearProbe((num_patches_per_timepoint+1)*model.encoder_embed_dim)
        # else:
        #     linear_probe = LinearProbe(num_patches_per_timepoint*model.encoder_embed_dim)
        linear_probe = LinearProbe(model.encoder_embed_dim)
        linear_probe = linear_probe.to(device)
        probe_opt_grouped_parameters = [
            {'params': [p for n, p in linear_probe.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': 1e-2},
            {'params': [p for n, p in linear_probe.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0},
        ]
        probe_optimizer = torch.optim.AdamW(probe_opt_grouped_parameters, lr=3e-3)

        if True:#(epoch % 5 == 0) or (epoch == num_epochs-1):
            model.eval()
            optimizer.eval()
            linear_probe.train()
            for probe_i, batch in enumerate(train_dl):
                probe_optimizer.zero_grad()

                input_func = batch['func.npy']

                subject_id = batch['subject_id.txt']
                subject_id = torch.Tensor([int(subject[-2:]) for subject in subject_id]).long()
                subject_id = torch.repeat_interleave(subject_id.long(), 2).to(device)

                func, _ = aug_transform(input_func)
                func = func.reshape(-1, num_frames, 
                                    func.shape[-3], func.shape[-2], func.shape[-1])
                func = func.unsqueeze(1).clamp(0,1)

                encoder_mask = torch.ones(num_patches_per_timepoint).to(torch.bool)
                encoder_mask = encoder_mask.tile(num_frames//frame_patch_size)

                # encode the tube patches
                with torch.no_grad():
                    encoder_out = model(func, encoder_mask=encoder_mask, device=device)
                    encoder_out = encoder_out[:,:1,:]
                    encoder_out = nn.functional.normalize(encoder_out,dim=-1)

                # linear probe
                subject_pred = linear_probe(encoder_out.flatten(1).to(device))
                probe_loss = crossentropy(subject_pred, subject_id-1) # minus 1 because subject_id is 1-indexed

                probe_accuracy = (torch.max(subject_pred,1).indices == (subject_id-1)).sum() / len(subject_id)
                probe_accs.append(probe_accuracy.item())
                probe_losses.append(probe_loss.item())

                print(probe_i, probe_accuracy.item(), probe_loss.item())

                probe_loss.backward()
                probe_optimizer.step()

                if probe_i >= (probe_num_iterations_per_epoch-1):
                    break

            for test_i, batch in enumerate(test_dl):
                input_func = batch['func.npy']

                subject_id = batch['subject_id.txt']
                subject_id = torch.Tensor([int(subject[-2:]) for subject in subject_id]).long()
                subject_id = torch.repeat_interleave(subject_id.long(), 2).to(device)

                func, _ = aug_transform(input_func)
                func = func.reshape(-1, num_frames, 
                                    func.shape[-3], func.shape[-2], func.shape[-1])
                func = func.unsqueeze(1).clamp(0,1)

                encoder_mask = torch.ones(num_patches_per_timepoint).to(torch.bool)
                encoder_mask = encoder_mask.tile(num_frames//frame_patch_size)

                # encode the tube patches
                with torch.no_grad():
                    encoder_out = model(func, encoder_mask=encoder_mask, device=device)
                    encoder_out = encoder_out[:,:1,:]
                    encoder_out = nn.functional.normalize(encoder_out,dim=-1)

                # linear probe
                subject_pred = linear_probe(encoder_out.flatten(1).to(device))
                test_loss = crossentropy(subject_pred, subject_id-1) # minus 1 because subject_id is 1-indexed

                test_accuracy = (torch.max(subject_pred,1).indices == (subject_id-1)).sum() / len(subject_id)
                test_accs.append(test_accuracy.item())
                test_losses.append(test_loss.item())

                print("test", test_i, test_accuracy.item(), test_loss.item())

                if test_i >= 1:
                    break

        logs = {
            "train/loss": np.mean(train_losses[-(train_i + 1) :]),
            "train/recon_losses": np.mean(recon_losses[-(train_i + 1) :]),
            "train/contrastive_losses": np.mean(contrastive_losses[-(train_i + 1) :]),
            "train/num_steps": len(recon_losses),
            "train/cos_sim_encoder_output": np.mean(cos_sim_encoder_output[-(train_i + 1) :]),
            "train/cos_sim_decoder_output": np.mean(cos_sim_decoder_output[-(train_i + 1) :]),
            "train/cos_sim_encoder_output_patchwise": np.mean(cos_sim_encoder_output_patchwise[-(train_i + 1) :]),
            "train/probe_losses": np.mean(probe_losses[-(probe_i + 1) :]),
            "train/probe_accs": np.mean(probe_accs[-(probe_i + 1) :]),
            "test/probe_losses": np.mean(test_losses[-(test_i + 1) :]),
            "test/probe_accs": np.mean(test_accs[-(test_i + 1) :]),
            "lr": np.mean(lrs[-(train_i + 1) :]),
            "epoch": epoch,
            "tube_mask_ratio": tube_mask_ratio,
            "decoder_mask_ratio": decoder_mask_ratio,
        }
        progress_bar.set_postfix(**logs)
        if utils.is_interactive(): print(logs)

        # Plot progress (first sample in batch)
        with torch.no_grad():
            if utils.is_interactive() or wandb_log:
                if epoch % 50 == 0:
                    output = (output * target_std) + target_mean
                    idx = 0
                    
                    decode_vis = torch.zeros_like(target_patches_vit)
                    decode_vis[:, decoder_mask] = output.to(decode_vis.device).to(decode_vis.dtype)
                    decoder_unpatches = rearrange(
                        decode_vis,
                        "b (f d h w) c -> b f d h w c",
                        d=img_size[0]//patch_depth,
                        h=img_size[1]//patch_height,
                        w=img_size[2]//patch_width,
                    )
                    decoder_func = rearrange(
                        decoder_unpatches,
                        "b f d h w (pd ph pw pf c) -> b c (f pf) (d pd) (h ph) (w pw)",
                        b=batch_size*2,
                        f=num_frames//frame_patch_size,
                        d=img_size[0]//patch_depth,
                        h=img_size[1]//patch_height,
                        w=img_size[2]//patch_width,
                        pd=patch_depth,
                        ph=patch_height,
                        pw=patch_width,
                        pf=frame_patch_size,
                    )
                    orig_image = utils.reshape_to_2d(func[idx])
                    recon_image = utils.reshape_to_2d(decoder_func[idx])

                    combined_image = orig_image.clone()
                    combined_image[recon_image!=0] = recon_image[recon_image!=0]

                    random_start = np.arange(3100,3450)
                    orig_image = transforms.ToPILImage()(orig_image[:,random_start])
                    recon_image = transforms.ToPILImage()(recon_image[:,random_start])
                    combined_image = transforms.ToPILImage()(combined_image[:,random_start])

                    if wandb_log:
                        logs[f"train/orig"] = wandb.Image(orig_image, caption=f"epoch{epoch:03d}")
                        logs[f"train/recon"] = wandb.Image(recon_image, caption=f"epoch{epoch:03d}")
                        logs[f"train/combined"] = wandb.Image(combined_image, caption=f"epoch{epoch:03d}")
                    else:
                        if epoch==0:
                            print("orig_image")
                            display(orig_image)
                            print("recon_image")
                            display(recon_image)
                            print("combined_image")
                        display(combined_image)

    if wandb_log: wandb.log(logs)

    # Save model checkpoint
    if (ckpt_saving) and ((epoch % ckpt_interval == 0) or (epoch==num_epochs-1)):
        save_ckpt(model,"last")

    # wait for other GPUs to catch up if needed
    if distributed: dist.barrier()
    torch.cuda.empty_cache()
        
if distributed:
    dist.destroy_process_group()

Overall:   0%|          | 0/20 [00:00<?, ?it/s]

0 0.078125 2.080991744995117
test 0 0.146484375 2.087881088256836
test 1 0.033203125 2.1139793395996094
{'train/loss': 0.6546683609485626, 'train/recon_losses': nan, 'train/contrastive_losses': 6.546683311462402, 'train/num_steps': 0, 'train/cos_sim_encoder_output': 0.6011962890625, 'train/cos_sim_decoder_output': nan, 'train/cos_sim_encoder_output_patchwise': 0.2939453125, 'train/probe_losses': 2.080991744995117, 'train/probe_accs': 0.078125, 'test/probe_losses': 2.1009302139282227, 'test/probe_accs': 0.08984375, 'lr': 3e-06, 'epoch': 0, 'tube_mask_ratio': 0.75, 'decoder_mask_ratio': 0.75}


  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)


0 0.2265625 2.132457733154297
test 0 0.10546875 2.0733165740966797
test 1 0.1875 1.9566822052001953
{'train/loss': 0.646026149392128, 'train/recon_losses': nan, 'train/contrastive_losses': 6.460261344909668, 'train/num_steps': 0, 'train/cos_sim_encoder_output': 0.6077880859375, 'train/cos_sim_decoder_output': nan, 'train/cos_sim_encoder_output_patchwise': 0.294189453125, 'train/probe_losses': 2.132457733154297, 'train/probe_accs': 0.2265625, 'test/probe_losses': 2.0149993896484375, 'test/probe_accs': 0.146484375, 'lr': 3e-06, 'epoch': 1, 'tube_mask_ratio': 0.75, 'decoder_mask_ratio': 0.75}
0 0.123046875 2.136058807373047
test 0 0.119140625 2.041532516479492
test 1 0.275390625 1.9539203643798828
{'train/loss': 0.6089650392532349, 'train/recon_losses': nan, 'train/contrastive_losses': 6.0896501541137695, 'train/num_steps': 0, 'train/cos_sim_encoder_output': 0.609375, 'train/cos_sim_decoder_output': nan, 'train/cos_sim_encoder_output_patchwise': 0.3001708984375, 'train/probe_losses': 2.13

In [None]:
encoder_mask.device

In [None]:
decoder_mask.sum()

In [None]:
plt.figure(figsize=(8, 3))
plt.plot(probe_losses)
# plt.title("Training re-construction losses")
plt.show()
if use_contrastive_loss:
    plt.figure(figsize=(8, 3))
    plt.plot(contrastive_losses)
    plt.title("Training contrastive losses")
    plt.show()