This notebook uses a diffusion prior approach to map CLIP-fMRI to CLIP-Image space. It is end-to-end in the sense that we are also fine-tuning the Brain-to-CLIP mapping, but this Brain-to-CLIP mapping is initialized separately first.

In [1]:
# # convert this notebook to .py such that you can then run it via slurm with "sbatch *.slurm"
# from subprocess import call
# command = "jupyter nbconvert CLIP_to_CLIP.ipynb --to python"
# call(command,shell=True)

# Import packages & functions

In [2]:
import os
import sys
import json
import argparse
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
from tqdm import tqdm
from info_nce import InfoNCE
from dalle2_pytorch import DiffusionPriorNetwork
import kornia
from kornia.augmentation.container import AugmentationSequential

import utils
from utils import torch_to_matplotlib, torch_to_Image
from models import OpenClipper, Clipper, BrainNetwork, BrainDiffusionPrior, BrainSD
from model3d import SimpleVoxel3dConvEncoder

import torch.distributed as dist
from accelerate import Accelerator

from diffusers import UniPCMultistepScheduler

# uses tf32 data type which is faster than standard float32
torch.backends.cuda.matmul.allow_tf32 = True

# Multi-GPU config #
accelerator = Accelerator()
print = accelerator.print # only print if local_rank=0

device = accelerator.device
print("device:",device)

num_devices = torch.cuda.device_count()
if num_devices==0: num_devices = 1
num_workers = num_devices

print(accelerator.state)
local_rank = accelerator.state.local_process_index
world_size = accelerator.state.num_processes
if num_devices<=1 and world_size<=1:
    distributed=False
else:
    distributed=True
print("distributed =",distributed,"num_devices =", num_devices, "local rank =", local_rank, "world size =", world_size)

device: cuda
Distributed environment: NO
Num processes: 1
Process index: 0
Local process index: 0
Device: cuda

Mixed precision type: no

distributed = False num_devices = 1 local rank = 0 world size = 1


# Configurations

In [3]:
# can specify jupyter_args here for argparser to use if running this code interactively
if utils.is_interactive():
    jupyter_args=[]
    jupyter_args.append("--voxel2clip_path=../train_logs/v2c_avg_v0_partialFalse/best.pth")
    print(jupyter_args)

['--voxel2clip_path=../train_logs/v2c_avg_v0_partialFalse/best.pth']


In [4]:
parser = argparse.ArgumentParser(description="Model Training Configuration")
parser.add_argument(
    "--model_name", type=str, default="testing",
    help="name of model, used for ckpt saving and wandb logging",
)
parser.add_argument(
    "--voxel2clip_path", type=str, default="None",
    help="pretrained checkpoint to initialize voxel2clip",
)
parser.add_argument(
    "--modality", type=str, default="image", choices=["image", "text"],
    help="image or text",
)
parser.add_argument(
    "--batch_size", type=int, default=64,
    help="Our maximum for A100 was 64 for 1dim voxels and 32 for 3dim voxels",
)
parser.add_argument(
    "--combine_losses", type=bool, default=True,
    help="if True train voxel2clip and diffusion prior, otherwise, just train the prior",
)
parser.add_argument(
    "--alpha_schedule", type=str, default="constant", choices=["constant", "linear"],
    help="alpha is weight for MSE diffusion prior loss",
)
parser.add_argument(
    "--clip_variant",type=str,default="ViT-L/14",choices=["ViT-L/14"],# "RN50", "ViT-L/14", "ViT-B/32", "ViT-H-14"
    help='clip / openclip variant',
)
parser.add_argument(
    "--outdir",type=str,default=None,
    help="output directory for logs and checkpoints",
)
parser.add_argument(
    "--wandb_log",type=bool,default=False,
    help="whether to log to wandb",
)
parser.add_argument(
    "--wandb_auto_resume",type=bool,default=True,
    help="automatically resume wandb run if it stops and restarts ",
)
parser.add_argument(
    "--resume_from_ckpt",type=bool,default=False,
    help="if not using wandb and want to resume from a ckpt",
)
parser.add_argument(
    "--wandb_project",type=str,default="stability",
    help="wandb project name",
)
parser.add_argument(
    "--use_mixco", type=bool, default=True,
    help="use mixup contrastive loss for voxel2clip",
)
parser.add_argument(
    "--mixup_pct",type=float,default=.5,
    help="proportion of way through training when to switch from InfoNCE to soft_clip_loss",
)
parser.add_argument(
    "--voxel_dims",type=int,default=1,choices=[1, 3],
    help="1 for flattened input, 3 for 3d input",
)
parser.add_argument(
    "--use_image_aug",type=bool,default=False,
    help="whether to use image augmentation (only used for modality=image)",
)
parser.add_argument(
    "--num_epochs",type=int,default=120,
)
parser.add_argument(
    "--lr_scheduler",type=str,default='cycle',choices=['cycle','fixed'],
)
parser.add_argument(
    "--ckpt_saving",type=bool,default=True,
)
parser.add_argument(
    "--ckpt_interval",type=int,default=1,
    help="save ckpt every x epochs",
)
parser.add_argument(
    "--save_at_end",type=bool,default=False,
    help="if False, will save best.ckpt whenever epoch shows best validation score",
)
parser.add_argument(
    "--seed",type=int,default=42,
)
parser.add_argument(
    "--n_samples_save",type=int,default=0,
    help="Number of reconstructions for monitoring progress, 0 will speed up training",
)
parser.add_argument(
    "--sd_scheduler",type=str,default="unipcm",choices=["unipcm","pndms"],
    help="Noise scheduler for image reconstructions",
)

