This notebook uses a diffusion prior approach to map CLIP-fMRI to CLIP-Image space. It is end-to-end in the sense that we are also fine-tuning the Brain-to-CLIP mapping, but this Brain-to-CLIP mapping is initialized separately first.

In [2]:
# # convert this notebook to .py such that you can then run it via slurm with "sbatch *.slurm"
# from subprocess import call
# command = "jupyter nbconvert CLIP_to_CLIP.ipynb --to python"
# call(command,shell=True)

[NbConvertApp] Converting notebook CLIP_to_CLIP.ipynb to python
[NbConvertApp] Writing 28956 bytes to CLIP_to_CLIP.py


0

# Import packages & functions

In [1]:
import os
import sys
import json
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
from tqdm import tqdm
from info_nce import InfoNCE
from dalle2_pytorch import DiffusionPriorNetwork
import kornia
from kornia.augmentation.container import AugmentationSequential

import utils
from utils import torch_to_matplotlib, torch_to_Image
from models import Clipper, BrainNetwork, BrainDiffusionPrior, BrainSD
from model3d import SimpleVoxel3dConvEncoder

import torch.distributed as dist
from accelerate import Accelerator

from diffusers import UniPCMultistepScheduler

# uses tf32 data type which is faster than standard float32
torch.backends.cuda.matmul.allow_tf32 = True

# Log to weights and biases?
wandb_log = True

# Resume from ckpt? #
resume_from_ckpt = False
if resume_from_ckpt:
    ckpt_path = '../train_logs/vox2clip_indiv/ckpt-voxel2clip-epoch029.pth'
else:
    ckpt_path = 'none'

# Multi-GPU config #
accelerator = Accelerator()
print = accelerator.print # only print if local_rank=0

device = accelerator.device
print("device:",device)

num_devices = torch.cuda.device_count()
if num_devices==0: num_devices = 1
num_workers = num_devices

print(accelerator.state)
local_rank = accelerator.state.local_process_index
world_size = accelerator.state.num_processes
if num_devices<=1 and world_size<=1:
    distributed=False
else:
    distributed=True
print("distributed =",distributed,"num_devices =", num_devices, "local rank =", local_rank, "world size =", world_size)

device: cuda
Distributed environment: NO
Num processes: 1
Process index: 0
Local process index: 0
Device: cuda

Mixed precision type: no

distributed = False num_devices = 1 local rank = 0 world size = 1


# Configurations

In [6]:
voxel2clip_path = '../train_logs/v2c_avg_v0_partialFalse/best.pth' # ckpt path for voxel2clip model

combine_models = True # combine voxel2clip and prior into one model and train both end to end
combine_losses = True # when combine_models=True, use two terms in the loss, NCE and MSE
prior_pretrained = True # starting point = LAION aesthetics

model_name = "diffusion_prior_test"
if not wandb_log:
    model_name = "testing"

modality = "image"
if modality == "text":
    is_text = True
else:
    is_text = False
clip_variant = "ViT-L/14" # ("RN50", "ViT-L/14", "ViT-B/32")
clamp_embs = False # clamp embeddings to (-1.5, 1.5)
alpha_schedule = "constant" # ("constant", "linear") - alpha is weight the MSE DP loss
voxel_dims = 1 # 1 for flattened input, 3 for 3d input
seed = 42

# Currently, reconstructing back to image takes up too much memory. Better to save ckpts and separately evaluate!
n_samples_save = 0 # how many SD reconstruction samples to save to monitor progress
num_inference_steps = 20 # how many steps for diffusion model to output pixel image reconstruction
img2img_strength = .6 # closer to 0 the more the recon will look like the input image 
recons_per_clip = 2
recons_per_brain = 4

use_mixco = False # use mixup-contrastive on the voxels
mixup_pct = 0.5

