
# RNN-GAN (LSTM) — Génération de séries temporelles du S&P 500

**Objectif :** entraîner un WGAN-GP avec un générateur et un critique LSTM pour générer des **rendements** puis tracer des **courbes de prix** synthétiques, et comparer des **statistiques clés** entre série réelle et synthétique **à la fin de 150 epochs**.

**Points clés :**
- Données : `^GSPC` via `yfinance` (depuis 1990).
- Prétraitement : log-rendements, normalisation avec **stats d'entraînement** (pas de fuite d'info).
- Modèle : WGAN-GP (LSTM Generator + LSTM Critic avec pooling temporel).
- Hyperparamètres (par défaut) : `nz=100`, `epochs=150`, `lr_G=4e-4`, `lr_D=2e-3`, split 80%/20%.
- Éval : KS test, histogrammes, QQ-plot, **nuage de courbes de prix**, **résumé statistique** (moyenne, std, skew, kurtosis, Sharpe, VaR, autocorr lag1, Ljung–Box sur r², max drawdown, CAGR).
- Snapshots : trajectoires générées enregistrées pendant l'entraînement.

> Astuce : si instabilité, baisse `lr_D` ou augmente `n_critic`. Réduis `batch_size` si mémoire GPU limite.


# Install (Colab recommandé)


In [1]:
# Si vous n'êtes pas sur Colab, adaptez les versions de torch à votre environnement.

!pip -q install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121
!pip -q install scipy statsmodels matplotlib pandas numpy yfinance


