In [1]:
"""
[Package import]
Import some useful packages here
"""

import clip
import torch
from PIL import Image
from diffusers import StableDiffusionImg2ImgPipeline
from diffusers.models import AutoencoderKL
from Models import Voxel2StableDiffusionModel
from transformers import BlipProcessor,BlipForConditionalGeneration
from utils import load_image, to_PIL, transform, decode_img, GetROI
from torchvision import transforms
import warnings
warnings.filterwarnings('ignore')

In [2]:
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-L/14", device=device)
vae = AutoencoderKL.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="vae").to(device)

In [7]:
def GetPredictions(target_idx):
    """
    Generate low-level prediction images to target image "target_idx"

    Note that the implementation here is not flexible, we directly choose 4 models here.
    If you're going to run this code, please make sure you have replace the model_paths
    to the path where you're model in.
    
    Also, here we only select subject 1's "training" image

    Parameters:
    - target_idx: The index number of image to predict.

    Returns:
    - Predictions provided by models
    - Target image
    """

    sample_preds = []
    # Load models for bagging
    # Here, we choose model 92, 94, 97, 100 since they have better performance
    model_paths = ['../Models/92/199', '../Models/94/199', '../Models/97/199', '../Models/100/199']
    select_ROIs = [3, 1, 3, 3]
    
    image_path = f'../dataset/subj01/training_split/training_images/{target_idx}.png'
    target_image = transform(load_image(image_path))
    voxel2sd = Voxel2StableDiffusionModel().to(device)
    low_level_predict = []
    
    for i in range(len(model_paths)):    
        voxel2sd.load_state_dict(torch.load(model_paths[i])['model_state_dict'])
        voxel2sd.eval()
    
        target_ROI = torch.from_numpy(GetROI(target_idx, select_ROIs[i])).to(device).float()
        latent = voxel2sd(target_ROI.reshape(1, -1))
        
        del target_ROI
        
        latentImg = to_PIL(decode_img(latent, vae)[0])
        low_level_predict.append(latentImg)
    
        del latent
    return target_image, low_level_predict

In [4]:
def project_to_CLIP(image):
    """
    Project a image to CLIP space by CLIP model.

    Parameters:
    - image: Image to project.

    Returns:
    - Projected CLIP latent
    """
    with torch.no_grad():
        image = preprocess(image).unsqueeze(0).to(device)
        features = model.encode_image(image)
    return features

In [5]:
def similarity(image1_features, image2_features):
    """
    Calculate similarity of two images by cosine similarity

    Parameters:
    - image1_features: Feature of 1st image
    - image2_features: Feature of 2nd image

    Returns:
    - Similarity of two images
    """

    cos = torch.nn.CosineSimilarity(dim=0)
    return cos(image1_features,image2_features).item()

In [None]:
"""
[Bagging Process]

Generate images by our models, then perform bagging by "CLIP_bagging" function.
Show result with compare images.

First image is the original(target) image.
Second image is the best image from bagging process.
The remaining images are predictions by models.

The range of i indicate the images to predict.
"""
from PIL import ImageOps

sim_records = {}
for i in range(201, 300):
    target_image, low_level_predict = GetPredictions(i)
    best_sim = 0
    best_idx = -1
    sims = []

    bg = Image.new('RGB', (3272, 552), '#ffffff')
    for j in range(4):
        pred_image = low_level_predict[j]
        pred_image_ = ImageOps.expand(pred_image, 20, (255, 255, 255)) 
        bg.paste(pred_image_, (20+512+20+512+20+40 + (512+20)*j, 0))
        sim = similarity(project_to_CLIP(pred_image)[0], project_to_CLIP(to_PIL(target_image[0]))[0])
        sims.append(sim)
        if(sim > best_sim):
            best_sim = sim
            best_idx = j
    
    best_image = low_level_predict[best_idx]
    target_image_ = ImageOps.expand(to_PIL(target_image[0]), 20, (255, 255, 255))
    best_image_ = ImageOps.expand(best_image, 20, (255, 255, 255))
    
    bg.paste(target_image_, (0, 0))
    bg.paste(best_image_, (20+512, 0))

    sim_records[i] = best_sim
    print(sims)
    print(f'Round {i}: {best_sim}')
    bg.save(f'../results/low_level_{i}.png')
    bg.show()

In [31]:
avg = 0
for i in sim_records:
    avg += (sim_records[i])
print(avg / len(sim_records))

0.5493396577380952