# clip_aug_mode = 'none' # ('none', 'x', 'y')
# clip_aug_prob = 0.03 # prob of applying augmentation to a batch
sd_scheduler = 'unipcm' # scheduler for SD image variation pipeline ('pndms', 'unipcm')
use_image_aug = False # use image augmentation prior to getting CLIP embeddings

num_epochs = 60
if voxel_dims==1:
    batch_size = 64
else:
    batch_size = 32

lr_scheduler = 'cycle'
initial_lr = 5e-4 # only used if lr_scheduler is 'fixed'
max_lr = 3e-4
ckpt_saving = True
ckpt_interval = 1 #10
save_at_end = True
outdir = f'../train_logs/{model_name}'
if not os.path.exists(outdir):
    os.makedirs(outdir,exist_ok=True)
remote_data = False # if True, pull webdatasets from huggingface

train_augs = AugmentationSequential(
    kornia.augmentation.RandomResizedCrop((224,224), (0.6,1), p=0.3),
    kornia.augmentation.Resize((224, 224)),
    kornia.augmentation.RandomHorizontalFlip(p=0.5),
    kornia.augmentation.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1, p=0.3),
    kornia.augmentation.RandomGrayscale(p=0.3),
    data_keys=["input"],
)

# Prep models and data loaders

In [7]:
# need non-deterministic CuDNN for conv3D to work
utils.seed_everything(seed, cudnn_deterministic=False)

if modality=='text':
    print('Using CLIP-text, preparing COCO annotations...')
    import h5py
    # load COCO annotations curated in the same way as the mind_reader (Lin Sprague Singh) preprint
    f = h5py.File('/scratch/gpfs/KNORMAN/nsdgeneral_hdf5/COCO_73k_subj_indices.hdf5', 'r')
    subj01_order = f['subj01'][:]
    f.close()
    annots = np.load('/scratch/gpfs/KNORMAN/nsdgeneral_hdf5/COCO_73k_annots_curated.npy',allow_pickle=True)
    subj01_annots = annots[subj01_order]

print('Pulling NSD webdataset data...')
if remote_data:
    # pull data directly from huggingface
    train_url, val_url = utils.get_huggingface_urls(data_commit)
    meta_url = None
else:
    # local paths
    # data_commit = '9947586218b6b7c8cab804009ddca5045249a38d'
    # train_url = f"/fsx/proj-medarc/fmri/natural-scenes-dataset/{data_commit}/datasets_pscotti_naturalscenesdataset_resolve_{data_commit}_webdataset_train/train_subj01_{{0..49}}.tar"
    # val_url = f"/fsx/proj-medarc/fmri/natural-scenes-dataset/{data_commit}/datasets_pscotti_naturalscenesdataset_resolve_{data_commit}_webdataset_val/val_subj01_0.tar"
    # meta_url = None
    # num_train = num_val = None # None means use all samples as specified in webdataset metadata.json
    
    train_url = "{/fsx/proj-medarc/fmri/natural-scenes-dataset/webdataset_avg_split/train/train_subj01_{0..17}.tar,/fsx/proj-medarc/fmri/natural-scenes-dataset/webdataset_avg_split/val/val_subj01_0.tar}"
    val_url = "/fsx/proj-medarc/fmri/natural-scenes-dataset/webdataset_avg_split/test/test_subj01_{0..1}.tar"
    meta_url = "/fsx/proj-medarc/fmri/natural-scenes-dataset/webdataset_avg_split/metadata_subj01.json"
    num_train = 8559 + 300
    num_val = 982

# which to use for the voxels
if voxel_dims == 1:
    voxels_key = 'nsdgeneral.npy'
elif voxel_dims == 3:
    voxels_key = 'wholebrain_3d.npy'
else:
    raise Exception(f"voxel_dims must be 1 or 3, not {voxel_dims}")

