In [1]:
import os
# os.environ['CUDA_VISIBLE_DEVICES'] = '1'
import sys
import json
import pickle
import argparse
import numpy as np
import math
from einops import rearrange
import time
import random
import string
import h5py
from tqdm import tqdm
import webdataset as wds

import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from torchvision import transforms
from accelerate import Accelerator

from PIL import Image, ImageDraw, ImageFont

# SDXL unCLIP requires code from https://github.com/Stability-AI/generative-models/tree/main
sys.path.append('generative_models/')
import sgm
from generative_models.sgm.modules.encoders.modules import FrozenOpenCLIPImageEmbedder, FrozenOpenCLIPEmbedder2
from generative_models.sgm.models.diffusion import DiffusionEngine
from generative_models.sgm.util import append_dims
from omegaconf import OmegaConf
from models import ClipperLarge
from sklearn.linear_model import Ridge
# tf32 data type is faster than standard float32
torch.backends.cuda.matmul.allow_tf32 = True

# custom functions #
import utils
from models import *

### Multi-GPU config ###
local_rank = os.getenv('RANK')
if local_rank is None: 
    local_rank = 0
else:
    local_rank = int(local_rank)
print("LOCAL RANK ", local_rank)  

accelerator = Accelerator(split_batches=False, mixed_precision="fp16")
device = accelerator.device
print("device:",device)

Detected kernel version 3.10.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


LOCAL RANK  0
device: cuda


In [2]:
# if running this interactively, can specify jupyter_args here for argparser to use
if utils.is_interactive():
    # model_name = "final_subj01_pretrained_40sess_24bs"
    model_name = "subj01_40sess_hypatia_turbo_ridge"
    print("model_name:", model_name)

    # other variables can be specified in the following string:
    jupyter_args = f"--data_path=../dataset \
                    --cache_dir=../cache \
                    --model_name={model_name} --subj=1 \
                    --mode imagery --no-blurry_recon"
    print(jupyter_args)
    jupyter_args = jupyter_args.split()
    
    from IPython.display import clear_output # function to clear print outputs in cell
    %load_ext autoreload 
    # this allows you to change functions in models.py or utils.py and have this notebook automatically update with your revisions
    %autoreload 2 

model_name: subj01_40sess_hypatia_turbo_ridge
--data_path=../dataset                     --cache_dir=../cache                     --model_name=subj01_40sess_hypatia_turbo_ridge --subj=1                     --mode imagery --no-blurry_recon


In [3]:
parser = argparse.ArgumentParser(description="Model Training Configuration")
parser.add_argument(
    "--model_name", type=str, default="testing",
    help="will load ckpt for model found in ../train_logs/model_name",
)
parser.add_argument(
    "--data_path", type=str, default=os.getcwd(),
    help="Path to where NSD data is stored / where to download it to",
)
parser.add_argument(
    "--cache_dir", type=str, default=os.getcwd(),
    help="Path to where misc. files downloaded from huggingface are stored. Defaults to current src directory.",
)
parser.add_argument(
    "--subj",type=int, default=1, choices=[1,2,3,4,5,6,7,8,9,10,11],
    help="Validate on which subject?",
)
parser.add_argument(
    "--blurry_recon",action=argparse.BooleanOptionalAction,default=True,
)
parser.add_argument(
    "--n_blocks",type=int,default=4,
)
parser.add_argument(
    "--hidden_dim",type=int,default=2048,
)
parser.add_argument(
    "--seq_len",type=int,default=1,
)
parser.add_argument(
    "--seed",type=int,default=42,
)
parser.add_argument(
    "--mode",type=str,default="vision",
)
parser.add_argument(
    "--gen_rep",type=int,default=10,
)
parser.add_argument(
    "--dual_guidance",action=argparse.BooleanOptionalAction,default=False,
)
parser.add_argument(
    "--snr",type=float,default=-1,
)
parser.add_argument(
    "--alpha",type=float,default=60000,
)
if utils.is_interactive():
    args = parser.parse_args(jupyter_args)
else:
    args = parser.parse_args()

