In [1]:
import tensorflow as tf
from tensorflow.keras.datasets.mnist import load_data
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Flatten, Dense, Conv2DTranspose, Conv2D, BatchNormalization, Reshape, Dropout
from tensorflow.nn import relu, tanh, leaky_relu, sigmoid
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm

In [2]:
(x_train, y_train), (x_test, y_test) = load_data()
x_train = x_train / 255
x_test = x_test / 255
x_train = np.expand_dims(x_train, axis=-1)
x_test = np.expand_dims(x_test, axis=-1)

In [3]:
generator = Sequential()

generator.add(Dense(128*7*7, input_dim=100, activation=leaky_relu))
generator.add(Reshape((7, 7, 128)))
generator.add(BatchNormalization())

generator.add(Conv2DTranspose(128, (3,3), strides=(2,2), padding="same", activation=leaky_relu))
generator.add(Conv2DTranspose(128, (3,3), strides=(2,2), padding="same", activation=leaky_relu))
generator.add(BatchNormalization())
generator.add(Dropout(0.3))

generator.add(Conv2D(1, (3,3), padding="same", activation=tanh))

generator.summary()

Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
dense (Dense)                (None, 6272)              633472    
_________________________________________________________________
reshape (Reshape)            (None, 7, 7, 128)         0         
_________________________________________________________________
batch_normalization (BatchNo (None, 7, 7, 128)         512       
_________________________________________________________________
conv2d_transpose (Conv2DTran (None, 14, 14, 128)       147584    
_________________________________________________________________
conv2d_transpose_1 (Conv2DTr (None, 28, 28, 128)       147584    
_________________________________________________________________
batch_normalization_1 (Batch (None, 28, 28, 128)       512       
_________________________________________________________________
dropout (Dropout)            (None, 28, 28, 128)       0

In [4]:
discriminator = Sequential()

discriminator.add(Conv2D(64, (3,3), padding="same", input_shape=(28,28,1), activation=leaky_relu))
discriminator.add(BatchNormalization())

discriminator.add(Conv2D(128, (3,3), padding="same", activation=leaky_relu))
discriminator.add(BatchNormalization())
discriminator.add(Dropout(0.3))

discriminator.add(Conv2D(64, (3,3), padding="same", activation=leaky_relu))
discriminator.add(BatchNormalization())
discriminator.add(Dropout(0.3))

discriminator.add(Flatten())
discriminator.add(Dense(64, activation=leaky_relu))
discriminator.add(Dense(128, activation=leaky_relu))
discriminator.add(Dense(1, activation=sigmoid))

discriminator.compile(loss ="binary_crossentropy", optimizer ="adam")
discriminator.trainable = False
discriminator.summary()

Model: "sequential_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
conv2d_1 (Conv2D)            (None, 28, 28, 64)        640       
_________________________________________________________________
batch_normalization_2 (Batch (None, 28, 28, 64)        256       
_________________________________________________________________
conv2d_2 (Conv2D)            (None, 28, 28, 128)       73856     
_________________________________________________________________
batch_normalization_3 (Batch (None, 28, 28, 128)       512       
_________________________________________________________________
dropout_1 (Dropout)          (None, 28, 28, 128)       0         
_________________________________________________________________
conv2d_3 (Conv2D)            (None, 28, 28, 64)        73792     
_________________________________________________________________
batch_normalization_4 (Batch (None, 28, 28, 64)       

In [5]:
gan = Sequential([generator, discriminator])

gan.compile(loss ="binary_crossentropy", optimizer ="adam")

In [6]:
gen, dis = gan.layers

In [None]:
for epoch in tqdm(range(10)):
    # real samples for discriminator
    random = np.random.randint(0, 10_000)
    x_real = x_train[0:random]
    np.random.shuffle(x_real)
    y_real = np.ones(random)
    
    # fake samples for discriminator
    x_fake = np.random.rand(random, 28, 28, 1)
    y_fake = np.zeros(random)
    
    # make a huge array for discriminator
    x = np.concatenate((x_real, x_fake))
    y = np.concatenate((y_real, y_fake))
    
    # train the discriminator
    dis.trainable = True
    dis.fit(x, y, epochs=1)
    
    # train the gans model
    noise = np.random.rand(100,100)
    y2 = np.ones(100)
    dis.trainable = False
    gan.fit(noise, y2, epochs=1)
    
    # save the prediction
    pred = gen(noise, training=False)
    pred = np.array(pred)[33].squeeze(axis=-1)
    plt.imsave(f"{epoch}.png", pred)

  0%|          | 0/10 [00:00<?, ?it/s]

 53/291 [====>.........................] - ETA: 1:55 - loss: 0.0144