In [60]:
import sys
import os

# Dynamically add root directory (project base)
project_root = os.path.abspath(os.path.join(os.getcwd(), '..'))
if project_root not in sys.path:
    sys.path.insert(0, project_root)


In [61]:
import numpy as np 
import tensorflow as tf
import random
from tensorflow.keras import layers, Model
from sklearn.model_selection import train_test_split
from tensorflow.keras import mixed_precision
from models.timegan.timegan_model import TimeGAN

mixed_precision.set_global_policy('mixed_float16')


In [62]:
# Load your processed data
X = np.load("../data/processed/chest_X_conditional.npy")
print(f"✅ X shape: {X.shape}")

✅ X shape: (55917, 256, 12)


In [73]:
seq_len = 256
feature_dim = 12  # 8 signals + 4 class indicators
hidden_dim = 16
num_layers = 2
iterations = 10000
batch_size = 64
gamma = 1  # weight on supervised loss


timegan = TimeGAN(seq_len, feature_dim, hidden_dim, num_layers)

In [64]:
bce = tf.keras.losses.BinaryCrossentropy()
mse = tf.keras.losses.MeanSquaredError()

generator_optimizer = mixed_precision.LossScaleOptimizer(
    tf.keras.optimizers.Adam(learning_rate=1e-4), dynamic=True)
discriminator_optimizer = mixed_precision.LossScaleOptimizer(
    tf.keras.optimizers.Adam(learning_rate=1e-4), dynamic=True)
embedder_optimizer = mixed_precision.LossScaleOptimizer(
    tf.keras.optimizers.Adam(learning_rate=1e-4), dynamic=True)
supervisor_optimizer = mixed_precision.LossScaleOptimizer(
    tf.keras.optimizers.Adam(learning_rate=1e-4), dynamic=True)

In [65]:
def sample_random_noise(batch_size, seq_len, dim):
    return tf.random.normal(shape=(batch_size, seq_len, dim), dtype=tf.float32)

In [66]:
@tf.function(jit_compile=True)
def train_embedder(x):
    with tf.GradientTape() as tape:
        h = timegan.embedder(x)
        x_tilde = timegan.recovery(h)
        e_loss = mse(x, x_tilde)

    grads = tape.gradient(e_loss, timegan.embedder.trainable_variables + timegan.recovery.trainable_variables)
    embedder_optimizer.apply_gradients(zip(grads, timegan.embedder.trainable_variables + timegan.recovery.trainable_variables))
    return e_loss

In [67]:
@tf.function(jit_compile=True)
def train_supervisor(x):
    with tf.GradientTape() as tape:
        h = timegan.embedder(x)
        h_supervised = timegan.supervisor(h)
        s_loss = mse(h[:, 1:, :], h_supervised[:, :-1, :])  # next-step prediction

    grads = tape.gradient(s_loss, timegan.supervisor.trainable_variables)
    supervisor_optimizer.apply_gradients(zip(grads, timegan.supervisor.trainable_variables))

    return s_loss

In [68]:
@tf.function(jit_compile=True)
def train_generator(x):
    z = sample_random_noise(tf.shape(x)[0], seq_len, hidden_dim)

    with tf.GradientTape() as tape:
        e_hat = timegan.generator(z)
        h_hat = timegan.supervisor(e_hat)
        y_fake = timegan.discriminator(h_hat)

        g_loss_u = bce(tf.ones_like(y_fake), y_fake)  # adversarial loss
        g_loss_s = mse(h_hat[:, 1:, :], timegan.supervisor(h_hat)[:, :-1, :])  # supervised loss
        g_loss = g_loss_u + gamma * g_loss_s

    grads = tape.gradient(g_loss, timegan.generator.trainable_variables + timegan.supervisor.trainable_variables)
    generator_optimizer.apply_gradients(zip(grads, timegan.generator.trainable_variables + timegan.supervisor.trainable_variables))

    return g_loss

In [69]:
@tf.function(jit_compile=True)
def train_discriminator(x):
    h = timegan.embedder(x)
    z = sample_random_noise(tf.shape(x)[0], seq_len, hidden_dim)
    h_hat = timegan.generate_latent(z)

    with tf.GradientTape() as tape:
        y_real = timegan.discriminator(h)
        y_fake = timegan.discriminator(tf.stop_gradient(h_hat))

        d_loss_real = bce(tf.ones_like(y_real), y_real)
        d_loss_fake = bce(tf.zeros_like(y_fake), y_fake)
        d_loss = d_loss_real + d_loss_fake

    grads = tape.gradient(d_loss, timegan.discriminator.trainable_variables)
    discriminator_optimizer.apply_gradients(zip(grads, timegan.discriminator.trainable_variables))

    return d_loss

In [70]:
# Suppose you have a model like timegan
# You create fake (zero) data to initialize it:

