In [None]:
"""
[Package import]
Import some useful packages here
"""

import clip
import torch
import numpy as np
from PIL import Image
from diffusers import StableDiffusionImg2ImgPipeline
from diffusers.models import AutoencoderKL
from Models import Voxel2StableDiffusionModel, MyDataset
from transformers import BlipProcessor,BlipForConditionalGeneration
from utils import load_image, to_PIL, transform, decode_img, GetROI, GetRoiMaskedLR, DoOneRoiMask
from torchvision import transforms
from torch.utils.data import DataLoader, random_split
import warnings
warnings.filterwarnings('ignore')

In [None]:
"""
[Load models]
"""

device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-L/14", device=device)
vae = AutoencoderKL.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="vae").to(device)
processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
modelGeneration = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large", torch_dtype=torch.float16).to(device)
sdPipe = StableDiffusionImg2ImgPipeline.from_pretrained("runwayml/stable-diffusion-v1-5",torch_dtype=torch.float16).to(device)

"""
[Hyperparameters and arguments]
"""
train_size = 0.7
valid_size = 1 - train_size
num_workers = torch.cuda.device_count()
ROIs_test1 = ["V1v", "V1d", "V2v", "V2d", "V3v", "V3d", "hV4", "FFA-1", "FFA-2", "PPA"]
ROIs_test2 = ["EBA", "FBA-1", "FBA-2", "mTL-bodies", "OFA", "FFA-1", "FFA-2", "mTL-faces", "aTL-faces", "OPA", "PPA", "RSC"]
ROIs_test3 = ["early", "midventral", "midlateral", "midparietal", "ventral", "lateral", "parietal"]
ROIs_tests = [0, ROIs_test1, ROIs_test2, ROIs_test3]

"""
[Path information]
"""

dataset_path = '../dataset/'
training_path = dataset_path + 'subj0{}/training_split/'
training_fmri_path = training_path + 'training_fmri/'
training_images_path = training_path + 'training_images/'
training_VAE_laten_path = training_path + 'fMRI_VAE.npy'
testing_path = dataset_path + 'subj0{}/test_split/test_fmri/'

In [None]:
def project_to_CLIP(image):
    """
    Project a image to CLIP space by CLIP model.

    Parameters:
    - image: Image to project.

    Returns:
    - Projected CLIP latent
    """
    with torch.no_grad():
        image = preprocess(image).unsqueeze(0).to(device)
        features = model.encode_image(image)
    return features

def similarity(image1_features, image2_features):
    """
    Calculate similarity of two images by cosine similarity

    Parameters:
    - image1_features: Feature of 1st image
    - image2_features: Feature of 2nd image

    Returns:
    - Similarity of two images
    """

    cos = torch.nn.CosineSimilarity(dim=0)
    return cos(image1_features,image2_features).item()

In [None]:
model_paths = ['../Models/92/199', '../Models/94/199', '../Models/97/199', '../Models/100/199']
select_ROIs = [3, 1, 3, 3]
random_seeds = [3, 2423, 849, 1921]
voxel2sd = Voxel2StableDiffusionModel().to(device)
sims_train_all = {}
avgs_train_all = {}
sims_valid_all = {}
avgs_valid_all = {}

