In [1]:
"""
[Package import]
Import some useful packages here
"""

import numpy as np
import torch
import wandb
from tqdm import tqdm
from utils import encode_img, GetRoiMaskedLR
from Models import Voxel2StableDiffusionModel, MyDataset
import torch.nn.functional as F
from diffusers.models import AutoencoderKL
from torch.utils.data import DataLoader, random_split
from torchvision import transforms

In [2]:
"""
[Select Torch Device]
"""

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

cuda


In [3]:
"""
[Hyperparameters]
"""

train_size = 0.7
valid_size = 1 - train_size
num_workers = torch.cuda.device_count()

# We normally only modify the following hyperparameters
random_seed = 42
ROI_num = 1 # For ROI Mask. [1, 2, 3]

In [4]:
"""
[Path information]
"""

dataset_path = '../dataset/'
training_path = dataset_path + 'subj0{}/training_split/'
training_fmri_path = training_path + 'training_fmri/'
training_images_path = training_path + 'training_images/'
training_VAE_laten_path = training_path + 'fMRI_VAE.npy'
testing_path = dataset_path + 'subj0{}/test_split/test_fmri/'

"""
[Load VAE]
"""

vae = AutoencoderKL.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="vae").to(device)

"""
[ROI-Masks]
"""

ROIs_test1 = ["V1v", "V1d", "V2v", "V2d", "V3v", "V3d", "hV4", "FFA-1", "FFA-2", "PPA"]
ROIs_test2 = ["EBA", "FBA-1", "FBA-2", "mTL-bodies", "OFA", "FFA-1", "FFA-2", "mTL-faces", "aTL-faces", "OPA", "PPA", "RSC"]
ROIs_test3 = ["early", "midventral", "midlateral", "midparietal", "ventral", "lateral", "parietal"]
ROIs_tests = [0, ROIs_test1, ROIs_test2, ROIs_test3]

In [5]:
def project_to_CLIP(image):
    """
    Project a image to CLIP space by CLIP model.

    Parameters:
    - image: Image to project.

    Returns:
    - Projected CLIP latent
    """
    with torch.no_grad():
        image = preprocess(image).unsqueeze(0).to(device)
        features = model.encode_image(image)
    return features

def similarity(image1_features, image2_features):
    """
    Calculate similarity of two images by cosine similarity

    Parameters:
    - image1_features: Feature of 1st image
    - image2_features: Feature of 2nd image

    Returns:
    - Similarity of two images
    """

    cos = torch.nn.CosineSimilarity(dim=0)
    return cos(image1_features,image2_features).item()

In [7]:
model_paths = ['../Models/92/199', '../Models/94/199', '../Models/97/199', '../Models/100/199']
select_ROIs = [3, 1, 3, 3]
random_seeds = [3, 2423, 849, 1921]
voxel2sd = Voxel2StableDiffusionModel().to(device)
train_losses = {}
train_loss_avgs = {}
valid_losses = {}
valid_loss_avgs = {}

for model_idx in range(4):
    print(f'Processing model {model_paths[model_idx]}')
    voxel2sd.load_state_dict(torch.load(model_paths[model_idx])['model_state_dict'])
    voxel2sd.eval()
    select_ROI = select_ROIs[model_idx]
    random_seed = random_seeds[model_idx]

    # Build data loader
    transform = transforms.Resize([512, 512])
    lrh = GetRoiMaskedLR(select_ROI, '../dataset/')
    my_dataset = MyDataset(lrh, training_VAE_laten_path.format(1), transform=transform)
    generator = torch.Generator().manual_seed(random_seed)
    trainset, validset = random_split(my_dataset, [train_size, valid_size], generator=generator)
    train_dataloader = DataLoader(trainset, batch_size=64, shuffle=False, num_workers=num_workers)
    val_dataloader = DataLoader(validset, batch_size=64, shuffle=False, num_workers=num_workers)
    
    # Process on Training Dataset
    train_loss = []
    for train_i, data in enumerate(train_dataloader):
        # Get data
        voxels, image_latents = data
        voxels = voxels.to(device).float()
        image_latents = image_latents.to(device).float()

        # Calculate similarity on voxel
        target_latent = torch.cat([x for x in image_latents])
        predict_latent = voxel2sd(voxels)
        loss = F.l1_loss(target_latent, predict_latent)
        train_loss.append(loss.item())
    
    train_losses[model_idx] = train_loss
    train_loss_avgs[model_idx] = np.mean(train_loss)
    print(f'Average loss on training set: {np.mean(train_loss)}')

    #################################################################################################################
    # Process on Validation Dataset
    valid_loss = []
    for valid_i, data in enumerate(val_dataloader):
        # Get data
        voxels, image_latents = data
        voxels = voxels.to(device).float()
        image_latents = image_latents.to(device).float()

        # Calculate similarity on voxel
        target_latent = torch.cat([x for x in image_latents])
        predict_latent = voxel2sd(voxels)
        loss = F.l1_loss(target_latent, predict_latent)
        valid_loss.append(loss.item())
    
    valid_losses[model_idx] = valid_loss
    valid_loss_avgs[model_idx] = np.mean(valid_loss)
    print(f'Average loss on validation set: {np.mean(valid_loss)}')

Processing model ../Models/92/199
Average loss on training set: 0.3989420733668587
Average loss on validation set: 0.6483591025074323
Processing model ../Models/94/199
Average loss on training set: 0.39130232659253206
Average loss on validation set: 0.659694088002046
Processing model ../Models/97/199
Average loss on training set: 0.4028529248454354
Average loss on validation set: 0.6492126733064651
Processing model ../Models/100/199
Average loss on training set: 0.370662676746195
Average loss on validation set: 0.6589636156956354
