# WGAN - Wasserstein GAN

## Synthetic tabular data leveraging WGAN

Well, GANs are a particular kind of networks that are mainly trying to learn the probability distribution of a given dataset - and for this particular task, probability density jumps right away into our minds. If the real data distribution Pr admits a density and Pθ is the distribution of the parametrized density Pθ, then, symptotically, this amounts to minimizing the Kullback-Leibler divergence KL(PrkPθ).

![image.png](attachment:image.png)
Fig.1 - Architecture of a generative adversarial network. (Image source: www.kdnuggets.com/2017/01/generative-…-learning.html)

But for this to make sense, we need the model density Pθ to exist. But this is not always the case, and in some cases it might be the situation where we deal with a KL distance not defined or simply infinite.


In a nutshell, the Wasserstein GAN, is an extension of the called Vanilla GANs, so what do they bring new? 



## The benefits

Wasserstein GAN was introduced by Martin Arjovsky and I'll leave it here the article the concept of using Earth-Mover distance as the loss function.

- [WGAN article](https://arxiv.org/pdf/1701.07875.pdf)

As I've mentioned WGAN introduces the concept of a critic, instead of discriminator, as this network scores the realness and fakeness of the given events. This change of concept was motivated by the argument that the generator should seek a minimization of the distance between the generated data and the one observed. This new concept, brings, off course, a few benefits: 

- The training stability of a WGAN is more stable when compared to, for example, VanillaGAN
- Less sensitive to model architecture selection (Generator and Critic choice)
- Less sensitive and impacted by the hyperparameters choice - although this is still very important to achieve good results.
- Most importantly, finally we can correlate the loss of the critic with the overall quality of the generated events.

### Why Wasserstein distance it is better than KL or JS divergence?

Very simple, and I kinda have covered it already - even if we are dealing with two distributions that are located in lower dimensional manifolds without overlaps, Wasserstein distance is still able to provide a meaningfull and smooth representation of the distance in-between. And that is what makes such a good bet to use this distance as a GAN loss function.

And without further a do, let's jump into the implementation!

## The implementation

#### The main differences between VanillaGAN and WGAN

These are the following changes that we will be undertaken in order to "transform" a VanillaGAN into a WGAN:

- After every gradient update on the critic function, clamp the weights to a small fixed range, [−c,c] -> This is the way we have to enforce the Lipschitz constraint.
- Use a new loss function derived from the Wasserstein distance.
- I suggest to go for RMSProp optimizer on the critic, rather than a momento based optimizer as it can cause instability.

In [None]:
import numpy as np

import tensorflow as tf
import tensorflow.keras.backend as K
from tensorflow.keras.layers import Input, Dense, Dropout
from tensorflow.keras import Model, constraints

In [None]:
# Let's start by defining the main components of our network - The Discriminator and the Critic (Discriminator)

#The Generator
class Generator(tf.keras.Model):
    def __init__(self, batch_size):
        self.batch_size = batch_size

    def build_model(self, input_shape, dim, data_dim):
        input = Input(shape=input_shape, batch_size=self.batch_size)
        x = Dense(dim, activation='relu')(input)
        x = Dense(dim * 2, activation='relu')(x)
        x = Dense(dim * 4, activation='relu')(x)
        x = Dense(data_dim)(x)
        return Model(inputs=input, outputs=x)

In [None]:
generator = Generator(batch_size=128).build_model((80, 1000), 126, 25)

generator.summary()

This generator is pretty simple. Here I've decided to go for a Dense Network with 4 Layers and Relu as the activation function

In [None]:
#The Critic or Discriminator
class Critic(tf.keras.Model):
    def __init__(self, batch_size):
        self.batch_size = batch_size

    def build_model(self, input_shape, dim):
        input = Input(shape=input_shape, batch_size=self.batch_size)
        x = Dense(dim * 4, activation='relu')(input)
        x = Dropout(0.1)(x)
        x = Dense(dim * 2, activation='relu')(x)
        x = Dropout(0.1)(x)
        x = Dense(dim, activation='relu')(x)
        x = Dense(1)(x)
        return Model(inputs=input, outputs=x)

In [None]:
critic = Critic(batch_size=128).build_model((80, 1000), 126)

critic.summary()

Similarly to the Generator, I've decided to go for a simple Network. Here I've a 4 Dense layers network with also Relu activation. 

But I want to emphasize a bit here the last code line. Different from Vanilla GAN where we add this as the last layer of the network:

<span style='color:blue'> x = Dense(1, activation='sigmoid')(x))</span>

