In [145]:
from skimage.metrics import structural_similarity
import datasets
from diffusers import StableDiffusionPipeline
import torch
import cv2
import numpy as np
import matplotlib.pyplot as plt
from tqdm.auto import tqdm

In [17]:
pokedata = datasets.load_dataset("lambdalabs/pokemon-blip-captions")

model_path = "sd-pokemon-model-lora"
pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", torch_dtype=torch.float16)
pipe.unet.load_attn_procs(model_path)
pipe.to("cuda")
pass

Loading pipeline components...:  57%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▌                                                                                                   | 4/7 [00:00<00:00,  8.03it/s]`text_config_dict` is provided which will be used to initialize `CLIPTextConfig`. The value `text_config["id2label"]` will be overriden.
`text_config_dict` is provided which will be used to initialize `CLIPTextConfig`. The value `text_config["bos_token_id"]` will be overriden.
`text_config_dict` is provided which will be used to initialize `CLIPTextConfig`. The value `text_config["eos_token_id"]` will be overriden.
Loading pipeline components...: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<

In [143]:
# def SSIM(*images):
#     assert len(images) == 2
#     images = [np.array(image.resize((512, 512))) for image in images]
#     return np.mean([structural_similarity(*[image[:,:,n_channel] for image in images]) for n_channel in range(3)])

def SSIM(*images):
    assert len(images) == 2
    images = [np.array(image.resize((512, 512)).convert("L")) for image in images]
    return structural_similarity(*images)

def SIFT(image1, image2):
    # Initialize SIFT detector
    sift = cv2.SIFT_create()
    
    # Find the keypoints and descriptor
    keypoints1, descriptors1 = sift.detectAndCompute(np.array(image1.convert("L")), None)
    keypoints2, descriptors2 = sift.detectAndCompute(np.array(image2.convert("L")), None)
    
    # BFMatcher with default params
    bf = cv2.BFMatcher()
    matches = bf.knnMatch(descriptors1, descriptors2, k=2)
    
    # Apply ratio test
    good_matches = []
    for m, n in matches:
        if m.distance < 0.75 * n.distance:
            good_matches.append(m)
    
    # Calculate similarity score
    similarity_score = len(good_matches) / max(len(keypoints1), len(keypoints2))
    return similarity_score

In [125]:
def is_all_black(image):
    return np.all(np.array(image) == 0)

In [None]:
n = 10
scorer = SIFT

for pokeimage, poketext in zip(*pokedata['train'][:n].values()):
    print(poketext)

    pokepredimages = pipe(poketext + " with white background" , num_inference_steps=30, guidance_scale=7.5).images

    for pokepredimage in pokepredimages:
        if not is_all_black(pokepredimage): 
            break
    else:
        print("All images were filtered")
        continue

    aux = []
    for pokeimage_, _ in tqdm(zip(*pokedata['train'][:].values())):
        score = scorer(pokeimage_, pokepredimage)
        aux.append((score, pokeimage_))

    pred_ssim, pred_image = list(sorted(aux, key=lambda x: x[0], reverse=True))[0]


    fig, axes = plt.subplots(1, 3, figsize=(10, 5))
    axes[0].imshow(pokeimage)
    axes[0].set_title(f"Score: {scorer(pokeimage, pokepredimage):3f}")

    axes[1].imshow(pokepredimage)
    axes[2].imshow(pred_image)
    axes[2].set_title(f"Score: {pred_ssim:3f}")
    plt.show()

a drawing of a green pokemon with red eyes


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 30/30 [00:01<00:00, 17.46it/s]
539it [01:06,  5.91it/s]