In [1]:
from ipywidgets import interact, widgets, interactive
import os
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from pathlib import Path
from PIL import Image
import math
from utils.layers import *
from utils.tools import *

In [2]:
gpus = tf.config.experimental.list_physical_devices('GPU')
tf.config.experimental.set_visible_devices(gpus[0], 'GPU')
tf.config.experimental.set_memory_growth(gpus[0], True)

In [3]:
# set some paths and config
PATH_DIR = Path.cwd()
model_dir = PATH_DIR.joinpath('bin')
media_dir = PATH_DIR.joinpath('media')


input_shape = (28,28,1)
ROUTING = False

# Import the dataset

In [4]:
# import the datatset
(X_train, y_train), (X_test, y_test) = tf.keras.datasets.mnist.load_data(path='mnist.npz')

In [5]:
# normalize dataset
def pre_process(x, y):
    return (x / 255)[...,None], tf.keras.utils.to_categorical(y, num_classes=10)

In [6]:
# prepare the data
X_train, y_train = pre_process(X_train, y_train)
X_test, y_test = pre_process(X_test, y_test)

# Create model

In [7]:
def Generator():
    inputs = tf.keras.Input(16*10)
    
    x = tf.keras.layers.Dense(512, activation='relu')(inputs)
    x = tf.keras.layers.Dense(1024, activation='relu')(x)
    x = tf.keras.layers.Dense(np.prod(input_shape), activation='sigmoid')(x)
    x = tf.keras.layers.Reshape(target_shape=input_shape, name='out_generator')(x)
    
    return tf.keras.Model(inputs=inputs, outputs=x, name='Generator')

In [8]:
def CapsNet():
    inputs = tf.keras.Input(input_shape)
    
    x = tf.keras.layers.Conv2D(32,5,activation="relu", padding='valid', kernel_initializer='he_normal')(inputs)
    x = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.layers.Conv2D(64,3, activation='relu', padding='valid', kernel_initializer='he_normal')(x)
    x = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.layers.Conv2D(64,3, activation='relu', padding='valid', kernel_initializer='he_normal')(x)   
    x = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.layers.Conv2D(128,3,2, activation='relu', padding='valid', kernel_initializer='he_normal')(x)   
    x = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.layers.Conv2D(128, 9, activation='linear', groups=128, padding='valid')(x)
    x = tf.keras.layers.Reshape((16,8))(x)
    x = tf.keras.layers.Lambda(squash, name='squash')(x)
    digit_caps = DigitCaps(10,16)(x)
    
    digit_caps_len = Length(name='micro_capsnet_output')(digit_caps)

    return tf.keras.Model(inputs=inputs,outputs=[digit_caps,digit_caps_len], name='CapsNet')

In [9]:
def modelPlay(generator):
    inputs = tf.keras.Input(input_shape)
    noise = tf.keras.layers.Input(shape=(10, 16))
    y_true = tf.keras.layers.Input(shape=(10,))
    
    digit_caps, digit_caps_len = CapsNet()(inputs)
    noised_digitcaps = tf.keras.layers.Add()([digit_caps, noise])

    masked = Mask()(digit_caps)  # Mask using the capsule with maximal length. For prediction
    masked_noised_y = Mask()([noised_digitcaps, y_true])
    
    x_rec_play = generator(masked_noised_y)
    
    
    return  tf.keras.models.Model([inputs, y_true, noise], x_rec_play)

In [10]:
model_play = modelPlay(Generator())
model_play.summary()

Model: "functional_1"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_2 (InputLayer)            [(None, 28, 28, 1)]  0                                            
__________________________________________________________________________________________________
CapsNet (Functional)            [(None, 10, 16), (No 162400      input_2[0][0]                    
__________________________________________________________________________________________________
input_3 (InputLayer)            [(None, 10, 16)]     0                                            
__________________________________________________________________________________________________
add (Add)                       (None, 10, 16)       0           CapsNet[0][0]                    
                                                                 input_3[0][0]         

In [11]:
name_model = 'efficient_capsnet_mnist_022.h5'

model_play.load_weights(model_dir.joinpath(name_model))

# Interact with model

In [20]:
class Visualizer(object):
    def __init__(self):
        self.min_value = - 0.30
        self.max_value = + 0.30
        self.step = 0.05
        self.sliders = {str(i):widgets.FloatSlider(min=self.min_value, max=self.max_value, step=self.step) for i in range(16)}
        self.text = widgets.IntText()
        self.sliders['index'] = self.text
        
    def affineTransform(self, **info):
    
        index = abs(int(info['index']))
        tmp = np.zeros([1, 10, 16])

        for d in range(16):
            tmp[:,:,d] = info[str(d)]

        X_gen = model_play.predict([X_test[index:index+1], y_test[index:index+1], tmp])

        fig, ax = plt.subplots(1, 2, figsize=(12,12))
        ax[0].imshow(X_test[index], cmap='gray')
        ax[0].set_title('Input Digit')
        ax[1].imshow(X_gen[0], cmap='gray')
        ax[1].set_title('Output Generator')
        plt.show()
    
    def on_button_clicked(self, k):
        for i in range(16):
            self.sliders[str(i)].value = 0
        
    def start(self):
        button = widgets.Button(description="Reset")
        button.on_click(self.on_button_clicked)
        
        main = widgets.HBox([self.text, button])
        u1 = widgets.HBox([self.sliders[str(i)] for i in range(0,4)])
        u2 = widgets.HBox([self.sliders[str(i)] for i in range(4,8)])
        u3 = widgets.HBox([self.sliders[str(i)] for i in range(8,12)])
        u4 = widgets.HBox([self.sliders[str(i)] for i in range(12,16)])
        
        out = widgets.interactive_output(self.affineTransform, self.sliders)
        
        display(main, u1, u2, u3, u4, out)

In [21]:
Visualizer().start()

HBox(children=(IntText(value=0), Button(description='Reset', style=ButtonStyle())))

HBox(children=(FloatSlider(value=0.0, max=7.3, min=-7.3, step=0.05), FloatSlider(value=0.0, max=7.3, min=-7.3,…

HBox(children=(FloatSlider(value=0.0, max=7.3, min=-7.3, step=0.05), FloatSlider(value=0.0, max=7.3, min=-7.3,…

HBox(children=(FloatSlider(value=0.0, max=7.3, min=-7.3, step=0.05), FloatSlider(value=0.0, max=7.3, min=-7.3,…

HBox(children=(FloatSlider(value=0.0, max=7.3, min=-7.3, step=0.05), FloatSlider(value=0.0, max=7.3, min=-7.3,…

Output()