It uses the sigmoid function in the output layer of the discriminator, which means that it predicts the likelihood of a given event to be real. 

When it comes to WGAN, the critic model requires a linear activation, in order to predict the score of the "realness" for a given event. 

<span style='color:blue'> x = Dense(1)(x)</span>

or

<span style='color:blue'> x = Dense(1, activation='linear')(x)</span>

As I've mentioned WGAN does not have precise labels provided to the critic (unless is conditional). 
Instead it encourages the Critic to output score that are different for real and fake events.

The WGAN can be implemented where -1 class labels are used for real event and the +1 class are used for the generated events. For that real we init them as below.

In [None]:
# Adversarial ground truths 
batch_size = 128
valid = -np.ones((batch_size, 1))
fake = np.zeros((batch_size, 1))

### The Wasserstein loss

As I'vementioned, the main contribution of the WGAN model is the use of a new loss function - The Wassertein loss. In this case we can implement the wasserstein loss as a custom function in Keras, that calculares the average score for the real and generated events. 

The score is maximizing the real events and minimizng the generated ones. Below the implementation of Wasserstein loss.

In [None]:
def wasserstein_loss(self, y_true, y_pred):
    return K.mean(y_true * y_pred)

### The weight clipping

As mentioned before, although other GAN architectures does not require, WGAN requires gradient clipping for the critic model.

In this case we've decided to defined to extend keras constraint class, with the below method: 

In [None]:
#https://keras.io/api/layers/constraints/
class ClipConstraint(constraints.Constraint):
    # set clip value when initialized
    def __init__(self, clip_value):
        self.clip_value = clip_value
 
    def __call__(self, weights):
        return backend.clip(weights, -self.clip_value, self.clip_value)
 
    # get the config
    def get_config(self):
        return {'clip_value': self.clip_value}

### WGAN in action to generate Synthetic tabular data

Now that we've covered the most theoritical peace about WGAN, and we've described more a less the differences when compared with the VanillaGAN, let's jump into it's use to generate synthetic tabular data.

I'll use the implementation of WGAN from a python file that I've here with the full implementation.

In [1]:
import sys
import os

sys.path.append(os.path.dirname(os.path.abspath(os.path.join(os.getcwd(), 'talks'))))

In [20]:
#Let's import the dataset thet we will be using today
from wgan import WGAN
import pandas as pd
import numpy as np
from sklearn import cluster
from credit_fraud import transformations

#The dataset that we will be using is the Credit Card fraud that can be found here https://www.kaggle.com/mlg-ulb/creditcardfraud
data = pd.read_csv('creditcard.csv', index_col=[0])

In [21]:
#List of columns different from the Class column
data_cols = list(data.columns[ data.columns != 'Class' ])
label_cols = ['Class']

print('Dataset columns: {}'.format(data_cols))
sorted_cols = ['V14', 'V4', 'V10', 'V17', 'V12', 'V26', 'Amount', 'V21', 'V8', 'V11', 'V7', 'V28', 'V19', 'V3', 'V22', 'V6', 'V20', 'V27', 'V16', 'V13', 'V25', 'V24', 'V18', 'V2', 'V1', 'V5', 'V15', 'V9', 'V23', 'Class']
processed_data = data[ sorted_cols ].copy()

Dataset columns: ['V1', 'V2', 'V3', 'V4', 'V5', 'V6', 'V7', 'V8', 'V9', 'V10', 'V11', 'V12', 'V13', 'V14', 'V15', 'V16', 'V17', 'V18', 'V19', 'V20', 'V21', 'V22', 'V23', 'V24', 'V25', 'V26', 'V27', 'V28', 'Amount']


In [22]:
data.head(10)

