<a href="https://colab.research.google.com/github/aneeshcheriank/approaching-any-machine-learning-problem/blob/main/First_Gan.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import tensorflow as tf
import tensorflow.keras as keras

from tensorflow.keras.datasets import mnist
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D, BatchNormalization, LeakyReLU, MaxPooling2D
from tensorflow.keras.layers import Dense, Flatten, Conv2DTranspose, Reshape

from tqdm import tqdm

In [3]:
BUFFER = 1024
BATCH = 256
EPOCHS = 10

In [19]:
# data loader
(x_train, y_train), (x_test, y_test) = mnist.load_data()

x_train = (x_train/255.).reshape((-1, 28, 28, 1))

train_data = tf.data.Dataset.from_tensor_slices((x_train, y_train))\
.shuffle(BUFFER)\
.batch(BATCH)

x, y = iter(train_data).next()
print(x.shape)
print(y.shape)

Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz
(256, 28, 28, 1)
(256,)


In [16]:
# models
generator = Sequential([
        Dense(16, kernel_initializer='he_normal', use_bias = False, input_dim=(8)), 
        Reshape((4, 4, 1)),               
        Conv2DTranspose(8, (2, 2), strides=(2, 2), kernel_initializer='he_normal'),  
        BatchNormalization(),
        LeakyReLU(alpha=0.5),
        Conv2DTranspose(8, (2, 2), strides=(2, 2), kernel_initializer='he_normal'),  
        BatchNormalization(),
        LeakyReLU(alpha=0.5),
        Conv2DTranspose(8, (2, 2), strides=(2, 2), kernel_initializer='he_normal'),  
        BatchNormalization(),
        LeakyReLU(alpha=0.5),
        Conv2D(1, (5, 5), activation='tanh')            
])

discriminator = Sequential([
        Conv2D(16, (2, 2), kernel_initializer='he_normal', input_shape=(28, 28, 1)),
        BatchNormalization(),
        LeakyReLU(alpha=0.5),
        MaxPooling2D((2, 2)),
        Conv2D(16, (2, 2), kernel_initializer='he_normal'),
        BatchNormalization(),
        LeakyReLU(alpha=0.5),
        MaxPooling2D((2, 2)),
        Conv2D(16, (2, 2), kernel_initializer='he_normal'),
        BatchNormalization(),
        LeakyReLU(alpha=0.5),
        MaxPooling2D((2, 2)),
        Flatten(),
        Dense(128, kernel_initializer='he_normal'),
        LeakyReLU(alpha=0.5),
        Dense(1, activation='softmax')
])

gan = Sequential([
       generator, discriminator           
])

In [None]:
optimizer = tf.keras.optimizers.Adam(learning_rate=0.0001, beta_1=0.5)
loss = tf.keras.losses.BinaryCrossentropy() 

gan_loss = []
dis_loss = []
# training loop
for epoch in range(EPOCHS):
  d_loss = 0
  g_loss = 0
  for real_image, _ in tqdm(train_data):
    batch_size = real_image.shape[0]
    noise = tf.random.normal((batch_size, 8))

    fake_images = tf.cast(generator(noise), dtype=tf.float64)
    train_batch = tf.concat([real_image, fake_images], axis=0)
    labels = tf.concat(
        [[[0.]]*batch_size + [[1.]]*batch_size], axis=0
    )

    # set the discirminator as trainable
    discriminator.trainable = True

    # calculate gradients and update weights
    with tf.GradientTape() as t:
      # predict the labels
      pred = discriminator(train_batch)
      l = loss(labels, pred)

      gradients = t.gradient(l, discriminator.trainable_weights) 
      optimizer.apply_gradients(
          zip(gradients, discriminator.trainable_weights)
      )
    
    # RMS of the loss
    d_loss += tf.sqrt(tf.math.reduce_mean(tf.math.square(l)))

    # set the discirminator as not trainable
    discriminator.trainable = False

    noise = tf.random.normal((batch_size*2, 8))
    labels = tf.constant([[0.]]*(2*batch_size), dtype=tf.float32)

    with tf.GradientTape() as t:
      pred = gan(noise)
      l = loss(labels, pred)

      gradients = t.gradient(l, gan.trainable_weights) 
      optimizer.apply_gradients(
          zip(gradients, gan.trainable_weights)
      )

    # RMS of the loss
    g_loss += tf.sqrt(tf.math.reduce_mean(tf.math.square(l)))
  
  gan_loss.append(g_loss)
  dis_loss.append(d_loss)
  if (epoch+1)%2 == 0:
    print(f'Epoch {epoch+1}/{EPOCHS}, dis loss {dis_loss}, gan loss = {g_loss}')

print('discirminator_loss')
print(dis_loss)
print('gan_loss')
print(gan_loss)

  3%|▎         | 6/235 [00:10<06:44,  1.77s/it]

In [25]:
fake_images.dtype

tf.float32