In [1]:
from sdhelper import SD
import torch
import numpy as np
import datasets
from tqdm.autonotebook import tqdm, trange
from matplotlib import pyplot as plt
import torch.nn as nn

In [None]:
# load model and dataset
data = datasets.load_dataset('JonasLoos/imagenet_subset', split='train')
images = [x['image'] for x in tqdm(data)]

In [22]:

sd15_all_blocks = {
    'conv_in': [
        'conv_in',
    ],
    'down_blocks[0]': [
        'down_blocks[0].resnets[0]',
        'down_blocks[0].attentions[0]',
        'down_blocks[0].resnets[1]',
        'down_blocks[0].attentions[1]',
        'down_blocks[0].downsamplers[0]',
    ],
    'down_blocks[1]': [
        'down_blocks[1].resnets[0]',
        'down_blocks[1].attentions[0]',
        'down_blocks[1].resnets[1]',
        'down_blocks[1].attentions[1]',
        'down_blocks[1].downsamplers[0]',
    ],
    'down_blocks[2]': [
        'down_blocks[2].resnets[0]',
        'down_blocks[2].attentions[0]',
        'down_blocks[2].resnets[1]',
        'down_blocks[2].attentions[1]',
        'down_blocks[2].downsamplers[0]',
    ],
    'down_blocks[3]': [
        'down_blocks[3].resnets[0]',
        'down_blocks[3].resnets[1]',
    ],
    'mid_block': [
        'mid_block.resnets[0]',
        'mid_block.attentions[0]',
        'mid_block.resnets[1]',
    ],
    'up_blocks[0]': [
        'up_blocks[0].resnets[0]',
        'up_blocks[0].resnets[1]',
        'up_blocks[0].upsamplers[0]',
    ],
    'up_blocks[1]': [
        'up_blocks[1].resnets[0]',
        'up_blocks[1].attentions[0]',
        'up_blocks[1].resnets[1]',
        'up_blocks[1].attentions[1]',
        'up_blocks[1].resnets[2]',
        'up_blocks[1].attentions[2]',
        'up_blocks[1].upsamplers[0]',
    ],
    'up_blocks[2]': [
        'up_blocks[2].resnets[0]',
        'up_blocks[2].attentions[0]',
        'up_blocks[2].resnets[1]',
        'up_blocks[2].attentions[1]',
        'up_blocks[2].resnets[2]',
        'up_blocks[2].attentions[2]',
        'up_blocks[2].upsamplers[0]',
    ],
    'up_blocks[3]': [
        'up_blocks[3].resnets[0]',
        'up_blocks[3].attentions[0]',
        'up_blocks[3].resnets[1]',
        'up_blocks[3].attentions[1]',
        'up_blocks[3].resnets[2]',
        'up_blocks[3].attentions[2]',
    ],
    'conv_out': [
        'conv_out',
    ]
}
blocks = [x for y in sd15_all_blocks.values() for x in y]

In [None]:
sd = SD('SD1.5')
representations = sd.img2repr(images, blocks, 50, seed=42)
del sd
torch.cuda.empty_cache()

In [None]:
# train config
batch_size = 512
num_epochs = 5
num_train = int(len(representations) * 0.8)

# train models
regressors = []
classifiers = []
accuracies_reg = torch.full((len(blocks), num_epochs), torch.nan)
accuracies_cls = torch.full((len(blocks), num_epochs), torch.nan)
for block_idx, block in enumerate(tq:=tqdm(blocks)):
    tq.set_postfix(block=block)
    _, features, w, h = representations[0][block].shape
    regressor = nn.Linear(features, 2).cuda()
    opt1 = torch.optim.Adam(regressor.parameters(), lr=1e-3)
    classifier = nn.Linear(features, w+h).cuda()
    opt2 = torch.optim.Adam(classifier.parameters(), lr=1e-3)
    regressors.append(regressor)
    classifiers.append(classifier)

    reprs = torch.stack([r[block].squeeze(0) for r in representations]).permute(0, 2, 3, 1).flatten(0, 2).cuda()
    labels = torch.stack(torch.meshgrid(torch.arange(w), torch.arange(h), indexing='ij'), dim=-1).expand(len(representations), -1, -1, -1).flatten(0, 2).cuda()
    reprs_train = reprs[:num_train*w*h]
    labels_train = labels[:num_train*w*h]
    reprs_test = reprs[num_train*w*h:]
    labels_test = labels[num_train*w*h:]

    for epoch in trange(num_epochs, leave=False):
        regressor.train()
        classifier.train()
        indices = torch.randperm(len(reprs_train))
        for i in trange(0, len(reprs_train), batch_size, leave=False):
            # Create shuffled indices for current batch
            reprs_batch = reprs_train[indices[i:i+batch_size]].float()
            labels_batch = labels_train[indices[i:i+batch_size]]

            regressor.zero_grad()
            preds = regressor(reprs_batch)
            loss = nn.functional.mse_loss(preds, labels_batch.float())
            loss.backward()
            opt1.step()

            classifier.zero_grad()
            preds = classifier(reprs_batch).view(batch_size, w, 2)
            loss = nn.functional.cross_entropy(preds, labels_batch)
            loss.backward()
            opt2.step()

        with torch.no_grad():
            regressor.eval()
            preds = regressor(reprs_test.float())
            loss_reg = nn.functional.mse_loss(preds, labels_test.float())
            acc_reg = (preds.round() == labels_test).float().mean()
            # print(f'regressor loss: {loss_reg.item():.4f}, accuracy: {acc_reg.item():.2%}')

            classifier.eval()
            preds = classifier(reprs_test.float()).view(len(reprs_test), w, 2)
            loss_cls = nn.functional.cross_entropy(preds, labels_test)
            acc_cls = (preds.argmax(dim=1) == labels_test).float().mean()
            # print(f'classifier loss: {loss_cls.item():.4f}, accuracy: {acc_cls.item():.2%}')
        
        accuracies_reg[block_idx, epoch] = acc_reg.cpu()
        accuracies_cls[block_idx, epoch] = acc_cls.cpu()


