# Import packages & functions

In [None]:
import os
import sys
import json
import argparse
import numpy as np
import math
from einops import rearrange
import time
import random
import string
import h5py
from tqdm import tqdm
import webdataset as wds
import gc
from PIL import Image
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from torchvision import transforms
from sklearn.linear_model import SGDRegressor
from sklearn.preprocessing import StandardScaler
from sklearn.utils import shuffle
# tf32 data type is faster than standard float32
torch.backends.cuda.matmul.allow_tf32 = True
from sklearn.linear_model import Ridge
import pickle
# custom functions #
import utils
from sc_reconstructor import SC_Reconstructor
from vdvae import VDVAE

# Configurations

In [None]:
# if running this interactively, can specify jupyter_args here for argparser to use
if utils.is_interactive():
    model_name = "jonathan_unclip"
    print("model_name:", model_name)
    
    # global_batch_size and batch_size should already be defined in the 2nd cell block
    jupyter_args = f"--data_path=/weka/proj-medarc/shared/mindeyev2_dataset \
                    --cache_dir=/weka/proj-medarc/shared/cache \
                    --model_name={model_name} \
                    --batch_size=64 \
                    --no-multi_subject --subj=1 --num_sessions=40 \
                    --no-dual_guidance --no-blurry_recon --prompt_recon --caption_type medium"

    print(jupyter_args)
    jupyter_args = jupyter_args.split()
    
    from IPython.display import clear_output # function to clear print outputs in cell
    %load_ext autoreload 
    # this allows you to change functions in models.py or utils.py and have this notebook automatically update with your revisions
    %autoreload 2 

In [None]:
parser = argparse.ArgumentParser(description="Model Training Configuration")
parser.add_argument(
    "--model_name", type=str, default="testing",
    help="name of model, used for ckpt saving and wandb logging (if enabled)",
)
parser.add_argument(
    "--data_path", type=str, default=os.getcwd(),
    help="Path to where NSD data is stored / where to download it to",
)
parser.add_argument(
    "--cache_dir", type=str, default=os.getcwd(),
    help="Path to where misc. files downloaded from huggingface are stored. Defaults to current src directory.",
)
parser.add_argument(
    "--subj",type=int, default=1, choices=[1,2,3,4,5,6,7,8,9,10,11],
    help="Validate on which subject?",
)
parser.add_argument(
    "--multisubject_ckpt", type=str, default=None,
    help="Path to pre-trained multisubject model to finetune a single subject from. multisubject must be False.",
)
parser.add_argument(
    "--num_sessions", type=int, default=1,
    help="Number of training sessions to include",
)
parser.add_argument(
    "--use_prior",action=argparse.BooleanOptionalAction,default=True,
    help="whether to train diffusion prior (True) or just rely on retrieval part of the pipeline (False)",
)
parser.add_argument(
    "--visualize_prior",action=argparse.BooleanOptionalAction,default=False,
    help="output visualizations from unCLIP every ckpt_interval (requires much more memory!)",
)
parser.add_argument(
    "--batch_size", type=int, default=16,
    help="Batch size can be increased by 10x if only training retreival submodule and not diffusion prior",
)
parser.add_argument(
    "--wandb_log",action=argparse.BooleanOptionalAction,default=False,
    help="whether to log to wandb",
)
parser.add_argument(
    "--resume_from_ckpt",action=argparse.BooleanOptionalAction,default=False,
    help="if not using wandb and want to resume from a ckpt",
)
parser.add_argument(
    "--prompt_recon",action=argparse.BooleanOptionalAction, default=True,
    help="Use for prompt generating",
)
parser.add_argument(
    "--wandb_project",type=str,default="stability",
    help="wandb project name",
)
parser.add_argument(
    "--mixup_pct",type=float,default=.33,
    help="proportion of way through training when to switch from BiMixCo to SoftCLIP",
)
parser.add_argument(
    "--blurry_recon",action=argparse.BooleanOptionalAction,default=True,
    help="whether to output blurry reconstructions",
)
parser.add_argument(
    "--blur_scale",type=float,default=.5,
    help="multiply loss from blurry recons by this number",
)
parser.add_argument(
    "--clip_scale",type=float,default=1.,
    help="multiply contrastive loss by this number",
)
parser.add_argument(
    "--prior_scale",type=float,default=30,
    help="multiply diffusion prior loss by this",
)
parser.add_argument(
    "--use_image_aug",action=argparse.BooleanOptionalAction,default=False,
    help="whether to use image augmentation",
)
parser.add_argument(
    "--num_epochs",type=int,default=150,
    help="number of epochs of training",
)
parser.add_argument(
    "--multi_subject",action=argparse.BooleanOptionalAction,default=False,
)
parser.add_argument(
    "--new_test",action=argparse.BooleanOptionalAction,default=True,
)
parser.add_argument(
    "--n_blocks",type=int,default=4,
)
parser.add_argument(
    "--hidden_dim",type=int,default=1024,
)
parser.add_argument(
    "--seq_past",type=int,default=0,
)
parser.add_argument(
    "--seq_future",type=int,default=0,
)
parser.add_argument(
    "--lr_scheduler_type",type=str,default='cycle',choices=['cycle','linear'],
)
parser.add_argument(
    "--ckpt_saving",action=argparse.BooleanOptionalAction,default=True,
)
parser.add_argument(
    "--ckpt_interval",type=int,default=5,
    help="save backup ckpt and reconstruct every x epochs",
)
parser.add_argument(
    "--seed",type=int,default=42,
)
parser.add_argument(
    "--max_lr",type=float,default=3e-4,
)
parser.add_argument(
    "--weight_decay",type=int,default=60000,
)
parser.add_argument(
    "--max_iter",type=int,default=50000,
)
parser.add_argument(
    "--train_imageryrf",action=argparse.BooleanOptionalAction,default=False,
    help="Use the ImageryRF dataset for pretraining",
)
parser.add_argument(
    "--no_nsd",action=argparse.BooleanOptionalAction,default=False,
    help="Don't use the Natural Scenes Dataset for pretraining",
)
parser.add_argument(
    "--snr_threshold",type=float,default=-1.0,
    help="Used for calculating SNR on a whole brain to narrow down voxels.",
)
parser.add_argument(
    "--mode",type=str,default="all",
)
parser.add_argument(
    "--dual_guidance",action=argparse.BooleanOptionalAction,default=False,
    help="Use the decoded captions for dual guidance",
)
parser.add_argument(
    "--top_n_rank_order_rois",type=int, default=-1,
    help="Used for selecting the top n rois on a whole brain to narrow down voxels.",
)
parser.add_argument(
    "--samplewise_rank_order_rois",action=argparse.BooleanOptionalAction, default=False,
    help="Use the samplewise rank order rois versus voxelwise",
)
parser.add_argument(
    "--caption_type",type=str,default='coco',choices=['coco','short', 'medium'],
)
if utils.is_interactive():
    args = parser.parse_args(jupyter_args)
