In [1]:
import torch
import models
import os
import sys
import json
import argparse
import numpy as np
import math
from einops import rearrange
import utils
import random
import string
import h5py
from tqdm import tqdm
from IPython.display import display, Image

import webdataset as wds
import gc

import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from torchvision import transforms


  from .autonotebook import tqdm as notebook_tqdm


In [17]:
global_batch_size = batch_size = 128
data_type = torch.float16 # change depending on your mixed_precision

if utils.is_interactive():
    # create random model_name
    model_name = ''.join(random.choices(string.ascii_letters + string.digits, k=10))
    model_name = model_name + "_interactive"
    print("model_name:", model_name)

    # global_batch_size and batch_size should already be defined in the above cells
    jupyter_args = f"--data_path=/fsx/proj-fmri/shared/mindeyev2_dataset \
                    --model_name={model_name} \
                    --subj=1 --batch_size={batch_size} --blurry_recon --no-depth_recon \
                    --clip_scale=1. --blur_scale=100. --depth_scale=100. \
                    --max_lr=3e-4 --mixup_pct=.66 --num_epochs=24 --ckpt_interval=999 --no-use_image_aug --no-ckpt_saving"

    jupyter_args = jupyter_args.split()#other variables can be specified in the following string:
    
    print(jupyter_args)
    
    from IPython.display import clear_output # function to clear print outputs in cell
    %load_ext autoreload 
    # this allows you to change functions in models.py or utils.py and have this notebook automatically update with your revisions
    %autoreload 2 

model_name: BvJzD580AL_interactive
['--data_path=/fsx/proj-fmri/shared/mindeyev2_dataset', '--model_name=BvJzD580AL_interactive', '--subj=1', '--batch_size=128', '--blurry_recon', '--no-depth_recon', '--clip_scale=1.', '--blur_scale=100.', '--depth_scale=100.', '--max_lr=3e-4', '--mixup_pct=.66', '--num_epochs=24', '--ckpt_interval=999', '--no-use_image_aug', '--no-ckpt_saving']


In [18]:
parser = argparse.ArgumentParser(description="Model Training Configuration")
parser.add_argument(
    "--model_name", type=str, default="testing",
    help="name of model, used for ckpt saving and wandb logging (if enabled)",
)
parser.add_argument(
    "--data_path", type=str, default="/fsx/proj-fmri/shared/natural-scenes-dataset",
    help="Path to where NSD data is stored / where to download it to",
)
parser.add_argument(
    "--subj",type=int, default=1, choices=[1,2,5,7],
)
parser.add_argument(
    "--batch_size", type=int, default=32,
    help="Batch size can be increased by 10x if only training v2c and not diffusion diffuser",
)
parser.add_argument(
    "--wandb_log",action=argparse.BooleanOptionalAction,default=False,
    help="whether to log to wandb",
)
parser.add_argument(
    "--resume_from_ckpt",action=argparse.BooleanOptionalAction,default=False,
    help="if not using wandb and want to resume from a ckpt",
)
parser.add_argument(
    "--wandb_project",type=str,default="stability",
    help="wandb project name",
)
parser.add_argument(
    "--mixup_pct",type=float,default=.33,
    help="proportion of way through training when to switch from BiMixCo to SoftCLIP",
)
parser.add_argument(
    "--blurry_recon",action=argparse.BooleanOptionalAction,default=True,
    help="whether to output blurry reconstructions",
)
parser.add_argument(
    "--depth_recon",action=argparse.BooleanOptionalAction,default=True,
    help="whether to output depth reconstructions",
)
parser.add_argument(
    "--blur_scale",type=float,default=100.,
    help="multiply loss from blurry recons by this number",
)
parser.add_argument(
    "--depth_scale",type=float,default=100.,
    help="multiply loss from depth recons by this number",
)
parser.add_argument(
    "--clip_scale",type=float,default=1.,
    help="multiply contrastive loss by this number",
)
parser.add_argument(
    "--use_image_aug",action=argparse.BooleanOptionalAction,default=True,
    help="whether to use image augmentation",
)
parser.add_argument(
    "--num_epochs",type=int,default=120,
    help="number of epochs of training",
)
parser.add_argument(
    "--lr_scheduler_type",type=str,default='cycle',choices=['cycle','linear'],
)
parser.add_argument(
    "--ckpt_saving",action=argparse.BooleanOptionalAction,default=True,
)
parser.add_argument(
    "--ckpt_interval",type=int,default=5,
    help="save backup ckpt and reconstruct every x epochs",
)
parser.add_argument(
    "--seed",type=int,default=42,
)
parser.add_argument(
    "--max_lr",type=float,default=3e-4,
)