if utils.is_interactive():
    args = parser.parse_args(jupyter_args)
else:
    args = parser.parse_args()

# create global variables without the args prefix
for attribute_name in vars(args).keys():
    globals()[attribute_name] = getattr(args, attribute_name)

In [5]:
if voxel2clip_path=='None':
    voxel2clip_path=None

if sd_scheduler=='unipcm':
    num_inference_steps = 20 # steps for diffusion model to output pixel image
else:
    num_inference_steps = 50
recons_per_clip = 1
recons_per_brain = 1

max_lr = 3e-4

if outdir is None:
    outdir = f'../train_logs/{model_name}'
if not os.path.exists(outdir):
    os.makedirs(outdir,exist_ok=True)
    
if use_image_aug:
    train_augs = AugmentationSequential(
        kornia.augmentation.RandomResizedCrop((224,224), (0.6,1), p=0.3),
        kornia.augmentation.Resize((224, 224)),
        kornia.augmentation.RandomHorizontalFlip(p=0.5),
        kornia.augmentation.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1, p=0.3),
        kornia.augmentation.RandomGrayscale(p=0.3),
        data_keys=["input"],
    )
else:
    train_augs = None
    
if modality=='text':
    annots = np.load("/fsx/proj-medarc/fmri/natural-scenes-dataset/COCO_73k_annots_curated.npy")

# Prep models and data loaders

In [6]:
# need non-deterministic CuDNN for conv3D to work
utils.seed_everything(seed, cudnn_deterministic=False)

print('Pulling NSD webdataset data...')
train_url = "{/fsx/proj-medarc/fmri/natural-scenes-dataset/webdataset_avg_split/train/train_subj01_{0..17}.tar,/fsx/proj-medarc/fmri/natural-scenes-dataset/webdataset_avg_split/val/val_subj01_0.tar}"
val_url = "/fsx/proj-medarc/fmri/natural-scenes-dataset/webdataset_avg_split/test/test_subj01_{0..1}.tar"
meta_url = "/fsx/proj-medarc/fmri/natural-scenes-dataset/webdataset_avg_split/metadata_subj01.json"
num_train = 8559 + 300
num_val = 982

# which to use for the voxels
if voxel_dims == 1:
    voxels_key = 'nsdgeneral.npy'
elif voxel_dims == 3:
    voxels_key = 'wholebrain_3d.npy'
else:
    raise Exception(f"voxel_dims must be 1 or 3, not {voxel_dims}")

print('Prepping train and validation dataloaders...')
train_dl, val_dl, num_train, num_val = utils.get_dataloaders(
    batch_size,'images',
    num_devices=num_devices,
    num_workers=num_workers,
    train_url=train_url,
    val_url=val_url,
    meta_url=meta_url,
    num_train=num_train,
    num_val=num_val,
    val_batch_size=300,
    cache_dir="/tmp/wds-cache",
    seed=seed,
    voxels_key=voxels_key,
    to_tuple=["voxels", "images", "coco"],
    local_rank=local_rank,
)