Unnamed: 0_level_0,V1,V2,V3,V4,V5,V6,V7,V8,V9,V10,...,V21,V22,V23,V24,V25,V26,V27,V28,Amount,Class
Time,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
0.0,-1.359807,-0.072781,2.536347,1.378155,-0.338321,0.462388,0.239599,0.098698,0.363787,0.090794,...,-0.018307,0.277838,-0.110474,0.066928,0.128539,-0.189115,0.133558,-0.021053,149.62,0
0.0,1.191857,0.266151,0.16648,0.448154,0.060018,-0.082361,-0.078803,0.085102,-0.255425,-0.166974,...,-0.225775,-0.638672,0.101288,-0.339846,0.16717,0.125895,-0.008983,0.014724,2.69,0
1.0,-1.358354,-1.340163,1.773209,0.37978,-0.503198,1.800499,0.791461,0.247676,-1.514654,0.207643,...,0.247998,0.771679,0.909412,-0.689281,-0.327642,-0.139097,-0.055353,-0.059752,378.66,0
1.0,-0.966272,-0.185226,1.792993,-0.863291,-0.010309,1.247203,0.237609,0.377436,-1.387024,-0.054952,...,-0.1083,0.005274,-0.190321,-1.175575,0.647376,-0.221929,0.062723,0.061458,123.5,0
2.0,-1.158233,0.877737,1.548718,0.403034,-0.407193,0.095921,0.592941,-0.270533,0.817739,0.753074,...,-0.009431,0.798278,-0.137458,0.141267,-0.20601,0.502292,0.219422,0.215153,69.99,0
2.0,-0.425966,0.960523,1.141109,-0.168252,0.420987,-0.029728,0.476201,0.260314,-0.568671,-0.371407,...,-0.208254,-0.559825,-0.026398,-0.371427,-0.232794,0.105915,0.253844,0.08108,3.67,0
4.0,1.229658,0.141004,0.045371,1.202613,0.191881,0.272708,-0.005159,0.081213,0.46496,-0.099254,...,-0.167716,-0.27071,-0.154104,-0.780055,0.750137,-0.257237,0.034507,0.005168,4.99,0
7.0,-0.644269,1.417964,1.07438,-0.492199,0.948934,0.428118,1.120631,-3.807864,0.615375,1.249376,...,1.943465,-1.015455,0.057504,-0.649709,-0.415267,-0.051634,-1.206921,-1.085339,40.8,0
7.0,-0.894286,0.286157,-0.113192,-0.271526,2.669599,3.721818,0.370145,0.851084,-0.392048,-0.41043,...,-0.073425,-0.268092,-0.204233,1.011592,0.373205,-0.384157,0.011747,0.142404,93.2,0
9.0,-0.338262,1.119593,1.044367,-0.222187,0.499361,-0.246761,0.651583,0.069539,-0.736727,-0.366846,...,-0.246914,-0.633753,-0.120794,-0.38505,-0.069733,0.094199,0.246219,0.083076,3.68,0


In [23]:
data.describe()

Unnamed: 0,V1,V2,V3,V4,V5,V6,V7,V8,V9,V10,...,V21,V22,V23,V24,V25,V26,V27,V28,Amount,Class
count,284807.0,284807.0,284807.0,284807.0,284807.0,284807.0,284807.0,284807.0,284807.0,284807.0,...,284807.0,284807.0,284807.0,284807.0,284807.0,284807.0,284807.0,284807.0,284807.0,284807.0
mean,1.16598e-15,3.416908e-16,-1.37315e-15,2.086869e-15,9.604066e-16,1.490107e-15,-5.556467e-16,1.177556e-16,-2.406455e-15,2.239751e-15,...,1.656562e-16,-3.44485e-16,2.578648e-16,4.471968e-15,5.340915e-16,1.687098e-15,-3.666453e-16,-1.220404e-16,88.349619,0.001727
std,1.958696,1.651309,1.516255,1.415869,1.380247,1.332271,1.237094,1.194353,1.098632,1.08885,...,0.734524,0.7257016,0.6244603,0.6056471,0.5212781,0.482227,0.4036325,0.3300833,250.120109,0.041527
min,-56.40751,-72.71573,-48.32559,-5.683171,-113.7433,-26.16051,-43.55724,-73.21672,-13.43407,-24.58826,...,-34.83038,-10.93314,-44.80774,-2.836627,-10.2954,-2.604551,-22.56568,-15.43008,0.0,0.0
25%,-0.9203734,-0.5985499,-0.8903648,-0.8486401,-0.6915971,-0.7682956,-0.5540759,-0.2086297,-0.6430976,-0.5354257,...,-0.2283949,-0.5423504,-0.1618463,-0.3545861,-0.3171451,-0.3269839,-0.07083953,-0.05295979,5.6,0.0
50%,0.0181088,0.06548556,0.1798463,-0.01984653,-0.05433583,-0.2741871,0.04010308,0.02235804,-0.05142873,-0.09291738,...,-0.02945017,0.006781943,-0.01119293,0.04097606,0.0165935,-0.05213911,0.001342146,0.01124383,22.0,0.0
75%,1.315642,0.8037239,1.027196,0.7433413,0.6119264,0.3985649,0.5704361,0.3273459,0.597139,0.4539234,...,0.1863772,0.5285536,0.1476421,0.4395266,0.3507156,0.2409522,0.09104512,0.07827995,77.165,0.0
max,2.45493,22.05773,9.382558,16.87534,34.80167,73.30163,120.5895,20.00721,15.59499,23.74514,...,27.20284,10.50309,22.52841,4.584549,7.519589,3.517346,31.6122,33.84781,25691.16,1.0


