In [None]:
import numpy as np
import cv2
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

In [None]:
# custom module
class Reshape(nn.Module):
    def __init__(self, *args):
        super(Reshape, self).__init__()
        self.shape = tuple(map(int,args))
    def forward(self, x):
        return x.view((x.shape[0],)+self.shape)

In [None]:
# architecture
class Autoencoder(nn.Module):
    def __init__(self):
        super().__init__()
        # convolution layers and max pooling of encoder
        self.encoder = nn.Sequential(
            nn.Conv2d(1,16,(3,3),padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(16,8,(3,3),padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(8,8,(3,3),padding=1),
            nn.Sigmoid(),
            nn.MaxPool2d(2,padding=1),
            torch.nn.Flatten()
        )
        # convolution layers and upsampling of decoder
        self.decoder = nn.Sequential(
            Reshape(8,4,4),
            nn.Conv2d(8,8,(3,3),padding=1),
            nn.ReLU(),
            nn.Upsample(scale_factor=(2,2)),
            nn.Conv2d(8,8,(3,3),padding=1),
            nn.ReLU(),
            nn.Upsample(scale_factor=(2,2)),
            nn.Conv2d(8,16,(3,3)),
            nn.ReLU(),
            nn.Upsample(scale_factor=(2,2)),
            nn.Conv2d(16,1,(3,3),padding=1),
            nn.Sigmoid()
        )
    def forward(self, x):
        # apply encoder
        features = self.encoder(x)
        # apply decoder
        return self.decoder(features)
    def __str__(self):
        return str(self.encoder)+str(self.decoder)

In [None]:
autoencoder = Autoencoder()

In [None]:
model_name = 'pytorch_mnist_autoencoder_model.pth'
autoencoder.load_state_dict(torch.load(model_name))

In [None]:
decoder = autoencoder.decoder
decoder.eval()

In [None]:
sample_image = cv2.imread('mnist/inp00009.png',cv2.IMREAD_GRAYSCALE)
sample_blob = cv2.dnn.blobFromImage(sample_image,1.0/255.0)
sample_features = autoencoder.encoder(torch.tensor(sample_blob))
features = sample_features.squeeze(0).detach().cpu().numpy()

In [None]:
last_features = np.copy(features)
last_features[0] = 1.0

In [None]:
cv2.namedWindow("generator")
cv2.imshow('generator',np.zeros((560,560),np.uint8))

In [None]:
index = 55
feature = 0

In [None]:
def update_index( *args ):
    global index, feature
    index = args[0]
    #features[index] = feature

In [None]:
def update_feature( *args ):
    global index, features
    print(index,args[0])
    feature = float(args[0]) / 127.0
    features[index] = feature

In [None]:
cv2.createTrackbar("index", "generator", index, 127, update_index)
cv2.createTrackbar("value", "generator", feature, 127, update_feature)

In [None]:
while True:

    if not np.array_equal(features,last_features):
        last_features = np.copy(features)
        coded = torch.tensor(features,dtype=torch.float32).unsqueeze(0)
        decoded = decoder(coded)
        decoded = np.asarray(decoded[0].squeeze(0).detach().numpy()*255,np.uint8)
        decoded = cv2.resize(decoded,(560,560))
        print('displayed')
        cv2.imshow('generator',decoded)

    if cv2.waitKey(10) == 27:
        break

In [None]:
cv2.destroyAllWindows()