if voxel_dims == 3:
    import nibabel as nib
    noise_ceils_path = '/fsx/proj-medarc/fmri/natural-scenes-dataset/temp_s3/nsddata_betas/ppdata/subj01/func1pt8mm/betas_fithrf_GLMdenoise_RR/ncsnr.nii.gz'
    noise_ceils = nib.load(noise_ceils_path).get_fdata()
    x_inc,y_inc,z_inc = np.where(noise_ceils > .5) # voxel.shape torch.Size([300, 3, 68, 64, 47])

Note: not using cudnn.deterministic
Pulling NSD webdataset data...
Prepping train and validation dataloaders...
Getting dataloaders...

num_train 8859
global_batch_size 64
batch_size 64
num_workers 1
num_batches 138
num_worker_batches 138
cache_dir None

num_val 982
val_batch_size 300
val_num_workers 1


In [7]:
print('Creating Clipper...')
    
# Don't L2 norm the extracted CLIP embeddings since we want the prior 
# to learn un-normed embeddings for usage with the SD image variation pipeline.
if clip_variant == "ViT-H-14":
    clip_extractor = OpenClipper(clip_variant, device=device, train_transforms=train_augs)
else:
    clip_extractor = Clipper(clip_variant, device=device, train_transforms=train_augs)

print('Creating voxel2clip...')

if voxel_dims == 1: # 1D data
    if clip_variant == "ViT-H-14" or clip_variant == "RN50":
        voxel2clip_kwargs = dict(out_dim=1024)
    else:
        voxel2clip_kwargs = dict(out_dim=768)
    voxel2clip = BrainNetwork(**voxel2clip_kwargs)
elif voxel_dims == 3: # 3D data
    if clip_variant == "ViT-H-14" or clip_variant == "RN50":
        voxel2clip_kwargs = dict(
            out_dim=1024,
            dims=voxel.shape[2:],
            channels=[64, 128, 256, 128],
            strides=[1, 2, 3, 3],
            padding=[1, 1, 1, 1],
            dilation=[1, 1, 1, 1],
            kernel=[3, 3, 3, 3],
        )
    else:
        voxel2clip_kwargs = dict(
            out_dim=768,
            dims=voxel.shape[2:],
            channels=[64, 128, 256, 128],
            strides=[1, 2, 3, 3],
            padding=[1, 1, 1, 1],
            dilation=[1, 1, 1, 1],
            kernel=[3, 3, 3, 3],
        )
    voxel2clip = SimpleVoxel3dConvEncoder(**voxel2clip_kwargs)  

print("params of voxel2clip:")
if local_rank==0: utils.count_params(voxel2clip)
    
if not combine_losses:
    voxel2clip.to(device)
    checkpoint = torch.load(voxel2clip_path, map_location=device)
    try:
        voxel2clip.load_state_dict(checkpoint['model_state_dict'])
    except:
        # converting ddp model to non-ddp format
        state_dict = checkpoint['model_state_dict']
        for key in list(state_dict.keys()):
            if 'module.' in key:
                state_dict[key.replace('module.', '')] = state_dict[key]
                del state_dict[key]
        voxel2clip.load_state_dict(state_dict)

    # freeze when not combining models
    voxel2clip.eval()
    voxel2clip.requires_grad_(False)
    
print('Creating diffusion prior...')
# not using prior_kwargs b/c the model is pretrained
diffusion_prior = BrainDiffusionPrior.from_pretrained(
    # kwargs for DiffusionPriorNetwork
    dict(),
    # kwargs for DiffusionNetwork
    dict(
        condition_on_text_encodings=False,
        timesteps=1000,
        voxel2clip=voxel2clip if combine_losses else None,
    ),
    voxel2clip_path=voxel2clip_path if combine_losses else None,
)

print("params of diffusionprior:")
if local_rank==0: utils.count_params(diffusion_prior)

