In [None]:
import numpy as np
import cv2
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import matplotlib.pyplot as plt
from ipywidgets import interact, IntSlider
from IPython.display import display

In [None]:
# architecture
class Autoencoder(nn.Module):
    def __init__(self):
        super().__init__()
        # convolution layers and max pooling of encoder
        self.encoder = nn.Sequential(
            nn.Conv2d(1,16,(3,3),padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(16,8,(3,3),padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(8,8,(3,3),padding=1),
            nn.Sigmoid(),
            nn.MaxPool2d(2,padding=1),
            nn.Flatten()
        )
        # convolution layers and upsampling of decoder
        self.decoder = nn.Sequential(
            nn.Unflatten(1,(8,4,4)),
            nn.Conv2d(8,8,(3,3),padding=1),
            nn.ReLU(),
            nn.Upsample(scale_factor=(2,2)),
            nn.Conv2d(8,8,(3,3),padding=1),
            nn.ReLU(),
            nn.Upsample(scale_factor=(2,2)),
            nn.Conv2d(8,16,(3,3)),
            nn.ReLU(),
            nn.Upsample(scale_factor=(2,2)),
            nn.Conv2d(16,1,(3,3),padding=1),
            nn.Sigmoid()
        )
    def forward(self, x):
        # apply encoder
        features = self.encoder(x)
        # apply decoder
        return self.decoder(features)
    def __str__(self):
        return str(self.encoder)+str(self.decoder)

In [None]:
autoencoder = Autoencoder()

In [None]:
model_name = 'pytorch_mnist_autoencoder_model.pth'
from google.colab import files
print('upload',model_name)
files.upload()
autoencoder.load_state_dict(torch.load(model_name))

In [None]:
decoder = autoencoder.decoder
decoder.eval()

In [None]:
print('upload inp00009.png')
_ = files.upload()

In [None]:
sample_image = cv2.imread('inp00009.png',cv2.IMREAD_GRAYSCALE)
sample_blob = cv2.dnn.blobFromImage(sample_image,1.0/255.0)
sample_features = autoencoder.encoder(torch.tensor(sample_blob))
features = sample_features.squeeze(0).detach().cpu().numpy()

In [None]:
print(sample_features.shape)

In [None]:
# Keep a copy of features for updating
features = sample_features[0].clone().detach()
last_index = -1

In [None]:
# Function to update and display the generated image
def update_latent(index=0, value=64):
    global features, last_index
    if last_index != index:
        last_index = index
        value_slider.value = int(features[index].item()*127)
    else:
        features[index] = value/127.0
    features = features
    with torch.no_grad():
        coded = features.unsqueeze(0).unsqueeze(0)
        decoded = decoder(coded).detach().squeeze(0).squeeze(0)
    decoded = (decoded.numpy()*255).astype(np.uint8)
    decoded_resized = cv2.resize(decoded, (420, 420), interpolation=cv2.INTER_NEAREST)
    plt.figure(figsize=(6,6))
    plt.imshow(decoded_resized, cmap='gray')
    plt.axis('off')
    plt.show()

In [None]:
# Interactive sliders for latent space
latent_dim = features.shape[0]
index_slider = IntSlider(min=0, max=latent_dim-1, step=1, value=0, description='Index')
value_slider = IntSlider(min=0, max=127, step=1, value=int(features[0].item()*127), description='Value')

In [None]:
def interactive_update(index, value):
    update_latent(index, value)

In [None]:
interact(interactive_update, index=index_slider, value=value_slider)