In [1]:
import os
import sys
import json
import argparse
import numpy as np
import math
from einops import rearrange
import time
import random
import string
import h5py
from tqdm import tqdm
import webdataset as wds

import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from torchvision import transforms
from accelerate import Accelerator

# SDXL unCLIP requires code from https://github.com/Stability-AI/generative-models/tree/main
import sgm
from pkgs.MindEyeV2.src.generative_models.sgm.modules.encoders.modules import FrozenOpenCLIPImageEmbedder, FrozenOpenCLIPEmbedder2
from pkgs.MindEyeV2.src.generative_models.sgm.models.diffusion import DiffusionEngine
from pkgs.MindEyeV2.src.generative_models.sgm.util import append_dims
from omegaconf import OmegaConf

# tf32 data type is faster than standard float32
torch.backends.cuda.matmul.allow_tf32 = True

# custom functions #
import pkgs.MindEyeV2.src.utils as utils
from pkgs.MindEyeV2.src.models import *

import lovely_tensors as lt
lt.monkey_patch()

from csng.data import get_dataloaders
from csng.utils.mix import seed_all
from csng.utils.data import crop

DATA_PATH_BRAINREADER = os.path.join(os.environ["DATA_PATH"], "brainreader")
DATA_PATH_MINDEYE = os.path.join(os.environ["DATA_PATH"], "mindeye")
DATA_PATH_MINDEYE_CACHE = os.path.join(DATA_PATH_MINDEYE, "cache")
print(f"{DATA_PATH_BRAINREADER=}\n{DATA_PATH_MINDEYE=}\n{DATA_PATH_MINDEYE_CACHE=}")

# accelerator = Accelerator(split_batches=False, mixed_precision="fp16")
# device = accelerator.device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("device:",device)

!nvidia-smi



DATA_PATH_BRAINREADER='/scratch/izar/sobotka/csng/brainreader'
DATA_PATH_MINDEYE='/scratch/izar/sobotka/csng/mindeye'
DATA_PATH_MINDEYE_CACHE='/scratch/izar/sobotka/csng/mindeye/cache'
device: cuda
Sun May  4 16:42:50 2025       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.154.05             Driver Version: 535.154.05   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  Tesla V100-PCIE-32GB           On  | 00000000:86:00.0 Off |                  Off |
| N/A   36C    P0              26W / 250W |      0MiB / 32768MiB |      0%      Default |
|                                         |       

## Configuration

In [4]:
# Load pretrained model ckpt
model_name = "csng_18-02-25_19-45"
tag = "best"
print(f"\n---loading {model_name}:{tag} ckpt---\n")
checkpoint = torch.load(f"{DATA_PATH_MINDEYE}/train_logs/{model_name}/{tag}.pth", map_location='cpu')
state_dict = checkpoint['model_state_dict']
cfg = checkpoint['cfg']
evals_dir = cfg["model"]["evalsdir"]
outdir = cfg["model"]["outdir"]
assert os.path.exists(outdir)
os.makedirs(evals_dir, exist_ok=True)
cfg


---loading csng_18-02-25_19-45:best ckpt---



