In [4]:
from IPython.display import display, HTML
display(HTML("<style>.container { width:96% !important; }</style>"))

In [13]:
import numpy as np
import tensorflow as tf
from tensorflow.keras import layers
from tensorflow.keras.layers import Flatten, Dense, Conv2D, ReLU, LeakyReLU, Reshape
from tensorflow.keras.optimizers import Adam
from tensorflow.keras import Sequential

In [39]:
from keras.datasets import mnist

In [11]:
row = 32
col = 32
planes = 1

img_shape = (row, col , planes)
input_dims = 100

## Generator Network

In [23]:
def build_the_generator(input_dims=input_dims, img_shape=img_shape):
    generator = Sequential([
                          Dense(512, input_dim=input_dims), # Input Layer
                          LeakyReLU(alpha=0.01), # f(x) = a
                          Dense(32 * 32 * 1, activation='tanh'), # Output Layer
                          Reshape(img_shape)
                         ])
    return generator
    
    

In [24]:
generator = build_the_generator()

In [25]:
for i, layers_name in enumerate(generator.layers):
    print(i, layers_name)

0 <keras.layers.core.dense.Dense object at 0x000002DF8B9DA350>
1 <keras.layers.activation.leaky_relu.LeakyReLU object at 0x000002DF8C6BAC90>
2 <keras.layers.core.dense.Dense object at 0x000002DF89383D50>
3 <keras.layers.reshaping.reshape.Reshape object at 0x000002DF89393E90>


## Discriminator

In [26]:
def build_the_discriminator(img_shape=img_shape):
    discriminator = Sequential([
        Flatten(input_shape=img_shape),
        Dense(512),
        LeakyReLU(alpha=0.01),
        Dense(1, activation='sigmoid') # Sigmoid, Since, the discriminator is doing binary classification
    ])
    return discriminator

In [27]:
discriminator=build_the_discriminator()

In [28]:
for i, layers_name in enumerate(discriminator.layers):
    print(i, layers_name)

0 <keras.layers.reshaping.flatten.Flatten object at 0x000002DF8C9144D0>
1 <keras.layers.core.dense.Dense object at 0x000002DF8CC27D90>
2 <keras.layers.activation.leaky_relu.LeakyReLU object at 0x000002DF8C673F50>
3 <keras.layers.core.dense.Dense object at 0x000002DF8CC36C10>


In [29]:
def gan_model(generator=generator, discriminator=discriminator):
    model_gan = Sequential([
        generator,
        discriminator
    ])
    return model_gan

In [30]:
gan_model = gan_model()

#### Model Compilation

In [37]:
discriminator.compile(optimizer=tf.keras.optimizers.Adam(),
                     loss=tf.keras.losses.binary_crossentropy,
                     metrics=['accuracy'])

# Freezing the Model
discriminator.trainable = False

# GAN Compilation
gan_model.compile(loss=tf.keras.losses.binary_crossentropy,
                 optimizer=Adam())

In [38]:
for i , layers in enumerate(gan_model.layers):
    print(f"Layer : {i} - {layers}")
    for j, layers_name in enumerate(layers.layers):
        print(f"\tInner Layer : {j} - {layers_name}")

Layer : 0 - <keras.engine.sequential.Sequential object at 0x000002DF8C542050>
	Inner Layer : 0 - <keras.layers.core.dense.Dense object at 0x000002DF8B9DA350>
	Inner Layer : 1 - <keras.layers.activation.leaky_relu.LeakyReLU object at 0x000002DF8C6BAC90>
	Inner Layer : 2 - <keras.layers.core.dense.Dense object at 0x000002DF89383D50>
	Inner Layer : 3 - <keras.layers.reshaping.reshape.Reshape object at 0x000002DF89393E90>
Layer : 1 - <keras.engine.sequential.Sequential object at 0x000002DF8CC34510>
	Inner Layer : 0 - <keras.layers.reshaping.flatten.Flatten object at 0x000002DF8C9144D0>
	Inner Layer : 1 - <keras.layers.core.dense.Dense object at 0x000002DF8CC27D90>
	Inner Layer : 2 - <keras.layers.activation.leaky_relu.LeakyReLU object at 0x000002DF8C673F50>
	Inner Layer : 3 - <keras.layers.core.dense.Dense object at 0x000002DF8CC36C10>