else:
    args = parser.parse_args()

# create global variables without the args prefix
for attribute_name in vars(args).keys():
    globals()[attribute_name] = getattr(args, attribute_name)
    
# seed all random functions
utils.seed_everything(seed)

outdir = os.path.abspath(f'../train_logs/{model_name}')
os.makedirs(outdir,exist_ok=True)
device = "cuda"

# Prep data, models, and dataloaders

In [None]:
betas = utils.create_snr_betas(subject=subj, data_type=torch.float16, data_path=data_path, threshold = snr_threshold)
# betas = utils.load_subject_based_on_rank_order_rois(excluded_subject=subj, data_type=torch.float16, top_n_rois=top_n_rank_order_rois, samplewise=samplewise_rank_order_rois)
x_train, valid_nsd_ids_train, x_test, test_nsd_ids = utils.load_nsd(subject=subj, betas=betas, data_path=data_path)
print(x_train.shape, valid_nsd_ids_train.shape)

print(f"Loaded subj {subj} betas!\n")

## Prepare git feature

In [None]:
if not os.path.exists(f'{data_path}/git_image_features.hdf5'):
    print("Creating Git Feature...")
    from PIL import Image
    import requests
    from transformers import AutoProcessor, GitVisionModel, AutoModelForCausalLM, GitModel
    from modeling_git import GitForCausalLMClipEmb
    # Load 73k NSD images
    f = h5py.File(f'{data_path}/coco_images_224_float16.hdf5', 'r')
    beta_images = f['images'] 
    print("Loaded all 73k possible NSD images to cpu!", beta_images.shape)

    git_images = []
    processor = AutoProcessor.from_pretrained("microsoft/git-large-coco")
    
    git_text_model = GitForCausalLMClipEmb.from_pretrained("microsoft/git-large-coco")
    git_text_model.to(device)
    git_text_model.eval().requires_grad_(False)
    print("success load Git model")
    for i, image in enumerate(tqdm(beta_images)):
        pil_image = (image.transpose((1, 2, 0))*255).astype(np.uint8)
        inputs = processor(images=pil_image, return_tensors="pt").pixel_values.to(device)
        outputs = git_text_model.git.image_encoder(inputs).last_hidden_state
        # valid the captions
        if i <= 5:
            generated_ids = git_text_model.generate(pixel_values=outputs, max_length=50)
            generated_caption = processor.batch_decode(generated_ids, skip_special_tokens=True)
            print(generated_caption)
        git_images.append(outputs.detach().cpu().numpy())


    with h5py.File('/weka/proj-fmri/jonxu/MindEye_Imagery/data/git_image_features.hdf5', 'w') as f:
        f.create_dataset('features', data=np.array(git_images))
    print("Finished!")
    del beta_images, git_images
