# Analysis of the spatial variance/std of representations

### problem

are the scales of the different representation channels comparable? If not all this might be heavily biased. Analysis shows this is not too bad.

## Setup

In [None]:
from sdhelper import SD
from datasets import load_dataset
import numpy as np
import matplotlib.pyplot as plt
import torch
from PIL import Image
import random
from tqdm.autonotebook import tqdm, trange

In [None]:

def concat_reprs(reprs: dict[str, torch.Tensor], pos: list[str]):
    '''Concatenate representations with different spatial sizes into a single tensor with the largest spatial size.'''
    # If the representation sizes are not multiples of each other, the bottom and right edges of the spatially larger representations will be 0-padded.
    max_spatial = np.array(max(reprs[x].shape[-2:] for x in pos))
    min_spatial = np.array(min(reprs[x].shape[-2:] for x in pos))
    while (max_spatial > min_spatial).any(): min_spatial *= 2
    spatial = min_spatial
    num_features1 = sum(reprs[x].shape[1] for x in pos)
    repr_full = torch.zeros((num_features1, *spatial))
    i = 0
    for p in pos:
        r1 = reprs[p]
        _, num_channels1, n1, m1 = r1.shape
        tmp1 = r1.repeat_interleave(spatial[0]//n1, dim=-2).repeat_interleave(spatial[1]//m1, dim=-1)
        repr_full[i:i+num_channels1, :tmp1.shape[-2], :tmp1.shape[-1]] = tmp1
        i += num_channels1
    return repr_full


In [None]:
sd = SD('sdxl-turbo')

## Mean representation values

In [None]:
res = sd('a cat', extract_positions=['mid_block'])

In [None]:
print('mean repr value during generation:', res.representations['mid_block'][-1].abs().mean().item())
print(f'mean repr value during extraction', sd.img2repr(res.result_image, ['mid_block'], i)['mid_block'].abs().mean().item())  # timestep 0 is different from the last generation step

for i in range(0, 501,50):
    repr = sd.img2repr(res.result_image, ['mid_block'], i)['mid_block']
    plt.plot(sorted(repr.abs().mean(dim=(0,2,3)), reverse=True), label=f'step {i}')
plt.xlabel('channel (sorted by mean absolute value)')
plt.ylabel('mean absolute value')
plt.legend()
plt.show()

# binning on logaritmic scale
repr = sd.img2repr(res.result_image, ['mid_block'], 100)['mid_block']
repr = repr.abs().mean(dim=(0,2,3))
plt.hist(repr, bins=np.logspace(np.log10(repr.min()), np.log10(repr.max()), 50))
plt.xscale('log')
plt.xlabel('mean absolute value')
plt.ylabel('number of channels')
plt.show()

## Dispersion of representation values

In [None]:
def plot_std(
        img: Image.Image,
        pos: list[str],
        num_samples: int,
        step: int = 50,
):
    reprs = torch.stack([concat_reprs(sd.img2repr(img, pos, step), pos) for _ in trange(num_samples)])
    std = reprs.std(dim=0).mean(dim=0)
    plt.subplot(1, 2, 1)
    plt.title('Input image')
    plt.imshow(img)
    plt.axis('off')
    plt.subplot(1, 2, 2)
    plt.title(f'Standard deviation (mean: +-{std.mean().item():.2f})')
    plt.imshow(std.cpu().numpy(), vmin=0, vmax=4, cmap='hot')
    plt.show()

for i in range(5):
    plot_std(sd('a cat').result_image, ['mid_block'], 100)

In [None]:
pos = ['up_blocks[1]']
img = sd('a cat').result_image

stds = []
for timestep in tqdm(range(0, 251, 50)):
    reprs = torch.stack([concat_reprs(sd.img2repr(img, pos, timestep), pos) for _ in range(100)])
    std = reprs.std(dim=0).mean()
    stds.append(std)

In [None]:
img

In [None]:
plt.plot(range(0, 251, 50), stds)
plt.xlabel('Timestep')
plt.ylabel('Mean standard deviation')