<a href="https://colab.research.google.com/github/JulienHelfenstein/World_model/blob/main/03_train_rnn.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from google.colab import drive
drive.mount('/content/drive')

# Définir le chemin racine de votre projet
PROJECT_ROOT = "/content/drive/My Drive/Colab Notebooks/World_model"

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
import torch.optim as optim
import numpy as np
import os
from tqdm import tqdm
from torch.distributions import Categorical, Normal, MixtureSameFamily

In [None]:
# --- 1. Configuration et Hyperparamètres ---
VAE_MODEL_PATH = os.path.join(PROJECT_ROOT, "vae.pth")
DATA_FILE = os.path.join(PROJECT_ROOT, "data/carracing_data.npz")
RNN_DATA_FILE = os.path.join(PROJECT_ROOT, "data/rnn_data.npz")
RNN_MODEL_PATH = os.path.join(PROJECT_ROOT, "rnn.pth")

# Paramètres (doivent correspondre au VAE et aux données)
z_dim = 32
action_dim = 3  # CarRacing: [steer, gas, brake]
hidden_dim = 256 # Taille de la mémoire du LSTM
num_mixtures = 5 # Nombre de "futurs" possibles à prédire
seq_length = 50  # Longueur des séquences pour l'entraînement du RNN
batch_size = 32
learning_rate = 1e-3
num_epochs = 10  # 10-20 époques est un bon début

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
# --- 2. Définition des Modèles ---

# On a besoin de la *définition* du CVAE pour charger le state_dict
# (copié-collé de train_vae_carracing.py)
class CVAE(nn.Module):
    def __init__(self, z_dim, image_channels=3):
        super(CVAE, self).__init__()
        self.z_dim = z_dim
        self.encoder = nn.Sequential(
            nn.Conv2d(image_channels, 32, 4, 2, 1), nn.ReLU(),
            nn.Conv2d(32, 64, 4, 2, 1), nn.ReLU(),
            nn.Conv2d(64, 128, 4, 2, 1), nn.ReLU(),
            nn.Conv2d(128, 256, 4, 2, 1), nn.ReLU()
        )
        self.flat_size = 256 * 4 * 4
        self.fc_mu = nn.Linear(self.flat_size, z_dim)
        self.fc_logvar = nn.Linear(self.flat_size, z_dim)

    def encode(self, x):
        h = self.encoder(x)
        h_flat = h.view(-1, self.flat_size)
        return self.fc_mu(h_flat), self.fc_logvar(h_flat)

    def reparameterize(self, mu, log_var):
        std = torch.exp(0.5 * log_var)
        eps = torch.randn_like(std)
        return mu + eps * std

    # Le décodeur n'est pas nécessaire ici, mais on garde encode/reparam

# Le "Moteur de Rêve" (MDN-RNN)

class MDNRNN(nn.Module):
    def __init__(self, z_dim, action_dim, hidden_dim, num_mixtures):
        super(MDNRNN, self).__init__()
        self.z_dim = z_dim
        self.hidden_dim = hidden_dim
        self.num_mixtures = num_mixtures

        input_dim = z_dim + action_dim
        self.lstm = nn.LSTM(input_dim, hidden_dim, batch_first=True)

        mdn_output_dim = num_mixtures * (1 + 2 * z_dim)
        self.mdn_output = nn.Linear(hidden_dim, mdn_output_dim)

    def forward(self, z_seq, a_seq, hidden_state):
        # z_seq shape: (batch_size, seq_len, z_dim)
        # a_seq shape: (batch_size, seq_len, action_dim)
        lstm_input = torch.cat([z_seq, a_seq], dim=-1)

        # lstm_output shape: (batch_size, seq_len, hidden_dim)
        lstm_output, next_hidden = self.lstm(lstm_input, hidden_state)

        # mdn_params shape: (batch_size, seq_len, mdn_output_dim)
        mdn_params = self.mdn_output(lstm_output)

        distribution = self.get_mixture_distribution(mdn_params)
        return distribution, next_hidden

    def get_mixture_distribution(self, mdn_params):
        # Sépare les paramètres (batch_size, seq_len, N_mix * (1 + 2*z_dim))
        # en pi, mu, sigma

        # pi_logits shape: (batch_size, seq_len, num_mixtures)
        pi_logits = mdn_params[..., :self.num_mixtures]

        # mu shape: (batch_size, seq_len, num_mixtures, z_dim)
        mu = mdn_params[..., self.num_mixtures : self.num_mixtures * (1 + self.z_dim)]
        mu = mu.view(mdn_params.size(0), mdn_params.size(1), self.num_mixtures, self.z_dim)

        # log_sigma shape: (batch_size, seq_len, num_mixtures, z_dim)
        log_sigma = mdn_params[..., self.num_mixtures * (1 + self.z_dim) :]
        log_sigma = log_sigma.view(mdn_params.size(0), mdn_params.size(1), self.num_mixtures, self.z_dim)

        pi_dist = Categorical(logits=pi_logits)
        sigma = torch.exp(log_sigma) + 1e-6
        gaussian_dist = Normal(loc=mu, scale=sigma)

        # Note: We need to expand gaussian_dist to match dimensions for MixtureSameFamily
        # This is a bit advanced, but it's needed for batching sequences
        mixture_dist = MixtureSameFamily(pi_dist, gaussian_dist)
        return mixture_dist

In [None]:
# --- 3. Fonction de Perte (Loss) du MDN ---
def mdn_loss_function(mixture_distribution, target_z):
    # target_z shape: (batch_size, seq_len, z_dim)
    # Le 'log_prob' du mélange a besoin que la cible soit "expand"
    # pour correspondre à la forme (batch_size, seq_len, num_mixtures, z_dim)
    target_z_expanded = target_z.unsqueeze(2).expand_as(mixture_distribution.component_distribution.loc)

    log_prob = mixture_distribution.log_prob(target_z_expanded)

    # On moyenne sur le batch et la séquence, et on minimise le négatif
    return -torch.mean(log_prob)

