In [None]:
from types import SimpleNamespace
import numpy as np
import matplotlib.pyplot as plt
from pdb import set_trace
from tqdm import tqdm

import torch
from torch import nn, optim
from torch.nn import functional as F

import keras
from keras import layers, models
from dataset import get_data
from network import AE
from training import trainer
from testing import reconstruction, display

print(keras.backend.backend())
print(keras.backend.image_data_format())

In [None]:
args = SimpleNamespace(dataset="fmnist")
args.device ="cuda" if torch.cuda.is_available() else "cpu"
args.img = 32       # image size
args.ch = 1         # num of channel
args.batch = 100    # batch size
args.dim = 2        # embedding dimension
args.epoch = 10
args.lr = 1e-3
print(args)

In [None]:
loader = get_data(args.dataset, args.batch)
len(loader.train), len(loader.test)

In [None]:
x, y = next(iter(loader.train))
x.shape, y.shape

In [None]:
model = AE(args.img, args.ch, args.dim).to(args.device)
print("\n", model.shape_bf)
print("\n", model.encoder.summary())
print("\n", model.decoder.summary())

In [None]:
model = AE(args.img, args.ch, args.dim).to(args.device)
optimizer = optim.Adam(model.parameters(), lr=args.lr)
loss_fn = nn.BCELoss()

In [None]:
model, history = trainer(model, loader, args.epoch, 
                         optimizer, loss_fn, args.device)
model.training

In [None]:
plt.plot(history.train, label="train loss")
plt.plot(history.test, label="test loss")
plt.legend()
plt.show()

In [None]:
# load the saved model
checkpoint = torch.load("best_model.pth", weights_only=True)
model.load_state_dict(checkpoint["model_dict"])
model.training

In [None]:
X, X_hat, Y = reconstruction(model, loader.test, 50, args.device)
X.shape, X_hat.shape, Y.shape

In [None]:
print("Example real clothing items")
display(X)
print("Correpsonding reconstructions")
display(X_hat)

In [None]:
# Encode the example images
with torch.no_grad():
    embeddings = model.encoder(X).cpu().numpy()
embeddings.shape

In [None]:
# Colour the embeddings by their label (clothing type - see table)
figsize = 8
plt.figure(figsize=(figsize, figsize))
plt.scatter(embeddings[:, 0], embeddings[:, 1], 
            cmap="rainbow", c=Y, alpha=0.8, s=3)
plt.colorbar()
plt.show()

In [None]:
# Get the range of the existing embeddings
mins, maxs = np.min(embeddings, axis=0), np.max(embeddings, axis=0)

# Sample some points in the latent space
grid_width, grid_height = (6, 3)
sample = np.random.uniform(mins, maxs, size=(grid_width*grid_height,args.dim))
sample.shape

In [None]:
# Decode the sampled points
with torch.no_grad():
    generated_sample = model.decoder(sample).cpu().numpy()
generated_sample.shape

In [None]:
# Draw a plot
figsize = 8
plt.figure(figsize=(figsize, figsize))

# ... the original embeddings ...
plt.scatter(embeddings[:, 0], embeddings[:, 1], c="black", alpha=0.5, s=2)

# ... and the newly generated points in the latent space
plt.scatter(sample[:, 0], sample[:, 1], c="#00B0F0", alpha=1, s=40)
plt.show()

# Add underneath a grid of the decoded images
fig = plt.figure(figsize=(figsize, grid_height * 2))
fig.subplots_adjust(hspace=0.4, wspace=0.4)

for i in range(grid_width * grid_height):
    ax = fig.add_subplot(grid_height, grid_width, i + 1)
    ax.axis("off")
    ax.text(0.5, -0.35, str(np.round(sample[i, :], 1)),
            fontsize=10, ha="center", transform=ax.transAxes)
    ax.imshow(generated_sample[i, :, :].squeeze(), cmap="Greys")

In [None]:
# Colour the embeddings by their label (clothing type - see table)
figsize = 12
grid_size = 15
plt.figure(figsize=(figsize, figsize))
plt.scatter(embeddings[:, 0], embeddings[:, 1], cmap="rainbow",
            c=Y, alpha=0.8, s=300)
plt.colorbar()

x = np.linspace(min(embeddings[:, 0]), max(embeddings[:, 0]), grid_size)
y = np.linspace(max(embeddings[:, 1]), min(embeddings[:, 1]), grid_size)
xv, yv = np.meshgrid(x, y)
xv = xv.flatten()
yv = yv.flatten()
grid = np.array(list(zip(xv, yv)))

with torch.no_grad():
    generated_sample = model.decoder(grid).cpu().numpy()

fig = plt.figure(figsize=(figsize, figsize))
fig.subplots_adjust(hspace=0.4, wspace=0.4)
for i in range(grid_size**2):
    ax = fig.add_subplot(grid_size, grid_size, i + 1)
    ax.axis("off")
    ax.imshow(generated_sample[i, :, :].squeeze(), cmap="Greys")