In [45]:
import qualia2 
from qualia2.core import *
from qualia2.functions import sigmoid, tanh, concat, relu
from qualia2.nn import Module, Linear
from qualia2.autograd import Tensor
import matplotlib.pyplot as plt
import os
import ipywidgets as widgets
from IPython.display import display

path = os.path.dirname(os.path.abspath('generator.ipynb'))

def showfig(g, noise):
    g.eval()
    fake_img = g(noise)
    img = fake_img.data[0].reshape(28,28)
    plt.imshow(to_cpu(img) if gpu else img, cmap='gray', interpolation='nearest') 
    plt.grid(False)
    plt.show()   

class Generator(Module):
    def __init__(self):
        super().__init__()
        self.linear1 = Linear(50, 128)
        self.linear2 = Linear(128, 256)
        self.linear3 = Linear(256, 512)
        self.linear4 = Linear(512, 784)

    def forward(self, x):
        x = tanh(self.linear1(x))
        x = tanh(self.linear2(x))
        x = tanh(self.linear3(x))
        x = relu(tanh(self.linear4(x)))
        return x

a0 = widgets.FloatSlider(min=-1, max=1, step=0.01)
a1 = widgets.FloatSlider(min=-1, max=1, step=0.01)
a2 = widgets.FloatSlider(min=-1, max=1, step=0.01)
a3 = widgets.FloatSlider(min=-1, max=1, step=0.01)
a4 = widgets.FloatSlider(min=-1, max=1, step=0.01)
a5 = widgets.FloatSlider(min=-1, max=1, step=0.01)
a6 = widgets.FloatSlider(min=-1, max=1, step=0.01)
a7 = widgets.FloatSlider(min=-1, max=1, step=0.01)
a8 = widgets.FloatSlider(min=-1, max=1, step=0.01)
a9 = widgets.FloatSlider(min=-1, max=1, step=0.01)
b0 = widgets.FloatSlider(min=-1, max=1, step=0.01)
b1 = widgets.FloatSlider(min=-1, max=1, step=0.01)
b2 = widgets.FloatSlider(min=-1, max=1, step=0.01)
b3 = widgets.FloatSlider(min=-1, max=1, step=0.01)
b4 = widgets.FloatSlider(min=-1, max=1, step=0.01)
b5 = widgets.FloatSlider(min=-1, max=1, step=0.01)
b6 = widgets.FloatSlider(min=-1, max=1, step=0.01)
b7 = widgets.FloatSlider(min=-1, max=1, step=0.01)
b8 = widgets.FloatSlider(min=-1, max=1, step=0.01)
b9 = widgets.FloatSlider(min=-1, max=1, step=0.01)
c0 = widgets.FloatSlider(min=-1, max=1, step=0.01)
c1 = widgets.FloatSlider(min=-1, max=1, step=0.01)
c2 = widgets.FloatSlider(min=-1, max=1, step=0.01)
c3 = widgets.FloatSlider(min=-1, max=1, step=0.01)
c4 = widgets.FloatSlider(min=-1, max=1, step=0.01)
c5 = widgets.FloatSlider(min=-1, max=1, step=0.01)
c6 = widgets.FloatSlider(min=-1, max=1, step=0.01)
c7 = widgets.FloatSlider(min=-1, max=1, step=0.01)
c8 = widgets.FloatSlider(min=-1, max=1, step=0.01)
c9 = widgets.FloatSlider(min=-1, max=1, step=0.01)
d0 = widgets.FloatSlider(min=-1, max=1, step=0.01)
d1 = widgets.FloatSlider(min=-1, max=1, step=0.01)
d2 = widgets.FloatSlider(min=-1, max=1, step=0.01)
d3 = widgets.FloatSlider(min=-1, max=1, step=0.01)
d4 = widgets.FloatSlider(min=-1, max=1, step=0.01)
d5 = widgets.FloatSlider(min=-1, max=1, step=0.01)
d6 = widgets.FloatSlider(min=-1, max=1, step=0.01)
d7 = widgets.FloatSlider(min=-1, max=1, step=0.01)
d8 = widgets.FloatSlider(min=-1, max=1, step=0.01)
d9 = widgets.FloatSlider(min=-1, max=1, step=0.01)
e0 = widgets.FloatSlider(min=-1, max=1, step=0.01)
e1 = widgets.FloatSlider(min=-1, max=1, step=0.01)
e2 = widgets.FloatSlider(min=-1, max=1, step=0.01)
e3 = widgets.FloatSlider(min=-1, max=1, step=0.01)
e4 = widgets.FloatSlider(min=-1, max=1, step=0.01)
e5 = widgets.FloatSlider(min=-1, max=1, step=0.01)
e6 = widgets.FloatSlider(min=-1, max=1, step=0.01)
e7 = widgets.FloatSlider(min=-1, max=1, step=0.01)
e8 = widgets.FloatSlider(min=-1, max=1, step=0.01)
e9 = widgets.FloatSlider(min=-1, max=1, step=0.01)

row0 = widgets.HBox([a0,a1,a2,a3,a4,a5,a6,a7,a8,a9])
row1 = widgets.HBox([b0,b1,b2,b3,b4,b5,b6,b7,b8,b9])
row2 = widgets.HBox([c0,c1,c2,c3,c4,c5,c6,c7,c8,c9])
row3 = widgets.HBox([d0,d1,d2,d3,d4,d5,d6,d7,d8,d9])
row4 = widgets.HBox([e0,e1,e2,e3,e4,e5,e6,e7,e8,e9])
ui = widgets.VBox([row0,row1,row2,row3,row4])

def generate(**kwargs):
    check_noise = qualia2.array([i for _, i in kwargs.items()]).reshape(1,50)
    g = Generator()
    g.load(path+'/weights/gan_g')  
    showfig(g, check_noise)

out = widgets.interactive_output(generate, {'a0':a0,'a1':a1,'a2':a2,'a3':a3,'a4':a4,'a5':a5,'a6':a6,'a7':a7,'a8':a8,'a9':a9,
                                     'b0':b0,'b1':b1,'b2':b2,'b3':b3,'b4':b4,'b5':b5,'b6':b6,'b7':b7,'b8':b8,'b9':b9,
                                     'c0':c0,'c1':c1,'c2':c2,'c3':c3,'c4':c4,'c5':c5,'c6':c6,'c7':c7,'c8':c8,'c9':c9,
                                     'd0':d0,'d1':d1,'d2':d2,'d3':d3,'d4':d4,'d5':d5,'d6':d6,'d7':d7,'d8':d8,'d9':d9,
                                     'e0':e0,'e1':e1,'e2':e2,'e3':e3,'e4':e4,'e5':e5,'e6':e6,'e7':e7,'e8':e8,'e9':e9,})

display(ui, out)


VBox(children=(HBox(children=(FloatSlider(value=0.0, max=1.0, min=-1.0, step=0.01), FloatSlider(value=0.0, max…

Output()