# Import packages & functions

In [None]:
import os
import sys
import json
import argparse
import numpy as np
import math
from einops import rearrange
import time
import random
import string
import h5py
from tqdm import tqdm
import webdataset as wds
import gc
from PIL import Image
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from torchvision import transforms
from accelerate import Accelerator
from sklearn.linear_model import SGDRegressor
from sklearn.preprocessing import StandardScaler
from sklearn.utils import shuffle

from versatile_diffusion import Reconstructor
# tf32 data type is faster than standard float32
torch.backends.cuda.matmul.allow_tf32 = True
from sklearn.linear_model import Ridge
import pickle
# custom functions #
import utils

In [None]:
### Multi-GPU config ###
local_rank = os.getenv('RANK')
if local_rank is None: 
    local_rank = 0
else:
    local_rank = int(local_rank)
print("LOCAL RANK ", local_rank)  

data_type = torch.float16 # change depending on your mixed_precision
num_devices = torch.cuda.device_count()
if num_devices==0: num_devices = 1

# First use "accelerate config" in terminal and setup using deepspeed stage 2 with CPU offloading!
accelerator = Accelerator(split_batches=False, mixed_precision="fp16")
if utils.is_interactive(): # set batch size here if using interactive notebook instead of submitting job
    global_batch_size = batch_size = 8
else:
    global_batch_size = os.environ["GLOBAL_BATCH_SIZE"]
    batch_size = int(os.environ["GLOBAL_BATCH_SIZE"]) // num_devices

In [None]:
print("PID of this process =",os.getpid())
device = accelerator.device
print("device:",device)
world_size = accelerator.state.num_processes
distributed = not accelerator.state.distributed_type == 'NO'
num_devices = torch.cuda.device_count()
if num_devices==0 or not distributed: num_devices = 1
num_workers = num_devices
print(accelerator.state)

print("distributed =",distributed, "num_devices =", num_devices, "local rank =", local_rank, "world size =", world_size, "data_type =", data_type)
print = accelerator.print # only print if local_rank=0

# Configurations

In [None]:
# if running this interactively, can specify jupyter_args here for argparser to use
if utils.is_interactive():
    model_name = "testing"
    print("model_name:", model_name)
    
    # global_batch_size and batch_size should already be defined in the 2nd cell block
    jupyter_args = f"--data_path=../dataset/ \
                    --cache_dir=../cache/ \
                    --model_name={model_name} \
                    --batch_size=64 \
                    --no-multi_subject --subj=1 --num_sessions=40 \
                    --hidden_dim=1024 --clip_scale=1. \
                    --no-blurry_recon --blur_scale=.5  \
                    --seq_past=0 --seq_future=0 \
                    --no-use_prior --prior_scale=30 \
                    --n_blocks=4 --max_lr=3e-4 --mixup_pct=.33 --num_epochs=150 --no-use_image_aug \
                    --ckpt_interval=1 --ckpt_saving"

    print(jupyter_args)
    jupyter_args = jupyter_args.split()
    
    from IPython.display import clear_output # function to clear print outputs in cell
    %load_ext autoreload 
    # this allows you to change functions in models.py or utils.py and have this notebook automatically update with your revisions
    %autoreload 2 

