This notebook takes brain voxels and maps them to CLIP-space.

In [7]:
# # convert this notebook to .py such that you can then run it via slurm with "sbatch *.slurm"
# from subprocess import call
# command = "jupyter nbconvert Brain_to_CLIP_refine.ipynb --to python"
# call(command,shell=True)

[NbConvertApp] Converting notebook Brain_to_CLIP_refine2.ipynb to python
[NbConvertApp] Writing 37354 bytes to Brain_to_CLIP_refine2.py


0

# Import packages & functions

In [2]:
import os
import sys
import json
import argparse
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
from tqdm import tqdm
from info_nce import InfoNCE
from dalle2_pytorch import DiffusionPriorNetwork
import kornia
from kornia.augmentation.container import AugmentationSequential

import torch.distributed as dist
from accelerate import Accelerator

# uses tf32 data type which is faster than standard float32
torch.backends.cuda.matmul.allow_tf32 = True

# Multi-GPU config #
accelerator = Accelerator()
print = accelerator.print # only print if local_rank=0

device = accelerator.device
print("device:",device)

num_devices = torch.cuda.device_count()
if num_devices==0: num_devices = 1
num_workers = num_devices

print(accelerator.state)
local_rank = accelerator.state.local_process_index
world_size = accelerator.state.num_processes
if num_devices<=1 and world_size<=1:
    distributed=False
else:
    distributed=True
print("distributed =",distributed,"num_devices =", num_devices, "local rank =", local_rank, "world size =", world_size)

# custom models and functions #
import utils
from utils import torch_to_matplotlib, torch_to_Image
from models import BrainNetwork#, BrainDiffusionPrior
# from model3d import SimpleVoxel3dConvEncoder

device: cuda
Distributed environment: NO
Num processes: 1
Process index: 0
Local process index: 0
Device: cuda

Mixed precision type: no

distributed = False num_devices = 1 local rank = 0 world size = 1


# Configurations

In [3]:
# can specify jupyter_args here for argparser to use if running this code interactively
if utils.is_interactive():
    jupyter_args=[]
    jupyter_args.append("--model_name=testing")
    jupyter_args.append("--modality=image")
    jupyter_args.append("--clip_variant=ViT-L/14")
    jupyter_args.append("--batch_size=256")
    jupyter_args.append("--num_epochs=100")
    jupyter_args.append("--with_mse")
    jupyter_args.append("--versatile")
    jupyter_args.append("--mixup_pct=0.")
    jupyter_args.append("--mse_pct=0.")
    jupyter_args.append("--max_lr=1e-4")
    print(jupyter_args)
    
    %load_ext autoreload
    %autoreload 2

['--model_name=testing', '--modality=image', '--clip_variant=ViT-L/14', '--batch_size=256', '--num_epochs=100', '--with_mse', '--versatile', '--mixup_pct=0.', '--mse_pct=0.', '--max_lr=1e-4']


In [4]:
parser = argparse.ArgumentParser(description="Model Training Configuration")
parser.add_argument(
    "--model_name", type=str, default="testing",
    help="name of model, used for ckpt saving and wandb logging",
)
parser.add_argument(
    "--modality", type=str, default="image", choices=["image", "text"],
    help="image or text",
)
parser.add_argument(
    "--batch_size", type=int, default=300,
    help="Our maximum for A100 was 300 for 1dim voxels and 128 for 3dim voxels",
)
parser.add_argument(
    "--clip_variant",type=str,default="ViT-L/14",choices=["RN50", "ViT-L/14", "ViT-B/32", "ViT-H-14", "RN50x64"],
    help='clip / openclip variant',
)
parser.add_argument(
    "--outdir",type=str,default=None,
    help="output directory for logs and checkpoints",
)
parser.add_argument(
    "--wandb_log",action=argparse.BooleanOptionalAction,default=False,
    help="whether to log to wandb",
)
parser.add_argument(
    "--resume_from_ckpt",action=argparse.BooleanOptionalAction,default=False,
    help="if not using wandb and want to resume from a ckpt",
)
parser.add_argument(
    "--wandb_project",type=str,default="stability",
    help="wandb project name",
)
parser.add_argument(
    "--mixup_pct",type=float,default=.5,
    help="proportion of way through training when to switch from InfoNCE to soft_clip_loss",
)
parser.add_argument(
    "--voxel_dims",type=int,default=1,choices=[1, 3],
    help="1 for flattened input, 3 for 3d input",
)
parser.add_argument(
    "--use_image_aug",action=argparse.BooleanOptionalAction,default=True,
    help="whether to use image augmentation (only used for modality=image)",
)
parser.add_argument(
    "--num_epochs",type=int,default=120,
)
parser.add_argument(
    "--lr_scheduler",type=str,default='cycle',choices=['cycle','fixed'],
)
parser.add_argument(
    "--ckpt_saving",action=argparse.BooleanOptionalAction,default=True,
)
parser.add_argument(
    "--ckpt_interval",type=int,default=1,
    help="save ckpt every x epochs",
)
parser.add_argument(
    "--save_at_end",action=argparse.BooleanOptionalAction,default=False,
    help="if False, will save best.ckpt whenever epoch shows best validation score",
)
parser.add_argument(
    "--seed",type=int,default=42,
)
parser.add_argument(
    "--with_mse",action=argparse.BooleanOptionalAction,default=False,
    help="Add mse loss to the other losses",
)
parser.add_argument(
    "--mse_mult",type=int,default=1,
    help="Multiplier for mse loss",
)
parser.add_argument(
    "--text_token",action=argparse.BooleanOptionalAction,default=False,
    help="Map to text token space instead of CLIP",
)
parser.add_argument(
    "--versatile",action=argparse.BooleanOptionalAction,default=False,
    help="Map to 257x768 versatile diffusion CLIP space",
)
parser.add_argument(
    "--mse_pct",type=float,default=1.0,
    help="What percentage of way through training to start adding mse loss",
)
parser.add_argument(
    "--initial_lr",type=float,default=3e-4,
    help="lr if lr_scheduler is fixed",
)
parser.add_argument(
    "--max_lr",type=float,default=3e-4,
    help="max_lr if lr_scheduler is onecycle",
)
parser.add_argument(
    "--att",action=argparse.BooleanOptionalAction,default=False,
    help="Map to 257x1024 instead of 257x768 versatile diffusion CLIP space",
)

