## GAN in Tensorflow

Reference : https://www.tensorflow.org/guide/keras/writing_a_training_loop_from_scratch?hl=ko

In [1]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import numpy as np
from tqdm import tqdm

In [2]:
# Root directory for dataset
dataroot = "your path!"

# Number of workers for dataloader
workers = 2

# Batch size during training
batch_size = 128

# Spatial size of training images. All images will be resized to this
#   size using a transformer.
image_size = 224

# Number of channels in the training images. For color images this is 3
nc = 3

# Size of z latent vector (i.e. size of generator input)
latent_dim = 1024

# Number of training epochs
num_epochs = 10


In [3]:
## Load CelebA
train_ds = tf.keras.preprocessing.image_dataset_from_directory(
  dataroot,
  seed=123,
  image_size=(image_size, image_size),
  batch_size=batch_size)

Found 202599 files belonging to 1 classes.


In [4]:
## Preprocessing
augmentation_layer = keras.Sequential([
    layers.experimental.preprocessing.Rescaling(1./255),
    layers.experimental.preprocessing.CenterCrop(image_size,image_size),
    layers.experimental.preprocessing.Normalization()    
])
image_ds = train_ds.map(lambda x,y: (augmentation_layer(x),y))

In [3]:
discriminator = keras.Sequential(
    [
        # Input: (h,w,c) = (224,224,3)
        keras.Input(shape=(224, 224, 3)),
        layers.Conv2D(64, (5, 5), strides=(2, 2), padding="same"),
        layers.LeakyReLU(alpha=0.2),
        # Input : (112, 112, 64)     
        layers.Conv2D(128, (5, 5), strides=(2, 2), padding="same"),
        layers.BatchNormalization(),
        layers.LeakyReLU(alpha=0.2),
        # Input : (56, 56, 128)     
        layers.Conv2D(256, (5, 5), strides=(2, 2), padding="same"),
        layers.BatchNormalization(),
        layers.LeakyReLU(alpha=0.2),     
        # Input : (28, 28, 256)     
        layers.Conv2D(256, (5, 5), strides=(2, 2), padding="same"),
        layers.BatchNormalization(),
        layers.LeakyReLU(alpha=0.2),     
        # Input : (14, 14, 256)             
        layers.Conv2D(512, (5, 5), strides=(2, 2), padding="same"),
        layers.BatchNormalization(),
        layers.LeakyReLU(alpha=0.2),
        # Input : (7, 7, 512)
        layers.Flatten(),
        # Input : (7*7*512)
        layers.Softmax()

    ],
    name="discriminator",
)
discriminator.summary()

Model: "discriminator"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 conv2d_5 (Conv2D)           (None, 112, 112, 64)      4864      
                                                                 
 leaky_re_lu_5 (LeakyReLU)   (None, 112, 112, 64)      0         
                                                                 
 conv2d_6 (Conv2D)           (None, 56, 56, 128)       204928    
                                                                 
 batch_normalization_4 (Batc  (None, 56, 56, 128)      512       
 hNormalization)                                                 
                                                                 
 leaky_re_lu_6 (LeakyReLU)   (None, 56, 56, 128)       0         
                                                                 
 conv2d_7 (Conv2D)           (None, 28, 28, 256)       819456    
                                                     

In [4]:
generator = keras.Sequential(
    [     
        # Input: (1024)
        layers.Dense(7*7*256, use_bias=False, input_shape=(1024,)),
        layers.BatchNormalization(),
        layers.ReLU(),     
        layers.Reshape((7,7,256)),
        # Input: (7, 7, 256)
        layers.Conv2DTranspose(256, (4, 4), strides=(2, 2), padding="same"),
        layers.BatchNormalization(),     
        layers.ReLU(),
        # Input: (14, 14, 256)
        layers.Conv2DTranspose(256, (3, 3), strides=(1, 1), padding="same"),
        layers.BatchNormalization(),     
        layers.ReLU(),
        # Input: (14, 14, 256)
        layers.Conv2DTranspose(256, (4, 4), strides=(2, 2), padding="same"),
        layers.BatchNormalization(),     
        layers.ReLU(),
        # Input: (28, 28, 256)
        layers.Conv2DTranspose(256, (3, 3), strides=(1, 1), padding="same"),
        layers.BatchNormalization(),     
        layers.ReLU(),
        # Input: (28, 28, 256)
        layers.Conv2DTranspose(128, (4, 4), strides=(2, 2), padding="same"),
        layers.BatchNormalization(),     
        layers.ReLU(),
        # Forth layer, Input: (56, 56, 128)
        layers.Conv2DTranspose(64, (4, 4), strides=(2, 2), padding="same"),
        layers.BatchNormalization(),     
        layers.ReLU(),
     
        # Final layer, Input: (112, 112, 64)
        layers.Conv2DTranspose(3, (4, 4), strides=(2, 2), padding="same", activation = "tanh"),
    ],
    name="generator",
)
generator.summary()

