# Import packages & functions

In [1]:
import os
import sys
import json
import argparse
import numpy as np
import math
from einops import rearrange
import time
import random
import string
import h5py
from tqdm import tqdm
import webdataset as wds
import gc
from PIL import Image
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from torchvision import transforms
from sklearn.linear_model import SGDRegressor
from sklearn.preprocessing import StandardScaler
from sklearn.utils import shuffle
# tf32 data type is faster than standard float32
torch.backends.cuda.matmul.allow_tf32 = True
from sklearn.linear_model import Ridge
import pickle
# custom functions #
import utils
from sc_reconstructor import SC_Reconstructor
from vdvae import VDVAE

  check_for_updates()


# Configurations

In [2]:
# if running this interactively, can specify jupyter_args here for argparser to use
if utils.is_interactive():
    model_name = "jonathan_refiner"
    print("model_name:", model_name)
    
    # global_batch_size and batch_size should already be defined in the 2nd cell block
    jupyter_args = f"--data_path=/weka/proj-medarc/shared/mindeyev2_dataset \
                    --cache_dir=/weka/proj-medarc/shared/cache \
                    --model_name={model_name} \
                    --subj=1 --num_sessions=40 \
                    --no-dual_guidance --no-blurry_recon --caption_type medium"

    print(jupyter_args)
    jupyter_args = jupyter_args.split()
    
    from IPython.display import clear_output # function to clear print outputs in cell
    %load_ext autoreload 
    # this allows you to change functions in models.py or utils.py and have this notebook automatically update with your revisions
    %autoreload 2 

model_name: jonathan_refiner
--data_path=/weka/proj-medarc/shared/mindeyev2_dataset                     --cache_dir=/weka/proj-medarc/shared/cache                     --model_name=jonathan_refiner                     --subj=1 --num_sessions=40                     --no-dual_guidance --no-blurry_recon --caption_type medium


In [3]:
parser = argparse.ArgumentParser(description="Model Training Configuration")
parser.add_argument(
    "--model_name", type=str, default="testing",
    help="name of model, used for ckpt saving and wandb logging (if enabled)",
)
parser.add_argument(
    "--data_path", type=str, default=os.getcwd(),
    help="Path to where NSD data is stored / where to download it to",
)
parser.add_argument(
    "--cache_dir", type=str, default=os.getcwd(),
    help="Path to where misc. files downloaded from huggingface are stored. Defaults to current src directory.",
)
parser.add_argument(
    "--subj",type=int, default=1, choices=[1,2,3,4,5,6,7,8],
    help="Validate on which subject?",
)
parser.add_argument(
    "--num_sessions", type=float, default=40,
    help="Number of training sessions to include",
)
parser.add_argument(
    "--prompt_recon",action=argparse.BooleanOptionalAction, default=True,
    help="Use for prompt generating",
)
parser.add_argument(
    "--blurry_recon",action=argparse.BooleanOptionalAction,default=True,
    help="whether to output blurry reconstructions",
)
parser.add_argument(
    "--seed",type=int,default=42,
)
parser.add_argument(
    "--weight_decay",type=int,default=100000,
)
parser.add_argument(
    "--max_iter",type=int,default=50000,
)
parser.add_argument(
    "--dual_guidance",action=argparse.BooleanOptionalAction,default=True,
    help="Use the decoded captions for dual guidance",
)
parser.add_argument(
    "--caption_type",type=str,default='medium',choices=['coco','short', 'medium', 'schmedium'],
)
parser.add_argument(
    "--retrieval",action=argparse.BooleanOptionalAction,default=True,
    help="Use the decoded captions for dual guidance",
)
if utils.is_interactive():
    args = parser.parse_args(jupyter_args)
else:
    args = parser.parse_args()
print(f"args: {args}")
# create global variables without the args prefix
for attribute_name in vars(args).keys():
    globals()[attribute_name] = getattr(args, attribute_name)
    
# seed all random functions
utils.seed_everything(seed)

outdir = os.path.abspath(f'../train_logs/{model_name}')
os.makedirs(outdir,exist_ok=True)
device = "cuda"