if utils.is_interactive():
    args = parser.parse_args(jupyter_args)
else:
    args = parser.parse_args()

# create global variables without the args prefix
for attribute_name in vars(args).keys():
    globals()[attribute_name] = getattr(args, attribute_name)

In [5]:
if outdir is None:
    outdir = os.path.abspath(f'../train_logs/{model_name}')
if not os.path.exists(outdir):
    os.makedirs(outdir,exist_ok=True)
if use_image_aug:
    train_augs = AugmentationSequential(
        kornia.augmentation.RandomResizedCrop((224,224), (0.6,1), p=0.3),
        kornia.augmentation.Resize((224, 224)),
        kornia.augmentation.RandomHorizontalFlip(p=0.5),
        kornia.augmentation.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1, p=0.3),
        kornia.augmentation.RandomGrayscale(p=0.3),
        data_keys=["input"],
    )
else:
    train_augs = None
if modality=='text':
    annots = np.load("/fsx/proj-medarc/fmri/natural-scenes-dataset/COCO_73k_annots_curated.npy")
if text_token and clip_variant!="ViT-H-14":
    from transformers import CLIPTextModel, CLIPTokenizer
    tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
    text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14").to(device)
    text_encoder.eval()
    text_encoder.requires_grad_(False)
elif text_token and clip_variant=="ViT-H-14":
    from diffusers import StableUnCLIPImg2ImgPipeline
    sd_cache_dir = '/fsx/proj-medarc/fmri/cache/models--stabilityai--stable-diffusion-2-1-unclip/snapshots/5eaf116f1b118d1756d5df9f578e8259befa95b7'
    with torch.no_grad():
        sd_pipe = StableUnCLIPImg2ImgPipeline.from_pretrained(
            sd_cache_dir, torch_dtype=torch.float16,
        )
    tokenizer = sd_pipe.tokenizer
    text_encoder = sd_pipe.text_encoder.to(device)
    text_encoder.eval()
    text_encoder.requires_grad_(False)
    del sd_pipe

# Prep models and data loaders

In [6]:
# LOAD NON-MSE CHECKPOINT
checkpoint = torch.load('/fsx/proj-medarc/fmri/paulscotti/fMRI-reconstruction-NSD/train_logs/v2c_vers/last.pth',
                       map_location='cpu')
non_mse_ckpt = checkpoint['model_state_dict']
del checkpoint

In [7]:
# need non-deterministic CuDNN for conv3D to work
utils.seed_everything(seed, cudnn_deterministic=False)

print('Pulling NSD webdataset data...')

train_url = "{/fsx/proj-medarc/fmri/natural-scenes-dataset/webdataset_avg_split/train/train_subj01_{0..17}.tar,/fsx/proj-medarc/fmri/natural-scenes-dataset/webdataset_avg_split/val/val_subj01_0.tar}"
val_url = "/fsx/proj-medarc/fmri/natural-scenes-dataset/webdataset_avg_split/test/test_subj01_{0..1}.tar"
meta_url = "/fsx/proj-medarc/fmri/natural-scenes-dataset/webdataset_avg_split/metadata_subj01.json"
num_train = 8559 + 300
num_val = 982

