In [None]:
import numpy as np
import matplotlib.pyplot as plt
from sklearn.cluster import KMeans
import torch
import torch.nn as nn

import os
import sys

sys.path.append("../scripts")
from data import get_sample_patches_dataset, get_filenames
from model import SmallAutoEncoder, AutoEncoder, DEC

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
# load data
filenames = get_filenames()
dataset = get_sample_patches_dataset(filenames=filenames, resize=(28, 28))
dl = torch.utils.data.DataLoader(dataset, batch_size=256)

In [None]:
# load only resnet and cluster with that
model = {}
model['autoencoder'] = AutoEncoder(64)
model['autoencoder'].load_state_dict(torch.load("../models/AE_small/model.pt"))
model['autoencoder'].eval()
model['resnet'] = model['autoencoder'].encoder
model['resnet'] = model['resnet'].child()[-3:]
model['dec'] = DEC(10, 64, model['autoencoder'].encoder)
model['dec'].load_state_dict(torch.load("../models/DEC_small/model.pt"))
model['dec'].eval()

In [None]:
embeddings = {}
embeddings["resnet"] = []
for batch in dl:
    embeddings["resnet"].append(model["resnet"](batch).detach().cpu().numpy().reshape(256, -1))
    embeddings["autoencoder"].append(model["autoencoder"](batch)[0].detach().cpu().numpy().reshape(256, -1))
    embeddings["dec"].append(model["dec"](batch).max(1)[1].detach().cpu().numpy().reshape(256, -1))
embeddings["resnet"] = np.concatenate(embeddings["resnet"])
embeddings["autoencoder"] = np.concatenate(embeddings["autoencoder"])
embeddings["dec"] = np.concatenate(embeddings["dec"])


clusters = {}
cls = KMeans(10, n_init=20)
clusters["resnet"] = cls.fit_predict(embeddings["resnet"]).labels_
clusters["autoencoder"] = cls.fit_predict(embeddings["autoencoder"]).labels_
clusters["dec"] = embeddings["dec"]

In [None]:
name = "resnet"
for i in range(10):
    samples = clusters[name] == i
    random_samples = np.random.choice(np.where(samples)[0], 10)
    fig, ax = plt.subplots(3, 4, figsize=(10, 8))
    for j, idx in enumerate(random_samples):
        ax[j].imshow(dataset[idx][0].permute(1, 2, 0))
        ax[j].axis("off")
    plt.show()

In [None]:
name = "autoencoder"
for i in range(10):
    samples = clusters[name] == i
    random_samples = np.random.choice(np.where(samples)[0], 10)
    fig, ax = plt.subplots(3, 4, figsize=(10, 8))
    for j, idx in enumerate(random_samples):
        ax[j].imshow(dataset[idx][0].permute(1, 2, 0))
        ax[j].axis("off")
    plt.show()

In [None]:
name = "dec"
for i in range(10):
    samples = clusters[name] == i
    random_samples = np.random.choice(np.where(samples)[0], 10)
    fig, ax = plt.subplots(3, 4, figsize=(10, 8))
    for j, idx in enumerate(random_samples):
        ax[j].imshow(dataset[idx][0].permute(1, 2, 0))
        ax[j].axis("off")
    plt.show()

In [None]:
# load only autoencoder and cluster with that