In [1]:
%matplotlib
import tensorflow as tf
from tensorflow import keras
import numpy as np
import matplotlib.pyplot as plt
import tensorflow.keras.datasets.mnist as mnist
from autoencoders import init_linear_autoencoder, init_dense_autoencoder
from numpy import random as rnd

Using matplotlib backend: Qt5Agg


In [12]:
(x_train, y_train),(x_test, y_test) = mnist.load_data()
x_train = x_train /255.
x_test  = x_test /255.
input_shape = x_train.shape[1:]
num_train  = x_train.shape[0]
num_test   = x_test.shape[0]
img_width  = x_train.shape[1]
img_height = x_train.shape[2]

latent_shape = (32,)

In [13]:
lin_auto, lin_enco, lin_deco = init_linear_autoencoder(latent_shape, input_shape)
dens_auto, dens_enco, dens_deco = init_dense_autoencoder(latent_shape, input_shape)


Model: "model_6"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_5 (InputLayer)         [(None, 28, 28)]          0         
_________________________________________________________________
flatten_4 (Flatten)          (None, 784)               0         
_________________________________________________________________
dense_4 (Dense)              (None, 32)                25120     
_________________________________________________________________
reshape_4 (Reshape)          (None, 32)                0         
Total params: 25,120
Trainable params: 25,120
Non-trainable params: 0
_________________________________________________________________
Model: "model_7"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_6 (InputLayer)         [(None, 32)]              0         
______________________________________

In [14]:
lin_auto.compile(optimizer = 'adam', loss = 'binary_crossentropy')
dens_auto.compile(optimizer = 'adam', loss = 'MSE')

In [None]:
lin_hist = lin_auto.fit(x_train, x_train, epochs = 10, batch_size = 256, validation_data = (x_test,x_test), shuffle = True)

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10

In [15]:
dens_hist = dens_auto.fit(x_train, x_train, epochs = 50, batch_size = 256, validation_data = (x_test,x_test), shuffle = True)

Epoch 1/50
Epoch 2/50
Epoch 3/50
Epoch 4/50
Epoch 5/50
Epoch 6/50
Epoch 7/50
Epoch 8/50
Epoch 9/50
Epoch 10/50
Epoch 11/50
Epoch 12/50
Epoch 13/50
Epoch 14/50
Epoch 15/50
Epoch 16/50
Epoch 17/50
Epoch 18/50
Epoch 19/50
Epoch 20/50
Epoch 21/50
Epoch 22/50
Epoch 23/50
Epoch 24/50
Epoch 25/50
Epoch 26/50
Epoch 27/50
Epoch 28/50
Epoch 29/50
Epoch 30/50
Epoch 31/50
Epoch 32/50
Epoch 33/50
Epoch 34/50
Epoch 35/50
Epoch 36/50
Epoch 37/50
Epoch 38/50
Epoch 39/50
Epoch 40/50
Epoch 41/50
Epoch 42/50
Epoch 43/50
Epoch 44/50
Epoch 45/50
Epoch 46/50
Epoch 47/50
Epoch 48/50
Epoch 49/50
Epoch 50/50


In [16]:
lin_x = lin_auto.predict(x_train)
dens_x = dens_auto.predict(x_train)

In [7]:
""" Display linear results """
N = 4
rnd_i = rnd.permutation(x_train.shape[0])[0:N]
fig, axs = plt.subplots(2, N)
for i in range(N):
    axs[0,i].imshow(lin_x[rnd_i[i]],vmin = 0, vmax = 1.0)
    axs[0,i].set_title("Number is " + str(y_train[rnd_i[i]]))
    axs[1,i].imshow(x_train[rnd_i[i]],vmin = 0, vmax = 1.0)
    axs[1,i].set_title("Number is " + str(y_train[rnd_i[i]]))

