In [None]:
from sdhelper import SD
import torch
import numpy as np
import datasets
from tqdm.autonotebook import tqdm, trange
from matplotlib import pyplot as plt
import PIL.Image
from typing import Callable

In [None]:
sd = SD('SD1.5')
data = datasets.load_dataset('0jl/SPair-71k', 'data', split='train', trust_remote_code=True)
pairs = datasets.load_dataset('0jl/SPair-71k', 'pairs', split='test', trust_remote_code=True)

In [None]:
def sc(transform_img: Callable):
    # precalculate representations
    representations = [
        sd.img2repr(transform_img(x['img']), ['up_blocks[1]'], 100, prompt=x['name'].split('/')[0]) for x in tqdm(data, desc='Calculating representations')
    ]

    # calculate percentage of correct keypoints at 10% of the bounding box (PCK@0.1_bbox)
    correct = []
    for x in (t:=tqdm(pairs, desc='Calculating SC')):
        a = representations[x['src_data_index']].concat()
        a /= torch.norm(a, dim=0, keepdim=True)
        b = representations[x['trg_data_index']].concat()
        b /= torch.norm(b, dim=0, keepdim=True)
        tbb = np.array(x['trg_bndbox'])
        tbb_max = max(tbb[2] - tbb[0], tbb[3] - tbb[1])
        for ([sx, sy],[tx,ty]) in zip(x['src_kps'], x['trg_kps']):
            src_repr = a[:, sy*a.shape[1]//x['src_img'].size[1], sx*a.shape[2]//x['src_img'].size[0]]
            cossim = (b * src_repr[:,None,None]).sum(dim=0)
            max_idx = cossim.argmax()
            y_max, x_max = np.unravel_index(max_idx, cossim.shape)
            x_max_pixel = x_max * x['trg_img'].size[0] / b.shape[2]
            y_max_pixel = y_max * x['trg_img'].size[1] / b.shape[1]
            dist = np.sqrt((x_max_pixel - tx)**2 + (y_max_pixel - ty)**2)
            relative_dist = dist / tbb_max
            correct.append(relative_dist < 0.1)
        if len(correct) % 100 == 0:
            t.set_postfix(pck=np.mean(correct)*100)

        
def transform_img(img):
    return img


In [None]:

def noise_img(img, noise_level):
    return PIL.Image.fromarray(np.clip(np.array(img) + np.random.normal(0, noise_level, np.array(img).shape), 0, 255).astype(np.uint8))

for noise_level in [0, 5, 50, 100, 200]:
    print(f'Noise level: {noise_level}')
    sc(lambda img: noise_img(img, noise_level))


In [None]:
def blur_img(img, blur_radius):
    return img.filter(PIL.ImageFilter.GaussianBlur(blur_radius))

for blur_radius in [0, 2, 5, 20, 50]:
    print(f'Blur radius: {blur_radius}')
    sc(lambda img: blur_img(img, blur_radius))

In [None]:
noise_img(data[0]['img'], 200)

In [None]:
blur_img(data[0]['img'], 5)