In [None]:
%load_ext autoreload
%autoreload 2

from tensorflow import keras
from tensorflow.keras import layers

import matplotlib.pyplot as plt
import tensorflow as tf
import numpy as np

import matplotlib.pyplot as plt
import seaborn as sns

In [None]:
import tsgm

In [None]:
batch_size = 128

latent_dim = 1
feature_dim = 1
seq_len = 123
output_dim = 1

generator_in_channels = latent_dim + output_dim
discriminator_in_channels = feature_dim + output_dim


In [None]:
# Create the discriminator.
discriminator = keras.Sequential([
        keras.layers.InputLayer((seq_len, discriminator_in_channels)),
        layers.Conv1D(64, 3, strides=2, padding="same"),
        layers.LeakyReLU(alpha=0.2),
        layers.Conv1D(128, 3, strides=2, padding="same"),
        layers.LeakyReLU(alpha=0.2),
        layers.Conv1D(128, 3, strides=2, padding="same"),
        layers.LeakyReLU(alpha=0.2),
        layers.Conv1D(128, 3, strides=2, padding="same"),
        layers.LeakyReLU(alpha=0.2),
        layers.GlobalAvgPool1D(),
        layers.Dense(1, activation="sigmoid"),
    ],
    name="discriminator",
)
discriminator.summary()

# Create the generator.
generator = keras.Sequential([
        keras.layers.InputLayer((seq_len, generator_in_channels)),

        layers.Conv1DTranspose(64, 2, strides=2, padding="same"),
        layers.LeakyReLU(alpha=0.2),
        layers.Dropout(rate=0.2),
        layers.Conv1DTranspose(64, 2, strides=2, padding="same"),
        layers.LeakyReLU(alpha=0.2),
        layers.Dropout(rate=0.2),
        layers.Conv1DTranspose(64, 2, strides=2, padding="same"),
        layers.LeakyReLU(alpha=0.2),
        layers.Dropout(rate=0.2),
    
        layers.LSTM(128, return_sequences=True),
        layers.Dropout(rate=0.2),
    
        layers.AveragePooling1D(pool_size=57, strides=57),
        layers.LocallyConnected1D(1, 1, activation="tanh"),
    
    ],
    name="generator",
)
generator.summary()

In [None]:
g_input = keras.Input((seq_len, generator_in_channels))
x = layers.Conv1DTranspose(64, 2, strides=2, padding="same")(g_input)
x = layers.LeakyReLU(alpha=0.2)(x)        
x = layers.Dropout(rate=0.2)(x)
x = layers.Conv1DTranspose(64, 2, strides=2, padding="same")(x)
x = layers.LeakyReLU(alpha=0.2)(x)
x = layers.Dropout(rate=0.2)(x)
x = layers.Conv1DTranspose(64, 2, strides=2, padding="same")(x)
x = layers.LeakyReLU(alpha=0.2)(x)
x = layers.Dropout(rate=0.2)(x)

pool_and_stride = round((x.shape[1] + 1) / (seq_len + 1))
x = layers.AveragePooling1D(pool_size=pool_and_stride, strides=pool_and_stride)(x)
g_output = layers.LocallyConnected1D(1, 1, activation="tanh")(x)

generator = keras.Model(g_input, g_output, name="generator")
generator.summary()

In [None]:
d_input = keras.Input((seq_len, discriminator_in_channels))
x = layers.Conv1D(64, 3, strides=2, padding="same")(d_input)
x = layers.LeakyReLU(alpha=0.2)(x)
x = layers.Dropout(rate=0.2)(x)
x = layers.Conv1D(128, 3, strides=2, padding="same")(x)
x = layers.LeakyReLU(alpha=0.2)(x)
x = layers.Dropout(rate=0.2)(x)
x = layers.Conv1D(128, 3, strides=2, padding="same")(x)
x = layers.LeakyReLU(alpha=0.2)(x)
x = layers.Dropout(rate=0.2)(x)
x = layers.Conv1D(128, 3, strides=2, padding="same")(x)
x = layers.LeakyReLU(alpha=0.2)(x)
x = layers.Dropout(rate=0.2)(x)
x = layers.GlobalAvgPool1D()(x)
d_output = layers.Dense(1, activation="sigmoid")(x)
discriminator = keras.Model(d_input, d_output, name="discriminator")
discriminator.summary()

