In [12]:
import os
import sys
import json
import yaml
import numpy as np
import pandas as pd
import copy
import math
from einops import rearrange
from einops.layers.torch import Rearrange
import time
import random
import h5py
import webdataset as wds
import gc
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
import utils
from models import *
from accelerate import Accelerator, load_checkpoint_in_model

# tf32 data type is faster than standard float32
torch.backends.cuda.matmul.allow_tf32 = True

### Multi-GPU config ###
local_rank = os.getenv('RANK')
if local_rank is None: 
    local_rank = 0
else:
    local_rank = int(local_rank)
print("LOCAL RANK ", local_rank)  

num_devices = os.getenv('NUM_GPUS')
if num_devices is None: 
    num_devices = 1
else:
    num_devices = int(num_devices)
print("NUM GPUS ", num_devices)

if utils.is_interactive():
    # Following allows you to change functions in models.py or utils.py and 
    # have this notebook automatically update with your revisions
    %load_ext autoreload
    %autoreload 2
    from tqdm.notebook import tqdm
else:
    from tqdm import tqdm

LOCAL RANK  0
NUM GPUS  1
The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [13]:
# Load parameters from yaml config
config = yaml.load(open('config.yaml', 'r'), Loader=yaml.FullLoader)

# create global variables from the config
for attribute_name in config.keys():
    globals()[attribute_name] = config[f'{attribute_name}']

# First use "accelerate config" in terminal for setup
data_type = torch.float16 # change depending on your mixed_precision
accelerator = Accelerator(split_batches=False, mixed_precision="fp16")
batch_size = global_batch_size // num_devices

In [14]:
print("PID of this process =",os.getpid())
device = accelerator.device
print("device:",device)
world_size = accelerator.state.num_processes
distributed = not accelerator.state.distributed_type == 'NO'
print(accelerator.state)

print("distributed =",distributed, "num_devices =", num_devices, "local rank =", local_rank, "world size =", world_size, "data_type =", data_type)
print = accelerator.print # only print if local_rank=0

PID of this process = 1557848
device: cuda
Distributed environment: DistributedType.NO
Num processes: 1
Process index: 0
Local process index: 0
Device: cuda

Mixed precision type: fp16

distributed = False num_devices = 1 local rank = 0 world size = 1 data_type = torch.float16


In [15]:
print(config)

# seed all random functions
utils.seed_everything(seed)

outdir = os.path.abspath(f'../ckpts/{model_name}')
print("outdir", outdir)

cache_dir = cache_dir + f'/{np.random.randint(9999)}' # create random subfolder so multiple runs arent using same directory
os.makedirs(cache_dir,exist_ok=True)
print("cache_dir", cache_dir)

if use_contrastive_loss:
    global_batch_size = global_batch_size // 2 # contrastive loss doubles the batch size with the same samples and different masks
print("global_batch_size", global_batch_size)

use_cls_token = True if use_contrastive_loss else use_cls_token
print("use_cls_token", use_cls_token)

num_patches = int(
    (img_size[0] / patch_size)
    * (img_size[1] / patch_size)
    * (img_size[2] / patch_size)
    * num_frames
)
num_patches_per_timepoint = num_patches // num_frames
num_encoder_patches = int(num_patches_per_timepoint * (1 - tube_start_masking_ratio) * num_frames)
num_decoder_patches = int(num_patches_per_timepoint * (1 - decoder_mask_ratio) * num_frames)
print("num_patches", num_patches)
print("num_encoder_patches", num_encoder_patches)
print("num_decoder_patches", num_decoder_patches)