args: Namespace(model_name='jonathan_refiner', data_path='/weka/proj-medarc/shared/mindeyev2_dataset', cache_dir='/weka/proj-medarc/shared/cache', subj=1, num_sessions=40.0, prompt_recon=True, blurry_recon=False, seed=42, weight_decay=100000, max_iter=50000, dual_guidance=False, caption_type='medium', retrieval=True)


# Prep data, models, and dataloaders

In [4]:
x_train, valid_nsd_ids_train, x_test, test_nsd_ids = utils.load_nsd(subject=subj, num_sessions=num_sessions, data_path=data_path)
print(x_train.shape, valid_nsd_ids_train.shape)

print(f"Loaded subj {subj} betas!\n")

torch.Size([27000, 15724]) (27000,)
Loaded subj 1 betas!



## Prepare git feature

In [5]:
if not os.path.exists(f'{data_path}/git_image_features.hdf5'):
    print("Creating Git Feature...")
    from PIL import Image
    import requests
    from transformers import AutoProcessor, GitVisionModel, AutoModelForCausalLM, GitModel
    from modeling_git import GitForCausalLMClipEmb
    # Load 73k NSD images
    f = h5py.File(f'{data_path}/coco_images_224_float16.hdf5', 'r')
    beta_images = f['images'] 
    print("Loaded all 73k possible NSD images to cpu!", beta_images.shape)

    git_images = []
    processor = AutoProcessor.from_pretrained("microsoft/git-large-coco")
    
    git_text_model = GitForCausalLMClipEmb.from_pretrained("microsoft/git-large-coco")
    git_text_model.to(device)
    git_text_model.eval().requires_grad_(False)
    print("success load Git model")
    for i, image in enumerate(tqdm(beta_images)):
        pil_image = (image.transpose((1, 2, 0))*255).astype(np.uint8)
        inputs = processor(images=pil_image, return_tensors="pt").pixel_values.to(device)
        outputs = git_text_model.git.image_encoder(inputs).last_hidden_state
        # valid the captions
        if i <= 5:
            generated_ids = git_text_model.generate(pixel_values=outputs, max_length=50)
            generated_caption = processor.batch_decode(generated_ids, skip_special_tokens=True)
            print(generated_caption)
        git_images.append(outputs.detach().cpu().numpy())


    with h5py.File(f'{data_path}/git_image_features.hdf5', 'w') as f:
        f.create_dataset('features', data=np.array(git_images))
    print("Finished!")
    del beta_images, git_images
else:
    print("git_image_features.hdf5 already exist!")

git_image_features.hdf5 already exist!


In [6]:
# Load 73k NSD images
f = h5py.File(f'{data_path}/coco_images_224_float16.hdf5', 'r')
images = f['images'] # if you go OOM you can remove the [:] so it isnt preloaded to cpu! (will require a few edits elsewhere tho)
# images = torch.Tensor(images).to("cpu").to(data_type)
print("Loaded all 73k possible NSD images to cpu!", images.shape)

# Load 73k NSD captions
if caption_type == "schmedium":
    captions_small = np.load(f'{data_path}/preprocessed_data/short_length_captions.npy')
    captions_medium = np.load(f'{data_path}/preprocessed_data/mid_length_captions_73K.npy')
    # Create a mask to randomly select elements from both arrays
    mask = np.random.rand(len(captions_small)) > 0.5
    # Mix the arrays based on the mask
    captions = np.where(mask, captions_small, captions_medium)
else:
    if caption_type == "coco":
        caption_file = "annots_73k.npy"
    elif caption_type == "short":
        caption_file = "short_length_captions.npy"
    elif caption_type == "medium":
        caption_file = "mid_length_captions_73K.npy"
    else:
        raise ValueError("Invalid caption type")
    captions = np.load(f'{data_path}/preprocessed_data/{caption_file}')
print("Loaded all 73k NSD captions to cpu!", captions.shape)

train_images = torch.zeros((len(valid_nsd_ids_train), 3, 224, 224))
train_captions = np.zeros((len(valid_nsd_ids_train),), dtype=object)

# Load specific training data
for i, idx in enumerate(valid_nsd_ids_train):
    train_images[i] =  torch.from_numpy(images[idx])
    train_captions[i] = captions[idx]
    