[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m780.5/780.5 MB[0m [31m2.8 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m23.7/23.7 MB[0m [31m100.6 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m823.6/823.6 kB[0m [31m51.3 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m14.1/14.1 MB[0m [31m124.2 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m664.8/664.8 MB[0m [31m2.1 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m410.6/410.6 MB[0m [31m2.7 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m121.6/121.6 MB[0m [31m21.0 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m56.5/56.5 MB[0m [31m43.5 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━

# Imports & Config

In [2]:


import os
import math
import time
import numpy as np
import pandas as pd
import yfinance as yf
import matplotlib.pyplot as plt
from scipy.stats import ks_2samp, probplot, skew, kurtosis

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset

from statsmodels.stats.diagnostic import acorr_ljungbox

class CFG:
    ticker = "^GSPC"
    start = "1990-01-01"
    end = None  # jusqu'à aujourd'hui
    seq_len = 64                # longueur des fenêtres
    nz = 100                    # dimension du bruit
    batch_size = 512            # réduisez si OOM
    epochs = 150                # demandé : 150 (rapide). Mettre 2000 pour un entraînement long.
    lr_G = 4e-4
    lr_D = 1e-3                 # plus stable qu'un 2e-2 agressif sur RNN-GAN
    n_critic = 8                # updates D par update G (WGAN-GP)
    lambda_gp = 5.0            # gradient penalty
    device = "cuda" if torch.cuda.is_available() else "cpu"
    outdir = "./rnn_gan_outputs"
    plot_epochs = [10, 50, 100, 150]  # snapshots pendant l'entraînement
    use_ema = True              # utilise une EMA du générateur pour l'échantillonnage

os.makedirs(CFG.outdir, exist_ok=True)


# Données : S&P 500 -> log-rendements

In [3]:

raw = yf.download(CFG.ticker, start=CFG.start, end=CFG.end, progress=False)
close = raw["Close"].dropna()
log_price = np.log(close)
rets = log_price.diff().dropna()

# Split temporel 80/20 (pas de shuffle)
split_idx = int(len(rets) * 0.8)
train_rets = rets.iloc[:split_idx]
test_rets  = rets.iloc[split_idx:]

# Normalisation (stats d'entraînement uniquement, pour éviter toute fuite)
mu, sigma = train_rets.mean(), train_rets.std()
train_z = (train_rets - mu) / sigma
test_z  = (test_rets - mu) / sigma

print(f"Observations totales: {len(rets)} | Train: {len(train_rets)} | Test: {len(test_rets)}")
print(f"mu(train)={mu.iloc[0]:.6f}, sigma(train)={sigma.iloc[0]:.6f}")

  raw = yf.download(CFG.ticker, start=CFG.start, end=CFG.end, progress=False)


Observations totales: 8969 | Train: 7175 | Test: 1794
mu(train)=0.000284, sigma(train)=0.011075


# Fenêtrage en séquences

In [4]:


class SeqDataset(Dataset):
    def __init__(self, series: pd.Series, seq_len: int):
        self.x = series.values.astype(np.float32)
        self.seq_len = seq_len
        self.idxs = np.arange(0, len(self.x) - seq_len + 1)
    def __len__(self):
        return len(self.idxs)
    def __getitem__(self, i):
        idx = self.idxs[i]
        window = self.x[idx: idx + self.seq_len]
        return torch.from_numpy(window).unsqueeze(-1)  # (seq_len, 1)

train_ds = SeqDataset(train_z, CFG.seq_len)
train_dl = DataLoader(
    train_ds,
    batch_size=CFG.batch_size,
    shuffle=True,
    drop_last=True,
    num_workers=2,                          # si souci, passer à 0
    pin_memory=(CFG.device=='cuda'),
)


# Helpers

In [5]:

def ensure_3d(x, T=CFG.seq_len):
    """Force la forme (B, T, 1). Tolère (B,T), (B,T,1), (B,T,1,1)."""
    if x.dim() == 4 and x.size(-1) == 1:
        x = x.squeeze(-1)
    if x.dim() == 2:
        x = x.unsqueeze(-1)
    if x.dim() != 3:
        # Dernier recours : reshape plat -> (B, T, 1) si possible
        B = x.size(0)
        x = x.view(B, T, -1)
        if x.size(-1) != 1:
            x = x[..., :1]
    return x

def reconstruct_price(p0, rets):
    """Reconstruit une courbe de prix à partir de rendements log."""
    return p0 * np.exp(np.cumsum(rets))

def max_drawdown(prices):
    prices = np.asarray(prices, dtype=float)
    run_max = np.maximum.accumulate(prices)
    drawdown = prices / run_max - 1.0
    return float(drawdown.min())

def annualized_vol(returns, periods_per_year=252):
    return float(np.std(returns, ddof=1) * np.sqrt(periods_per_year))

def sharpe(returns, rf=0.0, periods_per_year=252):
    ex = returns - rf / periods_per_year
    sd = np.std(ex, ddof=1)
    return float(np.mean(ex) / (sd + 1e-12) * np.sqrt(periods_per_year))

def cagr(prices, periods_per_year=252):
    return float((prices[-1] / prices[0]) ** (periods_per_year / len(prices)) - 1.0)


# Modèles : LSTM Generator & Critic (pooling temporel)

In [6]:

class LSTMGenerator(nn.Module):
    def __init__(self, nz=100, hidden=64, num_layers=1, out_dim=1):
        super().__init__()
        self.nz = nz
        self.lstm = nn.LSTM(input_size=nz, hidden_size=hidden, num_layers=num_layers, batch_first=True)
        self.proj = nn.Linear(hidden, out_dim)
    def forward(self, z):  # z: (B, T, nz)
        h, _ = self.lstm(z)
        out = self.proj(h)  # (B, T, 1)
        return out

class LSTMCritic(nn.Module):
    def __init__(self, in_dim=1, hidden=64, num_layers=1):
        super().__init__()
        self.lstm = nn.LSTM(input_size=in_dim, hidden_size=hidden, num_layers=num_layers, batch_first=True)
        self.head = nn.Linear(hidden, 1)
    def forward(self, x):  # x: (B, T, 1)
        x = ensure_3d(x)
        h, _ = self.lstm(x)           # (B,T,H)
        pooled = h.mean(dim=1)        # moyenne temporelle
        score = self.head(pooled)
        return score.squeeze(-1)

def init_module(m):
    if isinstance(m, (nn.Linear,)):
        nn.init.orthogonal_(m.weight); nn.init.zeros_(m.bias)
    if isinstance(m, nn.LSTM):
        for name, p in m.named_parameters():
            if "weight" in name: nn.init.orthogonal_(p)
            elif "bias" in name: nn.init.zeros_(p)

G = LSTMGenerator(nz=CFG.nz).to(CFG.device)
D = LSTMCritic().to(CFG.device)
G.apply(init_module); D.apply(init_module)

# EMA pour G (optionnel)
G_ema = None
if CFG.use_ema:
    G_ema = LSTMGenerator(nz=CFG.nz).to(CFG.device)
    G_ema.load_state_dict(G.state_dict())
    for p in G_ema.parameters(): p.requires_grad_(False)

def ema_update(model, ema_model, decay=0.999):
    if ema_model is None: return
    with torch.no_grad():
        for p, p_ema in zip(model.parameters(), ema_model.parameters()):
            p_ema.copy_(decay*p_ema + (1-decay)*p)


# Optimisation (WGAN-GP, TTUR)

In [7]:

opt_G = torch.optim.Adam(G.parameters(), lr=CFG.lr_G, betas=(0.5, 0.9))
opt_D = torch.optim.Adam(D.parameters(), lr=CFG.lr_D, betas=(0.5, 0.9))

def gradient_penalty(D, real, fake):
    real = ensure_3d(real)
    fake = ensure_3d(fake)
    B = real.size(0)
    eps = torch.rand(B, 1, 1, device=real.device)
    interp = eps * real + (1 - eps) * fake
    interp.requires_grad_(True)

    # Désactiver cuDNN pendant ce forward du Critic (double backward requis)
    with torch.backends.cudnn.flags(enabled=False):
        d_interp = D(interp)

    grads = torch.autograd.grad(
        outputs=d_interp,
        inputs=interp,
        grad_outputs=torch.ones_like(d_interp),
        create_graph=True,
        retain_graph=True,
        only_inputs=True,
    )[0]
    grads = grads.view(B, -1)
    gp = ((grads.norm(2, dim=1) - 1) ** 2).mean()
    return gp

@torch.no_grad()
def sample_sequences(model, n_seq=8, seq_len=CFG.seq_len, nz=CFG.nz):
    model.eval()
    z = torch.randn(n_seq, seq_len, nz, device=CFG.device)
    x_fake = model(z).squeeze(-1).detach().cpu().numpy()
    return x_fake


# Entraînement

In [8]:

loss_log = []
start_time = time.time()
p0_train = float(close.iloc[split_idx - 1])  # base price pour visualiser des prix pendant l'entraînement

for epoch in range(1, CFG.epochs + 1):
    G.train(); D.train()
    for x_real in train_dl:
        x_real = ensure_3d(x_real.to(CFG.device))

        # --- Update D (n_critic fois) ---
        for _ in range(CFG.n_critic):
            z = torch.randn(x_real.size(0), CFG.seq_len, CFG.nz, device=CFG.device)
            x_fake = G(z).detach()
            d_real = D(x_real)
            d_fake = D(x_fake)
            gp = gradient_penalty(D, x_real, x_fake)
            loss_D = -(d_real.mean() - d_fake.mean()) + CFG.lambda_gp * gp

            opt_D.zero_grad()
            loss_D.backward()
            torch.nn.utils.clip_grad_norm_(D.parameters(), 5.0)
            opt_D.step()

        # --- Update G ---
        z = torch.randn(x_real.size(0), CFG.seq_len, CFG.nz, device=CFG.device)
        x_fake = G(z)
        d_fake = D(x_fake)
        loss_G = -d_fake.mean()

        opt_G.zero_grad()
        loss_G.backward()
        torch.nn.utils.clip_grad_norm_(G.parameters(), 5.0)
        opt_G.step()

        if CFG.use_ema:
            ema_update(G, G_ema)

    loss_log.append((epoch, float(loss_D.detach().cpu()), float(loss_G.detach().cpu())))
    print(f"Epoch {epoch}/{CFG.epochs} | D: {loss_log[-1][1]:.4f} | G: {loss_log[-1][2]:.4f}")

    # Snapshots : courbes de rendements + prix synthétiques
    if epoch in set(CFG.plot_epochs) or epoch % 50 == 0 or epoch == CFG.epochs:
        with torch.no_grad():
            model_for_samples = G_ema if (CFG.use_ema and G_ema is not None) else G
            samples = sample_sequences(model_for_samples, n_seq=8)
            # Convertit en rendements "réels"
            samples_unscaled = samples * float(sigma) + float(mu)
            # Reconstruit des prix
            price_paths = [reconstruct_price(p0_train, s) for s in samples_unscaled]

        # Plot rendements
        plt.figure(figsize=(10, 6))
        for s in samples: plt.plot(s, alpha=0.8)
        plt.title(f"Trajectoires de rendements (normalisés) — Epoch {epoch}")
        plt.xlabel("t"); plt.ylabel("r normalisés")
        plt.grid(True, alpha=0.3); plt.tight_layout()
        plt.savefig(os.path.join(CFG.outdir, f"samples_returns_epoch_{epoch}.png")); plt.close()

        # Plot prix
        plt.figure(figsize=(10, 6))
        for s in price_paths: plt.plot(s, alpha=0.8)
        plt.title(f"Trajectoires de prix synthétiques — Epoch {epoch}")
        plt.xlabel("t"); plt.ylabel("Prix")
        plt.grid(True, alpha=0.3); plt.tight_layout()
        plt.savefig(os.path.join(CFG.outdir, f"samples_prices_epoch_{epoch}.png")); plt.close()

elapsed = time.time() - start_time
print(f"Entraînement terminé en {elapsed/60:.1f} min. Dernières pertes — D: {loss_log[-1][1]:.3f}, G: {loss_log[-1][2]:.3f}")


  p0_train = float(close.iloc[split_idx - 1])  # base price pour visualiser des prix pendant l'entraînement
  return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass


Epoch 1/150 | D: 0.4126 | G: 0.1628
Epoch 2/150 | D: -2.8589 | G: 1.7441
Epoch 3/150 | D: -3.8901 | G: 2.1164
Epoch 4/150 | D: -3.8552 | G: 2.7397
Epoch 5/150 | D: -4.1879 | G: 3.6797
Epoch 6/150 | D: -4.3392 | G: 1.9369
Epoch 7/150 | D: -3.9729 | G: 1.9053
Epoch 8/150 | D: -3.6351 | G: 1.9948
Epoch 9/150 | D: -2.4980 | G: 3.1477
Epoch 10/150 | D: -1.9288 | G: 5.6667


  samples_unscaled = samples * float(sigma) + float(mu)


Epoch 11/150 | D: -2.6211 | G: 6.1788
Epoch 12/150 | D: -2.8310 | G: 6.1326
Epoch 13/150 | D: -2.9944 | G: 8.1633
Epoch 14/150 | D: -3.0740 | G: 8.8858
Epoch 15/150 | D: -2.9711 | G: 8.3151
Epoch 16/150 | D: -3.6607 | G: 8.8994
Epoch 17/150 | D: -3.1597 | G: 9.1012
Epoch 18/150 | D: -2.6937 | G: 9.4539
Epoch 19/150 | D: -2.5901 | G: 10.1670
Epoch 20/150 | D: -2.2264 | G: 10.7636
Epoch 21/150 | D: -2.3435 | G: 10.3151
Epoch 22/150 | D: -2.0709 | G: 12.4177
Epoch 23/150 | D: -2.2086 | G: 12.4751
Epoch 24/150 | D: -2.0540 | G: 13.0757
Epoch 25/150 | D: -2.0142 | G: 12.9122
Epoch 26/150 | D: -1.6411 | G: 13.3154
Epoch 27/150 | D: -1.6941 | G: 13.1164
Epoch 28/150 | D: -1.9086 | G: 12.7462
Epoch 29/150 | D: -1.7806 | G: 13.3312
Epoch 30/150 | D: -2.1999 | G: 12.9028
Epoch 31/150 | D: -2.1134 | G: 13.4187
Epoch 32/150 | D: -2.1283 | G: 12.7002
Epoch 33/150 | D: -2.1940 | G: 13.8119
Epoch 34/150 | D: -1.9706 | G: 13.4579
Epoch 35/150 | D: -1.5385 | G: 13.3446
Epoch 36/150 | D: -2.1546 | G: 14

# Évaluation : KS test, histogrammes, QQ-plot

In [9]:

# Génère une séquence synthétique de la taille du test
G.eval()
model_for_eval = G_ema if (CFG.use_ema and G_ema is not None) else G

with torch.no_grad():
    T = len(test_z)
    z = torch.randn(1, T, CFG.nz, device=CFG.device)
    fake_norm = model_for_eval(z).squeeze().detach().cpu().numpy()   # rendements normalisés
    fake_eval = fake_norm * float(sigma) + float(mu)                 # rendements "réels" (log)

# 1D arrays
real_test = np.asarray(test_z, dtype=np.float32).reshape(-1)
real_test_unscaled  = real_test * float(sigma) + float(mu)
fake_test_unscaled  = fake_eval.astype(np.float32).reshape(-1)[: real_test.size]

# KS test (deux échantillons)
ks_stat, pval = ks_2samp(real_test_unscaled, fake_test_unscaled)
print(f"KS test: statistic={ks_stat:.4f}, p-value={pval:.4f}")

# Histogrammes comparés
plt.figure(figsize=(9, 5))
plt.hist(real_test_unscaled, bins=80, alpha=0.6, density=True, label="Réel (test)")
plt.hist(fake_test_unscaled, bins=80, alpha=0.6, density=True, label="Synthétique")
plt.title("Histogramme des rendements — Réel vs Synthétique")
plt.xlabel("Rendement journalier")
plt.ylabel("Densité")
plt.legend()
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig(os.path.join(CFG.outdir, "hist_reel_vs_synth.png"))
plt.close()

# QQ-plot contre normale
plt.figure(figsize=(5,5))
probplot(real_test_unscaled, dist="norm", plot=plt)
plt.title("QQ-plot (Réel vs Normale)")
plt.tight_layout()
plt.savefig(os.path.join(CFG.outdir, "qqplot_real_norm.png"))
plt.close()

plt.figure(figsize=(5,5))
probplot(fake_test_unscaled, dist="norm", plot=plt)
plt.title("QQ-plot (Synthétique vs Normale)")
plt.tight_layout()
plt.savefig(os.path.join(CFG.outdir, "qqplot_fake_norm.png"))
plt.close()


KS test: statistic=0.1159, p-value=0.0000


  fake_eval = fake_norm * float(sigma) + float(mu)                 # rendements "réels" (log)
  real_test_unscaled  = real_test * float(sigma) + float(mu)


# Prix : reconstruction + comparaisons statistiques (réel vs synthétique)


In [10]:

# Nuage de k trajectoires synthétiques sur la longueur du test
k = 20
T = len(test_z)
fake_mat = []
with torch.no_grad():
    for _ in range(k):
        z = torch.randn(1, T, CFG.nz, device=CFG.device)
        seq_norm = model_for_eval(z).squeeze().detach().cpu().numpy()  # rendements normalisés
        seq = seq_norm * float(sigma) + float(mu)                       # rendements log réels
        fake_mat.append(seq)
fake_mat = np.stack(fake_mat, axis=0)  # (k, T)

# Reconstruit les prix réel et synthétiques à partir du dernier prix du train
p0 = float(close.iloc[split_idx - 1])
real_price = reconstruct_price(p0, real_test_unscaled)
fake_prices = np.array([reconstruct_price(p0, f) for f in fake_mat])

# === Plot : réel vs nuage de trajectoires synthétiques ===
plt.figure(figsize=(10, 6))
for fp in fake_prices:
    plt.plot(fp, alpha=0.45, linewidth=1)
plt.plot(real_price, linewidth=2.5, label="Réel (test)")
plt.title("Prix synthétiques vs prix réel — fenêtre de test")
plt.xlabel("t")
plt.ylabel("Prix")
plt.legend()
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig(os.path.join(CFG.outdir, "prices_real_vs_synth.png"))
plt.close()

# Choisit une trajectoire synthétique "typique" (volatilité médiane)
vols = [annualized_vol(f) for f in fake_mat]
idx_med = int(np.argsort(vols)[len(vols)//2])
fake_returns_typ = fake_mat[idx_med]
fake_price_typ   = fake_prices[idx_med]

def lbq_p_r2(returns, lag=10):
    df = acorr_ljungbox(returns**2, lags=[lag], return_df=True)
    return float(df["lb_pvalue"].iloc[-1])

# Tableau de stats : réel vs synthétique (rendements + prix)
def summarize(returns, prices, name):
    return {
        "serie": name,
        "mean": float(np.mean(returns)),
        "std": float(np.std(returns, ddof=1)),
        "skew": float(skew(returns)),
        "kurt_excess": float(kurtosis(returns, fisher=True)),
        "sharpe_ann": sharpe(returns),
        "vol_ann": annualized_vol(returns),
        "VaR_95": float(np.quantile(returns, 0.05)),
        "VaR_99": float(np.quantile(returns, 0.01)),
        "autocorr_lag1": float(pd.Series(returns).autocorr(1)),
        "LBQ_p_r2_lag10": lbq_p_r2(returns, lag=10),
        "max_drawdown": max_drawdown(prices),
        "CAGR": cagr(prices),
    }

summary_df = pd.DataFrame([
    summarize(real_test_unscaled, real_price, "Réel (test)"),
    summarize(fake_returns_typ,   fake_price_typ, "Synthétique (vol médiane)"),
])

# Ajoute le KS déjà calculé pour info
summary_df["KS_statistic"] = [ks_stat, ks_stat]
summary_df["KS_pvalue"]    = [pval, pval]

# Sauvegardes
csv_path = os.path.join(CFG.outdir, "summary_stats.csv")
summary_df.to_csv(csv_path, index=False)
print("Résumé statistique (arrondi) :")
print(summary_df.round(6))

# Histogramme des prix finaux (répartition des niveaux atteints)
plt.figure(figsize=(8,5))
plt.hist([fp[-1] for fp in fake_prices], bins=40, alpha=0.7, density=True, label="Synthétique (prix final)")
plt.axvline(real_price[-1], linestyle="--", label="Réel (prix final)")
plt.title("Distribution du prix final — Synthétique vs Réel")
plt.xlabel("Prix final")
plt.ylabel("Densité")
plt.legend()
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig(os.path.join(CFG.outdir, "final_price_distribution.png"))
plt.close()

print(f"Stats sauvegardées : {csv_path}")


  seq = seq_norm * float(sigma) + float(mu)                       # rendements log réels
  p0 = float(close.iloc[split_idx - 1])


Résumé statistique (arrondi) :
                       serie      mean       std      skew  kurt_excess  \
0                Réel (test)  0.000476  0.012720 -0.630139    14.638681   
1  Synthétique (vol médiane)  0.000172  0.006494 -0.190256     0.950450   

   sharpe_ann   vol_ann    VaR_95    VaR_99  autocorr_lag1  LBQ_p_r2_lag10  \
0    0.593558  0.201931 -0.018568 -0.035258      -0.155738        0.000000   
1    0.419566  0.103092 -0.010741 -0.017781       0.008898        0.000041   

   max_drawdown      CAGR  KS_statistic  KS_pvalue  
0     -0.339250  0.129527      0.115942        0.0  
1     -0.121176  0.042755      0.115942        0.0  
Stats sauvegardées : ./rnn_gan_outputs/summary_stats.csv


# Sauvegarde des pertes + courbes

In [11]:


log_df = pd.DataFrame(loss_log, columns=["epoch", "loss_D", "loss_G"])
log_csv = os.path.join(CFG.outdir, "losses.csv")
log_df.to_csv(log_csv, index=False)
print(f"Pertes sauvegardées dans {log_csv}")

plt.figure(figsize=(8,4))
plt.plot(log_df["epoch"], log_df["loss_D"], label="D")
plt.plot(log_df["epoch"], log_df["loss_G"], label="G")
plt.xlabel("Epoch"); plt.ylabel("Loss")
plt.title("Pertes WGAN-GP")
plt.legend()
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig(os.path.join(CFG.outdir, "loss_curves.png"))
plt.close()


Pertes sauvegardées dans ./rnn_gan_outputs/losses.csv



## Notes
- Si les figures sont trop bruitées, augmentez `n_critic` (p.ex. 7 ou 10) ou baissez `lr_D`.
- Si `DataLoader` pose problème sous Windows, mettez `num_workers=0`.
- Pour des échantillons plus lisses, gardez `use_ema=True` (l'échantillonnage utilise `G_ema`).
- `batch_size` peut être réduit (p.ex. 128) si la mémoire GPU est limite.
