In [None]:
import os
# os.environ['CUDA_VISIBLE_DEVICES'] = '1'
import sys
import json
import pickle
import argparse
import numpy as np
import math
from einops import rearrange
import time
import random
import string
import h5py
from tqdm import tqdm
import webdataset as wds
import PIL
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from torchvision import transforms

from PIL import Image, ImageDraw, ImageFont

# SDXL unCLIP requires code from https://github.com/Stability-AI/generative-models/tree/main
# sys.path.append('generative_models/')
# import sgm
from flux_reconstructor import Flux_Reconstructor
from omegaconf import OmegaConf
from sklearn.linear_model import Ridge
# tf32 data type is faster than standard float32
torch.backends.cuda.matmul.allow_tf32 = True

# custom functions #
import utils
# from models import *
device = "cuda"
print("device:",device)

In [None]:
# if running this interactively, can specify jupyter_args here for argparser to use
if utils.is_interactive():
    # model_name = "final_subj01_pretrained_40sess_24bs"
    model_name = "subj01_40sess_hypatia_ridge_sc2"
    print("model_name:", model_name)

    # other variables can be specified in the following string:
    jupyter_args = f"--data_path=../dataset \
                    --cache_dir=../cache \
                    --model_name={model_name} --subj=1 \
                    --mode vision \
                    --dual_guidance"
    print(jupyter_args)
    jupyter_args = jupyter_args.split()
    
    from IPython.display import clear_output # function to clear print outputs in cell
    %load_ext autoreload 
    # this allows you to change functions in models.py or utils.py and have this notebook automatically update with your revisions
    %autoreload 2 

In [None]:
parser = argparse.ArgumentParser(description="Model Training Configuration")
parser.add_argument(
    "--model_name", type=str, default="testing",
    help="will load ckpt for model found in ../train_logs/model_name",
)
parser.add_argument(
    "--data_path", type=str, default=os.getcwd(),
    help="Path to where NSD data is stored / where to download it to",
)
parser.add_argument(
    "--cache_dir", type=str, default=os.getcwd(),
    help="Path to where misc. files downloaded from huggingface are stored. Defaults to current src directory.",
)
parser.add_argument(
    "--subj",type=int, default=1, choices=[1,2,3,4,5,6,7,8,9,10,11],
    help="Validate on which subject?",
)
parser.add_argument(
    "--blurry_recon",action=argparse.BooleanOptionalAction,default=True,
)
parser.add_argument(
    "--n_blocks",type=int,default=4,
)
parser.add_argument(
    "--hidden_dim",type=int,default=2048,
)
parser.add_argument(
    "--seq_len",type=int,default=1,
)
parser.add_argument(
    "--seed",type=int,default=42,
)
parser.add_argument(
    "--mode",type=str,default="vision",
)
parser.add_argument(
    "--gen_rep",type=int,default=10,
)
parser.add_argument(
    "--dual_guidance",action=argparse.BooleanOptionalAction,default=False,
)
parser.add_argument(
    "--snr",type=float,default=-1,
)
parser.add_argument(
    "--normalize_preds",action=argparse.BooleanOptionalAction,default=True,
)
parser.add_argument(
    "--save_raw",action=argparse.BooleanOptionalAction,default=False,
)
parser.add_argument(
    "--raw_path",type=str,
)
parser.add_argument(
    "--textstrength",type=float,default=0.5,
)
parser.add_argument(
    "--top_n_rank_order_rois",type=int, default=-1,
    help="Used for selecting the top n rois on a whole brain to narrow down voxels.",
)
parser.add_argument(
    "--samplewise_rank_order_rois",action=argparse.BooleanOptionalAction,default=False,
    help="Use the samplewise rank order rois versus voxelwise",
)
if utils.is_interactive():
    args = parser.parse_args(jupyter_args)
else:
    args = parser.parse_args()

# create global variables without the args prefix
for attribute_name in vars(args).keys():
    globals()[attribute_name] = getattr(args, attribute_name)


if seed > 0 and gen_rep == 1:
    # seed all random functions, but only if doing 1 rep
    utils.seed_everything(seed)
    

outdir = os.path.abspath(f'../train_logs/{model_name}')

# make output directory
os.makedirs("evals",exist_ok=True)
os.makedirs(f"evals/{model_name}",exist_ok=True)

# Load data