# create global variables without the args prefix
for attribute_name in vars(args).keys():
    globals()[attribute_name] = getattr(args, attribute_name)


if seed > 0 and gen_rep == 1:
    # seed all random functions, but only if doing 1 rep
    utils.seed_everything(seed)
outdir = os.path.abspath(f'../train_logs/{model_name}')

# make output directory
os.makedirs("evals",exist_ok=True)
os.makedirs(f"evals/{model_name}",exist_ok=True)

In [4]:
if mode == "synthetic":
    voxels, all_images = utils.load_nsd_synthetic(subject=subj, average=False, nest=True)
elif subj > 8:
    _, _, voxels, all_images = utils.load_imageryrf(subject=subj-8, mode=mode, stimtype="object", average=False, nest=True, split=True)
else:
    voxels, all_images = utils.load_nsd_mental_imagery(subject=subj, mode=mode, stimtype="all", snr=snr, average=True, nest=False)
num_voxels = voxels.shape[-1]

torch.Size([18, 1, 15724]) torch.Size([18, 3, 425, 425])


In [5]:
_, _, x_test, test_nsd_ids = utils.load_nsd(subject=subj, data_path=data_path)

# Load 73k NSD images
f = h5py.File(f'{data_path}/coco_images_224_float16.hdf5', 'r')
images = f['images'] # if you go OOM you can remove the [:] so it isnt preloaded to cpu! (will require a few edits elsewhere tho)
# images = torch.Tensor(images).to("cpu").to(data_type)
print("Loaded all 73k possible NSD images to cpu!", images.shape)

images_test = torch.zeros((len(test_nsd_ids), 3, 224, 224))
for i, idx in enumerate(test_nsd_ids):
    images_test[i] =  torch.from_numpy(images[idx])
print(f"Loaded test images for subj{subj}!", images_test.shape)

Loaded all 73k possible NSD images to cpu! (73000, 3, 224, 224)
Loaded test images for subj1! torch.Size([1000, 3, 224, 224])


In [20]:
print('Creating Clipper...')
clip_emb_dim = 1024
clip_seq_dim = 1
clip_txt_seq_dim=77
clip_extractor = ClipperLarge(device="cuda")

# # Generate CLIP Image embeddings
# clip_image_test = torch.zeros((len(all_images), clip_emb_dim)).to("cpu")
# for i in tqdm(range(0, len(all_images), 50)):
#     img_pils = [transforms.ToPILImage()(images_test) for images_test in all_images[i:i+50]]
#     clip_image_test[i:i+50] = clip_extractor.embed_image(img_pils).detach().to("cpu")
# torch.save(clip_image_test, f"{data_path}/preprocessed_data/subject{subj}/ViT-H-14_image_embeddings_nsdimagery.pt")

generator = Generator4Embeds(num_inference_steps=4, device=device, cache_dir=cache_dir)

Creating Clipper...


100%|██████████| 1/1 [00:00<00:00,  3.43it/s]


Loading pipeline components...:   0%|          | 0/7 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [19]:
clip_image_test = torch.load(f"{data_path}/preprocessed_data/subject{subj}/ViT-H-14_image_embeddings_test.pt")
clip_image_imagery = torch.load(f"{data_path}/preprocessed_data/subject{subj}/ViT-H-14_image_embeddings_nsdimagery.pt")

# Specify the directory where the weights are saved
weights_dir = f'{outdir}/{alpha}_alpha_weights.pkl'
print(f"Loading Ridge with alpha={alpha}, dir: {weights_dir}")

# Check if the weights file exists
assert os.path.exists(weights_dir)
# Load the regression weights
with open(weights_dir, 'rb') as f:
    weights = pickle.load(f)

# Create a new instance of the Ridge model
model = Ridge(
    alpha=alpha,
    max_iter=50000,
    random_state=42,
)

# Set the regression weights in the model
model.coef_ = weights["coef"]
model.intercept_ = weights["intercept"]

shared1000_pred_clip = model.predict(x_test)

import torch.nn.functional as F

cosine_sim = F.cosine_similarity(torch.from_numpy(shared1000_pred_clip), clip_image_test, dim=0)
print(f"Mean Shared1000 Cosine Similarity: {cosine_sim.mean().item()}")

