In [None]:
import os
import sys
import json
import argparse
import numpy as np
import copy
import math
from einops import rearrange
from einops.layers.torch import Rearrange
import time
import random
import h5py
from tqdm import tqdm
import webdataset as wds
import gc
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
import torchio as tio
import nibabel as nib
import utils
from models import *

from accelerate import Accelerator, DeepSpeedPlugin

# tf32 data type is faster than standard float32
torch.backends.cuda.matmul.allow_tf32 = True

if utils.is_interactive():
    %load_ext autoreload
    # this allows you to change functions in models.py or utils.py and have this notebook automatically update with your revisions
    %autoreload 2

In [None]:
### class token config ###
use_cls_token = True

### Loss Config ###
use_contrastive_loss = True
constrastive_loss_weight = 1.0
use_cls_token = (
    True if use_contrastive_loss else use_cls_token
)  # if using contrastive loss, we need to add a class token

### Multi-GPU config ###
local_rank = os.getenv("RANK")
if local_rank is None:
    local_rank = 0
else:
    local_rank = int(local_rank)
print("LOCAL RANK", local_rank)

num_devices = torch.cuda.device_count()
if num_devices == 0:
    num_devices = 1

accelerator = Accelerator(split_batches=False, mixed_precision="fp16")
global_batch_size = 8
if use_contrastive_loss:
    global_batch_size = (
        global_batch_size / 2
    )  # contrastive loss doubles the batch size with the same samples and different masks
print("GLOBAL BATCH SIZE", global_batch_size)

print("PID of this process =", os.getpid())
device = accelerator.device
print("device:", device)
num_workers = num_devices
print(accelerator.state)
world_size = accelerator.state.num_processes
distributed = not accelerator.state.distributed_type == "NO"
print(
    "distributed =",
    distributed,
    "num_devices =",
    num_devices,
    "local rank =",
    local_rank,
    "world size =",
    world_size,
)
print = accelerator.print  # only print if local_rank=0

# set data_type to match your mixed precision
if accelerator.mixed_precision == "bf16":
    data_type = torch.bfloat16
elif accelerator.mixed_precision == "fp16":
    data_type = torch.float16
else:
    data_type = torch.float32

# Prep models

In [None]:
batch_size = int(global_batch_size / num_devices)
print("batch_size", batch_size)
num_epochs = 30
tube_mask_ratio = 0.75
decoder_mask_ratio = 0.75
input_size = [64, 64, 48]
print("input_size", input_size)
seed = 42
num_frames = 4
tubelet_size = 1

img_size = (64, 64, 48)
patch_size = 8
frame_patch_size = 1
num_patches = int(
    (img_size[0] / patch_size)
    * (img_size[1] / patch_size)
    * (img_size[2] / patch_size)
    * num_frames
)
num_patches_per_timepoint = num_patches // num_frames
num_encoder_patches = int(
    num_patches_per_timepoint * (1 - tube_mask_ratio) * num_frames
)
num_decoder_patches = int(
    num_patches_per_timepoint * (1 - decoder_mask_ratio) * num_frames
)
print("num_patches", num_patches)
print("num_encoder_patches", num_encoder_patches)
print("num_decoder_patches", num_decoder_patches)

max_lr = 3e-5  # 3e-5 seems to be working best? original videomae used 1.5e-4
train_urls = "s3://proj-fmri/fmri_foundation_datasets/openneuro/{000001..000664}.tar"
test_urls = "s3://proj-fmri/fmri_foundation_datasets/openneuro/000000.tar"
num_samples_per_epoch = 512

In [None]:
model = SimpleViT(
    image_size=img_size,  # depth, height, width
    image_patch_size=(
        patch_size,
        patch_size,
        patch_size,
    ),  # depth, height, width patch size
    frames=num_frames,
    frame_patch_size=frame_patch_size,
    depth=12,
    heads=12,
    dim=512,
    mlp_dim=512,  # TODO: right now dim needs to equal mlp_dim, and both need to be 512
    num_encoder_patches=num_encoder_patches,
    num_decoder_patches=num_decoder_patches,
    channels=1,
    use_rope_emb=False,
    use_cls_token=use_cls_token,
)
utils.count_params(model)

