In [None]:
import numpy as np
from sklearn.metrics import pairwise
import matplotlib.pyplot as plt
import seaborn as sns
import scipy.cluster.hierarchy as spc
import os
import pickle
from lib.get_model import get_model, GetGANModelParameters, GetDiffusionModelParameters
from lib.data import create_dataset
from lib.random import set_random_seed
import torch
from lib.path import create_evaluate_path
import lib.metrics as metrics
from lib.metrics import multiprocessed_pixel_wise_correlation_phi
from mpl_toolkits.axes_grid1 import host_subplot
import mpl_toolkits.axisartist as AA
import matplotlib.pyplot as plt
import torchvision.utils as vutils

In [None]:
def CKA(X, Y, kernel=lambda x,y: pairwise.rbf_kernel(x,y)):
    '''Estimates centered kernel alignment between k(X,X') and k(Y,Y')'''
    assert X.shape == Y.shape
    n_size = X.shape[0]
    K_X = kernel(X, X)
    K_Y = kernel(Y, Y)
    centering = np.eye(n_size) - np.ones((n_size, n_size))/n_size
    K_Xc = centering @ K_X @ centering
    K_Yc = centering @ K_Y @ centering
    K_Xc_norm = np.linalg.norm(K_Xc, ord='fro')
    K_Yc_norm = np.linalg.norm(K_Yc, ord='fro')
    return np.trace(K_Xc.T @ K_Yc) / (K_Xc_norm * K_Yc_norm)

In [None]:
import multiprocess as mp
from functools import partial
from lib.notebook import get_tqdm

tqdm = get_tqdm()

def pixel_wise_CKA(index, image_size, samples, kernel=lambda x, y: pairwise.rbf_kernel(x, y, gamma=None)):
    """
    Computes the CKA between a specific pixel and all other pixels in an image.

    Args:
        index (int): The linear index of the pixel for which the correlation is to be computed.
        image_size (int): The size of the image (assumed to be a square image with dimensions image_size x image_size).
        samples (numpy.ndarray): A 4D numpy array of samples with shape (batch_size, height, width, channels).
        kernel (function): A function that computes the kernel between two sets of samples.
                           The default kernel is the Radial Basis Function (RBF) kernel with an unspecified gamma parameter.

    Returns:
        list: A 2D list where each element represents the phi correlation between the specified pixel and another pixel in the image.
    """
    i, j = index // image_size, index % image_size
    return [
        [
        CKA(samples[:, :, i, j], samples[:, :, px2, py2], kernel=kernel) if index<px2*image_size+py2 else 1.
            for py2 in range(image_size)
        ]
        for px2 in range(image_size)
    ]

def multiprocessed_pixel_wise_CKA(image_size, samples, cpus=8, kernel=lambda x, y: pairwise.rbf_kernel(x, y, gamma=None)):
    """
    Computes the pixel-wise CKA between all pixels in an image using multiprocessing.

    Args:
        image_size (int): The size of the image (assumed to be a square image with dimensions image_size x image_size).
        samples (iterable): An iterable of 4D numpy arrays, where each array has shape (batch_size, height, width, channels).
        cpus (int): The number of CPU cores to use for multiprocessing. Default is 8.
        kernel (function): A function that computes the kernel between two sets of samples.
                           The default kernel is the Radial Basis Function (RBF) kernel with an unspecified gamma parameter.

    Returns:
        numpy.ndarray: A 4D array of pixel-wise phi correlations with shape (image_size, image_size, image_size, image_size).
    """
    iterations = range(image_size**2)
    with mp.Pool(cpus) as pool:
        partial_worker_function = partial(pixel_wise_CKA, image_size=image_size, samples=samples, kernel=kernel)
        corr = np.array(list(tqdm(pool.imap(partial_worker_function, iterations), total=image_size**2)))
        return corr.reshape(image_size, image_size, image_size, image_size)

In [None]:
##########################
### DCGAN CelebA #########
##########################

In [None]:

model_name = "DCGAN64x64"
max_images_per_epoch = 0
n = 20
m = n
batch_size_gmmd = 150  # 100
batch_size_corr = 100
nc = 3
max_epochs = 50
epochs = list(range(0, max_epochs + 1))  # [1, 25, 50]
seed = set_random_seed(6746)
dataset_name = "celeba"
image_size = 64
n_subset = 1000
cpus = 75
cmap = cmap=plt.cm.Reds
plt.rcParams.update({'font.size': 14})

training_dataset = create_dataset(dataset_name, image_size, subset=n_subset, load_dataset_in_memory=True)
training_data = training_dataset.data.numpy()
print(training_data.shape)

# heuristic for setting kernel bandwidth (=gamma)
gamma = 1/np.median(pairwise.euclidean_distances(training_data.reshape(training_data.shape[0], -1), squared=True))

def rbf_kernel_(x,y):
    return pairwise.rbf_kernel(x, y, gamma=gamma)

kernel = rbf_kernel_
kernel_name = "rbf_kernel"