In [None]:
if mode == "synthetic":
    voxels, all_images = utils.load_nsd_synthetic(subject=subj, average=False, nest=True)
elif subj > 8:
    _, _, voxels, all_images = utils.load_imageryrf(subject=subj-8, mode=mode, stimtype="object", average=False, nest=True, split=True)
else:
    voxels, all_images = utils.load_nsd_mental_imagery(subject=subj, mode=mode, stimtype="all", average=True, nest=False)
    #top_n_rois=top_n_rank_order_rois, samplewise=samplewise_rank_order_rois
num_voxels = voxels.shape[-1]

# Load pretrained models

# Load Stable Diffusion 3.5

In [None]:
reconstructor = Flux_Reconstructor(embedder_only=False, device=device)
text_embedding_variant = "flux_clip"
clip_text_seq_dim=1
clip_text_emb_dim=768
text_embedding_variant2 = "flux_t5"
clip_text_seq_dim2=64
clip_text_emb_dim2=4096
latent_embedding_variant = "flux"
latent_emb_dim = 262144

### Compute ground truth embeddings for training data (for feature normalization)

In [None]:
# If this is erroring, feature extraction failed in Train.ipynb
if normalize_preds:
    file_path_txt = f"{data_path}/preprocessed_data/subject{subj}/{text_embedding_variant}_text_embeddings_train.pt"
    clip_text_train = torch.load(file_path_txt)
        
    file_path_txt2 = f"{data_path}/preprocessed_data/subject{subj}/{text_embedding_variant2}_text_embeddings_train.pt"
    clip_text_train2 = torch.load(file_path_txt2)
        
    if blurry_recon:
        file_path = f"{data_path}/preprocessed_data/subject{subj}/{latent_embedding_variant}_latent_embeddings_train.pt"
        vae_image_train = torch.load(file_path)
        strength = 0.75
    else:
        strength = 1.0

# Predicting latent vectors for reconstruction  

In [None]:
with open(f'{outdir}/ridge_text_weights.pkl', 'rb') as f:
    text_datadict = pickle.load(f)
pred_clip_text = torch.zeros((len(all_images), clip_text_seq_dim, clip_text_emb_dim)).to("cpu")
model = Ridge(
    alpha=60000,
    max_iter=50000,
    random_state=42,
)
model.coef_ = text_datadict["coef"]
model.intercept_ = text_datadict["intercept"]
pred_clip_text = torch.from_numpy(model.predict(voxels[:,0]).reshape(-1, clip_text_seq_dim, clip_text_emb_dim))

with open(f'{outdir}/ridge_text2_weights.pkl', 'rb') as f:
    text_datadict2 = pickle.load(f)
pred_clip_text2 = torch.zeros((len(all_images), clip_text_seq_dim2, clip_text_emb_dim2)).to("cpu")
model = Ridge(
    alpha=60000,
    max_iter=50000,
    random_state=42,
)
model.coef_ = text_datadict2["coef"]
model.intercept_ = text_datadict2["intercept"]
pred_clip_text2 = torch.from_numpy(model.predict(voxels[:,0]).reshape(-1, clip_text_seq_dim2, clip_text_emb_dim2))
    
if blurry_recon:
    pred_blurry_vae = torch.zeros((len(all_images), latent_emb_dim)).to("cpu")
    with open(f'{outdir}/ridge_blurry_weights.pkl', 'rb') as f:
        blurry_datadict = pickle.load(f)
    model = Ridge(
        alpha=60000,
        max_iter=50000,
        random_state=42,
    )
    model.coef_ = blurry_datadict["coef"]
    model.intercept_ = blurry_datadict["intercept"]
    pred_blurry_vae = torch.from_numpy(model.predict(voxels[:,0]).reshape(-1, latent_emb_dim))    
    
