In [None]:
from sdhelper import SD
import torch
import numpy as np
import datasets
from tqdm.autonotebook import tqdm, trange
from matplotlib import pyplot as plt
import PIL.Image
from trainplot.trainplot import TrainPlotPlotlyExperimental as TrainPlot
import random
from collections import defaultdict
from typing import Literal, Callable
from functools import partial, cache
ceil = lambda x: int(np.ceil(x))

In [None]:
sd = SD('sd1.5')

print(f'Using {sd.model_name} model, available extract positions: {sd.available_extract_positions}')
dataset_pairs: datasets.Dataset = datasets.load_dataset('0jl/SPair-71k', trust_remote_code=True, split='test')
repr_dataset_name = f'{sd.model_name}-SPair-71k-repr'

def expand_and_resize(x: PIL.Image.Image, size = 960, border_pad=True):
    n, m = x.size
    s = max(n, m)
    r = PIL.Image.new('RGB', (s, s))
    r.paste(x, ((s-n)//2, (s-m)//2))
    if border_pad:
        # pad with border
        if n > m:
            r.paste(x.crop((0, 0, n, 1)).resize((n,(s-m)//2)), (0, 0))
            r.paste(x.crop((0, m-1, n, m)).resize((n,(s-m)//2)), (0, m+(s-m)//2))
        elif m > n:
            r.paste(x.crop((0, 0, 1, m)).resize(((s-n)//2,m)), (0, 0))
            r.paste(x.crop((n-1, 0, n, m)).resize(((s-n)//2,m)), (n+(s-n)//2, 0))
    return r.resize((size, size))

def expand_and_resize_keypoint(x, y, n, m, o, p):
    s = max(n, m)
    return (x + (s-n)//2) * o / s, (y + (s-m)//2) * p / s

def expand(x: PIL.Image.Image, size = 960):
    factor = size / min(x.size)
    return x.resize((int(x.size[0]*factor), int(x.size[1]*factor)))

def expand_keypoint(x, y, n, m, o, p):
    return x * o / n, y * p / m

def get_transforms(name, size=960):
    match name:
        case 'expand_and_resize':
            return partial(expand_and_resize, size=size), expand_and_resize_keypoint
        case 'expand':
            return partial(expand, size=size), expand_keypoint
        case None:
            return lambda x, *args, **kwargs: x, lambda x, y, n, m, o, p: (x, y)
        case _:
            raise ValueError(f'Unknown transform name: {name}')

In [None]:
@cache
def get_empty_repr(pos: list[str], size: tuple[int, int]):
    empty_reprs = []
    for i in range(20):
        empty_repr = sd.img2repr(PIL.Image.new('RGB', size), pos, 999, output_device='cpu')
        assert isinstance(empty_repr, dict)
        empty_reprs.append(empty_repr)
    return {k: torch.mean(torch.stack([x[k] for x in empty_reprs]), dim=0) for k in empty_reprs[0]}



def get_zoomed_reprs(img: PIL.Image.Image, pos: list[str], step: int = 1):
    reprs = []
    for zoom in range(-6,7):
        zoom_factor = 2**(zoom/2)
        zoomed_img = img.resize((int(img.size[0]*zoom_factor), int(img.size[1]*zoom_factor)))
        zoomed_repr = sd.img2repr(zoomed_img, pos, step, output_device='cpu')
        assert isinstance(zoomed_repr, dict)
        reprs.append(concat_reprs(zoomed_repr))
    return reprs



shifted_reprs_cache = {}
def get_shifted_reprs(img: PIL.Image.Image, pos: list[str], step: int, skip: int = 1, smothing=False) -> torch.Tensor:

    # get representation sample to get representation sizes
    sample_repr = sd.img2repr(img, pos, step, output_device='cpu')
    assert isinstance(sample_repr, dict)
    m,n = img.size

    # create big image
    img_arr = np.array(img)
    big_img = np.zeros((n*3, m*3, 3), dtype=np.uint8)
    big_img[n:2*n, m:2*m] = img_arr

    # border pad big image
    big_img[:n, :m] = img_arr[0, 0]  # top left
    big_img[:n, 2*m:] = img_arr[0, -1]  # top right
    big_img[2*n:, :m] = img_arr[-1, 0]  # bottom left
    big_img[2*n:, 2*m:] = img_arr[-1, -1]  # bottom right
    big_img[:n, m:2*m] = img_arr[None,0,:]  # top
    big_img[2*n:, m:2*m] = img_arr[None,-1,:]  # bottom
    big_img[n:2*n, :m] = img_arr[:,0,None]  # left
    big_img[n:2*n, 2*m:] = img_arr[:,-1,None]  # right

    # setup result tensor
    result = torch.zeros((sum(sample_repr[x].shape[1] for x in pos), n, m))
    
    # shift image and calculate representations
    i_channel = 0
    for p in pos:
        tile_size = min(*2**np.arange(0, 8), key=lambda x: abs(x - n // sample_repr[p].shape[-2]))  # get closest power of 2
        num_channels = sample_repr[p].shape[1]
        for i in range(skip//2, tile_size, skip):
            for j in range(skip//2, tile_size, skip):
                x = i - tile_size // 2
                y = j - tile_size // 2
                shifted_img = big_img[n+x:n+x+n, m+y:m+y+m]
                shifted_repr = sd.img2repr(shifted_img, [p], step, output_device='cpu')[p][0]
                x_shape = min(ceil((n-i)/tile_size), shifted_repr.shape[1])
                y_shape = min(ceil((m-j)/tile_size), shifted_repr.shape[2])
                result[i_channel:i_channel+num_channels, i:x_shape*tile_size:tile_size, j:y_shape*tile_size:tile_size] = shifted_repr[:, :x_shape, :y_shape]
        i_channel += num_channels

    result = result[:,skip//2::skip,skip//2::skip]

    if smothing:
        # gaussian smothing
        kernel = torch.tensor([[1, 2, 1], [2, 4, 2], [1, 2, 1]], dtype=torch.float32, device=result.device) / 16
        result = torch.nn.functional.conv2d(result[None], kernel)

    return result



def concat_reprs(reprs: dict[str, torch.Tensor]):
    '''Concatenate representations with different spatial sizes into a single tensor with the largest spatial size.'''
    # If the representation sizes are not multiples of each other, the bottom and right edges of the spatially larger representations will be 0-padded.
    max_spatial = np.array(max(x.shape[-2:] for x in reprs.values()))
    min_spatial = np.array(min(x.shape[-2:] for x in reprs.values()))
    while (max_spatial > min_spatial).any(): min_spatial *= 2
    spatial = min_spatial
    num_features1 = sum(x.shape[1] for x in reprs.values())
    repr_full = torch.zeros((num_features1, *spatial), device=sd.device)
    i = 0
    for p in reprs:
        r1 = reprs[p]
        _, num_channels1, n1, m1 = r1.shape
        tmp1 = r1.repeat_interleave(spatial[0]//n1, dim=-2).repeat_interleave(spatial[1]//m1, dim=-1)
        repr_full[i:i+num_channels1, :tmp1.shape[-2], :tmp1.shape[-1]] = tmp1.to(repr_full.device)
        i += num_channels1
    return repr_full



@torch.no_grad()
def sc(
        sample: dict,
        plot: bool = True,
        precomputed_reprs: list | None = None,
        extraction_step: int = 1,
        pos: list[str] = ['up_blocks[1]'],
        transform: Callable = lambda x: x,
        transform_keypoint: Callable = lambda x, y, *_: (x, y),
        subtract_empty_repr: bool = False,
        shift_reprs: bool = False,
        shift_skip: int = 8,
        zoom_reprs: bool = False,
    ):
    '''Solve semantic correspondence for a single sample.'''

    if shift_reprs:
        assert precomputed_reprs is None
        assert not subtract_empty_repr
        if sample['src_name'] not in shifted_reprs_cache:
            shifted_reprs_cache[sample['src_name']] = get_shifted_reprs(transform(sample['src_img']), pos, extraction_step, shift_skip)
        repr1_full = shifted_reprs_cache[sample['src_name']]
        if sample['trg_name'] not in shifted_reprs_cache:
            shifted_reprs_cache[sample['trg_name']] = get_shifted_reprs(transform(sample['trg_img']), pos, extraction_step, shift_skip)
        repr2_full = shifted_reprs_cache[sample['trg_name']]
    else:
        # load representations
        if precomputed_reprs is not None:
            repr1 = precomputed_reprs[sample['src_data_index']]
            repr2 = precomputed_reprs[sample['trg_data_index']]
            assert set(repr1.keys()) == set(repr2.keys())
            pos = list(repr1.keys())
        else:
            category = dataset_pairs.features['category'].names[sample['category']]
            repr1 = sd.img2repr(transform(sample['src_img']), extract_positions=pos, step=extraction_step, prompt=category)
            repr2 = sd.img2repr(transform(sample['trg_img']), extract_positions=pos, step=extraction_step, prompt=category)
        assert isinstance(repr1, dict) and isinstance(repr2, dict)

        # concatenate representations
        repr1_full = concat_reprs(repr1)
        repr2_full = concat_reprs(repr2)
        if subtract_empty_repr:
            repr1_empty = concat_reprs(get_empty_repr(tuple(pos), transform(sample['src_img']).size))
            repr2_empty = concat_reprs(get_empty_repr(tuple(pos), transform(sample['trg_img']).size))
            repr1_full = repr1_full - repr1_empty
            repr2_full = repr2_full - repr2_empty

    # get images
    src_img = transform(sample['src_img'])
    trg_img = transform(sample['trg_img'])
    sn, sm = src_img.size
    tn, tm = trg_img.size
    assert len(sample['src_kps']) == len(sample['trg_kps'])

    # get bounding box
    sbb = np.array(sample['src_bndbox'])
    sbb[:2] = transform_keypoint(*sbb[:2], *sample['src_img'].size, sn, sm)
    sbb[2:] = transform_keypoint(*sbb[2:], *sample['src_img'].size, sn, sm)
    tbb = np.array(sample['trg_bndbox'])
    tbb[:2] = transform_keypoint(*tbb[:2], *sample['trg_img'].size, tn, tm)
    tbb[2:] = transform_keypoint(*tbb[2:], *sample['trg_img'].size, tn, tm)
    tbb_max = max(tbb[2] - tbb[0], tbb[3] - tbb[1])

    # solve semantic correspondence for each keypoint pair
    pcks = []
    for ([sx, sy],[tx,ty]) in zip(sample['src_kps'], sample['trg_kps']):

        # transform keypoints and bb
        sx, sy = transform_keypoint(sx, sy, *sample['src_img'].size, sn, sm)
        tx, ty = transform_keypoint(tx, ty, *sample['trg_img'].size, tn, tm)

        # calc similarities
        if shift_reprs:
            point = repr1_full[:, int(sy)//shift_skip, int(sx)//shift_skip, None, None]
        else:
            max_spatial1 = np.array(max(repr1[x].shape[-2:] for x in pos))
            point = repr1_full[:, int(sy/(sm/max_spatial1[-2])), int(sx/(sn/max_spatial1[-1])), None, None]
        similarities = torch.nn.functional.cosine_similarity(repr2_full, point, dim=0).cpu()  # cossim
        # similarities = (repr2_full - point).abs().mean(dim=0).cpu()  # MAE - doesn't seem to work well
        max_i = similarities.argmax().item()
        x_max = max_i % repr2_full.shape[-1]
        y_max = max_i // repr2_full.shape[-1]

        # calculate error distance -> PCK
        x_max_pixel = x_max*shift_skip if shift_reprs else (x_max+.5) * tn / repr2_full.shape[-1]
        y_max_pixel = y_max*shift_skip if shift_reprs else (y_max+.5) * tm / repr2_full.shape[-2]
        dist = ((x_max_pixel - tx)**2 + (y_max_pixel - ty)**2)**0.5
        dist_rel = dist / tbb_max
        pck = dist_rel <= 0.1
        pcks.append(pck)

        if not plot:
            continue

        # setup plot
        plt.figure(figsize=(9, 3))
        plt.suptitle(f'{sample["src_name"].split("/")[0]} (id:{sample["pair_id"]}) - rel.dist.: {dist_rel:.2f}')

        # plot source image
        plt.subplot(131)
        plt.title('source')
        plt.imshow(src_img)
        plt.scatter([sx], [sy], c='r')  # source keypoint
        plt.plot([sbb[0], sbb[2], sbb[2], sbb[0], sbb[0]], [sbb[1], sbb[1], sbb[3], sbb[3], sbb[1]], c='gray')  # bounding box
        plt.axis('off')

        # plot target image
        plt.subplot(132)
        plt.title('target')
        plt.imshow(trg_img)
        plt.scatter([tx], [ty], c='r')  # target keypoint
        plt.plot([tbb[0], tbb[2], tbb[2], tbb[0], tbb[0]], [tbb[1], tbb[1], tbb[3], tbb[3], tbb[1]], c='gray')  # bounding box
        plt.axis('off')

        # plot similarities
        plt.subplot(133)
        plt.title('similarities')
        plt.imshow(similarities.view(*repr2_full.shape[-2:]).numpy())  # similarities
        plt.scatter([x_max], [y_max], c='b')  # predicted keypoint
        plt.scatter([tx/tn*repr2_full.shape[-1]-.5], [ty/tm*repr2_full.shape[-2]-.5], c='r')  # true target keypoint
        plt.axis('off')

        plt.tight_layout()
        plt.show()

    return pcks


def sc_plot_random(
        i: int | None = None,
        pos: list[str] = ['up_blocks[0]'],
        trans: Literal['expand'] | Literal['expand_and_resize'] | None = None,
        step: int = 1,
        **kwargs,
):
    if i is None:
        i = torch.randint(len(dataset_pairs), (1,)).item()
        print(f'{i = }')
    transform, transform_keypoint = get_transforms(trans)
    sc(dataset_pairs[i], plot=True, extraction_step=step, pos=pos, transform=transform, transform_keypoint=transform_keypoint, **kwargs)


def sc_calc_dataset(
        pos: list[str] = ['up_blocks[0]'],
        trans: Literal['expand'] | Literal['expand_and_resize'] | None = None,
        repr_extraction: Callable | None = None,
        step: int = 1,
        shift_reprs: bool = False,
    ):
    transform, transform_keypoint = get_transforms(trans, size=512)
    if repr_extraction is None:
        repr_extraction_fn = lambda x, pos, step, prompt: sd.img2repr(x, pos, step, prompt, spatial_avg=False, output_device='cpu')

    # precalculate representations
    if not shift_reprs:
        data_dataset = datasets.load_dataset('0jl/SPair-71k', 'data', trust_remote_code=True, split='train')
        dataset_reprs = [repr_extraction_fn(transform(x['img']), pos, step, x['name'].split('/')[0]) for x in tqdm(data_dataset, desc='calculating representations')]
    else:
        dataset_reprs = None

    # tp = TrainPlot()
    pcks = []
    cpcks = {}
    try:
        for sample in tqdm(dataset_pairs, desc='processing samples'):
            assert isinstance(sample, dict)
            new_pcks = sc(sample, plot=False, precomputed_reprs=dataset_reprs, extraction_step=step, pos=pos, transform=transform, transform_keypoint=transform_keypoint)
            pcks += new_pcks
            c = dataset_pairs.features['category'].names[sample['category']]
            cpcks[c] = cpcks.get(c, []) + new_pcks
            # tp(**{k: np.mean(v) for k, v in cpcks.items()})
    finally:
        print('SC Results:')
        print(f'PCK: {np.mean(pcks):5.2%} ({len(pcks)})')
        for k, v in cpcks.items():
            print(f'{k+":":<15} {np.mean(v):5.1%} {f"({len(v)})":>7}')
        return pcks, cpcks


def sc_calc_dataset_small(
        num_samples=50,
        seed=42,
        pos: list[str] = ['up_blocks[0]', 'up_blocks[1]'],
        trans: Literal['expand'] | Literal['expand_and_resize'] | None = None,
        step: int = 1,
    ):
    transform, transform_keypoint = get_transforms(trans)
    rng = random.Random(seed)

    # calculate PCK
    pcks = []
    for _ in trange(num_samples):
        i = rng.randint(0, len(dataset_pairs)-1)
        new_pcks = sc(dataset_pairs[i], plot=False, extraction_step=step, pos=pos, transform=transform, transform_keypoint=transform_keypoint)
        pcks += new_pcks
    print(f'PCK: {np.mean(pcks):5.1%} ({len(pcks)})')



def random_hyper_opt():
    available_pos = [x for x in sd.available_extract_positions if any(f'{y}_block' in x for y in ['down', 'mid', 'up'])]
    runs = []
    try:
        for _ in trange(int(1e10)):
            # randomize hyperparameters
            t = ['expand_and_resize', 'expand', None][random.randint(0, 2)]
            p = random.sample(available_pos, random.randint(1, len(available_pos)))
            s = random.randint(1, 999)

            # run
            t1, t2 = get_transforms(t)
            sample = dataset_pairs[random.randint(0, len(dataset_pairs)-1)]
            pcks = sc(sample, plot=False, extraction_step=s, pos=p, transform=t1, transform_keypoint=t2)
            for pck in pcks:
                runs.append((dict(t=t,p=p,s=s), pck))
    except KeyboardInterrupt:
        pass
    finally:
        print('Random Hyperparameter Optimization Results:')
        res = np.mean([x[1] for x in runs])
        print(f'PCK: {res:.2%} ({len(runs)} runs)')

        print('transforms:')
        for t in ['expand_and_resize', 'expand', None]:
            r = [x[1] for x in runs if x[0]["t"] == t]
            print(f'  {str(t)+":":<20} {np.mean(r):5.1%} ({np.mean(r)-res:+6.1%}) {f"({len(r)})":>7}')
        
        print('positions:')
        for p in available_pos:
            r = [x[1] for x in runs if p in x[0]["p"]]
            r_ = [x[1] for x in runs if p not in x[0]["p"]]
            print(f'  {p+":":<20} with: {np.mean(r):5.1%} ({np.mean(r)-res:+6.1%}) {f"({len(r)})":>7}, without: {np.mean(r_):.1%} ({np.mean(r_)-res:+6.1%}) {f"({len(r_)})":>7}')

        print('position combinations: ')
        pos_combinations = defaultdict(list)
        for x in runs:
            pos_combinations[tuple(sorted(x[0]["p"]))].append(x[1])
        for pos, pck in sorted(pos_combinations.items(), key=lambda x: -np.mean(x[1]))[:20]:
            print(f'  {str(pos):50}: {np.mean(pck):5.1%} ({np.mean(pck)-res:+6.1%}) {f"({len(pck)})":>7}')

        print('steps:')
        for s in range(0, 1000, 100):
            r = [x[1] for x in runs if s < x[0]["s"] <= x[0]["s"]+100]
            print(f'  {f"{s}-{s+100}":20} {np.mean(r):5.1%} ({np.mean(r)-res:+6.1%}) {f"({len(r)})":>7}')

        return runs
    
def multi_repr_extraction(x, pos, step, prompt, count = 4):
    reprs = [sd.img2repr(x, pos, step, prompt, spatial_avg=False, output_device='cpu') for _ in range(count)]
    result = {}
    for k in reprs[0]:
        result[k] = torch.stack([x[k] for x in reprs], dim=0).mean(dim=0)
    return result

# sc_plot_random(pos=['up_blocks[1]'], trans=None, step=100, shift_reprs=True)
# sc_calc_dataset_small()
pcks, cpcks = sc_calc_dataset(
    pos=['up_blocks[1]'],
    trans='expand',
    repr_extraction=None,
    step=100,
    shift_reprs=True,
)
# runs = random_hyper_opt()