In [None]:
parser = argparse.ArgumentParser(description="Model Training Configuration")
parser.add_argument(
    "--model_name", type=str, default="testing",
    help="name of model, used for ckpt saving and wandb logging (if enabled)",
)
parser.add_argument(
    "--data_path", type=str, default=os.getcwd(),
    help="Path to where NSD data is stored / where to download it to",
)
parser.add_argument(
    "--cache_dir", type=str, default=os.getcwd(),
    help="Path to where misc. files downloaded from huggingface are stored. Defaults to current src directory.",
)
parser.add_argument(
    "--subj",type=int, default=1, choices=[1,2,3,4,5,6,7,8,9,10,11],
    help="Validate on which subject?",
)
parser.add_argument(
    "--multisubject_ckpt", type=str, default=None,
    help="Path to pre-trained multisubject model to finetune a single subject from. multisubject must be False.",
)
parser.add_argument(
    "--num_sessions", type=int, default=1,
    help="Number of training sessions to include",
)
parser.add_argument(
    "--use_prior",action=argparse.BooleanOptionalAction,default=True,
    help="whether to train diffusion prior (True) or just rely on retrieval part of the pipeline (False)",
)
parser.add_argument(
    "--visualize_prior",action=argparse.BooleanOptionalAction,default=False,
    help="output visualizations from unCLIP every ckpt_interval (requires much more memory!)",
)
parser.add_argument(
    "--batch_size", type=int, default=16,
    help="Batch size can be increased by 10x if only training retreival submodule and not diffusion prior",
)
parser.add_argument(
    "--wandb_log",action=argparse.BooleanOptionalAction,default=False,
    help="whether to log to wandb",
)
parser.add_argument(
    "--resume_from_ckpt",action=argparse.BooleanOptionalAction,default=False,
    help="if not using wandb and want to resume from a ckpt",
)
parser.add_argument(
    "--wandb_project",type=str,default="stability",
    help="wandb project name",
)
parser.add_argument(
    "--mixup_pct",type=float,default=.33,
    help="proportion of way through training when to switch from BiMixCo to SoftCLIP",
)
parser.add_argument(
    "--blurry_recon",action=argparse.BooleanOptionalAction,default=True,
    help="whether to output blurry reconstructions",
)
parser.add_argument(
    "--blur_scale",type=float,default=.5,
    help="multiply loss from blurry recons by this number",
)
parser.add_argument(
    "--clip_scale",type=float,default=1.,
    help="multiply contrastive loss by this number",
)
parser.add_argument(
    "--prior_scale",type=float,default=30,
    help="multiply diffusion prior loss by this",
)
parser.add_argument(
    "--use_image_aug",action=argparse.BooleanOptionalAction,default=False,
    help="whether to use image augmentation",
)
parser.add_argument(
    "--num_epochs",type=int,default=150,
    help="number of epochs of training",
)
parser.add_argument(
    "--multi_subject",action=argparse.BooleanOptionalAction,default=False,
)
parser.add_argument(
    "--new_test",action=argparse.BooleanOptionalAction,default=True,
)
parser.add_argument(
    "--n_blocks",type=int,default=4,
)
parser.add_argument(
    "--hidden_dim",type=int,default=1024,
)
parser.add_argument(
    "--seq_past",type=int,default=0,
)
parser.add_argument(
    "--seq_future",type=int,default=0,
)
parser.add_argument(
    "--lr_scheduler_type",type=str,default='cycle',choices=['cycle','linear'],
)
parser.add_argument(
    "--ckpt_saving",action=argparse.BooleanOptionalAction,default=True,
)
parser.add_argument(
    "--ckpt_interval",type=int,default=5,
    help="save backup ckpt and reconstruct every x epochs",
)
parser.add_argument(
    "--seed",type=int,default=42,
)
parser.add_argument(
    "--max_lr",type=float,default=3e-4,
)
parser.add_argument(
    "--weight_decay",type=float,default=60000,
)
parser.add_argument(
    "--train_imageryrf",action=argparse.BooleanOptionalAction,default=False,
    help="Use the ImageryRF dataset for pretraining",
)
parser.add_argument(
    "--no_nsd",action=argparse.BooleanOptionalAction,default=False,
    help="Don't use the Natural Scenes Dataset for pretraining",
)
parser.add_argument(
    "--snr_threshold",type=float,default=-1.0,
    help="Used for calculating SNR on a whole brain to narrow down voxels.",
)
parser.add_argument(
    "--mode",type=str,default="all",
)
parser.add_argument(
    "--dual_guidance",action=argparse.BooleanOptionalAction,default=False,
    help="Use the decoded captions for dual guidance",
)
if utils.is_interactive():
    args = parser.parse_args(jupyter_args)
else:
    args = parser.parse_args()

# create global variables without the args prefix
for attribute_name in vars(args).keys():
    globals()[attribute_name] = getattr(args, attribute_name)
    
# seed all random functions
utils.seed_everything(seed)

outdir = os.path.abspath(f'../train_logs/{model_name}')
if not os.path.exists(outdir) and ckpt_saving:
    os.makedirs(outdir,exist_ok=True)
    
if use_image_aug or blurry_recon:
    import kornia
    from kornia.augmentation.container import AugmentationSequential
if use_image_aug:
    img_augment = AugmentationSequential(
        kornia.augmentation.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1, p=0.3),
        same_on_batch=False,
        data_keys=["input"],
    )
    
if multi_subject:
    if train_imageryrf:
            # 9,10,11 is ImageryRF subjects
        if no_nsd:
            subj_list = np.arange(9,12)
        else:
            subj_list = np.arange(1,12)
    else:
        subj_list = np.arange(1,9)
    subj_list = subj_list[subj_list != subj]
else:
    subj_list = [subj]

print("subj_list", subj_list, "num_sessions", num_sessions)

# Prep data, models, and dataloaders

