# üöÄ Entra√Ænement du VAE

Notebook pour entra√Æner le VAE sur les s√©quences de workload.

---

## üì¶ Installation et imports

In [None]:
# Installation des d√©pendances si n√©cessaire
# !pip install torch matplotlib tqdm numpy

In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from pathlib import Path
import json
import sys
import matplotlib.pyplot as plt
from tqdm import tqdm

print("="*70)
print("üöÄ ENTRA√éNEMENT DU VAE")
print("="*70)

## üíæ Montage de Google Drive (optionnel)

Si vos donn√©es sont sur Google Drive, d√©commentez et ex√©cutez cette cellule.

In [None]:
from google.colab import drive
drive.mount('/content/drive')


## üèóÔ∏è D√©finition du mod√®le VAE

Si vous n'avez pas le code du mod√®le, d√©finissez-le ici.

In [None]:
# Importer les classes VAE LSTM
from src.models.vae_lstm import VAELSTM as VAE
from src.training.losses import vae_loss, VAELossTracker

print("‚úì Mod√®les import√©s depuis src/models/vae_lstm")


## ‚öôÔ∏è Configuration

In [None]:
CONFIG = {
    # Donn√©es
    'data_dir': 'data/processed/sequences',
    
    # Architecture VAE
    'input_dim': 100,           # Longueur s√©quence √ó n_features
    'latent_dim': 32,           # Dimension espace latent
    'hidden_dims': [256, 128],  # Couches cach√©es encoder/decoder
    'activation': 'relu',
    'dropout': 0.1,
    
    # Entra√Ænement
    'n_epochs': 100,
    'batch_size': 32,
    'learning_rate': 1e-3,
    'beta': 1.0,                # Poids KL divergence (Œ≤-VAE)
    
    # Sauvegarde
    'checkpoint_dir': 'checkpoints',
    'save_every': 10,           # Sauvegarder tous les N epochs
    
    # Device
    'device': 'cuda' if torch.cuda.is_available() else 'cpu'
}

print(f"‚öôÔ∏è  Configuration:")
print(f"   Device: {CONFIG['device']}")
print(f"   Latent dim: {CONFIG['latent_dim']}")
print(f"   Batch size: {CONFIG['batch_size']}")
print(f"   Epochs: {CONFIG['n_epochs']}")
print(f"   Learning rate: {CONFIG['learning_rate']}")
print(f"   Beta (KL weight): {CONFIG['beta']}")

## üìÇ Chargement des donn√©es

In [None]:
print("="*70)
print("üìÇ Chargement des donn√©es")
print("="*70)

data_dir = Path(CONFIG['data_dir'])

# Charger s√©quences
train_data = np.load(data_dir / 'train.npy')
val_data = np.load(data_dir / 'val.npy')

print(f"‚úì Train raw: {train_data.shape}")
print(f"‚úì Val raw:   {val_data.shape}")

# Si le mod√®le import√© est un LSTM-based VAE, conserver la forme s√©quentielle
is_lstm_vae = False
try:
    is_lstm_vae = hasattr(VAE, '__name__') and ('LSTM' in VAE.__name__.upper() or 'LSTMVAE' in VAE.__name__.upper())
except NameError:
    is_lstm_vae = False

if is_lstm_vae:
    # Attendu: (n_samples, seq_len, n_features)
    # Si les donn√©es sont d√©j√† en 3D, gardez-les. Si elles sont en 2D (flatten), essayez d'inf√©rer seq_len et input_size.
    if train_data.ndim == 3:
        seq_train = train_data
    elif train_data.ndim == 2:
        # Essayer d'utiliser CONFIG pour restaurer la forme
        seq_len = CONFIG.get('sequence_length')
        input_size = CONFIG.get('input_size')
        if seq_len is None or input_size is None:
            # tenter d'inf√©rer: supposer input_size=1
            input_size = CONFIG.get('input_size', 1)
            if CONFIG.get('input_dim'):
                seq_len = CONFIG['input_dim'] // input_size
            else:
                # fallback: prendre sqrt approximation
                seq_len = train_data.shape[1] // input_size
        train_data = train_data.reshape(len(train_data), seq_len, input_size)
        val_data = val_data.reshape(len(val_data), seq_len, input_size)

    print(f"‚úì Donn√©es format√©es pour LSTM-VAE: {train_data.shape}")

    # Cr√©er DataLoaders (s√©quences 3D)
    train_dataset = TensorDataset(torch.FloatTensor(train_data))
    val_dataset = TensorDataset(torch.FloatTensor(val_data))