print('Prepping train and validation dataloaders...')
train_dl, val_dl, num_train, num_val = utils.get_dataloaders(
    batch_size,'images',
    num_devices=num_devices,
    num_workers=num_workers,
    train_url=train_url,
    val_url=val_url,
    meta_url=meta_url,
    num_train=num_train,
    num_val=num_val,
    val_batch_size=300,
    cache_dir="/tmp/wds-cache",
    seed=seed,
    voxels_key=voxels_key,
    local_rank=local_rank,
)

if voxel_dims == 3:
    import nibabel as nib
    noise_ceils_path = '/fsx/proj-medarc/fmri/natural-scenes-dataset/temp_s3/nsddata_betas/ppdata/subj01/func1pt8mm/betas_fithrf_GLMdenoise_RR/ncsnr.nii.gz'
    noise_ceils = nib.load(noise_ceils_path).get_fdata()
    # plt.plot(np.sort(noise_ceils.flatten()))
    # plt.show()
    x_inc,y_inc,z_inc = np.where(noise_ceils > .5)

    # check that your data loader is working and save voxel shape after excluding low signal voxels
    for val_i, (voxel, img_input, key) in enumerate(val_dl):
        voxel = voxel[:,:,np.unique(x_inc),:,:]
        voxel = voxel[:,:,:,np.unique(y_inc),:]
        voxel = voxel[:,:,:,:,np.unique(z_inc)]
        print("voxel.shape", voxel.shape) # voxel.shape torch.Size([300, 3, 68, 64, 47])
        break

Note: not using cudnn.deterministic
Pulling NSD webdataset data...
Prepping train and validation dataloaders...
Getting dataloaders...

num_train 8859
global_batch_size 64
batch_size 64
num_workers 1
num_batches 138
num_worker_batches 138
cache_dir None

num_val 982
val_batch_size 300
val_num_workers 1


In [8]:
print('Creating Clipper...')
    
# Don't L2 norm the extracted CLIP embeddings since we want the prior 
# to learn un-normed embeddings for usage with the SD image variation pipeline.
clip_extractor = Clipper(clip_variant, clamp_embs=False, norm_embs=False, device=device, train_transforms=train_augs)

print('Creating voxel2clip...')

if voxel_dims == 1: # 1D data
    voxel2clip_kwargs = dict(out_dim=768)
    voxel2clip = BrainNetwork(**voxel2clip_kwargs)
elif voxel_dims == 3: # 3D data
    voxel2clip_kwargs = dict(
        out_dim=768,
        dims=voxel.shape[2:],
        channels=[64, 128, 256, 128],
        strides=[1, 2, 3, 3],
        padding=[1, 1, 1, 1],
        dilation=[1, 1, 1, 1],
        kernel=[3, 3, 3, 3],
    )
    voxel2clip = SimpleVoxel3dConvEncoder(**voxel2clip_kwargs)  

print("params of voxel2clip:")
if local_rank==0:
    utils.count_params(voxel2clip)
    
if not combine_models:
    # load voxel2clip model weights
    ckpt = torch.load(voxel2clip_path, map_location=device)
    if 'model_state_dict' in ckpt:
        ckpt = ckpt['model_state_dict']
    voxel2clip.load_state_dict(ckpt)

    # freeze when not combining models
    voxel2clip.eval()
    voxel2clip.requires_grad_(False)
    
if local_rank==0: print('Creating diffusion prior...')
prior_kwargs = dict(
    pretrained=prior_pretrained,
    network_kwargs=dict(),
    prior_kwargs=dict(),
)
if not prior_kwargs['pretrained']:
    # same as DALLE2-pytorch
    prior_network = DiffusionPriorNetwork(
        **prior_kwargs['network_kwargs'],
    )

    # custom version of DiffusionPrior from DALLE2-pytorch
    diffusion_prior = BrainDiffusionPrior(
        net=prior_network,
        voxel2clip=voxel2clip,
        **prior_kwargs['prior_kwargs'],
    )
