In [1]:
import os
# os.environ['CUDA_VISIBLE_DEVICES'] = '1'
import sys
import json
import pickle
import argparse
import numpy as np
import math
from einops import rearrange
import time
import random
import string
import h5py
from tqdm import tqdm
import webdataset as wds

import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from torchvision import transforms
from PIL import Image, ImageDraw, ImageFont

from omegaconf import OmegaConf
from sklearn.linear_model import Ridge
from flux_reconstructor import Flux_Reconstructor
from sd35_reconstructor import SD35_Reconstructor
# tf32 data type is faster than standard float32
torch.backends.cuda.matmul.allow_tf32 = True

# custom functions #
import utils
from models import *

device = "cuda"

Detected kernel version 3.10.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


LOCAL RANK  0
device: cuda


In [2]:
# if running this interactively, can specify jupyter_args here for argparser to use
if utils.is_interactive():
    # model_name = "final_subj01_pretrained_40sess_24bs"
    model_name = "subj01_40sess_hypatia_ridge_sc_flux_enhanced"
    print("model_name:", model_name)

    # other variables can be specified in the following string:
    jupyter_args = f"--data_path=../dataset \
                    --cache_dir=../cache \
                    --model_name={model_name} --subj=1 \
                    --mode vision \
                    --dual_guidance"
    print(jupyter_args)
    jupyter_args = jupyter_args.split()
    
    from IPython.display import clear_output # function to clear print outputs in cell
    %load_ext autoreload 
    # this allows you to change functions in models.py or utils.py and have this notebook automatically update with your revisions
    %autoreload 2 

model_name: subj01_40sess_hypatia_ridge_flux_ip
--data_path=../dataset                     --cache_dir=../cache                     --model_name=subj01_40sess_hypatia_ridge_flux_ip --subj=1                     --mode vision                     --dual_guidance


In [3]:
parser = argparse.ArgumentParser(description="Model Training Configuration")
parser.add_argument(
    "--model_name", type=str, default="testing",
    help="will load ckpt for model found in ../train_logs/model_name",
)
parser.add_argument(
    "--data_path", type=str, default=os.getcwd(),
    help="Path to where NSD data is stored / where to download it to",
)
parser.add_argument(
    "--cache_dir", type=str, default=os.getcwd(),
    help="Path to where misc. files downloaded from huggingface are stored. Defaults to current src directory.",
)
parser.add_argument(
    "--subj",type=int, default=1, choices=[1,2,3,4,5,6,7,8,9,10,11],
    help="Validate on which subject?",
)
parser.add_argument(
    "--blurry_recon",action=argparse.BooleanOptionalAction,default=True,
)
parser.add_argument(
    "--n_blocks",type=int,default=4,
)
parser.add_argument(
    "--hidden_dim",type=int,default=2048,
)
parser.add_argument(
    "--seq_len",type=int,default=1,
)
parser.add_argument(
    "--seed",type=int,default=42,
)
parser.add_argument(
    "--mode",type=str,default="vision",
)
parser.add_argument(
    "--gen_rep",type=int,default=10,
)
parser.add_argument(
    "--snr",type=float,default=-1,
)
parser.add_argument(
    "--normalize_preds",action=argparse.BooleanOptionalAction,default=True,
)
parser.add_argument(
    "--save_raw",action=argparse.BooleanOptionalAction,default=False,
)
parser.add_argument(
    "--raw_path",type=str,
)
parser.add_argument(
    "--sd35",action=argparse.BooleanOptionalAction,default=False,
)

if utils.is_interactive():
    args = parser.parse_args(jupyter_args)
else:
    args = parser.parse_args()

# create global variables without the args prefix
for attribute_name in vars(args).keys():
    globals()[attribute_name] = getattr(args, attribute_name)


if seed > 0 and gen_rep == 1:
    # seed all random functions, but only if doing 1 rep
    utils.seed_everything(seed)
    
all_base_recons = torch.load(f"evals/{model_name}/{model_name}_all_recons_{mode}.pt")

outdir = os.path.abspath(f'../train_logs/{model_name}')
model_name = model_name + "_enhanced"
# make output directory
os.makedirs("evals",exist_ok=True)
os.makedirs(f"evals/{model_name}",exist_ok=True)

# Load data

In [4]:
if mode == "synthetic":
    voxels, all_images = utils.load_nsd_synthetic(subject=subj, average=False, nest=True)