### Model Training

In [51]:
losses = list()
accuracies = list()
chkps = list()

def train_the_model(epochs , batch_size, checkpoint):
    # Loading the dataset
    (X_train, y_train), (X_val, y_val) = mnist.load_data()
    
    # Rescaling the dataset
    X_train = X_train / X_train.max()
    X_train = tf.expand_dims(X_train, axis=3) # Flatten the dataset 
    
    # Real image labels
    real_labels = np.ones((batch_size, 1))
    # Fake image labels
    fake_labels = np.zeros((batch_size, 1))
    
    for iteration in range(1,epochs + 1):
        # Batch of real images
        idx = np.random.randint(0, X_train.shape[0] , batch_size)
        imgs = X_train[idx]
        
        # Batch of fake images 
        z = np.random.normal(0, 1, (batch_size))
        z = tf.expand_dims(z, axis=1) # Flatten Image Matrix
        gen_imgs = generator.predict(z)
        
        # Train Discriminator
        d_loss_real = discriminator.train_on_batch(imgs, real_labels)
        d_loss_fake = discriminator.train_on_batch(gen_imgs, fake_labels)
        d_loss, accuracy = 0.5 * np.add(d_loss_fake, d_loss_real)
        
        # Batch of fake Images for GAN
        z = np.random.normal(0, 1, (batch_size, 1)) # Flatten Image 
        gen_imgs = generator.predict(z)
        
        # GAN loss
        g_loss = gan_model.train_on_batch(z, real_labels)
        
        if (iteration % checkpoint) == 0:
            losses.append(d_loss, g_loss)
            accuracies.append(100 * accuracy)
            chkps.append(iteration)
            
            print(f"[ Iteration : {iteration}, Loss : {d_loss}, Accuracy : {accuracy * 100}, GanLoss : {g_loss} ]")
            show_sample_images(generator)
        

In [41]:
def show_sample_images(generator,rows=4, cols=4):
    
    z = np.random.normal(0, 1, (rows * cols, input_dims))
    gen_imgs = generator.predict(z)
    
    # Rescaling the image 
    #gen_imgs = 0.5 * gen_imgs + 0.5 
    cnt = 0
    plt.figure(figsize=(16,16))
    for row in range(1, rows + 1):
        for col in range(1, cols + 1):
            plt.subplot(row, col , cnt + 1)
            plt.imshow(gen_imgs[cnt, : , : , 0], cmap='gray')
            plt.axis(False)
            cnt += 1
        
    

In [50]:
iterations = 20000
batch_size = 512
sample_interval = 1000

train_the_model(iterations, batch_size, sample_interval)

