## Setup

In [1]:
from sdhelper import SD
import torch
import numpy as np
import datasets
from tqdm.autonotebook import tqdm, trange
from matplotlib import pyplot as plt
import PIL.Image
import torch.nn as nn

In [None]:
# load model and dataset
sd = SD('SD1.5')
data = datasets.load_dataset('0jl/SPair-71k', 'data', split='train', trust_remote_code=True)

In [3]:
# config
pos = ['down_blocks[3]', 'mid_block', 'up_blocks[0]', 'up_blocks[1]', 'up_blocks[2]']
img_size = 512

## Precalculate representations

In [None]:
# precalculate representations

def expand_and_resize(x: PIL.Image.Image, size = img_size, border_pad=True):
    n, m = x.size
    s = max(n, m)
    r = PIL.Image.new('RGB', (s, s))
    r.paste(x, ((s-n)//2, (s-m)//2))
    if border_pad:
        # pad with border
        if n > m:
            r.paste(x.crop((0, 0, n, 1)).resize((n,(s-m)//2)), (0, 0))
            r.paste(x.crop((0, m-1, n, m)).resize((n,(s-m)//2)), (0, m+(s-m)//2))
        elif m > n:
            r.paste(x.crop((0, 0, 1, m)).resize(((s-n)//2,m)), (0, 0))
            r.paste(x.crop((n-1, 0, n, m)).resize(((s-n)//2,m)), (n+(s-n)//2, 0))
    return r.resize((size, size))

representations = []
for x in tqdm(data, desc='Calculating representations'):
    r = sd.img2repr(expand_and_resize(x['img']), pos, 100, prompt=x['name'].split('/')[0])
    representations.append({p: x[0].permute(1,2,0).flatten(0,1).to(torch.float32) for p, x in r.data.items()})

## Linear Layer to estimate position

In [None]:
# init linear classification layers
linear_layers = []
sample_reprs = sd.img2repr(np.zeros((img_size, img_size, 3)), pos, 100)
for p in pos:
    linear_layers.append(nn.Linear(in_features=sample_reprs[p].shape[1], out_features=sample_reprs[p].shape[2] * sample_reprs[p].shape[3], dtype=torch.float32, device=sd.device))

In [None]:
optimizers = [
    torch.optim.Adam(l.parameters(), lr=1e-3) for l in linear_layers
]

In [None]:
# train
for epoch in trange(10):
    print(f'Epoch {epoch}')
    for i, sample in enumerate(tqdm(representations[:-200], leave=False)):
        for p, l, o in zip(pos, linear_layers, optimizers):
            x = sample[p].to('cuda')
            l.zero_grad()
            y = l(x)
            loss = nn.functional.cross_entropy(y, torch.arange(y.shape[0], device=sd.device))
            loss.backward()
            o.step()
        if i % 100 == 0:
            print(loss.item())
    # for sample in tqdm(representations[-200:], leave=False):
    #     for p, l in zip(pos, linear_layers):
    #         x = sample[p][0].permute(1,2,0).flatten(0,1).to('cuda')
    #         y = l(x)
    #         loss = nn.functional.cross_entropy(y, torch.arange(y.shape[0], device=sd.device))
    #         plot(loss.item())

In [None]:
# test
accuracies = {p: [] for p in pos}
with torch.no_grad():
    for sample in tqdm(representations[-200:]):
        for p, l in zip(pos, linear_layers):
            x = sample[p].to('cuda')
            y = l(x)
            accuracies[p].append((y.argmax(dim=1) == torch.arange(y.shape[0], device='cuda')).float().mean().item())

In [None]:
for p, a in accuracies.items():
    print(f'{p:15} {np.mean(a):7.2%}')

In [None]:
plt.hist(linear_layers[2].bias.cpu().detach().numpy().flatten(), bins=100);

In [None]:
plt.hist(linear_layers[2].weight.cpu().detach().numpy().flatten(), bins=100)
plt.yscale('log')
plt.xlabel('Weight value')
plt.ylabel('Frequency')

In [None]:
(linear_layers[2].weight.cpu().detach().abs().numpy().flatten() > 0.25).sum() / linear_layers[2].weight.numel()

## get PCs corresponding to positions

In [None]:
def compute_pca_basis(data):
    # Center the data
    data_mean = torch.mean(data, dim=0)
    data_centered = data - data_mean

    # Compute covariance matrix
    covariance_matrix = torch.mm(data_centered.T, data_centered) / (data_centered.size(0) - 1)

    # Compute eigenvalues and eigenvectors using torch.linalg.eigh
    eigenvalues, eigenvectors = torch.linalg.eigh(covariance_matrix)

    # Sort eigenvalues and eigenvectors in descending order
    sorted_indices = torch.argsort(eigenvalues, descending=True)
    sorted_eigenvectors = eigenvectors[:, sorted_indices]

    return sorted_eigenvectors, data_mean


def transform_data(data, pca_basis, data_mean):
    data_centered = data - data_mean
    transformed_data = torch.mm(data_centered, pca_basis)
    return transformed_data


# pca transform
for p in tqdm(pos):
    print(f'calculating PCA basis for {p}')
    num_features = representations[0][p].shape[1]
    pca_basis, data_mean = compute_pca_basis(torch.cat([r[p] for r in representations[:-200]], dim=1).T)
    print('finding PCs corresponding to directions')
    representations_pca = torch.stack([transform_data(r[p].T, pca_basis, data_mean) for r in representations])

    # calc. correlation between all dimensions (PCs) and positions for whole (test) dataset
    representations_pca = representations_pca / representations_pca.pow(2).sum(dim=2, keepdim=True).sqrt()
    n = int(representations_pca.shape[2] ** 0.5)
    positions = torch.arange(n) - (n/2 - 0.5)
    positions /= positions.pow(2).sum().sqrt()
    y_correlations = (representations_pca * positions.repeat_interleave(n)[None, None, :]).sum(dim=2).mean(dim=0)
    x_correlations = (representations_pca * positions.repeat(n)[None, None, :]).sum(dim=2).mean(dim=0)
    y_sorted_idx =  y_correlations.abs().argsort(descending=True)
    x_sorted_idx =  x_correlations.abs().argsort(descending=True)
    print(f'Y best: {y_correlations[y_sorted_idx[0]]:+.4f} at PC {y_sorted_idx[0]:4}, top 10: {", ".join(f"{y:+.2f}" for y in y_correlations[y_sorted_idx[:10]])}')
    print(f'X best: {x_correlations[x_sorted_idx[0]]:+.4f} at PC {x_sorted_idx[0]:4}, top 10: {", ".join(f"{x:+.2f}" for x in x_correlations[x_sorted_idx[:10]])}')

In [None]:
representations_pca.shape