else:
    # Flatten pour VAE dense (batch, seq_len*features)
    train_flat = train_data.reshape(len(train_data), -1)
    val_flat = val_data.reshape(len(val_data), -1)

    print(f"\nApr√®s flatten:")
    print(f"  Train: {train_flat.shape}")
    print(f"  Val:   {val_flat.shape}")

    # Cr√©er DataLoaders
    train_dataset = TensorDataset(torch.FloatTensor(train_flat))
    val_dataset = TensorDataset(torch.FloatTensor(val_flat))

train_loader = DataLoader(
    train_dataset, 
    batch_size=CONFIG['batch_size'], 
    shuffle=True
)
val_loader = DataLoader(
    val_dataset, 
    batch_size=CONFIG['batch_size'], 
    shuffle=False
)

print(f"\n‚úì DataLoaders cr√©√©s:")
print(f"  Train batches: {len(train_loader)}")
print(f"  Val batches:   {len(val_loader)}")

## üèóÔ∏è Cr√©ation du mod√®le

In [None]:
print("="*70)
print("üèóÔ∏è  Cr√©ation du mod√®le VAE")
print("="*70)

import inspect

def instantiate_vae_class(VAE_cls):
    sig = inspect.signature(VAE_cls.__init__)
    param_names = set(sig.parameters.keys())
    param_names.discard('self')

    # Old dense VAE signature
    if 'input_dim' in param_names:
        return VAE_cls(
            input_dim=CONFIG['input_dim'],
            latent_dim=CONFIG['latent_dim'],
            hidden_dims=CONFIG.get('hidden_dims', [256, 128]),
            activation=CONFIG.get('activation', 'relu'),
            dropout=CONFIG.get('dropout', 0.1)
        )

    # LSTM VAE signature
    if 'input_size' in param_names or 'sequence_length' in param_names:
        try:
            seq_len = train_data.shape[1]
            input_size = train_data.shape[2] if train_data.ndim == 3 else 1
        except NameError:
            seq_len = CONFIG.get('sequence_length', 288)
            input_size = CONFIG.get('input_size', 1)

        hidden_size = CONFIG.get('hidden_size', CONFIG.get('hidden_dims', [128])[0])
        latent_dim = CONFIG['latent_dim']
        num_layers = CONFIG.get('num_layers', 2)
        dropout = CONFIG.get('dropout', 0.1)
        bidirectional = CONFIG.get('bidirectional', False)

        return VAE_cls(
            input_size=input_size,
            sequence_length=seq_len,
            hidden_size=hidden_size,
            latent_dim=latent_dim,
            num_layers=num_layers,
            dropout=dropout,
            bidirectional=bidirectional
        )

    # Fallback: try simple positional init
    try:
        return VAE_cls(CONFIG['input_dim'], CONFIG['latent_dim'])
    except Exception as e:
        raise RuntimeError(f"Impossible d'instancier le mod√®le VAE: {e}")

# Instantiate
model = instantiate_vae_class(VAE)
model = model.to(CONFIG['device'])

print(f"‚úì Mod√®le cr√©√©: {type(model).__name__}")
print(f"  Device: {CONFIG['device']}")
# Param√®tres si disponible
try:
    print(f"  Param√®tres: {model.count_parameters():,}")
except Exception:
    pass

# Optimiseur
optimizer = optim.Adam(model.parameters(), lr=CONFIG['learning_rate'])
print(f"\n‚úì Optimiseur: Adam (lr={CONFIG['learning_rate']})")

## üî• Entra√Ænement

In [None]:
print("="*70)
print("üî• Entra√Ænement")
print("="*70)

# Tracking
history = {
    'train_loss': [],
    'train_recon': [],
    'train_kl': [],
    'val_loss': [],
    'val_recon': [],
    'val_kl': []
}

best_val_loss = float('inf')

# Cr√©er dossier checkpoints
Path(CONFIG['checkpoint_dir']).mkdir(exist_ok=True)