else:
    print("git_image_features.hdf5 already exist!")

In [None]:
# Load 73k NSD images
f = h5py.File(f'{data_path}/coco_images_224_float16.hdf5', 'r')
images = f['images'] # if you go OOM you can remove the [:] so it isnt preloaded to cpu! (will require a few edits elsewhere tho)
# images = torch.Tensor(images).to("cpu").to(data_type)
print("Loaded all 73k possible NSD images to cpu!", images.shape)

# Load 73k GiT NSD features
f = h5py.File(f'/weka/proj-fmri/jonxu/MindEye_Imagery/data/git_image_features.hdf5', 'r')
git_features = f['features']

# Load 73k NSD captions
if caption_type == "coco":
    caption_file = "annots_73k.npy"
elif caption_type == "short":
    caption_file = "short_length_captions.npy"
elif caption_type == "medium":
    caption_file = "mid_length_captions_73K.npy"
else:
    raise ValueError("Invalid caption type")
captions = np.load(f'{data_path}/preprocessed_data/{caption_file}')
print(captions.shape)
print("Loaded all 73k NSD captions to cpu!", captions.shape)

train_images = torch.zeros((len(valid_nsd_ids_train), 3, 224, 224))
train_git_images = torch.zeros((len(valid_nsd_ids_train), 257,1024))
train_captions = np.zeros((len(valid_nsd_ids_train),), dtype=object)

# Load specific training data
for i, idx in enumerate(valid_nsd_ids_train):
    train_images[i] =  torch.from_numpy(images[idx])
    train_git_images[i] = torch.from_numpy(git_features[idx])
    train_captions[i] = captions[idx]

print(f"Filtered down to only the {len(valid_nsd_ids_train)} training images for subject {subj}!")
del git_features, f

## Load models

### Feature extractor model

In [None]:
from generative_models.sgm.modules.encoders.modules import FrozenOpenCLIPImageEmbedder # bigG embedder

clip_img_embedder = FrozenOpenCLIPImageEmbedder(
    arch="ViT-bigG-14",
    version="laion2b_s39b_b160k",
    output_tokens=True,
    only_tokens=True,
)
clip_img_embedder.to(device)
image_embedding_variant = "ViT-bigG-14"
clip_seq_dim = 256
clip_emb_dim = 1664

vdvae = VDVAE(device=device, cache_dir=cache_dir)

### unCLIP model

In [None]:
from omegaconf import OmegaConf

In [None]:
# prep unCLIP

from generative_models.sgm.models.diffusion import DiffusionEngine
from generative_models.sgm.util import append_dims

def update_conf_paths(config_dict):
    for key,value in config_dict.items():
        if isinstance(value, dict):
            update_conf_paths(value)
        elif key == "target" and isinstance(value, str) and value.startswith("sgm"):
            config_dict[key] = "generative_models." + value

# Load and conver the config
config = OmegaConf.load("generative_models/configs/unclip6.yaml")
config = OmegaConf.to_container(config, resolve=True)

# Update target paths
# update_conf_paths(config)

# Extract params
unclip_params = config["model"]["params"]
network_config = unclip_params["network_config"]
denoiser_config = unclip_params["denoiser_config"]
first_stage_config = unclip_params["first_stage_config"]
conditioner_config = unclip_params["conditioner_config"]
sampler_config = unclip_params["sampler_config"]
scale_factor = unclip_params["scale_factor"]
disable_first_stage_autocast = unclip_params["disable_first_stage_autocast"]
offset_noise_level = unclip_params["loss_fn_config"]["params"]["offset_noise_level"]

first_stage_config['target'] = 'generative_models.sgm.models.autoencoder.AutoencoderKL'
sampler_config['params']['num_steps'] = 38

diffusion_engine = DiffusionEngine(network_config=network_config,
                       denoiser_config=denoiser_config,
                       first_stage_config=first_stage_config,
                       conditioner_config=conditioner_config,
                       sampler_config=sampler_config,
                       scale_factor=scale_factor,
                       disable_first_stage_autocast=disable_first_stage_autocast)
# set to inference
diffusion_engine.eval().requires_grad_(False)
diffusion_engine.to(device)

ckpt_path = f'{cache_dir}/unclip6_epoch0_step110000.ckpt'
ckpt = torch.load(ckpt_path, map_location='cpu')
diffusion_engine.load_state_dict(ckpt['state_dict'])

batch={"jpg": torch.randn(1,3,1,1).to(device), # jpg doesnt get used, it's just a placeholder
      "original_size_as_tuple": torch.ones(1, 2).to(device) * 768,
      "crop_coords_top_left": torch.zeros(1, 2).to(device)}
