# Wasserstein GAN (W-GAN)

Originally proposed by [Arjovsky et al.](https://arxiv.org/pdf/1701.07875.pdf) is their work titled Unsupervised Representation Learning With Deep Convolutions Generative Adversarial Networks. This network uses a basic implementation where generator and discriminator models use convolutional layers, batch normalization and Upsampling.
This notebook trains both networks using ADAM optimizer to play the minimax game. We showcase the effectiveness using MNIST digit generation

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/PacktPublishing/Hands-On-Generative-AI-with-Python-and-TensorFlow-2/blob/master/Chapter_6/wasserstein_gan.ipynb)

## Load Libraries

In [None]:
from tensorflow.keras import Model
from tensorflow.keras.layers import Input
from tensorflow.keras.optimizers import Adam,RMSprop
from tensorflow.keras import datasets
import numpy as np

## Load Utility Functions

In [None]:
from gan_utils import build_critic
from gan_utils import build_dc_generator
from gan_utils import sample_images
from gan_utils import wasserstein_loss

## W-GAN Training Loop
- As proposed in the original paper
- Train critic using a mix of fake and real samples
- Calculate discriminator loss
- Train the critic 5 times per training cycle of the generator
- Use Wasserstein_loss for both generator and discriminators
- Fix the discriminator and train generator

In [None]:
def train(generator=None,discriminator=None,gan_model=None,
          epochs=1000, discriminator_cycles=5, batch_size=128, sample_interval=50,
          z_dim=100,clip_value = 0.01):
    # Load MNIST train samples
    (X_train, _), (_, _) = datasets.mnist.load_data()

    # Rescale -1 to 1
    X_train = X_train / 127.5 - 1
    X_train = np.expand_dims(X_train, axis=3)

    # Prepare GAN output labels
    real_y = -np.ones((batch_size, 1))
    fake_y = np.ones((batch_size, 1))

    for epoch in range(epochs):
        # train disriminator
        for _ in range(discriminator_cycles):
            # pick random real samples from X_train
            idx = np.random.randint(0, X_train.shape[0], batch_size)
            real_imgs = X_train[idx]

            # pick random noise samples (z) from a normal distribution
            noise = np.random.normal(0, 1, (batch_size, z_dim))
            # use generator model to generate output samples
            fake_imgs = generator.predict(noise)

            # calculate discriminator loss on real samples
            disc_loss_real = discriminator.train_on_batch(real_imgs, real_y)

            # calculate discriminator loss on fake samples
            disc_loss_fake = discriminator.train_on_batch(fake_imgs, fake_y)

            # overall discriminator loss
            discriminator_loss = 0.5 * np.add(disc_loss_real, disc_loss_fake)

            # clip weights to ensure adherance to model constraints in EM space
            # Clip critic weights
            for l in discriminator.layers:
                weights = l.get_weights()
                weights = [np.clip(w, -clip_value, clip_value) for w in weights]
                l.set_weights(weights)

        #train generator
        # pick random noise samples (z) from a normal distribution
        noise = np.random.normal(0, 1, (batch_size, z_dim))

        # use trained discriminator to improve generator
        gen_loss = gan_model.train_on_batch(noise, real_y)

        # training updates
        print ("%d [Discriminator loss: %f] [Generator loss: %f]" % (epoch,
                                                                     1 - discriminator_loss[0],
                                                                     1 - gen_loss[0]))

        # If at save interval => save generated image samples
        if epoch % sample_interval == 0:
            sample_images(epoch,generator)

## Prepare Discriminator Model or Critic

In [None]:
discriminator = build_critic()
discriminator.compile(loss=wasserstein_loss,
            optimizer=RMSprop(lr=0.00005),
            metrics=['accuracy'])

Model: "sequential_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
conv2d_4 (Conv2D)            (None, 14, 14, 16)        160       
_________________________________________________________________
leaky_re_lu_4 (LeakyReLU)    (None, 14, 14, 16)        0         
_________________________________________________________________
dropout_4 (Dropout)          (None, 14, 14, 16)        0         
_________________________________________________________________
conv2d_5 (Conv2D)            (None, 7, 7, 32)          4640      
_________________________________________________________________
zero_padding2d_1 (ZeroPaddin (None, 8, 8, 32)          0         
_________________________________________________________________
leaky_re_lu_5 (LeakyReLU)    (None, 8, 8, 32)          0         
_________________________________________________________________
dropout_5 (Dropout)          (None, 8, 8, 32)         

## Prepare Generator Model

In [None]:
generator = build_dc_generator()

Model: "sequential_2"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
dense_2 (Dense)              (None, 6272)              633472    
_________________________________________________________________
reshape (Reshape)            (None, 7, 7, 128)         0         
_________________________________________________________________
up_sampling2d (UpSampling2D) (None, 14, 14, 128)       0         
_________________________________________________________________
conv2d_8 (Conv2D)            (None, 14, 14, 128)       147584    
_________________________________________________________________
batch_normalization (BatchNo (None, 14, 14, 128)       512       
_________________________________________________________________
activation (Activation)      (None, 14, 14, 128)       0         
_________________________________________________________________
up_sampling2d_1 (UpSampling2 (None, 28, 28, 128)      

## Prepare GAN Model

In [None]:
# Noise for generator
z_dim = 100
z = Input(shape=(z_dim,))
img = generator(z)

# Fix the discriminator
discriminator.trainable = False

# Get discriminator output
valid = discriminator(img)

# Stack discriminator on top of generator
gan_model = Model(z, valid)
gan_model.compile(loss=wasserstein_loss,
    optimizer=RMSprop(lr=0.00005),
    metrics=['accuracy'])
gan_model.summary()

Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_4 (InputLayer)         [(None, 100)]             0         
_________________________________________________________________
sequential_2 (Sequential)    (None, 28, 28, 1)         856193    
_________________________________________________________________
sequential_1 (Sequential)    (None, 1)                 99201     
Total params: 955,394
Trainable params: 855,809
Non-trainable params: 99,585
_________________________________________________________________


## Train W-GAN

In [None]:
train(generator,discriminator,gan_model,epochs=4000, batch_size=64, sample_interval=100)

## Output
Samples generated after 4000 epochs
<img src="outputs/w_gan_output.png">