elif subj > 8:
    _, _, voxels, all_images = utils.load_imageryrf(subject=subj-8, mode=mode, stimtype="object", average=False, nest=True, split=True)
else:
    voxels, all_images = utils.load_nsd_mental_imagery(subject=subj, mode=mode, stimtype="all", snr=snr, average=True, nest=False)
num_voxels = voxels.shape[-1]

torch.Size([18, 1, 15724]) torch.Size([18, 3, 425, 425])


# Load pretrained models


In [6]:
if sd35:
    reconstructor = SD35_Reconstructor(embedder_only=False, device=device)
    text_embedding_variant = "sd35_t5"
    clip_text_seq_dim=154
    clip_text_emb_dim=4096
    text_embedding_variant2 = "sd35_clip"
    clip_text_seq_dim2=1
    clip_text_emb_dim2=2048
else:
    reconstructor = Flux_Reconstructor(embedder_only=False, device=device)
    text_embedding_variant = "flux_clip"
    clip_text_seq_dim=1
    clip_text_emb_dim=768
    text_embedding_variant2 = "flux_t5"
    clip_text_seq_dim2=64
    clip_text_emb_dim2=4096

### Compute ground truth embeddings for training data (for feature normalization)

In [7]:
if normalize_preds:
    file_path_txt = f"{data_path}/preprocessed_data/subject{subj}/{text_embedding_variant}_text_embeddings_train.pt"
    clip_text_train = torch.load(file_path_txt)
    
    file_path_txt2 = f"{data_path}/preprocessed_data/subject{subj}/{text_embedding_variant2}_text_embeddings_train.pt"
    clip_text_train2 = torch.load(file_path_txt2)

# Predicting latent vectors for reconstruction  

In [8]:

with open(f'{outdir}/ridge_text_weights.pkl', 'rb') as f:
    text_datadict = pickle.load(f)
pred_clip_text = torch.zeros((len(all_images), clip_text_seq_dim, clip_text_emb_dim)).to("cpu")
model = Ridge(
    alpha=60000,
    max_iter=50000,
    random_state=42,
)
model.coef_ = text_datadict["coef"]
model.intercept_ = text_datadict["intercept"]
pred_clip_text = torch.from_numpy(model.predict(voxels[:,0]).reshape(-1, clip_text_seq_dim, clip_text_emb_dim))

with open(f'{outdir}/ridge_text2_weights.pkl', 'rb') as f:
    text_datadict2 = pickle.load(f)
pred_clip_text2 = torch.zeros((len(all_images), clip_text_seq_dim2, clip_text_emb_dim2)).to("cpu")
model = Ridge(
    alpha=60000,
    max_iter=50000,
    random_state=42,
)
model.coef_ = text_datadict2["coef"]
model.intercept_ = text_datadict2["intercept"]
pred_clip_text2 = torch.from_numpy(model.predict(voxels[:,0]).reshape(-1, clip_text_seq_dim2, clip_text_emb_dim2))

print("pred_clip_text shape:", pred_clip_text.shape)
print("Number of NaN values in pred_clip_text:", torch.isnan(pred_clip_text).sum().item())

print("pred_clip_text2 shape:", pred_clip_text2.shape)
print("Number of NaN values in pred_clip_text2:", torch.isnan(pred_clip_text2).sum().item())
if normalize_preds:
    for sequence in range(clip_text_seq_dim):
        std_pred_clip_text = (pred_clip_text[:, sequence] - torch.mean(pred_clip_text[:, sequence],axis=0)) / (torch.std(pred_clip_text[:, sequence],axis=0) + 1e-6)
        pred_clip_text[:, sequence] = std_pred_clip_text * torch.std(clip_text_train[:, sequence],axis=0) + torch.mean(clip_text_train[:, sequence],axis=0)
    for sequence in range(clip_text_seq_dim2):
        std_pred_clip_text2 = (pred_clip_text2[:, sequence] - torch.mean(pred_clip_text2[:, sequence],axis=0)) / (torch.std(pred_clip_text2[:, sequence],axis=0) + 1e-6)
        pred_clip_text2[:, sequence] = std_pred_clip_text2 * torch.std(clip_text_train2[:, sequence],axis=0) + torch.mean(clip_text_train2[:, sequence],axis=0)
print("Normalized pred_clip_text shape:", pred_clip_text.shape)
print("Normalized Number of NaN values in pred_clip_text:", torch.isnan(pred_clip_text).sum().item())