# test that the model works without error
model = model.to(device)
encoder_mask = torch.zeros(num_patches).to(device).to(torch.bool)
encoder_mask[:num_encoder_patches] = True
decoder_mask = torch.zeros(num_patches).to(device).to(torch.bool)
decoder_mask[-num_decoder_patches:] = True
with torch.no_grad():
    print("\nencoder")
    encoder_out = model(
        torch.randn(6, 1, 4, 64, 64, 48).to(device),
        encoder_mask=encoder_mask,
        verbose=True,
    )
    print("\ndecoder")
    decoder_out = model(
        encoder_out, encoder_mask=encoder_mask, decoder_mask=decoder_mask, verbose=True
    )
    if use_cls_token:
        enc_cls_token = encoder_out[:, :1, :]
        encoder_patches = encoder_out[:, 1:, :]
        dec_cls_token = decoder_out[:, :1, :]
        decoder_patches = decoder_out[:, 1:, :]
        print("enc_cls_token", enc_cls_token.shape)
        print("encoder_patches", encoder_patches.shape)
        print("dec_cls_token", dec_cls_token.shape)
        print("decoder_patches", decoder_patches.shape)

In [None]:
def my_split_by_node(urls):
    return urls


aug_transform = utils.DataPrepper(
    masking_strategy="conservative",
    patch_depth=patch_size,
    patch_height=patch_size,
    patch_width=patch_size,
    frame_patch_size=frame_patch_size,
)

if train_urls[:2] == "s3":
    train_urls = f"pipe:aws s3 cp {train_urls} -"
print(train_urls)
train_data = (
    wds.WebDataset(train_urls, resampled=True, nodesplitter=my_split_by_node)
    .shuffle(100, initial=100, rng=random.Random(seed))
    .rename(
        key="__key__",
        func="func.png",
        header="header.npy",
        dataset="dataset.txt",
        minmax="minmax.npy",
        meansd="meansd.png",
    )
    .map_dict(
        func=utils.grayscale_decoder,
        meansd=utils.grayscale_decoder,
        minmax=utils.numpy_decoder,
    )
    .to_tuple(*("func", "minmax", "meansd"))
    .map(aug_transform)
    .with_epoch(num_samples_per_epoch)
)
train_dl = torch.utils.data.DataLoader(
    train_data, batch_size=batch_size, shuffle=False, drop_last=True, pin_memory=True
)

if test_urls[:2] == "s3":
    test_urls = f"pipe:aws s3 cp {test_urls} -"
print(test_urls)
test_data = (
    wds.WebDataset(test_urls, resampled=False, nodesplitter=my_split_by_node)
    .rename(
        key="__key__",
        func="func.png",
        header="header.npy",
        dataset="dataset.txt",
        minmax="minmax.npy",
        meansd="meansd.png",
    )
    .map_dict(
        func=utils.grayscale_decoder,
        meansd=utils.grayscale_decoder,
        minmax=utils.numpy_decoder,
    )
    .to_tuple(*("func", "minmax", "meansd"))
    .map(aug_transform)
    .with_epoch(num_samples_per_epoch)
)
test_dl = torch.utils.data.DataLoader(
    test_data, batch_size=batch_size, shuffle=False, drop_last=True, pin_memory=True
)

## test that data loaders work and calculate number of iterations per epoch

In [None]:
train_samp = 0
for train_i, out in enumerate(train_dl):
    train_samp += len(out[0])
num_iterations_per_epoch = train_i
print("num_iterations_per_epoch", num_iterations_per_epoch, "\n")

func, meansd, brain_pos_pats = out
print(func.shape, meansd.shape, brain_pos_pats.shape)

# test_samp = 0
for test_i, out in enumerate(test_dl):
    if test_i > 5:
        break
    # test_samp += len(out[0])

# Playing with the data, visualization of patching + masking

In [None]:
# extract func volumes and their reference mean and standard deviation volumes
func, meansd, brain_pos_pats = out
func = func.unsqueeze(1)  # add empty first dimension to serve as 1d channel dimension

# patchify func samples
print("func", func.shape)
patches = model.patchify(func)
print("patches", patches.shape)

