In [1]:
# from google.colab import drive
# import os

# drive.mount('/content/drive')
# %cd '/content/drive/My Drive/MindEye'
# os.chdir('/content/drive/My Drive/MindEye')

In [2]:
import os
import numpy as np
from pathlib import Path
from PIL import Image
import matplotlib
from matplotlib import pyplot as plt
from nilearn import datasets
from nilearn import plotting

import torch
import torch.nn as nn
import wandb
import pandas as pd
from tqdm import tqdm
from utils import load_image, save_image, encode_img, decode_img, to_PIL
import torch.nn.functional as F
from diffusers.models.vae import Decoder
from diffusers.models import AutoencoderKL
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import transforms
from torchvision.io import read_image
from collections import OrderedDict

In [3]:
def DoOneRoiMask(data_path, subject_num, data, LR, roi, maskedFmri):

  if roi in ["V1v", "V1d", "V2v", "V2d", "V3v", "V3d", "hV4"]:
      roi_class = 'prf-visualrois'
  elif roi in ["EBA", "FBA-1", "FBA-2", "mTL-bodies"]:
      roi_class = 'floc-bodies'
  elif roi in ["OFA", "FFA-1", "FFA-2", "mTL-faces", "aTL-faces"]:
      roi_class = 'floc-faces'
  elif roi in ["OPA", "PPA", "RSC"]:
      roi_class = 'floc-places'
  elif roi in ["OWFA", "VWFA-1", "VWFA-2", "mfs-words", "mTL-words"]:
      roi_class = 'floc-words'
  elif roi in ["early", "midventral", "midlateral", "midparietal", "ventral", "lateral", "parietal"]:
      roi_class = 'streams'

  # load mask file
  roi_space_dir = os.path.join(data_path, 'subj0{}'.format(subject_num), 'roi_masks',
      LR[0]+'h.'+roi_class+'_space.npy')
  roi_map_dir = os.path.join(data_path, 'subj0{}'.format(subject_num), 'roi_masks',
      'mapping_'+roi_class+'.npy')
  roi_space = np.load(roi_space_dir)
  roi_map = np.load(roi_map_dir, allow_pickle=True).item()

  # Select the vertices corresponding to the ROI of interest
  roi_mapping = list(roi_map.keys())[list(roi_map.values()).index(roi)]

  target_index = np.where(np.isin(roi_space, roi_mapping))[0]

  for i in target_index:
    maskedFmri[:, i] = data[:, i]

  return maskedFmri

In [4]:
def GetRoiMaskedFmri(data_path, subject_num, LR, regions):
  """
    data_path: string -- path to folder for subj0x
    subject_num: int -- subject number for training ex: 1
    LR: string "left" or "right" -- specify the fmri data is left or right hemisphere
    regoins: list of strings of interseted regions, split by bankspace -- ex: ["FFA-1", "OPA"]

    return: masked fmri data -- (n, 19004) or (n, 20054)
  """
  if LR == "left":
    data = np.load(os.path.join(data_path, 'subj0{}'.format(subject_num), 'training_split/training_fmri',
    LR[0]+'h_training_fmri.npy'))
    assert (data.shape[1] == 19004)
  elif LR == "right":
    data = np.load(os.path.join(data_path, 'subj0{}'.format(subject_num), 'training_split/training_fmri',
    LR[0]+'h_training_fmri.npy'))
    assert (data.shape[1] == 20544)

  maskedFmriData = np.zeros_like(data, dtype=float)

  for region in regions:
    DoOneRoiMask(data_path=data_path, subject_num=subject_num, data=data, LR=LR, roi=region, maskedFmri=maskedFmriData)

  return maskedFmriData

In [5]:
# usage

# determine the datapath and training subject
data_path = './dataset'
subject_num = 1

# get masked fmri data
ROIs_test1 = ["V1v", "V1d", "V2v", "V2d", "V3v", "V3d", "hV4", "FFA-1", "FFA-2", "PPA"]
ROIs_test2 = ["EBA", "FBA-1", "FBA-2", "mTL-bodies", "OFA", "FFA-1", "FFA-2", "mTL-faces", "aTL-faces", "OPA", "PPA", "RSC"]
ROIs_test3 = ["early", "midventral", "midlateral", "midparietal", "ventral", "lateral", "parietal"]

lh = GetRoiMaskedFmri(data_path=data_path, subject_num=subject_num, LR="left", regions=ROIs_test1)
rh = GetRoiMaskedFmri(data_path=data_path, subject_num=subject_num, LR="right", regions=ROIs_test1)
lrh = np.concatenate((lh, rh), axis=1)

In [6]:
wandb.login()

