In [None]:
import numpy as np
import matplotlib.pyplot as plt
from sdhelper import SD
from tqdm.notebook import tqdm, trange
import torch
import pickle

In [None]:
if torch.cuda.is_available():
    gpu_name = torch.cuda.get_device_name(0)
    print("GPU Name:", gpu_name)
else:
    print("No GPU available.")

In [None]:
labels = open('imagenet-labels.txt').readlines()

In [None]:
# load model
sd = SD('SDXL-Turbo', disable_progress_bar=True)

In [None]:
data = []
impact_distribution = np.zeros([16,16, 512,512])
for i in trange(1):
    prompt = random_label = np.random.choice(labels)
    seed = np.random.randint(0, 2**32)
    base_img = sd(random_label, steps=1, seed=seed).result_image
    impact_image = np.zeros([16,16])
    for i, j in tqdm(list(np.ndindex(16,16))):
        mask = torch.ones([1,1280,16,16], device=sd.device, dtype=torch.float16)
        mask[:,:,i,j] = 0
        mod_img = sd(random_label, steps=1, seed=seed, modification=lambda module, input, output, pos: output * mask if pos == 'mid_block' else None).result_image
        diff = np.abs(np.array(base_img) - np.array(mod_img))
        impact_image[i,j] = diff.mean()
        impact_distribution[i,j] += diff
    data.append((prompt, seed, base_img, impact_image))

In [None]:
# save data with pickle
# with open('h-space-locality-test.data3.pkl', 'wb') as f:
#     pickle.dump(data, f)

In [None]:
# load data if not defined
if 'data' not in locals():
    data = pickle.load(open('h-space-locality-test-data2.pkl', 'rb'))
    impact_distribution = np.load('h-space-locality-test-impact_distribution2.npy', allow_pickle=True)

In [None]:
# plot avg data
plt.imshow(np.mean([d[3] for d in data], axis=0))
plt.colorbar()
plt.axis('off')
plt.show()

In [None]:
# plot impact distribution
plt.figure(figsize=(10,10))
plt.imshow(impact_distribution.transpose(0,2,1,3).reshape(16*512,16*512))
# plt.colorbar()
plt.axis('off')
plt.show()

In [None]:
# plot histogram like scatter plot

def plot_impact_distribution(i, j):
    base_pos = (np.array([i,j]) + .5) / 16 * 512
    positions = np.array(list(np.ndindex(512,512)))
    values = impackt_distribution[i,j].flatten() / 50
    distances = ((positions - base_pos[None,:])**2).sum(axis=1)**.5
    plt.scatter(distances, values, s=1, alpha=.05)
    plt.title(f'Impact distribution for position {i},{j}')
    plt.xlabel('Distance from position')
    plt.ylabel('Impact')
    means = {d: [] for d in sorted(distances)}
    for d, v in zip(distances, values):
        means[d].append(v)
    x = list(means.keys())
    y = np.array([np.mean(v) for v in means.values()])
    y_cumsum = np.cumsum(y)
    y = (y_cumsum[200:] - y_cumsum[:-200]) / 200
    x = x[200:]
    plt.plot(x, y, label='mean impact', color='red')
    plt.show()

for i in range(16):
    plot_impact_distribution(i,i)

In [None]:
# plot 50th percentile impact distance compared to uniform distribution
def plot_perc_impact_distance2(percentile=.5):
    mean_impact_distance = np.zeros([16,16])
    mean_impact_distance_res = np.zeros([16,16])
    mean_impact_distance_uni = np.zeros([16,16])
    for i, j in tqdm(np.ndindex(16,16), total=16*16):
        base_pos = (np.array([i,j]) + .5) / 16 * 512
        positions = np.array(list(np.ndindex(512,512)))
        values = impact_distribution[i,j].flatten() / 50
        values /= values.sum()
        distances = ((positions - base_pos[None,:])**2).sum(axis=1)**.5
        indices = np.argsort(distances)
        values_cumsum = np.cumsum(values[indices])
        uniform_cumsum = np.cumsum(np.ones_like(values)/values.size)
        result_dist = distances[indices][np.argmax(values_cumsum > percentile)]
        uniform_dist = distances[indices][np.argmax(uniform_cumsum > percentile)]
        mean_impact_distance[i,j] = result_dist - uniform_dist
        mean_impact_distance_res[i,j] = result_dist
        mean_impact_distance_uni[i,j] = uniform_dist
    for x in [mean_impact_distance]:
        plt.title(f'{percentile*100}th percentile impact distance')
        plt.imshow(x)
        plt.colorbar()
        plt.axis('off')
        plt.show()

plot_perc_impact_distance2(.01)
plot_perc_impact_distance2(.1)
plot_perc_impact_distance2(.5)
plot_perc_impact_distance2(.9)

In [None]:
for prompt, seed, base_img, impact_image in data:
    # plot
    plt.subplot(1,2,1)
    plt.title(f'"{prompt}"')
    plt.imshow(base_img)
    plt.axis('off')
    plt.subplot(1,2,2)
    plt.title('masking impact')
    plt.imshow(impact_image)
    plt.axis('off')
    # plt.colorbar()
    plt.tight_layout()
    plt.show()