{'model_name': 'patch8_100eps_4gpu_accelerate_b', 'use_cls_token': False, 'use_contrastive_loss': False, 'constrastive_loss_weight': 1.0, 'global_batch_size': 8, 'num_workers': 4, 'num_epochs': 100, 'seed': 42, 'max_lr': 3e-05, 'num_samples_per_epoch': 1024, 'cache_dir': 'cache/', 'ema': [0.998, 1.0], 'ipe_scale': 1.25, 'ckpt_saving': True, 'ckpt_interval': 50, 'resume_from_ckpt': False, 'wandb_log': True, 'tube_start_masking_ratio': 0.75, 'tube_end_masking_ratio': 0.75, 'decoder_mask_ratio': 0.75, 'depth': 12, 'heads': 12, 'dim': 512, 'mlp_dim': 512, 'patch_size': 8, 'frame_patch_size': 1, 'use_rope_emb': False, 'img_size': [64, 64, 48], 'num_frames': 4, 'train_urls': 's3://proj-fmri/fmri_foundation_datasets/openneuro/{000005..000664}.tar', 'test_urls': 's3://proj-fmri/fmri_foundation_datasets/openneuro/{000000..000004}.tar'}
outdir /weka/proj-fmri/ks9249/fMRI-foundation-model/ckpts/patch8_100eps_4gpu_accelerate_b
cache_dir cache//7270
global_batch_size 8
use_cls_token False
num_patch

In [16]:
model = SimpleViT(
    image_size=img_size,  # depth, height, width
    image_patch_size=(patch_size,patch_size,patch_size),  # depth, height, width patch size
    frames=num_frames,
    frame_patch_size=frame_patch_size,
    depth=depth,
    heads=heads,
    dim=dim,
    mlp_dim=mlp_dim,  # TODO: right now dim needs to equal mlp_dim, and both need to be 512
    channels=1,
    use_rope_emb=use_rope_emb,
    use_cls_token=use_cls_token,
)
utils.count_params(model)

# test that the model works without error
model = model.to(device)
inputdata = torch.randn(6, 1, 4, 64, 64, 48).to(device)
encoder_mask = torch.zeros(num_patches).to(device).to(torch.bool)
encoder_mask[:num_encoder_patches] = True
maskedtokens = ~encoder_mask
print(encoder_mask)
print(maskedtokens)

with torch.no_grad():
    print("\nxencoder")
    xencoder_out = model(
                inputdata,
                encoder_mask=encoder_mask,
                encoder_type = "x",
                verbose=True)
    print("\npredictor")
    predictor_out = model(
                xencoder_out, 
                encoder_mask=encoder_mask, 
                encoder_type = "p",
                verbose=True)
    print("\nyencoder")
    yencoderout = model(
                inputdata, 
                encoder_mask=encoder_mask, 
                encoder_type = "y",
                verbose=True)
    if use_cls_token:
        enc_cls_token = xencoder_out[:, :1, :]
        encoder_patches = xencoder_out[:, 1:, :]
        pred_cls_token = predictor_out[:, :1, :]
        predictor_patches = predictor_out[:, 1:, :]
        print("\nenc_cls_token", enc_cls_token.shape)
        print("encoder_patches", encoder_patches.shape)
        print("pred_cls_token", pred_cls_token.shape)
        print("predictor_patches", predictor_patches.shape)

param counts:
76,138,496 total
76,138,496 trainable
tensor([ True,  True,  True,  ..., False, False, False], device='cuda:0')
tensor([False, False, False,  ...,  True,  True,  True], device='cuda:0')

xencoder
input shape torch.Size([6, 1, 4, 64, 64, 48])
after patching torch.Size([6, 4, 8, 8, 6, 512])
convert to embedding torch.Size([6, 4, 8, 8, 6, 512])
flattening torch.Size([6, 1536, 512])
positional embedding torch.Size([1536, 512])
current shape torch.Size([6, 1536, 512])
after masking torch.Size([6, 384, 512])
final shape torch.Size([6, 384, 512])

predictor
input shape torch.Size([6, 384, 512])
positional embedding torch.Size([1536, 512])
masked tokens torch.Size([1152, 512])
concatenation torch.Size([6, 1536, 512])
after transformer torch.Size([6, 1152, 512])

