In [None]:
from google.colab import drive
from tensorflow.python.keras.optimizer_v2.adam import Adam

drive.mount('/content/drive')

In [None]:
from __future__ import print_function, division

from abc import ABC

import keras
import numpy as np
import tensorflow as tf
import tensorflow.keras.backend as K
from tensorflow.keras.layers import BatchNormalization
from tensorflow.keras.layers import Input, Dense, Reshape
from tensorflow.keras.models import Sequential, Model
from tensorflow.python.keras.layers import LeakyReLU, LSTM, Dropout, TimeDistributed
from tensorflow.python.keras.layers.merge import _Merge
from matplotlib import pyplot

In [None]:
SAMPLE_INTERVAL = 10
BATCH_SIZE = 32
EPOCHS = 2000
d_loss_values, g_loss_values = list(), list()

class RandomWeightedAverage(_Merge):
    """Provides a (random) weighted average between
    real and generated trajectory samples"""

    def _merge_function(self, inputs):
        alpha = K.random_uniform((1, 144, 1))
        return (alpha * inputs[0]) + ((1 - alpha) * inputs[1])

class WGANGP(keras.Model, ABC):
    def __init__(self):
        super(WGANGP, self).__init__()
        self.max_length = 144
        self.features = 1
        self.traj_shape = (self.max_length, self.features)
        self.latent_dim = 100

        # Following parameter and optimizer set as recommended in paper
        self.n_discriminator = 5
        self.gp_weight = 10

        # Build the generator and discriminator
        self.generator = self.build_generator()
        self.discriminator = self.build_discriminator()

    def compile(self, d_optimizer, g_optimizer, d_loss_fn, g_loss_fn):
        super(WGANGP, self).compile()
        self.d_optimizer = d_optimizer
        self.g_optimizer = g_optimizer
        self.d_loss_fn = d_loss_fn
        self.g_loss_fn = g_loss_fn

    def gradient_penalty(self, real_trajectories, fake_trajectories):
        """Calculates the gradient penalty.
        This loss is calculated on an interpolated trajectory
        and added to the discriminator loss.
        """
        # Get the interpolated trajectory
        interpolated_traj = RandomWeightedAverage()([real_trajectories, fake_trajectories])

        with tf.GradientTape() as gp_tape:
            gp_tape.watch(interpolated_traj)
            # 1. Get the discriminator output for this interpolated trajectory.
            pred = self.discriminator(interpolated_traj, training=True)
        # 2. Calculate the gradients w.r.t to this interpolated trajectory.
        grads = gp_tape.gradient(pred, [interpolated_traj])[0]

        # 3. Calculate the norm of the gradients.
        norm = tf.sqrt(tf.reduce_sum(tf.square(grads), axis=[1, 2]))
        gp = tf.reduce_mean((norm - 1.0) ** 2)

        return gp

    def build_generator(self):
        model = Sequential()

        model.add(Dense(256, input_dim=self.latent_dim))
        model.add(LeakyReLU(alpha=0.2))
        model.add(BatchNormalization(momentum=0.8))
        model.add(Dense(512))
        model.add(LeakyReLU(alpha=0.2))
        model.add(BatchNormalization(momentum=0.8))
        model.add(Dense(1024))
        model.add(LeakyReLU(alpha=0.2))
        model.add(BatchNormalization(momentum=0.8))
        model.add(Dense(np.prod(self.traj_shape), activation='tanh'))
        model.add(Reshape(self.traj_shape))

        model.summary()

        noise = Input(shape=(self.latent_dim,))
        traj = model(noise)

        return Model(noise, traj)

    def build_discriminator(self):
        model = Sequential()

        model.add(LSTM(128, input_shape=self.traj_shape, return_sequences=True))
        model.add(Dropout(0.2))
        model.add(LSTM(128, return_sequences=True))
        model.add(Dropout(0.2))
        model.add(TimeDistributed(Dense(128)))
        model.add(LeakyReLU(alpha=0.2))
        model.add(TimeDistributed(Dense(64)))
        model.add(LeakyReLU(alpha=0.2))
        model.add(TimeDistributed(Dense(1, activation='tanh')))

        model.summary()

        traj = Input(shape=self.traj_shape)
        validity = model(traj)

        return Model(traj, validity)

    def train_step(self, real_trajectories):
        # For each batch, we are going to perform the
        # following steps as laid out in the original paper:
        # 1. Train the generator and get the generator loss
        # 2. Train the discriminator and get the discriminator loss
        # 3. Calculate the gradient penalty
        # 4. Multiply this gradient penalty with a constant weight factor
        # 5. Add the gradient penalty to the discriminator loss
        # 6. Return the generator and discriminator losses as a loss dictionary

        for _ in range(self.n_discriminator):
            # Get the latent vector
            noise = tf.random.normal((tf.shape(real_trajectories)[0], self.latent_dim), 0, 1)
            with tf.GradientTape() as tape:
                # Generate fake trajectories from the latent vector
                fake_trajectories = self.generator(noise, training=True)
                # Get the logits for the fake trajectories
                fake_logits = self.discriminator(fake_trajectories, training=True)
                # Get the logits for the real trajectories
                real_logits = self.discriminator(real_trajectories, training=True)

                # Calculate the discriminator loss using the fake and real trajectory logits
                d_cost = self.d_loss_fn(real_traj=real_logits, fake_traj=fake_logits)

                # Calculate the gradient penalty
                gp = self.gradient_penalty(real_trajectories, fake_trajectories)

                # Add the gradient penalty to the original discriminator loss
                d_loss = d_cost + gp * self.gp_weight
                d_loss_values.append(d_loss)

            # Get the gradients w.r.t the discriminator loss
            d_gradient = tape.gradient(d_loss, self.discriminator.trainable_variables)
            # Update the weights of the discriminator using the discriminator optimizer
            self.d_optimizer.apply_gradients(
                zip(d_gradient, self.discriminator.trainable_variables)
            )
        print("discriminator training done")
        # Train the generator
        # Get the latent vector
        noise = tf.random.normal((tf.shape(real_trajectories)[0], self.latent_dim), 0, 1)
        with tf.GradientTape() as tape:
            # Generate fake trajectories using the generator
            generated_trajectories = self.generator(noise, training=True)
            # Get the discriminator logits for fake trajectories
            gen_traj_logits = self.discriminator(generated_trajectories, training=True)
            # Calculate the generator loss
            g_loss = self.g_loss_fn(gen_traj_logits)
            g_loss_values.append(g_loss)

        # Get the gradients w.r.t the generator loss
        gen_gradient = tape.gradient(g_loss, self.generator.trainable_variables)
        # Update the weights of the generator using the generator optimizer
        self.g_optimizer.apply_gradients(
            zip(gen_gradient, self.generator.trainable_variables)
        )
        print("gen done")
        return {"d_loss": d_loss, "g_loss": g_loss}