TypeError: Only integers, slices (`:`), ellipsis (`...`), tf.newaxis (`None`) and scalar tf.int32/tf.int64 tensors are valid indices, got array([37104, 21402, 49647,  2308,  4888,  7605, 56521, 50931, 12039,
       51679, 52852, 52947, 13242, 59265, 37394, 20444, 29963, 43773,
       53203, 32249, 55178, 20857, 47618, 38034,  3612,  1599, 52535,
       11418, 20046, 26416, 38725,  5723, 29328, 30268, 40452, 44107,
       11972,  9157, 32770, 54650, 30934, 15991, 35429, 53844, 58928,
       26727, 37649, 35869, 51812, 46786, 10323, 42968, 58411, 51459,
       50407, 25724, 42662, 50737, 28103, 52879, 22241, 10453, 51701,
       31756, 51815, 24664,  7825, 13991, 47759, 10382, 16148, 38152,
       17079, 19974, 22742, 55730, 57004, 20155, 50717, 45748,  9361,
       40415, 23702, 36451,  2844, 33579, 46646, 45245, 53963, 48034,
        8494, 51934, 39233, 33712, 31045, 22799, 32480, 48996, 45612,
        8088, 17385, 18977, 13308, 59215, 38002, 44439, 54580,  6233,
        5044, 25752, 11050,  1213, 16049, 13203, 30204,  4570, 34477,
       47597, 38948, 21864, 19542,  8893, 39462,  6588, 13845, 33218,
       49309, 56737, 14601, 12338, 40278, 15681, 14499,  4082, 28700,
       56985, 33862, 35817, 54303, 50850, 30384,  5357, 15455, 18281,
       49674, 26334, 14525, 12034, 11797, 41051,  4072, 26874, 19476,
       29043, 42508, 20716, 58882, 30288, 49325, 53601, 46067, 14498,
        9254, 57390,  1929, 15716, 26128,  3957, 49082, 44611, 14145,
       10154, 43274, 58040, 58018, 31531, 31500,   646, 42100, 44566,
        2206, 57714, 41991, 13628, 27104, 19495, 23992, 21706, 30698,
        4831, 14359, 53822, 35602, 55664,  5049, 58250, 40952, 39336,
       37535, 11597,   573, 15537, 49569, 31941, 42656,  5617, 41101,
       35333, 55145, 28073, 30919, 48606,  8349, 59543, 22275,   676,
       47093, 18029, 24085,  2089, 14273, 50288, 46814, 58839,  5439,
       16699, 16029, 35067,  7978, 50323, 33482, 25756, 57195, 14659,
       26381, 42524, 58648, 46000, 42641, 39760, 50085, 36181, 52229,
       38707, 22975, 10824, 13690, 42781, 26435, 49917, 11487, 51065,
       43291, 53730, 25419, 32438, 20428, 19607,  6711, 41154,  2537,
       45543,  3171, 19374, 32294, 44092, 11143, 25439, 53653, 42064,
       53677,  3086, 40802, 42898, 35776,  1997, 29234,  3152, 18977,
       34404,  1128, 17519, 21804, 23908, 48858, 34573, 51800, 55817,
        8595, 17801,  5928, 31065, 50265, 12224, 39956, 44792, 33679,
       33143, 10295,  2656,     2, 39278, 54740, 22531, 33907, 33705,
        3459, 30194, 33910, 47205, 23384,  2300, 51444, 14705, 20442,
       29965, 26237, 39830,  8899, 11213, 39638, 44746,  3613,  5668,
       44983, 21475, 19746, 12924, 49590, 50942, 32055, 26166, 29239,
       38543, 16195, 10753, 30141, 54183,  7369, 35948, 44068,  6039,
        4041, 22286,  7973, 40694, 53653, 31250, 10517, 23608, 20962,
       28879, 26772, 40948, 25659, 41614, 19200,  1688, 24878, 17465,
       51290, 25138, 23158, 26156, 45080, 42151, 29110,   886, 34288,
       31330, 19504, 51151, 37775, 37947,  4930, 16126, 49754, 26917,
       13942, 55081, 45923, 57417, 40405, 15483,  5915, 54425, 52249,
       56408,  8056,  2666, 43202, 56139, 50253,  5495, 54453, 15076,
        8834, 57400, 32334, 32450, 45684, 50127,  7208, 37947,  8330,
        9455, 45509, 50494, 14257, 37553,  2347, 38895, 35053,  1254,
       29915, 54139, 14338, 20078, 58143, 23525, 25665, 10418, 57565,
        4063, 24421, 39343, 23341, 30786, 12127, 43510,  5849,  9482,
       37177, 18091, 48664, 49368, 23784, 22667, 49586, 41401,  2565,
       57319, 19620, 51232, 42668, 20389, 40330, 21553, 42561, 33804,
       46984, 34513, 43883, 10797, 12108, 22307, 30853, 19412, 50357,
       16702, 18738, 52939, 26450,  3765, 18180, 26286, 15554,  4858,
       22914,  6525, 46119,  3179, 32569, 10195,  5480, 38253, 23687,
       33850, 49515, 10439, 42376, 21040, 59366, 19622, 13666,  6411,
       24984, 52360,  2524, 59806, 50052, 37777, 58774, 15135,  7840,
       55568, 53596, 45505, 59871, 15904, 11669, 41305, 41795, 25472,
       45609, 26324, 11724, 10851, 43462, 31014, 50063, 46594])