# TimeGAN (PyTorch) — S&P 500 returns

**Paper:** Yoon et al., *Time-series Generative Adversarial Networks* (NeurIPS 2019)

**Pipeline**
- **Data:** `^GSPC` via `yfinance`
- **Target:** log-returns (stationary-ish), train/test split 80/20
- **Model:** GRU-based TimeGAN (Embedder/Recovery, Generator/Supervisor, Discriminator)
- **Training:** (1) Autoencoder pretrain, (2) Supervisor pretrain, (3) Adversarial joint train
- **Eval:** KS test, histogram, QQ-plot; price reconstruction from synthetic returns

In [29]:
%%capture
!pip -q install yfinance torch torchvision torchaudio
!pip -q install scipy statsmodels matplotlib

## Imports & Seed

In [30]:
import os, math, time, json, random
import numpy as np
import pandas as pd
import yfinance as yf
import matplotlib.pyplot as plt
from scipy.stats import ks_2samp, probplot

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler

from scipy.stats import ks_2samp, skew, kurtosis, jarque_bera, wasserstein_distance
from statsmodels.stats.diagnostic import acorr_ljungbox



## Config

In [31]:
class CFG:
    ticker = "^GSPC"
    start = "1990-01-01"
    end = None
    seq_len = 128            # shorter = faster
    feat_dim = 1            # return dimension
    z_dim = 8               # noise latent per timestep
    hidden = 64
    num_layers = 1
    batch_size = 256
    ae_epochs = 100         # autoencoder pretrain
    sup_epochs = 150        # supervisor pretrain
    adv_epochs = 1500       # adversarial training (increase for quality)
    lr = 1e-3
    device = "cuda" if torch.cuda.is_available() else "cpu"
    outdir = "./timegan_outputs"

os.makedirs(CFG.outdir, exist_ok=True)
print(CFG.device)

cuda


## Data: S&P500 → log-returns → windows
- Split chronologically 80/20
- Standardize using **train** stats only (avoid leakage)
- Build a tail-aware sampler to upweight rare shocks (98th percentile of |returns|)

In [32]:
raw = yf.download(CFG.ticker, start=CFG.start, end=CFG.end, progress=False, auto_adjust=True)
close = raw["Close"].dropna()
logp = np.log(close)
rets = logp.diff().dropna()

# split 80/20 (chronological)
split_idx = int(len(rets)*0.8)
train = rets.iloc[:split_idx]
test = rets.iloc[split_idx:]

# scale using train stats (avoid leakage)
mu, sigma = train.mean(), train.std()
train_z = ((train - mu)/sigma).astype(np.float32)
test_z  = ((test - mu)/sigma).astype(np.float32)

# Also keep unscaled train returns to get tail threshold
train_unscaled = (train_z.values.astype(np.float32) * float(sigma) + float(mu))
q_abs = float(np.quantile(np.abs(train_unscaled), 0.98))  # 98th pct. = rare shocks
q_abs

  train_unscaled = (train_z.values.astype(np.float32) * float(sigma) + float(mu))


0.030811725184321404

## Dataset & DataLoader (Tail-weighted sampling)

In [33]:
class SeqDatasetTail(Dataset):
    def __init__(self, series_z: pd.Series, series_unscaled: np.ndarray, seq_len: int):
        self.x = series_z.values.astype(np.float32)
        self.u = series_unscaled.astype(np.float32)
        self.seq_len = seq_len
        self.idxs = np.arange(0, len(self.x) - seq_len + 1)
        alpha = 5.0
        w = []
        for j in self.idxs:
            win = self.u[j:j+seq_len]
            w.append(1.0 + (alpha if np.max(np.abs(win)) >= q_abs else 0.0))
        self.weights = np.array(w, dtype=np.float32)

    def __len__(self): return len(self.idxs)
    def __getitem__(self, i):
        j = self.idxs[i]
        win = self.x[j:j+self.seq_len]
        return torch.from_numpy(win).unsqueeze(-1)  # (T,1)

