# Tutorial temporal conditional GANs

In [None]:
%load_ext autoreload
%autoreload 2

from tensorflow import keras
from tensorflow.keras import layers

import tensorflow as tf
import numpy as np

import matplotlib.pyplot as plt
import seaborn as sns

In [None]:
import tsgm

We want to generate a temporal dataset where each time series belongs to one of two classes. Let's go step by step through the solution.

#### 1. Define parameters of GAN:
First, we are defining the parameters of GAN, and the training algorithm.
- `latent_dim` is the size of input noise in GAN,
- `output_dim` is the number of classes, which, mentioned above is two,
- `feature_dim` is the number of time series features,
- `seq_len` is the length of the time series.

In [None]:
batch_size = 128

latent_dim = 1
feature_dim = 1
seq_len = 123
output_dim = 1

generator_in_channels = latent_dim + output_dim
discriminator_in_channels = feature_dim + output_dim


#### 2. Choose architecture.
Here, you can either use one of the architectures presented in `tsgm.models.architectures`, or define custom discriminator and generator architectures as `tf` models.

In [None]:
architecture = tsgm.models.architectures.zoo["t-cgan_c4"](
    seq_len=seq_len, feat_dim=feature_dim,
    latent_dim=latent_dim, output_dim=output_dim)
discriminator, generator = architecture.discriminator, architecture.generator


#### 3. Load data:
We are working with a toy dataset, and use `tsgm` utility called `tsgm.utils.gen_sine_const_switch_dataset` to generate the data. Next, we featurewise scale the dataset so that each feature is in $[-1, 1]$, using `tsgm.utils.TSFeatureWiseScaler`.

In [None]:
X, y = tsgm.utils.gen_sine_const_switch_dataset(50_000, seq_len, 1, max_value=20, const=10)

scaler = tsgm.utils.TSFeatureWiseScaler((-1, 1))
X_train = scaler.fit_transform(X)


X_train = X_train.astype(np.float32)
y = y.astype(np.float32)

dataset = tf.data.Dataset.from_tensor_slices((X_train, y))
dataset = dataset.shuffle(buffer_size=1024).batch(batch_size)


#### 4. Define model and train it.
We define conditional GAN model (`tsgm.models.cgan.ConditionalGAN`), compile it (here, you can choose different optimizers for discriminator and generator), and train using `.fit` model.

In [None]:
cond_gan = tsgm.models.cgan.ConditionalGAN(
    discriminator=discriminator, generator=generator, latent_dim=latent_dim,
    temporal=True,
)
cond_gan.compile(
    d_optimizer=keras.optimizers.Adam(learning_rate=0.0001, beta_1=0.5),
    g_optimizer=keras.optimizers.Adam(learning_rate=0.0001, beta_1=0.5),
    loss_fn=keras.losses.BinaryCrossentropy(),
)

cond_gan.fit(dataset, epochs=1)


#### 5. Visually explore the dataset.
There are many tools for convenient visualizations of temporal datasets. Here, we use `tsgm.utils.visualize_ts_lineplot`, which is convenient for TS classification datasets.

In [None]:
tsgm.utils.visualize_ts_lineplot(X_train, y, 5)
plt.savefig("data_temporal_gan.pdf", bbox_inches='tight')

In [None]:
n_samples = 5

tmp_latent = tf.random.normal(shape=(n_samples, seq_len, latent_dim))
random_vector_labels = tf.concat(
    [tmp_latent, y[:n_samples, :, None]], axis=2
)

generated_images = cond_gan.generator(random_vector_labels)

In [None]:
tsgm.utils.visualize_ts_lineplot(generated_images, y, 5)
plt.savefig("synth_data_temporal_gan.pdf", bbox_inches='tight')