else:
    # not using prior_kwargs b/c the model is pretrained
    diffusion_prior = BrainDiffusionPrior.from_pretrained(
        # kwargs for DiffusionPriorNetwork
        dict(),
        # kwargs for DiffusionNetwork
        dict(
            condition_on_text_encodings=False,
            timesteps=1000,
            voxel2clip=voxel2clip if combine_models else None,
        ),
        voxel2clip_path=voxel2clip_path if combine_models else None,
    )

print("params of diffusionprior:")
if local_rank==0: utils.count_params(diffusion_prior)

if n_samples_save > 0:
    if local_rank == 0: print('Creating SD image variation pipeline...')
    from diffusers import StableDiffusionImageVariationPipeline
    from diffusers import AutoencoderKL, PNDMScheduler, UNet2DConditionModel, UniPCMultistepScheduler
    
    sd_cache_dir = os.path.join(
                        os.path.expanduser('~'), 
                        ".cache/huggingface/diffusers/models--lambdalabs--sd-image-variations-diffusers/snapshots/a2a13984e57db80adcc9e3f85d568dcccb9b29fc/"
                    )
    if not os.path.isdir(sd_cache_dir): # download from huggingface if not already downloaded / cached
        from diffusers import StableDiffusionImageVariationPipeline
        print("Downloading lambdalabs/sd-image-variations-diffusers from huggingface...")
        sd_pipe = StableDiffusionImageVariationPipeline.from_pretrained("lambdalabs/sd-image-variations-diffusers", revision="v2.0")
        sd_cache_dir = "lambdalabs/sd-image-variations-diffusers"

    unet = UNet2DConditionModel.from_pretrained(sd_cache_dir,subfolder="unet").to(device)
    vae = AutoencoderKL.from_pretrained(sd_cache_dir,subfolder="vae").to(device)
    noise_scheduler = PNDMScheduler.from_pretrained(sd_cache_dir, subfolder="scheduler")
    if sd_scheduler=='unipcm':
        noise_scheduler = UniPCMultistepScheduler.from_config(noise_scheduler.config)

    unet.eval() # dont want to train model
    unet.requires_grad_(False) # dont need to calculate gradients

    vae.eval()
    vae.requires_grad_(False)
    