In [None]:
fig, ax = plt.subplots(figsize=(8, 3))

# plot x ticks
x = np.arange(len(blocks))
ticks = ['attn' if 'attentions' in block else 'res' if 'resnets' in block else 'down' if 'downsamplers' in block else 'up' if 'upsamplers' in block else 'conv' if 'conv' in block else '?' for block in blocks]
ax.set_xticks(x)
ax.set_xticklabels(ticks, rotation=90)


# compute main blocks names and positions
main_blocks = []
main_block_positions = []
layer_counter = 0
for block_name, block_list in sd15_all_blocks.items():
    if 'mid' in block_name:
        name = 'mid'
    elif 'conv' in block_name:
        name = block_name[5:]
    else:
        name = block_name.replace('_blocks','').replace('[','').replace(']','')
    main_blocks.append(name)
    main_block_positions.append(layer_counter)
    layer_counter += len(block_list)

# plot main blocks names
ax_x2 = ax.secondary_xaxis(location=0)
ax_x2.set_xticks([p+len(bl)/2-0.5 for p, bl in zip(main_block_positions, sd15_all_blocks.values())], labels=[f'\n\n\n{b}' for b in main_blocks], ha='center')
ax_x2.tick_params(length=0)

# lines between main blocks
for p in main_block_positions[1:]:
    ax.axvline(x=p-0.5, color='black', linestyle='--', c='lightgray')
ax_x3 = ax.secondary_xaxis(location=0)
ax_x3.set_xticks([p-0.5 for p in main_block_positions[1:]], labels=[])
ax_x3.tick_params(axis='x', length=34, width=1.5, color='lightgray')

# plot accuracies
colors_reg = plt.cm.Blues(np.linspace(0.3, 0.8, accuracies_reg.shape[1]))
colors_cls = plt.cm.Reds(np.linspace(0.3, 0.8, accuracies_cls.shape[1]))
for i in range(accuracies_reg.shape[1]):
    l1 = ax.plot(accuracies_cls[:, i]*100, label=f'classification', color=colors_cls[i])
    l2 = ax.plot(accuracies_reg[:, i]*100, label=f'regression', color=colors_reg[i])

# finish plot
ax.set_ylabel('test accuracy')
ax.legend([l1[0], l2[0]], [l1[0].get_label(), l2[0].get_label()])
plt.tight_layout()
plt.show()

In [None]:
fig, axs = plt.subplots(len(sd15_all_blocks), 3, figsize=(8, 2*len(sd15_all_blocks)))

