# Import packages & functions

In [1]:
import os
import sys
import json
import argparse
import numpy as np
import math
from einops import rearrange
import time
import random
import string
import h5py
from tqdm import tqdm
import webdataset as wds
import gc
from PIL import Image
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from torchvision import transforms
from accelerate import Accelerator
from sklearn.linear_model import SGDRegressor
from sklearn.preprocessing import StandardScaler
from sklearn.utils import shuffle
# SDXL unCLIP requires code from https://github.com/Stability-AI/generative-models/tree/main
sys.path.append('generative_models/')
import sgm
from models import Clipper
from versatile_diffusion import Reconstructor
from generative_models.sgm.modules.encoders.modules import FrozenOpenCLIPImageEmbedder # bigG embedder
# tf32 data type is faster than standard float32
torch.backends.cuda.matmul.allow_tf32 = True
from sklearn.linear_model import Ridge
import pickle
# custom functions #
import utils

In [2]:
### Multi-GPU config ###
local_rank = os.getenv('RANK')
if local_rank is None: 
    local_rank = 0
else:
    local_rank = int(local_rank)
print("LOCAL RANK ", local_rank)  

data_type = torch.float16 # change depending on your mixed_precision
num_devices = torch.cuda.device_count()
if num_devices==0: num_devices = 1

# First use "accelerate config" in terminal and setup using deepspeed stage 2 with CPU offloading!
accelerator = Accelerator(split_batches=False, mixed_precision="fp16")
if utils.is_interactive(): # set batch size here if using interactive notebook instead of submitting job
    global_batch_size = batch_size = 8
else:
    global_batch_size = os.environ["GLOBAL_BATCH_SIZE"]
    batch_size = int(os.environ["GLOBAL_BATCH_SIZE"]) // num_devices

Detected kernel version 3.10.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


LOCAL RANK  0


In [3]:
print("PID of this process =",os.getpid())
device = accelerator.device
print("device:",device)
world_size = accelerator.state.num_processes
distributed = not accelerator.state.distributed_type == 'NO'
num_devices = torch.cuda.device_count()
if num_devices==0 or not distributed: num_devices = 1
num_workers = num_devices
print(accelerator.state)

print("distributed =",distributed, "num_devices =", num_devices, "local rank =", local_rank, "world size =", world_size, "data_type =", data_type)
print = accelerator.print # only print if local_rank=0

PID of this process = 84693
device: cuda
Distributed environment: DistributedType.NO
Num processes: 1
Process index: 0
Local process index: 0
Device: cuda

Mixed precision type: fp16

distributed = False num_devices = 1 local rank = 0 world size = 1 data_type = torch.float16


# Configurations

In [4]:
# if running this interactively, can specify jupyter_args here for argparser to use
if utils.is_interactive():
    model_name = "subj01_40sess_hypatia_ridge_flat_dp_test"
    print("model_name:", model_name)

    # other variables can be specified in the following string:
    jupyter_args = f"--data_path=../dataset \
                    --cache_dir=../cache \
                    --model_name={model_name} --subj=1 \
                    --no-multi_subject \
                    --mode imagery \
                    --use_prior \
                    --dual_guidance"

    print(jupyter_args)
    jupyter_args = jupyter_args.split()
    
    from IPython.display import clear_output # function to clear print outputs in cell
    %load_ext autoreload 
    # this allows you to change functions in models.py or utils.py and have this notebook automatically update with your revisions
    %autoreload 2 

model_name: subj01_40sess_hypatia_ridge_flat_dp_light
--data_path=../dataset                     --cache_dir=../cache                     --model_name=subj01_40sess_hypatia_ridge_flat_dp_light --subj=1                     --no-multi_subject                     --mode imagery                     --use_prior                     --dual_guidance


In [5]:
parser = argparse.ArgumentParser(description="Model Training Configuration")
parser.add_argument(
    "--model_name", type=str, default="testing",
    help="name of model, used for ckpt saving and wandb logging (if enabled)",
)
parser.add_argument(
    "--data_path", type=str, default=os.getcwd(),
    help="Path to where NSD data is stored / where to download it to",
)
parser.add_argument(
    "--cache_dir", type=str, default=os.getcwd(),
    help="Path to where misc. files downloaded from huggingface are stored. Defaults to current src directory.",
)
parser.add_argument(
    "--subj",type=int, default=1, choices=[1,2,3,4,5,6,7,8,9,10,11],
    help="Validate on which subject?",
)
parser.add_argument(
    "--multisubject_ckpt", type=str, default=None,
    help="Path to pre-trained multisubject model to finetune a single subject from. multisubject must be False.",
)
parser.add_argument(
    "--num_sessions", type=int, default=1,
    help="Number of training sessions to include",
)
parser.add_argument(
    "--use_prior",action=argparse.BooleanOptionalAction,default=True,
    help="whether to train diffusion prior (True) or just rely on retrieval part of the pipeline (False)",
)
parser.add_argument(
    "--visualize_prior",action=argparse.BooleanOptionalAction,default=False,
    help="output visualizations from unCLIP every ckpt_interval (requires much more memory!)",
)
parser.add_argument(
    "--batch_size", type=int, default=16,
    help="Batch size can be increased by 10x if only training retreival submodule and not diffusion prior",
)
parser.add_argument(
    "--wandb_log",action=argparse.BooleanOptionalAction,default=True,
    help="whether to log to wandb",
)
parser.add_argument(
    "--resume_from_ckpt",action=argparse.BooleanOptionalAction,default=False,
    help="if not using wandb and want to resume from a ckpt",
)
parser.add_argument(
    "--wandb_project",type=str,default="mindeye_imagery",
    help="wandb project name",
)
parser.add_argument(
    "--mixup_pct",type=float,default=.33,
    help="proportion of way through training when to switch from BiMixCo to SoftCLIP",
)
parser.add_argument(
    "--blurry_recon",action=argparse.BooleanOptionalAction,default=True,
    help="whether to output blurry reconstructions",
)
parser.add_argument(
    "--blur_scale",type=float,default=.5,
    help="multiply loss from blurry recons by this number",
)
parser.add_argument(
    "--clip_scale",type=float,default=1.,
    help="multiply contrastive loss by this number",
)
parser.add_argument(
    "--prior_scale",type=float,default=30,
    help="multiply diffusion prior loss by this",
)
parser.add_argument(
    "--use_image_aug",action=argparse.BooleanOptionalAction,default=False,
    help="whether to use image augmentation",
)
parser.add_argument(
    "--num_epochs",type=int,default=150,
    help="number of epochs of training",
)
parser.add_argument(
    "--multi_subject",action=argparse.BooleanOptionalAction,default=False,
)
parser.add_argument(
    "--new_test",action=argparse.BooleanOptionalAction,default=True,
)
parser.add_argument(
    "--n_blocks",type=int,default=4,
)
parser.add_argument(
    "--hidden_dim",type=int,default=1024,
)
parser.add_argument(
    "--seq_past",type=int,default=0,
)
parser.add_argument(
    "--seq_future",type=int,default=0,
)
parser.add_argument(
    "--lr_scheduler_type",type=str,default='cycle',choices=['cycle','linear'],
)
parser.add_argument(
    "--ckpt_saving",action=argparse.BooleanOptionalAction,default=True,
)
parser.add_argument(
    "--ckpt_interval",type=int,default=5,
    help="save backup ckpt and reconstruct every x epochs",
)
parser.add_argument(
    "--seed",type=int,default=42,
)
parser.add_argument(
    "--max_lr",type=float,default=5e-4,
)
parser.add_argument(
    "--ridge_weight_decay",type=float,default=60000,
)
parser.add_argument(
    "--prior_weight_decay",type=float,default=1e-2,
)
parser.add_argument(
    "--train_imageryrf",action=argparse.BooleanOptionalAction,default=False,
    help="Use the ImageryRF dataset for pretraining",
)
parser.add_argument(
    "--no_nsd",action=argparse.BooleanOptionalAction,default=False,
    help="Don't use the Natural Scenes Dataset for pretraining",
)
parser.add_argument(
    "--snr_threshold",type=float,default=-1.0,
    help="Used for calculating SNR on a whole brain to narrow down voxels.",
)
parser.add_argument(
    "--mode",type=str,default="all",
)
parser.add_argument(
    "--dual_guidance",action=argparse.BooleanOptionalAction,default=False,
    help="Use the decoded captions for dual guidance",
)
if utils.is_interactive():
    args = parser.parse_args(jupyter_args)