no_decay = ['bias']
opt_grouped_parameters = [
    {'params': [p for n, p in voxel2clip.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': 1e-2},
    {'params': [p for n, p in voxel2clip.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
]
optimizer = torch.optim.AdamW(opt_grouped_parameters, lr=initial_lr) # lr doesnt get used if lr_scheduler='cycle'

if lr_scheduler == 'fixed':
    lr_scheduler = None
elif lr_scheduler == 'cycle':
    total_steps=num_epochs*(num_train//batch_size)
    lr_scheduler = torch.optim.lr_scheduler.OneCycleLR(
        optimizer, 
        max_lr=max_lr,
        total_steps=total_steps,
        final_div_factor=1000,
        last_epoch=-1, pct_start=2/num_epochs
    )
    
def save_ckpt(tag):
    ckpt_path = os.path.join(outdir, f'{tag}.pth')
    print(f'saving {ckpt_path}',flush=True)
    state_dict = diffusion_prior.state_dict()
    torch.save({
        'epoch': epoch,
        'model_state_dict': diffusion_prior.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'train_losses': losses,
        'val_losses': val_losses,
        'lrs': lrs,
        }, ckpt_path)
        
print("\nDone with model preparations!")

Creating Clipper...
ViT-L/14 cuda
Creating voxel2clip...
params of voxel2clip:
param counts:
134,722,304 total
134,722,304 trainable
Creating diffusion prior...
params of diffusionprior:
param counts:
236,616,336 total
236,616,320 trainable
Creating SD image variation pipeline...

Done with model preparations!


# Weights and Biases

In [9]:
# params for wandb
if local_rank==0 and wandb_log:
    wandb_project = 'stability'
    wandb_run = model_name
    wandb_notes = ''
    
    if wandb_log: 
        import wandb
        print(f"wandb {wandb_project} run {wandb_run}")
        wandb.login(host='https://stability.wandb.io')#, relogin=True)
        wandb_config = {
          "model_name": model_name,
          "modality": modality,
          "voxel_dims": voxel_dims,
          "clip_variant": clip_variant,
          "batch_size": batch_size,
          "num_epochs": num_epochs,
          "use_image_aug": use_image_aug,
          "max_lr": max_lr,
          "lr_scheduler": lr_scheduler,
          "clamp_embs": clamp_embs,
          "mixup_pct": mixup_pct,
          "num_train": num_train,
          "num_val": num_val,
          "seed": seed,
          "distributed": distributed,
          "num_devices": num_devices,
          "world_size": world_size,
          "resume_from_ckpt": resume_from_ckpt,
          "ckpt_path": ckpt_path,
          "train_url": train_url,
          "val_url": val_url,
          "voxel2clip_path": voxel2clip_path,
          "combine_models": combine_models,
          "combine_losses": combine_losses,
          "prior_pretrained": prior_pretrained,
          "use_mixco": use_mixco, 
          "n_samples_save": n_samples_save,
          "sd_scheduler": sd_scheduler,
        }
        print("wandb_config:\n",wandb_config)
        wandb.init(
            project=wandb_project,
            name=wandb_run,
            config=wandb_config,
            notes=wandb_notes,
        )

# Huggingface Accelerate

In [10]:
diffusion_prior, optimizer, train_dl, val_dl, lr_scheduler = accelerator.prepare(
    diffusion_prior, optimizer, train_dl, val_dl, lr_scheduler
)

# Main

In [None]:
# need non-deterministic CuDNN for conv3D to work
utils.seed_everything(seed, cudnn_deterministic=False)

epoch = 0
losses, val_losses, lrs = [], [], []
best_val_loss = 1e9
soft_loss_temps = utils.cosine_anneal(0.004, 0.0075, num_epochs - int(mixup_pct * num_epochs))

if use_mixco:
    contrast_loss = utils.mixco_nce
else:
    contrast_loss = InfoNCE()
print('contrast_loss', contrast_loss)

# weight for prior's MSE loss term
if alpha_schedule == 'constant':
    alphas = np.ones(num_epochs) * 0.01
elif alpha_schedule == 'linear':
    alphas = np.linspace(0.01, 0.05, num_epochs, endpoint=True)
else:
    raise ValueError(f'unknown alpha_schedule: {alpha_schedule}')

val_voxel0 = val_image0 = None

# Optionally resume from checkpoint #
# PS: still need to check that this actually works
if resume_from_ckpt:
    print("\n---resuming from ckpt_path---\n",ckpt_path)
    checkpoint = torch.load(ckpt_path, map_location=device)
    epoch = checkpoint['epoch']
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    try:
        diffusionprior.load_state_dict(checkpoint['model_state_dict'])
    except:
        state_dict = checkpoint['model_state_dict']
        for key in list(state_dict.keys()):
            if 'module.' in key:
                state_dict[key.replace('module.', '')] = state_dict[key]
                del state_dict[key]
        diffusionprior.load_state_dict(state_dict)

progress_bar = tqdm(range(epoch,num_epochs), ncols=1200, disable=(local_rank!=0))
for epoch in progress_bar:
    diffusion_prior.train()

    sims = 0.
    sims_base = 0.
    val_sims = 0.
    val_sims_base = 0.
    fwd_percent_correct = 0.
    bwd_percent_correct = 0.
    val_fwd_percent_correct = 0.
    val_bwd_percent_correct = 0.
    loss_nce_sum = 0.
    loss_prior_sum = 0.
    val_loss_nce_sum = 0.
    val_loss_prior_sum = 0.
    # loss_on_aug = []
    # loss_off_aug = []
    # image_aug = None
    
    alpha = alphas[epoch]

    for train_i, (voxel, image, trial) in enumerate(train_dl):
        optimizer.zero_grad()
        
        repeat_index = train_i % 3

        image = image.float()
        voxel = voxel.float()[:,repeat_index].float()
        
        if voxel_dims == 3:
            voxel = voxel[:,np.unique(x_inc),:,:]
            voxel = voxel[:,:,np.unique(y_inc),:]
            voxel = voxel[:,:,:,np.unique(z_inc)]

        if is_text:
            annots = utils.select_annotations(subj01_annots[trial], random=True)
            clip_target = clip_extractor.embed_text(annots).float()
        else:
            clip_target = clip_extractor.embed_image(image).float()
        clip_target.to(voxel.dtype)
        
        if use_mixco and epoch < int(mixup_pct * num_epochs):
            voxel, perm, betas, select = utils.mixco(voxel)
            clip_target = utils.mixco_clip_target(clip_target, perm, select, betas)
        
        if combine_models:
            # loss here is MSE for the prior, clip_voxels are voxel2clip outputs
            loss, pred, clip_voxels = diffusion_prior(image_embed=clip_target, voxel=voxel)
            utils.check_loss(loss)
            
            if combine_losses:
                # combine losses for contrastive learned voxel2clip mapper and the prior
                if use_mixco:
                    if epoch < int(mixup_pct * num_epochs):
                        loss_nce = contrast_loss(
                            nn.functional.normalize(clip_voxels, dim=-1), 
                            nn.functional.normalize(clip_target, dim=-1),
                            temp=0.006, perm=perm, betas=betas, select=select,
                            distributed=distributed, accelerator=accelerator, local_rank=local_rank)

                    else:
                        epoch_temp = soft_loss_temps[epoch-int(mixup_pct*num_epochs)]
                        loss_nce = utils.soft_clip_loss(
                            nn.functional.normalize(clip_voxels, dim=-1), 
                            nn.functional.normalize(clip_target, dim=-1),
                            temp=epoch_temp,
                            distributed=distributed, accelerator=accelerator)
                else:
                    loss_nce = contrast_loss(
                        nn.functional.normalize(clip_voxels, dim=-1), 
                        nn.functional.normalize(clip_target, dim=-1),
                    )
                utils.check_loss(loss_nce)

                loss_nce_sum += loss_nce.item()
                loss_prior_sum += loss.item()

                # MSE and NCE are weighted equally at the beginning,
                # with alpha=0.01 we'll have something like .01*300 + .99*3 = 3 + 3
                loss = alpha * loss + (1-alpha) * loss_nce
            else:
                loss_prior_sum += loss.item()
        else:
            # don't train end to end, just use the frozen voxel2clip to get clip_voxels
            if distributed:
                clip_voxels = voxel2clip.module(voxel)
            else:
                clip_voxels = voxel2clip(voxel)
                
            loss, pred, _ = diffusion_prior(text_embed=clip_voxels, image_embed=clip_target)
            utils.check_loss(loss)
            # loss here is MSE for the prior when not combining losses
            loss_prior_sum += loss.item()
            
        losses.append(loss.item())
        lrs.append(optimizer.param_groups[0]['lr'])

        # similarity after prior diffusion
        sims += F.cosine_similarity(accelerator.gather(clip_target), 
                                    accelerator.gather(pred)).mean().item()
        # baseline similarity before prior diffusion
        sims_base += F.cosine_similarity(accelerator.gather(clip_target),
                                    accelerator.gather(clip_voxels)).mean().item()

        # forward and backward top 1 accuracy
        labels = torch.arange(len(clip_target)).to(device)
        fwd_percent_correct += utils.topk(utils.batchwise_cosine_similarity(clip_target, clip_voxels), labels, k=1)
        bwd_percent_correct += utils.topk(utils.batchwise_cosine_similarity(clip_voxels, clip_target), labels, k=1)

        accelerator.backward(loss)
        optimizer.step()

        if lr_scheduler is not None:
            lr_scheduler.step()

    diffusion_prior.eval()
    for val_i, (voxel, image, key) in enumerate(val_dl): 
        with torch.no_grad():
            repeat_index = val_i % 3

            image = image.float()
            voxel = voxel[:,repeat_index].float()

            if voxel_dims == 3:
                voxel = voxel[:,np.unique(x_inc),:,:]
                voxel = voxel[:,:,np.unique(y_inc),:]
                voxel = voxel[:,:,:,np.unique(z_inc)]

            if val_image0 is None:
                val_image0 = image.detach().clone()
                val_voxel0 = voxel.detach().clone()

            if is_text:
                annots = utils.select_annotations(subj01_annots[trial], random=False)
                clip_target = clip_extractor.embed_text(annots).float()
            else:
                clip_target = clip_extractor.embed_image(image).float()
            clip_target.to(voxel.dtype)
            
            if combine_models:
                loss, pred, clip_voxels = diffusion_prior(image_embed=clip_target, voxel=voxel) \
                            if not distributed else diffusion_prior.module(image_embed=clip_target, voxel=voxel)
                utils.check_loss(loss)
                
                if combine_losses:
                    if use_mixco:
                        if epoch < int(mixup_pct * num_epochs):
                            loss_nce = contrast_loss(
                                nn.functional.normalize(clip_voxels, dim=-1), 
                                nn.functional.normalize(clip_target, dim=-1),
                                temp=0.006, perm=perm, betas=betas, select=select,
                                distributed=distributed, accelerator=accelerator, local_rank=local_rank)

                        else:
                            epoch_temp = soft_loss_temps[epoch-int(mixup_pct*num_epochs)]
                            loss_nce = utils.soft_clip_loss(
                                nn.functional.normalize(clip_voxels, dim=-1), 
                                nn.functional.normalize(clip_target, dim=-1),
                                temp=epoch_temp,
                                distributed=distributed, accelerator=accelerator)
                    else:
                        loss_nce = contrast_loss(
                            nn.functional.normalize(clip_voxels, dim=-1), 
                            nn.functional.normalize(clip_target, dim=-1),
                        )
                    utils.check_loss(loss_nce)
                    
                    val_loss_nce_sum += loss_nce.item()
                    val_loss_prior_sum += loss.item()

                    val_loss = alpha * loss + (1-alpha) * loss_nce
                else:
                    val_loss = loss
                    val_loss_prior_sum += loss.item()
            else:
                clip_voxels = voxel2clip(voxel)
                val_loss, pred, _ = diffusion_prior(text_embed=clip_voxels, image_embed=clip_target) \
                    if not distributed else diffusion_prior.module(text_embed=clip_voxels, image_embed=clip_target)

            val_losses.append(val_loss.item())
            val_sims += F.cosine_similarity(clip_target, pred).mean().item()
            val_sims_base += F.cosine_similarity(clip_target, clip_voxels).mean().item()
            labels = torch.arange(len(clip_voxels)).to(device)
            val_fwd_percent_correct += utils.topk(
                utils.batchwise_cosine_similarity(clip_target, clip_voxels), labels, k=1
            ).item()
            val_bwd_percent_correct += utils.topk(
                utils.batchwise_cosine_similarity(clip_voxels, clip_target), labels, k=1
            ).item()

    if local_rank==0:
        if (not save_at_end and ckpt_saving) or (save_at_end and epoch == num_epochs - 1):
            # save best model
            val_loss = np.mean(val_losses[-(val_i+1):])
            if val_loss < best_val_loss:
                best_val_loss = val_loss
                save_ckpt('best')
            else:
                print(f'not best - val_loss: {val_loss:.3f}, best_val_loss: {best_val_loss:.3f}')

        # Save model checkpoint every `ckpt_interval`` epochs or on the last epoch
        if (ckpt_interval is not None and (epoch + 1) % ckpt_interval == 0) or epoch == num_epochs - 1:
            save_ckpt(f'epoch{epoch:03d}')

        logs = {"train/loss": np.mean(losses[-(train_i+1):]),
                "val/loss": np.mean(val_losses[-(val_i+1):]),
                "val/loss": np.mean(val_losses[-(val_i+1):]),
                "train/loss_nce": loss_nce_sum / (train_i + 1),
                "train/loss_mse": loss_prior_sum / (train_i + 1),
                "val/loss_nce": val_loss_nce_sum / (val_i + 1),
                "val/loss_mse": val_loss_prior_sum / (val_i + 1),
                "train/lr": lrs[-1],
                "train/num_steps": len(losses),
                "val/num_steps": len(val_losses),
                "train/sim": sims / (train_i + 1),
                "val/sim": val_sims / (val_i + 1),
                "train/cosine_sim_base": sims_base / (train_i + 1),
                "val/cosine_sim_base": val_sims_base / (val_i + 1),
                "train/fwd_pct_correct": fwd_percent_correct / (train_i + 1),
                "train/bwd_pct_correct": bwd_percent_correct / (train_i + 1),
                "val/val_fwd_pct_correct": val_fwd_percent_correct / (val_i + 1),
                "val/val_bwd_pct_correct": val_bwd_percent_correct / (val_i + 1),
                "train/alpha": alpha}
        progress_bar.set_postfix(**logs)
        
        # sample some images
        if (ckpt_interval is not None and (epoch + 1) % ckpt_interval == 0) or epoch == num_epochs - 1:
            if (not save_at_end and n_samples_save > 0) or (save_at_end and epoch == num_epochs - 1):
                # training   
                grid = utils.reconstruct_from_clip(
                    image0, voxel0,
                    diffusion_prior,
                    clip_extractor, unet, vae, noise_scheduler,
                    img_lowlevel = None,
                    num_inference_steps = num_inference_steps,
                    n_samples_save = n_samples_save,
                    recons_per_clip = recons_per_clip,
                    recons_per_brain = recons_per_brain,
                    guidance_scale = 7.5,
                    img2img_strength = img2img_strength,
                    timesteps = recon_timesteps,
                    seed = seed,
                    distributed = distributed,
                )
                grid.savefig(os.path.join(outdir, f'samples-train-epoch{epoch:03d}.png'))
                if wandb_log and local_rank==0:
                    logs[f"train/recons"] = wandb.Image(grid, caption=f"epoch{epoch:03d}")

                # validation
                grid = utils.reconstruct_from_clip(
                    val_image0, val_voxel0,
                    diffusion_prior, 
                    clip_extractor, unet, vae, noise_scheduler,
                    img_lowlevel = None,
                    num_inference_steps = num_inference_steps,
                    n_samples_save = n_samples_save,
                    recons_per_clip = recons_per_clip,
                    recons_per_brain = recons_per_brain,
                    guidance_scale = 7.5,
                    img2img_strength = img2img_strength,
                    timesteps = recon_timesteps,
                    seed = seed,
                    distributed = distributed,
                )
                grid.savefig(os.path.join(outdir, f'samples-val-epoch{epoch:03d}.png'))
                if wandb_log and local_rank==0:
                    logs[f"val/recons"] = wandb.Image(grid, caption=f"epoch{epoch:03d}")
            

        if wandb_log:
            wandb.log(logs)
            
    if distributed:
        dist.barrier()

if wandb_log and local_rank==0:
    wandb.finish()

print("\n===Finished!===\n")

Note: not using cudnn.deterministic
contrast_loss InfoNCE()


  0%|                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                   

saving ../train_logs/testing/epoch000.pth


  2%|███████████▉                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                  | 1/60 [01:56<1:54:44, 116.69s/it, train/alpha=0.01, train/bwd_pct_correct=tensor(0.9617, device='cuda:0'), train/cosine_sim_base=0.223, train/fwd_pct_correct=tensor(0.9786, device='cuda:0'), train/loss=5.6, train/loss_mse=338, train/loss_nce=2.24, train/lr=0.000155, train/num