In [1]:
import torch
from torch.utils.data import DataLoader
from diffusers import StableDiffusionPipeline, DDIMScheduler, AutoencoderKL, KandinskyV22PriorPipeline
from accelerate.utils import set_seed
from PIL import Image
import pandas as pd
import os

from torchmetrics.multimodal.clip_score import CLIPScore
from torchmetrics.image.fid import FrechetInceptionDistance
from torchmetrics.image import StructuralSimilarityIndexMeasure, MultiScaleStructuralSimilarityIndexMeasure
from tqdm.notebook import tqdm

from src.dataset import MetricsDataset

from src.ip_adapter import IPAdapter

In [2]:
set_seed(2204)

In [3]:
base_model_path = 'dreamlike-art/dreamlike-anime-1.0'
vae_model_path = "stabilityai/sd-vae-ft-mse"
image_encoder_path = "/home/chaichuk/Annual_Project/IP-Adapter/models/image_encoder"
prior_model_path = "kandinsky-community/kandinsky-2-2-prior"
ip_ckpt = "/home/chaichuk/Annual_Project/Team73-Annual-Project/weights/512_res_model/checkpoint-100/ip_adapter.bin"
device = "cuda:0"

In [4]:
def collate_fn(data):
    pil_face_image = [example["pil_face_image"] for example in data]
    pt_face_image = torch.stack([example["pt_face_image"] for example in data])
    anime_image = torch.stack([example["anime_image"] for example in data])
    text = [example["text"] for example in data]

    return {
        "pil_face_image": pil_face_image,
        "pt_face_image": pt_face_image,
        "anime_image": anime_image,
        "text": text
    }

dataset = MetricsDataset(num_samples=5000)

dataloader = DataLoader(
        dataset,
        shuffle=False,
        collate_fn=collate_fn,
        batch_size=50,
        num_workers=4,
        pin_memory=True
)

In [5]:
noise_scheduler = DDIMScheduler(
    num_train_timesteps=1000,
    beta_start=0.00085,
    beta_end=0.012,
    beta_schedule="scaled_linear",
    clip_sample=False,
    set_alpha_to_one=False,
    steps_offset=1,
)
vae = AutoencoderKL.from_pretrained(vae_model_path).to(dtype=torch.float16)

In [6]:
pipeline = StableDiffusionPipeline.from_pretrained(
    base_model_path,
    torch_dtype=torch.float16,
    scheduler=noise_scheduler,
    vae=vae,
    feature_extractor=None,
    safety_checker=None
)

pipeline.set_progress_bar_config(disable=True)

Loading pipeline components...:   0%|          | 0/5 [00:00<?, ?it/s]

In [7]:
# pipe_prior = KandinskyV22PriorPipeline.from_pretrained(prior_model_path, torch_dtype=torch.float16).to(device)
# pipe_prior.set_progress_bar_config(disable=True)

In [8]:
ip_model = IPAdapter(pipeline, image_encoder_path, ip_ckpt, device)

  state_dict = torch.load(self.ip_ckpt, map_location="cpu")


In [9]:
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [10]:
clip_score = CLIPScore(model_name_or_path="openai/clip-vit-base-patch16").to(device)
ssim = StructuralSimilarityIndexMeasure(data_range=1.0).to(device)
ms_ssim = MultiScaleStructuralSimilarityIndexMeasure(data_range=1.0).to(device)
fid = FrechetInceptionDistance(normalize=True).to(device)

for batch in tqdm(dataloader):
    texts = batch["text"]
    pil_face_images = batch["pil_face_image"]
    pt_face_images = batch["pt_face_image"]
    anime_images = batch["anime_image"]

    fid.update(anime_images.to(device), real=True)

    # clip_t2i_embeds = pipe_prior(texts).image_embeds
    # t2i_images = ip_model.generate(clip_image_embeds=clip_t2i_embeds, num_samples=1, num_inference_steps=50, scale=0.7, height=512, width=512, output_type='pt')

    images = ip_model.generate(pil_image=pil_face_images, prompt='one person\'s face, hand-drawn anime style', num_samples=1, num_inference_steps=50, height=512, width=512, scale=0.7, output_type='pt')
    ssim.update(images.to(torch.float32), pt_face_images.to(device))
    ms_ssim.update(images.to(torch.float32), pt_face_images.to(device))
    fid.update(images.to(torch.float32), real=False)
    clip_score.update(images, texts)


print('FID:', float(fid.compute()))
print('SSIM:', float(ssim.compute()))
print('MS_SSIM:', float(ms_ssim.compute()))
print('CLIP Score:', float(clip_score.compute()))

  0%|          | 0/100 [00:00<?, ?it/s]

It looks like you are trying to rescale already rescaled images. If the input images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again.
Token indices sequence length is longer than the specified maximum sequence length for this model (87 > 77). Running this sequence through the model will result in indexing errors


FID: tensor(120.8018, device='cuda:0')
SSIM: tensor(0.3272, device='cuda:0')
MS_SSIM: tensor(0.2138, device='cuda:0')
CLIP Score: tensor(22.4388, device='cuda:0')
