<a href="https://colab.research.google.com/github/JulienHelfenstein/World_model/blob/main/05_run_agent.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install gymnasium[box2d] numpy torch opencv-python tqdm pyvirtualdisplay xvfbwrapper &> /dev/null

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import gymnasium
import numpy as np
import os
from time import sleep
from google.colab import drive
from pyvirtualdisplay import Display
from gymnasium.wrappers import RecordVideo
from glob import glob
from IPython.display import HTML, display

# Configurer l'affichage virtuel (nécessaire pour Colab)
display_colab = Display(visible=0, size=(1400, 900))
display_colab.start()

# Monter Google Drive
drive.mount('/content/drive')

In [None]:
PROJECT_ROOT = "/content/drive/My Drive/Colab Notebooks/World_model"
VIDEO_DIR = os.path.join(PROJECT_ROOT, "videos")

# S'assurer que le dossier vidéo existe
if not os.path.exists(VIDEO_DIR):
    os.makedirs(VIDEO_DIR)

VAE_MODEL_PATH = os.path.join(PROJECT_ROOT, "vae.pth")
RNN_MODEL_PATH = os.path.join(PROJECT_ROOT, "rnn.pth")
CONTROLLER_SAVE_PATH = os.path.join(PROJECT_ROOT, "controller.pth")

z_dim = 32
action_dim = 3
hidden_dim = 256
num_mixtures = 5

device = torch.device("cpu") # CPU suffisant pour le run

In [None]:
class CVAE(nn.Module):
    def __init__(self, z_dim, image_channels=3):
        super(CVAE, self).__init__()
        self.z_dim = z_dim
        self.encoder = nn.Sequential(
            nn.Conv2d(image_channels, 32, 4, 2, 1), nn.ReLU(),
            nn.Conv2d(32, 64, 4, 2, 1), nn.ReLU(),
            nn.Conv2d(64, 128, 4, 2, 1), nn.ReLU(),
            nn.Conv2d(128, 256, 4, 2, 1), nn.ReLU()
        )
        self.flat_size = 256 * 4 * 4
        self.fc_mu = nn.Linear(self.flat_size, z_dim)
        self.fc_logvar = nn.Linear(self.flat_size, z_dim)

    def encode(self, x):
        h = self.encoder(x); h_flat = h.view(-1, self.flat_size)
        return self.fc_mu(h_flat), self.fc_logvar(h_flat)

    def reparameterize(self, mu, log_var):
        std = torch.exp(0.5 * log_var); eps = torch.randn_like(std)
        return mu + eps * std

class MDNRNN(nn.Module):
    def __init__(self, z_dim, action_dim, hidden_dim, num_mixtures):
        super(MDNRNN, self).__init__()
        self.z_dim = z_dim; self.hidden_dim = hidden_dim; self.num_mixtures = num_mixtures
        input_dim = z_dim + action_dim
        self.lstm = nn.LSTM(input_dim, hidden_dim, batch_first=True)
        mdn_output_dim = num_mixtures * (1 + 2 * z_dim)
        self.mdn_output = nn.Linear(hidden_dim, mdn_output_dim)
        self.reward_head = nn.Linear(hidden_dim, 1)
        self.done_head = nn.Linear(hidden_dim, 1)

    def forward(self, z_t, a_t, hidden_state):
        lstm_input = torch.cat([z_t, a_t], dim=-1).unsqueeze(1)
        lstm_output, next_hidden = self.lstm(lstm_input, hidden_state)
        lstm_output = lstm_output.squeeze(1)
        mdn_params = self.mdn_output(lstm_output)
        pred_reward = self.reward_head(lstm_output)
        pred_done_logits = self.done_head(lstm_output)
        return mdn_params, pred_reward, pred_done_logits, next_hidden

class Controller(nn.Module):
    def __init__(self, z_dim, hidden_dim, action_dim):
        super(Controller, self).__init__()
        self.fc = nn.Linear(z_dim + hidden_dim, action_dim)

    def forward(self, z_t, h_t):
        action_unscaled = self.fc(torch.cat([z_t, h_t], dim=-1))
        steer = torch.tanh(action_unscaled[:, 0:1])
        gas = torch.sigmoid(action_unscaled[:, 1:2])
        brake = torch.sigmoid(action_unscaled[:, 2:3])
        return torch.cat([steer, gas, brake], dim=-1)

In [None]:
def preprocess_obs(obs):
    obs_tensor = torch.from_numpy(obs).permute(2, 0, 1).float() / 255.0
    obs_tensor = obs_tensor.unsqueeze(0).to(device)
    obs_resized = F.interpolate(obs_tensor, size=(64, 64), mode='bilinear', align_corners=False)
    return obs_resized

In [None]:
def run_agent():
    print("Chargement des modèles entraînés...")

    vae = CVAE(z_dim).to(device); vae.load_state_dict(torch.load(VAE_MODEL_PATH, map_location=device)); vae.eval()
    rnn = MDNRNN(z_dim, action_dim, hidden_dim, num_mixtures).to(device); rnn.load_state_dict(torch.load(RNN_MODEL_PATH, map_location=device)); rnn.eval()
    controller = Controller(z_dim, hidden_dim, action_dim).to(device); controller.load_state_dict(torch.load(CONTROLLER_SAVE_PATH, map_location=device)); controller.eval()

    print("Modèles chargés. Lancement de l'environnement...")

    # --- MODIFICATION POUR LA VIDÉO ---
    # 1. Utiliser 'rgb_array' au lieu de 'human'
    env = gymnasium.make('CarRacing-v2', render_mode='rgb_array')
    # 2. Envelopper l'environnement pour enregistrer une vidéo
    # (On enregistre seulement le premier épisode pour l'exemple)
    env = RecordVideo(env, video_folder=VIDEO_DIR, episode_trigger=lambda e: e == 0)

    for episode in range(2): # Lancer 2 épisodes (seul le 1er sera enregistré)
        print(f"--- Début de l'Épisode {episode + 1} ---")
        total_reward = 0
        obs, _ = env.reset()
        h_t = torch.zeros(1, hidden_dim).to(device)
        c_t = torch.zeros(1, hidden_dim).to(device)

        while True:
            with torch.no_grad():
                obs_preprocessed = preprocess_obs(obs)
                z_t = vae.reparameterize(*vae.encode(obs_preprocessed))
                action_tensor = controller(z_t, h_t)
                _, _, _, (h_t, c_t) = rnn(z_t, action_tensor, (h_t.unsqueeze(0), c_t.unsqueeze(0)))
                h_t, c_t = h_t.squeeze(0), c_t.squeeze(0)

            action_np = action_tensor.squeeze(0).cpu().numpy()
            obs, reward, terminated, truncated, _ = env.step(action_np)
            total_reward += reward

            if terminated or truncated:
                break

        print(f"Épisode terminé. Récompense totale : {total_reward:.2f}")

    env.close()
    print("Simulation terminée.")

    # --- 6. Afficher la vidéo dans Colab ---
    print("Affichage de la vidéo enregistrée...")
    video_files = glob(os.path.join(VIDEO_DIR, "*.mp4"))

    if video_files:
        video_path = sorted(video_files)[-1] # Prendre la vidéo la plus récente
        html = f"""
        <video width="600" controls>
          <source src="{video_path}" type="video/mp4">
        </video>
        """
        display(HTML(html))
        print(f"Vidéo affichée depuis {video_path}")
    else:
        print("Aucun fichier vidéo trouvé.")

# --- Lancer le tout ---
if __name__ == "__main__":
    run_agent()