else:
    args = parser.parse_args()

# create global variables without the args prefix
for attribute_name in vars(args).keys():
    globals()[attribute_name] = getattr(args, attribute_name)
    
# seed all random functions
utils.seed_everything(seed)

outdir = os.path.abspath(f'../train_logs/{model_name}')
if not os.path.exists(outdir) and ckpt_saving:
    os.makedirs(outdir,exist_ok=True)
    
if use_image_aug or blurry_recon:
    import kornia
    from kornia.augmentation.container import AugmentationSequential
if use_image_aug:
    img_augment = AugmentationSequential(
        kornia.augmentation.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1, p=0.3),
        same_on_batch=False,
        data_keys=["input"],
    )
    
if multi_subject:
    if train_imageryrf:
            # 9,10,11 is ImageryRF subjects
        if no_nsd:
            subj_list = np.arange(9,12)
        else:
            subj_list = np.arange(1,12)
    else:
        subj_list = np.arange(1,9)
    subj_list = subj_list[subj_list != subj]
else:
    subj_list = [subj]

print("subj_list", subj_list, "num_sessions", num_sessions)

subj_list [1] num_sessions 1


# Prep data, models, and dataloaders

### Creating wds dataloader, preload betas and all 73k possible images

In [6]:
def my_split_by_node(urls): return urls
num_voxels_list = []
num_devices = 1
if multi_subject:
    nsessions_allsubj=np.array([40, 40, 32, 30, 40, 32, 40, 30])
    num_samples_per_epoch = (750*40) // num_devices 
else:
    num_samples_per_epoch = (750*num_sessions) // num_devices 

print("dividing batch size by subj_list, which will then be concatenated across subj during training...") 
batch_size = batch_size // len(subj_list)

num_iterations_per_epoch = num_samples_per_epoch // (batch_size*len(subj_list))

print("batch_size =", batch_size, "num_iterations_per_epoch =",num_iterations_per_epoch, "num_samples_per_epoch =",num_samples_per_epoch)



dividing batch size by subj_list, which will then be concatenated across subj during training...
batch_size = 16 num_iterations_per_epoch = 46 num_samples_per_epoch = 750