# which to use for the voxels
if voxel_dims == 1:
    voxels_key = 'nsdgeneral.npy'
elif voxel_dims == 3:
    voxels_key = 'wholebrain_3d.npy'
else:
    raise Exception(f"voxel_dims must be 1 or 3, not {voxel_dims}")

print('Prepping train and validation dataloaders...')
train_dl, val_dl, num_train, num_val = utils.get_dataloaders(
    batch_size,'images',
    num_devices=num_devices,
    num_workers=num_workers,
    train_url=train_url,
    val_url=val_url,
    meta_url=meta_url,
    num_train=num_train,
    num_val=num_val,
    val_batch_size=300,
    cache_dir="/tmp/wds-cache",
    seed=seed,
    voxels_key=voxels_key,
    to_tuple=["voxels", "images", "coco"],
    local_rank=local_rank,
)

if voxel_dims == 3:
    import nibabel as nib
    noise_ceils_path = '/fsx/proj-medarc/fmri/natural-scenes-dataset/temp_s3/nsddata_betas/ppdata/subj01/func1pt8mm/betas_fithrf_GLMdenoise_RR/ncsnr.nii.gz'
    noise_ceils = nib.load(noise_ceils_path).get_fdata()
    x_inc,y_inc,z_inc = np.where(noise_ceils > .5) # voxel.shape torch.Size([300, 3, 68, 64, 47])

Note: not using cudnn.deterministic
Pulling NSD webdataset data...
Prepping train and validation dataloaders...
Getting dataloaders...

num_train 8859
global_batch_size 256
batch_size 256
num_workers 1
num_batches 34
num_worker_batches 34
cache_dir None

num_val 982
val_batch_size 300
val_num_workers 1


In [8]:
print('Creating Clipper...')
    
# Don't L2 norm the extracted CLIP embeddings since we want the prior 
# to learn un-normed embeddings for usage with the SD image variation pipeline.
if clip_variant == "ViT-H-14":
    from models import OpenClipper
    clip_extractor = OpenClipper(clip_variant, device=device, train_transforms=train_augs)
    out_dim = 1024
else:
    from models import Clipper
    if versatile:
        print("Using versatile CLIP space")
        if not att:
            clip_extractor = Clipper(clip_variant, device=device, hidden_state=True, refine=False, train_transforms=train_augs)
            out_dim = 257 * 768
        else:
            clip_extractor = Clipper(clip_variant, device=device, hidden_state=True, refine=True, train_transforms=train_augs)
            print("HIDDEN STATE AND REFINE BOTH TRUE")
            out_dim = 257 * 1024
    else:
        clip_extractor = Clipper(clip_variant, device=device, train_transforms=train_augs)
        out_dim = 768
print("out_dim:",out_dim)

print('Creating voxel2clip...')

if voxel_dims == 1: # 1D data
    voxel2clip_kwargs = dict(out_dim=out_dim)
    voxel2clip = BrainNetwork(**voxel2clip_kwargs)
elif voxel_dims == 3: # 3D data
    if text_token:
        voxel2clip_kwargs = dict(
            out_dim=77*outdim,
            dims=voxel.shape[2:],
            channels=[64, 128, 256, 128],
            strides=[1, 2, 3, 3],
            padding=[1, 1, 1, 1],
            dilation=[1, 1, 1, 1],
            kernel=[3, 3, 3, 3],
        )
    else:
        voxel2clip_kwargs = dict(
            out_dim=out_dim,
            dims=voxel.shape[2:],
            channels=[64, 128, 256, 128],
            strides=[1, 2, 3, 3],
            padding=[1, 1, 1, 1],
            dilation=[1, 1, 1, 1],
            kernel=[3, 3, 3, 3],
        )
    voxel2clip = SimpleVoxel3dConvEncoder(**voxel2clip_kwargs)  

print("params of voxel2clip:")
if local_rank==0:
    utils.count_params(voxel2clip)