if n_samples_save > 0:
    if local_rank == 0: print('Creating SD image variation pipeline...')
    from diffusers import StableDiffusionImageVariationPipeline
    from diffusers import AutoencoderKL, PNDMScheduler, UNet2DConditionModel, UniPCMultistepScheduler
    
    sd_cache_dir = os.path.join(
                        os.path.expanduser('~'), 
                        ".cache/huggingface/diffusers/models--lambdalabs--sd-image-variations-diffusers/snapshots/a2a13984e57db80adcc9e3f85d568dcccb9b29fc/"
                    )
    if not os.path.isdir(sd_cache_dir): # download from huggingface if not already downloaded / cached
        from diffusers import StableDiffusionImageVariationPipeline
        print("Downloading lambdalabs/sd-image-variations-diffusers from huggingface...")
        sd_pipe = StableDiffusionImageVariationPipeline.from_pretrained("lambdalabs/sd-image-variations-diffusers", revision="v2.0")
        sd_cache_dir = "lambdalabs/sd-image-variations-diffusers"

    unet = UNet2DConditionModel.from_pretrained(sd_cache_dir,subfolder="unet").to(device)
    vae = AutoencoderKL.from_pretrained(sd_cache_dir,subfolder="vae").to(device)
    noise_scheduler = PNDMScheduler.from_pretrained(sd_cache_dir, subfolder="scheduler")
    if sd_scheduler=='unipcm':
        noise_scheduler = UniPCMultistepScheduler.from_config(noise_scheduler.config)

    unet.eval() # dont want to train model
    unet.requires_grad_(False) # dont need to calculate gradients

    vae.eval()
    vae.requires_grad_(False)
    