print("Normalized pred_clip_text2 shape:", pred_clip_text2.shape)
print("Normalized Number of NaN values in pred_clip_text2:", torch.isnan(pred_clip_text2).sum().item())

In [10]:
raw_root = f"/export/raid1/home/kneel027/Second-Sight/output/mental_imagery_paper_b3/{mode}/{model_name}/subject{subj}/"
os.makedirs(raw_root,exist_ok=True)
torch.save(pred_clip_text, f"{raw_root}/{text_embedding_variant}_text_voxels.pt")
torch.save(pred_clip_text2, f"{raw_root}/{text_embedding_variant2}_text_voxels.pt")
print(raw_root)

/export/raid1/home/kneel027/Second-Sight/output/mental_imagery_paper_b3/vision/subj01_40sess_hypatia_ridge_flux_ip/subject1/


In [11]:
final_recons = None
final_predcaptions = None
final_clipvoxels = None
final_blurryrecons = None

raw_root = f"{raw_path}/{mode}/{model_name}/subject{subj}/"
print("raw_root:", raw_root)
os.makedirs(raw_root,exist_ok=True)
recons_per_sample = 16

for rep in tqdm(range(gen_rep)):
    seed = random.randint(0,10000000)
    utils.seed_everything(seed = seed)
    print(f"seed = {seed}")
    # get all reconstructions    
    all_recons = None
    all_clipvoxels = None
    
    minibatch_size = 1
    num_samples_per_image = 1
    plotting = False
    
    with torch.no_grad():
        for idx in tqdm(range(0,voxels.shape[0]), desc="sample loop"):
            
            clip_text_voxels = pred_clip_text[idx]
            clip_text_voxels2 = pred_clip_text2[idx]
            
            # Feed outputs through versatile diffusion
            samples = reconstructor.reconstruct(
                                image=transforms.ToPILImage()(all_base_recons[idx][rep]),
                                c_t=clip_text_voxels,
                                t5=clip_text_voxels2,
                                n_samples=1,
                                strength=0.8,
                                seed=seed)
            if isinstance(samples, PIL.Image.Image):
                samples = transforms.ToTensor()(samples)
            samples = samples.unsqueeze(0)
            
            
            if all_recons is None:
                all_recons = samples.cpu()
            else:
                all_recons = torch.vstack((all_recons, samples.cpu()))
            if plotting:
                for s in range(num_samples_per_image):
                    plt.figure(figsize=(2,2))
                    plt.imshow(transforms.ToPILImage()(samples[s]))
                    plt.axis('off')
                    plt.show()
                    
            if plotting: 
                print(model_name)
                err # dont actually want to run the whole thing with plotting=True

            if save_raw:
                # print(f"Saving raw images to {raw_root}/{idx}/{rep}.png")
                os.makedirs(f"{raw_root}/{idx}/", exist_ok=True)
                transforms.ToPILImage()(samples[0]).save(f"{raw_root}/{idx}/{rep}.png")
                if rep == 0:
                    transforms.ToPILImage()(all_base_recons[idx][rep]).save(f"{raw_root}/{idx}/low_level.png")
                    transforms.ToPILImage()(all_images[idx]).save(f"{raw_root}/{idx}/ground_truth.png")
        # resize outputs before saving
        imsize = 256
        # saving
        # print(all_recons.shape)
        # torch.save(all_images,"evals/all_images.pt")
        if final_recons is None:
            final_recons = all_recons.unsqueeze(1)
        else:
            final_recons = torch.cat((final_recons, all_recons.unsqueeze(1)), dim=1)

torch.save(final_recons,f"evals/{model_name}/{model_name}_all_recons_{mode}.pt")
print(f"saved {model_name} mi outputs!")

# if not utils.is_interactive():
#     sys.exit(0)


raw_root: /export/raid1/home/kneel027/Second-Sight/output/mental_imagery_paper_b3/vision/subj01_40sess_hypatia_ridge_flux_ip/subject1/


  0%|          | 0/10 [00:00<?, ?it/s]

seed = 8678759




sample loop:   0%|          | 0/18 [00:03<?, ?it/s]
  0%|          | 0/10 [00:03<?, ?it/s]


NameError: name 'reconstructor' is not defined

In [None]:

if not utils.is_interactive():
    sys.exit(0)