In [22]:
import torch
from torch import nn
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm
from sklearn.metrics import accuracy_score, precision_score, recall_score
from sklearn.cluster import KMeans
import pickle
from get_orlov_datasets import get_orlov_datasets
from autoencoder import Autoencoder, Encoder, Decoder
from torchvision.transforms.functional import crop
from torch.utils.data import DataLoader, TensorDataset
from sklearn.neighbors import NearestNeighbors

In [23]:
SUBIMAGE_SIZE = 40
BATCH_SIZE = 256
NUM_LOADERS_WORKERS = 0
# PRETRAINED_AUTOENCODER_FILE = './checkpoints/autoencoder/old/ldim-256_c_hid-32_lam-50-decoder_extended-3/checkpoints/epoch=14-step=570.ckpt'
PRETRAINED_AUTOENCODER_FILE = 'C:/_DIPLOMA/code/checkpoints/autoenc4_test/lightning_logs/ldim-256_c_hid-32_lam-50-decoder_extended-3/checkpoints/epoch=14-step=570.ckpt'
TRAINED_KMEANS_FILE = './checkpoints/kmeans/v1_KMeans.pkl'
CENTROIDS_COUNT=48

In [24]:
train_loader, val_loader, test_loader, additional = get_orlov_datasets(train_subimages_num=1, 
                                                                        num_loaders_workers=NUM_LOADERS_WORKERS,
                                                                        batch_size=1, subimage_size=None)

In [25]:
data_train, data_test, train_count, val_count, test_count = additional

In [26]:
autoencoder_model = Autoencoder.load_from_checkpoint(PRETRAINED_AUTOENCODER_FILE)
encoder = Encoder(num_input_channels=3, base_channel_size=32, latent_dim=256)
encoder.load_state_dict(autoencoder_model.encoder.state_dict())

<All keys matched successfully>

In [27]:
file = open(TRAINED_KMEANS_FILE, 'rb')
kmeans: KMeans = pickle.load(file)

In [28]:
train_xs = np.zeros((train_count, 256))
val_xs = np.zeros((val_count, 256))
test_xs = np.zeros((test_count, 256))
train_ys, val_ys, test_ys = np.zeros((train_count, 1)), np.zeros((val_count, 1)), np.zeros((test_count, 1))

In [29]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [30]:
encoder = encoder.to(device)

In [31]:
def occs_dataset(loader, data, xs, ys):
    i = 0
    for batch_images, batch_classes in tqdm(loader):
        image = batch_images[0]

        subimages_batch, reconstruction_info = data.get_pure_image_subimages(image)

        width_count, height_count = reconstruction_info[3], reconstruction_info[2]

        subimages_batch = subimages_batch.to(device)
        with torch.no_grad():
            representations = encoder(subimages_batch)
        reps_np = representations.cpu().detach().numpy()

        # reps_clusters = kmeans.predict(reps_np)
        # images_count = width_count * height_count
        
        xs[i] = np.mean(reps_np, axis=0)

        # xs[i] = np.bincount(reps_clusters, minlength=CENTROIDS_COUNT) / images_count
        ys[i] = batch_classes[0]
        i += 1

In [32]:
occs_dataset(train_loader, data_train.datasets[0], train_xs, train_ys)
occs_dataset(val_loader, data_train.datasets[0], val_xs, val_ys)
occs_dataset(test_loader, data_train.datasets[0], test_xs, test_ys)

100%|██████████| 300/300 [01:00<00:00,  4.94it/s]
100%|██████████| 41/41 [00:09<00:00,  4.52it/s]
100%|██████████| 33/33 [00:07<00:00,  4.62it/s]


In [57]:
knn = NearestNeighbors(n_neighbors=3)
knn.fit(train_xs)

In [58]:
val_5_acc = 0
for x, y in zip(val_xs, val_ys):
    res_ids = knn.kneighbors(x.reshape(1, -1), return_distance=False)[0]
    found_ys = train_ys[res_ids]
    val_5_acc += np.median(y == found_ys)
val_5_acc /= len(val_xs)
print("Accuracy:", val_5_acc)

Accuracy: 0.1951219512195122