# 1. This calls the embedder + recovery to get x_tilde
_ = timegan(tf.zeros((1, seq_len, feature_dim)))

# 2. This initializes the generator + supervisor (latent space modules)
_ = timegan.generate_latent(tf.zeros((1, seq_len, hidden_dim)))

# 3. This initializes the discriminator
_ = timegan.discriminator(tf.zeros((1, seq_len, hidden_dim)))

In [71]:
# Checkpoint setup for saving and restoring
ckpt = tf.train.Checkpoint(
    generator=timegan.generator,
    supervisor=timegan.supervisor,
    discriminator=timegan.discriminator,
    embedder=timegan.embedder,
    generator_optimizer=generator_optimizer,
    discriminator_optimizer=discriminator_optimizer,
    embedder_optimizer=embedder_optimizer,
    supervisor_optimizer=supervisor_optimizer
)

ckpt_manager = tf.train.CheckpointManager(ckpt, './checkpoints/timegan', max_to_keep=3)

# Optionally restore the latest checkpoint if it exists
if ckpt_manager.latest_checkpoint:
    ckpt.restore(ckpt_manager.latest_checkpoint)
    print(f"✅ Restored from checkpoint: {ckpt_manager.latest_checkpoint}")
else:
    print("🆕 Initializing from scratch.")

🆕 Initializing from scratch.


In [74]:
X_train = tf.convert_to_tensor(X.astype(np.float32))

log_dir = "./logs/timegan"
summary_writer = tf.summary.create_file_writer(log_dir)

# Logging and checkpoints
log_every = 100
save_every = 1000
g_losses, d_losses, e_losses, s_losses = [], [], [], []

print("🚀 Starting TimeGAN training...")

for step in range(iterations):
    idx = np.random.randint(0, X_train.shape[0], batch_size)
    x_batch = tf.gather(X_train, idx)
    x_batch.set_shape([batch_size, seq_len, feature_dim])

    # Phase 1: Embedder training
    if step < 1000:
        e_loss = train_embedder(x_batch)
        if step % log_every == 0:
            print(f"[Step {step}] Embedder loss: {e_loss.numpy():.4f}")
            e_losses.append((step, e_loss.numpy()))
            with summary_writer.as_default():
                tf.summary.scalar("Embedder Loss", e_loss, step=step)
        continue

    # Phase 2: Supervisor training
    if step < 2000:
        s_loss = train_supervisor(x_batch)
        if step % log_every == 0:
            print(f"[Step {step}] Supervisor loss: {s_loss.numpy():.4f}")
            s_losses.append((step, s_loss.numpy()))
            with summary_writer.as_default():
                tf.summary.scalar("Supervisor Loss", s_loss, step=step)
        continue

    # Phase 3: Joint training
    g_loss = train_generator(x_batch)
    d_loss = train_discriminator(x_batch)

    if step % log_every == 0:
        print(f"[Step {step}] Generator loss: {g_loss:.4f}, Discriminator loss: {d_loss:.4f}")
        g_losses.append((step, float(g_loss)))
        d_losses.append((step, float(d_loss)))
        with summary_writer.as_default():
            tf.summary.scalar("Generator Loss", g_loss, step=step)
            tf.summary.scalar("Discriminator Loss", d_loss, step=step)

    # Optional model saving
    if step % save_every == 0 and step != 0:
        ckpt_manager.save()
        print(f"✅ Saved checkpoint at step {step}")

print("🚀 TimeGAN training End")


🚀 Starting TimeGAN training...
[Step 0] Embedder loss: 0.5026
[Step 100] Embedder loss: 0.5029
[Step 200] Embedder loss: 0.4757
[Step 300] Embedder loss: 0.5031
[Step 400] Embedder loss: 0.5427
[Step 500] Embedder loss: 0.4926
[Step 600] Embedder loss: 0.5059
[Step 700] Embedder loss: 0.4908
[Step 800] Embedder loss: 0.4353
[Step 900] Embedder loss: 0.5116
[Step 1000] Supervisor loss: 0.0258
[Step 1100] Supervisor loss: 0.0246
[Step 1200] Supervisor loss: 0.0229
[Step 1300] Supervisor loss: 0.0198
[Step 1400] Supervisor loss: 0.0195
[Step 1500] Supervisor loss: 0.0186
[Step 1600] Supervisor loss: 0.0173
[Step 1700] Supervisor loss: 0.0172
[Step 1800] Supervisor loss: 0.0180
[Step 1900] Supervisor loss: 0.0164
[Step 2000] Generator loss: 0.7103, Discriminator loss: 1.4533
✅ Saved checkpoint at step 2000
[Step 2100] Generator loss: 0.6922, Discriminator loss: 1.3625
[Step 2200] Generator loss: 0.6984, Discriminator loss: 1.2727
[Step 2300] Generator loss: 0.7010, Discriminator loss: 1.20