In [1]:
import torch
import torch.nn as nn
import numpy as np
import pandas as pd
from tqdm import tqdm
from utils import load_image, save_image, encode_img, decode_img, to_PIL, transform
import torch.nn.functional as F
from diffusers.models.vae import Decoder
from diffusers.models import AutoencoderKL
from torch.utils.data import DataLoader, random_split
from collections import OrderedDict

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

cuda


In [3]:
class Voxel2StableDiffusionModel(torch.nn.Module):
    # define the prototype of the module
    def __init__(self, in_dim=39548, h=8192, n_blocks=8):
        super().__init__()
        
        self.lin0 = nn.Sequential(
            nn.Linear(in_dim, h, bias=False),
            nn.LayerNorm(h),
            nn.SiLU(inplace=True),
            nn.Dropout(0.5),
        )

        self.mlp = nn.ModuleList([
            nn.Sequential(
                nn.Linear(h, h, bias=False),
                nn.LayerNorm(h),
                nn.SiLU(inplace=True),
                nn.Dropout(0.25),
            ) for _ in range(n_blocks)
        ])

        self.lin1 = nn.Linear(h, 16384, bias=False)
        self.norm = nn.GroupNorm(1, 64)
        
        self.upsampler = Decoder(
            in_channels=64,
            out_channels=4,
            up_block_types=["UpDecoderBlock2D","UpDecoderBlock2D","UpDecoderBlock2D"],
            block_out_channels=[64, 128, 256],
            layers_per_block=1,
        )

    # define how it forward, using the module defined above
    def forward(self, x):
        x = self.lin0(x)
        residual = x
        for res_block in self.mlp:
            x = res_block(x)
            x = x + residual
            residual = x
        x = x.reshape(len(x), -1)
        x = self.lin1(x)
        x = self.norm(x.reshape(x.shape[0], -1, 16, 16).contiguous())
        return self.upsampler(x)

In [4]:
voxel2sd = Voxel2StableDiffusionModel()

In [5]:
voxel2sd.to(device)