yencoder
input shape torch.Size([6, 1, 4, 64, 64, 48])
after patching torch.Size([6, 4, 8, 8, 6, 512])
convert to embedding torch.Size([6, 4, 8, 8, 6, 512])
flattening torch.Size([6, 1536, 512])
positional embedding torch

In [17]:
aug_transform = utils.DataPrepper(
    masking_strategy="conservative",
    patch_depth=patch_size,
    patch_height=patch_size,
    patch_width=patch_size,
    frame_patch_size=frame_patch_size,
)

def log_and_continue(exn):
    """Call in an exception handler to ignore any exception, issue a warning, and continue."""
    print(f'Handling webdataset error ({repr(exn)}). Ignoring.')
    return True

def filter_corrupted_images(sample):
    """If all the required files are not present don't use them."""
    correct_data = ("func.png" in sample and "dataset.txt" in sample and "header.npy" in sample and "meansd.png" in sample and "minmax.npy" in sample)
    return correct_data

### ================      Train Dataset and DataLoader    ====================
if train_urls[:2] == "s3":
    train_urls = f"pipe:aws s3 cp {train_urls} -"
print(train_urls)
train_data = (
    wds.WebDataset(train_urls, resampled=True, nodesplitter=my_split_by_node, cache_dir=cache_dir, handler=log_and_continue)
    .shuffle(100, initial=100, rng=random.Random(seed))
    .select(filter_corrupted_images)
    .rename(key="__key__",
        func="func.png",
        header="header.npy",
        dataset="dataset.txt",
        minmax="minmax.npy",
        meansd="meansd.png")
    .map_dict(func=utils.grayscale_decoder,
        meansd=utils.grayscale_decoder,
        minmax=utils.numpy_decoder)
    .to_tuple(*("func", "minmax", "meansd"))
    .map(aug_transform)
    .with_epoch(num_samples_per_epoch)
)
train_dl = wds.WebLoader(
    train_data.batched(batch_size), 
    pin_memory=True,
    shuffle=False,
    batch_size=None,
    num_workers=num_workers, 
    persistent_workers=num_workers>0,
).with_epoch(num_samples_per_epoch//batch_size)

### ================      Test Dataset and DataLoader    ====================
if test_urls[:2] == "s3":
    test_urls = f"pipe:aws s3 cp {test_urls} -"
print(test_urls)
test_data = (
    wds.WebDataset(test_urls, resampled=False, nodesplitter=my_split_by_node, cache_dir=cache_dir, handler=log_and_continue)
    .select(filter_corrupted_images)
    .rename(key="__key__",
        func="func.png",
        header="header.npy",
        dataset="dataset.txt",
        minmax="minmax.npy",
        meansd="meansd.png")
    .map_dict(func=utils.grayscale_decoder,
        meansd=utils.grayscale_decoder,
        minmax=utils.numpy_decoder)
    .to_tuple(*("func", "minmax", "meansd"))
    .map(aug_transform)
)
test_dl = wds.WebLoader(
    test_data.batched(batch_size), 
    pin_memory=True,
    shuffle=False,
    batch_size=None,
    num_workers=num_workers,
    persistent_workers=num_workers>0,
)

pipe:aws s3 cp s3://proj-fmri/fmri_foundation_datasets/openneuro/{000005..000664}.tar -
pipe:aws s3 cp s3://proj-fmri/fmri_foundation_datasets/openneuro/{000000..000004}.tar -


In [18]:
from accelerate.state import AcceleratorState
try:
    AcceleratorState().deepspeed_plugin.deepspeed_config['train_micro_batch_size_per_gpu'] = global_batch_size
    print("deepspeed reconfigured, train_micro_batch_size_per_gpu = ", global_batch_size)
except:
    print("skipping deepspeed reconfiguration...")

skipping deepspeed reconfiguration...


In [19]:
num_iterations_per_epoch = num_samples_per_epoch // batch_size
print(f"num_iterations_per_epoch {num_iterations_per_epoch}")

train_batch = next(iter(train_dl))
func, meansd, brain_pos_pats = train_batch
print("Train batch:", func.shape, meansd.shape, brain_pos_pats.shape)

for test_num_iterations_per_epoch, test_batch in enumerate(test_dl):
    pass
print(f"test_num_iterations_per_epoch {test_num_iterations_per_epoch}")
func, meansd, brain_pos_pats = test_batch
print("Test batch:", func.shape, meansd.shape, brain_pos_pats.shape)

num_iterations_per_epoch 128
Train batch: torch.Size([8, 4, 64, 64, 48]) torch.Size([8, 2, 64, 64, 48]) torch.Size([8, 1536])
test_num_iterations_per_epoch 199
Test batch: torch.Size([8, 4, 64, 64, 48]) torch.Size([8, 2, 64, 64, 48]) torch.Size([8, 1536])


In [22]:
no_decay = ["bias", "LayerNorm.bias", "LayerNorm.weight"]
opt_grouped_parameters = [
    {'params': [p for n, p in model.xencoder_transformer.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': 1e-2},
    {'params': [p for n, p in model.predictor_transformer.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': 1e-2},
    {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0},
] #we only update the xencoder and predictor, the yencoder is updated through a moving average of the xencoder

optimizer = torch.optim.AdamW(opt_grouped_parameters, lr=max_lr)

total_steps = num_epochs * num_iterations_per_epoch * num_devices
print("total_steps", total_steps)

momentum_scheduler = (ema[0] + i*(ema[1]-ema[0])/(num_iterations_per_epoch*num_epochs*ipe_scale)
                          for i in range(int(num_iterations_per_epoch*num_epochs*ipe_scale)+1))

lr_scheduler = torch.optim.lr_scheduler.OneCycleLR(
    optimizer,
    max_lr=max_lr,
    total_steps=total_steps,
)

print("\nDone with model preparations!")
num_params = utils.count_params(model)

total_steps 12800

Done with model preparations!
param counts:
76,138,496 total
76,138,496 trainable


In [23]:
def save_ckpt(tag="last"):
    ckpt_path = outdir+f'/{tag}'
    os.makedirs(ckpt_path,exist_ok=True)
    accelerator.save_model(model, ckpt_path, max_shard_size="2GB", safe_serialization=True)
    print(f"\n---saved {ckpt_path}!---\n")
        
def save_progress(tag="last"):
    if accelerator.is_main_process:
        ckpt_path = outdir+f'/{tag}'
        torch.save(
                {
                    "optimizer": optimizer.state_dict(),
                    "scheduler": lr_scheduler.state_dict(),
                    "epoch": epoch,
                    "recon_losses": recon_losses,
                    "contrastive_losses": contrastive_losses,
                    "test_losses": test_losses,
                    "lrs": lrs,
                },
                os.path.join(ckpt_path, f"params.pt"),
            )

In [None]:
if accelerator.is_main_process and wandb_log: # only use main process for wandb logging
    import wandb
    wandb_project = 'found'
    print(f"wandb {wandb_project} run {model_name}")
    # need to configure wandb beforehand in terminal with "wandb init"!
    wandb_config = {
      "model_name": model_name,
      "global_batch_size": global_batch_size,
      "batch_size": batch_size,
      "num_epochs": num_epochs,
      "num_samples_per_epoch": num_samples_per_epoch,
      "depth": depth,
      "heads": heads,
      "dim": dim,
      "mlp_dim": mlp_dim,
      "tube_start_masking_ratio": tube_start_masking_ratio,
      "tube_end_masking_ratio": tube_end_masking_ratio,
      "decoder_mask_ratio": decoder_mask_ratio,
      "num_frames": num_frames,
      "patch_size": patch_size,
      "frame_patch_size": frame_patch_size,
      "use_contrastive_loss": use_contrastive_loss,
      "use_cls_token": use_cls_token,
      "constrastive_loss_weight": constrastive_loss_weight,
      "num_params": num_params,
      "max_lr": max_lr,
      "ckpt_interval": ckpt_interval,
      "ckpt_saving": ckpt_saving,
      "seed": seed,
      "distributed": distributed,
      "num_devices": num_devices,
      "world_size": world_size,
      "train_urls": train_urls,
      "test_urls": test_urls,
    }
    print("wandb_config:\n",wandb_config)
    print("wandb_id:",model_name)
    wandb.init(
        id=model_name,
        project=wandb_project,
        name=model_name,
        config=wandb_config,
        resume="allow",
    )
else:
    wandb_log = False

# Start training

In [None]:
epoch = 0
lrs, recon_losses, contrastive_losses, test_losses = [], [], [], []
best_test_loss = 1e9
torch.cuda.empty_cache()

In [None]:
# resume from ckpt (e.g., if you are resuming from a run that got pre-empted)
load_progress = False
if wandb_log:
    if wandb.run.resumed:
        load_checkpoint_in_model(model, outdir+"/last")
        load_progress = True
elif resume_from_ckpt: # if resuming without using wandb
    load_checkpoint_in_model(model, outdir+"/last")
    load_progress = True
    
if load_progress:
    ckpt_path = outdir+'/last'
    prev_params = torch.load(ckpt_path+"/params.pt")
    optimizer.load_state_dict(prev_params["optimizer"])
    lr_scheduler.load_state_dict(prev_params["scheduler"])
    epoch = prev_params["epoch"]
    recon_losses = prev_params["recon_losses"]
    contrastive_losses = prev_params["contrastive_losses"]
    test_losses = prev_params["test_losses"]
    lrs = prev_params["lrs"]
    for _ in range(epoch * num_iterations_per_epoch):
            next(momentum_scheduler)
    print("Loaded model params from", ckpt_path, "at epoch", epoch)

In [None]:
model, optimizer, train_dl, lr_scheduler = accelerator.prepare(
    model, optimizer, train_dl, lr_scheduler
)

In [None]:
l1 = nn.L1Loss() #Following VJEPA architecture, which uses L1 loss not L2 loss

progress_bar = tqdm(range(epoch, num_epochs), disable=not accelerator.is_main_process, desc="Overall")
for epoch in progress_bar:
    # get the masking ratio for the current epoch
    tube_mask_ratio = utils.get_masking_ratio(
        current_epoch=epoch, 
        total_epochs=num_epochs, 
        start_masking_ratio=tube_start_masking_ratio, 
        end_masking_ratio=tube_end_masking_ratio
    )
    with torch.cuda.amp.autocast(dtype=data_type):
        model.train()
        for train_i, batch in enumerate(tqdm(train_dl, disable=not accelerator.is_main_process, 
                 total=num_iterations_per_epoch, leave=False, desc="Training")):
            optimizer.zero_grad()

            func, meansd, brain_pos_pats = batch

            func = func.unsqueeze(1).to(device)

            # create tube mask (i.e., a mask that is the same for all frames/timepoints)
            tube_mask = torch.zeros(num_patches // num_frames).to(torch.bool)
            batch_positive_approx = (brain_pos_pats[:, :num_patches // num_frames].float().mean(dim=0) > 0)
            mask_idx_candidates = torch.where(batch_positive_approx)[0]
            # check if there's not enough brain left for code to continue
            if len(mask_idx_candidates) < (int(num_patches/num_frames*(1-tube_mask_ratio))+int(num_patches/num_frames*(1-decoder_mask_ratio))):
                print("Brain volume skipped due to not enough brain-positive patches remaining...")
                continue
            mask_idx_candidates = mask_idx_candidates[torch.randperm(len(mask_idx_candidates))]
            tube_idx = mask_idx_candidates[:int(num_patches / num_frames * (1 - tube_mask_ratio))]
            tube_mask[tube_idx] = True
            tube_mask = tube_mask.tile(num_frames)

            # feed into x-encoder
            xencoder_out = model(func, encoder_mask=tube_mask, encoder_type = "x")
            
            # feed output of x-encoder into predictor
            predictor_out = model(xencoder_out, encoder_mask=tube_mask, encoder_type="p")
            
            # feed entire func into y-encoder
            yencoder_out = model(func, encoder_mask=tube_mask, encoder_type = "y")
            
            # compare output of predictor to output of y-encoder and calculate L1 Loss
            loss = l1(predictor_out,yencoder_out)
            
            # backwards + step
            accelerator.backward(loss)
            optimizer.step()
            lr_scheduler.step()
            recon_losses.append(loss.item())
            lrs.append(optimizer.param_groups[0]["lr"])
            
            # update y-encoder using exponential-moving average of x-encoder params to prevent collapse
            m = next(momentum_scheduler)
            with torch.no_grad():
                for param_q, param_k in zip(model.xencoder_transformer.parameters(), model.yencoder_transformer.parameters()):
                    param_k.data.mul_(m).add_((1.-m) * param_q.detach().data)
                    
        model.eval()
        for test_i, batch in enumerate(tqdm(test_dl, disable=not accelerator.is_main_process, total=test_num_iterations_per_epoch, leave=False, desc="Testing")):
            func, meansd, brain_pos_pats = batch
            func = func.unsqueeze(1).to(device)

            # create tube mask (i.e., a mask that is the same for all frames/timepoints)
            tube_mask = torch.zeros(num_patches // num_frames).to(torch.bool)
            batch_positive_approx = brain_pos_pats[:, :num_patches // num_frames].float().mean(dim=0) > 0
            mask_idx_candidates = torch.where(batch_positive_approx)[0]
            # check if there's not enough brain left for code to continue
            if len(mask_idx_candidates) < (int(num_patches/num_frames*(1-tube_mask_ratio))+int(num_patches/num_frames*(1-decoder_mask_ratio))):
                if test_i==0:
                    print("Brain volume skipped due to not enough brain-positive patches remaining...")
                continue
            mask_idx_candidates = mask_idx_candidates[torch.randperm(len(mask_idx_candidates))]
            tube_idx = mask_idx_candidates[:int(num_patches / num_frames * (1 - tube_mask_ratio))]
            tube_mask[tube_idx] = True
            tube_mask = tube_mask.tile(num_frames)
            
            # feed into x-encoder
            xencoder_out = model(func, encoder_mask=tube_mask, encoder_type = "x")
            
            # feed output of x-encoder into predictor
            predictor_out = model(xencoder_out, encoder_mask=tube_mask, encoder_type="p")
            
            # feed entire func into y-encoder
            yencoder_out = model(func, encoder_mask=tube_mask, encoder_type = "y")

            # compare output of predictor to output of y-encoder and calculate L1 Loss
            loss = l1(predictor_out,yencoder_out)
            test_losses.append(loss.item())

        logs = {
            "train/loss": np.mean(recon_losses[-(train_i + 1) :]),
            "test/loss": np.mean(test_losses[-(test_i + 1) :]),
            "train/num_steps": len(recon_losses),
            "test/num_steps": len(test_losses),
            "lr": np.mean(lrs[-(train_i + 1) :]),
            "epoch": epoch,
            "tube_mask_ratio": tube_mask_ratio,
        }
        progress_bar.set_postfix(**logs)
        if wandb_log: wandb.log(logs)
                
        # Save model checkpoint
        if (ckpt_saving) and ((epoch % ckpt_interval == 0) or (epoch==num_epochs-1)):
            save_ckpt()
            save_progress()
            
        # wait for other GPUs to catch up if needed
        accelerator.wait_for_everyone()
        torch.cuda.empty_cache()
        gc.collect()
        
# remove cache directories
os.system('rm -fr "%s"' % f"{cache_dir}")