In [None]:
# --- 4. Phase 2a: Pré-traitement des données (VAE -> Z) ---
def create_rnn_data():
    if os.path.exists(RNN_DATA_FILE):
        print(f"Le fichier de données pré-traitées {RNN_DATA_FILE} existe déjà.")
        return

    print("Phase 2a: Pré-traitement des données (Images -> Vecteurs Z)...")

    # 1. Charger le VAE entraîné
    vae_model = CVAE(z_dim).to(device)
    vae_model.load_state_dict(torch.load(VAE_MODEL_PATH, map_location=device))
    vae_model.eval() # Mode évaluation (gèle les poids)

    # 2. Charger les données brutes (images + actions)
    raw_data = np.load(DATA_FILE)
    observations = raw_data['observations'] # shape (N, 64, 64, 3)
    actions = raw_data['actions']           # shape (N, 3)

    # 3. Transformer les images en tenseurs (N, C, H, W)
    obs_tensor = torch.from_numpy(observations).permute(0, 3, 1, 2).to(device, dtype=torch.float32)

    z_vectors = []

    # 4. Encoder toutes les images en vecteurs Z
    # On traite par petits batchs pour ne pas saturer la VRAM
    vae_batch_size = 256
    with torch.no_grad(): # TRES IMPORTANT: pas de calcul de gradient
        pbar = tqdm(range(0, len(obs_tensor), vae_batch_size), desc="Encodage VAE")
        for i in pbar:
            batch_obs = obs_tensor[i : i + vae_batch_size]
            mu, log_var = vae_model.encode(batch_obs)
            z = vae_model.reparameterize(mu, log_var)
            z_vectors.append(z.cpu().numpy()) # Stocker sur CPU

    # 5. Concaténer et sauvegarder
    all_z = np.concatenate(z_vectors, axis=0)
    all_actions = actions # Les actions n'ont pas besoin de changer

    print(f"Encodage terminé. Shape Z: {all_z.shape}, Shape A: {all_actions.shape}")
    np.savez_compressed(RNN_DATA_FILE, z_vectors=all_z, actions=all_actions)
    print(f"Données pré-traitées sauvegardées dans {RNN_DATA_FILE}")

In [None]:
# --- 5. Phase 2b: Dataset pour Séquences ---
class SequenceDataset(Dataset):
    def __init__(self, data_file, seq_length):
        data = np.load(data_file)
        self.z_vectors = torch.from_numpy(data['z_vectors']).float()
        self.actions = torch.from_numpy(data['actions']).float()
        self.seq_length = seq_length

    def __len__(self):
        # On ne peut pas commencer une séquence près de la fin
        return len(self.z_vectors) - self.seq_length - 1

    def __getitem__(self, idx):
        # Séquence d'entrée
        z_seq = self.z_vectors[idx : idx + self.seq_length]
        a_seq = self.actions[idx : idx + self.seq_length]

        # Séquence cible (décalée d'un pas dans le temps)
        target_z_seq = self.z_vectors[idx + 1 : idx + self.seq_length + 1]

        return z_seq, a_seq, target_z_seq

In [None]:
# --- 6. Script Principal d'Entraînement (Phase 2b) ---
if __name__ == "__main__":

    # 1. Lancer la phase 2a (pré-traitement)
    create_rnn_data()

    print("Phase 2b: Entraînement du MDN-RNN...")

    # 2. Créer le Dataset et le DataLoader
    dataset = SequenceDataset(RNN_DATA_FILE, seq_length)
    data_loader = DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=2
    )

    # 3. Initialiser le Modèle et l'Optimiseur
    model = MDNRNN(z_dim, action_dim, hidden_dim, num_mixtures).to(device)
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)

    model.train() # Mode entraînement

    # 4. Boucle d'Entraînement
    for epoch in range(num_epochs):
        pbar = tqdm(data_loader, desc=f"Epoch {epoch+1}/{num_epochs}")
        total_epoch_loss = 0

        for z_seq, a_seq, target_z_seq in pbar:
            z_seq = z_seq.to(device)
            a_seq = a_seq.to(device)
            target_z_seq = target_z_seq.to(device)

            # Initialiser l'état caché (h, c) pour le LSTM
            hidden_state = (torch.zeros(1, batch_size, hidden_dim).to(device),
                            torch.zeros(1, batch_size, hidden_dim).to(device))

            # --- Forward pass ---
            # Gérer le dernier batch (qui peut être plus petit)
            if z_seq.size(0) != batch_size:
                hidden_state = (torch.zeros(1, z_seq.size(0), hidden_dim).to(device),
                                torch.zeros(1, z_seq.size(0), hidden_dim).to(device))

            distribution, _ = model(z_seq, a_seq, hidden_state)

            # --- Calcul de la perte ---
            loss = mdn_loss_function(distribution, target_z_seq)

            # --- Backward pass ---
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_epoch_loss += loss.item()
            pbar.set_postfix(loss=f"{loss.item():.4f}")

        avg_loss = total_epoch_loss / len(data_loader)
        print(f"Fin Epoch {epoch+1}. Perte moyenne : {avg_loss:.4f}")

    # 5. Sauvegarder le modèle
    print("Entraînement du RNN terminé.")
    torch.save(model.state_dict(), RNN_MODEL_PATH)
    print(f"Modèle RNN sauvegardé dans {RNN_MODEL_PATH}")