In [None]:
betas = utils.create_snr_betas(subject=subj, data_type=torch.float16, data_path=data_path, threshold = snr_threshold)
x_train, valid_nsd_ids_train, x_test, test_nsd_ids = utils.load_nsd(subject=subj, betas=betas, data_path=data_path)
print(x_train.shape, valid_nsd_ids_train.shape)

print(f"Loaded subj {subj} betas!\n")

In [None]:
# Load 73k NSD images
f = h5py.File(f'{data_path}/coco_images_224_float16.hdf5', 'r')
images = f['images'] # if you go OOM you can remove the [:] so it isnt preloaded to cpu! (will require a few edits elsewhere tho)
# images = torch.Tensor(images).to("cpu").to(data_type)
print("Loaded all 73k possible NSD images to cpu!", images.shape)

# Load 73k NSD captions
captions = np.load(f'{data_path}/preprocessed_data/annots_73k.npy')
print("Loaded all 73k NSD captions to cpu!", captions.shape)

## Load models

In [None]:

if blurry_recon:
    from diffusers import AutoencoderKL    
    autoenc = AutoencoderKL(
        down_block_types=['DownEncoderBlock2D', 'DownEncoderBlock2D', 'DownEncoderBlock2D', 'DownEncoderBlock2D'],
        up_block_types=['UpDecoderBlock2D', 'UpDecoderBlock2D', 'UpDecoderBlock2D', 'UpDecoderBlock2D'],
        block_out_channels=[128, 256, 512, 512],
        layers_per_block=2,
        sample_size=256,
    )
    ckpt = torch.load(f'{cache_dir}/sd_image_var_autoenc.pth')
    # Create a mapping from the old layer names to the new layer names
    autoenc.load_state_dict(ckpt)
    
    autoenc.eval()
    autoenc.requires_grad_(False)
    autoenc.to(device)
    utils.count_params(autoenc)

### VD/CLIP image embeddings  model

In [None]:
clip_emb_dim = 768
clip_text_emb_dim = 768
clip_seq_dim = 257
clip_text_seq_dim=77
clip_extractor = Reconstructor(device=device, cache_dir=cache_dir)
clip_variant = "ViT-L-14"
latent_dim = 3136
latent_variant = "autoenc"

# Creating block of CLIP embeddings

