based on code from https://github.com/eriklindernoren/Keras-GAN



In [1]:
from __future__ import print_function, division

from imblearn.over_sampling import SMOTE
from keras.layers import Input, Dense, Reshape, Flatten, Dropout, multiply, GaussianNoise
from keras.layers import BatchNormalization, Activation, Embedding, ZeroPadding2D
from keras.layers import MaxPooling2D, concatenate, Conv2DTranspose, Concatenate
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.convolutional import UpSampling2D, Conv2D
from keras.utils import plot_model
from keras.models import Sequential, Model, load_model
from keras.optimizers import Adam, RMSprop
from keras import losses
from keras.utils import to_categorical
import keras.backend as K

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd



## Load Dataset

In [2]:
data = pd.read_csv('./fer2013.csv')
data.head()

Unnamed: 0,emotion,pixels,Usage
0,0,70 80 82 72 58 58 60 63 54 58 60 48 89 115 121...,Training
1,0,151 150 147 155 148 133 111 140 170 174 182 15...,Training
2,2,231 212 156 164 174 138 161 173 182 200 106 38...,Training
3,4,24 32 36 30 32 23 19 20 30 41 21 22 32 34 21 1...,Training
4,6,4 0 0 0 0 0 0 0 0 0 0 0 3 15 23 28 48 50 58 84...,Training


In [3]:
data.emotion.value_counts()

3    8989
6    6198
4    6077
2    5121
0    4953
5    4002
1     547
Name: emotion, dtype: int64

In [4]:
data = data[data.emotion != 1]
data['emotion'] = data.emotion.replace(6, 1)

In [5]:
data.emotion.value_counts()

3    8989
1    6198
4    6077
2    5121
0    4953
5    4002
Name: emotion, dtype: int64

In [6]:
dic = {0:'Angry', 1:'Neutral', 2:'Fear', 3:'Happy', 4:'Sad', 5:'Surprise'}

In [7]:
num_classes = 6
img_width = 48
img_height = 48

In [8]:
X = data['pixels']
y = data['emotion']

X_train = []
for i in X:
    X_train.append([int(j) for j in i.split()])

X_train = np.array(X_train)/255.0

X_train = X_train.reshape(X_train.shape[0], img_width, img_height, 1)
X_train = X_train.astype('float32')

y_train = y.to_numpy().reshape(-1, 1)

# sm = SMOTE()
# X_train, y_train = sm.fit_sample(X_train, y_train)

print(X_train.shape)

(35340, 48, 48, 1)


## BiCoWGAN Model