for model_idx in range(4):
    print(f'Processing model {model_paths[model_idx]}')
    voxel2sd.load_state_dict(torch.load(model_paths[model_idx])['model_state_dict'])
    voxel2sd.eval()
    select_ROI = select_ROIs[model_idx]
    random_seed = random_seeds[model_idx]

    # Build data loader
    transform = transforms.Resize([512, 512])
    lrh = GetRoiMaskedLR(select_ROI, '../dataset/')
    my_dataset = MyDataset(lrh, training_VAE_laten_path.format(1), transform=transform)
    generator = torch.Generator().manual_seed(random_seed)
    trainset, validset = random_split(my_dataset, [train_size, valid_size], generator=generator)
    train_dataloader = DataLoader(trainset, batch_size=1, shuffle=False, num_workers=num_workers)
    val_dataloader = DataLoader(validset, batch_size=1, shuffle=False, num_workers=num_workers)
    
    # Process on Training Dataset
    sims = []
    avgs = []
    cnt = 0
    for train_i, data in enumerate(train_dataloader):
        if(cnt == 100):
            break
        else:
            cnt += 1
        # Get data
        voxel, image_latent = data
        voxel = voxel.to(device).float()
        image_latent = image_latent.to(device).float()

        # Mask data
        lh = voxel[:, :19004].cpu()
        rh = voxel[:, 19004:].cpu()
        lhMasked = np.zeros_like(lh, dtype=float)
        rhMasked = np.zeros_like(rh, dtype=float)

        for region in ROIs_tests[select_ROI]:
            DoOneRoiMask(data_path='../dataset/', subject_num=1, data=lh, LR='l', roi=region, maskedFmri=lhMasked)
            DoOneRoiMask(data_path='../dataset/', subject_num=1, data=rh, LR='r', roi=region, maskedFmri=rhMasked)
        voxel = torch.from_numpy(np.concatenate((lhMasked, rhMasked), axis=1)).to(device).float()

        # Calculate similarity on voxel
        target_image = to_PIL(decode_img(image_latent[0], vae)[0])
        target_image_CLIP = project_to_CLIP(target_image)[0]
        latent_image = to_PIL(decode_img(voxel2sd(voxel), vae)[0])
        out = transforms.functional.pil_to_tensor(transforms.Resize((384,384))(latent_image)).to(device)
        out = modelGeneration.generate(out.unsqueeze(0)).to(device)
        prompt = processor.decode(out[0], skip_special_tokens=True).replace('blurry ', '')
        del out
        predict_image = sdPipe(prompt=prompt, image=latent_image,strength=0.75, guidance_scale=6).images[0]
        predict_image_CLIP = project_to_CLIP(predict_image)[0]
        sim = similarity(target_image_CLIP, predict_image_CLIP)
        sims.append(sim)
    
    # calculate mean sim
    average = 0
    for i in sims:
        average += i
    average /= len(sims)
    sims_train_all[model_idx] = sims
    avgs_train_all[model_idx] = average
    print(f'Average similarity on training set: {average}')

    #################################################################################################################
    # Process on Validation Dataset
    sims = []
    avgs = []
    cnt = 0
    for valid_i, data in enumerate(val_dataloader):
        if(cnt == 100):
            break
        else:
            cnt += 1
        # Get data
        voxel, image_latent = data
        voxel = voxel.to(device).float()
        image_latent = image_latent.to(device).float()

        # Mask data
        lh = voxel[:, :19004].cpu()
        rh = voxel[:, 19004:].cpu()
        lhMasked = np.zeros_like(lh, dtype=float)
        rhMasked = np.zeros_like(rh, dtype=float)

        for region in ROIs_tests[select_ROI]:
            DoOneRoiMask(data_path='../dataset/', subject_num=1, data=lh, LR='l', roi=region, maskedFmri=lhMasked)
            DoOneRoiMask(data_path='../dataset/', subject_num=1, data=rh, LR='r', roi=region, maskedFmri=rhMasked)
        voxel = torch.from_numpy(np.concatenate((lhMasked, rhMasked), axis=1)).to(device).float()

        # Calculate similarity on voxel
        target_image = to_PIL(decode_img(image_latent[0], vae)[0])
        target_image_CLIP = project_to_CLIP(target_image)[0]
        latent_image = to_PIL(decode_img(voxel2sd(voxel), vae)[0])
        out = transforms.functional.pil_to_tensor(transforms.Resize((384,384))(latent_image)).to(device)
        out = modelGeneration.generate(out.unsqueeze(0)).to(device)
        prompt = processor.decode(out[0], skip_special_tokens=True).replace('blurry ', '')
        del out
        predict_image = sdPipe(prompt=prompt, image=latent_image,strength=0.75, guidance_scale=6).images[0]
        predict_image_CLIP = project_to_CLIP(predict_image)[0]
        sim = similarity(target_image_CLIP, predict_image_CLIP)
        sims.append(sim)
    
    # calculate mean sim
    average = 0
    for i in sims:
        average += i
    average /= len(sims)
    sims_valid_all[model_idx] = sims
    avgs_valid_all[model_idx] = average
    print(f'Average similarity on validation set: {average}')

In [None]:
# print(sims_train_all)
print(avgs_train_all)
# print(sims_valid_all)
print(avgs_valid_all)