In [7]:
train_data = {}
train_dl = {}
num_voxels = {}
voxels = {}
for s in subj_list:
    print(f"Training with {num_sessions} sessions")
    # If an NSD subject
    if s < 9:
        if multi_subject:
            train_url = f"{data_path}/wds/subj{s:02d}/train/" + "{0.." + f"{nsessions_allsubj[s-1]-1}" + "}.tar"
        else:
            train_url = f"{data_path}/wds/subj{s:02d}/train/" + "{0.." + f"{num_sessions-1}" + "}.tar"
        print(train_url)
        
        train_data[f'subj{s:02d}'] = wds.WebDataset(train_url,resampled=True,nodesplitter=my_split_by_node)\
                            .shuffle(750, initial=1500, rng=random.Random(42))\
                            .decode("torch")\
                            .rename(behav="behav.npy", past_behav="past_behav.npy", future_behav="future_behav.npy", olds_behav="olds_behav.npy")\
                            .to_tuple(*["behav", "past_behav", "future_behav", "olds_behav"])
        train_dl[f'subj{s:02d}'] = torch.utils.data.DataLoader(train_data[f'subj{s:02d}'], batch_size=batch_size, shuffle=False, drop_last=True, pin_memory=True)
        betas = utils.create_snr_betas(subject=s, data_type=data_type, data_path=data_path, threshold = snr_threshold)
        x_train, train_nsd_ids, x_test, test_nsd_ids = utils.load_nsd(subject=s, betas=betas, data_path=data_path)
        print(x_test.shape, train_nsd_ids.shape)
        num_voxels_list.append(x_test[0].shape[-1])
        num_voxels[f'subj{s:02d}'] = x_test[0].shape[-1]
        voxels[f'subj{s:02d}'] = x_train
    elif s < 12:
        train_url = ""
        test_url = ""
        betas, images, _, _ = utils.load_imageryrf(subject=int(s-8), mode=mode, mask=True, stimtype="object", average=False, nest=False, split=True)
        betas = torch.where(torch.isnan(betas), torch.zeros_like(betas), betas)
        betas = betas.to("cpu").to(data_type)
        num_voxels_list.append(betas[0].shape[-1])
        num_voxels[f'subj{s:02d}'] = betas[0].shape[-1]
        num_nan_values = torch.sum(torch.isnan(betas))
        print("Number of NaN values in betas:", num_nan_values.item())
        indices = torch.randperm(len(betas))
        shuffled_betas = betas[indices]
        shuffled_images = images[indices]
        train_data[f'subj{s:02d}'] = torch.utils.data.TensorDataset(shuffled_betas, shuffled_images)
        train_dl[f'subj{s:02d}'] = torch.utils.data.DataLoader(train_data[f'subj{s:02d}'], batch_size=batch_size, shuffle=False, drop_last=True, pin_memory=True)
        
        
    # elif s < 15:
    #     betas, images = utils.load_imageryrf(subject=int(s-11), mode="imagery", mask=True, stimtype="object", average=False, nest=False)
    #     betas = torch.where(torch.isnan(betas), torch.zeros_like(betas), betas)
    #     betas = betas.to("cpu").to(data_type)
    #     num_voxels_list.append(betas[0].shape[-1])
    #     num_voxels[f'subj{s:02d}'] = betas[0].shape[-1]
        
    #     indices = torch.randperm(len(betas))
    #     shuffled_betas = betas[indices]
    #     shuffled_images = images[indices]
    #     train_data[f'subj{s:02d}'] = torch.utils.data.TensorDataset(shuffled_betas, shuffled_images)
    #     train_dl[f'subj{s:02d}'] = torch.utils.data.DataLoader(train_data[f'subj{s:02d}'], batch_size=batch_size, shuffle=False, drop_last=True, pin_memory=True)
    print(f"num_voxels for subj{s:02d}: {num_voxels[f'subj{s:02d}']}")

print("Loaded all subj train dls and betas!\n")

# Validate only on one subject (doesn't support ImageryRF)
if multi_subject: 
    subj = subj_list[0] # cant validate on the actual held out person so picking first in subj_list
if not new_test: # using old test set from before full dataset released (used in original MindEye paper)
    if subj==3:
        num_test=2113
    elif subj==4:
        num_test=1985
    elif subj==6:
        num_test=2113
    elif subj==8:
        num_test=1985
    else:
        num_test=2770
    test_url = f"{data_path}/wds/subj0{subj}/test/" + "0.tar"
elif new_test: # using larger test set from after full dataset released
    if subj==3:
        num_test=2371
    elif subj==4:
        num_test=2188
    elif subj==6:
        num_test=2371
    elif subj==8:
        num_test=2188
    else:
        num_test=3000
    test_url = f"{data_path}/wds/subj0{subj}/new_test/" + "0.tar"
print(test_url)
if subj < 9:
    test_data = wds.WebDataset(test_url,resampled=False,nodesplitter=my_split_by_node)\
                        .shuffle(750, initial=1500, rng=random.Random(42))\
                        .decode("torch")\
                        .rename(behav="behav.npy", past_behav="past_behav.npy", future_behav="future_behav.npy", olds_behav="olds_behav.npy")\
                        .to_tuple(*["behav", "past_behav", "future_behav", "olds_behav"])
    test_dl = torch.utils.data.DataLoader(test_data, batch_size=num_test, shuffle=False, drop_last=True, pin_memory=True)
else:
    _, _, betas, images = utils.load_imageryrf(subject=int(subj-8), mode=mode, mask=True, stimtype="object", average=False, nest=True, split=True)
    num_test = len(betas)
    betas = torch.where(torch.isnan(betas), torch.zeros_like(betas), betas)
    betas = betas.to("cpu").to(data_type)
    num_nan_values = torch.sum(torch.isnan(betas))
    print("Number of NaN values in test betas:", num_nan_values.item())
    test_data = torch.utils.data.TensorDataset(betas, images)
    test_dl = torch.utils.data.DataLoader(test_data, batch_size=num_test, shuffle=False, drop_last=True, pin_memory=True)
print(f"Loaded test dl for subj{subj}!\n")

seq_len = seq_past + 1 + seq_future
print(f"currently using {seq_len} seq_len (chose {seq_past} past behav and {seq_future} future behav)")

Training with 1 sessions
../dataset/wds/subj01/train/{0..0}.tar
torch.Size([1000, 15724]) (27000,)
num_voxels for subj01: 15724
Loaded all subj train dls and betas!

../dataset/wds/subj01/new_test/0.tar
Loaded test dl for subj1!

currently using 1 seq_len (chose 0 past behav and 0 future behav)


In [8]:
# Load 73k NSD images
f = h5py.File(f'{data_path}/coco_images_224_float16.hdf5', 'r')
images = f['images'] # if you go OOM you can remove the [:] so it isnt preloaded to cpu! (will require a few edits elsewhere tho)
# images = torch.Tensor(images).to("cpu").to(data_type)
print("Loaded all 73k possible NSD images to cpu!", images.shape)

# Load 73k NSD captions
captions = np.load(f'{data_path}/preprocessed_data/annots_73k.npy')
print("Loaded all 73k NSD captions to cpu!", captions.shape)

Loaded all 73k possible NSD images to cpu! (73000, 3, 224, 224)
Loaded all 73k NSD captions to cpu! (73000,)


## Load models

In [9]:

if blurry_recon:
    from diffusers import AutoencoderKL    
    autoenc = AutoencoderKL(
        down_block_types=['DownEncoderBlock2D', 'DownEncoderBlock2D', 'DownEncoderBlock2D', 'DownEncoderBlock2D'],
        up_block_types=['UpDecoderBlock2D', 'UpDecoderBlock2D', 'UpDecoderBlock2D', 'UpDecoderBlock2D'],
        block_out_channels=[128, 256, 512, 512],
        layers_per_block=2,
        sample_size=256,
    )
    ckpt = torch.load(f'{cache_dir}/sd_image_var_autoenc.pth')
    # Create a mapping from the old layer names to the new layer names
    autoenc.load_state_dict(ckpt)
    
    autoenc.eval()
    autoenc.requires_grad_(False)
    autoenc.to(device)
    utils.count_params(autoenc)