nsd_imagery_pred_clip = model.predict(voxels[:,0])

cosine_sim = F.cosine_similarity(torch.from_numpy(nsd_imagery_pred_clip), clip_image_imagery, dim=0)
print(f"Mean NSDImagery: {mode} Cosine Similarity: {cosine_sim.mean().item()}")

# torch.save(shared1000_pred_clip,f"evals/{model_name}/{model_name}_all_clipvoxels.pt")
# torch.save(nsd_imagery_pred_clip,f"evals/{model_name}/{model_name}_all_clipvoxels_{mode}.pt")

# save_path = f"/export/raid1/home/kneel027/Second-Sight/output/mental_imagery_paper_b3/{mode}/mindeye_imagery_turbo_ridge_alpha_{alpha}/subject{subj}/"
# final_recons = None

# for rep in range(gen_rep):
#     utils.seed_everything(seed = random.randint(0,10000000))
#     all_recons = None
#     for sample in tqdm(range(len(nsd_imagery_pred_clip)), desc="nsdimagery recons"):
#         os.makedirs(f"{save_path}/{sample}/",exist_ok=True)
#         pred_clip = torch.from_numpy(nsd_imagery_pred_clip[sample]).unsqueeze(0)
#         gen_img = generator.generate(pred_clip)
#         gen_img.save(f"{save_path}/{sample}/{rep}.png")
#         transforms.ToPILImage()(all_images[sample]).save(f"{save_path}/{sample}/ground_truth.png")
        
#         samples = transforms.ToTensor()(gen_img).unsqueeze(0)
#         if all_recons is None:
#             all_recons = samples.cpu()
#         else:
#             all_recons = torch.vstack((all_recons, samples.cpu()))
#     if final_recons is None:
#         final_recons = all_recons.unsqueeze(1)
#     else:
#         final_recons = torch.cat((final_recons, all_recons.unsqueeze(1)), dim=1)
# torch.save(final_recons,f"evals/{model_name}/{model_name}_all_recons_{mode}.pt")
    
    
    # final_recons = None
    # save_path = f"/export/raid1/home/kneel027/Second-Sight/output/mindeye_imagery_turbo_ridge_alpha_{alpha}/subject{subj}/"
    # for rep in range(10):
    #     utils.seed_everything(seed = random.randint(0,10000000))
    #     all_recons = None
    #     for sample in tqdm(range(len(shared1000_pred_clip)), desc="shared1000 recons"):
    #         os.makedirs(f"{save_path}/{sample}/",exist_ok=True)
    #         pred_clip = torch.from_numpy(shared1000_pred_clip[sample]).unsqueeze(0)
    #         gen_img = generator.generate(pred_clip)
    #         gen_img.save(f"{save_path}/{sample}/{rep}.png")
    #         transforms.ToPILImage()(images_test[sample]).save(f"{save_path}/{sample}/ground_truth.png")
            
    #         samples = transforms.ToTensor()(gen_img).unsqueeze(0)
    #         if all_recons is None:
    #             all_recons = samples.cpu()
    #         else:
    #             all_recons = torch.vstack((all_recons, samples.cpu()))
    #     if final_recons is None:
    #         final_recons = all_recons.unsqueeze(1)
    #     else:
    #         final_recons = torch.cat((final_recons, all_recons.unsqueeze(1)), dim=1)
    # torch.save(final_recons,f"evals/{model_name}_alpha_{alpha}/{model_name}_alpha_{alpha}_all_recons.pt")



Loading Ridge with alpha=60000
Mean Shared1000 Cosine Similarity: 0.5007764716582568
Mean NSDImagery: imagery Cosine Similarity: 0.12528371837698826