Voxel2StableDiffusionModel(
  (lin0): Sequential(
    (0): Linear(in_features=39548, out_features=8192, bias=False)
    (1): LayerNorm((8192,), eps=1e-05, elementwise_affine=True)
    (2): SiLU(inplace=True)
    (3): Dropout(p=0.5, inplace=False)
  )
  (mlp): ModuleList(
    (0): Sequential(
      (0): Linear(in_features=8192, out_features=8192, bias=False)
      (1): LayerNorm((8192,), eps=1e-05, elementwise_affine=True)
      (2): SiLU(inplace=True)
      (3): Dropout(p=0.25, inplace=False)
    )
    (1): Sequential(
      (0): Linear(in_features=8192, out_features=8192, bias=False)
      (1): LayerNorm((8192,), eps=1e-05, elementwise_affine=True)
      (2): SiLU(inplace=True)
      (3): Dropout(p=0.25, inplace=False)
    )
    (2): Sequential(
      (0): Linear(in_features=8192, out_features=8192, bias=False)
      (1): LayerNorm((8192,), eps=1e-05, elementwise_affine=True)
      (2): SiLU(inplace=True)
      (3): Dropout(p=0.25, inplace=False)
    )
    (3): Sequential(
      (0): 

In [6]:
# some hyperparameters
batch_size = 8
num_epochs = 120
num_train = 5000
lr_scheduler = 'cycle'
initial_lr = 1e-3
max_lr = 5e-4
random_seed = 42
train_size = 0.7
valid_size = 1 - train_size
num_workers = torch.cuda.device_count()

In [7]:
# some path information
dataset_path = './dataset/'
training_path = dataset_path + 'subj0{}/training_split/'
training_fmri_path = training_path + 'training_fmri/'
training_images_path = training_path + 'training_images/'
testing_path = dataset_path + 'subj0{}/test_split/test_fmri/'

In [8]:
# Load dataset, now only subj01
lh = np.load(training_path.format(1) + 'training_fmri/lh_training_fmri.npy')
rh = np.load(training_path.format(1) + 'training_fmri/rh_training_fmri.npy')
lrh = np.concatenate((lh, rh), axis=1)

from IPython.display import clear_output
dataset = {}
for i in range(5000):
    clear_output(wait=True)
    print(i)
    dataset[i] = {'voxel': lrh[i],
                  'image_path': training_images_path.format(1) + f'{i}.png'
                 }

4999


In [9]:
# train-val split
generator = torch.Generator().manual_seed(random_seed)
trainset, validset = random_split(dataset, [train_size, valid_size], generator=generator)

In [10]:
# build dataloader
train_dataloader = DataLoader(trainset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
val_dataloader = DataLoader(validset, batch_size=batch_size, shuffle=False, num_workers=num_workers)

In [11]:
no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
opt_grouped_parameters = [
    {'params': [p for n, p in voxel2sd.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': 1e-2},
    {'params': [p for n, p in voxel2sd.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
]

In [12]:
vae = AutoencoderKL.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="vae").to(device)

In [13]:
optimizer = torch.optim.AdamW(opt_grouped_parameters, lr=1e-3)
lr_scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=max_lr, 
                                            total_steps=num_epochs*((num_train//batch_size)//num_workers), 
                                            final_div_factor=1000,
                                            last_epoch=-1, pct_start=2/num_epochs)

In [14]:
epoch = 0
progress_bar = tqdm(range(epoch, num_epochs), ncols=150)
losses = []
val_losses = []
lrs = []

  0%|                                                                                                                         | 0/120 [00:00<?, ?it/s]

In [15]:
for epoch in progress_bar:
    voxel2sd.train()

    loss_sum = 0
    val_loss_sum = 0

    reconst_fails = []

    for train_i, data in enumerate(train_dataloader):
        voxels = data['voxel'].to(device).float()
        image_paths = data['image_path']
        images = []
        for path in image_paths:
            images.append(transform(load_image(path)).to(device).float())
        
        optimizer.zero_grad()
        # run image encoder
        encoded_latents = torch.cat([encode_img(image, vae) for image in images])
        # MLP forward
        encoded_predict = voxel2sd(voxels)
        # calulate loss
        loss = F.mse_loss(encoded_predict, encoded_latents)
        loss_sum += loss.item()
        losses.append(loss.item())
        lrs.append(optimizer.param_groups[0]['lr'])

        # backward
        loss.backward()
        optimizer.step()
        lr_scheduler.step()

        logs = OrderedDict(
            train_loss=np.mean(losses[-(train_i+1):]),
            lr=lrs[-1],
        )
        progress_bar.set_postfix(**logs)

    # After training, evaluation
    voxel2sd.eval()
    for val_i, data in enumerate(val_dataloader):
        voxels = data['voxel'].to(device).float()
        image_paths = data['image_path']
        images = []
        for path in image_paths:
            images.append(transform(load_image(path)).to(device).float())

        # run image encoder
        encoded_latents = torch.cat([encode_img(image, vae) for image in images])
        # MLP forward
        encoded_predict = voxel2sd(voxels)
        # calulate loss
        loss = F.mse_loss(encoded_predict, encoded_latents)
        val_loss_sum += loss.item()
        val_losses.append(loss.item())

    # Print results
    logs = {
        "train/loss": np.mean(losses[-(train_i+1):]),
        "val/loss": np.mean(val_losses[-(val_i+1):]),
        "train/lr": lrs[-1],
        "train/num_steps": len(losses),
        "train/loss_mse": loss_sum / (train_i + 1),
        "val/loss_mse": val_loss_sum / (val_i + 1)
    }
    print(logs)

  0%|                                                                                                                         | 0/120 [00:08<?, ?it/s]


OutOfMemoryError: CUDA out of memory. Tried to allocate 1.21 GiB (GPU 0; 6.00 GiB total capacity; 15.28 GiB already allocated; 0 bytes free; 15.38 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF