In [6]:
from sklearn import datasets
import math
import matplotlib.pyplot as plt
import numpy as np
import progressbar

from sklearn.datasets import fetch_openml

from scratchkit.dl.optim import Adam
from scratchkit.dl.losses import CrossEntropy, SquareLoss
from scratchkit.dl.layers import Dense, Dropout, Flatten, Activation, Reshape, BatchNormalization
from scratchkit.dl import NeuralNetwork

In [9]:
class AutoEncoder():
    """
    An AutoEncoder with deep-connected neural nets
    
    Training data: MNIST
    """
    def __init__(self):
        self.img_rows = 28
        self.img_cols = 28
        self.img_dim = self.img_rows * self.img_cols
        self.latent_dim = 128 # The dimension of the data embedding

        optimizer = Adam(learning_rate=0.0002, b1=0.5)
        loss_function = SquareLoss

        self.encoder = self.build_encoder(optimizer, loss_function)
        self.decoder = self.build_decoder(optimizer, loss_function)

        self.autoencoder = NeuralNetwork(optimizer=optimizer, loss=loss_function)
        self.autoencoder.layers.extend(self.encoder.layers)
        self.autoencoder.layers.extend(self.decoder.layers)

        print ()
        self.autoencoder.summary(name="Variational Autoencoder")

    def build_encoder(self, optimizer, loss_function):

        encoder = NeuralNetwork(optimizer=optimizer, loss=loss_function)
        encoder.add(Dense(512, input_shape=(self.img_dim,)))
        encoder.add(Activation('leaky_relu'))
        encoder.add(BatchNormalization(momentum=0.8))
        encoder.add(Dense(256))
        encoder.add(Activation('leaky_relu'))
        encoder.add(BatchNormalization(momentum=0.8))
        encoder.add(Dense(self.latent_dim))

        return encoder

    def build_decoder(self, optimizer, loss_function):

        decoder = NeuralNetwork(optimizer=optimizer, loss=loss_function)
        decoder.add(Dense(256, input_shape=(self.latent_dim,)))
        decoder.add(Activation('leaky_relu'))
        decoder.add(BatchNormalization(momentum=0.8))
        decoder.add(Dense(512))
        decoder.add(Activation('leaky_relu'))
        decoder.add(BatchNormalization(momentum=0.8))
        decoder.add(Dense(self.img_dim))
        decoder.add(Activation('tanh'))

        return decoder

    def train(self, n_epochs, batch_size=128, save_interval=50):

        mnist = fetch_openml('mnist_784')

        X = mnist.data
        y = mnist.target

        # Rescale [-1, 1]
        X = (X.astype(np.float32) - 127.5) / 127.5

        for epoch in range(n_epochs):

            # Select a random half batch of images
            idx = np.random.randint(0, X.shape[0], batch_size)
            imgs = X[idx]

            # Train the Autoencoder
            loss, _ = self.autoencoder.train_on_batch(imgs, imgs)

            # Display the progress
            print ("%d [D loss: %f]" % (epoch, loss))

            # If at save interval => save generated image samples
            if epoch % save_interval == 0:
                self.save_imgs(epoch, X)

    def save_imgs(self, epoch, X):
        r, c = 5, 5 # Grid size
        # Select a random half batch of images
        idx = np.random.randint(0, X.shape[0], r*c)
        imgs = X[idx]
        # Generate images and reshape to image shape
        gen_imgs = self.autoencoder.predict(imgs).reshape((-1, self.img_rows, self.img_cols))

        # Rescale images 0 - 1
        gen_imgs = 0.5 * gen_imgs + 0.5

        fig, axs = plt.subplots(r, c)
        plt.suptitle("Autoencoder")
        cnt = 0
        for i in range(r):
            for j in range(c):
                axs[i,j].imshow(gen_imgs[cnt,:,:], cmap='gray')
                axs[i,j].axis('off')
                cnt += 1
        fig.savefig("ae_%d.png" % epoch)
        plt.close()

In [11]:
# Train it
ae = AutoEncoder()
ae.train(n_epochs=2000, batch_size=64, save_interval=400)


+-------------------------+
| Variational Autoencoder |
+-------------------------+
Input Shape: (784,)
+------------------------+------------+--------------+
| Layer Type             | Parameters | Output Shape |
+------------------------+------------+--------------+
| Dense                  | 401920     | (512,)       |
| Activation (LeakyReLU) | 0          | (512,)       |
| BatchNormalization     | 1024       | (512,)       |
| Dense                  | 131328     | (256,)       |
| Activation (LeakyReLU) | 0          | (256,)       |
| BatchNormalization     | 512        | (256,)       |
| Dense                  | 32896      | (128,)       |
| Dense                  | 33024      | (256,)       |
| Activation (LeakyReLU) | 0          | (256,)       |
| BatchNormalization     | 512        | (256,)       |
| Dense                  | 131584     | (512,)       |
| Activation (LeakyReLU) | 0          | (512,)       |
| BatchNormalization     | 1024       | (512,)       |
| Dense        