print(f"Filtered down to only the {len(valid_nsd_ids_train)} training images for subject {subj}!")

Loaded all 73k possible NSD images to cpu! (73000, 3, 224, 224)
Loaded all 73k NSD captions to cpu! (73000,)
Filtered down to only the 27000 training images for subject 1!


## Load models

### Feature extractor model

In [9]:
from generative_models.sgm.modules.encoders.modules import FrozenOpenCLIPImageEmbedder # bigG embedder
clip_img_embedder = FrozenOpenCLIPImageEmbedder(
    arch="ViT-bigG-14",
    version="laion2b_s39b_b160k",
    output_tokens=True,
    only_tokens=True,
)
clip_img_embedder.to(device)
image_embedding_variant = "ViT-bigG-14"
clip_emb_dim = 1664
clip_seq_dim = 256

clip_extractor = SC_Reconstructor(compile_models=False, embedder_only=True, device=device, cache_dir=cache_dir)
retrieval_embedding_variant = "stable_cascade_hidden"
retrieval_emb_dim = 1024
retrieval_seq_dim = 257

prompt_embedding_variant = "git"
git_seq_dim = 257
git_emb_dim = 1024

Stable Cascade Reconstructor: Loading model...
['model_version', 'effnet_checkpoint_path', 'previewer_checkpoint_path']
['transforms', 'clip_preprocess', 'gdf', 'sampling_configs', 'effnet_preprocess']


OutOfMemoryError: CUDA out of memory. Tried to allocate 32.00 MiB. GPU 0 has a total capacty of 79.11 GiB of which 3.25 MiB is free. Process 4066995 has 67.03 GiB memory in use. Including non-PyTorch memory, this process has 12.06 GiB memory in use. Of the allocated memory 11.32 GiB is allocated by PyTorch, and 243.36 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

# Creating block of CLIP embeddings

