In [None]:
import torch
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor

import matplotlib.pyplot as plt

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"

training_data = datasets.FashionMNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor(),
)

test_data = datasets.FashionMNIST(
    root="data",
    train=False,
    download=True,
    transform=ToTensor(),
)

class WrappedDataLoader:
    def __init__(self, dataloader, func):
        self.dataloader = dataloader
        self.func = func

    def __len__(self):
        return len(self.dataloader)
    
    def __iter__(self):
        for b in self.dataloader:
            yield(self.func(*b))

def get_data(training_data, test_data, batch_size):
    return (DataLoader(training_data, batch_size=batch_size), DataLoader(test_data, batch_size=batch_size, shuffle=True))

def gpu_preprocess(x, y):
    # x is input, y is labels - sending to GPU
    return (x.view(-1, 1, 28, 28).to(device), y.to(device))


In [None]:
labels_map = {
    0: "T-Shirt",
    1: "Trouser",
    2: "Pullover",
    3: "Dress",
    4: "Coat",
    5: "Sandal",
    6: "Shirt",
    7: "Sneaker",
    8: "Bag",
    9: "Ankle Boot",
}

figure = plt.figure(figsize=(8, 8))
cols, rows = 3, 3
for i in range(1, cols * rows + 1):
    sample_idx = torch.randint(len(training_data), size=(1,)).item()
    img, label = training_data[sample_idx]
    figure.add_subplot(rows, cols, i)
    plt.title(labels_map[label])
    plt.axis("off")
    plt.imshow(img.squeeze(), cmap="gray")
plt.show()

### LOADING AUTOENCODER MODEL

In [None]:
from autoencoder import Autoencoder

saved_ae_model = Autoencoder().to(device)
saved_ae_model.load_state_dict(torch.load('AE.pth', weights_only=True))
saved_ae_model.eval()

In [None]:
import random
import time

# picking random image to test from dataloader
random.seed(time.time())
sample_index = random.randint(0, len(test_data))

img, label = test_data[sample_index]
img = img.to(device)

pred_img = []
with torch.no_grad():
    pred_img = saved_ae_model(img)
    pred_img = torch.unflatten(pred_img, -1, (28, 28))


figure = plt.figure(figsize=(8, 8))
cols, rows = 2, 1

figure.add_subplot(rows, cols, 1)
plt.title(f"True: {labels_map[label]}")
plt.axis("off")
plt.imshow(img.cpu().squeeze(), cmap="gray")

figure.add_subplot(rows, cols, 2)
plt.title(f"autoencoder: {labels_map[label]}")
plt.axis("off")
plt.imshow(pred_img.cpu().squeeze(), cmap="gray")

plt.show()

In [None]:
_, test_dataloader = get_data(training_data, test_data, batch_size=64)

for i, (X, y) in enumerate(test_dataloader):
    with torch.no_grad():
        z = saved_ae_model.encoder(X.to(device))
        z = z.to('cpu').numpy()

    plt.scatter(z[:, 0], z[:, 1], c=y, cmap='tab10')

    if i > 64:
        plt.colorbar()
        break

plt.show()

    