for epoch in tqdm(range(CONFIG['n_epochs']), desc="Epochs"):
    
    # ========== TRAINING ==========
    model.train()
    train_tracker = VAELossTracker()
    
    for batch in train_loader:
        x = batch[0].to(CONFIG['device'])
        
        # Forward
        x_recon, mu, log_var = model(x)
        
        # Loss
        loss, recon_loss, kl_div = vae_loss(
            x_recon, x, mu, log_var, 
            beta=CONFIG['beta']
        )
        
        # Backward
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        # Track
        train_tracker.update(
            loss.item(), 
            recon_loss.item(), 
            kl_div.item()
        )
    
    train_losses = train_tracker.get_average()
    
    # ========== VALIDATION ==========
    model.eval()
    val_tracker = VAELossTracker()
    
    with torch.no_grad():
        for batch in val_loader:
            x = batch[0].to(CONFIG['device'])
            
            x_recon, mu, log_var = model(x)
            loss, recon_loss, kl_div = vae_loss(
                x_recon, x, mu, log_var,
                beta=CONFIG['beta']
            )
            
            val_tracker.update(
                loss.item(),
                recon_loss.item(),
                kl_div.item()
            )
    
    val_losses = val_tracker.get_average()
    
    # ========== LOGGING ==========
    history['train_loss'].append(train_losses['total'])
    history['train_recon'].append(train_losses['recon'])
    history['train_kl'].append(train_losses['kl'])
    history['val_loss'].append(val_losses['total'])
    history['val_recon'].append(val_losses['recon'])
    history['val_kl'].append(val_losses['kl'])
    
    # Print
    if (epoch + 1) % 10 == 0 or epoch == 0:
        print(f"\nEpoch {epoch+1:3d}/{CONFIG['n_epochs']} | "
              f"Train Loss: {train_losses['total']:.4f} "
              f"(R: {train_losses['recon']:.4f}, KL: {train_losses['kl']:.4f}) | "
              f"Val Loss: {val_losses['total']:.4f}")
    
    # ========== CHECKPOINTING ==========
    
    # Sauvegarder meilleur mod√®le
    if val_losses['total'] < best_val_loss:
        best_val_loss = val_losses['total']
        
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'val_loss': best_val_loss,
            'config': CONFIG
        }, Path(CONFIG['checkpoint_dir']) / 'vae_best.pth')
    
    # Sauvegarder p√©riodiquement
    if (epoch + 1) % CONFIG['save_every'] == 0:
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'history': history,
            'config': CONFIG
        }, Path(CONFIG['checkpoint_dir']) / f'vae_epoch_{epoch+1}.pth')

print(f"\n{'='*70}")
print("‚úÖ Entra√Ænement termin√© !")
print(f"Meilleure val loss: {best_val_loss:.4f}")
print("="*70)

## üìä Visualisation des r√©sultats

In [None]:
print("üìä G√©n√©ration des visualisations...")

fig, axes = plt.subplots(1, 3, figsize=(15, 4))

# Total loss
axes[0].plot(history['train_loss'], label='Train', linewidth=2)
axes[0].plot(history['val_loss'], label='Val', linewidth=2)
axes[0].set_title('Total Loss')
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Loss')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# Reconstruction loss
axes[1].plot(history['train_recon'], label='Train', linewidth=2)
axes[1].plot(history['val_recon'], label='Val', linewidth=2)
axes[1].set_title('Reconstruction Loss')
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Loss')
axes[1].legend()
axes[1].grid(True, alpha=0.3)

# KL divergence
axes[2].plot(history['train_kl'], label='Train', linewidth=2)
axes[2].plot(history['val_kl'], label='Val', linewidth=2)
axes[2].set_title('KL Divergence')
axes[2].set_xlabel('Epoch')
axes[2].set_ylabel('KL')
axes[2].legend()
axes[2].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig(Path(CONFIG['checkpoint_dir']) / 'training_history.png', dpi=150)
print(f"‚úì Sauvegard√©: {CONFIG['checkpoint_dir']}/training_history.png")
plt.show()

# Sauvegarder historique
with open(Path(CONFIG['checkpoint_dir']) / 'history.json', 'w') as f:
    json.dump(history, f, indent=2)

print(f"\n{'='*70}")
print("üéâ TOUT EST PR√äT !")
print("="*70)
print(f"\nüìÅ Fichiers cr√©√©s:")
print(f"  {CONFIG['checkpoint_dir']}/vae_best.pth")
print(f"  {CONFIG['checkpoint_dir']}/training_history.png")
print(f"  {CONFIG['checkpoint_dir']}/history.json")

## üíæ Sauvegarde sur Google Drive (optionnel)

In [None]:
# D√©commentez pour copier les checkpoints vers Google Drive
# import shutil
# drive_checkpoint_dir = '/content/drive/MyDrive/vae_checkpoints'
# Path(drive_checkpoint_dir).mkdir(parents=True, exist_ok=True)
# shutil.copytree(CONFIG['checkpoint_dir'], drive_checkpoint_dir, dirs_exist_ok=True)
# print(f"‚úì Checkpoints copi√©s vers {drive_checkpoint_dir}")

## üöÄ Prochaines √©tapes

Maintenant que votre VAE est entra√Æn√©, vous pouvez :

1. **G√©n√©rer de nouveaux sc√©narios** en √©chantillonnant dans l'espace latent
2. **Analyser l'espace latent** pour comprendre ce que le mod√®le a appris
3. **Interpoler** entre diff√©rents sc√©narios
4. **Reconstruire** des s√©quences existantes pour √©valuer la qualit√©

---

**Bon entra√Ænement ! üéâ**