In [1]:
"""
[Colab initialization]
Initialize Colab in this cell, mount drive and cd to the working directory
"""

# from google.colab import drive
# import os

# drive.mount('/content/drive')
# %cd '/content/drive/My Drive/MindEye'
# os.chdir('/content/drive/My Drive/MindEye')

'\n[Colab initialization]\nInitialize Colab in this cell, mount drive and cd to the working directory\n'

In [2]:
"""
[Package import]
Import some useful packages here
"""

import numpy as np
import torch
import wandb
from tqdm import tqdm
from utils import encode_img, GetRoiMaskedLR
from Models import Voxel2StableDiffusionModel, MyDataset
import torch.nn.functional as F
from diffusers.models import AutoencoderKL
from torch.utils.data import DataLoader, random_split
from torchvision import transforms

In [3]:
"""
[Select Torch Device]
"""

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

cuda


In [4]:
"""
[Load Low-level Model]

Model defined in Models.py
"""
voxel2sd = Voxel2StableDiffusionModel()
voxel2sd.to(device)

Voxel2StableDiffusionModel(
  (lin0): Sequential(
    (0): Linear(in_features=39548, out_features=2048, bias=False)
    (1): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
    (2): SiLU(inplace=True)
    (3): Dropout(p=0.5, inplace=False)
  )
  (mlp): ModuleList(
    (0-3): 4 x Sequential(
      (0): Linear(in_features=2048, out_features=2048, bias=False)
      (1): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
      (2): SiLU(inplace=True)
      (3): Dropout(p=0.3, inplace=False)
    )
  )
  (lin1): Linear(in_features=2048, out_features=16384, bias=False)
  (norm): GroupNorm(1, 64, eps=1e-05, affine=True)
  (upsampler): Decoder(
    (conv_in): Conv2d(64, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (up_blocks): ModuleList(
      (0): UpDecoderBlock2D(
        (resnets): ModuleList(
          (0-1): 2 x ResnetBlock2D(
            (norm1): GroupNorm(32, 256, eps=1e-06, affine=True)
            (conv1): LoRACompatibleConv(256, 256, kernel_size=(3, 3), stri

In [5]:
"""
[Summary of Model]

Our input is a (39548, ) vector, we can see summary in this cell
"""

from torchsummary import summary
summary(voxel2sd, (39548,))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Linear-1                 [-1, 2048]      80,994,304
         LayerNorm-2                 [-1, 2048]           4,096
              SiLU-3                 [-1, 2048]               0
           Dropout-4                 [-1, 2048]               0
            Linear-5                 [-1, 2048]       4,194,304
         LayerNorm-6                 [-1, 2048]           4,096
              SiLU-7                 [-1, 2048]               0
           Dropout-8                 [-1, 2048]               0
            Linear-9                 [-1, 2048]       4,194,304
        LayerNorm-10                 [-1, 2048]           4,096
             SiLU-11                 [-1, 2048]               0
          Dropout-12                 [-1, 2048]               0
           Linear-13                 [-1, 2048]       4,194,304
        LayerNorm-14                 [-

In [6]:
"""
[Hyperparameters]
"""

batch_size = 16
num_epochs = 200
num_train = 5000
lr_scheduler = 'cycle'
initial_lr = 1e-4
max_lr = 5e-4
train_size = 0.7
valid_size = 1 - train_size
num_workers = torch.cuda.device_count()

# We normally only modify the following hyperparameters
random_seed = 42
ROI_num = 1 # For ROI Mask. [1, 2, 3]

In [7]:
"""
[Path information]
"""

dataset_path = '../dataset/'
training_path = dataset_path + 'subj0{}/training_split/'
training_fmri_path = training_path + 'training_fmri/'
training_images_path = training_path + 'training_images/'
testing_path = dataset_path + 'subj0{}/test_split/test_fmri/'

In [8]:
"""
[ROI Mask]
Perform ROI Mask on subject's fMRI data.
We have selected three sets of ROI Masks

Note: In this implementation, we only consider subject 1
"""

lrh = GetRoiMaskedLR(ROI_num, dataset_path)

In [9]:
"""
[Load Training Dataset and Data Preprocessing]
Load training dataset, split dataset into train and validation set
Finally, build a dataloader for convenience
"""

transform = transforms.Resize([512, 512])
my_dataset = MyDataset(lrh, training_images_path.format(1), transform=transform)

# train-val split
generator = torch.Generator().manual_seed(random_seed)
trainset, validset = random_split(my_dataset, [train_size, valid_size], generator=generator)

# build dataloader
train_dataloader = DataLoader(trainset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
val_dataloader = DataLoader(validset, batch_size=batch_size, shuffle=False, num_workers=num_workers)

In [10]:
"""
[Optimizer Initialization]

Here we choose OneCycleLR
"""

no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
opt_grouped_parameters = [
    {'params': [p for n, p in voxel2sd.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': 1e-2},
    {'params': [p for n, p in voxel2sd.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
]

optimizer = torch.optim.AdamW(opt_grouped_parameters, lr=initial_lr)
lr_scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=max_lr,
                                            total_steps=num_epochs*((num_train//batch_size)//num_workers),
                                            final_div_factor=1000,
                                            last_epoch=-1, pct_start=2/num_epochs)
# lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
#                                                    milestones=[50*i for i in range(num_epochs*((num_train//batch_size)//num_workers//50))],
#                                                    gamma=0.1)

In [11]:
"""
[Load VAE]
"""

vae = AutoencoderKL.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="vae").to(device)

In [12]:
"""
[Wandb Initialization]

Login first, then record some information
"""

wandb.login()

# initialize wandb
wandb.init(
    # set the wandb project where this run will be logged
    project="MindEye",

    # track hyperparameters and run metadata
    config={
        "learning_rate": initial_lr,
        "architecture": "MLP",
        "dataset": "NSD",
        "epochs": num_epochs,
        "random_seed": random_seed,
        "train_size": train_size,
        "valid_size": valid_size
    }
)

[34m[1mwandb[0m: Currently logged in as: [33mkoioslin[0m ([33mtkoioslin[0m). Use [1m`wandb login --relogin`[0m to force relogin


In [13]:
"""
[Recording Variables]

Record epoch, loss, leaning rates duing training.
"""

epoch = 0
steps = 0
losses = []
val_losses = []
lrs = []
progress_bar = tqdm(range(epoch, num_epochs), ncols=150)

  0%|                                                                                                                         | 0/200 [00:00<?, ?it/s]

In [14]:
"""
[Resume Training from Pretrained Model]

If you have a model trained before, you may use this cell
to resume the training process
"""

# checkpoint = torch.load('./Models/100') # [TODO] 修改 Model 編號
# voxel2sd.load_state_dict(checkpoint['model_state_dict'])
# optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
# epoch = checkpoint['epoch'] + 1
# loss = checkpoint['loss']
# steps = checkpoint['steps'] + 1

# progress_bar = tqdm(range(epoch, num_epochs), ncols=150)

'\n[Resume Training from Pretrained Model]\n\nIf you have a model trained before, you may use this cell\nto resume the training process\n'

In [None]:
"""
[Training Process]

Training Process will starts from here
"""

for epoch in progress_bar:
    voxel2sd.train()

    loss_sum = 0
    val_loss_sum = 0

    reconst_fails = []

    for train_i, data in enumerate(train_dataloader):
        voxels, images = data
        voxels = voxels.to(device).float()
        images = images.to(device).float()

        optimizer.zero_grad()
        # run image encoder
        encoded_latents = torch.cat([encode_img(image, vae).to(device) for image in images])
        # MLP forward
        encoded_predict = voxel2sd(voxels)
        # calulate loss
        loss = F.l1_loss(encoded_predict, encoded_latents)
        loss_sum += loss.item()
        losses.append(loss.item())
        lrs.append(optimizer.param_groups[0]['lr'])
        steps += 1

        # backward
        loss.backward()
        optimizer.step()
        lr_scheduler.step()

        logs = {
            "train/loss": np.mean(losses[-(train_i+1):]),
            "train/lr": lrs[-1],
            "train/num_steps": steps,
            "train/loss_mse": loss_sum / (train_i + 1)
        }
        wandb.log(logs)

        progress_bar.set_postfix(**logs)

    # After training one epoch, evaluation
    voxel2sd.eval()
    for val_i, data in enumerate(val_dataloader):
        voxels, images = data
        voxels = voxels.to(device).float()
        images = images.to(device).float()

        # run image encoder
        encoded_latents = torch.cat([encode_img(image, vae).to(device) for image in images])
        # MLP forward
        encoded_predict = voxel2sd(voxels)
        # calulate loss
        loss = F.l1_loss(encoded_predict, encoded_latents)
        val_loss_sum += loss.item()
        val_losses.append(loss.item())

    # Print results
    logs = {
        "train/loss": np.mean(losses[-(train_i+1):]),
        "val/loss": np.mean(val_losses[-(val_i+1):]),
        "train/lr": lrs[-1],
        "train/num_steps": steps,
        "train/loss_mse": loss_sum / (train_i + 1),
        "val/loss_mse": val_loss_sum / (val_i + 1)
    }
    wandb.log(logs)

    # save ckpt first
    print('saving model')
    torch.save({
      'epoch': epoch,
      'model_state_dict': voxel2sd.state_dict(),
      'optimizer_state_dict': optimizer.state_dict(),
      'loss': loss,
      'steps': steps,
      }, '../Models/{}'.format(epoch)
    )
    print('model saved')
    

    # print(logs)