In [None]:
class ConditionalGAN(keras.Model):
    def __init__(self, discriminator, generator, latent_dim):
        super(ConditionalGAN, self).__init__()
        self.discriminator = discriminator
        self.generator = generator
        self.latent_dim = latent_dim
        self.gen_loss_tracker = keras.metrics.Mean(name="generator_loss")
        self.disc_loss_tracker = keras.metrics.Mean(name="discriminator_loss")

    @property
    def metrics(self):
        return [self.gen_loss_tracker, self.disc_loss_tracker]

    def compile(self, d_optimizer, g_optimizer, loss_fn):
        super(ConditionalGAN, self).compile()
        self.d_optimizer = d_optimizer
        self.g_optimizer = g_optimizer
        self.loss_fn = loss_fn

    def train_step(self, data):
        # Unpack the data.
        real_images, one_hot_labels = data
        batch_size = tf.shape(real_images)[0]

        # Add dummy dimensions to the labels so that they can be concatenated with
        # the images. This is for the discriminator.
        image_one_hot_labels = one_hot_labels[:, :, None,]
        image_one_hot_labels = tf.reshape(
            image_one_hot_labels, (-1, seq_len, output_dim)
        )

        # Sample random points in the latent space and concatenate the labels.
        # This is for the generator.        
        tmp_latent = tf.random.normal(shape=(batch_size, seq_len, self.latent_dim))
        random_vector_labels = tf.concat(
            [tmp_latent, one_hot_labels[:, :, None]], axis=2
        )
        
        # TODO: experiment
        #random_vector_labels = one_hot_labels[:, :, None]

        # Decode the noise (guided by labels) to fake images.
        generated_images = self.generator(random_vector_labels)

        # Combine them with real images. Note that we are concatenating the labels
        # with these images here.
        fake_image_and_labels = tf.concat([generated_images, image_one_hot_labels], -1)
        real_image_and_labels = tf.concat([real_images, image_one_hot_labels], -1)
        combined_images = tf.concat(
            [fake_image_and_labels, real_image_and_labels], axis=0
        )

        # Assemble labels discriminating real from fake images.
        labels = tf.concat(
            [tf.ones((batch_size, 1)), tf.zeros((batch_size, 1))], axis=0
        )

        # Train the discriminator.
        with tf.GradientTape() as tape:
            predictions = self.discriminator(combined_images)
            d_loss = self.loss_fn(labels, predictions)
        grads = tape.gradient(d_loss, self.discriminator.trainable_weights)
        self.d_optimizer.apply_gradients(
            zip(grads, self.discriminator.trainable_weights)
        )

        # Sample random points in the latent space.
        random_latent_vectors = tf.random.normal(shape=(batch_size, self.latent_dim))
        #tmp_latent = tf.repeat(
        #    random_latent_vectors[:, None, :], repeats=[seq_len], axis=1
        #)
        
        tmp_latent = tf.random.normal(shape=(batch_size, seq_len, self.latent_dim))
        random_vector_labels = tf.concat(
            [tmp_latent, one_hot_labels[:, :, None]], axis=2
        )

        # Assemble labels that say "all real images".
        misleading_labels = tf.zeros((batch_size, 1))

        # Train the generator (note that we should *not* update the weights
        # of the discriminator)!
        with tf.GradientTape() as tape:
            fake_images = self.generator(random_vector_labels)
            fake_image_and_labels = tf.concat([fake_images, image_one_hot_labels], -1)
            predictions = self.discriminator(fake_image_and_labels)
            g_loss = self.loss_fn(misleading_labels, predictions)
        grads = tape.gradient(g_loss, self.generator.trainable_weights)
        self.g_optimizer.apply_gradients(zip(grads, self.generator.trainable_weights))

        # Monitor loss.
        self.gen_loss_tracker.update_state(g_loss)
        self.disc_loss_tracker.update_state(d_loss)
        return {
            "g_loss": self.gen_loss_tracker.result(),
            "d_loss": self.disc_loss_tracker.result(),
        }