In [17]:
""" Display dense results """
N = 4
rnd_i = rnd.permutation(x_train.shape[0])[0:N]
fig, axs = plt.subplots(2, N)
for i in range(N):
    axs[0,i].imshow(dens_x[rnd_i[i]],vmin = 0, vmax = 1.0)
    axs[0,i].set_title("Number is " + str(y_train[rnd_i[i]]))
    axs[1,i].imshow(x_train[rnd_i[i]],vmin = 0, vmax = 1.0)
    axs[1,i].set_title("Number is " + str(y_train[rnd_i[i]]))

In [28]:
lin_latent = lin_enco.predict(x_train)
dens_latent = dens_enco.predict(x_train[y_train ==4])

In [29]:
from sklearn.decomposition import PCA
from sklearn.preprocessing import normalize

In [30]:
dens_latent.shape


(5842, 32)

In [31]:
vv = normalize(PCA(2).fit_transform(dens_latent.T), axis = 0)
dens_mean = np.mean(dens_latent, axis = 0)
dens_mean.shape

(32,)

In [16]:
plt.imshow(dens_latent,aspect='auto')
plt.colorbar()
plt.ion()

In [11]:
np.max(x_train[rnd_i[i]])

0.19105611215508309

In [32]:
class LatentDigit:
    def __init__(self, model, fig,
                              latent_boundary = (-1., 1., -1., 1.),
                              latent_shape = (2,),
                              latent_vectors = np.array([[1,0],[0,1]]),
                              original_point = np.nan):
        
        self.original_point = np.zeros(latent_shape) if np.any(np.isnan(original_point)) else original_point
        self.latent_point = self.original_point
        self.component_point = np.zeros((2,))
        self.latent_shape = latent_shape
        self.latent_vectors = latent_vectors
        self.model = model
        self.image = self.model.predict(np.array([self.latent_point]))[0]
        
        self.fig = fig
        self.fig.subplots(1,2)
        self.latent_boundary = latent_boundary
        self.setup_latent_space()
        
        self.click_tol = 0.05
        self.draw()
        self.fig.canvas.mpl_connect('button_press_event'  , self.button_press_callback)
        self.fig.canvas.mpl_connect('motion_notify_event' ,self.motion_notify_callback)
        self.fig.canvas.mpl_connect('button_release_event', self.button_release_callback)
        self.fig.canvas.mpl_connect('key_press_event', self.key_press_callback)
        self.point_selected = False        
    
    def setup_latent_space(self):
        self.fig.axes[0].set_xlim(self.latent_boundary[0], self.latent_boundary[1])
        self.fig.axes[0].set_ylim(self.latent_boundary[2], self.latent_boundary[3])       
        
    
    def draw(self):
        self.fig.axes[0].cla()
        self.setup_latent_space()
        self.fig.axes[0].scatter(self.component_point[0], self.component_point[1])
        self.fig.axes[1].imshow(self.image,vmin = 0, vmax = 1.0)
        self.fig.canvas.draw_idle()
    
    def button_press_callback(self, event):
        self.update_point_select(event)
    
    def update_point_select(self, event):
        click_point = np.array([event.xdata, event.ydata])
        dist = np.linalg.norm(click_point-self.component_point)
        self.point_selected =  dist < self.click_tol
    
    def motion_notify_callback(self, event):
        if not self.point_selected: return
        self.update_point(event)
        self.update_digit()
        self.draw()
    
    def button_release_callback(self, event):
        self.point_selected = False
    
    def update_point(self, event):
        self.component_point = np.array([event.xdata, event.ydata])

    def update_digit(self):
        self.latent_point = self.original_point+ self.latent_vectors@self.component_point
        self.image = self.model.predict(np.array([self.latent_point]))[0]
    def key_press_callback(self, event):
        if not event.key == " ": return
        self.update_point(event)
        self.update_digit()
        self.draw()



In [34]:
ld = LatentDigit(dens_deco, plt.figure(), latent_shape = latent_shape,latent_boundary = (-10., 10., -10., 10.),  latent_vectors = vv, original_point = dens_mean)

In [71]:
dens_latent[0]

array([0.19989485, 0.08612621, 0.20631026, 0.1652254 , 0.1744605 ,
       0.14679769, 0.62257195, 0.10897823, 0.35317123, 0.18604484],
      dtype=float32)