if normalize_preds:
    for sequence in range(clip_text_seq_dim):
        std_pred_clip_text = (pred_clip_text[:, sequence] - torch.mean(pred_clip_text[:, sequence],axis=0)) / (torch.std(pred_clip_text[:, sequence],axis=0) + 1e-6)
        pred_clip_text[:, sequence] = std_pred_clip_text * torch.std(clip_text_train[:, sequence],axis=0) + torch.mean(clip_text_train[:, sequence],axis=0)
    for sequence in range(clip_text_seq_dim2):
        std_pred_clip_text2 = (pred_clip_text2[:, sequence] - torch.mean(pred_clip_text2[:, sequence],axis=0)) / (torch.std(pred_clip_text2[:, sequence],axis=0) + 1e-6)
        pred_clip_text2[:, sequence] = std_pred_clip_text2 * torch.std(clip_text_train2[:, sequence],axis=0) + torch.mean(clip_text_train2[:, sequence],axis=0)
    if blurry_recon:
        std_pred_blurry_vae = (pred_blurry_vae - torch.mean(pred_blurry_vae,axis=0)) / torch.std(pred_blurry_vae,axis=0)
        pred_blurry_vae = std_pred_blurry_vae * torch.std(vae_image_train,axis=0) + torch.mean(vae_image_train,axis=0)

In [None]:
final_recons = None
final_predcaptions = None
final_clipvoxels = None
final_blurryrecons = None

if save_raw:
    raw_root = f"{raw_path}/{mode}/{model_name}/subject{subj}/"
    print("raw_root:", raw_root)
    os.makedirs(raw_root,exist_ok=True)
    torch.save(pred_clip_image, f"{raw_root}/{text_embedding_variant}_text_voxels.pt")
    torch.save(pred_clip_text, f"{raw_root}/{text_embedding_variant2}_text_voxels.pt")
    if blurry_recon:
        torch.save(pred_blurry_vae, f"{raw_root}/{latent_embedding_variant}_latent_voxels.pt")


for rep in tqdm(range(gen_rep)):
    utils.seed_everything(seed = random.randint(0,10000000))
    # get all reconstructions    
    # all_images = None
    all_blurryrecons = None
    all_recons = None
    all_predcaptions = []
    all_clipvoxels = None
    
    minibatch_size = 1
    plotting = False
    with torch.no_grad(), torch.cuda.amp.autocast(dtype=torch.float16):
        for idx in tqdm(range(0,voxels.shape[0]), desc="sample loop"):
            clip_text_voxels = pred_clip_text[idx]
            clip_text_voxels2 = pred_clip_text2[idx]
            latent_voxels = pred_blurry_vae[idx]
                
            if blurry_recon and rep==0:
                blurred_image = reconstructor.reconstruct(latent=latent_voxels,
                                                            n_samples=1,
                                                            strength=0.0)
                
                im = transforms.ToTensor()(blurred_image)
                if all_blurryrecons is None:
                    all_blurryrecons = im.cpu()
                else:
                    all_blurryrecons = torch.vstack((all_blurryrecons, im.cpu()))
            
            samples = reconstructor.reconstruct(latent=latent_voxels,
                                        c_t=clip_text_voxels,
                                        t5=clip_text_voxels2,
                                        n_samples=1,
                                        strength=0.8)
            if isinstance(samples, PIL.Image.Image):
                samples = transforms.ToTensor()(samples)
            samples = samples.unsqueeze(0)
            
            if all_recons is None:
                all_recons = samples.cpu()
            else:
                all_recons = torch.vstack((all_recons, samples.cpu()))

            if save_raw:
                # print(f"Saving raw images to {raw_root}/{idx}/{rep}.png")
                os.makedirs(f"{raw_root}/{idx}/", exist_ok=True)
                transforms.ToPILImage()(samples[0]).save(f"{raw_root}/{idx}/{rep}.png")
                transforms.ToPILImage()(all_images[idx]).save(f"{raw_root}/{idx}/ground_truth.png")
                if rep == 0:
                    blurred_image.save(f"{raw_root}/{idx}/low_level.png")
                    torch.save(clip_voxels, f"{raw_root}/{idx}/clip_image_voxels.pt")
    
        if final_recons is None:
            final_recons = all_recons.unsqueeze(1)
            if blurry_recon:
                final_blurryrecons = all_blurryrecons.unsqueeze(1)
        else:
            final_recons = torch.cat((final_recons, all_recons.unsqueeze(1)), dim=1)
            if blurry_recon:
                final_blurryrecons = torch.cat((all_blurryrecons.unsqueeze(1),final_blurryrecons), dim = 1)
        
if blurry_recon:
    torch.save(final_blurryrecons,f"evals/{model_name}/{model_name}_all_blurryrecons_{mode}.pt")
torch.save(final_recons,f"evals/{model_name}/{model_name}_all_recons_{mode}.pt")
print(f"saved {model_name} mi outputs!")


In [None]:
if not utils.is_interactive():
    sys.exit(0)