<a href="https://colab.research.google.com/github/aneeshcheriank/approaching-any-machine-learning-problem/blob/main/First_Gan.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import tensorflow as tf
import tensorflow.keras as keras

from tensorflow.keras.datasets import mnist
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D, BatchNormalization, LeakyReLU, MaxPooling2D
from tensorflow.keras.layers import Dense, Flatten, Conv2DTranspose, Reshape

from tqdm import tqdm
import matplotlib.pyplot as plt

In [3]:
BUFFER = 1024
BATCH = 256
EPOCHS = 10

In [19]:
# data loader
(x_train, y_train), (x_test, y_test) = mnist.load_data()

x_train = (x_train/255.).reshape((-1, 28, 28, 1))

train_data = tf.data.Dataset.from_tensor_slices((x_train, y_train))\
.shuffle(BUFFER)\
.batch(BATCH)

x, y = iter(train_data).next()
print(x.shape)
print(y.shape)

Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz
(256, 28, 28, 1)
(256,)


In [16]:
# models
generator = Sequential([
        Dense(16, kernel_initializer='he_normal', use_bias = False, input_dim=(8)), 
        Reshape((4, 4, 1)),               
        Conv2DTranspose(8, (2, 2), strides=(2, 2), kernel_initializer='he_normal'),  
        BatchNormalization(),
        LeakyReLU(alpha=0.5),
        Conv2DTranspose(8, (2, 2), strides=(2, 2), kernel_initializer='he_normal'),  
        BatchNormalization(),
        LeakyReLU(alpha=0.5),
        Conv2DTranspose(8, (2, 2), strides=(2, 2), kernel_initializer='he_normal'),  
        BatchNormalization(),
        LeakyReLU(alpha=0.5),
        Conv2D(1, (5, 5), activation='tanh')            
])

discriminator = Sequential([
        Conv2D(16, (2, 2), kernel_initializer='he_normal', input_shape=(28, 28, 1)),
        BatchNormalization(),
        LeakyReLU(alpha=0.5),
        MaxPooling2D((2, 2)),
        Conv2D(16, (2, 2), kernel_initializer='he_normal'),
        BatchNormalization(),
        LeakyReLU(alpha=0.5),
        MaxPooling2D((2, 2)),
        Conv2D(16, (2, 2), kernel_initializer='he_normal'),
        BatchNormalization(),
        LeakyReLU(alpha=0.5),
        MaxPooling2D((2, 2)),
        Flatten(),
        Dense(128, kernel_initializer='he_normal'),
        LeakyReLU(alpha=0.5),
        Dense(1, activation='softmax')
])

gan = Sequential([
       generator, discriminator           
])

In [34]:
optimizer = tf.keras.optimizers.Adam(learning_rate=0.0001, beta_1=0.5)
loss = tf.keras.losses.BinaryCrossentropy() 

gan_loss = []
dis_loss = []
# training loop
for epoch in range(EPOCHS):
  d_loss = 0
  g_loss = 0
  for real_image, _ in tqdm(train_data):
    batch_size = real_image.shape[0]
    noise = tf.random.normal((batch_size, 8))

    fake_images = tf.cast(generator(noise), dtype=tf.float64)
    train_batch = tf.concat([real_image, fake_images], axis=0)
    labels = tf.concat(
        [[[0.]]*batch_size + [[1.]]*batch_size], axis=0
    )

    # set the discirminator as trainable
    discriminator.trainable = True

    # calculate gradients and update weights
    with tf.GradientTape() as t:
      # predict the labels
      pred = discriminator(train_batch)
      l = loss(labels, pred)

      gradients = t.gradient(l, discriminator.trainable_weights) 
      optimizer.apply_gradients(
          zip(gradients, discriminator.trainable_weights)
      )
    
    # RMS of the loss
    d_loss += tf.math.reduce_mean(l)

    # set the discirminator as not trainable
    discriminator.trainable = False

    noise = tf.random.normal((batch_size*2, 8))
    labels = tf.constant([[0.]]*(2*batch_size), dtype=tf.float32)

    with tf.GradientTape() as t:
      pred = gan(noise)
      l = loss(labels, pred)

      gradients = t.gradient(l, gan.trainable_weights) 
      optimizer.apply_gradients(
          zip(gradients, gan.trainable_weights)
      )

    # RMS of the loss
    g_loss += tf.math.reduce_mean(l)
  
  gan_loss.append(g_loss.numpy())
  dis_loss.append(d_loss.numpy())
  if (epoch+1)%2 == 0:
    print(f'Epoch {epoch+1}/{EPOCHS}, dis loss {d_loss.numpy():.4f}, gan loss = {g_loss.numpy():.4f}')

print('discirminator_loss')
print(dis_loss)
print('gan_loss')
print(gan_loss)

100%|██████████| 235/235 [04:20<00:00,  1.11s/it]
100%|██████████| 235/235 [04:22<00:00,  1.11s/it]


Epoch 2/10, dis loss 2.044924020767212, gan loss = 2126.35546875


100%|██████████| 235/235 [04:19<00:00,  1.11s/it]
100%|██████████| 235/235 [04:19<00:00,  1.11s/it]


Epoch 4/10, dis loss 3.000990867614746, gan loss = 2079.308837890625


100%|██████████| 235/235 [04:18<00:00,  1.10s/it]
100%|██████████| 235/235 [04:22<00:00,  1.11s/it]


Epoch 6/10, dis loss 2.4166922569274902, gan loss = 2097.734619140625


100%|██████████| 235/235 [04:27<00:00,  1.14s/it]
100%|██████████| 235/235 [04:18<00:00,  1.10s/it]


Epoch 8/10, dis loss 2.530951738357544, gan loss = 2220.668212890625


100%|██████████| 235/235 [04:22<00:00,  1.11s/it]
100%|██████████| 235/235 [04:22<00:00,  1.12s/it]

Epoch 10/10, dis loss 1.4921554327011108, gan loss = 2279.336669921875
discirminator_loss
[2.9921033, 2.044924, 2.877721, 3.0009909, 2.2605298, 2.4166923, 2.5833619, 2.5309517, 1.6626145, 1.4921554]
gan_loss
[2151.0876, 2126.3555, 2110.291, 2079.3088, 2127.4287, 2097.7346, 2134.46, 2220.6682, 2237.9507, 2279.3367]





In [25]:
noise = tf.random.normal((1, 8), dtype=tf.float32)
image = generator(noise)

plt.figure(figsize=(8, 8))
plt.imshow(image)
plt.show()

tf.float32