if utils.is_interactive():
    args = parser.parse_args(jupyter_args)
else:
    args = parser.parse_args()

# create global variables without the args prefix
for attribute_name in vars(args).keys():
    globals()[attribute_name] = getattr(args, attribute_name)

In [19]:
m1 = models.DV2MLP(use_cont=True)
m2 = models.DV2MLP(use_cont=True)
m3 = models.DV2MLP(use_cont=True)
m4 = models.DV2MLP(use_cont=True)

In [21]:
ckp1 = torch.load("/fsx/proj-fmri/mihirneal/MindEyeV2/train_logs/models/dp_1/last.pth")
ckp2 = torch.load("/fsx/proj-fmri/mihirneal/MindEyeV2/train_logs/models/dp_2/last.pth")
ckp3 = torch.load("/fsx/proj-fmri/mihirneal/MindEyeV2/train_logs/models/dp_3/last.pth")
ckp4 = torch.load("/fsx/proj-fmri/mihirneal/MindEyeV2/train_logs/models/dp_4/last.pth")
m1.load_state_dict(ckp1["model_state_dict"])
m2.load_state_dict(ckp2["model_state_dict"])
m3.load_state_dict(ckp3["model_state_dict"])
m4.load_state_dict(ckp4["model_state_dict"])

<All keys matched successfully>

In [24]:
train_url = f"{data_path}/wds/subj0{subj}/new_train/" + "{0..36}.tar"
print(train_url)
def my_split_by_node(urls): return urls
if subj==1:
    num_train = 24958
    num_test = 2770
test_batch_size = num_test
train_data = wds.WebDataset(train_url,resampled=False,nodesplitter=my_split_by_node)\
                    .shuffle(750, initial=1500, rng=random.Random(42))\
                    .decode("torch")\
                    .rename(behav="behav.npy", future_behav="future_behav.npy")\
                    .to_tuple(*["behav", "future_behav"])
train_dl = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=False, drop_last=True, pin_memory=True)

test_url = f"{data_path}/wds/subj0{subj}/test/" + "0.tar"
print(test_url)

test_data = wds.WebDataset(test_url,resampled=False,nodesplitter=my_split_by_node)\
                    .shuffle(750, initial=1500, rng=random.Random(42))\
                    .decode("torch")\
                    .rename(behav="behav.npy", future_behav="future_behav.npy")\
                    .to_tuple(*["behav", "future_behav"])
test_dl = torch.utils.data.DataLoader(test_data, batch_size=test_batch_size, shuffle=False, drop_last=True, pin_memory=True)


# ### check dataloaders are working

# In[13]:


test_vox_indices = []
test_73k_images = []
for test_i, (behav, future_behav) in enumerate(test_dl):
    test_vox_indices = np.append(test_vox_indices, behav[:,0,5].cpu().numpy())
    test_73k_images = np.append(test_73k_images, behav[:,0,0].cpu().numpy())
test_vox_indices = test_vox_indices.astype(np.int16)
print(test_i, (test_i+1) * test_batch_size, len(test_vox_indices))
print("---\n")