train_ds = SeqDatasetTail(train_z, train_unscaled, CFG.seq_len)
weights_t = torch.as_tensor(train_ds.weights, dtype=torch.double)
sampler = WeightedRandomSampler(weights_t, num_samples=len(train_ds.weights), replacement=True)
train_dl = DataLoader(train_ds, batch_size=CFG.batch_size, sampler=sampler, drop_last=True)

## TimeGAN modules (GRU-based)

In [34]:
def squash_to_3d(x, feat_dim=1, T=None):
    """Ensure shape (B, T, feat). Tolerates (B,T), (B,T,feat,1), etc."""
    if T is None:
        T = x.shape[1] if x.dim() >= 2 else 1
    while x.dim() > 3 and x.size(-1) == 1:
        x = x.squeeze(-1)
    if x.dim() == 2:
        x = x.unsqueeze(-1)
    if x.dim() != 3:
        B = x.size(0)
        x = x.view(B, T, feat_dim)
    return x.contiguous()

def to(x):
    return x.to(CFG.device)

def sample_z(batch, seq_len, z_dim):
    return torch.randn(batch, seq_len, z_dim, device=CFG.device)

class GRUSeq(nn.Module):
    def __init__(self, in_dim, out_dim, hidden, num_layers=1):
        super().__init__()
        self.in_dim = in_dim
        self.rnn = nn.GRU(input_size=in_dim, hidden_size=hidden,
                          num_layers=num_layers, batch_first=True)
        self.proj = nn.Linear(hidden, out_dim)
    def forward(self, x):
        x = squash_to_3d(x, feat_dim=self.in_dim)
        h, _ = self.rnn(x)
        return self.proj(h)

class Embedder(nn.Module):
    def __init__(self, x_dim, hidden, num_layers=1):
        super().__init__()
        self.net = GRUSeq(x_dim, hidden, hidden, num_layers)
    def forward(self, x):
        return self.net(x)

class Recovery(nn.Module):
    def __init__(self, hidden, x_dim, num_layers=1):
        super().__init__()
        self.net = GRUSeq(hidden, x_dim, hidden, num_layers)
    def forward(self, h):
        return self.net(h)

class Generator(nn.Module):
    def __init__(self, z_dim, hidden, num_layers=1):
        super().__init__()
        self.net = GRUSeq(z_dim, hidden, hidden, num_layers)
    def forward(self, z):
        return self.net(z)

class Supervisor(nn.Module):
    def __init__(self, hidden, num_layers=1):
        super().__init__()
        self.net = GRUSeq(hidden, hidden, hidden, num_layers)
    def forward(self, h):
        return self.net(h)

class Discriminator(nn.Module):
    def __init__(self, hidden, num_layers=1):
        super().__init__()
        self.rnn = nn.GRU(input_size=hidden, hidden_size=hidden, num_layers=num_layers, batch_first=True)
        self.head = nn.Linear(hidden, 1)
    def forward(self, h):  # h: (B,T,H)
        o, _ = self.rnn(h)
        logit_t = self.head(o)        # (B,T,1)
        logit = logit_t.mean(dim=1)   # (B,1)
        return logit.squeeze(-1)      # (B,)

## Instantiate modules & optimizers

In [35]:
E = Embedder(CFG.feat_dim, CFG.hidden, CFG.num_layers).to(CFG.device)
R = Recovery(CFG.hidden, CFG.feat_dim, CFG.num_layers).to(CFG.device)
G = Generator(CFG.z_dim, CFG.hidden, CFG.num_layers).to(CFG.device)
S = Supervisor(CFG.hidden, CFG.num_layers).to(CFG.device)
D = Discriminator(CFG.hidden, CFG.num_layers).to(CFG.device)

params_ae = list(E.parameters()) + list(R.parameters())
opt_ae = torch.optim.Adam(params_ae, lr=CFG.lr)
opt_sup = torch.optim.Adam(S.parameters(), lr=CFG.lr)
opt_g   = torch.optim.Adam(list(G.parameters()) + list(S.parameters()), lr=CFG.lr)
opt_e   = torch.optim.Adam(E.parameters(), lr=CFG.lr)
opt_d   = torch.optim.Adam(D.parameters(), lr=CFG.lr)