# compress into ViT compatible inputs (bs x seq_len x emb_size)
patches_vit = rearrange(patches, "b ... d -> b (...) d")
print("patches_vit", patches_vit.shape)
print("num patches in one timepoint", patches_vit.shape[1] // num_frames)

tube_mask = torch.zeros(num_patches // num_frames).to(
    torch.bool
)  # start by masking everything (aka include nothing)
batch_positive_approx = (
    brain_pos_pats[:, : num_patches // num_frames].float().mean(dim=0) > 0
)  # approximate brain positive patches for the whole batch
mask_idx_candidates = torch.where(batch_positive_approx)[0]
mask_idx_candidates = mask_idx_candidates[torch.randperm(len(mask_idx_candidates))]
print(
    "Percentage of brain positive patches",
    len(mask_idx_candidates) / len(batch_positive_approx),
)
tube_idx = mask_idx_candidates[: int(num_patches / num_frames * (1 - tube_mask_ratio))]
print("num tube patches =", len(tube_idx))
tube_mask[
    tube_idx
] = True  # Trues mean to include the patch, False means to remove the patch
tube_mask = tube_mask.tile(num_frames)  # repeat masking for the other timepoints
print("tube mask percent", tube_mask.sum().item() / len(tube_mask))


# create decoder mask similar to tube mask, but ensure no overlap
decoder_mask = torch.zeros(num_patches // num_frames).to(
    torch.bool
)  # start by masking everything (aka include nothing)
remaining_mask_idx = mask_idx_candidates[
    int(num_patches / num_frames * (1 - tube_mask_ratio)) :
]  # brain positive tokens not selected for the encoder tokens
decoder_mask_idx = remaining_mask_idx[
    : int(num_patches / num_frames * (1 - decoder_mask_ratio))
]
print("num decoder patches =", len(decoder_mask_idx))
decoder_mask[decoder_mask_idx] = True
decoder_mask = decoder_mask.tile(num_frames)  # repeat masking for the other timepoints
print("decoder_mask percent", decoder_mask.sum() / len(decoder_mask))

# apply masks to patches_vit
tube_patches_vit = copy.deepcopy(patches_vit)
decoder_patches_vit = copy.deepcopy(patches_vit)
tube_patches_vit[:, ~tube_mask] = 0.0
decoder_patches_vit[:, ~decoder_mask] = 0.0

# undo patchification so we can visualize
tube_unpatches = rearrange(
    tube_patches_vit,
    "b (f d h w) c -> b f d h w c",
    f=num_frames,
    d=patch_size,
    h=patch_size,
)
decoder_unpatches = rearrange(
    decoder_patches_vit,
    "b (f d h w) c -> b f d h w c",
    f=num_frames,
    d=patch_size,
    h=patch_size,
)
print("tube_unpatches", tube_unpatches.shape)
print("decoder_unpatches", decoder_unpatches.shape)
tube_func = rearrange(
    tube_unpatches,
    "b f d h w (pd ph pw pf c) -> b c (f pf) (d pd) (h ph) (w pw)",
    b=len(func),
    f=num_frames,
    d=8,
    h=8,
    w=6,
    pd=patch_size,
    ph=patch_size,
    pw=patch_size,
    pf=frame_patch_size,
)
decoder_func = rearrange(
    decoder_unpatches,
    "b f d h w (pd ph pw pf c) -> b c (f pf) (d pd) (h ph) (w pw)",
    b=len(func),
    f=num_frames,
    d=8,
    h=8,
    w=6,
    pd=patch_size,
    ph=patch_size,
    pw=patch_size,
    pf=frame_patch_size,
)
print("tube_func", tube_func.shape)
print("decoder_func", decoder_func.shape)

idx = 0
mean, sd = meansd[idx]
brain_pos_pat = brain_pos_pats[idx]

print("original func without adding mean/sd references")
display(transforms.ToPILImage()(utils.reshape_to_2d(func[idx])))
print("original func")
display(transforms.ToPILImage()(utils.reshape_to_2d(func[idx] * mean + sd)))
print(
    f"Brain positive patches: {brain_pos_pat.count_nonzero()*100/len(brain_pos_pat)}% of the patches are remaining"
)
expanded_mask = np.repeat(
    np.repeat(
        np.repeat(
            brain_pos_pat.view(
                [
                    num_frames // frame_patch_size,
                    img_size[0] // patch_size,
                    img_size[1] // patch_size,
                    img_size[2] // patch_size,
                ]
            ),
            patch_size,
            axis=1,
        ),
        patch_size,
        axis=2,
    ),
    patch_size,
    axis=3,
)
display(transforms.ToPILImage()(utils.reshape_to_2d(expanded_mask).float()))

print("\ntube func without adding mean/sd references")
display(transforms.ToPILImage()(utils.reshape_to_2d(tube_func[idx])))
print("tube func")
display(transforms.ToPILImage()(utils.reshape_to_2d(tube_func[idx] * mean + sd)))

print("\ndecoder func without adding mean/sd references")
display(transforms.ToPILImage()(utils.reshape_to_2d(decoder_func[idx])))
print("decoder func")
display(transforms.ToPILImage()(utils.reshape_to_2d(decoder_func[idx] * mean + sd)))

# Set up optimizer and begin model training

In [None]:
no_decay = ["bias", "LayerNorm.bias", "LayerNorm.weight"]
opt_grouped_parameters = [
    {
        "params": [
            p
            for n, p in model.named_parameters()
            if not any(nd in n for nd in no_decay)
        ],
        "weight_decay": 1e-2,
    },
    {
        "params": [
            p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)
        ],
        "weight_decay": 0.0,
    },
]

optimizer = torch.optim.AdamW(opt_grouped_parameters, lr=max_lr)

total_steps = num_epochs * num_iterations_per_epoch
print("total_steps", total_steps)
lr_scheduler = torch.optim.lr_scheduler.OneCycleLR(
    optimizer,
    max_lr=max_lr,
    total_steps=total_steps,
    final_div_factor=1000,
    last_epoch=-1,
    pct_start=2 / num_epochs,
)

print("\nDone with model preparations!")

In [None]:
epoch = 0
losses, test_losses, lrs = [], [], []
best_test_loss = 1e9
torch.cuda.empty_cache()
model, optimizer, train_dl, lr_scheduler = accelerator.prepare(
    model, optimizer, train_dl, lr_scheduler
)

In [None]:
mse = nn.MSELoss()
if use_contrastive_loss:
    logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))  # learned logit scale