train_vox_indices = []
train_73k_images = []
for train_i, (behav, future_behav) in enumerate(train_dl):
    train_vox_indices = np.append(train_vox_indices, behav[:,0,5].long().cpu().numpy())
    train_73k_images = np.append(train_73k_images, behav[:,0,0].cpu().numpy())
train_vox_indices = train_vox_indices.astype(np.int16)
print(train_i, (train_i+1) * batch_size, len(train_vox_indices))

all_vox_indices = np.hstack((train_vox_indices, test_vox_indices))
all_images = np.hstack((train_73k_images, test_73k_images))

/fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj01/new_train/{0..36}.tar
/fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj01/test/0.tar
0 2770 2770
---

194 24960 24960


In [26]:
# load betas
f = h5py.File(f'{data_path}/betas_all_subj0{subj}.hdf5', 'r')
# f = h5py.File(f'{data_path}/betas_subj0{subj}_thresholded_wholebrain.hdf5', 'r')

voxels = f['betas'][:]
print(f"subj0{subj} betas loaded into memory")
voxels = torch.Tensor(voxels).to("cpu").to(data_type)
print("voxels", voxels.shape)
num_voxels = voxels.shape[-1]

# load orig images
f = h5py.File(f'{data_path}/coco_images_224_float16.hdf5', 'r')
images = f['images'][:]
images = torch.Tensor(images).to("cpu").to(data_type)
print("images", images.shape)

subj01 betas loaded into memory
voxels torch.Size([27750, 15724])
images torch.Size([73000, 3, 224, 224])


In [27]:
test_image, test_voxel = None, None
with torch.no_grad(), torch.cuda.amp.autocast(dtype=data_type): 
            for test_i, (behav, future_behav) in enumerate(test_dl):  
                # all test samples should be loaded per batch such that test_i should never exceed 0
                assert len(behav) == num_test
                
                ## Average same-image repeats ##
                if test_image is None:
                    voxel = voxels[behav[:,0,5].cpu().long()]
                    image = behav[:,0,0].cpu().long()
                    
                    unique_image, sort_indices = torch.unique(image, return_inverse=True)
                    for im in unique_image:
                        locs = torch.where(im == image)[0]
                        if test_image is None:
                            test_image = images[im][None]
                            test_voxel = torch.mean(voxel[locs],axis=0)[None]
                        else:
                            test_image = torch.vstack((test_image, images[im][None]))
                            test_voxel = torch.vstack((test_voxel, torch.mean(voxel[locs],axis=0)[None]))

In [9]:
from functools import partial
from diffusers.models.vae import Decoder
class BrainNetwork(nn.Module):
    def __init__(self, text, pool, out_dim=2048, in_dim=15724, seq_len=2, h=4096, n_blocks=4, drop=.15, clip_size=2048):
        super().__init__()
        self.is_text = text
        self.is_pool = pool
        self.seq_len = seq_len
        self.h = h
        self.clip_size = clip_size
        
        # Initial linear layer to match the input dimensions to hidden dimensions
        # self.lin0 = nn.Linear(in_dim, seq_len * h)
        
        # Mixer Blocks
        self.mixer_blocks1 = nn.ModuleList([
            self.mixer_block1(h, drop) for _ in range(n_blocks)
        ])
        self.mixer_blocks2 = nn.ModuleList([
            self.mixer_block2(seq_len, drop) for _ in range(n_blocks)
        ])
        
        # Output linear layer
        if self.is_text:
            self.clin1 = nn.Linear(h * seq_len, out_dim, bias=True)
            self.clin2 = nn.Linear(h * seq_len, 768*257, bias=True)
            self.clip_proj = nn.Sequential(
            nn.LayerNorm(768),
            nn.GELU(),
            nn.Linear(768, 2048),
            nn.LayerNorm(2048),
            nn.GELU(),
            nn.Linear(2048, 2048),
            nn.LayerNorm(2048),
            nn.GELU(),
            nn.Linear(2048, 768)
        )
        if self.is_pool:
            self.clin3 = nn.Sequential(
                nn.Linear(h * seq_len, out_dim),
                nn.LayerNorm(out_dim),
                nn.GELU(),
                nn.Linear(out_dim, 1280),
            )

        # low-rank matrices
        # self.rank = 500
        # self.U = nn.Parameter(torch.randn(self.rank, out_dim))
        # self.V = nn.Parameter(torch.randn(h * seq_len, self.rank))
        # self.S = nn.Parameter(torch.randn(out_dim))

        