In [None]:
os.mkdir(f"data/{dataset_name}/corrs/")
agg = []
for i in range(n_subset // batch_size_corr):
    batch = training_data[i*batch_size_corr:(i+1)*batch_size_corr]
    data_corr_tri = multiprocessed_pixel_wise_CKA(image_size, batch, cpus=cpus, kernel=kernel)
    data_corr_tri = data_corr_tri.reshape(image_size**2, image_size**2)
    ### make triangular matrix symmetric
    data_corr = data_corr_tri.copy()
    data_corr[np.tril_indices_from(data_corr_tri, k=-1)] = data_corr_tri.T[np.tril_indices_from(data_corr_tri, k=-1)]
    data_corr = data_corr.reshape(image_size, image_size, image_size, image_size)

    with open(f"data/{dataset_name}/corrs/CKA {kernel_name} {seed} n_subset-{n_subset} batch-{i}.pkl", "wb") as file:
        pickle.dump(data_corr, file)

In [None]:
n_batches = n_subset // batch_size_corr
data_corr = []
for i in tqdm(range(n_batches), total=n_batches):
    with open(f"data/{dataset_name}/corrs/CKA {kernel_name} {seed} n_subset-{n_subset} batch-{i}.pkl", "rb") as file:
        data_corr.append(pickle.load(file))
data_corr = np.array(data_corr).mean(axis=0)

In [None]:
os.makedirs(f"plots/{dataset_name} {image_size}x{image_size}/corr_pixel_wise_cluster/", exist_ok = True)
os.makedirs(f"plots/{dataset_name} {image_size}x{image_size}/corr_pixel_wise_cluster/", exist_ok = True)

num_clusters = 5
colors = {
    1: '#003f5c',
    2: '#fb8500',
    3: '#e40b0b',
    4: '#219ebc',
    5: '#8ab17d',
}

big_corr_matrix = data_corr.reshape(image_size*image_size, image_size*image_size)
plt.rcParams.update({'font.size': 22})
pdist = spc.distance.pdist(big_corr_matrix)
linkage = spc.linkage(pdist, method='average')
if num_clusters is None:
    idx = spc.fcluster(linkage, 0.5 * pdist.max(), criterion='distance')
else:
    idx = spc.fcluster(linkage, criterion='maxclust', t=num_clusters)
cluster = np.array(idx).reshape(image_size, image_size)
plt.imshow(cluster, cmap=plt.cm.colors.ListedColormap(colors.values()))
handles = [plt.Line2D([0, 1], [0, 1], color=colors[label], linewidth=5) for label in range(1, num_clusters + 1)]
labels = [f"Cluster {label}" for label in range(1, num_clusters + 1)]
plt.legend(handles, labels, loc='upper left', bbox_to_anchor=(1, 1))
plt.savefig(f"plots/{dataset_name} {image_size}x{image_size}/corr_pixel_wise_cluster/Clustering {kernel_name=}.pdf", bbox_inches='tight')
plt.show()
order = np.argsort(idx)
reordered_corr_matrix = big_corr_matrix[order][:, order]
print(reordered_corr_matrix.shape)
plt.imshow(reordered_corr_matrix, cmap=cmap, vmin=0, vmax=1)
plt.colorbar()
plt.savefig(f"plots/{dataset_name} {image_size}x{image_size}/corr_pixel_wise_cluster/Clustered Correlation Matrix {kernel_name=} {num_clusters}.pdf", bbox_inches='tight')
plt.show();

In [None]:
create_evaluate_path(dataset_name, image_size)
dataset = create_dataset(dataset_name, image_size, load_dataset_in_memory=True, type="test")
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size_gmmd, shuffle=False, num_workers=0)
data = dataset.data[:len(dataset)//16]
data = np.array(data[:(len(data)//batch_size_gmmd) * batch_size_gmmd].reshape(-1, batch_size_gmmd, nc, image_size, image_size))
data.shape

In [None]:
model_params = GetGANModelParameters()
model_params = [
    model_param for model_param in model_params
    if model_param.get()[1]==64 and model_param.get()[2]==64 and model_param.get()[3]==0.0002
]
filename = f'data/samples DCGAN64x64 0 {model_params[0]}.pkl'
with open(filename, 'rb') as file:
    exp_results = pickle.load(file)
tests = np.array([exp_results])
tests.shape

In [None]:
test_data_pixel_cluster = [np.array([data[:, :, :, x, y] for x in range(image_size) for y in range(image_size) if cluster[x, y] == c]).transpose(1, 2, 0, 3).reshape(*data.shape[:2], -1) for c in set(cluster.flatten())]

In [None]:
cluster_max_epochs = 50

In [None]:
kernel = rbf_kernel_
gmmd_dist = [[[] for _ in set(cluster.flatten())] for _ in tests]
cos_dist = [[[] for _ in set(cluster.flatten())] for _ in tests]
for j, model_tests in enumerate(tests):
    for test_epoch in tqdm(model_tests[:cluster_max_epochs], total=cluster_max_epochs):
        generated_data_pixel_cluster = [np.array([test_epoch[:, :, :, x, y] for x in range(image_size) for y in range(image_size) if cluster[x, y] == c]).transpose(1, 2, 0, 3).reshape(n, m, -1) for c in set(cluster.flatten())]
        for i, (generated_cluster, test_cluster) in enumerate(zip(generated_data_pixel_cluster, test_data_pixel_cluster)):
            gmmd_dist[j][i].append(metrics.gmmd(generated_cluster, test_cluster, kernel=kernel))
            cos_dist[j][i].append(metrics.cos(generated_cluster, test_cluster, kernel=kernel))

In [None]:
with open(f"data/{dataset_name}/cms_mmd_{kernel_name}_{seed}_n_subset-{n_subset}.pkl", "wb") as file:
    pickle.dump((cos_dist, gmmd_dist), file)

In [None]:
with open(f"data/{dataset_name}/cms_mmd_{kernel_name}_{seed}_n_subset-{n_subset}.pkl", "rb") as file:
    (cos_dist, gmmd_dist) = pickle.load(file)

gmmd_dist = np.array(gmmd_dist)
cos_dist = np.array(cos_dist)
cos_dist.shape

In [None]:
os.makedirs(f"plots/{dataset_name} {image_size}x{image_size}/cluster_wise", exist_ok=True)

plt.rcParams.update({'font.size': 18})
for model_cos_dist, model_param in zip(cos_dist, model_params):
    batch_size, ngf, ndf, lr, beta1 = model_param.get()
    print(f"cos dist for all clusters {model_param}")
    for i, cos_dist_i in enumerate(model_cos_dist, 1):
        plt.plot(cos_dist_i, label=f"Cluster {i}", color=colors[i], linewidth=5)
    plt.xlim(1, cluster_max_epochs)
    plt.ylim(np.min(model_cos_dist[:, 1:cluster_max_epochs + 1]), np.max(model_cos_dist[:, 1:cluster_max_epochs + 1]))
    plt.legend(loc='upper left', bbox_to_anchor=(1, 1))
    plt.xlabel("Epochs")
    plt.ylabel("CMS")
    #plt.xticks(range(2, cluster_max_epochs+1, 4))
    # plt.title("gmmd dist for all clusters")
    plt.savefig(f"plots/{dataset_name} {image_size}x{image_size}/cluster_wise/Cluster cos of {model_name} {max_images_per_epoch} {model_param} {kernel_name}.pdf", bbox_inches='tight')
    plt.show()

In [None]:
cpus = 1
kernel = rbf_kernel_
flatted_tests = tests.reshape(len(model_params), len(epochs), n, m, -1)
image_wise_mmd = np.array([metrics.multiprocessed_gmmd(flatted_tests[i], data.reshape(data.shape[0], data.shape[1], -1), kernel=kernel, cpus=cpus) for i in range(len(model_params))])
image_wise_cos = np.array([metrics.multiprocessed_cos(flatted_tests[i], data.reshape(data.shape[0], data.shape[1], -1), kernel=kernel, cpus=cpus) for i in range(len(model_params))])
max_epochs = 50

with open(f"data/{dataset_name}/image_wise_cms_mmd_{kernel_name}_{seed}_n_subset-{n_subset}.pkl", "wb") as file:
    pickle.dump((image_wise_cos, image_wise_mmd), file)

In [None]:
with open(f"data/{dataset_name}/image_wise_cms_mmd_{kernel_name}_{seed}_n_subset-{n_subset}.pkl", "rb") as file:
    (image_wise_cos, image_wise_mmd) = pickle.load(file)

image_wise_cos_prod = cos_dist.prod(axis=1)
image_wise_cos_prod.shape

In [None]:
n_clusters_to_eval = [2, 5, 50]

In [None]:
def eval_clusters(num_clusters):
    big_corr_matrix = data_corr.reshape(image_size*image_size, image_size*image_size)
    pdist = spc.distance.pdist(big_corr_matrix)
    linkage = spc.linkage(pdist, method='average')
    if num_clusters is None:
        idx = spc.fcluster(linkage, 0.5 * pdist.max(), criterion='distance')
    else:
        idx = spc.fcluster(linkage, criterion='maxclust', t=num_clusters)
    cluster = np.array(idx).reshape(image_size, image_size)
    test_data_pixel_cluster = [np.array([data[:, :, :, x, y] for x in range(image_size) for y in range(image_size) if cluster[x, y] == c]).transpose(1, 2, 0, 3).reshape(*data.shape[:2], -1) for c in set(cluster.flatten())]
    cos_dist_cluster = [[[] for _ in set(cluster.flatten())] for _ in tests]
    for j, model_tests in enumerate(tests):
        for test_epoch in tqdm(model_tests[:cluster_max_epochs], total=cluster_max_epochs):
            generated_data_pixel_cluster = [np.array([test_epoch[:, :, :, x, y] for x in range(image_size) for y in range(image_size) if cluster[x, y] == c]).transpose(1, 2, 0, 3).reshape(n, m, -1) for c in set(cluster.flatten())]
            for i, (generated_cluster, test_cluster) in enumerate(zip(generated_data_pixel_cluster, test_data_pixel_cluster)):
                cos_dist_cluster[j][i].append(metrics.cos(generated_cluster, test_cluster, kernel=kernel))
    cos_dist_cluster = np.array(cos_dist_cluster)
    return cos_dist_cluster.prod(axis=1)

image_wise_cos_cluster_prod = []
for num_clusters in tqdm(n_clusters_to_eval, total=len(n_clusters_to_eval)):
    results = eval_clusters(num_clusters)
    image_wise_cos_cluster_prod += [results]

In [None]:
with open(f"data/{dataset_name}/image_wise_cms_cluster_prod_{kernel_name}_{seed}_n_subset-{n_subset}.pkl", "wb") as file:
    pickle.dump(image_wise_cos_cluster_prod, file)

In [None]:
with open(f"data/{dataset_name}/image_wise_cms_cluster_prod_{kernel_name}_{seed}_n_subset-{n_subset}.pkl", "rb") as file:
    image_wise_cos_cluster_prod = pickle.load(file)

image_wise_cos_cluster_prod[0].shape

In [None]:
os.makedirs(f"plots/{dataset_name} {image_size}x{image_size}/image_wise", exist_ok=True)

max_epochs = 50
linewidth=5

for image_wise_mmd_i, image_wise_cos_i, model_param in zip(image_wise_mmd, image_wise_cos, model_params):
    batch_size, ngf, ndf, lr, beta1 = model_param.get()
    print(f"dist image wise {model_param}")
    s = slice(0, max_epochs+1)
    _s = slice(1, max_epochs+1)
    plt.plot(image_wise_cos_i[s], label="CMS (Higher is better)", linewidth=linewidth)
    for j, image_wise_cos_prod_i in zip(n_clusters_to_eval, image_wise_cos_cluster_prod):
        plt.plot(image_wise_cos_prod_i[0][s], label=f"Product CMS for {j} clusters", linewidth=linewidth)
    plt.legend(loc='upper left', bbox_to_anchor=(0, 1.5))
    plt.xlim((1, max_epochs))
    plt.xlabel("Epochs")
    plt.ylabel("Value")
    plt.savefig(f"plots/{dataset_name} {image_size}x{image_size}/image_wise/multi cluster CMS {model_name} {max_images_per_epoch} {model_param} {kernel_name}.pdf", bbox_inches='tight')
    plt.show()

In [None]:
os.makedirs(f"plots/celeba 64x64/gen_samples", exist_ok=True)
look_ensambles = 20
look_samples = 20
for epoch in epochs[slice(1, 50, 4)]:
    print((f"{epoch=}"))
    images = (torch.tensor(tests[0, epoch, :look_ensambles, :look_samples].reshape((look_ensambles - 0)*look_samples, nc, image_size, image_size)) + 1)/2
    plt.imshow(np.transpose(vutils.make_grid(images, nrow=look_samples), (1, 2, 0)))
    plt.axis("off")
    plt.savefig(f"plots/celeba 64x64/gen_samples/{dataset_name} epoch{epoch} {model_params[0]}.pdf", bbox_inches='tight')
    plt.show()

In [None]:
with open('data/singles.pkl', 'rb') as f:
     isc_mean_single, isc_std_single, fid_single, kid_mean_single, kid_std_single = pickle.load(f)

In [None]:
os.makedirs(f"plots/{dataset_name} {image_size}x{image_size}/image_wise", exist_ok=True)

plt.rcParams.update({'font.size': 18})
linewidth = 5

host = host_subplot(111, axes_class=AA.Axes)
plt.subplots_adjust(right=0.75)

par1 = host.twinx()
par2 = host.twinx()

offset = 75
new_fixed_axis = par2.get_grid_helper().new_fixed_axis
par2.axis["right"] = new_fixed_axis(loc="right", axes=par2,
                                        offset=(offset, 0))
par1.axis["right"].toggle(all=True)
par2.axis["right"].toggle(all=True)

host.set_xlabel("Epochs")
host.set_ylabel("ISC")
par1.set_ylabel("KID / MMD")
par2.set_ylabel("FID")

x_axis = np.arange(1, 51)
p1, = host.plot(x_axis, isc_mean_single.mean(0)[1:], label="ISC (Higher is better)", linewidth=linewidth)
p2, = par1.plot(x_axis, kid_mean_single.mean(0)[1:], label="KID (Lower is better)", linewidth=linewidth)
p3, = par2.plot(x_axis, fid_single.mean(0)[1:], label="FID (Lower is better)", linewidth=linewidth)
p4, = par1.plot(x_axis, image_wise_mmd[0][1:], label="MMD (Lower is better)", linewidth=linewidth)

host.legend(loc='upper left', bbox_to_anchor=(0., 1.5))

plt.draw()
plt.savefig(f"plots/{dataset_name} {image_size}x{image_size}/image_wise/alternative errors.pdf", bbox_inches='tight')
plt.show();

In [None]:
##############################
##### ChestMNIST #############
##############################

In [None]:
##############################
##### DCGAN ##################
##############################

In [None]:
seed = set_random_seed(6746)
n = 20
m = n
ngpu = 1
n_subset = 2000
batch_size_corr = 100
batch_size_gmmd = 150  # 100
workers_gmmd = 22
nz = 100
nc = 1
image_size = 28
model_name = "DCGAN28x28"
dataset_name = "ChestMNIST"
max_epochs = 50
epochs = list(range(0, max_epochs + 1))  # [1, 25, 50]
corr_epoch = 40
cpus = 1 # mp.cpu_count()
max_images_per_epoch = 0  # set to 0 for max
num_clusters = 5  # None
device = torch.device("cuda:0" if (torch.cuda.is_available() and ngpu > 0) else "cpu")

In [None]:
plt.rcParams.update({'font.size': 14})

training_dataset = create_dataset(dataset_name, image_size, subset=n_subset, load_dataset_in_memory=True)
training_data = training_dataset.data.numpy()
print(training_data.shape)

gamma = 1/np.median(pairwise.euclidean_distances(training_data.reshape(training_data.shape[0], -1), squared=True))

def rbf_kernel_(x,y):
    return pairwise.rbf_kernel(x,y,gamma=gamma)

kernel = rbf_kernel_
kernel_name = "rbf_kernel"


In [None]:
os.makedirs(f"data/{dataset_name}/corrs", exist_ok=True)

cpus = 50
agg = []
for i in range(n_subset // batch_size_corr):
    batch = training_data[i*batch_size_corr:(i+1)*batch_size_corr]
    data_corr_tri = multiprocessed_pixel_wise_CKA(image_size, batch, cpus=cpus, kernel=kernel)
    data_corr_tri = data_corr_tri.reshape(image_size**2, image_size**2)
    ### make triangular matrix symmetric
    data_corr = data_corr_tri.copy()
    data_corr[np.tril_indices_from(data_corr_tri, k=-1)] = data_corr_tri.T[np.tril_indices_from(data_corr_tri, k=-1)]
    data_corr = data_corr.reshape(image_size, image_size, image_size, image_size)

    with open(f"data/{dataset_name}/corrs/CKA {kernel_name} {seed} n_subset-{n_subset} batch-{i}.pkl", "wb") as file:
        pickle.dump(data_corr, file)


In [None]:
n_batches = n_subset // batch_size_corr
data_corr = []
for i in tqdm(range(n_batches), total=n_batches):
    with open(f"data/{dataset_name}/corrs/CKA {kernel_name} {seed} n_subset-{n_subset} batch-{i}.pkl", "rb") as file:
        data_corr.append(pickle.load(file))
data_corr = np.array(data_corr).mean(axis=0)

In [None]:
os.makedirs(f"plots/{dataset_name} {image_size}x{image_size}/corr_pixel_wise_cluster", exist_ok=True)

cmap = cmap=plt.cm.Reds
num_clusters = 6
colors = {
    1: '#003f5c',
    2: '#fb8500',
    3: '#e40b0b',
    4: '#219ebc',
    5: '#8ab17d',
    6: '#d08c60',
}
big_corr_matrix = data_corr.reshape(image_size*image_size, image_size*image_size)
pdist = spc.distance.pdist(big_corr_matrix)
linkage = spc.linkage(pdist, method='average')
if num_clusters is None:
    idx = spc.fcluster(linkage, 0.5 * pdist.max(), criterion='distance')
else:
    idx = spc.fcluster(linkage, criterion='maxclust', t=num_clusters)
cluster = np.array(idx).reshape(image_size, image_size)
plt.imshow(cluster, cmap=plt.cm.colors.ListedColormap(colors.values()))
handles = [plt.Line2D([0, 1], [0, 1], color=colors[label], linewidth=3) for label in range(1, num_clusters + 1)]
labels = [f"Cluster {label}" for label in range(1, num_clusters + 1)]
plt.legend(handles, labels, loc='upper left', bbox_to_anchor=(1, 1))

plt.savefig(f"plots/{dataset_name} {image_size}x{image_size}/corr_pixel_wise_cluster/same {model_name} models Clustering {batch_size=} {ngf=} {ndf=} {lr=} {beta1=} {kernel_name=}.pdf", bbox_inches='tight')
plt.show()
order = np.argsort(idx)
reordered_corr_matrix = big_corr_matrix[order][:, order]
print(reordered_corr_matrix.shape)
plt.imshow(reordered_corr_matrix, cmap=cmap, vmin=0, vmax=1)
plt.colorbar()
plt.savefig(f"plots/{dataset_name} {image_size}x{image_size}/corr_pixel_wise_cluster/same {model_name} models Clustered Correlation Matrix {batch_size=} {ngf=} {ndf=} {lr=} {beta1=} {kernel_name=}.pdf", bbox_inches='tight')
plt.show()

In [None]:
dataset = create_dataset(dataset_name, image_size, load_dataset_in_memory=True, type="test")
data = dataset.data[:len(dataset)//16]
data = np.array(data[:(len(data)//batch_size_gmmd) * batch_size_gmmd].reshape(-1, batch_size_gmmd, nc, image_size, image_size))
print(data.shape)

test_data_pixel_cluster = [np.array([data[:, :, :, x, y] for x in range(image_size) for y in range(image_size) if cluster[x, y] == c]).transpose(1, 2, 0, 3).reshape(*data.shape[:2], -1) for c in set(cluster.flatten())]

In [None]:
model_params = GetGANModelParameters()
model_params = [
    model_param for model_param in model_params
    if model_param.get()[1]==64 and model_param.get()[2]==64 and model_param.get()[3]==0.0008
]
filename = f'data/samples DCGAN28x28 0 {model_params[0]}.pkl'
with open(filename, 'rb') as file:
    exp_results = pickle.load(file)
tests = np.array([exp_results])
tests.shape

In [None]:
os.makedirs(f"plots/celeba 28x28/gen_samples", exist_ok=True)

look_ensambles = 20
look_samples = 20
for epoch in epochs[slice(1, 6, 1)]:
    print((f"{epoch=}"))
    images = (torch.tensor(tests[0, epoch, :look_ensambles, :look_samples].reshape((look_ensambles - 0)*look_samples, nc, image_size, image_size)) + 1)/2
    plt.imshow(np.transpose(vutils.make_grid(images, nrow=look_samples), (1, 2, 0)))
    plt.axis("off")
    plt.savefig(f"plots/celeba 28x28/gen_samples/{dataset_name} epoch{epoch} {model_params[0]}.pdf", bbox_inches='tight')
    plt.show()

In [None]:
cluster_max_epochs = 20
kernel = rbf_kernel_
gmmd_dist = [[[] for _ in set(cluster.flatten())] for _ in tests]
cos_dist = [[[] for _ in set(cluster.flatten())] for _ in tests]
for j, model_tests in enumerate(tests):
    for test_epoch in tqdm(model_tests[:cluster_max_epochs], total=cluster_max_epochs):
        generated_data_pixel_cluster = [np.array([test_epoch[:, :, :, x, y] for x in range(image_size) for y in range(image_size) if cluster[x, y] == c]).transpose(1, 2, 0, 3).reshape(n, m, -1) for c in set(cluster.flatten())]
        for i, (generated_cluster, test_cluster) in enumerate(zip(generated_data_pixel_cluster, test_data_pixel_cluster)):
            gmmd_dist[j][i].append(metrics.gmmd(generated_cluster, test_cluster, kernel=kernel))
            cos_dist[j][i].append(metrics.cos(generated_cluster, test_cluster, kernel=kernel))

In [None]:
os.makedirs(f"plots/{dataset_name} {image_size}x{image_size}/cluster_wise", exist_ok=True)

plt.rcParams.update({'font.size': 18})
for model_cos_dist, model_param in zip(cos_dist, model_params):
    batch_size, ngf, ndf, lr, beta1 = model_param.get()
    print(f"cos dist for all clusters {model_param}")
    for i, cos_dist_i in enumerate(model_cos_dist, 1):
        plt.plot(np.arange(1, len(cos_dist_i)+1), cos_dist_i, label=f"Cluster {i}", color=colors[i], linewidth=5)
    plt.xlim(1, 6)
    plt.legend(loc='upper left', bbox_to_anchor=(1, 1))
    plt.xlabel("Epochs")
    plt.ylabel("CMS")
    plt.savefig(f"plots/{dataset_name} {image_size}x{image_size}/cluster_wise/Cluster cos of {model_name} {max_images_per_epoch} {model_param} {kernel_name}.pdf", bbox_inches='tight')
    plt.show()

In [None]:
cpus = 1
epochs = np.arange(0, 51)
flatted_tests = tests.reshape(len(model_params), len(epochs), n, m, -1)
image_wise_mmd = np.array([metrics.multiprocessed_gmmd(flatted_tests[i], data.reshape(data.shape[0], data.shape[1], -1), kernel=kernel, cpus=cpus) for i in range(len(model_params))])
image_wise_cos = np.array([metrics.multiprocessed_cos(flatted_tests[i], data.reshape(data.shape[0], data.shape[1], -1), kernel=kernel, cpus=cpus) for i in range(len(model_params))])
image_wise_cos.shape

In [None]:
image_wise_cos_cluster_prod = [np.array(cos_dist).prod(axis=1)]
image_wise_cos_cluster_prod[0].shape

In [None]:
os.makedirs(f"plots/{dataset_name} {image_size}x{image_size}/image_wise", exist_ok=True)

max_epochs = 6
linewidth=5
n_clusters_to_eval = [5]

for image_wise_mmd_i, image_wise_cos_i, model_param in zip(image_wise_mmd, image_wise_cos, model_params):
    batch_size, ngf, ndf, lr, beta1 = model_param.get()
    print(f"dist image wise {model_param}")
    s = slice(0, max_epochs)
    _s = slice(1, max_epochs+1)
    plt.plot(np.arange(1, max_epochs+1), image_wise_cos_i[s], label="CMS (Higher is better)", linewidth=linewidth)
    for j, image_wise_cos_prod_i in zip(n_clusters_to_eval, image_wise_cos_cluster_prod):
        plt.plot(np.arange(1, max_epochs+1), image_wise_cos_prod_i[0][s], label=f"Product CMS for {j} clusters", linewidth=linewidth)
    plt.legend(loc='upper left', bbox_to_anchor=(0, 1.3))
    plt.xlim((1, max_epochs))
    plt.xlabel("Epochs")
    plt.ylabel("Value")
    plt.savefig(f"plots/{dataset_name} {image_size}x{image_size}/image_wise/multi cluster CMS {model_name} {max_images_per_epoch} {model_param} {kernel_name}.pdf", bbox_inches='tight')
    plt.show()

In [None]:
model_params = GetGANModelParameters()
model_params = [
    model_param for model_param in model_params
    if model_param.get()[1]==64 and model_param.get()[2]==64 and model_param.get()[3]==1e-05
]
filename = f'data/samples DCGAN28x28 0 {model_params[0]}.pkl'
with open(filename, 'rb') as file:
    exp_results = pickle.load(file)
tests = np.array([exp_results])
tests.shape

In [None]:
os.makedirs(f"plots/celeba 28x28/gen_samples", exist_ok=True)

look_ensambles = 20
look_samples = 20
for epoch in epochs[slice(1, 6, 1)]:
    print((f"{epoch=}"))
    images = (torch.tensor(tests[0, epoch, :look_ensambles, :look_samples].reshape((look_ensambles - 0)*look_samples, nc, image_size, image_size)) + 1)/2
    plt.imshow(np.transpose(vutils.make_grid(images, nrow=look_samples), (1, 2, 0)))
    plt.axis("off")
    plt.savefig(f"plots/celeba 28x28/gen_samples/{dataset_name} epoch{epoch} {model_params[0]}.pdf", bbox_inches='tight')
    plt.show()

In [None]:
cluster_max_epochs = 20
kernel = rbf_kernel_
gmmd_dist = [[[] for _ in set(cluster.flatten())] for _ in tests]
cos_dist = [[[] for _ in set(cluster.flatten())] for _ in tests]
for j, model_tests in enumerate(tests):
    for test_epoch in tqdm(model_tests[:cluster_max_epochs], total=cluster_max_epochs):
        generated_data_pixel_cluster = [np.array([test_epoch[:, :, :, x, y] for x in range(image_size) for y in range(image_size) if cluster[x, y] == c]).transpose(1, 2, 0, 3).reshape(n, m, -1) for c in set(cluster.flatten())]
        for i, (generated_cluster, test_cluster) in enumerate(zip(generated_data_pixel_cluster, test_data_pixel_cluster)):
            gmmd_dist[j][i].append(metrics.gmmd(generated_cluster, test_cluster, kernel=kernel))
            cos_dist[j][i].append(metrics.cos(generated_cluster, test_cluster, kernel=kernel))

In [None]:
plt.rcParams.update({'font.size': 18})
for model_cos_dist, model_param in zip(cos_dist, model_params):
    batch_size, ngf, ndf, lr, beta1 = model_param.get()
    print(f"cos dist for all clusters {model_param}")
    for i, cos_dist_i in enumerate(model_cos_dist, 1):
        plt.plot(np.arange(1, len(cos_dist_i)+1), cos_dist_i, label=f"Cluster {i}", color=colors[i], linewidth=5)
    plt.xlim(1, 20)
    plt.legend(loc='upper left', bbox_to_anchor=(1, 1))
    plt.xlabel("Epochs")
    plt.ylabel("CMS")
    plt.savefig(f"plots/{dataset_name} {image_size}x{image_size}/cluster_wise/Cluster cos of {model_name} {max_images_per_epoch} {model_param} {kernel_name}.pdf", bbox_inches='tight')
    plt.show()

In [None]:
cpus = 1
flatted_tests = tests.reshape(len(model_params), len(epochs), n, m, -1)
image_wise_mmd = np.array([metrics.multiprocessed_gmmd(flatted_tests[i], data.reshape(data.shape[0], data.shape[1], -1), kernel=kernel, cpus=cpus) for i in range(len(model_params))])
image_wise_cos = np.array([metrics.multiprocessed_cos(flatted_tests[i], data.reshape(data.shape[0], data.shape[1], -1), kernel=kernel, cpus=cpus) for i in range(len(model_params))])
image_wise_cos.shape

In [None]:
image_wise_cos_cluster_prod = [np.array(cos_dist).prod(axis=1)]
image_wise_cos_cluster_prod[0].shape

In [None]:
max_epochs = 19
linewidth=5
n_clusters_to_eval = [5]

for image_wise_mmd_i, image_wise_cos_i, model_param in zip(image_wise_mmd, image_wise_cos, model_params):
    batch_size, ngf, ndf, lr, beta1 = model_param.get()
    print(f"dist image wise {model_param}")
    s = slice(0, max_epochs+1)
    _s = slice(1, max_epochs+1)
    plt.plot(np.arange(1, len(cos_dist_i)+1), image_wise_cos_i[s], label="CMS (Higher is better)", linewidth=linewidth)
    for j, image_wise_cos_prod_i in zip(n_clusters_to_eval, image_wise_cos_cluster_prod):
        plt.plot(np.arange(1, len(cos_dist_i)+1), image_wise_cos_prod_i[0][s], label=f"Product CMS for {j} clusters", linewidth=linewidth)
    #plt.plot(image_wise_cos_cluster_prod_50[0][0][s], label=f"Product CMS for 50 clusters")
    plt.legend(loc='upper left', bbox_to_anchor=(0, 1.3))
    plt.xlim((1, max_epochs+1))
    plt.xlabel("Epochs")
    plt.ylabel("Value")
    plt.savefig(f"plots/{dataset_name} {image_size}x{image_size}/image_wise/multi cluster CMS {model_name} {max_images_per_epoch} {model_param} {kernel_name}.pdf", bbox_inches='tight')
    plt.show()

In [None]:
####################################
######### DDPM #####################
####################################

In [None]:
model_params = GetDiffusionModelParameters()
for i, model_param in enumerate(model_params):
    print(i, model_param)
    
model_params = [list(model_params)[0]]
filename = f'data/samples Diffusion28x28 0 {model_params[0]}.pkl'
with open(filename, 'rb') as file:
    exp_results = pickle.load(file)
tests = np.array([exp_results])
tests.shape

In [None]:
look_ensambles = 20
look_samples = 20
for epoch in epochs[slice(0, 30, 5)]:
    print((f"{epoch=}"))
    images = (torch.tensor(tests[0, epoch, :look_ensambles, :look_samples].reshape((look_ensambles - 0)*look_samples, nc, image_size, image_size)) + 1)/2
    plt.imshow(np.transpose(vutils.make_grid(images, nrow=look_samples), (1, 2, 0)))
    plt.axis("off")
    plt.savefig(f"plots/celeba 28x28/gen_samples/{dataset_name} epoch{epoch} {model_params[0]}.pdf", bbox_inches='tight')
    plt.show()

In [None]:
#model_id = 0
cluster_max_epochs = 30
kernel = rbf_kernel_
gmmd_dist = [[[] for _ in set(cluster.flatten())] for _ in tests]
cos_dist = [[[] for _ in set(cluster.flatten())] for _ in tests]
for j, model_tests in enumerate(tests):
    for test_epoch in tqdm(model_tests[:cluster_max_epochs], total=cluster_max_epochs):
        generated_data_pixel_cluster = [np.array([test_epoch[:, :, :, x, y] for x in range(image_size) for y in range(image_size) if cluster[x, y] == c]).transpose(1, 2, 0, 3).reshape(n, m, -1) for c in set(cluster.flatten())]
        for i, (generated_cluster, test_cluster) in enumerate(zip(generated_data_pixel_cluster, test_data_pixel_cluster)):
            gmmd_dist[j][i].append(metrics.gmmd(generated_cluster, test_cluster, kernel=kernel))
            cos_dist[j][i].append(metrics.cos(generated_cluster, test_cluster, kernel=kernel))

In [None]:
plt.rcParams.update({'font.size': 18})
for model_cos_dist, model_param in zip(cos_dist, model_params):
    batch_size, ngf, ndf, lr, beta1 = model_param.get()
    print(f"cos dist for all clusters {model_param}")
    for i, cos_dist_i in enumerate(model_cos_dist, 1):
        plt.plot(np.arange(1, len(cos_dist_i)+1), cos_dist_i, label=f"Cluster {i}", color=colors[i], linewidth=5)
    plt.xlim(1, 10)
    plt.legend(loc='upper left', bbox_to_anchor=(1, 1))
    plt.xlabel("Epochs")
    plt.ylabel("CMS")
    plt.savefig(f"plots/{dataset_name} {image_size}x{image_size}/cluster_wise/Cluster cos of {model_name} {max_images_per_epoch} {model_param} {kernel_name}.pdf", bbox_inches='tight')
    plt.show()

In [None]:
cpus = 1
epochs = np.arange(0, 31)
flatted_tests = tests.reshape(len(model_params), len(epochs), n, m, -1)
image_wise_mmd = np.array([metrics.multiprocessed_gmmd(flatted_tests[i], data.reshape(data.shape[0], data.shape[1], -1), kernel=kernel, cpus=cpus) for i in range(len(model_params))])
image_wise_cos = np.array([metrics.multiprocessed_cos(flatted_tests[i], data.reshape(data.shape[0], data.shape[1], -1), kernel=kernel, cpus=cpus) for i in range(len(model_params))])
image_wise_cos.shape

In [None]:
image_wise_cos_cluster_prod = [np.array(cos_dist).prod(axis=1)]
image_wise_cos_cluster_prod[0].shape

In [None]:
max_epochs = 10
linewidth=5
n_clusters_to_eval = [5]

for image_wise_mmd_i, image_wise_cos_i, model_param in zip(image_wise_mmd, image_wise_cos, model_params):
    batch_size, ngf, ndf, lr, beta1 = model_param.get()
    print(f"dist image wise {model_param}")
    s = slice(0, max_epochs)
    _s = slice(1, max_epochs+1)
    plt.plot(np.arange(1, max_epochs+1), image_wise_cos_i[s], label="CMS (Higher is better)", linewidth=linewidth)
    for j, image_wise_cos_prod_i in zip(n_clusters_to_eval, image_wise_cos_cluster_prod):
        plt.plot(np.arange(1, max_epochs+1), image_wise_cos_prod_i[0][s], label=f"Product CMS for {j} clusters", linewidth=linewidth)
    plt.legend(loc='upper left', bbox_to_anchor=(0, 1.3))
    plt.xlim((1, max_epochs+1))
    plt.xlabel("Epochs")
    plt.ylabel("Value")
    plt.savefig(f"plots/{dataset_name} {image_size}x{image_size}/image_wise/multi cluster CMS {model_name} {max_images_per_epoch} {model_param} {kernel_name}.pdf", bbox_inches='tight')
    plt.show()