In [None]:
from diffusers import UNet2DModel
import PIL.Image
import numpy as np
import torch

import tqdm
from diffusers import StableDiffusionPipeline, DiffusionPipeline
# from diffusers.utils.remote_utils import remote_decode
from diffusers import AutoencoderKL

import os
os.chdir('..')
from src.utils import gen_img, plot_images, load_vaes, compare_all_vaes_pandas
from src.custom_vae import download_custom_vae, load_custom_vae

import seaborn as sns
import matplotlib.pyplot as plt
import zipfile
from torch import nn

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse").to(device)

In [None]:
is_data_downlaoded = os.path.exists('data/all_dogs')

if not is_data_downlaoded:
    print('Downloading data...')
    !gdown 1KXRTB_q4uub_XOHecpsQjE4Kmv76sZbV -O data/all-dogs.zip
    # linux
    # !unzip -q data/all-dogs.zip -d data/all-dogs

    # windows
    zip_path = "data/all-dogs.zip"
    extract_to = "data/"

    os.makedirs(extract_to, exist_ok=True)

    with zipfile.ZipFile(zip_path, 'r') as zip_ref:
        zip_ref.extractall(extract_to)

In [None]:
imgs = [gen_img(prompt="a photo of a huge dog", vae=vae, seed=i) for i in range(1)]
plot_images(*imgs)

In [None]:
vaes = load_vaes()

In [None]:
download_custom_vae()
custom_vae = load_custom_vae()

In [None]:
class VAEDog(nn.Module):
    def __init__(self, encoder, decoder):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder

    def encode(self, x):
        return self.encoder(x)

    def decode(self, z):
        return self.decoder(z)
    
    def forward(self, x):
        z = self.encode(x)
        out = self.decode(z)
        return out

In [None]:
vae.to(device)
vae = VAEDog(vae.encoder, vae.decoder)

vaes['custom_vae'] = vae

In [None]:
from src.utils import compare_all_vaes_pandas

cka_df = compare_all_vaes_pandas(vaes, imgs, batch_size=4)

In [None]:
sns.heatmap(cka_df.astype(float), annot=True, fmt=".2f", cmap="coolwarm")
plt.title("CKA Similarity Between VAEs")
plt.xticks(rotation=45, ha="right")
plt.yticks(rotation=0)
plt.tight_layout()
plt.savefig("cka_heatmap.png")
plt.show()

In [None]:
sns.clustermap(cka_df.astype(float), 
               cmap="coolwarm", 
               annot=True, 
               fmt=".2f", 
               linewidths=0.5)
plt.suptitle("CKA Clustermap of VAE Similarities", y=1.05)
plt.savefig("cka_clustermap.png")
plt.show()