In [3]:
from keras.layers import Input, Dense, Reshape, BatchNormalization
from keras.models import Model
from tensorflow.keras.optimizers import Adam
from keras.datasets import mnist
from tensorflow.keras.utils import to_categorical
import numpy as np
 
class GAN():
    def __init__(self):
 
        (self.x_train, self.y_train), (self.x_test, self.y_test) = mnist.load_data()
        self.batch_size = 100
        self.half_batch_size = 50
        self.latent_dim = 100
        self.iterations = 10000
        self.optimizer = Adam(0.0002, 0.5)
        self.generator_model = self.generator() 
        self.discriminator_model, self.classification_model = self.discriminator()
        self.combined_model = self.combined()
        
 
    def generator(self):
        
        input_gen = Input(shape = (self.latent_dim,))
        hidden1 = BatchNormalization(momentum=0.8)(Dense(256, activation = 'relu')(input_gen))
        hidden2 = BatchNormalization(momentum=0.8)(Dense(512, activation = 'relu')(hidden1))
        hidden3 = BatchNormalization(momentum=0.8)(Dense(1024, activation = 'relu')(hidden2))
        output = Dense(784, activation='tanh')(hidden3)
        reshaped_output = Reshape((28, 28, 1))(output)
        gen_model = Model(input_gen, reshaped_output)
        gen_model.compile(loss='binary_crossentropy', optimizer=self.optimizer)
        print(gen_model.summary())
        
        
        return gen_model
    
    def discriminator(self):
        
        input_disc = Input(shape = (784,))
        hidden1 = Dense(512, activation = 'relu')(input_disc)
        hidden2 = Dense(256, activation = 'relu')(hidden1)
        hidden3 = Dense(128, activation = 'relu')(hidden2)
        output = Dense(1, activation = 'sigmoid')(hidden3)
        output2 = Dense(10, activation = 'softmax', name = 'classification_layer')(hidden3)
        disc_model = Model(input_disc, output)
        disc_model_2 = Model(input_disc, output2)
        disc_model.compile(loss=['binary_crossentropy'], optimizer=self.optimizer, metrics=['accuracy'])
        disc_model_2.compile(loss=['categorical_crossentropy'], optimizer=self.optimizer, metrics=['accuracy'])
        print(disc_model.summary())
        print(disc_model_2.summary())
        
        return disc_model, disc_model_2
    
    def combined(self):
        
        inputs = Input(shape = (self.latent_dim,)) 
        gen_img = self.generator_model(inputs)
        gen_img = Reshape((784,))(gen_img)
        self.discriminator_model.trainable = False
        outs = self.discriminator_model(gen_img)
        comb_model = Model(inputs, outs)
        comb_model.compile(loss='binary_crossentropy', optimizer=self.optimizer, metrics=['accuracy'])
        print(comb_model.summary())
        
        return comb_model
    
    def sample_1000(self, x, y):
        
        x_1000 = []
        y_1000 = []
        for i in range(10):
            x_i = x[y==i]
            ix = np.random.randint(0, len(x_i), 100)
            [x_1000.append(x_i[j]) for j in ix]
            [y_1000.append(i) for j in ix]
            
        return x_1000, y_1000
    
    def train(self):
        
        train_data, train_data_y = self.sample_1000(self.x_train, self.y_train)
        train_data = ((np.array(train_data).astype(np.float32))-127.5)/127.5
        train_data_y = to_categorical(train_data_y)
        
        all_train_data = ((np.array(self.x_train).astype(np.float32))-127.5)/127.5
        all_train_data_y = to_categorical(self.y_train)
        
        for j in range(self.iterations):
            
            batch_indx = np.random.randint(0, train_data.shape[0], size = (self.half_batch_size))
            batch_x = train_data[batch_indx]
            batch_x = batch_x.reshape((-1, 784))
            batch_y = train_data_y[batch_indx]
            
            
            batch_indx_total = np.random.randint(0, all_train_data.shape[0], size = (self.half_batch_size))
            batch_x_total = all_train_data[batch_indx_total]
            batch_x_total = batch_x_total.reshape((-1, 784))
            batch_y_total = all_train_data_y[batch_indx_total]
            
            
            input_noise = np.random.normal(0, 1, size=(self.half_batch_size, 100))
            gen_outs = self.generator_model.predict(input_noise)
            gen_outs = gen_outs.reshape((-1, 784))
            
            classi_loss = self.classification_model.train_on_batch(batch_x, batch_y)
            real_loss1 = self.discriminator_model.train_on_batch(batch_x_total, np.ones((self.half_batch_size,1)))
            fake_loss = self.discriminator_model.train_on_batch(gen_outs, np.zeros((self.half_batch_size,1)))     
        
            
            full_batch_input_noise = np.random.normal(0, 1, size=(self.batch_size, 100))
            gan_loss = self.combined_model.train_on_batch(full_batch_input_noise, np.array([1] * self.batch_size))
            
            if j%1000 == 0:
                test_data = ((self.x_test.astype(np.float32)-127.5)/127.5).reshape((-1, 784))
                test_results = self.classification_model.predict(test_data)
                test_results_argmax = np.argmax(test_results, axis = 1)
                
                count = 0
                for i in range(len(test_results_argmax)):
                    if test_results_argmax[i] == self.y_test[i]:
                        count += 1
                print("Accuracy After", j,"iterations: ", (count/len(test_data))*100)
            
            
gan = GAN()
gan.train()

Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz
Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_1 (InputLayer)         [(None, 100)]             0         
_________________________________________________________________
dense (Dense)                (None, 256)               25856     
_________________________________________________________________
batch_normalization (BatchNo (None, 256)               1024      
_________________________________________________________________
dense_1 (Dense)              (None, 512)               131584    
_________________________________________________________________
batch_normalization_1 (Batch (None, 512)               2048      
_________________________________________________________________
dense_2 (Dense)              (None, 1024)              525312    
___________________________________