In [None]:
file_path = f"{data_path}/preprocessed_data/{image_embedding_variant}_image_embeddings.pt"
emb_batch_size = 50
if not os.path.exists(file_path):
    # Generate CLIP Image embeddings
    print("Generating CLIP Image embeddings!")
    clip_image = torch.zeros((len(images), clip_seq_dim, clip_emb_dim)).to("cpu")
    for i in tqdm(range(len(images) // emb_batch_size), desc="Encoding images..."):
        batch_images = images[i * emb_batch_size:i * emb_batch_size + emb_batch_size]
        batch_embeddings = clip_extractor.embed_image(torch.from_numpy(batch_images)).detach().to("cpu")
        clip_image[i * emb_batch_size:i * emb_batch_size + emb_batch_size] = batch_embeddings

    torch.save(clip_image, file_path)
else:
    clip_image = torch.load(file_path)
clip_image_train = torch.zeros((len(valid_nsd_ids_train), clip_seq_dim, clip_emb_dim)).to("cpu")
for i, idx in enumerate(valid_nsd_ids_train):
    clip_image_train[i] = clip_image[idx]
file_path = f"{data_path}/preprocessed_data/subject{subj}/{image_embedding_variant}_image_embeddings_train.pt"
if not os.path.exists(file_path):
    torch.save(clip_image_train, file_path)
        
if dual_guidance:
    file_path_txt = f"{data_path}/preprocessed_data/{text_embedding_variant}_text_embeddings.pt"
    if not os.path.exists(file_path_txt):
        # Generate CLIP Text embeddings
        print("Generating CLIP Text embeddings!")
        clip_text = torch.zeros((len(captions), clip_text_seq_dim, clip_text_emb_dim)).to("cpu")
        for i in tqdm(range(len(captions) // emb_batch_size), desc="Encoding captions..."):
            batch_captions = captions[i * emb_batch_size:i * emb_batch_size + emb_batch_size].tolist()
            clip_text[i * emb_batch_size:i * emb_batch_size + emb_batch_size] = clip_extractor.embed_text(batch_captions).detach().to("cpu")
        torch.save(clip_text, file_path_txt)
    else:
        clip_text = torch.load(file_path_txt)
    clip_text_train = torch.zeros((len(valid_nsd_ids_train), clip_text_seq_dim, clip_text_emb_dim)).to("cpu")
    for i, idx in enumerate(valid_nsd_ids_train):
        clip_text_train[i] = clip_text[idx]
    file_path_txt = f"{data_path}/preprocessed_data/subject{subj}/{text_embedding_variant}_text_embeddings_train.pt"
    if not os.path.exists(file_path_txt):
        torch.save(clip_text_train, file_path_txt)
        

if blurry_recon:
    file_path = f"{data_path}/preprocessed_data/{latent_variant}_image_embeddings.pt"
    if not os.path.exists(file_path):
        # Generate CLIP Image embeddings
        print("Generating VAE Image embeddings!")
        vae_image = torch.zeros((len(images), latent_dim)).to("cpu")
        with torch.cuda.amp.autocast(dtype=torch.float16):

            for i in tqdm(range(len(images) // emb_batch_size), desc="Encoding images..."):
                batch_images = images[i * emb_batch_size:i * emb_batch_size + emb_batch_size]
                batch_images = 2 * torch.from_numpy(batch_images).unsqueeze(0).detach().to(device=device, dtype=torch.float16) - 1
                batch_embeddings = (autoenc.encode(batch_images).latent_dist.mode() * 0.18215).detach().to("cpu").reshape(emb_batch_size, -1)
                vae_image[i * emb_batch_size:i * emb_batch_size + emb_batch_size] = batch_embeddings


    else:
        vae_image = torch.load(file_path)
    vae_image_train = torch.zeros((len(valid_nsd_ids_train), latent_dim)).to("cpu")
    for i, idx in enumerate(valid_nsd_ids_train):
        vae_image_train[i] = vae_image[idx]
    file_path_vae = f"{data_path}/preprocessed_data/subject{subj}/{latent_variant}_image_embeddings_train.pt"
    if not os.path.exists(file_path_vae):
        torch.save(vae_image_train, file_path_vae)
print(f"Loaded train image clip and text clip for subj{subj}!", clip_image_train.shape)

# Train Ridge regression models

In [None]:
start = time.time()
ridge_weights = np.zeros((clip_seq_dim * clip_emb_dim, x_train.shape[-1])).astype(np.float32)
ridge_biases = np.zeros((clip_seq_dim * clip_emb_dim)).astype(np.float32)
print(f"Training Ridge CLIP Image model with alpha={weight_decay}")
model = Ridge(
    alpha=weight_decay,
    max_iter=50000,
    random_state=42,
)

model.fit(x_train, clip_image_train.reshape(len(clip_image_train), -1))
ridge_weights = model.coef_
ridge_biases = model.intercept_
datadict = {"coef" : ridge_weights, "intercept" : ridge_biases}
# Save the regression weights
with open(f'{outdir}/ridge_image_weights.pkl', 'wb') as f:
    pickle.dump(datadict, f)
    
if dual_guidance:
    ridge_weights_txt = np.zeros((clip_text_seq_dim * clip_emb_dim, x_train.shape[-1])).astype(np.float32)
    ridge_biases_txt = np.zeros((clip_text_seq_dim * clip_emb_dim)).astype(np.float32)
    print(f"Training Ridge CLIP Text model with alpha={weight_decay}")
    model = Ridge(
        alpha=weight_decay,
        max_iter=50000,
        random_state=42,
    )

    model.fit(x_train, clip_text_train.reshape(len(clip_text_train), -1))
    ridge_weights_txt = model.coef_
    ridge_biases_txt = model.intercept_
    datadict = {"coef" : ridge_weights_txt, "intercept" : ridge_biases_txt}
    # Save the regression weights
    with open(f'{outdir}/ridge_text_weights.pkl', 'wb') as f:
        pickle.dump(datadict, f)
            
if blurry_recon:
    ridge_weights_blurry = np.zeros((latent_dim,x_train.shape[-1])).astype(np.float32)
    ridge_biases_blurry = np.zeros((latent_dim,)).astype(np.float32)
    print(f"Training Ridge Blurry recon model with alpha={weight_decay}")
    model = Ridge(
        alpha=weight_decay,
        max_iter=50000,
        random_state=42,
    )
    model.fit(x_train, vae_image_train)
    ridge_weights_blurry = model.coef_
    ridge_biases_blurry = model.intercept_
    datadict = {"coef" : ridge_weights_blurry, "intercept" : ridge_biases_blurry}
    # Save the regression weights
    with open(f'{outdir}/ridge_blurry_weights.pkl', 'wb') as f:
        pickle.dump(datadict, f)


print(f"Elapsed training time for {model_name}: {time.strftime('%H:%M:%S', time.gmtime(time.time() - start))}")