# Define the loss function for the discriminator.
def discriminator_loss(real_traj, fake_traj):
    real_loss = tf.reduce_mean(real_traj)
    fake_loss = tf.reduce_mean(fake_traj)
    return fake_loss - real_loss


# Define the loss function for the generator.
def generator_loss(fake_traj):
    return -tf.reduce_mean(fake_traj)


class GANMonitor(keras.callbacks.Callback):
    def __init__(self, latent_dim=100):
        self.latent_dim = latent_dim

    def on_epoch_end(self, epoch, logs={}):
        if epoch % SAMPLE_INTERVAL == 0:
            self.model.generator.save_weights("G_model_" + str(epoch) + ".h5")
            self.model.discriminator.save_weights("D_model_" + str(epoch) + ".h5")


# create a line plot of loss for the gan and save to file
def plot_history(d_loss, g_loss):
    # plot history
    pyplot.plot(d_loss, label='discri')
    pyplot.plot(g_loss, label='gen')
    pyplot.legend()
    pyplot.savefig('plot_line_plot_loss.png')
    pyplot.close()

In [None]:

if __name__ == '__main__':
    callback = GANMonitor()
    wgan = WGANGP()

    generator_optimizer = Adam(learning_rate=0.0002, beta_1=0.5, beta_2=0.9)
    discriminator_optimizer = Adam(learning_rate=0.0002, beta_1=0.5, beta_2=0.9)
    wgan.compile(
        d_optimizer=discriminator_optimizer,
        g_optimizer=generator_optimizer,
        g_loss_fn=generator_loss,
        d_loss_fn=discriminator_loss,
    )
        # Training data
    X_train = np.load('/content/drive/MyDrive/train.npy', allow_pickle=True)
    wgan.fit(X_train, batch_size=BATCH_SIZE, epochs=EPOCHS, callbacks=[callback])
    plot_history(d_loss_values, g_loss_values)