param counts:
83,653,863 total
0 trainable


### VD/CLIP image embeddings  model

In [10]:
clip_emb_dim = 768
clip_seq_dim = 257
clip_text_seq_dim=77
clip_extractor = Reconstructor(device=device, cache_dir=cache_dir)
clip_variant = "ViT-L-14_2"


Reconstructor: Loading model... fp16: True

#######################
# Running in eps mode #
#######################

making attention of type 'vanilla' with 512 in_channels
Working with z of shape (1, 4, 32, 32) = 4096 dimensions.
making attention of type 'vanilla' with 512 in_channels
Load pth from ../cache/kl-f8.pth
Load autoencoderkl with total 83653863 parameters,72921.759 parameter sum.
Load optimus_bert_connector with total 109489920 parameters,19325.272 parameter sum.
Load optimus_gpt2_connector with total 132109824 parameters,18600.700 parameter sum.
Load pth from ../cache/optimus-vae.pth
Load optimus_vae_next with total 241599744 parameters,-344611.688 parameter sum.




Load clip_image_context_encoder with total 427616513 parameters,64007.510 parameter sum.
Load clip_text_context_encoder with total 427616513 parameters,64007.510 parameter sum.
Load openai_unet_2d_next with total 859520964 parameters,99914.823 parameter sum.
Load openai_unet_0d_next with total 1706797888 parameters,250071.939 parameter sum.
Load vd_v2_0 with total 3746805485 parameters,206311.852 parameter sum.


# Diffusion Prior Models

In [11]:
if use_prior:
    from models import *

    class MindEyeModule(nn.Module):
        def __init__(self):
            super(MindEyeModule, self).__init__()
        def forward(self, x):
            return x
            
    dp_model = MindEyeModule()

    # setup diffusion prior network
    out_dim = clip_emb_dim
    depth = 6
    dim_head = 64
    heads = clip_emb_dim//64 # heads * dim_head = clip_emb_dim
    timesteps = 100

    prior_network = DiffusionPriorUNet(
        cond_dim=out_dim, 
        dropout=0.1)
    
    dp_model.diffusion_prior = DiffusionPrior(
        prior_network,
        device=device,
    )
    if dual_guidance:
        prior_network_txt = DiffusionPriorUNet(
        cond_dim=out_dim, 
        dropout=0.1)
    

        dp_model.diffusion_prior_txt = DiffusionPrior(
        prior_network,
        device=device,
    )
    #     utils.count_params(dp_model.diffusion_prior_txt)
    # utils.count_params(dp_model.diffusion_prior)
    # num_params = utils.count_params(dp_model)
    dp_model.to(device)
    # for name, param in dp_model.named_parameters():
        # print("DP Model Dtype:", param.dtype)
        # break

# Creating block of CLIP embeddings