In [24]:
#Before training the GAN do not forget to apply the required data transformations
#To ease here we've applied a PowerTransformation
#This dataset presents highly skewness for some of the features presented here
data = transformations(data)

#For the purpose of this example we will only synthesize the minority class,
#for that reason I'll filter the events on the Class example>
train_data = data.loc[ data['Class']==1 ].copy()

In [25]:
#Create a new class column using KMeans - This will mainly be useful if we want to experiments with conditional GAN
print("Dataset info: Number of records - {} Number of varibles - {}".format(train_data.shape[0], train_data.shape[1]))
algorithm = cluster.KMeans
args, kwds = (), {'n_clusters':2, 'random_state':0}
labels = algorithm(*args, **kwds).fit_predict(train_data[ data_cols ])

print( pd.DataFrame( [ [np.sum(labels==i)] for i in np.unique(labels) ], columns=['count'], index=np.unique(labels) ) )

fraud_w_classes = train_data.copy()
fraud_w_classes['Class'] = labels

Dataset info: Number of records - 492 Number of varibles - 30
   count
0    384
1    108


### Finally the GAN training

In [60]:
#Define the WGAN and the training parameters
noise_dim = 32
dim = 128
batch_size = 128

log_step = 100
epochs = 500+1
learning_rate = 5e-4
models_dir = './cache'

In [54]:
train_sample = fraud_w_classes.copy().reset_index(drop=True)
train_sample = pd.get_dummies(train_sample, columns=['Class'], prefix='Class', drop_first=True)
label_cols = [list(train_sample.columns).index(i) for i in train_sample.columns if 'Class' in i ]
data_cols = [ i for i in train_sample.columns if i not in label_cols ]
train_sample[ data_cols ] = train_sample[ data_cols ] / 10 # scale to random noise size, one less thing to learn
train_no_label = train_sample[ data_cols ]

In [55]:
#Define the GAN and training arguments
gan_args = [batch_size, learning_rate, noise_dim, train_sample.shape[1], dim]
train_args = ['', epochs, log_step]

### Update the Critic more times than the Generator

In GAN architectures such as the VanillaGAN or even DCGAN, both the generator and the discriminator model must be updated in an equal amount of times. 

But this is not entirely true for the WGAN. In this case, the critic model must be updated more times than the generator model.

That's why we have a input parameter, that I've called the n_critic - this parameter controls the number of times the critic gets update from every batch of the generator. 

In this case I've set it to 3 times. But you can set for others and check the impacts in the end results.

In [None]:
#Training the GAN model chosen: in this particular case WGAN
synthesizer = WGAN(gan_args, n_critic=3)
synthesizer.train(train_sample, train_args)

0 [D loss: -0.062969, acc.: 0.00%] [G loss: 0.003234]
1 [D loss: -0.170221, acc.: 0.00%] [G loss: -0.062807]
2 [D loss: -0.201229, acc.: 4.69%] [G loss: -0.287758]
3 [D loss: -0.154375, acc.: 45.70%] [G loss: -0.653458]
4 [D loss: -0.065258, acc.: 50.00%] [G loss: -0.942831]
5 [D loss: -0.010851, acc.: 50.00%] [G loss: -1.070124]
6 [D loss: -0.005807, acc.: 50.00%] [G loss: -0.948883]
7 [D loss: -0.132307, acc.: 48.44%] [G loss: -0.590279]
8 [D loss: -0.308321, acc.: 8.59%] [G loss: -0.013747]
9 [D loss: -0.581507, acc.: 0.00%] [G loss: 0.682335]
10 [D loss: -0.955995, acc.: 0.00%] [G loss: 1.400254]
11 [D loss: -1.339898, acc.: 0.00%] [G loss: 1.988306]
12 [D loss: -1.723626, acc.: 0.00%] [G loss: 2.349804]
13 [D loss: -1.921125, acc.: 0.39%] [G loss: 2.446639]
14 [D loss: -1.852506, acc.: 3.91%] [G loss: 1.687127]
15 [D loss: -0.964682, acc.: 25.39%] [G loss: -0.299341]
16 [D loss: 0.071267, acc.: 46.88%] [G loss: -2.081782]
17 [D loss: 0.448980, acc.: 50.00%] [G loss: -2.155408]
18 