#         self.text_proj = nn.Sequential(
#             nn.LayerNorm(clip_size),
#             nn.GELU(),
#             nn.Linear(clip_size, 2048),
#             # nn.Dropout(0.5),
#             nn.LayerNorm(2048),
#             nn.GELU(),
#             nn.Linear(2048, 2048),
#             # nn.Dropout(0.5),
#             nn.LayerNorm(2048),
#             nn.GELU(),
#             nn.Linear(2048, clip_size)
#         )

#         self.pooled_proj = nn.Sequential(
#             nn.LayerNorm(157696),
#             nn.GELU(),
#             nn.Linear(157696, h),
#             nn.LayerNorm(h),
#             nn.GELU(),
#             nn.Linear(h, h),
#             nn.LayerNorm(h),
#             nn.GELU(),
#             nn.Linear(h, 1280)
#         )

        
    def mixer_block1(self, h, drop):
        return nn.Sequential(
            nn.LayerNorm(h),
            self.mlp(h, h, drop),  # Token mixing
        )

    def mixer_block2(self, seq_len, drop):
        return nn.Sequential(
            nn.LayerNorm(seq_len),
            self.mlp(seq_len, seq_len, drop)  # Channel mixing
        )
    
    def mlp(self, in_dim, out_dim, drop):
        return nn.Sequential(
            nn.Linear(in_dim, out_dim),
            nn.GELU(),
            nn.Dropout(drop),
            nn.Linear(out_dim, out_dim),
        )
        
    def forward(self, x):
        # make empty tensors for blur and depth outputs
        
        # Initial linear layer
        # x = self.lin0(x)
        
        # Reshape to seq_len by dim
        # x = x.reshape(-1, self.seq_len, self.h)
        
        # Mixer blocks
        residual1 = x
        residual2 = x.permute(0,2,1)
        for block1, block2 in zip(self.mixer_blocks1,self.mixer_blocks2):
            x = block1(x) + residual1
            residual1 = x
            x = x.permute(0,2,1)
            
            x = block2(x) + residual2
            residual2 = x
            x = x.permute(0,2,1)
        
        # Flatten
        x = x.reshape(x.size(0), -1)
        
        if self.is_text:
            c = self.clin1(x)
            i = self.clin2(x)
            i_proj = self.clip_proj(i.reshape(len(i), -1, 768))
            c = c.reshape(len(c), -1, self.clip_size)
            return c, i_proj

        if self.is_pool:
            c_p = self.clin3(x)
            return c_p


        
b_text = BrainNetwork(text=True, pool=False, h=4096, in_dim=2048, seq_len=1, clip_size=2048, out_dim=2048*77).cuda()
# b_pool = BrainNetwork(text=False, pool=True, h=hidden_dim, in_dim=2048, seq_len=seq_len, clip_size=2048, out_dim=2048*77).cuda()
utils.count_params(b_text)
# utils.count_params(b_pool)

# test that the model works on some fake data
b = torch.randn((2,1,4096)).cuda()
print("b.shape",b.shape)
emb, img = b_text(b)
print(emb.shape, img.shape)

param counts:
1,596,367,896 total
1,596,367,896 trainable
b.shape torch.Size([2, 1, 4096])
torch.Size([2, 77, 2048]) torch.Size([2, 257, 768])