bce = nn.BCEWithLogitsLoss()
mse = nn.MSELoss()

## Loss helpers

In [36]:
def sup_loss(H_real, H_sup):
    return mse(H_real[:,1:,:], H_sup[:,:-1,:])

def moment_loss(x, x_hat):
    m1 = x.mean(dim=1)
    s1 = x.std(dim=1)
    m2 = x_hat.mean(dim=1)
    s2 = x_hat.std(dim=1)
    return mse(m1, m2) + mse(s1, s2)

## Phase 1 — Autoencoder pretrain (E, R)

In [37]:
print("[Phase 1] Autoencoder pretrain...")
for epoch in range(1, CFG.ae_epochs + 1):
    E.train(); R.train()
    losses = []
    for xb in train_dl:
        xb = squash_to_3d(to(xb), feat_dim=CFG.feat_dim, T=CFG.seq_len)  # (B,T,1)
        H = E(xb)
        X_tilde = R(H)
        L_ae = mse(xb, X_tilde)
        opt_ae.zero_grad(); L_ae.backward(); opt_ae.step()
        losses.append(L_ae.item())
    if epoch % 10 == 0 or epoch == 1:
        print(f"AE epoch {epoch:03d} | recon MSE: {np.mean(losses):.5f}")

[Phase 1] Autoencoder pretrain...
AE epoch 001 | recon MSE: 1.16919
AE epoch 010 | recon MSE: 0.00678
AE epoch 020 | recon MSE: 0.00091
AE epoch 030 | recon MSE: 0.00031
AE epoch 040 | recon MSE: 0.00013
AE epoch 050 | recon MSE: 0.00009
AE epoch 060 | recon MSE: 0.00007
AE epoch 070 | recon MSE: 0.00036
AE epoch 080 | recon MSE: 0.00045
AE epoch 090 | recon MSE: 0.00004
AE epoch 100 | recon MSE: 0.00030


## Phase 2 — Supervisor pretrain (S)

In [38]:
print("[Phase 2] Supervisor pretrain...")
for epoch in range(1, CFG.sup_epochs + 1):
    E.train(); S.train()
    losses = []
    for xb in train_dl:
        xb = squash_to_3d(to(xb), feat_dim=CFG.feat_dim, T=CFG.seq_len)
        H = E(xb).detach()
        H_sup = S(H)
        L_s = sup_loss(H, H_sup)
        opt_sup.zero_grad(); L_s.backward(); opt_sup.step()
        losses.append(L_s.item())
    if epoch % 10 == 0 or epoch == 1:
        print(f"SUP epoch {epoch:03d} | sup MSE: {np.mean(losses):.5f}")

[Phase 2] Supervisor pretrain...
SUP epoch 001 | sup MSE: 0.03766
SUP epoch 010 | sup MSE: 0.01181
SUP epoch 020 | sup MSE: 0.01106
SUP epoch 030 | sup MSE: 0.01040
SUP epoch 040 | sup MSE: 0.00922
SUP epoch 050 | sup MSE: 0.00829
SUP epoch 060 | sup MSE: 0.00742
SUP epoch 070 | sup MSE: 0.00650
SUP epoch 080 | sup MSE: 0.00578
SUP epoch 090 | sup MSE: 0.00506
SUP epoch 100 | sup MSE: 0.00451
SUP epoch 110 | sup MSE: 0.00410
SUP epoch 120 | sup MSE: 0.00383
SUP epoch 130 | sup MSE: 0.00355
SUP epoch 140 | sup MSE: 0.00333
SUP epoch 150 | sup MSE: 0.00315


## Phase 3 — Adversarial joint training