{'device': 'cuda',
 'data_type': torch.float16,
 'seed': 0,
 'data': {'mixing_strategy': 'parallel_min',
  'max_training_batches': None,
  'brainreader_mouse': {'device': 'cuda',
   'mixing_strategy': 'parallel_min',
   'max_batches': None,
   'data_dir': '/scratch/izar/sobotka/csng/brainreader/data',
   'batch_size': 12,
   'sessions': [6],
   'resize_stim_to': (36, 64),
   'normalize_stim': True,
   'normalize_resp': True,
   'div_resp_by_std': True,
   'clamp_neg_resp': False,
   'additional_keys': None,
   'avg_test_resp': True,
   'drop_last': True}},
 'wandb': {'project': 'MindEye', 'group': 'mindeye'},
 'model': {'model_name': 'csng_18-02-25_19-45',
  'cache_dir': '/scratch/izar/sobotka/csng/mindeye/cache',
  'data_path': '/scratch/izar/sobotka/csng/brainreader',
  'outdir': '/scratch/izar/sobotka/csng/mindeye/train_logs/csng_18-02-25_19-45',
  'evalsdir': '/scratch/izar/sobotka/csng/mindeye/evals/csng_18-02-25_19-45',
  'ckpt_saving': True,
  'ckpt_interval': 1,
  'subj_list': 

## Model

In [5]:
clip_img_embedder = FrozenOpenCLIPImageEmbedder(
    arch=cfg["model"]["clip_img_embedder_arch"],
    version=cfg["model"]["clip_img_embedder_version"],
    output_tokens=True,
    only_tokens=True,
    cache_dir=cfg["model"]["cache_dir"],
)
clip_img_embedder.to(cfg["device"])

open_clip_pytorch_model.bin:   0%|          | 0.00/10.2G [00:00<?, ?B/s]

FrozenOpenCLIPImageEmbedder(
  (model): CLIP(
    (visual): VisionTransformer(
      (conv1): Conv2d(3, 1664, kernel_size=(14, 14), stride=(14, 14), bias=False)
      (patch_dropout): Identity()
      (ln_pre): LayerNorm((1664,), eps=1e-05, elementwise_affine=True)
      (transformer): Transformer(
        (resblocks): ModuleList(
          (0-47): 48 x ResidualAttentionBlock(
            (ln_1): LayerNorm((1664,), eps=1e-05, elementwise_affine=True)
            (attn): MultiheadAttention(
              (out_proj): NonDynamicallyQuantizableLinear(in_features=1664, out_features=1664, bias=True)
            )
            (ls_1): Identity()
            (ln_2): LayerNorm((1664,), eps=1e-05, elementwise_affine=True)
            (mlp): Sequential(
              (c_fc): Linear(in_features=1664, out_features=8192, bias=True)
              (gelu): GELU(approximate='none')
              (c_proj): Linear(in_features=8192, out_features=1664, bias=True)
            )
            (ls_2): Identity()


In [8]:
if cfg["model"]["blurry_recon"]:
    from diffusers import AutoencoderKL
    autoenc = AutoencoderKL(**cfg["model"]["autoenc"])
    ckpt = torch.load(f'{cfg["model"]["cache_dir"]}/sd_image_var_autoenc.pth')
    autoenc.load_state_dict(ckpt)
    autoenc.eval()
    autoenc.requires_grad_(False)
    autoenc.to(cfg["device"])
    utils.count_params(autoenc)

param counts:
83,653,863 total
0 trainable


In [9]:
class MindEyeModule(nn.Module):
    def __init__(self):
        super(MindEyeModule, self).__init__()
    
    def forward(self, x):
        return x
        
class RidgeRegression(torch.nn.Module):
    # make sure to add weight_decay when initializing optimizer to enable regularization
    def __init__(self, input_sizes, out_features): 
        super(RidgeRegression, self).__init__()
        self.out_features = out_features
        self.linears = torch.nn.ModuleList([
                torch.nn.Linear(input_size, out_features) for input_size in input_sizes
            ])
    
    def forward(self, x, subj_idx):
        out = self.linears[subj_idx](x[:,0]).unsqueeze(1)
        return out

model = MindEyeModule()
model.ridge = RidgeRegression(cfg["model"]["num_voxels_list"], out_features=cfg["model"]["hidden_dim"])

In [10]:
from pkgs.MindEyeV2.src.models import BrainNetwork

model.backbone = BrainNetwork(**cfg["model"]["brainnetwork"]) 
utils.count_params(model.ridge)
utils.count_params(model.backbone)
utils.count_params(model)

# setup diffusion prior network
prior_network = PriorNetwork(**cfg["model"]["prior_network"])

model.diffusion_prior = BrainDiffusionPrior(net=prior_network, **cfg["model"]["brain_diffusion_prior"])
model.to(cfg["device"])

utils.count_params(model.diffusion_prior)
utils.count_params(model)

param counts:
6,595,584 total
6,595,584 trainable
param counts:
345,356,284 total
345,356,284 trainable
param counts:
351,951,868 total
351,951,868 trainable
param counts:
259,865,216 total
259,865,200 trainable
param counts:
611,817,084 total
611,817,068 trainable


611817068

In [12]:
# Load pretrained model ckpt
model.load_state_dict(checkpoint['model_state_dict'], strict=True)
del checkpoint

In [14]:
# setup text caption networks
from transformers import AutoProcessor, AutoModelForCausalLM
from pkgs.MindEyeV2.src.modeling_git import GitForCausalLMClipEmb

processor = AutoProcessor.from_pretrained("microsoft/git-large-coco")
clip_text_model = GitForCausalLMClipEmb.from_pretrained("microsoft/git-large-coco")
# clip_text_model.to(device) # if you get OOM running this script, you can switch this to cpu and lower minibatch_size to 4
clip_text_model.to("cpu") # if you get OOM running this script, you can switch this to cpu and lower minibatch_size to 4
clip_text_model.eval().requires_grad_(False)
cfg["model"]["clip_text_seq_dim"] = 257
cfg["model"]["clip_text_emb_dim"] = 1024

class CLIPConverter(torch.nn.Module):
    def __init__(self):
        super(CLIPConverter, self).__init__()
        self.linear1 = nn.Linear(cfg["model"]["clip_seq_dim"], cfg["model"]["clip_text_seq_dim"])
        self.linear2 = nn.Linear(cfg["model"]["clip_emb_dim"], cfg["model"]["clip_text_emb_dim"])

    def forward(self, x):
        x = x.permute(0,2,1)
        x = self.linear1(x)
        x = self.linear2(x.permute(0,2,1))
        return x
        
clip_convert = CLIPConverter()
state_dict = torch.load(f"{cfg['model']['cache_dir']}/bigG_to_L_epoch8.pth", map_location='cpu')['model_state_dict']
clip_convert.load_state_dict(state_dict, strict=True)
# clip_convert.to(device) # if you get OOM running this script, you can switch this to cpu and lower minibatch_size to 4
clip_convert.to("cpu") # if you get OOM running this script, you can switch this to cpu and lower minibatch_size to 4
del state_dict

In [23]:
# prep unCLIP
from omegaconf import OmegaConf
from copy import deepcopy

config = OmegaConf.load("src/generative_models/configs/unclip6.yaml")
config = OmegaConf.to_container(config, resolve=True)
cfg["model"]["unclip"] = deepcopy(config)
unclip_params = config["model"]["params"]
network_config = unclip_params["network_config"]
denoiser_config = unclip_params["denoiser_config"]
first_stage_config = unclip_params["first_stage_config"]
conditioner_config = unclip_params["conditioner_config"]
sampler_config = unclip_params["sampler_config"]
scale_factor = unclip_params["scale_factor"]
disable_first_stage_autocast = unclip_params["disable_first_stage_autocast"]
offset_noise_level = unclip_params["loss_fn_config"]["params"]["offset_noise_level"]

first_stage_config['target'] = 'sgm.models.autoencoder.AutoencoderKL'
sampler_config['params']['num_steps'] = 38

diffusion_engine = DiffusionEngine(
    network_config=network_config,
    denoiser_config=denoiser_config,
    first_stage_config=first_stage_config,
    conditioner_config=conditioner_config,
    sampler_config=sampler_config,
    scale_factor=scale_factor,
    disable_first_stage_autocast=disable_first_stage_autocast
)
# set to inference
diffusion_engine.eval().requires_grad_(False)
diffusion_engine.to(cfg["device"])

ckpt_path = f'{cfg["model"]["cache_dir"]}/unclip6_epoch0_step110000.ckpt'
ckpt = torch.load(ckpt_path, map_location='cpu')
diffusion_engine.load_state_dict(ckpt['state_dict'])

batch={"jpg": torch.randn(1,3,1,1).to(cfg["device"]), # jpg doesnt get used, it's just a placeholder
      "original_size_as_tuple": torch.ones(1, 2).to(cfg["device"]) * 768,
      "crop_coords_top_left": torch.zeros(1, 2).to(cfg["device"])}
out = diffusion_engine.conditioner(batch)
vector_suffix = out["vector"].to(cfg["device"])
print("vector_suffix", vector_suffix.shape)



Initialized embedder #0: FrozenOpenCLIPImageEmbedder with 1909889025 params. Trainable: False
Initialized embedder #1: ConcatTimestepEmbedderND with 0 params. Trainable: False
Initialized embedder #2: ConcatTimestepEmbedderND with 0 params. Trainable: False
vector_suffix torch.Size([1, 1024])


## Data

In [28]:
cfg["crop_wins"]

KeyError: 'crop_wins'

In [None]:
### select a subject to test on
save_to = os.path.join(evals_dir, f"patterns_{subj_name}")
os.makedirs(save_to, exist_ok=True)
subj_name = "subj06"
cfg["data_name"] = "brainreader_mouse"
cfg["crop_wins"] = {
    "mouse_v1": (22, 36),
    "cat_v1": (20, 20),
    "brainreader_mouse": None,
}
data_tier = "test"
subj_list_idx = list(cfg["model"]["num_voxels"].keys()).index(subj_name)

patterns_file = torch.load(f"src/stim_resp_pairs_{subj_name}.pt")
images, voxels = patterns_file["stim"], {subj_name: patterns_file["resp"]}

# voxels = {subj_name: []}
# images = []
# for b_i, batch in enumerate(test_dl):
#     images.append(batch.images.cpu())
#     voxels[subj_name].append(batch.responses.cpu())
# images = torch.cat(images, dim=0)
# voxels = {k: torch.cat(v, dim=0) for k,v in voxels.items()}
print(f"{subj_name=}\n{images=}\n{voxels[subj_name]=}")

subj_name='subj06'
images=tensor[28, 1, 36, 64] n=64512 (0.2Mb) x∈[-1.575, 4.105] μ=-4.943e-08 σ=1.000 cuda:0
voxels[subj_name]=tensor[28, 8587] n=240436 (0.9Mb) x∈[0.001, 6.323] μ=0.403 σ=0.406


In [25]:
img_tform = transforms.Compose([
    transforms.Lambda(lambda x: crop(x, cfg["crop_wins"][cfg["data_name"]])),
    # transforms.Resize((224, 224), antialias=True),
    # transforms.Lambda(lambda x: x.repeat(1, 3, 1, 1)),
])

In [26]:
# get all reconstructions
model.to(cfg["device"])
model.eval().requires_grad_(False)

all_blurryrecons = None
all_recons = None
all_predcaptions = []
all_clipvoxels = None

minibatch_size = 1
num_samples_per_image = 1

plotting = False

seed_all(cfg["seed"])
with torch.no_grad():
    for start_idx in tqdm(range(0,len(images),minibatch_size)):
        voxel = voxels[subj_name][start_idx:start_idx + minibatch_size].unsqueeze(1).to(cfg["device"])

        # voxel_ridge = model.ridge(voxel, 0) # 0th index of subj_list
        voxel_ridge = model.ridge(voxel, subj_list_idx)
        torch.cuda.empty_cache()
        backbone, clip_voxels, blurry_image_enc = model.backbone(voxel_ridge)
        blurry_image_enc = blurry_image_enc[0]
                
        # Save retrieval submodule outputs
        if all_clipvoxels is None:
            all_clipvoxels = clip_voxels.cpu()
        else:
            all_clipvoxels = torch.vstack((all_clipvoxels, clip_voxels.cpu()))
        
        # Feed voxels through OpenCLIP-bigG diffusion prior
        prior_out = model.diffusion_prior.p_sample_loop(backbone.shape, 
                        text_cond = dict(text_embed = backbone), 
                        cond_scale = 1., timesteps = 20)
        
        pred_caption_emb = clip_convert(prior_out.to(clip_convert.linear1.weight.device, clip_convert.linear1.weight.dtype))
        generated_ids = clip_text_model.generate(pixel_values=pred_caption_emb, max_length=20)
        generated_caption = processor.batch_decode(generated_ids, skip_special_tokens=True)
        all_predcaptions = np.hstack((all_predcaptions, generated_caption))
        print(generated_caption)
        
        ### feed diffusion prior outputs through unCLIP
        for i in range(len(voxel)):
            assert images.amin() < 0 or images.amax() > 1
            samples = utils.unclip_recon(
                prior_out[[i]],
                diffusion_engine,
                vector_suffix,
                num_samples=num_samples_per_image,
                clamp=False, # to [0, 1]
            )
            if all_recons is None:
                all_recons = samples.cpu()
            else:
                all_recons = torch.vstack((all_recons, samples.cpu()))
            if plotting:
                for s in range(num_samples_per_image):
                    plt.figure(figsize=(2,2))
                    # plt.imshow(transforms.ToPILImage()(samples[s]))
                    plt.imshow(samples[s].cpu().permute(1,2,0).to(torch.float32))
                    plt.axis('off')
                    plt.show()

        if cfg["model"]["blurry_recon"]:
            # blurred_image = (autoenc.decode(blurry_image_enc/0.18215).sample/ 2 + 0.5).clamp(0,1)
            blurred_image = autoenc.decode(blurry_image_enc/0.18215).sample # already z-scored
            
            for i in range(len(voxel)):
                im = torch.Tensor(blurred_image[i])
                if all_blurryrecons is None:
                    all_blurryrecons = im[None].cpu()
                else:
                    all_blurryrecons = torch.vstack((all_blurryrecons, im[None].cpu()))
                if plotting:
                    plt.figure(figsize=(2,2))
                    # plt.imshow(transforms.ToPILImage()(im))
                    plt.imshow(im.cpu().permute(1,2,0).to(torch.float32))
                    plt.axis('off')
                    plt.show()

# resize outputs before saving
imsize = 256
all_recons = transforms.Resize((imsize,imsize))(all_recons).float()
if cfg["model"]["blurry_recon"]: 
    all_blurryrecons = transforms.Resize((imsize,imsize))(all_blurryrecons).float()

# saving
print(all_recons.shape)
torch.save(img_tform(images), f"{save_to}/{subj_name}_{data_tier}_all_images.pt")
torch.save(images, f"{save_to}/{subj_name}_{data_tier}_all_images_before_transform.pt")
torch.save(voxels,f"{save_to}/{subj_name}_{data_tier}_all_voxels.pt") 
if cfg["model"]["blurry_recon"]:
    torch.save(all_blurryrecons, f"{save_to}/{subj_name}_{data_tier}_all_blurryrecons.pt")
torch.save(all_recons, f"{save_to}/{subj_name}_{data_tier}_all_recons.pt")
torch.save(all_predcaptions, f"{save_to}/{subj_name}_{data_tier}_all_predcaptions.pt")
torch.save(all_clipvoxels, f"{save_to}/{subj_name}_{data_tier}_all_clipvoxels.pt")
torch.save(cfg, f"{save_to}/cfg.pt")
print(f"saved {cfg['model']['model_name']} outputs!")

  0%|                                     | 0/28 [00:00<?, ?it/s]

sampling loop time step:   0%|          | 0/19 [00:00<?, ?it/s]

['a black and white photo of a person.']


  4%|█                            | 1/28 [00:16<07:34, 16.83s/it]

sampling loop time step:   0%|          | 0/19 [00:00<?, ?it/s]

['a large white and black object.']


  7%|██                           | 2/28 [00:28<05:57, 13.75s/it]

sampling loop time step:   0%|          | 0/19 [00:00<?, ?it/s]

['a black and white photo of a person.']


 11%|███                          | 3/28 [00:40<05:24, 12.97s/it]

sampling loop time step:   0%|          | 0/19 [00:00<?, ?it/s]

['a large black and white object.']


 14%|████▏                        | 4/28 [00:52<04:57, 12.41s/it]

sampling loop time step:   0%|          | 0/19 [00:00<?, ?it/s]

['a view of a person.']


 18%|█████▏                       | 5/28 [01:03<04:36, 12.02s/it]

sampling loop time step:   0%|          | 0/19 [00:00<?, ?it/s]

['a person is standing up.']


 21%|██████▏                      | 6/28 [01:14<04:19, 11.79s/it]

sampling loop time step:   0%|          | 0/19 [00:00<?, ?it/s]

['a car is parked in front of a building.']


 25%|███████▎                     | 7/28 [01:26<04:11, 11.95s/it]

sampling loop time step:   0%|          | 0/19 [00:00<?, ?it/s]

['a dog is standing in front of a camera.']


 29%|████████▎                    | 8/28 [01:39<04:01, 12.07s/it]

sampling loop time step:   0%|          | 0/19 [00:00<?, ?it/s]

['a man is standing in front of a building.']


 32%|█████████▎                   | 9/28 [01:51<03:50, 12.16s/it]

sampling loop time step:   0%|          | 0/19 [00:00<?, ?it/s]

['a man is holding a piece of paper.']


 36%|██████████                  | 10/28 [02:03<03:38, 12.12s/it]

sampling loop time step:   0%|          | 0/19 [00:00<?, ?it/s]

['a car is parked in front of a building.']


 39%|███████████                 | 11/28 [02:15<03:27, 12.18s/it]

sampling loop time step:   0%|          | 0/19 [00:00<?, ?it/s]

['a group of people.']


 43%|████████████                | 12/28 [02:27<03:09, 11.85s/it]

sampling loop time step:   0%|          | 0/19 [00:00<?, ?it/s]

['a man is standing in front of a wall.']


 46%|█████████████               | 13/28 [02:39<02:59, 11.99s/it]

sampling loop time step:   0%|          | 0/19 [00:00<?, ?it/s]

['a person is standing up.']


 50%|██████████████              | 14/28 [02:50<02:45, 11.81s/it]

sampling loop time step:   0%|          | 0/19 [00:00<?, ?it/s]

['a large white and black truck.']


 54%|███████████████             | 15/28 [03:02<02:32, 11.75s/it]

sampling loop time step:   0%|          | 0/19 [00:00<?, ?it/s]

['a person is standing up.']


 57%|████████████████            | 16/28 [03:13<02:19, 11.63s/it]

sampling loop time step:   0%|          | 0/19 [00:00<?, ?it/s]

['a man is standing in front of a wall.']


 61%|█████████████████           | 17/28 [03:26<02:10, 11.84s/it]

sampling loop time step:   0%|          | 0/19 [00:00<?, ?it/s]

['a woman holding a cell phone.']


 64%|██████████████████          | 18/28 [03:37<01:57, 11.77s/it]

sampling loop time step:   0%|          | 0/19 [00:00<?, ?it/s]

['a large black and white photo of a person.']


 68%|███████████████████         | 19/28 [03:50<01:47, 11.94s/it]

sampling loop time step:   0%|          | 0/19 [00:00<?, ?it/s]

['a woman holding a cell phone.']


 71%|████████████████████        | 20/28 [04:01<01:34, 11.85s/it]

sampling loop time step:   0%|          | 0/19 [00:00<?, ?it/s]

["a close up of a person's face"]


 75%|█████████████████████       | 21/28 [04:13<01:23, 11.92s/it]

sampling loop time step:   0%|          | 0/19 [00:00<?, ?it/s]

['a large group of people.']


 79%|██████████████████████      | 22/28 [04:25<01:10, 11.76s/it]

sampling loop time step:   0%|          | 0/19 [00:00<?, ?it/s]

['a car is parked in front of a truck.']


 82%|███████████████████████     | 23/28 [04:37<00:59, 11.93s/it]

sampling loop time step:   0%|          | 0/19 [00:00<?, ?it/s]

['a group of people standing around.']


 86%|████████████████████████    | 24/28 [04:49<00:47, 11.85s/it]

sampling loop time step:   0%|          | 0/19 [00:00<?, ?it/s]

['a group of people.']


 89%|█████████████████████████   | 25/28 [05:00<00:34, 11.63s/it]

sampling loop time step:   0%|          | 0/19 [00:00<?, ?it/s]

['a large group of people.']


 93%|██████████████████████████  | 26/28 [05:11<00:23, 11.56s/it]

sampling loop time step:   0%|          | 0/19 [00:00<?, ?it/s]

['a car is parked in front of a building.']


 96%|███████████████████████████ | 27/28 [05:23<00:11, 11.79s/it]

sampling loop time step:   0%|          | 0/19 [00:00<?, ?it/s]

['a group of people standing around each other.']


100%|████████████████████████████| 28/28 [05:36<00:00, 12.00s/it]


torch.Size([28, 3, 256, 256])


KeyError: 'crop_wins'