Model: "generator"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 dense (Dense)               (None, 12544)             12845056  
                                                                 
 batch_normalization_8 (Batc  (None, 12544)            50176     
 hNormalization)                                                 
                                                                 
 re_lu (ReLU)                (None, 12544)             0         
                                                                 
 reshape (Reshape)           (None, 7, 7, 256)         0         
                                                                 
 conv2d_transpose (Conv2DTra  (None, 14, 14, 256)      1048832   
 nspose)                                                         
                                                                 
 batch_normalization_9 (Batc  (None, 14, 14, 256)      10

In [7]:
# Instantiate one optimizer for the discriminator and another for the generator.
d_optimizer = keras.optimizers.Adam(learning_rate=0.0002)
g_optimizer = keras.optimizers.Adam(learning_rate=0.0002)

# Instantiate a loss function.
loss_fn = keras.losses.MeanSquaredError()

In [8]:
random_latent_vectors = tf.random.normal(shape=(batch_size, latent_dim))

In [9]:
@tf.function
def train_step(real_images):
    # Sample random points in the latent space
    random_latent_vectors = tf.random.normal(shape=(batch_size, latent_dim))
    # Decode them to fake images
    generated_images = generator(random_latent_vectors)
    # Combine them with real images
    combined_images = tf.concat([generated_images, real_images], axis=0)

    # Assemble labels discriminating real from fake images
    labels = tf.concat(
        [tf.ones((batch_size, 1)), tf.zeros((real_images.shape[0], 1))], axis=0
    )
    # Add random noise to the labels - important trick!
    labels += 0.05 * tf.random.uniform(labels.shape)

    # Train the discriminator
    with tf.GradientTape() as tape:
        predictions = discriminator(combined_images)
        d_loss = loss_fn(labels, predictions)
    grads = tape.gradient(d_loss, discriminator.trainable_weights)
    d_optimizer.apply_gradients(zip(grads, discriminator.trainable_weights))

    # Sample random points in the latent space
    random_latent_vectors = tf.random.normal(shape=(batch_size, latent_dim))
    # Assemble labels that say "all real images"
    misleading_labels = tf.zeros((batch_size, 1))

    # Train the generator (note that we should *not* update the weights
    # of the discriminator)!
    with tf.GradientTape() as tape:
        predictions = discriminator(generator(random_latent_vectors))
        g_loss = loss_fn(misleading_labels, predictions)
    grads = tape.gradient(g_loss, generator.trainable_weights)
    g_optimizer.apply_gradients(zip(grads, generator.trainable_weights))
    return d_loss, g_loss, generated_images

In [10]:
img_list = []
G_losses = []
D_losses = []
tf.keras.backend.clear_session()
for epoch in range(num_epochs):
    print("\nStart epoch", epoch + 1)

    for _, (real_images,_) in enumerate(tqdm(image_ds)):
        # Train the discriminator & generator on one batch of real images.
        d_loss, g_loss, generated_images = train_step(real_images)

    # Logging.
    # Print metrics
    print("discriminator loss at epochs %d: %.2f" % (epoch + 1, d_loss))
    print("adversarial loss at epochs %d: %.2f" % (epoch + 1, g_loss))


Start epoch 0


100%|██████████| 1583/1583 [01:50<00:00, 14.36it/s]


discriminator loss at epochs 0: 0.73
adversarial loss at epochs 0: 1.31

Start epoch 1


100%|██████████| 1583/1583 [01:46<00:00, 14.80it/s]


discriminator loss at epochs 1: 0.73
adversarial loss at epochs 1: 1.31

Start epoch 2


100%|██████████| 1583/1583 [01:46<00:00, 14.80it/s]


discriminator loss at epochs 2: 0.73
adversarial loss at epochs 2: 1.31

Start epoch 3


100%|██████████| 1583/1583 [01:46<00:00, 14.80it/s]


discriminator loss at epochs 3: 0.74
adversarial loss at epochs 3: 1.31

Start epoch 4


100%|██████████| 1583/1583 [01:46<00:00, 14.81it/s]


discriminator loss at epochs 4: 0.73
adversarial loss at epochs 4: 1.31

Start epoch 5


100%|██████████| 1583/1583 [01:46<00:00, 14.81it/s]


discriminator loss at epochs 5: 0.73
adversarial loss at epochs 5: 1.31

Start epoch 6


100%|██████████| 1583/1583 [01:46<00:00, 14.81it/s]


discriminator loss at epochs 6: 0.73
adversarial loss at epochs 6: 1.31

Start epoch 7


100%|██████████| 1583/1583 [01:46<00:00, 14.81it/s]


discriminator loss at epochs 7: 0.73
adversarial loss at epochs 7: 1.31

Start epoch 8


100%|██████████| 1583/1583 [01:46<00:00, 14.81it/s]


discriminator loss at epochs 8: 0.73
adversarial loss at epochs 8: 1.31

Start epoch 9


100%|██████████| 1583/1583 [01:47<00:00, 14.79it/s]

discriminator loss at epochs 9: 0.73
adversarial loss at epochs 9: 1.31