In [39]:
print("[Phase 3] Adversarial training...")
for epoch in range(1, CFG.adv_epochs + 1):
    G.train(); S.train(); D.train(); E.train(); R.train()
    loss_d_hist, loss_g_hist = [], []
    for xb in train_dl:
        xb = squash_to_3d(to(xb), feat_dim=CFG.feat_dim, T=CFG.seq_len)
        B = xb.size(0)
        ones = torch.ones(B, device=CFG.device)
        zeros = torch.zeros(B, device=CFG.device)

        # ---- Update D ----
        H_real = E(xb).detach()
        H_real_sup = S(H_real).detach()
        z = sample_z(B, CFG.seq_len, CFG.z_dim)
        H_gen = G(z).detach()
        H_hat = S(H_gen).detach()

        D_real = D(H_real)
        D_fake = D(H_hat)
        loss_d = bce(D_real, ones) + bce(D_fake, zeros)
        opt_d.zero_grad(); loss_d.backward(); opt_d.step()
        loss_d_hist.append(loss_d.item())

        # ---- Update G (+S) ----
        z = sample_z(B, CFG.seq_len, CFG.z_dim)
        H_gen = G(z)
        H_hat = S(H_gen)
        X_hat = R(H_hat)

        # adversarial
        D_fake_for_g = D(H_hat)
        L_g_adv = bce(D_fake_for_g, ones)

        # moment matching on decoded fakes vs real
        L_mom = moment_loss(xb, X_hat)

        # supervised on synthetic hidden (TimeGAN)
        L_s = sup_loss(H_gen, H_hat)

        # Tail-aware loss (use fixed q_abs from train set)
        x_u  = xb * float(sigma) + float(mu)
        xh_u = X_hat * float(sigma) + float(mu)
        mask = (x_u.abs() >= q_abs).float()
        num  = mask.sum() + 1e-6
        L_tail = ((x_u - xh_u)**2 * mask).sum() / num

        # Total generator loss (coefficients to tune as needed)
        L_g_total = L_g_adv + 100.0 * L_s + 1.0 * L_mom + 10.0 * L_tail
        opt_g.zero_grad(); L_g_total.backward(); opt_g.step()
        loss_g_hist.append(L_g_total.item())

        # ---- Update Embedder (fresh graph)
        H_real_e = E(xb)
        X_tilde_e = R(H_real_e)
        L_r_e = mse(xb, X_tilde_e)
        H_sup_real_e = S(H_real_e)
        L_s_e = sup_loss(H_real_e, H_sup_real_e)
        L_e = L_r_e + 0.1 * L_s_e
        opt_e.zero_grad(); L_e.backward(); opt_e.step()

    if epoch % 20 == 0 or epoch == 1:
        print(f"ADV epoch {epoch:03d} | D: {np.mean(loss_d_hist):.4f} | G_total: {np.mean(loss_g_hist):.4f}")

[Phase 3] Adversarial training...


  x_u  = xb * float(sigma) + float(mu)
  xh_u = X_hat * float(sigma) + float(mu)


ADV epoch 001 | D: 1.0555 | G_total: 2.3000
ADV epoch 020 | D: 0.0291 | G_total: 5.4951
ADV epoch 040 | D: 0.0076 | G_total: 6.1763
ADV epoch 060 | D: 0.2746 | G_total: 5.4841
ADV epoch 080 | D: 1.2614 | G_total: 1.8001
ADV epoch 100 | D: 1.1975 | G_total: 1.8463
ADV epoch 120 | D: 1.1834 | G_total: 1.7228
ADV epoch 140 | D: 1.2238 | G_total: 2.1321
ADV epoch 160 | D: 1.2393 | G_total: 1.6166
ADV epoch 180 | D: 1.7273 | G_total: 2.2268
ADV epoch 200 | D: 1.2261 | G_total: 1.5327
ADV epoch 220 | D: 1.3887 | G_total: 1.3395
ADV epoch 240 | D: 1.3805 | G_total: 1.7893
ADV epoch 260 | D: 1.3064 | G_total: 1.1702
ADV epoch 280 | D: 1.3097 | G_total: 1.2171
ADV epoch 300 | D: 1.3838 | G_total: 1.2149
ADV epoch 320 | D: 1.4142 | G_total: 1.1430
ADV epoch 340 | D: 1.2929 | G_total: 1.1962
ADV epoch 360 | D: 1.3470 | G_total: 1.2431
ADV epoch 380 | D: 1.3640 | G_total: 1.0412
ADV epoch 400 | D: 1.3300 | G_total: 1.0590
ADV epoch 420 | D: 1.3649 | G_total: 1.2674
ADV epoch 440 | D: 1.3653 | G_to