[34m[1mwandb[0m: W&B API key is configured. Use [1m`wandb login --relogin`[0m to force relogin


True

In [7]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

cuda


In [8]:
class Voxel2StableDiffusionModel(torch.nn.Module):
    # define the prototype of the module
    def __init__(self, in_dim=39548, h=2048, n_blocks=4):
        super().__init__()

        self.lin0 = nn.Sequential(
            nn.Linear(in_dim, h, bias=False),
            nn.LayerNorm(h),
            nn.SiLU(inplace=True),
            nn.Dropout(0.5),
        )

        self.mlp = nn.ModuleList([
            nn.Sequential(
                nn.Linear(h, h, bias=False),
                nn.LayerNorm(h),
                nn.SiLU(inplace=True),
                nn.Dropout(0.3),
            ) for _ in range(n_blocks)
        ])
        
        self.lin1 = nn.Linear(h, 16384, bias=False)
        self.norm = nn.GroupNorm(1, 64)

        def init_weights(m):
            if isinstance(m, nn.Linear):
                torch.nn.init.xavier_uniform_(m.weight)

        init_weights(self.lin0)
        self.mlp.apply(init_weights)
        init_weights(self.lin1)

        self.upsampler = Decoder(
            in_channels=64,
            out_channels=4,
            up_block_types=["UpDecoderBlock2D","UpDecoderBlock2D","UpDecoderBlock2D"],
            block_out_channels=[64, 128, 256],
            layers_per_block=1,
        )
        for parm in self.upsampler.parameters():
            parm.require_grad = False
        self.upsampler.eval()

    # define how it forward, using the module defined above
    def forward(self, x):
        x = self.lin0(x)
        residual = x
        for res_block in self.mlp:
            x = res_block(x)
            x = x + residual
            residual = x
        x = x.reshape(len(x), -1)
        x = self.lin1(x)
        x = self.norm(x.reshape(x.shape[0], -1, 16, 16).contiguous())
        return self.upsampler(x)

In [9]:
voxel2sd = Voxel2StableDiffusionModel()
voxel2sd.to(device)

Voxel2StableDiffusionModel(
  (lin0): Sequential(
    (0): Linear(in_features=39548, out_features=2048, bias=False)
    (1): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
    (2): SiLU(inplace=True)
    (3): Dropout(p=0.5, inplace=False)
  )
  (mlp): ModuleList(
    (0): Sequential(
      (0): Linear(in_features=2048, out_features=2048, bias=False)
      (1): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
      (2): SiLU(inplace=True)
      (3): Dropout(p=0.3, inplace=False)
    )
    (1): Sequential(
      (0): Linear(in_features=2048, out_features=2048, bias=False)
      (1): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
      (2): SiLU(inplace=True)
      (3): Dropout(p=0.3, inplace=False)
    )
    (2): Sequential(
      (0): Linear(in_features=2048, out_features=2048, bias=False)
      (1): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
      (2): SiLU(inplace=True)
      (3): Dropout(p=0.3, inplace=False)
    )
    (3): Sequential(
      (0): Lin

In [10]:
from torchsummary import summary
summary(voxel2sd, (39548,))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Linear-1                 [-1, 2048]      80,994,304
         LayerNorm-2                 [-1, 2048]           4,096
              SiLU-3                 [-1, 2048]               0
           Dropout-4                 [-1, 2048]               0
            Linear-5                 [-1, 2048]       4,194,304
         LayerNorm-6                 [-1, 2048]           4,096
              SiLU-7                 [-1, 2048]               0
           Dropout-8                 [-1, 2048]               0
            Linear-9                 [-1, 2048]       4,194,304
        LayerNorm-10                 [-1, 2048]           4,096
             SiLU-11                 [-1, 2048]               0
          Dropout-12                 [-1, 2048]               0
           Linear-13                 [-1, 2048]       4,194,304
        LayerNorm-14                 [-

In [11]:
# some hyperparameters
batch_size = 16
num_epochs = 200
num_train = 5000
lr_scheduler = 'cycle'
initial_lr = 1e-4
max_lr = 5e-4
random_seed = 42
train_size = 0.7
valid_size = 1 - train_size
num_workers = torch.cuda.device_count()

In [12]:
# some path information
dataset_path = './dataset/'
training_path = dataset_path + 'subj0{}/training_split/'
training_fmri_path = training_path + 'training_fmri/'
training_images_path = training_path + 'training_images/'
testing_path = dataset_path + 'subj0{}/test_split/test_fmri/'

In [13]:
class MyDataset(Dataset):
  def __init__(self, fmri_data, images_folder, transform=None):
    self.fmri_data = fmri_data
    self.images_folder = images_folder
    self.image_paths = [f"{images_folder}/{filename}" for filename in os.listdir(images_folder)]
    self.transform = transform

  def __len__(self):
    return len(self.fmri_data)

  def __getitem__(self, idx):
    fmri = self.fmri_data[idx]
    image_path = self.image_paths[idx]
    image = load_image(image_path)

    if(self.transform):
      image = self.transform(image)

    return fmri, image

In [14]:
transform = transforms.Resize([512, 512])
my_dataset = MyDataset(lrh, training_images_path.format(1), transform=transform)

In [15]:
# train-val split
generator = torch.Generator().manual_seed(random_seed)
trainset, validset = random_split(my_dataset, [train_size, valid_size], generator=generator)

In [16]:
# build dataloader
train_dataloader = DataLoader(trainset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
val_dataloader = DataLoader(validset, batch_size=batch_size, shuffle=False, num_workers=num_workers)

In [17]:
no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
opt_grouped_parameters = [
    {'params': [p for n, p in voxel2sd.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': 1e-2},
    {'params': [p for n, p in voxel2sd.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
]

In [18]:
vae = AutoencoderKL.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="vae").to(device)

In [19]:
optimizer = torch.optim.AdamW(opt_grouped_parameters, lr=initial_lr)
lr_scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=max_lr,
                                            total_steps=num_epochs*((num_train//batch_size)//num_workers),
                                            final_div_factor=1000,
                                            last_epoch=-1, pct_start=2/num_epochs)
# lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
#                                                    milestones=[50*i for i in range(num_epochs*((num_train//batch_size)//num_workers//50))],
#                                                    gamma=0.1)

In [20]:
# initialize wandb
wandb.init(
    # set the wandb project where this run will be logged
    project="MindEye",

    # track hyperparameters and run metadata
    config={
        "learning_rate": initial_lr,
        "architecture": "MLP",
        "dataset": "NSD",
        "epochs": num_epochs,
        "random_seed": random_seed,
        "train_size": train_size,
        "valid_size": valid_size
    }
)

In [21]:
# epoch = 0

losses = []
val_losses = []
lrs = []

In [22]:
checkpoint = torch.load('./Models/100')
voxel2sd.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch'] + 1
loss = checkpoint['loss']

progress_bar = tqdm(range(epoch, num_epochs), ncols=150)

voxel2sd.eval()

  0%|                                                                                                                          | 0/99 [00:00<?, ?it/s]

Voxel2StableDiffusionModel(
  (lin0): Sequential(
    (0): Linear(in_features=39548, out_features=2048, bias=False)
    (1): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
    (2): SiLU(inplace=True)
    (3): Dropout(p=0.5, inplace=False)
  )
  (mlp): ModuleList(
    (0): Sequential(
      (0): Linear(in_features=2048, out_features=2048, bias=False)
      (1): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
      (2): SiLU(inplace=True)
      (3): Dropout(p=0.3, inplace=False)
    )
    (1): Sequential(
      (0): Linear(in_features=2048, out_features=2048, bias=False)
      (1): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
      (2): SiLU(inplace=True)
      (3): Dropout(p=0.3, inplace=False)
    )
    (2): Sequential(
      (0): Linear(in_features=2048, out_features=2048, bias=False)
      (1): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
      (2): SiLU(inplace=True)
      (3): Dropout(p=0.3, inplace=False)
    )
    (3): Sequential(
      (0): Lin

In [None]:
for epoch in progress_bar:
    voxel2sd.train()

    loss_sum = 0
    val_loss_sum = 0

    reconst_fails = []

    for train_i, data in enumerate(train_dataloader):
        voxels, images = data
        voxels = voxels.to(device).float()
        images = images.to(device).float()

        optimizer.zero_grad()
        # run image encoder
        encoded_latents = torch.cat([encode_img(image, vae).to(device) for image in images])
        # MLP forward
        encoded_predict = voxel2sd(voxels)
        # calulate loss
        loss = F.l1_loss(encoded_predict, encoded_latents)
        loss_sum += loss.item()
        losses.append(loss.item())
        lrs.append(optimizer.param_groups[0]['lr'])

        # backward
        loss.backward()
        optimizer.step()
        lr_scheduler.step()

        logs = {
            "train/loss": np.mean(losses[-(train_i+1):]),
            "train/lr": lrs[-1],
            "train/num_steps": 100*219+len(losses),
            "train/loss_mse": loss_sum / (train_i + 1)
        }
        wandb.log(logs)

        progress_bar.set_postfix(**logs)

    # After training one epoch, evaluation
    # save ckpt first
    print('saving model')
    torch.save({
      'epoch': epoch,
      'model_state_dict': voxel2sd.state_dict(),
      'optimizer_state_dict': optimizer.state_dict(),
      'loss': loss,
      }, './Models/{}'.format(epoch)
    )
    print('model saved')
    

    # print(logs)

  0%|        | 0/99 [04:11<?, ?it/s, _runtime=264, _timestamp=1.7e+9, train/loss=0.499, train/loss_mse=0.499, train/lr=3.65e-5, train/num_steps=21975]