# MNIST

In [1]:
%reset

Once deleted, variables cannot be recovered. Proceed (y/[n])? y


In [2]:
import tensorflow as tf
from tensorflow import keras
from keras.layers import Input
from tensorflow.keras.optimizers import RMSprop
from keras.models import Model
from keras.datasets import mnist
from keras import backend as K
from keras.models import load_model

import numpy as np
import argparse
import matplotlib.pyplot as plt

from lib import gan

%matplotlib inline
%config InlineBackend.figure_format='retina'

In [3]:
def build_and_train_models():
    (x_train, _), (_, _) = mnist.load_data()
    
    # Reshape & Normalize
    image_size = x_train.shape[1]
    x_train = np.reshape(x_train, [-1, image_size, image_size, 1])
    x_train = x_train.astype('float32')/255
    
    model_name = 'lsgan_mnist'
    
    # Network Params
    latent_size = 100 # z-vector dimension
    batch_size = 64
    train_steps = 40000
    lr = 2e-4
    decay = 6e-8
    input_shape = (image_size, image_size, 1)

    # In Keras 2.11.0, 'decay' argument changed to 'weight_decay'
    # Discriminator model
    inputs = keras.Input(shape=input_shape, name='discriminator_input')
    discriminator = gan.discriminator(inputs, activation=None)
    # discriminator = build_discriminator(inputs)
    optimizer = RMSprop(learning_rate=lr)
    discriminator.compile(loss='mse',
                          optimizer=optimizer,
                          metrics=['accuracy']
                         )
    discriminator.summary()
    
    # Generator model
    input_shape = (latent_size, )
    inputs = keras.Input(shape=input_shape, name='z_input')
    generator = gan.generator(inputs, image_size)
    generator.summary()
    
    # Adversarial model
    optimizer = RMSprop(learning_rate=lr*0.5, decay=decay*0.5,)
    discriminator.trainable = False
    ## Adversarial = Generator + Discriminator
    adversarial = Model(inputs, discriminator(generator(inputs)), name=model_name)
    adversarial.compile(loss='mse',
                        optimizer=optimizer,
                        metrics=['accuracy']
                       )
    adversarial.summary()
    
    models = (generator, discriminator, adversarial)
    params = (batch_size, latent_size, train_steps, model_name)
    gan.train(models, x_train, params)

In [5]:
if __name__=='__main__':
    build_and_train_models()

Metal device set to: Apple M2
Model: "discriminator"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 discriminator_input (InputL  [(None, 28, 28, 1)]      0         
 ayer)                                                           
                                                                 
 leaky_re_lu (LeakyReLU)     (None, 28, 28, 1)         0         
                                                                 
 conv2d (Conv2D)             (None, 14, 14, 32)        832       
                                                                 
 leaky_re_lu_1 (LeakyReLU)   (None, 14, 14, 32)        0         
                                                                 
 conv2d_1 (Conv2D)           (None, 7, 7, 64)          51264     
                                                                 
 leaky_re_lu_2 (LeakyReLU)   (None, 7, 7, 64)          0         
                       

2023-02-26 14:28:25.360762: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:305] Could not identify NUMA node of platform GPU ID 0, defaulting to 0. Your kernel may not have been built with NUMA support.
2023-02-26 14:28:25.361046: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:271] Created TensorFlow device (/job:localhost/replica:0/task:0/device:GPU:0 with 0 MB memory) -> physical PluggableDevice (device: 0, name: METAL, pci bus id: <undefined>)


 ormalization)                                                   
                                                                 
 activation (Activation)     (None, 7, 7, 128)         0         
                                                                 
 conv2d_transpose (Conv2DTra  (None, 14, 14, 128)      409728    
 nspose)                                                         
                                                                 
 batch_normalization_1 (Batc  (None, 14, 14, 128)      512       
 hNormalization)                                                 
                                                                 
 activation_1 (Activation)   (None, 14, 14, 128)       0         
                                                                 
 conv2d_transpose_1 (Conv2DT  (None, 28, 28, 64)       204864    
 ranspose)                                                       
                                                                 
 batch_nor

2023-02-26 14:28:25.660499: W tensorflow/core/platform/profile_utils/cpu_utils.cc:128] Failed to get CPU frequency: 0 Hz
2023-02-26 14:28:25.728879: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:113] Plugin optimizer for device_type GPU is enabled.
2023-02-26 14:28:26.007700: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:113] Plugin optimizer for device_type GPU is enabled.
2023-02-26 14:28:26.704596: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:113] Plugin optimizer for device_type GPU is enabled.


0: [discriminator loss: 0.500465, acc: 0.500000] [adversarial loss: 0.524313, acc: 0.000000]
1: [discriminator loss: 0.284929, acc: 0.500000] [adversarial loss: 0.210235, acc: 0.984375]
2: [discriminator loss: 0.122025, acc: 0.914062] [adversarial loss: 0.189520, acc: 1.000000]
3: [discriminator loss: 0.029710, acc: 0.992188] [adversarial loss: 0.316113, acc: 1.000000]
4: [discriminator loss: 0.188594, acc: 1.000000] [adversarial loss: 1.304102, acc: 0.000000]
5: [discriminator loss: 0.192147, acc: 0.609375] [adversarial loss: 0.697409, acc: 0.000000]
6: [discriminator loss: 0.067073, acc: 0.953125] [adversarial loss: 0.314149, acc: 0.000000]
7: [discriminator loss: 0.024091, acc: 0.992188] [adversarial loss: 0.136305, acc: 1.000000]
8: [discriminator loss: 0.025811, acc: 1.000000] [adversarial loss: 0.145105, acc: 1.000000]
9: [discriminator loss: 0.016205, acc: 1.000000] [adversarial loss: 0.054616, acc: 1.000000]
10: [discriminator loss: 0.011949, acc: 1.000000] [adversarial loss: 0

88: [discriminator loss: 0.015425, acc: 1.000000] [adversarial loss: 0.013546, acc: 1.000000]
89: [discriminator loss: 0.010681, acc: 1.000000] [adversarial loss: 0.076251, acc: 1.000000]
90: [discriminator loss: 0.015047, acc: 1.000000] [adversarial loss: 0.018334, acc: 1.000000]
91: [discriminator loss: 0.009276, acc: 1.000000] [adversarial loss: 0.068541, acc: 1.000000]
92: [discriminator loss: 0.014587, acc: 1.000000] [adversarial loss: 0.019203, acc: 1.000000]
93: [discriminator loss: 0.010191, acc: 1.000000] [adversarial loss: 0.059593, acc: 1.000000]
94: [discriminator loss: 0.013435, acc: 1.000000] [adversarial loss: 0.015261, acc: 1.000000]
95: [discriminator loss: 0.008441, acc: 1.000000] [adversarial loss: 0.061436, acc: 1.000000]
96: [discriminator loss: 0.012152, acc: 1.000000] [adversarial loss: 0.011828, acc: 1.000000]
97: [discriminator loss: 0.006314, acc: 1.000000] [adversarial loss: 0.038646, acc: 1.000000]
98: [discriminator loss: 0.006936, acc: 1.000000] [adversari

KeyboardInterrupt: 