## Sampling synthetic returns & simple evaluation
- KS test (real test returns vs synthetic)
- Histograms
- QQ-plots vs Normal

In [40]:
@torch.no_grad()
def sample_synthetic(n_seq=200):
    G.eval(); S.eval(); R.eval()
    out = []
    for _ in range(n_seq):
        z = sample_z(1, CFG.seq_len, CFG.z_dim)
        H_gen = G(z)
        H_hat = S(H_gen)
        X_hat = R(H_hat)
        out.append(X_hat.squeeze(0).cpu().numpy())
    return np.stack(out, axis=0)  # (n_seq, T, 1)

synth = sample_synthetic(n_seq=500).squeeze(-1)  # (N,T)
real_test = np.asarray(test_z, dtype=np.float32).reshape(-1)
synth_flat = np.asarray(synth, dtype=np.float32).reshape(-1)
mu, sigma = float(mu), float(sigma)
real_unscaled  = real_test * sigma + mu
synth_unscaled = synth_flat[: real_test.size] * sigma + mu

def metrics(x: np.ndarray):
    x = np.asarray(x, dtype=np.float64)
    out = {
        "mean": float(x.mean()),
        "std": float(x.std(ddof=1)),
        "skew": float(skew(x, bias=False)),
        "kurtosis_excess": float(kurtosis(x, fisher=True, bias=False)),
    }
    jb_stat, jb_p = jarque_bera(x)
    out["JB_stat"] = float(jb_stat); out["JB_p"] = float(jb_p)
    for q in (0.01, 0.05, 0.50, 0.95, 0.99):
        out[f"Q{int(q*100)}"] = float(np.quantile(x, q))
    # Risk tails (lower tail)
    var95 = np.quantile(x, 0.05); var99 = np.quantile(x, 0.01)
    out["VaR_95"] = float(var95); out["VaR_99"] = float(var99)
    out["ES_95"] = float(x[x <= var95].mean()) if (x <= var95).any() else np.nan
    out["ES_99"] = float(x[x <= var99].mean()) if (x <= var99).any() else np.nan
    # ACF(1) for returns and |returns|
    def acf1(y):
        y = y - y.mean()
        return float(np.corrcoef(y[1:], y[:-1])[0,1]) if len(y) > 1 else np.nan
    out["acf1_returns"] = acf1(x)
    out["acf1_abs_returns"] = acf1(np.abs(x))
    # Ljung–Box at lag 10 (returns & |returns|)
    try:
        lb_ret = acorr_ljungbox(x, lags=[10], return_df=True)
        out["LB_Q_ret_lag10"] = float(lb_ret["lb_stat"].iloc[0])
        out["LB_p_ret_lag10"] = float(lb_ret["lb_pvalue"].iloc[0])
    except Exception:
        out["LB_Q_ret_lag10"] = np.nan; out["LB_p_ret_lag10"] = np.nan
    try:
        lb_abs = acorr_ljungbox(np.abs(x), lags=[10], return_df=True)
        out["LB_Q_abs_lag10"] = float(lb_abs["lb_stat"].iloc[0])
        out["LB_p_abs_lag10"] = float(lb_abs["lb_pvalue"].iloc[0])
    except Exception:
        out["LB_Q_abs_lag10"] = np.nan; out["LB_p_abs_lag10"] = np.nan
    return out

# 3) Compute stats tables + distances (the “statistical value” complementing the histogram)
real_stats  = metrics(real_unscaled)
synth_stats = metrics(synth_unscaled)

df_stats = pd.DataFrame({"Real": real_stats, "Synthetic": synth_stats})
df_stats["Diff (Synth-Real)"] = df_stats["Synthetic"] - df_stats["Real"]

ks_stat, ks_p = ks_2samp(real_unscaled, synth_unscaled)
w_dist = wasserstein_distance(real_unscaled, synth_unscaled)