block_idx = 0
for block, regressor, classifier in zip(blocks, regressors, classifiers):
    if block not in [x[-1] for x in sd15_all_blocks.values()]:
        continue
    with torch.no_grad():
        regressor.eval()
        classifier.eval()

        # Get predictions on test set
        r = representations[num_train][block].squeeze(0).permute(1,2,0)
        w, h, features = r.shape
        r = r.flatten(0, 1).float().cuda()
        preds_reg = regressor(r)
        preds_cls = classifier(r).view(len(r), w, 2)
        preds_cls = preds_cls.argmax(dim=1)

        colorwheel = torch.zeros((w, h, 3), dtype=torch.float32)
        offset = w/2 + .5
        for i in range(w):
            for j in range(h):
                angle = torch.atan2(torch.tensor(i)-offset, torch.tensor(j)-offset)
                dist = 1 - torch.sqrt((torch.tensor(i)-offset)**2 + (torch.tensor(j)-offset)**2) / offset / torch.sqrt(torch.tensor(2))
                colorwheel[i, j, :] = torch.tensor([.5+.5*torch.sin(angle), .5+.5*torch.sin(angle+torch.pi/2), dist]).clamp(0, 1)

        # Plot actual vs predicted positions
        if block_idx == 0:
            axs[0, 0].set_title('True Positions')
            axs[0, 1].set_title('Classification Results')
            axs[0, 2].set_title('Regression Results')
        block_name = block.split('.')[0].replace('_blocks', '').replace('_block', '').replace('_', '-')
        axs[block_idx, 0].text(-0.1, 0.5, block_name, ha='right', va='center', transform=axs[block_idx, 0].transAxes)

        # colorwheel for reference
        # axs[0].imshow(colorwheel.permute(1, 0, 2))
        axs[block_idx, 0].scatter(np.arange(w).repeat(h), np.tile(np.arange(h), w), c=colorwheel.flatten(0, 1), s=5000/w/h, alpha=0.5)
        axs[block_idx, 0].set_ylim(-0.1*w, 1.1*w)
        axs[block_idx, 0].set_xlim(-0.1*h, 1.1*h)
        axs[block_idx, 0].set_aspect('equal')
        axs[block_idx, 0].axis('off')

        # classification results
        axs[block_idx, 1].scatter(*preds_cls.T.cpu(), c=colorwheel.flatten(0, 1), s=5000/w/h, alpha=0.5)
        axs[block_idx, 1].set_ylim(-0.1*w, 1.1*w)
        axs[block_idx, 1].set_xlim(-0.1*h, 1.1*h)
        axs[block_idx, 1].set_aspect('equal')
        axs[block_idx, 1].axis('off')

        # Regression results
        axs[block_idx, 2].scatter(*preds_reg.T.cpu(), c=colorwheel.flatten(0, 1), s=5000/w/h, alpha=0.5)
        axs[block_idx, 2].set_ylim(-0.1*w, 1.1*w)
        axs[block_idx, 2].set_xlim(-0.1*h, 1.1*h)
        axs[block_idx, 2].set_aspect('equal')
        axs[block_idx, 2].axis('off')

        block_idx += 1

plt.tight_layout()
plt.show()

In [None]:
for block, regressor, classifier in zip(blocks, regressors, classifiers):
    with torch.no_grad():
        regressor.eval()
        classifier.eval()

        # Get predictions on test set
        r = representations[num_train][block].squeeze(0).permute(1,2,0)
        w, h, features = r.shape
        r = r.flatten(0, 1).float().cuda()
        preds_reg = regressor(r)
        preds_cls = classifier(r).view(len(r), w, 2)
        preds_cls = preds_cls.argmax(dim=1)

        colorwheel = torch.zeros((w, h, 3), dtype=torch.float32)
        offset = w/2 + .5
        for i in range(w):
            for j in range(h):
                angle = torch.atan2(torch.tensor(i)-offset, torch.tensor(j)-offset)
                dist = 1 - torch.sqrt((torch.tensor(i)-offset)**2 + (torch.tensor(j)-offset)**2) / offset / torch.sqrt(torch.tensor(2))
                colorwheel[i, j, :] = torch.tensor([.5+.5*torch.sin(angle), .5+.5*torch.sin(angle+torch.pi/2), dist]).clamp(0, 1)

        # Plot actual vs predicted positions
        fig, axs = plt.subplots(1, 3, figsize=(11, 4))

        # colorwheel for reference
        # axs[0].imshow(colorwheel.permute(1, 0, 2))
        axs[0].scatter(np.arange(w).repeat(h), np.tile(np.arange(h), w), c=colorwheel.flatten(0, 1), s=10000/w/h)
        axs[0].set_title('Colorwheel')
        axs[0].set_ylim(-0.1*w, 1.1*w)
        axs[0].set_xlim(-0.1*h, 1.1*h)
        axs[0].axis('off')

        # classification results
        axs[1].scatter(*preds_cls.T.cpu(), c=colorwheel.flatten(0, 1), alpha=0.5, s=10000/w/h)
        axs[1].set_title(block)
        axs[1].set_ylim(-0.1*w, 1.1*w)
        axs[1].set_xlim(-0.1*h, 1.1*h)
        axs[1].axis('off')

        # Regression results
        axs[2].scatter(*preds_reg.T.cpu(), c=colorwheel.flatten(0, 1), alpha=0.5, s=100/w)
        axs[2].set_title(block)
        axs[2].set_ylim(-0.1*w, 1.1*w)
        axs[2].set_xlim(-0.1*h, 1.1*h)
        axs[2].axis('off')

        plt.tight_layout()
        plt.show()

In [None]:
classifier(r).view(len(r), w, 2).argmax(dim=-1)