Training: N/A% [-                                              ] ETA:  --:--:--
Training: N/A% [-                                              ] ETA:  --:--:--
Training: N/A% [-                                              ] ETA:  --:--:--


0 [D loss: 0.560906]
1 [D loss: 0.551484]
2 [D loss: 0.545883]
3 [D loss: 0.535760]
4 [D loss: 0.527998]
5 [D loss: 0.520596]
6 [D loss: 0.516944]
7 [D loss: 0.509262]
8 [D loss: 0.500547]
9 [D loss: 0.494367]
10 [D loss: 0.493100]
11 [D loss: 0.491192]
12 [D loss: 0.481135]
13 [D loss: 0.477259]
14 [D loss: 0.481425]
15 [D loss: 0.476635]
16 [D loss: 0.470372]
17 [D loss: 0.470374]
18 [D loss: 0.465281]
19 [D loss: 0.467958]
20 [D loss: 0.459608]
21 [D loss: 0.459647]
22 [D loss: 0.453013]
23 [D loss: 0.449364]
24 [D loss: 0.446528]
25 [D loss: 0.452391]
26 [D loss: 0.450766]
27 [D loss: 0.445012]
28 [D loss: 0.445116]
29 [D loss: 0.439396]
30 [D loss: 0.440218]
31 [D loss: 0.440462]
32 [D loss: 0.439656]
33 [D loss: 0.437230]
34 [D loss: 0.440809]
35 [D loss: 0.438820]
36 [D loss: 0.434369]
37 [D loss: 0.430145]
38 [D loss: 0.440041]
39 [D loss: 0.437350]
40 [D loss: 0.426140]
41 [D loss: 0.425560]
42 [D loss: 0.427967]
43 [D loss: 0.430840]
44 [D loss: 0.422797]
45 [D loss: 0.427004

361 [D loss: 0.336948]
362 [D loss: 0.325840]
363 [D loss: 0.344711]
364 [D loss: 0.337073]
365 [D loss: 0.317330]
366 [D loss: 0.331502]
367 [D loss: 0.335071]
368 [D loss: 0.329961]
369 [D loss: 0.329408]
370 [D loss: 0.345480]
371 [D loss: 0.327159]
372 [D loss: 0.325845]
373 [D loss: 0.340448]
374 [D loss: 0.330513]
375 [D loss: 0.325974]
376 [D loss: 0.331535]
377 [D loss: 0.326588]
378 [D loss: 0.337907]
379 [D loss: 0.327959]
380 [D loss: 0.326471]
381 [D loss: 0.325710]
382 [D loss: 0.322307]
383 [D loss: 0.332611]
384 [D loss: 0.317306]
385 [D loss: 0.323204]
386 [D loss: 0.329101]
387 [D loss: 0.319455]
388 [D loss: 0.326013]
389 [D loss: 0.321019]
390 [D loss: 0.322160]
391 [D loss: 0.327345]
392 [D loss: 0.325745]
393 [D loss: 0.334736]
394 [D loss: 0.323122]
395 [D loss: 0.322082]
396 [D loss: 0.317390]
397 [D loss: 0.326995]
398 [D loss: 0.325722]
399 [D loss: 0.325633]
400 [D loss: 0.329627]
401 [D loss: 0.318654]
402 [D loss: 0.323271]
403 [D loss: 0.320809]
404 [D loss

719 [D loss: 0.275136]
720 [D loss: 0.263352]
721 [D loss: 0.267721]
722 [D loss: 0.278567]
723 [D loss: 0.271579]
724 [D loss: 0.272953]
725 [D loss: 0.285441]
726 [D loss: 0.271581]
727 [D loss: 0.273598]
728 [D loss: 0.276427]
729 [D loss: 0.270776]
730 [D loss: 0.271626]
731 [D loss: 0.268862]
732 [D loss: 0.273497]
733 [D loss: 0.262384]
734 [D loss: 0.265115]
735 [D loss: 0.268041]
736 [D loss: 0.268923]
737 [D loss: 0.263871]
738 [D loss: 0.275401]
739 [D loss: 0.266247]
740 [D loss: 0.262472]
741 [D loss: 0.264428]
742 [D loss: 0.271376]
743 [D loss: 0.261737]
744 [D loss: 0.269662]
745 [D loss: 0.261394]
746 [D loss: 0.269666]
747 [D loss: 0.261857]
748 [D loss: 0.270687]
749 [D loss: 0.257760]
750 [D loss: 0.273440]
751 [D loss: 0.263558]
752 [D loss: 0.269584]
753 [D loss: 0.259218]
754 [D loss: 0.328958]
755 [D loss: 0.266209]
756 [D loss: 0.267614]
757 [D loss: 0.268450]
758 [D loss: 0.265275]
759 [D loss: 0.265057]
760 [D loss: 0.272539]
761 [D loss: 0.270330]
762 [D loss

1073 [D loss: 0.223276]
1074 [D loss: 0.215080]
1075 [D loss: 0.216989]
1076 [D loss: 0.216015]
1077 [D loss: 0.216647]
1078 [D loss: 0.247748]
1079 [D loss: 0.220541]
1080 [D loss: 0.214518]
1081 [D loss: 0.215504]
1082 [D loss: 0.220075]
1083 [D loss: 0.209746]
1084 [D loss: 0.218395]
1085 [D loss: 0.209641]
1086 [D loss: 0.211919]
1087 [D loss: 0.211294]
1088 [D loss: 0.207582]
1089 [D loss: 0.214971]
1090 [D loss: 0.208086]
1091 [D loss: 0.222644]
1092 [D loss: 0.215646]
1093 [D loss: 0.206896]
1094 [D loss: 0.211418]
1095 [D loss: 0.209477]
1096 [D loss: 0.209207]
1097 [D loss: 0.212992]
1098 [D loss: 0.211073]
1099 [D loss: 0.213281]
1100 [D loss: 0.215260]
1101 [D loss: 0.209953]
1102 [D loss: 0.205591]
1103 [D loss: 0.215391]
1104 [D loss: 0.205431]
1105 [D loss: 0.213849]
1106 [D loss: 0.202185]
1107 [D loss: 0.205390]
1108 [D loss: 0.213414]
1109 [D loss: 0.215216]
1110 [D loss: 0.208652]
1111 [D loss: 0.211170]
1112 [D loss: 0.219396]
1113 [D loss: 0.203436]
1114 [D loss: 0.

1415 [D loss: 0.171902]
1416 [D loss: 0.182527]
1417 [D loss: 0.223097]
1418 [D loss: 0.166039]
1419 [D loss: 0.164080]
1420 [D loss: 0.182927]
1421 [D loss: 0.178711]
1422 [D loss: 0.172365]
1423 [D loss: 0.167417]
1424 [D loss: 0.170864]
1425 [D loss: 0.164128]
1426 [D loss: 0.177542]
1427 [D loss: 0.174566]
1428 [D loss: 0.171975]
1429 [D loss: 0.178469]
1430 [D loss: 0.170564]
1431 [D loss: 0.175061]
1432 [D loss: 0.173016]
1433 [D loss: 0.165225]
1434 [D loss: 0.172160]
1435 [D loss: 0.179535]
1436 [D loss: 0.163909]
1437 [D loss: 0.165031]
1438 [D loss: 0.173212]
1439 [D loss: 0.164629]
1440 [D loss: 0.168311]
1441 [D loss: 0.174264]
1442 [D loss: 0.166996]
1443 [D loss: 0.171023]
1444 [D loss: 0.169427]
1445 [D loss: 0.164540]
1446 [D loss: 0.177205]
1447 [D loss: 0.160845]
1448 [D loss: 0.168069]
1449 [D loss: 0.165348]
1450 [D loss: 0.172072]
1451 [D loss: 0.171646]
1452 [D loss: 0.176803]
1453 [D loss: 0.162941]
1454 [D loss: 0.171473]
1455 [D loss: 0.159878]
1456 [D loss: 0.

1757 [D loss: 0.131459]
1758 [D loss: 0.129239]
1759 [D loss: 0.140539]
1760 [D loss: 0.134377]
1761 [D loss: 0.140450]
1762 [D loss: 0.147052]
1763 [D loss: 0.146238]
1764 [D loss: 0.135368]
1765 [D loss: 0.139108]
1766 [D loss: 0.134915]
1767 [D loss: 0.137447]
1768 [D loss: 0.136226]
1769 [D loss: 0.143368]
1770 [D loss: 0.149929]
1771 [D loss: 0.132010]
1772 [D loss: 0.135586]
1773 [D loss: 0.132458]
1774 [D loss: 0.170353]
1775 [D loss: 0.137787]
1776 [D loss: 0.134929]
1777 [D loss: 0.132476]
1778 [D loss: 0.140131]
1779 [D loss: 0.137273]
1780 [D loss: 0.137344]
1781 [D loss: 0.132264]
1782 [D loss: 0.130964]
1783 [D loss: 0.128787]
1784 [D loss: 0.129708]
1785 [D loss: 0.139605]
1786 [D loss: 0.132390]
1787 [D loss: 0.133864]
1788 [D loss: 0.131308]
1789 [D loss: 0.130535]
1790 [D loss: 0.130730]
1791 [D loss: 0.141824]
1792 [D loss: 0.219446]
1793 [D loss: 0.141172]
1794 [D loss: 0.138501]
1795 [D loss: 0.134218]
1796 [D loss: 0.136000]
1797 [D loss: 0.130675]
1798 [D loss: 0.

In [12]:
ae

<__main__.AutoEncoder at 0x269e8389940>

In [13]:
ae.img_dim

784