print(f"KS stat = {ks_stat:.4f}  (p = {ks_p:.4f})")
print(f"Wasserstein distance (1D EMD) = {w_dist:.6f}")
display(df_stats)

# Optional: save the table next to your figures
os.makedirs(CFG.outdir, exist_ok=True)
df_stats.to_csv(os.path.join(CFG.outdir, "timegan_stats_comparison.csv"))

ks_stat, pval = ks_2samp(real_unscaled, synth_unscaled)
print(f"KS test (test returns vs synthetic): stat={ks_stat:.4f}, p={pval:.4f}")

plt.figure(figsize=(9,5))
plt.hist(real_unscaled, bins=80, alpha=0.6, density=True, label="Réel (test)")
plt.hist(synth_unscaled, bins=80, alpha=0.6, density=True, label="Synthétique")
plt.title("Histogramme des rendements — Réel vs Synthétique (TimeGAN)")
plt.xlabel("Rendement journalier")
plt.ylabel("Densité")
plt.legend(); plt.grid(True, alpha=0.3); plt.tight_layout()
plt.savefig(os.path.join(CFG.outdir, "timegan_hist.png")); plt.close()

plt.figure(figsize=(5,5)); probplot(real_unscaled, dist="norm", plot=plt)
plt.title("QQ-plot (Réel vs Normale)"); plt.tight_layout()
plt.savefig(os.path.join(CFG.outdir, "timegan_qq_real.png")); plt.close()

plt.figure(figsize=(5,5)); probplot(synth_unscaled, dist="norm", plot=plt)
plt.title("QQ-plot (Synthétique vs Normale)"); plt.tight_layout()
plt.savefig(os.path.join(CFG.outdir, "timegan_qq_synth.png")); plt.close()

KS stat = 0.0608  (p = 0.0027)
Wasserstein distance (1D EMD) = 0.001886


  mu, sigma = float(mu), float(sigma)


Unnamed: 0,Real,Synthetic,Diff (Synth-Real)
mean,0.0004756263,0.0008842929,0.0004086666
std,0.01272047,0.01189536,-0.0008251089
skew,-0.6306668,-0.2855872,0.3450796
kurtosis_excess,14.68291,0.5855083,-14.0974
JB_stat,16136.97,49.53799,-16087.44
JB_p,0.0,1.749693e-11,1.749693e-11
Q1,-0.0352582,-0.03086625,0.004391948
Q5,-0.01856772,-0.01985536,-0.001287633
Q50,0.0008841069,0.001374868,0.0004907613
Q95,0.01676704,0.01987438,0.003107338


KS test (test returns vs synthetic): stat=0.0608, p=0.0027


## Price reconstruction from synthetic returns

In [41]:
@torch.no_grad()
def synth_prices(n=3, base_price=None):
    if base_price is None:
        base_price = float(close.iloc[int(len(close)*0.8)])
    seq = sample_synthetic(n_seq=n).squeeze(-1)
    rets_unscaled = seq * sigma + mu
    P0 = float(base_price)
    prices = []
    for i in range(n):
        rp = rets_unscaled[i]
        lp = np.cumsum(rp)
        p = P0 * np.exp(lp)
        prices.append(p)
    return np.array(prices)

paths = synth_prices(n=5)
plt.figure(figsize=(10,6))
for i in range(paths.shape[0]):
    plt.plot(paths[i], alpha=0.8)
plt.title("Prix synthétiques reconstruits (TimeGAN)")
plt.xlabel("t (jours)"); plt.ylabel("Prix")
plt.grid(True, alpha=0.3); plt.tight_layout()
plt.savefig(os.path.join(CFG.outdir, "timegan_price_paths.png")); plt.close()
print("Done. Figures saved in:", os.path.abspath(CFG.outdir))

Done. Figures saved in: /content/timegan_outputs


  base_price = float(close.iloc[int(len(close)*0.8)])


## Notes
- Consider gradient clipping (e.g., 1.0–5.0) if you see instability.
- Label smoothing for D targets can help (e.g., real=0.9).
- For deeper validation, compare ACF/PACF and Ljung–Box tests on returns and |returns| for real vs synthetic.