nsdimagery recons:   0%|          | 0/18 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  deprecate("image_embeds not a list", "1.0.0", deprecation_message, standard_warn=False)
nsdimagery recons:   6%|▌         | 1/18 [00:00<00:08,  1.99it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

nsdimagery recons:  11%|█         | 2/18 [00:01<00:08,  1.93it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

nsdimagery recons:  17%|█▋        | 3/18 [00:01<00:07,  1.96it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

nsdimagery recons:  22%|██▏       | 4/18 [00:02<00:07,  1.93it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

nsdimagery recons:  28%|██▊       | 5/18 [00:02<00:06,  1.93it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

nsdimagery recons:  33%|███▎      | 6/18 [00:03<00:06,  1.92it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

nsdimagery recons:  39%|███▉      | 7/18 [00:03<00:05,  1.88it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

nsdimagery recons:  44%|████▍     | 8/18 [00:04<00:05,  1.88it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

nsdimagery recons:  50%|█████     | 9/18 [00:04<00:04,  1.84it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

nsdimagery recons:  56%|█████▌    | 10/18 [00:05<00:04,  1.85it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

nsdimagery recons:  61%|██████    | 11/18 [00:05<00:03,  1.81it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

nsdimagery recons:  67%|██████▋   | 12/18 [00:06<00:03,  1.80it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

nsdimagery recons:  72%|███████▏  | 13/18 [00:06<00:02,  1.84it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

nsdimagery recons:  78%|███████▊  | 14/18 [00:07<00:02,  1.87it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

nsdimagery recons:  83%|████████▎ | 15/18 [00:07<00:01,  1.90it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

nsdimagery recons:  89%|████████▉ | 16/18 [00:08<00:01,  1.91it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

nsdimagery recons:  94%|█████████▍| 17/18 [00:08<00:00,  1.93it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

nsdimagery recons: 100%|██████████| 18/18 [00:09<00:00,  1.89it/s]
nsdimagery recons:   0%|          | 0/18 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

nsdimagery recons:   6%|▌         | 1/18 [00:00<00:08,  1.99it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

nsdimagery recons:  11%|█         | 2/18 [00:01<00:08,  1.94it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

nsdimagery recons:  17%|█▋        | 3/18 [00:01<00:07,  1.94it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

nsdimagery recons:  22%|██▏       | 4/18 [00:02<00:07,  1.92it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

nsdimagery recons:  28%|██▊       | 5/18 [00:02<00:06,  1.94it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

nsdimagery recons:  33%|███▎      | 6/18 [00:03<00:06,  1.94it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

nsdimagery recons:  39%|███▉      | 7/18 [00:03<00:05,  1.90it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

nsdimagery recons:  44%|████▍     | 8/18 [00:04<00:05,  1.88it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

nsdimagery recons:  50%|█████     | 9/18 [00:04<00:04,  1.84it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

nsdimagery recons:  56%|█████▌    | 10/18 [00:05<00:04,  1.83it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

nsdimagery recons:  61%|██████    | 11/18 [00:05<00:03,  1.79it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

nsdimagery recons:  67%|██████▋   | 12/18 [00:06<00:03,  1.78it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

nsdimagery recons:  72%|███████▏  | 13/18 [00:06<00:02,  1.81it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

nsdimagery recons:  78%|███████▊  | 14/18 [00:07<00:02,  1.86it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

nsdimagery recons:  83%|████████▎ | 15/18 [00:07<00:01,  1.91it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

nsdimagery recons:  89%|████████▉ | 16/18 [00:08<00:01,  1.92it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

nsdimagery recons:  94%|█████████▍| 17/18 [00:09<00:00,  1.93it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

nsdimagery recons: 100%|██████████| 18/18 [00:09<00:00,  1.89it/s]
nsdimagery recons:   0%|          | 0/18 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

nsdimagery recons:   6%|▌         | 1/18 [00:00<00:08,  1.98it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

nsdimagery recons:  11%|█         | 2/18 [00:01<00:08,  1.98it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

nsdimagery recons:  17%|█▋        | 3/18 [00:01<00:07,  1.98it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

nsdimagery recons:  22%|██▏       | 4/18 [00:02<00:07,  1.95it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

nsdimagery recons:  28%|██▊       | 5/18 [00:02<00:06,  1.96it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

nsdimagery recons:  33%|███▎      | 6/18 [00:03<00:06,  1.97it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

nsdimagery recons:  39%|███▉      | 7/18 [00:03<00:05,  1.90it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

nsdimagery recons:  44%|████▍     | 8/18 [00:04<00:05,  1.87it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

nsdimagery recons:  50%|█████     | 9/18 [00:04<00:04,  1.85it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

nsdimagery recons:  56%|█████▌    | 10/18 [00:05<00:04,  1.83it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

nsdimagery recons:  61%|██████    | 11/18 [00:05<00:03,  1.78it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

nsdimagery recons:  67%|██████▋   | 12/18 [00:06<00:03,  1.78it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

nsdimagery recons:  72%|███████▏  | 13/18 [00:06<00:02,  1.82it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

nsdimagery recons:  78%|███████▊  | 14/18 [00:07<00:02,  1.87it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

nsdimagery recons:  83%|████████▎ | 15/18 [00:07<00:01,  1.90it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

nsdimagery recons:  89%|████████▉ | 16/18 [00:08<00:01,  1.89it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

nsdimagery recons:  94%|█████████▍| 17/18 [00:09<00:00,  1.92it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

nsdimagery recons: 100%|██████████| 18/18 [00:09<00:00,  1.89it/s]
nsdimagery recons:   0%|          | 0/18 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

nsdimagery recons:   6%|▌         | 1/18 [00:00<00:08,  1.93it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

nsdimagery recons:  11%|█         | 2/18 [00:01<00:08,  1.89it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

nsdimagery recons:  17%|█▋        | 3/18 [00:01<00:07,  1.93it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

nsdimagery recons:  22%|██▏       | 4/18 [00:02<00:07,  1.89it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

nsdimagery recons:  28%|██▊       | 5/18 [00:02<00:06,  1.92it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

nsdimagery recons:  33%|███▎      | 6/18 [00:03<00:06,  1.92it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

nsdimagery recons:  39%|███▉      | 7/18 [00:03<00:05,  1.86it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

nsdimagery recons:  44%|████▍     | 8/18 [00:04<00:05,  1.84it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

nsdimagery recons:  50%|█████     | 9/18 [00:04<00:04,  1.81it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

nsdimagery recons:  56%|█████▌    | 10/18 [00:05<00:04,  1.83it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

nsdimagery recons:  61%|██████    | 11/18 [00:05<00:03,  1.80it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

nsdimagery recons:  67%|██████▋   | 12/18 [00:06<00:03,  1.81it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

nsdimagery recons:  72%|███████▏  | 13/18 [00:07<00:02,  1.84it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

nsdimagery recons:  78%|███████▊  | 14/18 [00:07<00:02,  1.90it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

nsdimagery recons:  83%|████████▎ | 15/18 [00:08<00:01,  1.92it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

nsdimagery recons:  89%|████████▉ | 16/18 [00:08<00:01,  1.94it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

nsdimagery recons:  94%|█████████▍| 17/18 [00:09<00:00,  1.96it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

nsdimagery recons: 100%|██████████| 18/18 [00:09<00:00,  1.89it/s]
nsdimagery recons:   0%|          | 0/18 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

nsdimagery recons:   6%|▌         | 1/18 [00:00<00:09,  1.81it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

nsdimagery recons:  11%|█         | 2/18 [00:01<00:08,  1.88it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

nsdimagery recons:  17%|█▋        | 3/18 [00:01<00:07,  1.94it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

nsdimagery recons:  22%|██▏       | 4/18 [00:02<00:07,  1.95it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

nsdimagery recons:  28%|██▊       | 5/18 [00:02<00:06,  1.96it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

nsdimagery recons:  33%|███▎      | 6/18 [00:03<00:06,  1.99it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

nsdimagery recons:  39%|███▉      | 7/18 [00:03<00:05,  1.95it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

nsdimagery recons:  44%|████▍     | 8/18 [00:04<00:05,  1.94it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

nsdimagery recons:  50%|█████     | 9/18 [00:04<00:04,  1.90it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

nsdimagery recons:  56%|█████▌    | 10/18 [00:05<00:04,  1.88it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

nsdimagery recons:  61%|██████    | 11/18 [00:05<00:03,  1.83it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

nsdimagery recons:  67%|██████▋   | 12/18 [00:06<00:03,  1.82it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

nsdimagery recons:  72%|███████▏  | 13/18 [00:06<00:02,  1.84it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

nsdimagery recons:  78%|███████▊  | 14/18 [00:07<00:02,  1.91it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

nsdimagery recons:  83%|████████▎ | 15/18 [00:07<00:01,  1.93it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

nsdimagery recons:  89%|████████▉ | 16/18 [00:08<00:01,  1.95it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

nsdimagery recons:  94%|█████████▍| 17/18 [00:08<00:00,  1.95it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

nsdimagery recons: 100%|██████████| 18/18 [00:09<00:00,  1.92it/s]
nsdimagery recons:   0%|          | 0/18 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

nsdimagery recons:   6%|▌         | 1/18 [00:00<00:08,  2.03it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

nsdimagery recons:  11%|█         | 2/18 [00:01<00:08,  1.95it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

nsdimagery recons:  17%|█▋        | 3/18 [00:01<00:07,  1.98it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

nsdimagery recons:  22%|██▏       | 4/18 [00:02<00:07,  1.89it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

nsdimagery recons:  28%|██▊       | 5/18 [00:02<00:06,  1.92it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

nsdimagery recons:  33%|███▎      | 6/18 [00:03<00:06,  1.93it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

nsdimagery recons:  39%|███▉      | 7/18 [00:03<00:05,  1.92it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

nsdimagery recons:  44%|████▍     | 8/18 [00:04<00:05,  1.89it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

nsdimagery recons:  50%|█████     | 9/18 [00:04<00:04,  1.87it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

nsdimagery recons:  56%|█████▌    | 10/18 [00:05<00:04,  1.84it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

nsdimagery recons:  61%|██████    | 11/18 [00:05<00:03,  1.80it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

nsdimagery recons:  67%|██████▋   | 12/18 [00:06<00:03,  1.81it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

nsdimagery recons:  72%|███████▏  | 13/18 [00:06<00:02,  1.84it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

nsdimagery recons:  78%|███████▊  | 14/18 [00:07<00:02,  1.84it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

nsdimagery recons:  83%|████████▎ | 15/18 [00:07<00:01,  1.89it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

nsdimagery recons:  89%|████████▉ | 16/18 [00:08<00:01,  1.89it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

nsdimagery recons:  94%|█████████▍| 17/18 [00:09<00:00,  1.92it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

nsdimagery recons: 100%|██████████| 18/18 [00:09<00:00,  1.89it/s]
nsdimagery recons:   0%|          | 0/18 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

nsdimagery recons:   6%|▌         | 1/18 [00:00<00:08,  2.02it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

nsdimagery recons:  11%|█         | 2/18 [00:01<00:08,  1.87it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

nsdimagery recons:  17%|█▋        | 3/18 [00:01<00:07,  1.98it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

nsdimagery recons:  22%|██▏       | 4/18 [00:02<00:07,  1.96it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

nsdimagery recons:  28%|██▊       | 5/18 [00:02<00:06,  1.98it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

nsdimagery recons:  33%|███▎      | 6/18 [00:03<00:06,  1.95it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

nsdimagery recons:  39%|███▉      | 7/18 [00:03<00:05,  1.92it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

nsdimagery recons:  44%|████▍     | 8/18 [00:04<00:05,  1.90it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

nsdimagery recons:  50%|█████     | 9/18 [00:04<00:04,  1.86it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

nsdimagery recons:  56%|█████▌    | 10/18 [00:05<00:04,  1.86it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

nsdimagery recons:  61%|██████    | 11/18 [00:05<00:03,  1.84it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

nsdimagery recons:  67%|██████▋   | 12/18 [00:06<00:03,  1.83it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

nsdimagery recons:  72%|███████▏  | 13/18 [00:06<00:02,  1.85it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

nsdimagery recons:  78%|███████▊  | 14/18 [00:07<00:02,  1.90it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

nsdimagery recons:  83%|████████▎ | 15/18 [00:07<00:01,  1.91it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

nsdimagery recons:  89%|████████▉ | 16/18 [00:08<00:01,  1.93it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

nsdimagery recons:  94%|█████████▍| 17/18 [00:08<00:00,  1.96it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

nsdimagery recons: 100%|██████████| 18/18 [00:09<00:00,  1.91it/s]
nsdimagery recons:   0%|          | 0/18 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

nsdimagery recons:   6%|▌         | 1/18 [00:00<00:08,  1.96it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

nsdimagery recons:  11%|█         | 2/18 [00:01<00:08,  1.96it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

nsdimagery recons:  17%|█▋        | 3/18 [00:01<00:07,  2.00it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

nsdimagery recons:  22%|██▏       | 4/18 [00:02<00:07,  1.97it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

nsdimagery recons:  28%|██▊       | 5/18 [00:02<00:06,  1.96it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

nsdimagery recons:  33%|███▎      | 6/18 [00:03<00:06,  1.97it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

nsdimagery recons:  39%|███▉      | 7/18 [00:03<00:05,  1.92it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

nsdimagery recons:  44%|████▍     | 8/18 [00:04<00:05,  1.90it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

nsdimagery recons:  50%|█████     | 9/18 [00:04<00:04,  1.89it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

nsdimagery recons:  56%|█████▌    | 10/18 [00:05<00:04,  1.89it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

nsdimagery recons:  61%|██████    | 11/18 [00:05<00:03,  1.83it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

nsdimagery recons:  67%|██████▋   | 12/18 [00:06<00:03,  1.84it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

nsdimagery recons:  72%|███████▏  | 13/18 [00:06<00:02,  1.87it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

nsdimagery recons:  78%|███████▊  | 14/18 [00:07<00:02,  1.91it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

nsdimagery recons:  83%|████████▎ | 15/18 [00:07<00:01,  1.95it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

nsdimagery recons:  89%|████████▉ | 16/18 [00:08<00:01,  1.94it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

nsdimagery recons:  94%|█████████▍| 17/18 [00:08<00:00,  1.96it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

nsdimagery recons: 100%|██████████| 18/18 [00:09<00:00,  1.92it/s]
nsdimagery recons:   0%|          | 0/18 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

nsdimagery recons:   6%|▌         | 1/18 [00:00<00:08,  1.99it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

nsdimagery recons:  11%|█         | 2/18 [00:00<00:07,  2.02it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

nsdimagery recons:  17%|█▋        | 3/18 [00:01<00:07,  2.01it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

nsdimagery recons:  22%|██▏       | 4/18 [00:02<00:07,  1.97it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

nsdimagery recons:  28%|██▊       | 5/18 [00:02<00:06,  1.97it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

nsdimagery recons:  33%|███▎      | 6/18 [00:03<00:06,  1.97it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

nsdimagery recons:  39%|███▉      | 7/18 [00:03<00:05,  1.94it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

nsdimagery recons:  44%|████▍     | 8/18 [00:04<00:05,  1.92it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

nsdimagery recons:  50%|█████     | 9/18 [00:04<00:04,  1.89it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

nsdimagery recons:  56%|█████▌    | 10/18 [00:05<00:04,  1.87it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

nsdimagery recons:  61%|██████    | 11/18 [00:05<00:03,  1.82it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

nsdimagery recons:  67%|██████▋   | 12/18 [00:06<00:03,  1.83it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

nsdimagery recons:  72%|███████▏  | 13/18 [00:06<00:02,  1.85it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

nsdimagery recons:  78%|███████▊  | 14/18 [00:07<00:02,  1.89it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

nsdimagery recons:  83%|████████▎ | 15/18 [00:07<00:01,  1.93it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

nsdimagery recons:  89%|████████▉ | 16/18 [00:08<00:01,  1.93it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

nsdimagery recons:  94%|█████████▍| 17/18 [00:08<00:00,  1.94it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

nsdimagery recons: 100%|██████████| 18/18 [00:09<00:00,  1.92it/s]
nsdimagery recons:   0%|          | 0/18 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

nsdimagery recons:   6%|▌         | 1/18 [00:00<00:08,  2.03it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

nsdimagery recons:  11%|█         | 2/18 [00:01<00:08,  1.96it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

nsdimagery recons:  17%|█▋        | 3/18 [00:01<00:07,  1.97it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

nsdimagery recons:  22%|██▏       | 4/18 [00:02<00:07,  1.95it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

nsdimagery recons:  28%|██▊       | 5/18 [00:02<00:06,  1.95it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

nsdimagery recons:  33%|███▎      | 6/18 [00:03<00:06,  1.95it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

nsdimagery recons:  39%|███▉      | 7/18 [00:03<00:05,  1.90it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

nsdimagery recons:  44%|████▍     | 8/18 [00:04<00:05,  1.86it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

nsdimagery recons:  50%|█████     | 9/18 [00:04<00:04,  1.84it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

nsdimagery recons:  56%|█████▌    | 10/18 [00:05<00:04,  1.85it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

nsdimagery recons:  61%|██████    | 11/18 [00:05<00:03,  1.81it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

nsdimagery recons:  67%|██████▋   | 12/18 [00:06<00:03,  1.81it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

nsdimagery recons:  72%|███████▏  | 13/18 [00:06<00:02,  1.83it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

nsdimagery recons:  78%|███████▊  | 14/18 [00:07<00:02,  1.90it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

nsdimagery recons:  83%|████████▎ | 15/18 [00:07<00:01,  1.93it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

nsdimagery recons:  89%|████████▉ | 16/18 [00:08<00:01,  1.94it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

nsdimagery recons:  94%|█████████▍| 17/18 [00:08<00:00,  1.94it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

nsdimagery recons: 100%|██████████| 18/18 [00:09<00:00,  1.91it/s]


RuntimeError: Parent directory evals/subj01_40sess_hypatia_turbo_ridge_alpha_60000 does not exist.

In [9]:
# imsize = 150
# if all_images.shape[-1] != imsize:
#     all_images = transforms.Resize((imsize,imsize))(transforms.CenterCrop(all_images.shape[2])(all_images)).float()
# if all_recons.shape[-1] != imsize:
#     all_recons = transforms.Resize((imsize,imsize))(transforms.CenterCrop(all_images.shape[2])(all_recons)).float()
# print(all_images.shape, all_recons.shape)
# num_images = all_recons.shape[0]
# num_rows = (2 * num_images + 11) // 12

# # Interleave tensors
# merged = torch.stack([val for pair in zip(all_images, all_recons) for val in pair], dim=0)

# # Calculate grid size
# grid = torch.zeros((num_rows * 12, 3, all_recons.shape[-1], all_recons.shape[-1]))

# # Populate the grid
# grid[:2*num_images] = merged
# grid_images = [transforms.functional.to_pil_image(grid[i]) for i in range(num_rows * 12)]

# # Create the grid image
# grid_image = Image.new('RGB', (all_recons.shape[-1] * 12, all_recons.shape[-1] * num_rows))  # 12 images wide

# # Paste images into the grid
# for i, img in enumerate(grid_images):
#     grid_image.paste(img, (all_recons.shape[-1] * (i % 12), all_recons.shape[-1] * (i // 12)))

# # Create title row image
# title_height = 150
# title_image = Image.new('RGB', (grid_image.width, title_height), color=(255, 255, 255))
# draw = ImageDraw.Draw(title_image)
# font = ImageFont.truetype("arial.ttf", 38)  # Change font size to 3 times bigger (15*3)
# title_text = f"Model: {model_name}, Mode: {mode}"
# bbox = draw.textbbox((0, 0), title_text, font=font)
# text_width, text_height = bbox[2] - bbox[0], bbox[3] - bbox[1]
# draw.text(((grid_image.width - text_width) / 2, (title_height - text_height) / 2), title_text, fill="black", font=font)

# # Combine title and grid images
# final_image = Image.new('RGB', (grid_image.width, grid_image.height + title_height))
# final_image.paste(title_image, (0, 0))
# final_image.paste(grid_image, (0, title_height))

# final_image.save(f"../figs/{model_name}_{len(all_recons)}recons_{mode}.png")
# print(f"saved ../figs/{model_name}_{len(all_recons)}recons_{mode}.png")

In [None]:

if not utils.is_interactive():
    sys.exit(0)