no_decay = ['bias']
opt_grouped_parameters = [
    {'params': [p for n, p in voxel2clip.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': 1e-2},
    {'params': [p for n, p in voxel2clip.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
]
optimizer = torch.optim.AdamW(opt_grouped_parameters, lr=initial_lr) # lr doesnt get used if lr_scheduler='cycle'

if lr_scheduler == 'fixed':
    lr_scheduler = None
elif lr_scheduler == 'cycle':
    global_batch_size = batch_size * num_devices
    total_steps=num_epochs*(num_train//global_batch_size)
    lr_scheduler = torch.optim.lr_scheduler.OneCycleLR(
        optimizer, 
        max_lr=max_lr,
        total_steps=total_steps,
        final_div_factor=1000,
        last_epoch=-1, pct_start=2/num_epochs
    )
    
def save_ckpt(tag):
    ckpt_path = outdir+f'/{tag}.pth'
    print(f'saving {ckpt_path}',flush=True)
    state_dict = voxel2clip.state_dict()
    if lr_scheduler!='fixed':
        lr_dict = lr_scheduler.state_dict()
    else:
        lr_dict = None
    if with_mse:
        torch.save({
            'epoch': epoch,
            'model_state_dict': voxel2clip.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'lr_scheduler': lr_dict,
            'train_losses': losses,
            'val_losses': val_losses,
            'fwd_percent_correct': fwd_percent_correct,
            'bwd_percent_correct': bwd_percent_correct,
            'val_fwd_percent_correct': val_fwd_percent_correct,
            'val_bwd_percent_correct': val_bwd_percent_correct,
            'lrs': lrs,
            "mse_losses": mse_losses,
            "val_mse_losses": val_mse_losses,
            }, ckpt_path)
    else:
        torch.save({
            'epoch': epoch,
            'model_state_dict': voxel2clip.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'lr_scheduler': lr_dict,
            'train_losses': losses,
            'val_losses': val_losses,
            'fwd_percent_correct': fwd_percent_correct,
            'bwd_percent_correct': bwd_percent_correct,
            'val_fwd_percent_correct': val_fwd_percent_correct,
            'val_bwd_percent_correct': val_bwd_percent_correct,
            'lrs': lrs,
            }, ckpt_path)
        
print("\nDone with model preparations!")

Creating Clipper...
Using versatile CLIP space
ViT-L/14 cuda
HIDDEN STATE AND REFINE BOTH TRUE
out_dim: 263168
Creating voxel2clip...
params of voxel2clip:
param counts:
1,209,775,104 total
1,209,775,104 trainable

Done with model preparations!


# Weights and Biases

In [9]:
# params for wandb
if local_rank==0 and wandb_log:
    import wandb
    
    wandb_project = 'stability'
    wandb_run = model_name
    wandb_notes = ''
    
    print(f"wandb {wandb_project} run {wandb_run}")
    wandb.login(host='https://stability.wandb.io')#, relogin=True)
    wandb_config = {
      "model_name": model_name,
      "modality": modality,
      "text_token": text_token,
      "voxel_dims": voxel_dims,
      "clip_variant": clip_variant,
      "batch_size": batch_size,
      "num_epochs": num_epochs,
      "use_image_aug": use_image_aug,
      "max_lr": max_lr,
      "lr_scheduler": lr_scheduler,
      "mixup_pct": mixup_pct,
      "mse_pct": mse_pct,
      "num_train": num_train,
      "num_val": num_val,
      "seed": seed,
      "distributed": distributed,
      "num_devices": num_devices,
      "world_size": world_size,
      "train_url": train_url,
      "val_url": val_url,
    }
    print("wandb_config:\n",wandb_config)
    if True: # wandb_auto_resume
        print("wandb_id:",model_name)
        wandb.init(
            id = model_name,
            project=wandb_project,
            name=wandb_run,
            config=wandb_config,
            notes=wandb_notes,
            resume="allow",
        )
    else:
        wandb.init(
            project=wandb_project,
            name=wandb_run,
            config=wandb_config,
            notes=wandb_notes,
        )
else:
    wandb_log = False

# Start from versatile-nonmse ckpt

In [10]:
# state_dict = non_mse_ckpt
# for key in list(state_dict.keys()):
#     if 'module.' in key:
#         state_dict[key.replace('module.', '')] = state_dict[key]
#         del state_dict[key]
# voxel2clip.load_state_dict(non_mse_ckpt)

# Huggingface Accelerate

In [11]:
voxel2clip, optimizer, train_dl, val_dl, lr_scheduler = accelerator.prepare(
    voxel2clip, optimizer, train_dl, val_dl, lr_scheduler
)

# Main

In [12]:
from diffusers.utils import randn_tensor

from transformers import CLIPVisionModelWithProjection
sd_cache_dir = '/fsx/proj-medarc/fmri/cache/models--shi-labs--versatile-diffusion/snapshots/2926f8e11ea526b562cd592b099fcf9c2985d0b7'
image_encoder = CLIPVisionModelWithProjection.from_pretrained(sd_cache_dir, subfolder='image_encoder').to(device)

from diffusers import AutoencoderKL, PNDMScheduler, UNet2DConditionModel, UniPCMultistepScheduler
sd_cache_dir = '/fsx/proj-medarc/fmri/cache/models--shi-labs--versatile-diffusion/snapshots/2926f8e11ea526b562cd592b099fcf9c2985d0b7'
unet = UNet2DConditionModel.from_pretrained(sd_cache_dir,subfolder="image_unet").to(device)

unet.eval() # dont want to train model
unet.requires_grad_(False) # dont need to calculate gradients

vae = AutoencoderKL.from_pretrained(sd_cache_dir,subfolder="vae").to(device)
vae.eval()
vae.requires_grad_(False)

scheduler = "unipc" # "pndms" or "unipc"
noise_scheduler = PNDMScheduler.from_pretrained(sd_cache_dir, subfolder="scheduler")
noise_scheduler = UniPCMultistepScheduler.from_config(noise_scheduler.config)
num_inference_steps = 20

def decode_latents(latents):
    latents = 1 / 0.18215 * latents
    image = vae.decode(latents).sample
    image = (image / 2 + 0.5).clamp(0, 1)
    return image

guidance_scale=7.5
do_classifier_free_guidance = guidance_scale > 1.0
vae_scale_factor = 2 ** (len(vae.config.block_out_channels) - 1)
height = unet.config.sample_size * vae_scale_factor
width = unet.config.sample_size * vae_scale_factor

In [13]:
# need non-deterministic CuDNN for conv3D to work
utils.seed_everything(seed, cudnn_deterministic=False)

# mse_alphas = np.hstack((np.zeros(int(mse_pct * num_epochs)), np.linspace(0,1,int(num_epochs-(mse_pct * num_epochs)))**2))
mse_alphas = np.ones(num_epochs)

epoch = 0
losses, val_losses, lrs, mse_losses, val_mse_losses = [], [], [], [], []
best_val_loss = 1e9
soft_loss_temps = utils.cosine_anneal(0.004, 0.0075, num_epochs - int(mixup_pct * num_epochs))

mse = nn.MSELoss()
val_voxel0 = val_image0 = None

# Optionally resume from checkpoint #
if resume_from_ckpt:
    print("\n---resuming from last.pth ckpt---\n")
    checkpoint = torch.load(outdir+'/last.pth')
    epoch = checkpoint['epoch']
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
    voxel2clip.load_state_dict(checkpoint['model_state_dict'])
elif wandb_log:
    if wandb.run.resumed:
        print("\n---resuming from last.pth ckpt---\n")
        checkpoint = torch.load(outdir+'/last.pth')
        epoch = checkpoint['epoch']
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
        voxel2clip.load_state_dict(checkpoint['model_state_dict'])

print(f"{model_name} starting with epoch {epoch} / {num_epochs}")
progress_bar = tqdm(range(epoch,num_epochs), ncols=1200, disable=(local_rank!=0))
for epoch in progress_bar:
    voxel2clip.train()

    sims = 0.
    sims_base = 0.
    val_sims = 0.
    val_sims_base = 0.
    fwd_percent_correct = 0.
    bwd_percent_correct = 0.
    val_fwd_percent_correct = 0.
    val_bwd_percent_correct = 0.

    for train_i, (voxel, image, coco) in enumerate(train_dl):
        optimizer.zero_grad()
        
        repeat_index = train_i % 3

        image = image.float()
        voxel = voxel.float()[:,repeat_index].float()
        
        if voxel_dims == 3:
            voxel = voxel[:,np.unique(x_inc),:,:]
            voxel = voxel[:,:,np.unique(y_inc),:]
            voxel = voxel[:,:,:,np.unique(z_inc)]

        if epoch < int(mixup_pct * num_epochs):
            voxel, perm, betas, select = utils.mixco(voxel)

        if text_token:
            img_annots = utils.select_annotations(annots[coco.cpu().numpy()], random=True)
            img_annots = tokenizer(img_annots.tolist(), padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt")
            img_annots = text_encoder(img_annots.input_ids.to(device))[0]
            clip_target = img_annots.reshape(len(img_annots),-1).float()
        elif modality=="text":
            img_annots = utils.select_annotations(annots[coco.cpu().numpy()], random=True)
            clip_target = clip_extractor.embed_text(img_annots).float()
        else:
            clip_target = clip_extractor.embed_image(image).float()
        clip_target=clip_target.reshape(len(clip_target),-1).to(voxel.dtype)
        
        clip_voxels = voxel2clip(voxel)            
            
        # if epoch < int(mixup_pct * num_epochs):
        #     loss = utils.mixco_nce(
        #         nn.functional.normalize(clip_voxels, dim=-1), 
        #         nn.functional.normalize(clip_target, dim=-1),
        #         temp=0.006, perm=perm, betas=betas, select=select,
        #         distributed=distributed, accelerator=accelerator, local_rank=local_rank)
        # else:
        #     epoch_temp = soft_loss_temps[epoch-int(mixup_pct*num_epochs)]
        #     loss = utils.soft_clip_loss(
        #         nn.functional.normalize(clip_voxels, dim=-1), 
        #         nn.functional.normalize(clip_target, dim=-1),
        #         temp=epoch_temp,
        #         distributed=distributed, accelerator=accelerator)
        
        if with_mse:
            # clip_voxels_norm = clip_voxels.reshape(-1, 257, 768)
            # clip_voxels_norm = nn.functional.normalize(clip_voxels_norm,dim=-1)
            # clip_target_norm = clip_target.reshape(-1, 257, 768)
            # clip_target_norm = nn.functional.normalize(clip_target_norm,dim=-1)
            
            mse_amount = mse_alphas[epoch]
            mseloss = mse(clip_voxels,clip_target)*mse_mult
            # #loss = (loss*(1-mse_amount)) + (mseloss * mse_amount)
            loss = mseloss
            mse_losses.append(mseloss.item())
        utils.check_loss(loss)

        losses.append(loss.item())
        lrs.append(optimizer.param_groups[0]['lr'])

        if distributed:
            sims_base += F.cosine_similarity(accelerator.gather(clip_target),
                                                  accelerator.gather(clip_voxels)).mean().item()
        else:
            sims_base += F.cosine_similarity(clip_target,clip_voxels).mean().item()

        # forward and backward top 1 accuracy
        labels = torch.arange(len(clip_target)).to(device)
        fwd_percent_correct += utils.topk(utils.batchwise_cosine_similarity(clip_target, clip_voxels), labels, k=1)
        bwd_percent_correct += utils.topk(utils.batchwise_cosine_similarity(clip_voxels, clip_target), labels, k=1)

        accelerator.backward(loss)
        optimizer.step()

        if lr_scheduler is not None:
            lr_scheduler.step()

    voxel2clip.eval()
    if local_rank==0: # i think its possible to remove this if statement though with some revisions
        for val_i, (voxel, image, coco) in enumerate(val_dl): 
            with torch.no_grad():
                #repeat_index = val_i % 3

                image = image.float()
                #voxel = voxel[:,repeat_index].float()
                voxel = torch.mean(voxel,axis=1).float().to(device)

                if voxel_dims == 3:
                    voxel = voxel[:,np.unique(x_inc),:,:]
                    voxel = voxel[:,:,np.unique(y_inc),:]
                    voxel = voxel[:,:,:,np.unique(z_inc)]

                if val_image0 is None:
                    val_image0 = image.detach().clone()
                    val_voxel0 = voxel.detach().clone()

                if text_token:
                    img_annots = utils.select_annotations(annots[coco.cpu().numpy()], random=True)
                    img_annots = tokenizer(img_annots.tolist(), padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt")
                    img_annots = text_encoder(img_annots.input_ids.to(device))[0]
                    clip_target = img_annots.reshape(len(img_annots),-1).float()
                elif modality=="text":
                    img_annots = utils.select_annotations(annots[coco.cpu().numpy()], random=False)
                    clip_target = clip_extractor.embed_text(img_annots).float()
                else:
                    clip_target = clip_extractor.embed_image(image).float()
                clip_target=clip_target.reshape(len(clip_target),-1).to(voxel.dtype)

                clip_voxels = voxel2clip(voxel)

                # if epoch < int(mixup_pct * num_epochs):
                #     val_loss = utils.mixco_nce(
                #         nn.functional.normalize(clip_voxels, dim=-1), 
                #         nn.functional.normalize(clip_target, dim=-1),
                #         temp=0.006, 
                #         distributed=distributed, accelerator=accelerator, local_rank=local_rank)
                # else:
                #     epoch_temp = soft_loss_temps[epoch-int(mixup_pct*num_epochs)]
                #     val_loss = utils.soft_clip_loss(
                #         nn.functional.normalize(clip_voxels, dim=-1), 
                #         nn.functional.normalize(clip_target, dim=-1),
                #         temp=epoch_temp, 
                #         distributed=distributed, accelerator=accelerator)

                if with_mse:
                    # clip_voxels_norm = clip_voxels.reshape(-1, 257, 768)
                    # clip_voxels_norm = nn.functional.normalize(clip_voxels_norm,dim=-1)
                    # clip_target_norm = clip_target.reshape(-1, 257, 768)
                    # clip_target_norm = nn.functional.normalize(clip_target_norm,dim=-1)
                    
                    val_mseloss = mse(clip_voxels,clip_target)*mse_mult
                    # # val_loss = (val_loss*(1-mse_amount)) + (val_mseloss * mse_amount)
                    val_loss = val_mseloss
                    val_mse_losses.append(val_mseloss.item())
                    
                utils.check_loss(val_loss)

                val_losses.append(val_loss.item())

                if distributed:
                    val_sims_base += F.cosine_similarity(accelerator.gather(clip_target),
                                                          accelerator.gather(clip_voxels)).mean().item()
                else:
                    val_sims_base += F.cosine_similarity(clip_target,clip_voxels).mean().item()

                labels = torch.arange(len(clip_target)).to(device)
                val_fwd_percent_correct += utils.topk(utils.batchwise_cosine_similarity(clip_target, clip_voxels), labels, k=1)
                val_bwd_percent_correct += utils.topk(utils.batchwise_cosine_similarity(clip_voxels, clip_target), labels, k=1)

        if (not save_at_end and ckpt_saving) or (save_at_end and epoch == num_epochs - 1):
            # save best model
            val_loss = np.mean(val_losses[-(val_i+1):])
            if val_loss < best_val_loss:
                best_val_loss = val_loss
                save_ckpt('best')
            else:
                print(f'not best - val_loss: {val_loss:.3f}, best_val_loss: {best_val_loss:.3f}')
        
        # Save model checkpoint every `ckpt_interval`` epochs or on the last epoch
        if (ckpt_interval is not None and (epoch + 1) % ckpt_interval == 0) or epoch == num_epochs - 1:
            save_ckpt(f'last')
            if wandb_log: # save last ckpt so you can resume from it if need be
                wandb.save(os.path.abspath(outdir)+'/last.pth', base_path=os.path.abspath(outdir))

        logs = {"train/loss": np.mean(losses[-(train_i+1):]),
                "val/loss": np.mean(val_losses[-(val_i+1):]),
                "train/lr": lrs[-1],
                "train/num_steps": len(losses),
                "val/num_steps": len(val_losses),
                "train/cosine_sim_base": sims_base / (train_i + 1),
                "val/cosine_sim_base": val_sims_base / (val_i + 1),
                "train/fwd_pct_correct": fwd_percent_correct / (train_i + 1),
                "train/bwd_pct_correct": bwd_percent_correct / (train_i + 1),
                "val/val_fwd_pct_correct": val_fwd_percent_correct / (val_i + 1),
                "val/val_bwd_pct_correct": val_bwd_percent_correct / (val_i + 1),
                "train/mse_losses": np.mean(mse_losses[-(train_i+1):]),
                "val/mse_losses": np.mean(val_mse_losses[-(val_i+1):])}
        progress_bar.set_postfix(**logs)
        
        with torch.no_grad():
            if not att:
                clip_target = clip_target.reshape(-1, 257, 768)
                clip_voxels = clip_voxels.reshape(-1, 257, 768)
            else:
                clip_target = clip_target.reshape(-1, 257, 1024)
                clip_voxels = clip_voxels.reshape(-1, 257, 1024)
            ww = 2 # only reconstruct one sample in batch
            for ee,in_emb in enumerate([clip_target, clip_voxels]):
                if ee==0 and epoch>0:
                    continue
                if att:
                    input_embedding = image_encoder.vision_model.post_layernorm(in_emb[[ww]])
                    input_embedding = image_encoder.visual_projection(input_embedding)
                    input_embedding = nn.functional.normalize(input_embedding,dim=-1)
                else:
                    input_embedding = nn.functional.normalize(in_emb[[ww]],dim=-1)

                input_embedding = input_embedding.repeat(1, 1, 1)
                input_embedding = torch.cat([torch.zeros_like(input_embedding), input_embedding]).to(device)

                # 4. Prepare timesteps
                noise_scheduler.set_timesteps(num_inference_steps=20, device=device)

                # 5b. Prepare latent variables
                batch_size = input_embedding.shape[0] // 2 # divide by 2 bc we doubled it for classifier-free guidance
                shape = (batch_size, unet.in_channels, height // vae_scale_factor, width // vae_scale_factor)

                timesteps = noise_scheduler.timesteps
                latents = randn_tensor(shape, device=device, dtype=input_embedding.dtype)
                latents = latents * noise_scheduler.init_noise_sigma

                # 7. Denoising loop
                for i, t in enumerate(timesteps):
                    # expand the latents if we are doing classifier free guidance
                    latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
                    latent_model_input = noise_scheduler.scale_model_input(latent_model_input, t)

                    noise_pred = unet(latent_model_input, t, encoder_hidden_states=input_embedding).sample

                    # perform guidance
                    if do_classifier_free_guidance:
                        noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
                        noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

                    # compute the previous noisy sample x_t -> x_t-1
                    latents = noise_scheduler.step(noise_pred, t, latents).prev_sample
                if ee==0:
                    recons_clip = decode_latents(latents).detach().cpu()
                else:
                    recons = decode_latents(latents).detach().cpu()
                
            num_xaxis_subplots = 3
            fig, ax = plt.subplots(1, num_xaxis_subplots, 
                               figsize=(9,3),facecolor=(1, 1, 1))
            ax[0].set_title(f"Original Image")
            ax[0].imshow(utils.torch_to_Image(image[[ww]]))
            ax[1].set_title(f"Recon from orig CLIP")
            ax[1].imshow(utils.torch_to_Image(recons_clip))
            ax[2].set_title("Recon from brain")
            ax[2].imshow(utils.torch_to_Image(recons))
            ax[0].axis('off'); ax[1].axis('off'); ax[2].axis('off')
            if wandb_log and local_rank==0:
                logs[f"val/recons"] = wandb.Image(fig, caption=f"epoch{epoch:03d}")
                plt.close()
            else:
                plt.show()

        if wandb_log:
            wandb.log(logs)
            
    if distributed:
        dist.barrier()

if wandb_log and local_rank==0:
    wandb.finish()

print("\n===Finished!===\n")
exit

In [None]:
# with torch.no_grad():
#     clip_target = clip_extractor.embed_image(image).float()
#     clip_target=clip_target.reshape(len(clip_target),-1).to(voxel.dtype)
#     in_emb = clip_target.reshape(-1, 257, 1024).detach().clone()
#     ww = 2
    
#     in_emb[:,:200] = torch.rand_like(in_emb[:,:200]) * 10
    
#     input_embedding = image_encoder.vision_model.post_layernorm(in_emb[[ww]])
#     input_embedding = image_encoder.visual_projection(input_embedding)
#     input_embedding = nn.functional.normalize(input_embedding,dim=-1)

#     input_embedding = input_embedding.repeat(1, 1, 1)
#     input_embedding = torch.cat([torch.zeros_like(input_embedding), input_embedding]).to(device)

#     # 4. Prepare timesteps
#     noise_scheduler.set_timesteps(num_inference_steps=20, device=device)

#     # 5b. Prepare latent variables
#     batch_size = input_embedding.shape[0] // 2 # divide by 2 bc we doubled it for classifier-free guidance
#     shape = (batch_size, unet.in_channels, height // vae_scale_factor, width // vae_scale_factor)

#     timesteps = noise_scheduler.timesteps
#     latents = randn_tensor(shape, device=device, dtype=input_embedding.dtype)
#     latents = latents * noise_scheduler.init_noise_sigma

#     # 7. Denoising loop
#     for i, t in enumerate(timesteps):
#         # expand the latents if we are doing classifier free guidance
#         latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
#         latent_model_input = noise_scheduler.scale_model_input(latent_model_input, t)

#         noise_pred = unet(latent_model_input, t, encoder_hidden_states=input_embedding).sample

#         # perform guidance
#         if do_classifier_free_guidance:
#             noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
#             noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

#         # compute the previous noisy sample x_t -> x_t-1
#         latents = noise_scheduler.step(noise_pred, t, latents).prev_sample
#     recons_clip = decode_latents(latents).detach().cpu()
#     plt.imshow(utils.torch_to_Image(image[ww]))
#     plt.show()
#     plt.imshow(utils.torch_to_Image(recons_clip))
#     plt.show()

In [None]:
# targs = clip_target
# preds = clip_voxels

# temp=0.125
# clip_clip = (targs @ targs.T)/temp
# brain_clip = (preds @ targs.T)/temp

# loss1 = -(brain_clip.log_softmax(-1) * clip_clip.softmax(-1)).sum(-1).mean()
# loss2 = -(brain_clip.T.log_softmax(-1) * clip_clip.T.softmax(-1)).sum(-1).mean()

# loss = (loss1 + loss2)/2

In [None]:
# def soft_clip_loss(preds, targs, temp=0.125, distributed=False, accelerator=None):
#     clip_clip = (targs @ targs.T)/temp
#     brain_clip = (preds @ targs.T)/temp
    
#     loss1 = -(brain_clip.log_softmax(-1) * clip_clip.softmax(-1)).sum(-1).mean()
#     loss2 = -(brain_clip.T.log_softmax(-1) * clip_clip.T.softmax(-1)).sum(-1).mean()
    
#     loss = (loss1 + loss2)/2
#     return loss

# def mixco_nce(preds, targs, temp=0.1, perm=None, betas=None, select=None, distributed=False, accelerator=None, local_rank=None):
#     brain_clip = (preds @ targs.T)/temp
    
#     probs = torch.diag(betas)
#     probs[torch.arange(preds.shape[0]).to(preds.device), perm] = 1 - betas

#     loss = -(brain_clip.log_softmax(-1) * probs).sum(-1).mean()