out = diffusion_engine.conditioner(batch)
vector_suffix = out["vector"].to(device)
print("vector_suffix", vector_suffix.shape)


# Creating block of CLIP embeddings

In [None]:
file_path = f"{data_path}/preprocessed_data/subject{subj}/{image_embedding_variant}_image_embeddings_train.pt"
emb_batch_size = 50
if not os.path.exists(file_path):
    # Generate CLIP Image embeddings
    print("Generating Image embeddings!")
    clip_image_train = torch.zeros((len(train_images), clip_seq_dim, clip_emb_dim)).to("cpu")
    for i in tqdm(range(len(train_images) // emb_batch_size), desc="Encoding images..."):
        batch = train_images[i * emb_batch_size:i * emb_batch_size + emb_batch_size]
        batch = batch.to(device).to(dtype=torch.float16)
        embedding = clip_img_embedder(batch).to("cpu")
        clip_image_train[i * emb_batch_size:i * emb_batch_size + emb_batch_size] = embedding

    torch.save(clip_image_train, file_path)
else:
    clip_image_train = torch.load(file_path)

print(f"Loaded train image clip {clip_image_train.shape} for subj{subj}!")

# Train Ridge regression models

In [None]:
start = time.time()
ridge_weights = np.zeros((clip_seq_dim * clip_emb_dim, x_train.shape[-1])).astype(np.float32)
ridge_biases = np.zeros((clip_seq_dim * clip_emb_dim)).astype(np.float32)
print(f"Training Ridge Image model with alpha=100000")
model = Ridge(
    alpha=100000,
    max_iter=max_iter,
    random_state=42,
)

model.fit(x_train, clip_image_train.reshape(len(clip_image_train), -1))
ridge_weights = model.coef_
ridge_biases = model.intercept_
datadict = {"coef" : ridge_weights, "intercept" : ridge_biases}
# Save the regression weights
with open(f'{outdir}/ridge_image_weights.pkl', 'wb') as f:
    pickle.dump(datadict, f)
    
if dual_guidance:
    ridge_weights_txt = np.zeros((clip_text_seq_dim * clip_text_emb_dim, x_train.shape[-1])).astype(np.float32)
    ridge_biases_txt = np.zeros((clip_text_seq_dim * clip_text_emb_dim)).astype(np.float32)
    print(f"Training Ridge Text model with alpha=100000")
    model = Ridge(
        alpha=100000,
        max_iter=max_iter,
        random_state=42,
    )

    model.fit(x_train, clip_text_train.reshape(len(clip_text_train), -1))
    ridge_weights_txt = model.coef_
    ridge_biases_txt = model.intercept_
    datadict = {"coef" : ridge_weights_txt, "intercept" : ridge_biases_txt}
    # Save the regression weights
    with open(f'{outdir}/ridge_text_weights.pkl', 'wb') as f:
        pickle.dump(datadict, f)
            
if blurry_recon:
    ridge_weights_blurry = np.zeros((latent_emb_dim, x_train.shape[-1])).astype(np.float32)
    ridge_biases_blurry = np.zeros((latent_emb_dim,)).astype(np.float32)
    print(f"Training Ridge Blurry recon model with alpha={weight_decay}")
    model = Ridge(
        alpha=weight_decay,
        max_iter=max_iter,
        random_state=42,
    )
    model.fit(x_train, vae_image_train)
    ridge_weights_blurry = model.coef_
    ridge_biases_blurry = model.intercept_
    datadict = {"coef" : ridge_weights_blurry, "intercept" : ridge_biases_blurry}
    # Save the regression weights
    with open(f'{outdir}/ridge_blurry_weights.pkl', 'wb') as f:
        pickle.dump(datadict, f)

if prompt_recon:
    ridge_weights_prompt = np.zeros((git_seq_dim*git_emb_dim, x_train.shape[-1])).astype(np.float32)
    ridge_biases_prompt = np.zeros((git_seq_dim*git_emb_dim)).astype(np.float32)
    print(f"Training Ridge prompt recon model with alpha=100000")
    model = Ridge(
        alpha=100000,
        max_iter=max_iter,
        random_state=42,
    )
    model.fit(x_train, git_text_train.reshape(len(git_text_train), -1))
    ridge_weights_prompt = model.coef_
    ridge_biases_prompt = model.intercept_
    datadict = {"coef" : ridge_weights_prompt, "intercept" : ridge_biases_prompt}
    # Save the regression weights
    with open(f'{outdir}/ridge_prompt_weights.pkl', 'wb') as f:
        pickle.dump(datadict, f)

print(f"Elapsed training time for {model_name}: {time.strftime('%H:%M:%S', time.gmtime(time.time() - start))}")