In [None]:
import torch
import torch.nn as nn
import numpy as np
import pandas as pd
from tqdm import tqdm
from utils import load_image, save_image, encode_img, decode_img, to_PIL
import torch.nn.functional as F
from diffusers.models.vae import Decoder
from diffusers.models import AutoencoderKL
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import transforms
from torchvision.io import read_image
from collections import OrderedDict

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

In [None]:
class Voxel2StableDiffusionModel(torch.nn.Module):
    # define the prototype of the module
    def __init__(self, in_dim=39548, h=2048, n_blocks=4):
        super().__init__()

        self.lin0 = nn.Sequential(
            nn.Linear(in_dim, h, bias=False),
            nn.LayerNorm(h),
            nn.SiLU(inplace=True),
            nn.Dropout(0.5),
        )

        self.mlp = nn.ModuleList([
            nn.Sequential(
                nn.Linear(h, h, bias=False),
                nn.LayerNorm(h),
                nn.SiLU(inplace=True),
                nn.Dropout(0.25),
            ) for _ in range(n_blocks)
        ])

        self.lin1 = nn.Linear(h, 16384, bias=False)
        self.norm = nn.GroupNorm(1, 64)

        self.upsampler = Decoder(
            in_channels=64,
            out_channels=4,
            up_block_types=["UpDecoderBlock2D","UpDecoderBlock2D","UpDecoderBlock2D"],
            block_out_channels=[64, 128, 256],
            layers_per_block=1,
        )

    # define how it forward, using the module defined above
    def forward(self, x):
        x = self.lin0(x)
        residual = x
        for res_block in self.mlp:
            x = res_block(x)
            x = x + residual
            residual = x
        x = x.reshape(len(x), -1)
        x = self.lin1(x)
        x = self.norm(x.reshape(x.shape[0], -1, 16, 16).contiguous())
        return self.upsampler(x)

In [None]:
voxel2sd = Voxel2StableDiffusionModel()

In [None]:
voxel2sd.to(device)

In [None]:
# from torchsummary import summary
# device = 'cuda'
# summary(voxel2sd.to(device), input_size=(39548,))

In [None]:
# some hyperparameters
batch_size = 32
num_epochs = 120
num_train = 5000
lr_scheduler = 'cycle'
initial_lr = 1e-3
max_lr = 5e-4
random_seed = 42
train_size = 0.7
valid_size = 1 - train_size
num_workers = torch.cuda.device_count()

In [None]:
# some path information
dataset_path = '../2023-Machine-Learning-Dataset/'
training_path = dataset_path + 'subj0{}/training_split/'
training_fmri_path = training_path + 'training_fmri/'
training_images_path = training_path + 'training_images/'
testing_path = dataset_path + 'subj0{}/test_split/test_fmri/'

In [None]:
class MyDataset(Dataset):
  def __init__(self, fmri_data, images_folder, transform=None):
    self.fmri_data = fmri_data
    self.images_folder = images_folder
    self.image_paths = [f"{images_folder}/{filename}" for filename in os.listdir(images_folder)]
    self.transform = transform

  def __len__(self):
    return len(self.fmri_data)

  def __getitem__(self, idx):
    fmri = self.fmri_data[idx]
    image_path = self.image_paths[idx]
    image = load_image(image_path)

    if(self.transform):
      image = self.transform(image)

    return fmri, image

In [None]:
transform = transforms.Resize([512, 512])

# Load dataset, now only subj01
lh = np.load(training_path.format(1) + 'training_fmri/lh_training_fmri.npy')
rh = np.load(training_path.format(1) + 'training_fmri/rh_training_fmri.npy')
lrh = np.concatenate((lh, rh), axis=1)

my_dataset = MyDataset(lrh, training_images_path.format(1), transform=transform)

In [None]:
# train-val split
generator = torch.Generator().manual_seed(random_seed)
trainset, validset = random_split(my_dataset, [train_size, valid_size], generator=generator)

In [None]:
# build dataloader
train_dataloader = DataLoader(trainset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
val_dataloader = DataLoader(validset, batch_size=batch_size, shuffle=False, num_workers=num_workers)

In [None]:
no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
opt_grouped_parameters = [
    {'params': [p for n, p in voxel2sd.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': 1e-2},
    {'params': [p for n, p in voxel2sd.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
]

In [None]:
vae = AutoencoderKL.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="vae").to(device)

In [None]:
optimizer = torch.optim.AdamW(opt_grouped_parameters, lr=1e-3)
lr_scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=max_lr,
                                            total_steps=num_epochs*((num_train//batch_size)//num_workers),
                                            final_div_factor=1000,
                                            last_epoch=-1, pct_start=2/num_epochs)

In [None]:
epoch = 0
progress_bar = tqdm(range(epoch, num_epochs), ncols=150)
losses = []
val_losses = []
lrs = []

In [None]:
for epoch in progress_bar:
    voxel2sd.train()

    loss_sum = 0
    val_loss_sum = 0

    reconst_fails = []

    for train_i, data in enumerate(train_dataloader):
        voxels, images = data
        voxels = voxels.to(device).float()
        images = images.to(device).float()

        optimizer.zero_grad()
        # run image encoder
        encoded_latents = torch.cat([encode_img(image, vae).to(device) for image in images])
        # MLP forward
        encoded_predict = voxel2sd(voxels)
        # calulate loss
        loss = F.mse_loss(encoded_predict, encoded_latents)
        loss_sum += loss.item()
        losses.append(loss.item())
        lrs.append(optimizer.param_groups[0]['lr'])

        # backward
        loss.backward()
        optimizer.step()
        lr_scheduler.step()

        logs = {
            "train/loss": np.mean(losses[-(train_i+1):]),
            "train/lr": lrs[-1],
            "train/num_steps": len(losses),
            "train/loss_mse": loss_sum / (train_i + 1)
        }

        progress_bar.set_postfix(**logs)

    # After training one epoch, evaluation
    # save ckpt first
    torch.save({
      'epoch': epoch,
      'model_state_dict': voxel2sd.state_dict(),
      'optimizer_state_dict': optimizer.state_dict(),
      'loss': loss,
      }, './Models/{}'.format(epoch)
    )

    # Then evaluate
    voxel2sd.eval()
    for val_i, data in enumerate(val_dataloader):
        voxels, images = data
        voxels = voxels.to(device).float()
        images = images.to(device).float()

        # run image encoder
        encoded_latents = torch.cat([encode_img(image, vae).to(device) for image in images])
        # MLP forward
        encoded_predict = voxel2sd(voxels)
        # calulate loss
        loss = F.mse_loss(encoded_predict, encoded_latents)
        val_loss_sum += loss.item()
        val_losses.append(loss.item())

    # Print results
    logs = {
        "train/loss": np.mean(losses[-(train_i+1):]),
        "val/loss": np.mean(val_losses[-(val_i+1):]),
        "train/lr": lrs[-1],
        "train/num_steps": len(losses),
        "train/loss_mse": loss_sum / (train_i + 1),
        "val/loss_mse": val_loss_sum / (val_i + 1)
    }

    # print(logs)