In [None]:
X, y = tsgm.utils.gen_sine_const_switch_dataset(50_000, seq_len, 1, max_value=20, const=10)

scaler = tsgm.utils.TSFeatureWiseScaler((-1, 1))
X_train = scaler.fit_transform(X)

#scaler_y = tss.utils.TSFeatureWiseScaler((-1, 1))
#y = scaler_y.fit_transform(y)

X_train = X_train.astype(np.float32)
y = y.astype(np.float32)

dataset = tf.data.Dataset.from_tensor_slices((X_train, y))
dataset = dataset.shuffle(buffer_size=1024).batch(batch_size)


In [None]:
import typing

class GANMonitor(keras.callbacks.Callback):
    def __init__(self, num_samples: int, latent_dim: int = 128, num_classes: int = 2,
                 save: bool = True, save_path: typing.Optional[str] = None):
        self._num_samples = num_samples
        self._latent_dim = latent_dim
        self._num_classes = num_classes
        self._save = save
        self._save_path = save_path

        if self._save and self._save_path is None:
            self._save_path = "/tmp/"
            print("[WARNING]: save_path is not specified. Using `/tmp` as the default save_path")

        if self._save is False and self._save_path is not None:
            print("[WARNING]: save_path is specified, but save is False.")
            os.makedirs(self._save_path, exist_ok=True)

    def on_epoch_end(self, epoch, logs=None):
        random_latent_vectors = tf.random.normal(shape=(self._num_classes * self._num_samples, self._latent_dim))

        labels = []
        for i in range(self._num_classes):
            if not len(labels):
                labels = keras.utils.to_categorical([i], self._num_classes)
            else:
                labels = tf.concat((labels, keras.utils.to_categorical([i], self._num_classes)), 0)

        labels = tf.repeat(labels, self._num_samples, axis=0)
        generated_images = self.model.generator(tf.concat([random_latent_vectors, labels], 1))

        for i in range(self._num_classes * self._num_samples):
            sns.lineplot(x=range(0, generated_images[i].shape[0]), y=tf.squeeze(generated_images[i]))
            if self._save:
                plt.savefig("/tmp/epoch_{}_sample_{}".format(epoch, i))
            else:
                plt.show()

In [None]:
cond_gan = ConditionalGAN(
    discriminator=discriminator, generator=generator, latent_dim=latent_dim
)
cond_gan.compile(
    d_optimizer=keras.optimizers.Adam(learning_rate=0.0001, beta_1=0.5),
    g_optimizer=keras.optimizers.Adam(learning_rate=0.0001, beta_1=0.5),
    loss_fn=keras.losses.BinaryCrossentropy(),
)

cbk = GANMonitor(num_samples=2, latent_dim=latent_dim)
cond_gan.fit(dataset, epochs=10000)#, callbacks=[cbk])

In [None]:
tsgm.utils.visualize_ts_lineplot(X_train, y, 5)
plt.savefig("data_temporal_gan.pdf", bbox_inches='tight')

In [None]:
n_samples = 5

tmp_latent = tf.random.normal(shape=(n_samples, seq_len, latent_dim))
random_vector_labels = tf.concat(
    [tmp_latent, y[:n_samples, :, None]], axis=2
)

generated_images = cond_gan.generator(random_vector_labels)

In [None]:
tsgm.utils.visualize_ts_lineplot(generated_images, y, 5)
plt.savefig("synth_data_temporal_gan.pdf", bbox_inches='tight')