In [12]:
file_path = f"{data_path}/preprocessed_data/{clip_variant}_image_embeddings.pt"
emb_batch_size = 50
if not os.path.exists(file_path):
    # Generate CLIP Image embeddings
    print("Generating CLIP Image embeddings!")
    clip_image = torch.zeros((len(images), clip_seq_dim * clip_emb_dim)).to("cpu")
    for i in tqdm(range(len(images) // emb_batch_size), desc="Encoding images..."):
        batch_images = images[i * emb_batch_size:i * emb_batch_size + emb_batch_size]
        batch_embeddings = clip_extractor.embed_image(torch.from_numpy(batch_images)).reshape(emb_batch_size, -1).detach().to("cpu")
        clip_image[i * emb_batch_size:i * emb_batch_size + emb_batch_size] = batch_embeddings

    torch.save(clip_image, file_path)
else:
    clip_image = torch.load(file_path)

if dual_guidance:
    file_path_txt = f"{data_path}/preprocessed_data/{clip_variant}_text_embeddings.pt"
    if not os.path.exists(file_path_txt):
        # Generate CLIP Text embeddings
        print("Generating CLIP Text embeddings!")
        clip_text = torch.zeros((len(captions), clip_text_seq_dim * clip_emb_dim)).to("cpu")
        for i in tqdm(range(len(captions) // emb_batch_size), desc="Encoding images..."):
            batch_captions = captions[i * emb_batch_size:i * emb_batch_size + emb_batch_size]
            clip_text[i * emb_batch_size:i * emb_batch_size + emb_batch_size] = clip_extractor.embed_text(batch_captions).reshape(emb_batch_size, -1).detach().to("cpu")
        torch.save(clip_text, file_path_txt)
    else:
        clip_text = torch.load(file_path_txt)

if blurry_recon:
    file_path = f"{data_path}/preprocessed_data/autoenc_image_embeddings.pt"
    if not os.path.exists(file_path):
        # Generate CLIP Image embeddings
        print("Generating VAE Image embeddings!")
        vae_image = torch.zeros((len(images), 3136)).to("cpu")
        with torch.cuda.amp.autocast(dtype=torch.float16):

            for i in tqdm(range(len(images) // emb_batch_size), desc="Encoding images..."):
                batch_images = images[i * emb_batch_size:i * emb_batch_size + emb_batch_size]
                batch_images = 2 * torch.from_numpy(batch_images).unsqueeze(0).detach().to(device=device, dtype=torch.float16) - 1
                batch_embeddings = (autoenc.encode(batch_images).latent_dist.mode() * 0.18215).detach().to("cpu").reshape(emb_batch_size, -1)
                vae_image[i * emb_batch_size:i * emb_batch_size + emb_batch_size] = batch_embeddings


    else:
        vae_image = torch.load(file_path)

In [13]:
# Filter to only ones needed during trainin

clip_image_train = torch.zeros((len(train_nsd_ids), clip_seq_dim, clip_emb_dim)).to("cpu")
clip_text_train = torch.zeros((len(train_nsd_ids), clip_text_seq_dim, clip_emb_dim)).to("cpu")
vae_image_train = torch.zeros((len(train_nsd_ids), 3136)).to("cpu")
for i, idx in enumerate(train_nsd_ids):
    clip_image_train[i] =  clip_image[idx].reshape(clip_seq_dim, clip_emb_dim)
    clip_text_train[i] = clip_text[idx].reshape(clip_text_seq_dim, clip_emb_dim)
    vae_image_train[i] = vae_image[idx]
    
clip_image_test = torch.zeros((len(test_nsd_ids), clip_seq_dim, clip_emb_dim)).to("cpu")
clip_text_test = torch.zeros((len(test_nsd_ids), clip_text_seq_dim, clip_emb_dim)).to("cpu")
vae_image_test = torch.zeros((len(test_nsd_ids), 3136)).to("cpu")
for i, idx in enumerate(test_nsd_ids):
    clip_image_test[i] =  clip_image[idx].reshape(clip_seq_dim, clip_emb_dim)
    clip_text_test[i] = clip_text[idx].reshape(clip_text_seq_dim, clip_emb_dim)
    vae_image_test[i] = vae_image[idx]
print(f"Loaded train/test images and captions for subj{subj}!", clip_image_train.shape, clip_image_test.shape)

Loaded train/test images and captions for subj1! torch.Size([27000, 257, 768]) torch.Size([1000, 257, 768])


# Train Ridge regression models

In [14]:
start = time.time()
model_path = f'{outdir}/ridge_image_weights.pkl'
if not os.path.exists(model_path):
    ridge_weights = np.zeros((clip_seq_dim * clip_emb_dim, num_voxels[f'subj{s:02d}'])).astype(np.float32)
    ridge_biases = np.zeros((clip_seq_dim * clip_emb_dim)).astype(np.float32)
    print(f"Training Ridge CLIP Image model with alpha={ridge_weight_decay}")
    
    model = Ridge(
        alpha=ridge_weight_decay,
        max_iter=50000,
        random_state=42,
    )

    model.fit(x_train, clip_image_train.reshape(len(clip_image_train), -1))
    ridge_weights = model.coef_
    ridge_biases = model.intercept_
    image_datadict = {"coef" : ridge_weights, "intercept" : ridge_biases}
    # Save the regression weights
    with open(model_path, 'wb') as f:
        pickle.dump(image_datadict, f)
else:
    with open(model_path, 'rb') as f:
        image_datadict = pickle.load(f)
    
if dual_guidance:
    model_path = f'{outdir}/ridge_text_weights.pkl'
    if not os.path.exists(model_path):
        ridge_weights_txt = np.zeros((clip_text_seq_dim * clip_emb_dim, num_voxels[f'subj{s:02d}'])).astype(np.float32)
        ridge_biases_txt = np.zeros((clip_text_seq_dim * clip_emb_dim)).astype(np.float32)
        print(f"Training Ridge CLIP Text model with alpha={ridge_weight_decay}")
        model = Ridge(
            alpha=ridge_weight_decay,
            max_iter=50000,
            random_state=42,
        )

        model.fit(x_train, clip_text_train.reshape(len(clip_text_train), -1))
        ridge_weights_txt = model.coef_
        ridge_biases_txt = model.intercept_
        text_datadict = {"coef" : ridge_weights_txt, "intercept" : ridge_biases_txt}
        # Save the regression weights
        with open(model_path, 'wb') as f:
            pickle.dump(text_datadict, f)
    else:
        with open(model_path, 'rb') as f:
            text_datadict = pickle.load(f)
            
if blurry_recon:
    model_path = f'{outdir}/ridge_blurry_weights.pkl'
    if not os.path.exists(model_path):
        ridge_weights_blurry = np.zeros((3136,num_voxels[f'subj{s:02d}'])).astype(np.float32)
        ridge_biases_blurry = np.zeros((3136,)).astype(np.float32)
        print(f"Training Ridge Blurry recon model with alpha={ridge_weight_decay}")
        model = Ridge(
            alpha=ridge_weight_decay,
            max_iter=50000,
            random_state=42,
        )
        model.fit(x_train, vae_image_train)
        ridge_weights_blurry = model.coef_
        ridge_biases_blurry = model.intercept_
        blurry_datadict = {"coef" : ridge_weights_blurry, "intercept" : ridge_biases_blurry}
        # Save the regression weights
        with open(model_path, 'wb') as f:
            pickle.dump(blurry_datadict, f)
    else:
        with open(model_path, 'rb') as f:
            blurry_datadict = pickle.load(f)

print(f"{model_name} model trained/loaded in {time.strftime('%H:%M:%S', time.gmtime(time.time() - start))}")
# If we arent going to train the diffusion prior, stop here:
if not use_prior:
    sys.exit(0)

subj01_40sess_hypatia_ridge_flat_dp_light model trained/loaded in 00:00:13


# Predict ridge variables for diffusion prior stage 2 training

In [15]:
# Training variables

pred_clip_image_train = torch.zeros((len(clip_image_train), clip_seq_dim, clip_emb_dim)).to("cpu")
model = Ridge(
    alpha=60000,
    max_iter=50000,
    random_state=42,
)
model.coef_ = image_datadict["coef"]
model.intercept_ = image_datadict["intercept"]
pred_clip_image_train = torch.from_numpy(model.predict(x_train).reshape(-1, clip_seq_dim, clip_emb_dim))

if dual_guidance:
    pred_clip_text_train = torch.zeros((len(clip_text_train), clip_text_seq_dim, clip_emb_dim)).to("cpu")
    model = Ridge(
        alpha=60000,
        max_iter=50000,
        random_state=42,
    )
    model.coef_ = text_datadict["coef"]
    model.intercept_ = text_datadict["intercept"]
    pred_clip_text_train = torch.from_numpy(model.predict(x_train).reshape(-1, clip_text_seq_dim, clip_emb_dim))
if blurry_recon:
    pred_blurry_vae_train = torch.zeros((len(vae_image_train), 3136)).to("cpu")
    model = Ridge(
        alpha=60000,
        max_iter=50000,
        random_state=42,
    )
    model.coef_ = blurry_datadict["coef"]
    model.intercept_ = blurry_datadict["intercept"]
    pred_blurry_vae_train = torch.from_numpy(model.predict(x_train).reshape(-1, 3136))
    
# normalizing preds
# for sequence in range(clip_seq_dim):
#     std_pred_clip_image_train = (pred_clip_image_train[:, sequence] - torch.mean(pred_clip_image_train[:, sequence],axis=0)) / torch.std(pred_clip_image_train[:, sequence],axis=0)
#     pred_clip_image_train[:, sequence] = std_pred_clip_image_train * torch.std(clip_image_train[:, sequence],axis=0) + torch.mean(clip_image_train[:, sequence],axis=0)
# if dual_guidance:
#     for sequence in range(clip_text_seq_dim):
#         std_pred_clip_text_train = (pred_clip_text_train[:, sequence] - torch.mean(pred_clip_text_train[:, sequence],axis=0)) / torch.std(pred_clip_text_train[:, sequence],axis=0)
#         pred_clip_text_train[:, sequence] = std_pred_clip_text_train * torch.std(clip_text_train[:, sequence],axis=0) + torch.mean(clip_text_train[:, sequence],axis=0)
# if blurry_recon:
#     std_pred_blurry_vae_train = (pred_blurry_vae_train - torch.mean(pred_blurry_vae_train,axis=0)) / torch.std(pred_blurry_vae_train,axis=0)
#     pred_blurry_vae_train = std_pred_blurry_vae_train * torch.std(vae_image_train,axis=0) + torch.mean(vae_image_train,axis=0)
    
    
# Testing variables:
pred_clip_image_test = torch.zeros((len(clip_image_test), clip_seq_dim, clip_emb_dim)).to("cpu")
model = Ridge(
    alpha=60000,
    max_iter=50000,
    random_state=42,
)
model.coef_ = image_datadict["coef"]
model.intercept_ = image_datadict["intercept"]
pred_clip_image_test = torch.from_numpy(model.predict(x_test).reshape(-1, clip_seq_dim, clip_emb_dim))

if dual_guidance:
    pred_clip_text_test = torch.zeros((len(clip_text_test), clip_text_seq_dim, clip_emb_dim)).to("cpu")
    model = Ridge(
        alpha=60000,
        max_iter=50000,
        random_state=42,
    )
    model.coef_ = text_datadict["coef"]
    model.intercept_ = text_datadict["intercept"]
    pred_clip_text_test = torch.from_numpy(model.predict(x_test).reshape(-1, clip_text_seq_dim, clip_emb_dim))
if blurry_recon:
    pred_blurry_vae_test = torch.zeros((len(vae_image_test), 3136)).to("cpu")
    model = Ridge(
        alpha=60000,
        max_iter=50000,
        random_state=42,
    )
    model.coef_ = blurry_datadict["coef"]
    model.intercept_ = blurry_datadict["intercept"]
    pred_blurry_vae_test = torch.from_numpy(model.predict(x_test).reshape(-1, 3136))
    
# normalizing preds
# for sequence in range(clip_seq_dim):
#     std_pred_clip_image_test = (pred_clip_image_test[:, sequence] - torch.mean(pred_clip_image_test[:, sequence],axis=0)) / torch.std(pred_clip_image_test[:, sequence],axis=0)
#     pred_clip_image_test[:, sequence] = std_pred_clip_image_test * torch.std(clip_image_train[:, sequence],axis=0) + torch.mean(clip_image_train[:, sequence],axis=0)
# if dual_guidance:
#     for sequence in range(clip_text_seq_dim):
#         std_pred_clip_text_test = (pred_clip_text_test[:, sequence] - torch.mean(pred_clip_text_test[:, sequence],axis=0)) / torch.std(pred_clip_text_test[:, sequence],axis=0)
#         pred_clip_text_test[:, sequence] = std_pred_clip_text_test * torch.std(clip_text_train[:, sequence],axis=0) + torch.mean(clip_text_train[:, sequence],axis=0)
# if blurry_recon:
#     std_pred_blurry_vae_test = (pred_blurry_vae_test - torch.mean(pred_blurry_vae_test,axis=0)) / torch.std(pred_blurry_vae_test,axis=0)
#     pred_blurry_vae_test = std_pred_blurry_vae_test * torch.std(vae_image_train,axis=0) + torch.mean(vae_image_train,axis=0)

# Train Diffusion Priors

In [None]:
if local_rank==0 and wandb_log: # only use main process for wandb logging
    import wandb
    print(f"wandb {wandb_project} run {model_name}")
    # need to configure wandb beforehand in terminal with "wandb init"!
    wandb_config = {
      "model_name": model_name,
      "use-prior": use_prior,
      "blurry_recon": blurry_recon,
      "global_batch_size": global_batch_size,
      "batch_size": batch_size,
      "num_epochs": num_epochs,
      "num_sessions": num_sessions,
    #   "num_params": num_params,
      "clip_scale": clip_scale,
      "prior_scale": prior_scale,
      "blur_scale": blur_scale,
      "use_image_aug": use_image_aug,
      "max_lr": max_lr,
      "lr_scheduler_type": lr_scheduler_type,
      "mixup_pct": mixup_pct,
      "num_samples_per_epoch": num_samples_per_epoch,
      "num_test": num_test,
      "ckpt_interval": ckpt_interval,
      "ckpt_saving": ckpt_saving,
      "seed": seed,
      "distributed": distributed,
      "num_devices": num_devices,
      "world_size": world_size,
      "train_url": train_url,
      "test_url": test_url,
      "train_imageryrf": train_imageryrf,
      "mode": mode,
    }
    print("wandb_config:\n",wandb_config)
    print("wandb_id:",model_name)
    wandb.init(
        id=model_name,
        project=wandb_project,
        name=model_name,
        config=wandb_config,
        resume=None,
    )
else:
    wandb_log = False

In [None]:
from torch.utils.data import Dataset
class EmbeddingDataset(Dataset):

    def __init__(self, clip_pred=None, clip_target=None):
        self.clip_pred = clip_pred
        self.clip_target = clip_target

    def __len__(self):
        return len(self.clip_pred)

    def __getitem__(self, idx):
        return {
            "clip_pred": self.clip_pred[idx],
            "clip_target": self.clip_target[idx]
        }

data_type=torch.float32
image_dataset_train = EmbeddingDataset(
        pred_clip_image_train.to("cpu", data_type), 
        clip_image_train.to("cpu", data_type))
image_dataloader_train = torch.utils.data.DataLoader(image_dataset_train, batch_size=batch_size, shuffle=True)

image_dataset_test = EmbeddingDataset(
        pred_clip_image_test.to("cpu", data_type), 
        clip_image_test.to("cpu", data_type))
image_dataloader_test = torch.utils.data.DataLoader(image_dataset_test, batch_size=batch_size, shuffle=False)

if dual_guidance:
    text_dataset_train = EmbeddingDataset(
        pred_clip_text_train.to("cpu", data_type),
        clip_text_train.to("cpu", data_type))
    text_dataloader_train = torch.utils.data.DataLoader(text_dataset_train, batch_size=batch_size, shuffle=True)
    
    text_dataset_test = EmbeddingDataset(
        pred_clip_text_test.to("cpu", data_type),
        clip_text_test.to("cpu", data_type))
    text_dataloader_test = torch.utils.data.DataLoader(text_dataset_test, batch_size=batch_size, shuffle=False)
    

In [None]:

# no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']

# opt_grouped_parameters = []
# if use_prior:
#     opt_grouped_parameters = [
#         {'params': [p for n, p in dp_model.diffusion_prior.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': prior_weight_decay},
#         {'params': [p for n, p in dp_model.diffusion_prior.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0},
#     ]
#     if dual_guidance:
#         opt_grouped_parameters.extend([
#         {'params': [p for n, p in dp_model.diffusion_prior_txt.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': prior_weight_decay},
#         {'params': [p for n, p in dp_model.diffusion_prior_txt.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
#         ])
        
# optimizer = torch.optim.AdamW(opt_grouped_parameters, lr=max_lr)

# if lr_scheduler_type == 'linear':
#     lr_scheduler = torch.optim.lr_scheduler.LinearLR(
#         optimizer,
#         total_iters=int(np.floor(num_epochs*len(train_dataloader))),
#         last_epoch=-1
#     )
# elif lr_scheduler_type == 'cycle':
#     total_steps=int(np.floor(num_epochs*len(train_dataloader)))
#     print("total_steps", total_steps)
#     lr_scheduler = torch.optim.lr_scheduler.OneCycleLR(
#         optimizer, 
#         max_lr=max_lr,
#         total_steps=total_steps,
#         final_div_factor=1000,
#         last_epoch=-1, pct_start=2/num_epochs
#     )
    
def save_ckpt(tag):
    ckpt_path = outdir+f'/{tag}.pth'
    torch.save({
        'epoch': epoch,
        'model_state_dict': dp_model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'lr_scheduler': lr_scheduler.state_dict(),
        'train_losses': losses,
        'test_losses': test_losses,
        'lrs': lrs,
        }, ckpt_path)
    print(f"\n---saved {outdir}/{tag} ckpt!---\n")

# def load_ckpt(tag,load_lr=True,load_optimizer=True,load_epoch=True,strict=True,outdir=outdir,multisubj_loading=False): 
#     print(f"\n---loading {outdir}/{tag}.pth ckpt---\n")
#     checkpoint = torch.load(outdir+'/last.pth', map_location='cpu')
#     state_dict = checkpoint['model_state_dict']
#     if multisubj_loading: # remove incompatible ridge layer that will otherwise error
#         state_dict.pop('ridge.linears.0.weight',None)
#     model.load_state_dict(state_dict, strict=strict)
#     if load_epoch:
#         globals()["epoch"] = checkpoint['epoch']
#         print("Epoch",epoch)
#     if load_optimizer:
#         optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
#     if load_lr:
#         lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
#     del checkpoint

print("\nDone with model preparations!")



In [None]:
mse = nn.MSELoss()
l1 = nn.L1Loss()
losses, test_losses, lrs = [], [], []
best_test_loss = 1e9
torch.cuda.empty_cache()

# with torch.cuda.amp.autocast(dtype=data_type):
dp_model.diffusion_prior.train(image_dataloader_train, num_epochs=150, learning_rate=max_lr, wandb_log=wandb_log)
dp_model.diffusion_prior_txt.train(text_dataloader_train, num_epochs=150, learning_rate=max_lr, wandb_log=wandb_log)
print("\n===Finished!===\n")
if ckpt_saving:
    save_ckpt(f'last')

# PASTING INFERENCE CODE 

In [None]:
# Setup
reconstructor = clip_extractor
save_raw = True
# Load pretrained model ckpt
tag='last'
outdir = os.path.abspath(f'../train_logs/{model_name}')
print(f"\n---loading {outdir}/{tag}.pth ckpt---\n")
try:
    checkpoint = torch.load(f'{outdir}/{tag}.pth', map_location='cpu')
    state_dict = checkpoint['model_state_dict']
    # dp_model.load_state_dict(state_dict, strict=False)
    
    dp_model.load_state_dict(state_dict, strict=True)
    del checkpoint
except: # probably ckpt is saved using deepspeed format
    import deepspeed
    print("load ckpt failed, loading deepspeed ckpt...")
    state_dict = deepspeed.utils.zero_to_fp32.get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir=outdir, tag=tag)
    dp_model.load_state_dict(state_dict, strict=False)
    del state_dict
print("ckpt loaded!")

In [None]:
final_recons = None
final_predcaptions = None
final_clipvoxels = None
final_blurryrecons = None
raw_root = f"/export/raid1/home/kneel027/Second-Sight/output/mental_imagery_paper_b3/{mode}/{model_name}/subject{subj}/"
print("raw_root:", raw_root)
recons_per_sample = 16
data_type = torch.float16
for rep in tqdm(range(gen_rep)):
    seed = random.randint(0,10000000)
    utils.seed_everything(seed = seed)
    print(f"seed = {seed}")
    # get all reconstructions    
    # all_images = None
    all_blurryrecons = None
    all_recons = None
    all_predcaptions = []
    all_clipvoxels = None
    
    minibatch_size = 1
    num_samples_per_image = 1
    plotting = False
    
    with torch.no_grad(), torch.cuda.amp.autocast(dtype=torch.float16):
        for idx in tqdm(range(0,voxels.shape[0]), desc="sample loop"):
            
            clip_voxels = pred_clip_image[idx].unsqueeze(0)
            if dual_guidance:
                clip_text_voxels = pred_clip_text[idx].unsqueeze(0)
            else:
                clip_text_voxels = None
            print(f"Ridge clip properties: shape {clip_voxels.shape}, {clip_text_voxels.shape}, type {clip_voxels.dtype}, {clip_text_voxels.dtype}, mean: {torch.mean(clip_voxels)}, {torch.mean(clip_text_voxels)}")
            # Save retrieval submodule outputs
            if all_clipvoxels is None:
                all_clipvoxels = clip_voxels.to('cpu')
            else:
                all_clipvoxels = torch.vstack((all_clipvoxels, clip_voxels.to('cpu')))
            
            # Set defaults if diffusion prior is not enabled
            prior_out = clip_voxels.reshape((-1, clip_seq_dim, clip_emb_dim)).to(device=device, dtype=data_type)
            if dual_guidance:
                prior_out_txt = clip_text_voxels.reshape((-1, clip_text_seq_dim, clip_emb_dim)).to(device=device, dtype=data_type)
            else:
                prior_out_txt = None
            # Overwrite guidance variables if diffusion prior is enabled
            if use_prior:
                print(f"Converted CLIP clip properties: shape {prior_out.shape}, {prior_out_txt.shape}, type {prior_out.dtype}, {prior_out_txt.dtype}, mean: {torch.mean(prior_out)}, {torch.mean(prior_out_txt)}, num_nans {torch.isnan(prior_out).sum()}, {torch.isnan(prior_out_txt).sum()}")
                # Feed voxels through versatile diffusion diffusion prior
                prior_out = dp_model.diffusion_prior.generate(c_embeds=prior_out, num_inference_steps=50, guidance_scale=5.0)
                if dual_guidance:
                    prior_out_txt = dp_model.diffusion_prior_txt.generate(c_embeds=prior_out_txt, num_inference_steps=50, guidance_scale=5.0)
                print(f"Diffusion Prior clip properties: shape {prior_out.shape}, {prior_out_txt.shape}, type {prior_out.dtype}, {prior_out_txt.dtype}, mean: {torch.mean(prior_out)}, {torch.mean(prior_out_txt)}, num_nans {torch.isnan(prior_out).sum()}, {torch.isnan(prior_out_txt).sum()}")
            
            if blurry_recon:
                blurred_image = (autoenc.decode(pred_blurry_vae[idx].reshape((1,4,28,28)).half().to(device)/0.18215).sample/ 2 + 0.5).clamp(0,1)
                im = torch.Tensor(blurred_image)
                if all_blurryrecons is None:
                    all_blurryrecons = im.cpu()
                else:
                    all_blurryrecons = torch.vstack((all_blurryrecons, im.cpu()))
                if plotting:
                    plt.figure(figsize=(2,2))
                    plt.imshow(transforms.ToPILImage()(im))
                    plt.axis('off')
                    plt.show()
            
            # Feed outputs through versatile diffusion
            samples_multi = [reconstructor.reconstruct(
                                image=transforms.ToPILImage()(torch.Tensor(blurred_image[0])),
                                c_i=prior_out,
                                c_t=prior_out_txt,
                                n_samples=1,
                                textstrength=0.4,
                                strength=0.85,
                                seed=seed) for _ in range(recons_per_sample)]
            samples = utils.pick_best_recon(samples_multi, clip_voxels, clip_extractor)
            if isinstance(samples, PIL.Image.Image):
                samples = transforms.ToTensor()(samples)
            samples = samples.unsqueeze(0)
            
            if all_recons is None:
                all_recons = samples.cpu()
            else:
                all_recons = torch.vstack((all_recons, samples.cpu()))
            if plotting:
                for s in range(num_samples_per_image):
                    plt.figure(figsize=(2,2))
                    plt.imshow(transforms.ToPILImage()(samples[s]))
                    plt.axis('off')
                    plt.show()
                    
            if plotting: 
                print(model_name)
                err # dont actually want to run the whole thing with plotting=True

            if save_raw:
                # print(f"Saving raw images to {raw_root}/{idx}/{rep}.png")
                os.makedirs(f"{raw_root}/{idx}/", exist_ok=True)
                transforms.ToPILImage()(samples[0]).save(f"{raw_root}/{idx}/{rep}.png")
                transforms.ToPILImage()(all_images[idx]).save(f"{raw_root}/{idx}/ground_truth.png")
                if rep == 0:
                    transforms.ToPILImage()(torch.Tensor(blurred_image[0]).cpu()).save(f"{raw_root}/{idx}/low_level.png")
                    torch.save(clip_voxels, f"{raw_root}/{idx}/clip_image_voxels.pt")
                    if dual_guidance:
                        torch.save(clip_text_voxels, f"{raw_root}/{idx}/clip_text_voxels.pt")
        # resize outputs before saving
        imsize = 256
        # saving
        # print(all_recons.shape)
        # torch.save(all_images,"evals/all_images.pt")
        if final_recons is None:
            final_recons = all_recons.unsqueeze(1)
            # final_predcaptions = np.expand_dims(all_predcaptions, axis=1)
            final_clipvoxels = all_clipvoxels.unsqueeze(1)
            if blurry_recon:
                final_blurryrecons = all_blurryrecons.unsqueeze(1)
        else:
            final_recons = torch.cat((final_recons, all_recons.unsqueeze(1)), dim=1)
            # final_predcaptions = np.concatenate((final_predcaptions, np.expand_dims(all_predcaptions, axis=1)), axis=1)
            final_clipvoxels = torch.cat((final_clipvoxels, all_clipvoxels.unsqueeze(1)), dim=1)
            if blurry_recon:
                final_blurryrecons = torch.cat((all_blurryrecons.unsqueeze(1),final_blurryrecons), dim = 1)
        
if blurry_recon:
    torch.save(final_blurryrecons,f"evals/{model_name}/{model_name}_all_blurryrecons_{mode}.pt")
torch.save(final_recons,f"evals/{model_name}/{model_name}_all_recons_{mode}.pt")
# torch.save(final_predcaptions,f"evals/{model_name}/{model_name}_all_predcaptions_{mode}.pt")
torch.save(final_clipvoxels,f"evals/{model_name}/{model_name}_all_clipvoxels_{mode}.pt")
print(f"saved {model_name} mi outputs!")

# if not utils.is_interactive():
#     sys.exit(0)