no_decay = ['bias']
opt_grouped_parameters = [
    {'params': [p for n, p in diffusion_prior.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': 1e-2},
    {'params': [p for n, p in diffusion_prior.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
]
optimizer = torch.optim.AdamW(opt_grouped_parameters, lr=5e-4) # lr doesnt get used if lr_scheduler='cycle'

if lr_scheduler == 'fixed':
    lr_scheduler = None
elif lr_scheduler == 'cycle':
    global_batch_size = batch_size * num_devices
    total_steps=num_epochs*(num_train//global_batch_size)
    lr_scheduler = torch.optim.lr_scheduler.OneCycleLR(
        optimizer, 
        max_lr=max_lr,
        total_steps=total_steps,
        final_div_factor=1000,
        last_epoch=-1, pct_start=2/num_epochs
    )
    
def save_ckpt(tag):
    ckpt_path = outdir+f'/{tag}.pth'
    print(f'saving {ckpt_path}',flush=True)
    try:
        torch.save({
            'epoch': epoch,
            'voxel2clip_state_dict': diffusion_prior.voxel2clip.state_dict(),
            'model_state_dict': diffusion_prior.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'lr_scheduler': lr_scheduler.state_dict(),
            'train_losses': losses,
            'val_losses': val_losses,
            'lrs': lrs,
            }, ckpt_path)
    except:
        print("couldnt save voxel2clip.state_dict")
        torch.save({
            'epoch': epoch,
            'model_state_dict': diffusion_prior.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'lr_scheduler': lr_scheduler.state_dict(),
            'train_losses': losses,
            'val_losses': val_losses,
            'lrs': lrs,
            }, ckpt_path)
        
print("\nDone with model preparations!")

Creating Clipper...
ViT-L/14 cuda
Creating voxel2clip...
params of voxel2clip:
param counts:
134,722,304 total
134,722,304 trainable
Creating diffusion prior...
params of diffusionprior:
param counts:
236,616,336 total
236,616,320 trainable

Done with model preparations!


# Weights and Biases

In [8]:
# params for wandb
if local_rank==0 and wandb_log:
    wandb_project = 'stability'
    wandb_run = model_name
    wandb_notes = ''
    
    if wandb_log: 
        import wandb
        print(f"wandb {wandb_project} run {wandb_run}")
        wandb.login(host='https://stability.wandb.io')#, relogin=True)
        wandb_config = {
          "model_name": model_name,
          "modality": modality,
          "voxel_dims": voxel_dims,
          "clip_variant": clip_variant,
          "batch_size": batch_size,
          "num_epochs": num_epochs,
          "use_image_aug": use_image_aug,
          "max_lr": max_lr,
          "lr_scheduler": lr_scheduler,
          "clamp_embs": clamp_embs,
          "mixup_pct": mixup_pct,
          "num_train": num_train,
          "num_val": num_val,
          "seed": seed,
          "distributed": distributed,
          "num_devices": num_devices,
          "world_size": world_size,
          "resume_from_ckpt": resume_from_ckpt,
          "ckpt_path": ckpt_path,
          "train_url": train_url,
          "val_url": val_url,
          "voxel2clip_path": voxel2clip_path,
          "combine_losses": combine_losses,
          "use_mixco": use_mixco, 
          "n_samples_save": n_samples_save,
          "sd_scheduler": sd_scheduler,
        }
        print("wandb_config:\n",wandb_config)
        if wandb_auto_resume:
            print("wandb_id:",model_name)
            wandb.init(
                id = model_name,
                project=wandb_project,
                name=wandb_run,
                config=wandb_config,
                notes=wandb_notes,
                resume="allow",
            )
        else:
            wandb.init(
                project=wandb_project,
                name=wandb_run,
                config=wandb_config,
                notes=wandb_notes,
            )

# Huggingface Accelerate

In [9]:
diffusion_prior, optimizer, train_dl, val_dl, lr_scheduler = accelerator.prepare(
    diffusion_prior, optimizer, train_dl, val_dl, lr_scheduler
)

# Main

In [10]:
# need non-deterministic CuDNN for conv3D to work
utils.seed_everything(seed, cudnn_deterministic=False)

epoch = 0
losses, val_losses, lrs = [], [], []
best_val_loss = 1e9
soft_loss_temps = utils.cosine_anneal(0.004, 0.0075, num_epochs - int(mixup_pct * num_epochs))

if use_mixco:
    contrast_loss = utils.mixco_nce
else:
    contrast_loss = InfoNCE()
print('contrast_loss', contrast_loss)

# weight for prior's MSE loss term
if alpha_schedule == 'constant':
    alphas = np.ones(num_epochs) * 0.01
elif alpha_schedule == 'linear':
    alphas = np.linspace(0.01, 0.05, num_epochs, endpoint=True)
else:
    raise ValueError(f'unknown alpha_schedule: {alpha_schedule}')

val_voxel0 = val_image0 = None

# Optionally resume from checkpoint #
if (wandb_log and wandb.run.resumed) or resume_from_ckpt:
    print("\n---resuming from last.pth ckpt---\n")
    checkpoint = torch.load(outdir+'/last.pth')
    epoch = checkpoint['epoch']
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
    diffusion_prior.load_state_dict(checkpoint['model_state_dict'])
    try: # will fail if not using pretrained voxel2clip
        diffusion_prior.voxel2clip.load_state_dict(checkpoint['voxel2clip_state_dict'])
    except: 
        pass

progress_bar = tqdm(range(epoch,num_epochs), ncols=1200, disable=(local_rank!=0))
for epoch in progress_bar:
    diffusion_prior.train()

    sims = 0.
    sims_base = 0.
    val_sims = 0.
    val_sims_base = 0.
    fwd_percent_correct = 0.
    bwd_percent_correct = 0.
    val_fwd_percent_correct = 0.
    val_bwd_percent_correct = 0.
    loss_nce_sum = 0.
    loss_prior_sum = 0.
    val_loss_nce_sum = 0.
    val_loss_prior_sum = 0.
    
    alpha = alphas[epoch]

    for train_i, (voxel, image, coco) in enumerate(train_dl):
        optimizer.zero_grad()
        
        repeat_index = train_i % 3

        image = image.float()
        voxel = voxel.float()[:,repeat_index].float()
        
        if voxel_dims == 3:
            voxel = voxel[:,np.unique(x_inc),:,:]
            voxel = voxel[:,:,np.unique(y_inc),:]
            voxel = voxel[:,:,:,np.unique(z_inc)]

        if modality=='text':
            img_annots = utils.select_annotations(annots[coco.cpu().numpy()], random=True)
            clip_target = clip_extractor.embed_text(img_annots).float()
        else:
            clip_target = clip_extractor.embed_image(image).float()
        clip_target.to(voxel.dtype)
        
        if use_mixco and epoch < int(mixup_pct * num_epochs):
            voxel, perm, betas, select = utils.mixco(voxel)
            clip_target = utils.mixco_clip_target(clip_target, perm, select, betas)

        if combine_losses:
            # loss here is MSE for the prior, clip_voxels are voxel2clip outputs
            loss, pred, clip_voxels = diffusion_prior(image_embed=clip_target, voxel=voxel)
            utils.check_loss(loss)
            # combine losses for contrastive learned voxel2clip mapper and the prior
            if use_mixco:
                if epoch < int(mixup_pct * num_epochs):
                    loss_nce = contrast_loss(
                        nn.functional.normalize(clip_voxels, dim=-1), 
                        nn.functional.normalize(clip_target, dim=-1),
                        temp=0.006, perm=perm, betas=betas, select=select,
                        distributed=distributed, accelerator=accelerator, local_rank=local_rank)

                else:
                    epoch_temp = soft_loss_temps[epoch-int(mixup_pct*num_epochs)]
                    loss_nce = utils.soft_clip_loss(
                        nn.functional.normalize(clip_voxels, dim=-1), 
                        nn.functional.normalize(clip_target, dim=-1),
                        temp=epoch_temp,
                        distributed=distributed, accelerator=accelerator)
            else:
                loss_nce = contrast_loss(
                    nn.functional.normalize(clip_voxels, dim=-1), 
                    nn.functional.normalize(clip_target, dim=-1),
                )
            utils.check_loss(loss_nce)

            loss_nce_sum += loss_nce.item()
            loss_prior_sum += loss.item() * .01

            # MSE and NCE are weighted equally at the beginning,
            # with alpha=0.01 we'll have something like .01*300 + .99*3 = 3 + 3
            loss = alpha * loss + (1-alpha) * loss_nce
        else:
            clip_voxels = voxel2clip(voxel)
            loss, pred, _ = diffusion_prior(text_embed=clip_voxels, image_embed=clip_target)
            
            loss = loss * .01 # to keep scale of losses roughly similar
            loss_prior_sum += loss.item()
            
        losses.append(loss.item())
        lrs.append(optimizer.param_groups[0]['lr'])

        # similarity after prior diffusion
        sims += F.cosine_similarity(clip_target,pred).mean().item()
        # baseline similarity before prior diffusion
        sims_base += F.cosine_similarity(clip_target,clip_voxels).mean().item()

        # forward and backward top 1 accuracy
        labels = torch.arange(len(clip_target)).to(device)
        fwd_percent_correct += utils.topk(utils.batchwise_cosine_similarity(clip_target, clip_voxels), labels, k=1)
        bwd_percent_correct += utils.topk(utils.batchwise_cosine_similarity(clip_voxels, clip_target), labels, k=1)

        accelerator.backward(loss)
        optimizer.step()

        if lr_scheduler is not None:
            lr_scheduler.step()

    diffusion_prior.eval()
    for val_i, (voxel, image, coco) in enumerate(val_dl): 
        with torch.no_grad():
            repeat_index = val_i % 3

            image = image.float()
            voxel = voxel[:,repeat_index].float()

            if voxel_dims == 3:
                voxel = voxel[:,np.unique(x_inc),:,:]
                voxel = voxel[:,:,np.unique(y_inc),:]
                voxel = voxel[:,:,:,np.unique(z_inc)]

            if val_image0 is None:
                val_image0 = image.detach().clone()
                val_voxel0 = voxel.detach().clone()

            if modality=='text':
                img_annots = utils.select_annotations(annots[coco.cpu().numpy()], random=False)
                clip_target = clip_extractor.embed_text(img_annots).float()
            else:
                clip_target = clip_extractor.embed_image(image).float()
            clip_target.to(voxel.dtype)

            if combine_losses:
                loss, pred, clip_voxels = diffusion_prior(image_embed=clip_target, voxel=voxel)
                utils.check_loss(loss)
                if use_mixco:
                    if epoch < int(mixup_pct * num_epochs):
                        loss_nce = utils.mixco_nce(
                            nn.functional.normalize(clip_voxels, dim=-1), 
                            nn.functional.normalize(clip_target, dim=-1),
                            temp=0.006, 
                            distributed=distributed, accelerator=accelerator, local_rank=local_rank)
                    else:
                        epoch_temp = soft_loss_temps[epoch-int(mixup_pct*num_epochs)]
                        loss_nce = utils.soft_clip_loss(
                            nn.functional.normalize(clip_voxels, dim=-1), 
                            nn.functional.normalize(clip_target, dim=-1),
                            temp=epoch_temp,
                            distributed=distributed, accelerator=accelerator)
                else:
                    loss_nce = contrast_loss(
                        nn.functional.normalize(clip_voxels, dim=-1), 
                        nn.functional.normalize(clip_target, dim=-1),
                    )
                utils.check_loss(loss_nce)

                val_loss_nce_sum += loss_nce.item()
                val_loss_prior_sum += loss.item() * .01

                val_loss = alpha * loss + (1-alpha) * loss_nce
            else:
                clip_voxels = voxel2clip(voxel)
                loss, pred, _ = diffusion_prior(text_embed=clip_voxels, image_embed=clip_target)
                        
                loss = loss * .01
                val_loss = loss
                val_loss_prior_sum += loss.item()
            
            val_losses.append(val_loss.item())
            val_sims += F.cosine_similarity(clip_target, pred).mean().item()
            val_sims_base += F.cosine_similarity(clip_target, clip_voxels).mean().item()
            labels = torch.arange(len(clip_voxels)).to(device)
            val_fwd_percent_correct += utils.topk(
                utils.batchwise_cosine_similarity(clip_target, clip_voxels), labels, k=1
            ).item()
            val_bwd_percent_correct += utils.topk(
                utils.batchwise_cosine_similarity(clip_voxels, clip_target), labels, k=1
            ).item()

    if local_rank==0:
        if (not save_at_end and ckpt_saving) or (save_at_end and epoch == num_epochs - 1):
            # save best model
            val_loss = np.mean(val_losses[-(val_i+1):])
            if val_loss < best_val_loss:
                best_val_loss = val_loss
                save_ckpt('best')
            else:
                print(f'not best - val_loss: {val_loss:.3f}, best_val_loss: {best_val_loss:.3f}')

        # Save model checkpoint every `ckpt_interval`` epochs or on the last epoch
        if (ckpt_interval is not None and (epoch + 1) % ckpt_interval == 0) or epoch == num_epochs - 1:
            save_ckpt(f'last')
            if wandb_log: # save last ckpt so you can resume from it if need be
                wandb.save(os.path.abspath(outdir)+'/last.pth', base_path=os.path.abspath(outdir))

        logs = {"train/loss": np.mean(losses[-(train_i+1):]),
                "val/loss": np.mean(val_losses[-(val_i+1):]),
                "val/loss": np.mean(val_losses[-(val_i+1):]),
                "train/loss_nce": loss_nce_sum / (train_i + 1),
                "train/loss_mse": loss_prior_sum / (train_i + 1),
                "val/loss_nce": val_loss_nce_sum / (val_i + 1),
                "val/loss_mse": val_loss_prior_sum / (val_i + 1),
                "train/lr": lrs[-1],
                "train/num_steps": len(losses),
                "val/num_steps": len(val_losses),
                "train/sim": sims / (train_i + 1),
                "val/sim": val_sims / (val_i + 1),
                "train/cosine_sim_base": sims_base / (train_i + 1),
                "val/cosine_sim_base": val_sims_base / (val_i + 1),
                "train/fwd_pct_correct": fwd_percent_correct / (train_i + 1),
                "train/bwd_pct_correct": bwd_percent_correct / (train_i + 1),
                "val/val_fwd_pct_correct": val_fwd_percent_correct / (val_i + 1),
                "val/val_bwd_pct_correct": val_bwd_percent_correct / (val_i + 1),
                "train/alpha": alpha}
        progress_bar.set_postfix(**logs)
        
        # sample some images
        if (ckpt_interval is not None and (epoch + 1) % ckpt_interval == 0) or epoch == num_epochs - 1:
            if n_samples_save > 0:
                if not combine_losses:
                    diffusion_prior.voxel2clip = voxel2clip
                # validation
                grid = utils.reconstruct_from_clip(
                    val_image0, val_voxel0,
                    diffusion_prior, 
                    clip_extractor, unet, vae, noise_scheduler,
                    img_lowlevel = None,
                    num_inference_steps = num_inference_steps,
                    n_samples_save = n_samples_save,
                    recons_per_clip = recons_per_clip,
                    recons_per_brain = recons_per_brain,
                    guidance_scale = 7.5,
                    img2img_strength = 1,
                    timesteps = 1000,
                    seed = seed,
                    retrieve = False,
                    plotting = True,
                )
                grid = grid[0]
                grid.savefig(os.path.join(outdir, f'samples-val-epoch{epoch:03d}.png'))
                if wandb_log and local_rank==0:
                    logs[f"val/recons"] = wandb.Image(grid, caption=f"epoch{epoch:03d}")
                plt.close()
                if not combine_losses:
                    diffusion_prior.voxel2clip = None

        if wandb_log:
            wandb.log(logs)
            
    if distributed:
        dist.barrier()

if wandb_log and local_rank==0:
    wandb.finish()

print("\n===Finished!===\n")