In [96]:
import numpy as np
import matplotlib.pyplot as plt

from keras.datasets import mnist
from keras.layers import Dropout, Input, Dense, Flatten, Reshape, BatchNormalization
from keras.models import Sequential, Model
from keras.optimizers import Adam
import keras
import tensorflow as tf
from tensorflow.python.keras.backend import set_session
from tensorflow.python.keras.models import load_model

In [93]:
class GAN:
    
    def __init__(self):
        
        #dimensions
        self.rows = 28
        self.cols = 28
        self.channels = 1
        self.img_shape = (self.rows, self.cols, self.channels)
        self.lat_dim = 100
        
        optimizer = Adam(0.0002, 0.5)
        
        self.discriminator = self.disc()
        self.discriminator.compile(loss = 'binary_crossentropy', optimizer = optimizer, metrics = ['accuracy'])
        
        self.generator = self.gen()
        z = Input(shape = (self.lat_dim, ))
        img = self.generator(z)
        
        self.discriminator.trainable = False
        validity = self.discriminator(img)
        
        self.combined = Model(z, validity)
        self.combined.compile(loss = 'binary_crossentropy', optimizer = optimizer)
        

#discriminator
    def disc(self):
        
        model = Sequential()
        model.add(Flatten(input_shape = self.img_shape))
        model.add(Dense(512, activation = 'relu'))
        model.add(Dense(256, activation = 'relu'))
        model.add(Dense(1, activation = 'sigmoid'))
        print("Discriminator :")
        print(model.summary())
        
        img = Input(shape = self.img_shape)
        validity = model(img)
        
        return Model(img, validity)

#generator
    def gen(self):
        
        model = Sequential()
        model.add(Dense(256, input_dim = self.lat_dim, activation = 'relu'))
        model.add(BatchNormalization(momentum = 0.8))
        model.add(Dense(512, activation = 'relu'))
        model.add(BatchNormalization(momentum = 0.8))
        model.add(Dense(1024, activation = 'relu'))
        model.add(BatchNormalization(momentum = 0.8))
        model.add(Dense(np.prod(self.img_shape), activation = 'tanh'))
        model.add(Reshape(self.img_shape))
        print("Generator:")
        print(model.summary())
        
        z = Input(shape = (self.lat_dim, ))
        img = model(z)
        
        return Model(z, img)

#training
    def train(self, epochs, batch_size = 128, sample_interval = 50):
        (X_train, _), (_, _) = mnist.load_data()
        
        #Pre-processing
        X_train = X_train / 127.5 - 1
        X_train = np.expand_dims(X_train, axis = 3)
        
        valid = np.ones((batch_size, 1))
        fake = np.zeros((batch_size, 1))

        for it in range(epochs):
            
            ##################################
            #     discriminator training     #
            ##################################
            ids = np.random.randint(0, X_train.shape[0], batch_size)
            imgs = X_train[ids]
            
            z = np.random.normal(0, 1, (batch_size, self.lat_dim))
            gen_imgs = self.generator.predict(z)
            
            D_loss_real = self.discriminator.train_on_batch(imgs, valid)
            D_loss_fake = self.discriminator.train_on_batch(gen_imgs, fake)
            D_loss = 0.5 * np.add(D_loss_fake, D_loss_real)
            
            ##################################
            #          generator training    #
            ##################################
            z = np.random.normal(0, 1, (batch_size, self.lat_dim))
            
            G_loss =  self.combined.train_on_batch(z, valid)
            print ("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (it, D_loss[0], 100*D_loss[1], G_loss))
            
            if it % sample_interval == 0:
                self.sample_images(it)

#tracking generator progress via image sampling
    def sample_images(self, epoch):
        r = 5 
        c = 5
        z = np.random.normal(0, 1, (r *c, self.lat_dim))
        gen_imgs = self.generator.predict(z)
        gen_imgs = (0.5 * gen_imgs) + 0.5
        
        fig, axs = plt.subplots(r, c)
        
        count = 0
        for i in range(r):
            for j in range(c):
                axs[i, j].imshow(gen_imgs[count, :, :, 0], cmap = 'gray')
                axs[i, j].axis('off')
                count += 1
        fig.savefig(r"C:\Users\achint\Downloads\GANs\%d.png" % epoch)
        plt.close()

In [105]:
# tf_config = some_custom_config
sess = tf.compat.v1.Session()
graph = tf.compat.v1.get_default_graph()
# global sess
# global graph
with graph.as_default():
  tf.compat.v1.experimental.output_all_intermediates(True)
  set_session(sess)
  gan = GAN()
  gan.train(epochs = 3000, batch_size = 32, sample_interval = 200)

Instructions for updating:
If using Keras pass *_constraint arguments to layers.
Discriminator :
Model: "sequential_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
flatten_1 (Flatten)          (None, 784)               0         
_________________________________________________________________
dense_1 (Dense)              (None, 512)               401920    
_________________________________________________________________
dense_2 (Dense)              (None, 256)               131328    
_________________________________________________________________
dense_3 (Dense)              (None, 1)                 257       
Total params: 533,505
Trainable params: 533,505
Non-trainable params: 0
_________________________________________________________________
None
Generator:
Model: "sequential_2"
_________________________________________________________________
Layer (type)                 Output Shape    

  'Discrepancy between trainable weights and collected trainable'


1 [D loss: 0.453111, acc.: 79.69%] [G loss: 0.587574]
2 [D loss: 0.386737, acc.: 76.56%] [G loss: 0.628791]
3 [D loss: 0.352603, acc.: 79.69%] [G loss: 0.693770]
4 [D loss: 0.335748, acc.: 84.38%] [G loss: 0.851937]
5 [D loss: 0.267965, acc.: 100.00%] [G loss: 0.957399]
6 [D loss: 0.234546, acc.: 100.00%] [G loss: 1.160879]
7 [D loss: 0.188888, acc.: 100.00%] [G loss: 1.389143]
8 [D loss: 0.167025, acc.: 100.00%] [G loss: 1.526902]
9 [D loss: 0.145547, acc.: 100.00%] [G loss: 1.689894]
10 [D loss: 0.107856, acc.: 100.00%] [G loss: 1.878368]
11 [D loss: 0.087529, acc.: 100.00%] [G loss: 2.080362]
12 [D loss: 0.073568, acc.: 100.00%] [G loss: 2.178494]
13 [D loss: 0.077455, acc.: 100.00%] [G loss: 2.364131]
14 [D loss: 0.065638, acc.: 100.00%] [G loss: 2.559041]
15 [D loss: 0.052068, acc.: 100.00%] [G loss: 2.661669]
16 [D loss: 0.047411, acc.: 100.00%] [G loss: 2.746705]
17 [D loss: 0.041469, acc.: 100.00%] [G loss: 2.992876]
18 [D loss: 0.042474, acc.: 100.00%] [G loss: 2.901683]
19 [D