lrs, recon_losses, contrastive_losses, test_losses = [], [], [], []
recon_image_list = []
progress_bar = tqdm(range(epoch, num_epochs), ncols=1200, disable=(local_rank != 0))
for epoch in progress_bar:
    with torch.cuda.amp.autocast(dtype=data_type):
        model.train()
        for train_i, batch in enumerate(
            train_dl
        ):  # total samples in 1 epoch = train_dl.nsamples
            optimizer.zero_grad()

            func, meansd, brain_pos_pats = batch
            if use_contrastive_loss:  # create positive pairs by duplicating the batch
                func = torch.cat([func, func], dim=0)
                meansd = torch.cat([meansd, meansd], dim=0)
                brain_pos_pats = torch.cat([brain_pos_pats, brain_pos_pats], dim=0)

            func = func.unsqueeze(1).to(device)

            # create tube mask (i.e., a mask that is the same for all frames/timepoints)
            tube_mask = torch.zeros(num_patches // num_frames).to(torch.bool)
            batch_positive_approx = (
                brain_pos_pats[:, : num_patches // num_frames].float().mean(dim=0) > 0
            )
            mask_idx_candidates = torch.where(batch_positive_approx)[0]
            mask_idx_candidates = mask_idx_candidates[
                torch.randperm(len(mask_idx_candidates))
            ]
            tube_idx = mask_idx_candidates[
                : int(num_patches / num_frames * (1 - tube_mask_ratio))
            ]
            tube_mask[tube_idx] = True
            tube_mask = tube_mask.tile(num_frames)

            # create decoder mask
            decoder_mask = torch.zeros(num_patches // num_frames).to(torch.bool)
            remaining_mask_idx = mask_idx_candidates[
                int(num_patches / num_frames * (1 - tube_mask_ratio)) :
            ]
            decoder_mask_idx = remaining_mask_idx[
                : int(num_patches / num_frames * (1 - decoder_mask_ratio))
            ]
            decoder_mask[decoder_mask_idx] = True
            decoder_mask = decoder_mask.tile(num_frames)

            # encode the tube patches
            encoder_out = model(func, encoder_mask=tube_mask)
            if use_cls_token:
                enc_cls_token = encoder_out[:, :1, :]

            # decode both the encoder_out patches and masked decoder patches
            decoder_out = model(
                encoder_out, encoder_mask=tube_mask, decoder_mask=decoder_mask
            )
            # subset only the reconstructed decoder patches
            output = decoder_out[:, -num_decoder_patches:]

            # compare to ground truth and calculate loss
            target_patches = model.patchify(func)
            target_patches_vit = rearrange(target_patches, "b ... d -> b (...) d")
            target = target_patches_vit[:, decoder_mask]
            loss = mse(output, target)

            # contrastive loss
            if use_contrastive_loss:
                n_b = len(func) // 2
                cls_token1 = enc_cls_token[
                    :n_b, 0, :
                ]  # first half of batch, cls_token shape B, 1, d_model
                cls_token2 = enc_cls_token[n_b:, 0, :]
                contrastive_loss = utils.contrastive_loss(
                    cls_token1, cls_token2, temperature=logit_scale
                )
                loss += constrastive_loss_weight * contrastive_loss
                contrastive_losses.append(contrastive_loss.item())

            accelerator.backward(loss)
            optimizer.step()
            recon_losses.append(loss.item())
            lrs.append(optimizer.param_groups[0]["lr"])

        model.eval()
        for test_i, batch in enumerate(test_dl):
            func, meansd, brain_pos_pats = batch
            func = func.unsqueeze(1).to(device)

            # create tube mask (i.e., a mask that is the same for all frames/timepoints)
            tube_mask = torch.zeros(num_patches // num_frames).to(torch.bool)
            batch_positive_approx = (
                brain_pos_pats[:, : num_patches // num_frames].float().mean(dim=0) > 0
            )
            mask_idx_candidates = torch.where(batch_positive_approx)[0]
            mask_idx_candidates = mask_idx_candidates[
                torch.randperm(len(mask_idx_candidates))
            ]
            tube_idx = mask_idx_candidates[
                : int(num_patches / num_frames * (1 - tube_mask_ratio))
            ]
            tube_mask[tube_idx] = True
            tube_mask = tube_mask.tile(num_frames)

            # create decoder mask
            decoder_mask = torch.zeros(num_patches // num_frames).to(torch.bool)
            remaining_mask_idx = mask_idx_candidates[
                int(num_patches / num_frames * (1 - tube_mask_ratio)) :
            ]
            decoder_mask_idx = remaining_mask_idx[
                : int(num_patches / num_frames * (1 - decoder_mask_ratio))
            ]
            decoder_mask[decoder_mask_idx] = True
            decoder_mask = decoder_mask.tile(num_frames)

            # encode the tube patches
            encoder_out = model(func, encoder_mask=tube_mask)
            # decode both the encoder_out patches and masked decoder patches
            decoder_out = model(
                encoder_out, encoder_mask=tube_mask, decoder_mask=decoder_mask
            )
            # subset only the reconstructed decoder patches
            output = decoder_out[:, -num_decoder_patches:]

            # # compare to ground truth and calculate loss
            # target_patches = model.patchify(func)
            # target_patches_vit = rearrange(target_patches, 'b ... d -> b (...) d')
            # target = target_patches_vit[:,decoder_mask]
            # loss = mse(output, target)

            # compare to ground truth and calculate loss
            target_patches = model.patchify(func)
            target_patches_vit = rearrange(target_patches, "b ... d -> b (...) d")
            target = target_patches_vit[:, decoder_mask]
            loss = mse(output, target)
            test_losses.append(loss.item())

        logs = {
            "train/loss": np.mean(recon_losses[-(train_i + 1) :]),
            "test/loss": np.mean(test_losses[-(test_i + 1) :]),
        }
        progress_bar.set_postfix(**logs)

        # Plot progress (first sample in batch)
        with torch.no_grad():
            # prep reference volumes for going back to original data
            idx = 0
            mean, sd = meansd[idx]
            mean, sd = mean.to(device), sd.to(device)
            if epoch == 0:
                print("original volumes without adding mean/sd references")
                display(
                    transforms.ToPILImage()(utils.reshape_to_2d(func[idx]) * 5)
                )  # scaling by 5 for visualization contrast
                print("original volumes")
                display(
                    transforms.ToPILImage()(utils.reshape_to_2d(func[idx] * mean + sd))
                )
            if epoch % 5 == 0:
                # undo patchification so we can visualize
                decode_vis = torch.zeros_like(target_patches_vit)
                decode_vis[:, decoder_mask] = output
                decoder_unpatches = rearrange(
                    decode_vis,
                    "b (f d h w) c -> b f d h w c",
                    f=num_frames,
                    d=patch_size,
                    h=patch_size,
                )
                decoder_func = rearrange(
                    decoder_unpatches,
                    "b f d h w (pd ph pw pf c) -> b c (f pf) (d pd) (h ph) (w pw)",
                    b=batch_size,
                    f=num_frames,
                    d=8,
                    h=8,
                    w=6,
                    pd=patch_size,
                    ph=patch_size,
                    pw=patch_size,
                    pf=frame_patch_size,
                )
                print("recons of decoded patches without adding mean/sd references")
                display(
                    transforms.ToPILImage()(utils.reshape_to_2d(decoder_func[idx] * 5))
                )  # scaling by 5 for visualization contrast
                print("recons of decoded patches")
                display(
                    transforms.ToPILImage()(
                        utils.reshape_to_2d(decoder_func[idx] * mean + sd)
                    )
                )

In [None]:
plt.figure(figsize=(8, 3))
plt.plot(recon_losses)
plt.title("Training re-construction losses")
plt.show()
if use_contrastive_loss:
    plt.figure(figsize=(8, 3))
    plt.plot(contrastive_losses)
    plt.title("Training contrastive losses")
    plt.show()
plt.figure(figsize=(8, 3))
plt.plot(test_losses)
plt.title("Test losses")
plt.show()

In [None]:
# save model ckpt
torch.save({"model_state_dict": model.state_dict()}, "last.ckpt")