145 [D loss: -14.062318, acc.: 0.00%] [G loss: 27.090717]
146 [D loss: -17.724545, acc.: 0.00%] [G loss: 36.266289]
147 [D loss: -24.751625, acc.: 0.00%] [G loss: 52.213791]
148 [D loss: -31.645126, acc.: 0.00%] [G loss: 66.931946]
149 [D loss: -36.146839, acc.: 0.00%] [G loss: 74.961197]
150 [D loss: -40.256039, acc.: 0.00%] [G loss: 77.022583]
151 [D loss: -38.925484, acc.: 0.00%] [G loss: 74.407173]


In [None]:
#Visualize the generated data

#Setup the visualization parameters
import matplotlib.pyplot as plt

seed = 17
test_size = 492 # number of fraud cases
noise_dim = 32

np.random.seed(seed)
z = np.random.normal(size=(test_size, noise_dim))
real = synthesizer.get_data_batch(train=train_sample, batch_size=test_size, seed=seed)
real_samples = pd.DataFrame(real, columns=data_cols)
labels = fraud_w_classes['Class']

model_names = ['WGAN']
models = {'WGAN': ['WGAN', False, synthesizer.generator]}
colors = ['deepskyblue','blue']
markers = ['o','^']
class_labels = ['Class 1','Class 2']

col1, col2 = 'V17', 'V10'

base_dir = 'cache/'

#Actual fraud data visualization
model_steps = [ 0, 200, 300, 400, 500]
rows = len(model_steps)
columns = 5

axarr = [[]]*len(model_steps)

fig = plt.figure(figsize=(14,rows*3))

for model_step_ix, model_step in enumerate(model_steps):        
    axarr[model_step_ix] = plt.subplot(rows, columns, model_step_ix*columns + 1)
    
    for group, color, marker, label in zip(real_samples.groupby('Class_1'), colors, markers, class_labels ):
        plt.scatter( group[1][[col1]], group[1][[col2]], 
                         label=label, marker=marker, edgecolors=color, facecolors='none' )
    
    plt.title('Actual Fraud Data')
    plt.ylabel(col2) # Only add y label to left plot
    plt.xlabel(col1)
    xlims, ylims = axarr[model_step_ix].get_xlim(), axarr[model_step_ix].get_ylim()
    
    if model_step_ix == 0: 
        legend = plt.legend()
        legend.get_frame().set_facecolor('white')
    
    for i, model_name in enumerate( model_names[:] ):

        [model_name, with_class, generator_model] = models[model_name]

        generator_model.load_weights( base_dir + '_generator_model_weights_step_'+str(model_step)+'.h5')

        ax = plt.subplot(rows, columns, model_step_ix*columns + 1 + (i+1) )

        if with_class:
            g_z = generator_model.predict([z, labels])
            gen_samples = pd.DataFrame(g_z, columns=data_cols+label_cols)
            for group, color, marker, label in zip( gen_samples.groupby('Class_1'), colors, markers, class_labels ):
                plt.scatter( group[1][[col1]], group[1][[col2]], 
                                 label=label, marker=marker, edgecolors=color, facecolors='none' )
        else:
            g_z = generator_model.predict(z)
            gen_samples = pd.DataFrame(g_z, columns=data_cols)
            gen_samples.to_csv('Generated_sample.csv')
            plt.scatter( gen_samples[[col1]], gen_samples[[col2]], 
                             label=class_labels[0], marker=markers[0], edgecolors=colors[0], facecolors='none' )
        plt.title(model_name)   
        plt.xlabel(data_cols[0])
        ax.set_xlim(xlims), ax.set_ylim(ylims)

plt.suptitle('Comparison of WGAN outputs', size=16, fontweight='bold')
plt.tight_layout(rect=[0.075,0,1,0.95])


## The issues of WGAN

Although WGAN brings some benefits to data generation when compared with VanillaGAN, it stills have it's own issues:

 - WGAN still suffers from unstable training
 - Slow convergence after weight clipping - when clipping window is too large
 - Vanishing gradients - when clipping window is too small
 
 There are some options that since WGAN was published, were already suggested to deal with the clipping issues, and it is called gradient penalties -  [WGAN-GP article](https://arxiv.org/abs/1704.00028)
 
 ![image.png](attachment:image.png)