In [8]:
file_path = f"{data_path}/preprocessed_data/subject{subj}/{image_embedding_variant}_image_embeddings_train.pt"
emb_batch_size = 50
if not os.path.exists(file_path):
    # Generate CLIP Image embeddings
    print("Generating Image embeddings!")
    clip_image_train = torch.zeros((len(train_images), clip_seq_dim, clip_emb_dim)).to("cpu")
    for i in tqdm(range(len(train_images) // emb_batch_size), desc="Encoding images..."):
        batch = train_images[i * emb_batch_size:i * emb_batch_size + emb_batch_size]
        batch = batch.to(device).to(dtype=torch.float16)
        embedding = clip_img_embedder(batch).to("cpu")
        clip_image_train[i * emb_batch_size:i * emb_batch_size + emb_batch_size] = embedding

    torch.save(clip_image_train, file_path)
else:
    clip_image_train = torch.load(file_path)

if retrieval:
    file_path = f"{data_path}/preprocessed_data/subject{subj}/{retrieval_embedding_variant}_retrieval_embeddings_train.pt"
    emb_batch_size = 50
    if not os.path.exists(file_path):
        # Generate CLIP Retrieval embeddings
        print("Generating Retrieval embeddings!")
        retrieval_image_train = torch.zeros((len(train_images), retrieval_seq_dim, retrieval_emb_dim)).to("cpu")
        for i in tqdm(range(len(train_images) // emb_batch_size), desc="Encoding images..."):
            batch_list = []
            for img in train_images[i * emb_batch_size:i * emb_batch_size + emb_batch_size]:
                batch_list.append(transforms.ToPILImage()(img))
            retrieval_image_train[i * emb_batch_size:i * emb_batch_size + emb_batch_size] = clip_extractor.embed_image(batch_list, hidden=True).to("cpu")
        # Normalize for optimal cosine similarity
        retrieval_image_train = torch.nn.functional.normalize(retrieval_image_train, p=2, dim=2)
        torch.save(retrieval_image_train, file_path)
    else:
        retrieval_image_train = torch.load(file_path)
        
# Load 73k GiT NSD features
if prompt_recon:
    file_path_git = f"{data_path}/preprocessed_data/subject{subj}/{prompt_embedding_variant}_prompt_embeddings_train.pt"
    if not os.path.exists(file_path_git):
        with h5py.File(f'{data_path}/git_image_features.hdf5', 'r') as f:
            git_features = f['features'][:]
        train_git_images = torch.zeros((len(valid_nsd_ids_train), 257,1024))
        for i, idx in enumerate(valid_nsd_ids_train):
            train_git_images[i] = torch.from_numpy(git_features[idx])
        torch.save(train_git_images, file_path_git)
        del git_features
    else:
        train_git_images = torch.load(file_path_git)

# Cut down vectors to only samples used for training based on num_sessions, this assumed scanIDs are in order:
clip_image_train = clip_image_train[:len(train_images)]
if retrieval:
    retrieval_image_train = retrieval_image_train[:len(train_images)]
if prompt_recon:
    train_git_images = train_git_images[:len(train_images)]
    
print(f"Loaded vectors for subj{subj}!")

Generating Retrieval embeddings!


Encoding images...:   0%|                                                                         | 0/540 [00:00<?, ?it/s]


NameError: name 'clip_extractor' is not defined

# Train Ridge regression models

In [None]:
start = time.time()
ridge_weights = np.zeros((clip_seq_dim * clip_emb_dim, x_train.shape[-1])).astype(np.float32)
ridge_biases = np.zeros((clip_seq_dim * clip_emb_dim)).astype(np.float32)
print(f"Training Ridge Image model with alpha=100000")
model = Ridge(
    alpha=100000,
    max_iter=max_iter,
    random_state=42,
)

param_count = ridge_weights.size + ridge_biases.size  # + n_outputs if intercept is included
print(f"Estimated parameter count before fitting: {param_count}")

model.fit(x_train, clip_image_train.reshape(len(clip_image_train), -1))
ridge_weights = model.coef_
ridge_biases = model.intercept_
datadict = {"coef" : ridge_weights, "intercept" : ridge_biases}
# Save the regression weights
with open(f'{outdir}/ridge_image_weights.pkl', 'wb') as f:
    pickle.dump(datadict, f)

if retrieval:
    ridge_weights = np.zeros((retrieval_seq_dim * retrieval_emb_dim, x_train.shape[-1])).astype(np.float32)
    ridge_biases = np.zeros((retrieval_seq_dim * retrieval_emb_dim)).astype(np.float32)
    print(f"Training Ridge Retrieval model with alpha={weight_decay}")
    model = Ridge(
        alpha=weight_decay,
        max_iter=max_iter,
        random_state=42,
    )
    x_train_norm = torch.nn.functional.normalize(x_train, p=2, dim=1)
    model.fit(x_train_norm, retrieval_image_train.reshape(len(retrieval_image_train), -1))
    ridge_weights = model.coef_
    ridge_biases = model.intercept_
    datadict = {"coef" : ridge_weights, "intercept" : ridge_biases}
    # Save the regression weights
    with open(f'{outdir}/ridge_retrieval_weights.pkl', 'wb') as f:
        pickle.dump(datadict, f)
    
    del retrieval_image_train
    del ridge_weights
    del ridge_biases
    del datadict
        
if prompt_recon:
    ridge_weights_prompt = np.zeros((git_seq_dim*git_emb_dim, x_train.shape[-1])).astype(np.float32)
    ridge_biases_prompt = np.zeros((git_seq_dim*git_emb_dim)).astype(np.float32)
    print(f"Training Ridge prompt recon model with alpha={weight_decay}")
    model = Ridge(
        alpha=weight_decay,
        max_iter=max_iter,
        random_state=42,
    )
    model.fit(x_train, train_git_images.reshape(len(train_git_images), -1))
    ridge_weights_prompt = model.coef_
    ridge_biases_prompt = model.intercept_
    datadict = {"coef" : ridge_weights_prompt, "intercept" : ridge_biases_prompt}
    # Save the regression weights
    with open(f'{outdir}/ridge_prompt_weights.pkl', 'wb') as f:
        pickle.dump(datadict, f)
        
    del train_git_images
    del ridge_weights_prompt
    del ridge_biases_prompt
    del datadict

print(f"Elapsed training time for {model_name}: {time.strftime('%H:%M:%S', time.gmtime(time.time() - start))}")
