In [1]:
import numpy as np
from sdhelper import SD
import torch
import matplotlib.pyplot as plt
from datasets import load_dataset
import torch
from tqdm.autonotebook import tqdm, trange
from torch.nn import functional as F
import PIL.Image
from ipywidgets import interact, FloatSlider, IntSlider


In [None]:
imagenet_subset = load_dataset("JonasLoos/imagenet_subset", split="train")
images = [x['image'].convert('RGB') for x in tqdm(imagenet_subset)][:20]

In [3]:
def rgb2gbr(arr: np.ndarray, interpolation: float):
    return arr[:,:,::-1] * interpolation + arr * (1-interpolation)

grass_image = np.array(PIL.Image.open('../data/grass.png').convert('RGB'))
def add_grass(img: np.ndarray, interpolation: float):
    return img * (1-interpolation/2) + grass_image[:img.shape[0],:img.shape[1],:] * interpolation/2

knitting_image = np.array(PIL.Image.open('../data/knitting.png').convert('RGB'))
def add_knitting(img: np.ndarray, interpolation: float):
    return img * (1-interpolation/2) + knitting_image[:img.shape[0],:img.shape[1],:] * interpolation/2

def edges_only(arr: np.ndarray, interpolation: float):
    gray_tensor = torch.from_numpy(np.dot(arr[...,:3], [0.2989, 0.5870, 0.1140]))[None,None,:,:]  # convert to grayscale
    
    # Sobel kernels
    kx = torch.tensor([[[[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]]]], dtype=torch.float64)
    ky = torch.tensor([[[[-1, -2, -1], [0, 0, 0], [1, 2, 1]]]], dtype=torch.float64)
    
    dx = F.conv2d(gray_tensor, kx, padding=1)
    dy = F.conv2d(gray_tensor, ky, padding=1)
    
    edges = torch.sqrt(dx**2 + dy**2)[0,0].numpy()
    edges **= .5  # increase sensitivity
    edges = (edges / edges.max() * 255).astype(np.uint8)  # normalize
    return np.stack([edges] * 3, axis=-1) * interpolation + arr * (1-interpolation)

def add_noise(arr: np.ndarray, interpolation: float):
    return arr * (1-interpolation/2) + np.random.rand(*arr.shape) * 255 * interpolation/2

def blur(arr: np.ndarray, interpolation: float):
    return np.array(PIL.Image.fromarray(arr).filter(PIL.ImageFilter.GaussianBlur(radius=16 * interpolation)))

conversion_function = add_noise

In [None]:
def interpolate_images(interpolation, image_index):
    a = np.array(images[image_index])
    result = conversion_function(a, interpolation).astype(np.uint8)

    plt.figure(figsize=(8,8))
    plt.imshow(result)
    plt.axis('off')
    plt.show()

interact(
    interpolate_images,
    interpolation=FloatSlider(min=0, max=1, step=0.01, value=0.5),
    image_index=IntSlider(min=0, max=20, step=1, value=0)
)

In [None]:
plt.figure(figsize=(10, 2))
for i, x in enumerate(np.linspace(0, 1, 5)):
    arr = np.array(images[0])
    plt.subplot(1, 5, i+1)
    plt.imshow(conversion_function(arr, x).astype(np.uint8))
    plt.title(f'{x:.2f}')
    plt.axis('off')
plt.show()

In [None]:
sd = SD()


In [None]:
reference_representations = sd.img2repr(images, sd.available_extract_positions, 50, seed=0)
representations = []
for step in tqdm(np.linspace(0, 1, 5)):
    tmp_images = [conversion_function(arr, step).astype(np.uint8) for img in images for arr in [np.array(img)]]
    representations.append(sd.img2repr(tmp_images, sd.available_extract_positions, 50, seed=42))

In [None]:
accuracies = np.zeros((len(sd.available_extract_positions), len(representations), len(images)))
for block_idx, block in enumerate(tqdm(sd.available_extract_positions)):
    for i in trange(len(images)):
        a = reference_representations[i].at(block).to('cuda')
        for int_idx, bs in enumerate(representations):
            b = bs[i].at(block).to('cuda')
            sim = a.cosine_similarity(b)
            n = sim.shape[0]
            accuracy = (sim.view(n*n, n*n).argmax(dim=0) == torch.arange(n*n, device='cuda')).float().mean().cpu()
            accuracies[block_idx, int_idx, i] = accuracy

In [None]:
accuracies.shape

In [None]:
fig, axs = plt.subplots(2, 1, figsize=(8, 6))
colors = plt.cm.rainbow(np.linspace(0, 1, len(sd.available_extract_positions)))
for block_idx, (block, color) in enumerate(zip(sd.available_extract_positions, colors)):
    axs[0].plot(np.linspace(0, 1, len(representations)), accuracies[block_idx].mean(axis=1), label=block, color=color)
axs[0].legend(bbox_to_anchor=(1.01, 0.5), loc='center left')
axs[0].set_xticklabels([])
axs[0].set_ylabel('Accuracy')

colors = plt.cm.rainbow(np.linspace(0, 1, len(sd.available_extract_positions)))
for block_idx, (block, color) in enumerate(zip(sd.available_extract_positions, colors)):
    init_acc = accuracies[block_idx, 0].mean()
    acc_change = (accuracies[block_idx].mean(axis=1) - init_acc) / init_acc
    axs[1].plot(np.linspace(0, 1, len(representations)), acc_change, label=block, color=color)
axs[1].set_xlabel('Interpolation step')
axs[1].set_ylabel('Rel. Change in Accuracy')

plt.tight_layout()
plt.show()
