In [1]:
import numpy as np
import pandas as pd
import csv
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, random_split
from functools import partial
import torch.nn.functional as F
from tqdm import tqdm
from PIL import Image
from torchvision import transforms
from utils import load_image, save_image, encode_img, decode_img, to_PIL

import io
from model import Voxel2Blip
import os
from transformers import BlipImageProcessor
from diffusers.models import AutoencoderKL


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
loss = nn.BCELoss()


In [8]:
loss(torch.tensor([1.,0.,1.]),torch.tensor([0.,0.,1]))


tensor(33.3333)

In [2]:
dataset_path = '2023-Machine-Learning-Dataset/'
training_path = dataset_path + 'subj0{}/training_split/'
training_fmri_path = training_path + 'training_fmri/'
training_images_path = training_path + 'training_images/'
testing_path = dataset_path + 'subj0{}/test_split/test_fmri/'
image_infos_path = dataset_path + 'image_infos/subj0{}_infos_train.csv'

In [3]:
class MyDataset(Dataset):
  def __init__(self, fmri_data, images_data, transform=None ,device="cpu"):
    self.fmri_data = torch.from_numpy(fmri_data).to(device)
    self.images_data = images_data.to(device)
    
    # self.images_folder = images_folder
    # self.image_paths = [f"{images_folder}/{filename}" for filename in os.listdir(images_folder)]
    self.transform = transform

  def __len__(self):
    return len(self.fmri_data)

  def __getitem__(self, idx):
    fmri = self.fmri_data[idx]
    image = self.images_data[idx]
    # image_path = self.image_paths[idx]
    # image = load_image(image_path)

    # if(self.transform):
    #   image = self.transform(image)

    return fmri, image
transform = transforms.Resize([512, 512])

# Load dataset, now only subj01
lh = np.load(training_path.format(1) + 'training_fmri/lh_training_fmri.npy')
rh = np.load(training_path.format(1) + 'training_fmri/rh_training_fmri.npy')
lrh = np.concatenate((lh, rh), axis=1)

images = torch.load('subj01_image_blip_encoded.pt')['images']

my_dataset = MyDataset(lrh, images, transform=transform)

In [4]:
commandline_args = os.environ.get('COMMANDLINE_ARGS', "--skip-torch-cuda-test --no-half")
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# device= torch.device('cpu')


In [5]:
imgProcessor = BlipImageProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
vae = AutoencoderKL.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="vae").to(device)

In [6]:
# imageList = []
# for i in range(5000):
#     img = Image.open(f"{training_images_path.format(1)}{i}.png")
#     img = imgProcessor(img)['pixel_values'][0]
#     img = torch.from_numpy(img).to(device)
#     img = encode_img(img,vae)
#     imageList.append(img)
#     # (encode_img(torch.from_numpy(imgProcessor(image)['pixel_values'][0]).to(device
# torch.cat(imageList).shape
# torch.save({
#         'images': torch.cat(imageList),
#         }, 'subj01_image_blip_encoded.pt'
# )

In [7]:
epoch = 0
batch_size = 16
num_epochs = 2000
num_train = 5000
lr_scheduler = 'cycle'
initial_lr = 1e-4
max_lr = 5e-4
random_seed = 42
train_size = 0.9
valid_size = 1 - train_size
num_workers = torch.cuda.device_count()
losses = []
losses_val = []
lrs=[]

In [8]:
generator = torch.Generator().manual_seed(42)
trainset, validset = random_split(my_dataset, [train_size, valid_size], generator=generator)
# build dataloader
train_dataloader = DataLoader(trainset, batch_size=16, shuffle=True, num_workers=0)
val_dataloader = DataLoader(validset, batch_size=16, shuffle=False, num_workers=0)

In [9]:
model = Voxel2Blip().to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=initial_lr)
lr_scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=max_lr,
                                            total_steps=num_epochs*((num_train//batch_size)//num_workers),
                                            final_div_factor=1000,
                                            last_epoch=-1, pct_start=2/num_epochs)

In [10]:
checkpoint = torch.load('./ModelsBlip/1000.4_mse')
model.load_state_dict(checkpoint['model_state_dict'])
epoch = checkpoint["epoch"]
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
del checkpoint

In [11]:
progress_bar = tqdm(range(epoch, num_epochs), ncols=150)

  0%|                                                                                                                        | 0/1001 [00:00<?, ?it/s]

In [12]:
for epoch in progress_bar:
    model.train()

    loss_sum = 0
    val_loss_sum = 0

    reconst_fails = []

    for train_i, data in enumerate(train_dataloader):
        voxels, images = data
        voxels = voxels.to(device).float()
        images = images.to(device).float()

        optimizer.zero_grad()
        # run image encoder
        
        encoded_img = images #torch.cat([(encode_img(image.to(device), vae.to(device))) for image in images])
        # MLP forward
        predict = model(voxels)
        # calulate loss
        loss = F.mse_loss(predict, encoded_img)
        loss_sum += loss.item()
        losses.append(loss.item())
        lrs.append(optimizer.param_groups[0]['lr'])

        # backward
        loss.backward()
        optimizer.step()
        lr_scheduler.step()

    
        # wandb.log(logs)

        

    # After training one epoch, evaluation
    # save ckpt first
    # print('saving model')
    if (epoch+1) % 100 == 0:
        torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': loss,
        }, './ModelsBlip/{}'.format(epoch+1)
        )
    # print('model saved')
    model.eval()
    
    for val_i, data in enumerate(val_dataloader):
        voxels, images = data
        voxels = voxels.to(device).float()
        images = images.to(device).float()

        # run image encoder
        
        encoded_img = images
        # MLP forward
        predict = model(voxels)
        # calulate loss
        loss = F.mse_loss(predict, encoded_img)
        losses_val.append(loss.item())
    np.mean(losses_val)
    logs = {
            "train/loss": np.mean(losses[-(train_i+1):]),
            "val/loss": np.mean(losses_val[-(val_i+1):]),
            "train/lr": lrs[-1],
            "train/num_steps": len(losses)
        }
    progress_bar.set_postfix(**logs)

    # print(logs)

100%|██████████████████████████████| 1001/1001 [5:36:18<00:00, 20.16s/it, train/loss=0.0816, train/lr=0.000288, train/num_steps=282282, val/loss=1.62]