In [26]:
class BiCoWGAN():
    def __init__(self):
        # Input shape
        self.img_rows = 48
        self.img_cols = 48
        self.channels = 1
        self.img_shape = (self.img_rows, self.img_cols, self.channels)
        self.num_classes = 6
        self.latent_dim = 100

         # Following parameter and optimizer set as recommended in paper
        self.n_critic = 5
        self.clip_value = 0.01
        optimizer_c = RMSprop(lr=0.00005)

        optimizer_gan = Adam(0.0002, 0.5)


        # Build and compile the critic
        self.critic = self.build_critic()
        print(self.critic.summary())
        plot_model(self.critic, show_shapes=True)
        self.critic.compile(loss=self.wasserstein_loss,
            optimizer=optimizer_c,
            metrics=['accuracy'])

        # Build the generator
        self.generator = self.build_generator()

        # Build the encoder
        self.encoder = self.build_encoder()

        # The generator takes noise and the target label as input
        # and generates the corresponding digit of that label
        label = Input(shape=(1,))

        # For the combined model we will only train the generator
        self.critic.trainable = False

        # Generate image from sampled noise
        z = Input(shape=(self.latent_dim, ))
        img_ = self.generator([z, label])

        # Encode image
        img = Input(shape=self.img_shape)
        z_ = self.encoder(img)

        # Latent -> img is fake, and img -> latent is valid
        fake = self.critic([z, img_, label])
        valid = self.critic([z_, img, label])

        # Set up and compile the combined model
        # Trains generator to fool the discriminator
        self.bicowgan_generator = Model([z, img, label], [fake, valid])
        self.bicowgan_generator.compile(loss=[self.wasserstein_loss, self.wasserstein_loss],
            optimizer=optimizer_c)
        
    def wasserstein_loss(self, y_true, y_pred):
        return K.mean(y_true * y_pred)

    def build_encoder(self):
        model = Sequential()

        model.add(Conv2D(64, (5,5), strides=(2,2), padding='same', input_shape=self.img_shape))
        model.add(Conv2D(128, (5,5), strides=(2,2), padding='same'))
        model.add(BatchNormalization(momentum=0.9))
        model.add(Conv2D(256, (5,5), strides=(2,2), padding='same'))
        model.add(BatchNormalization(momentum=0.9))
        model.add(Conv2D(512, (5,5), strides=(2,2), padding='same'))
        model.add(BatchNormalization(momentum=0.9))
        model.add(Flatten())
        model.add(Dense(self.latent_dim))

        print('encoder')
        model.summary()

        img = Input(shape=self.img_shape)
        z = model(img)

        return Model(img, z)

    def build_generator(self):

        model = Sequential()

        model.add(Dense(128 * 12 * 12, activation="relu", input_dim=self.latent_dim))
        model.add(Reshape((12, 12, 128)))
        model.add(UpSampling2D())
        model.add(Conv2D(128, kernel_size=4, padding="same"))
        model.add(BatchNormalization(momentum=0.8))
        model.add(Activation("relu"))
        model.add(UpSampling2D())
        model.add(Conv2D(64, kernel_size=4, padding="same"))
        model.add(BatchNormalization(momentum=0.8))
        model.add(Activation("relu"))
        model.add(Conv2D(self.channels, kernel_size=4, padding="same"))
        model.add(Activation("tanh"))


        # # foundation for 12x12 image
        # n_nodes = 128 * 12 * 12
        # model.add(Dense(n_nodes, input_dim=self.latent_dim))
        # model.add(LeakyReLU(alpha=0.2))
        # model.add(Reshape((12, 12, 128)))
        # # upsample to 24x24
        # model.add(Conv2DTranspose(128, (4,4), strides=(2,2), padding='same'))
        # model.add(LeakyReLU(alpha=0.2))
        # # upsample to 48x48
        # model.add(Conv2DTranspose(128, (4,4), strides=(2,2), padding='same'))
        # model.add(LeakyReLU(alpha=0.2))
        # # generate
        # model.add(Conv2D(1, (12, 12), activation='tanh', padding='same'))

        print('generator')
        model.summary()

        z = Input(shape=(self.latent_dim,))
        label = Input(shape=(1,), dtype='int32')
        label_embedding = Flatten()(Embedding(self.num_classes, self.latent_dim)(label))

        model_input = multiply([z, label_embedding])
        img = model(model_input)

        return Model([z, label], img)


    def build_critic(self):
        xi = Input(self.img_shape)
        zi = Input(self.latent_dim)
        label = Input(shape=(1,), dtype='int32')

        xn = Conv2D(128, (5,5), padding='same')(xi)
        xn = LeakyReLU(alpha=0.2)(xn)
        # downsample to 24x24
        xn = Conv2D(128, (5,5), strides=(2,2), padding='same')(xn)
        xn = LeakyReLU(alpha=0.2)(xn)
        # downsample to 12x12
        xn = Conv2D(128, (5,5), strides=(2,2), padding='same')(xn)
        xn = LeakyReLU(alpha=0.2)(xn)
        # downsample to 6x6
        xn = Conv2D(128, (5,5), strides=(2,2), padding='same')(xn)
        xn = LeakyReLU(alpha=0.2)(xn)
        # downsample to 3x3
        xn = Conv2D(128, (5,5), strides=(2,2), padding='same')(xn)
        xn = LeakyReLU(alpha=0.2)(xn)
        # classifier
        xn = Flatten()(xn)

        zn = Flatten()(zi)

        label_embedding = Flatten()(Embedding(self.num_classes, np.prod(self.img_shape))(label))

        nn = concatenate([zn, xn, label_embedding])
        nn = Dense(1, activation='sigmoid')(nn)

        return Model([zi, xi, label], nn, name='critic')


    def train(self, epochs, batch_size=128, sample_interval=50):
        
        # Adversarial ground truths
        valid = -np.ones((batch_size, 1))
        fake = np.zeros((batch_size, 1))

        for epoch in range(epochs):

            for _ in range(self.n_critic):

                # ---------------------
                #  Train Critic
                # ---------------------

                # Select a random batch of images and encode
                idx = np.random.randint(0, X_train.shape[0], batch_size)
                imgs, labels = X_train[idx], y_train[idx]
                z_ = self.encoder.predict(imgs)

                # Sample noise and generate img
                z = np.random.normal(0, 1, (batch_size, 100))
                imgs_ = self.generator.predict([z, labels])

                # Train the critic (img -> z is valid, z -> img is fake)
                d_loss_real = self.critic.train_on_batch([z_, imgs, labels], valid)
                d_loss_fake = self.critic.train_on_batch([z, imgs_, labels], fake)
                d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)

                # Clip critic weights
                for l in self.critic.layers:
                    weights = l.get_weights()
                    weights = [np.clip(w, -self.clip_value, self.clip_value) for w in weights]
                    l.set_weights(weights)

            # ---------------------
            #  Train Generator
            # ---------------------

            # Condition on labels
            sampled_labels = np.random.randint(0, 6, batch_size).reshape(-1, 1)

            # Train the generator
            g_loss = self.bicowgan_generator.train_on_batch([z, imgs, sampled_labels], [valid, fake])

            # Plot the progress
            if epoch%20 == 0:
              print ("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (epoch, d_loss[0], 100*d_loss[1], g_loss[0]))

            # If at save interval => save generated image samples
            if epoch % sample_interval == 0:
                self.sample_images(epoch)


    def sample_images(self, epoch):
          r, c = 1, 6
          noise = np.random.normal(0, 1, (r * c, 100))
          sampled_labels = np.arange(0, 6).reshape(-1, 1)

          gen_imgs = self.generator.predict([noise, sampled_labels])

          # Rescale images 0 - 1
          gen_imgs = 0.5 * gen_imgs + 0.5

          fig, axs = plt.subplots(r, c)
          cnt = 0
          for j in range(c):
              axs[j].imshow(gen_imgs[cnt,:,:,0], cmap='gray')
              axs[j].set_title("%s" % dic[sampled_labels[cnt][0]])
              axs[j].axis('off')
              cnt += 1
          fig.savefig("images/%d.png" % epoch)
          plt.close()


In [None]:
if __name__ == '__main__':
    bicowgan = BiCoWGAN()
    bicowgan.train(epochs=10000, batch_size=128, sample_interval=200)

Model: "critic"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_73 (InputLayer)           [(None, 48, 48, 1)]  0                                            
__________________________________________________________________________________________________
conv2d_84 (Conv2D)              (None, 48, 48, 128)  3328        input_73[0][0]                   
__________________________________________________________________________________________________
leaky_re_lu_58 (LeakyReLU)      (None, 48, 48, 128)  0           conv2d_84[0][0]                  
__________________________________________________________________________________________________
conv2d_85 (Conv2D)              (None, 24, 24, 128)  409728      leaky_re_lu_58[0][0]             
_____________________________________________________________________________________________

In [None]:
if __name__ == '__main__':
    bicogan = BiCoGAN()
    bicogan.train(epochs=25000, batch_size=128, sample_interval=200)

Model: "discriminator"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_1 (InputLayer)            [(None, 48, 48, 1)]  0                                            
__________________________________________________________________________________________________
conv2d (Conv2D)                 (None, 48, 48, 128)  3328        input_1[0][0]                    
__________________________________________________________________________________________________
leaky_re_lu (LeakyReLU)         (None, 48, 48, 128)  0           conv2d[0][0]                     
__________________________________________________________________________________________________
conv2d_1 (Conv2D)               (None, 24, 24, 128)  409728      leaky_re_lu[0][0]                
______________________________________________________________________________________