# Representation Autoencoder

goal: modify representation channels so that 2 channels contain the positional information and the others don't.

WIP: doesn't work (yet)

## Setup

In [None]:
from sdhelper import SD
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import datasets
from tqdm.autonotebook import tqdm, trange
import PIL.Image


In [None]:
# load model and data
sd = SD('sd15')
data = datasets.load_dataset('0jl/SPair-71k', 'data', split='train', trust_remote_code=True)

In [None]:
# config
p = 'up_blocks[1]'
img_size = 512
device = 'cuda'

In [None]:
# precalculate representations

def expand_and_resize(x: PIL.Image.Image, size, border_pad=True):
    n, m = x.size
    s = max(n, m)
    r = PIL.Image.new('RGB', (s, s))
    r.paste(x, ((s-n)//2, (s-m)//2))
    if border_pad:
        # pad with border
        if n > m:
            r.paste(x.crop((0, 0, n, 1)).resize((n,(s-m)//2)), (0, 0))
            r.paste(x.crop((0, m-1, n, m)).resize((n,(s-m)//2)), (0, m+(s-m)//2))
        elif m > n:
            r.paste(x.crop((0, 0, 1, m)).resize(((s-n)//2,m)), (0, 0))
            r.paste(x.crop((n-1, 0, n, m)).resize(((s-n)//2,m)), (n+(s-n)//2, 0))
    return r.resize((size, size))

transform_img = lambda img: expand_and_resize(img, img_size, True)

representations = []
for x in tqdm(data, desc='Calculating representations'):
    r = sd.img2repr(transform_img(x['img']), [p], 100, prompt=x['name'].split('/')[0])
    representations.append(r.data[p])

In [None]:
num_channels, H, W = representations[0].shape[1:]
assert H == W

## Train autoencoder

In [None]:
class SDReprAutoencoder(nn.Module):
    def __init__(self, num_channels, encoded_channels):
        super().__init__()
        self.encoder = nn.Linear(num_channels, encoded_channels)
        self.decoder = nn.Linear(encoded_channels, num_channels)

ae = SDReprAutoencoder(num_channels, num_channels+H+W).to(device)
ae_optimizer = torch.optim.Adam(ae.parameters(), lr=1e-4)
position_estimator = nn.Sequential(
    nn.Linear(num_channels, H+W, device=device),
)
pe_optimizser = torch.optim.Adam(position_estimator.parameters(), lr=1e-4)

In [None]:
# plot training progress
from trainplot.trainplot import TrainPlotPlotlyExperimental as TrainPlot
tp = TrainPlot(threaded=True)
tp.fig.update_yaxes(type="log");
tp_acc = TrainPlot(threaded=True)

In [None]:
# training loop

for epoch in trange(50):
    accucacies = []
    losses = []
    for i, reprs in enumerate(representations):

        # setup input
        r = reprs[0].to(dtype=torch.float32, device=device)
        x_positions = (torch.arange(H, device=device)).repeat(W)
        y_positions = (torch.arange(W, device=device)).repeat_interleave(H)
        positions = torch.cat([
            F.one_hot(x_positions, H).float(),
            F.one_hot(y_positions, W).float()
        ], dim=1)
        r = r.flatten(1,2).T.to(device)
        r = (r - r.mean()) / r.std()

        # forward pass
        encoded = ae.encoder(r)
        position_estimation = position_estimator(encoded[:,:num_channels])
        # decoded = ae.decoder(torch.cat([encoded[:,:num_channels], positions], dim=1))  # use real positions for decoding
        decoded = ae.decoder(encoded)

        # calculate losses
        # pos_loss = F.mse_loss(F.softmax(encoded[:,num_channels:], dim=-1), positions)
        pos_loss = F.cross_entropy(encoded[:,num_channels:], positions)
        # pe_loss = F.mse_loss(position_estimation, positions)
        pe_loss = F.cross_entropy(position_estimation, positions)
        ae_loss = F.mse_loss(decoded, r) + F.l1_loss(decoded, r)
        ae_loss_full = 10*ae_loss + 1*pos_loss - .5*pe_loss
        # TODO: maybe add loss trying to have the embedding be close to the actual representation

        # optimize autoencoder
        ae_optimizer.zero_grad()
        ae_loss_full.backward(retain_graph=True)
        ae_optimizer.step()

        # optimize position estimator
        pe_optimizser.zero_grad()
        pe_loss.backward()
        pe_optimizser.step()

        # log
        with torch.no_grad():
            pos_acc = (encoded[:,num_channels:].unflatten(-1,(2,H)).argmax(dim=-1) == positions.unflatten(-1,(2,H)).argmax(dim=-1)).float().mean(axis=0).cpu()
            pe_acc = (position_estimation.unflatten(-1,(2,H)).argmax(dim=-1) == positions.unflatten(-1,(2,H)).argmax(dim=-1)).float().mean(axis=0).cpu()
            accucacies.append(dict(pos_x=pos_acc[0].item(), pos_y=pos_acc[1].item(), pe_x=pe_acc[0].item(), pe_y=pe_acc[1].item()))
            losses.append(dict(pos=pos_loss.item(), pe=pe_loss.item(), ae=ae_loss.item()))
            if i % 100 == 0:
                tp_acc(**{k: np.mean([x[k] for x in accucacies]) for k in accucacies[0].keys()})
                tp(**{k: np.mean([x[k] for x in losses]) for k in losses[0].keys()})
                accucacies = []
                losses = []

## Calculate Semantic Correspondence

In [None]:
pairs = datasets.load_dataset('0jl/SPair-71k', 'pairs', split='test', trust_remote_code=True)

In [None]:
# reference
# calculate percentage of correct keypoints at 10% of the bounding box (PCK@0.1_bbox)

correct_reference = []
positions_reference = []
for x in (t:=tqdm(pairs, desc='Calculating SC')):
    a = representations[x['src_data_index']].squeeze(0)
    b = representations[x['trg_data_index']].squeeze(0)
    tbb_max = max(x['trg_bndbox'][2] - x['trg_bndbox'][0], x['trg_bndbox'][3] - x['trg_bndbox'][1])
    for ([sx, sy],[tx,ty]) in zip(x['src_kps'], x['trg_kps']):
        src_repr = a[:, 
            int((sy + (max(x['src_img'].size) - x['src_img'].size[1])/2) * a.shape[1] / max(x['src_img'].size)),
            int((sx + (max(x['src_img'].size) - x['src_img'].size[0])/2) * a.shape[2] / max(x['src_img'].size)),
        ]
        cossim = (b * src_repr[:,None,None]).sum(dim=0)
        y_max, x_max = np.unravel_index(cossim.argmax().cpu(), cossim.shape)
        x_max_pixel = x_max / b.shape[2] * max(x['trg_img'].size) - (max(x['trg_img'].size) - x['trg_img'].size[0]) / 2
        y_max_pixel = y_max / b.shape[1] * max(x['trg_img'].size) - (max(x['trg_img'].size) - x['trg_img'].size[1]) / 2
        relative_distance = ((x_max_pixel - tx)**2 + (y_max_pixel - ty)**2) ** 0.5 / tbb_max
        correct_reference.append(relative_distance < 0.1)
        positions_reference.append((x_max_pixel, y_max_pixel))
    if len(correct_reference) % 100 == 0:
        t.set_postfix(pck=np.mean(correct_reference)*100)

In [None]:
# using AE representations

encoded_representation = [
    ae.encoder(r.view(-1, H*W).T.to(device=device, dtype=torch.float32)).T.unflatten(1, (H,W))[:num_channels,:,:].to('cpu')
    for r in tqdm(representations, desc='Calculating encoded representations')
]

# calculate percentage of correct keypoints at 10% of the bounding box (PCK@0.1_bbox)
correct = []
positions = []
for x in (t:=tqdm(pairs, desc='Calculating SC')):
    a = encoded_representation[x['src_data_index']]
    b = encoded_representation[x['trg_data_index']]
    tbb_max = max(x['trg_bndbox'][2] - x['trg_bndbox'][0], x['trg_bndbox'][3] - x['trg_bndbox'][1])
    for ([sx, sy],[tx,ty]) in zip(x['src_kps'], x['trg_kps']):
        src_repr = a[:, 
            int((sy + (max(x['src_img'].size) - x['src_img'].size[1])/2) * a.shape[1] / max(x['src_img'].size)),
            int((sx + (max(x['src_img'].size) - x['src_img'].size[0])/2) * a.shape[2] / max(x['src_img'].size)),
        ]
        cossim = (b * src_repr[:,None,None]).sum(dim=0)
        y_max, x_max = np.unravel_index(cossim.argmax().cpu(), cossim.shape)
        x_max_pixel = x_max / b.shape[2] * max(x['trg_img'].size) - (max(x['trg_img'].size) - x['trg_img'].size[0]) / 2
        y_max_pixel = y_max / b.shape[1] * max(x['trg_img'].size) - (max(x['trg_img'].size) - x['trg_img'].size[1]) / 2
        relative_distance = ((x_max_pixel - tx)**2 + (y_max_pixel - ty)**2) ** 0.5 / tbb_max
        correct.append(relative_distance < 0.1)
        positions.append((x_max_pixel, y_max_pixel))
    if len(correct) % 100 == 0:
        t.set_postfix(pck=np.mean(correct)*100)

In [None]:
f'{np.mean(correct):.2%} instead of {np.mean(correct_reference):.2%}'

In [None]:
# with ae output for verification

decoded_representation = [
    ae.decoder(ae.encoder(r.reshape(-1, H*W).T.to(device=device, dtype=torch.float32))).T.unflatten(1, (H,W)).to('cpu')
    for r in tqdm(representations, desc='Calculating decoded representations')
]

correct_reference = []
positions_reference = []
for x in (t:=tqdm(pairs, desc='Calculating SC')):
    a = decoded_representation[x['src_data_index']]
    b = decoded_representation[x['trg_data_index']]
    tbb_max = max(x['trg_bndbox'][2] - x['trg_bndbox'][0], x['trg_bndbox'][3] - x['trg_bndbox'][1])
    for ([sx, sy],[tx,ty]) in zip(x['src_kps'], x['trg_kps']):
        src_repr = a[:, 
            int((sy + (max(x['src_img'].size) - x['src_img'].size[1])/2) * a.shape[1] / max(x['src_img'].size)),
            int((sx + (max(x['src_img'].size) - x['src_img'].size[0])/2) * a.shape[2] / max(x['src_img'].size)),
        ]
        cossim = (b * src_repr[:,None,None]).sum(dim=0)
        y_max, x_max = np.unravel_index(cossim.argmax().cpu(), cossim.shape)
        x_max_pixel = x_max / b.shape[2] * max(x['trg_img'].size) - (max(x['trg_img'].size) - x['trg_img'].size[0]) / 2
        y_max_pixel = y_max / b.shape[1] * max(x['trg_img'].size) - (max(x['trg_img'].size) - x['trg_img'].size[1]) / 2
        relative_distance = ((x_max_pixel - tx)**2 + (y_max_pixel - ty)**2) ** 0.5 / tbb_max
        correct_reference.append(relative_distance < 0.1)
        positions_reference.append((x_max_pixel, y_max_pixel))
    if len(correct_reference) % 100 == 0:
        t.set_postfix(pck=np.mean(correct_reference)*100)