In [None]:
#!apt-get update
#!apt-get install -y swig python3-dev

In [None]:
# !pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu128
!pip install -r requirements.txt

In [None]:
# %% [markdown]
# # PPO en CarRacing-v3 (discreto) con TorchRL  
#
# Este notebook entrena CarRacing-v3 usando PPO (discreto) de TorchRL, y replica el flujo de tu notebook original para crear el entorno.
# Incluye:
# 1. Registro de un nuevo Gym env “DiscreteCarRacing-v3” que envuelve a CarRacing-v3 con acciones discretas.  
# 2. Creación del entorno TorchRL con `GymEnv("DiscreteCarRacing-v3", continuous=False, render_mode="rgb_array", device=device)` + `TransformedEnv`/`Compose` tal como en tu notebook.  
# 3. Definición del actor–crítico con `SafeModule`, `ProbabilisticActor` y `ValueOperator`.  
# 4. Uso de `SyncDataCollector`, `ClipPPOLoss` y `GAE` para PPO “out-of-the-box” de TorchRL.  
# 5. Barra de progreso (`tqdm`) y función `plot(logs)` idéntica a la tuya, para actualizar en vivo las curvas de “Avg Reward (Train)” y “Avg Reward (Eval)”.  
#
# **Dependencias**:
# ```bash
# pip install torch torchvision torchrl gymnasium[box2d] tqdm matplotlib
# ```

# %% [markdown]
# ## 1. Importaciones y utilidades de graficado

# %% [code]
import warnings
warnings.filterwarnings("ignore")

from IPython.display import clear_output, display
import matplotlib.pyplot as plt
from tqdm import tqdm

import gymnasium as gym
from gymnasium.envs.registration import register
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Categorical

# TorchRL
from torchrl.envs.libs.gym import GymEnv
from torchrl.envs.transforms import (
    DoubleToFloat, ToTensorImage, GrayScale, UnsqueezeTransform, CatFrames,
    ObservationNorm, StepCounter, Compose
)
from torchrl.envs import TransformedEnv
from torchrl.collectors import SyncDataCollector
from torchrl.envs import AutoResetEnv
from torchrl.envs.utils import ExplorationType, set_exploration_type

from torchrl.modules import SafeModule, ProbabilisticActor, ValueOperator
from torchrl.objectives import ClipPPOLoss
from torchrl.objectives.value import GAE


def plot(logs):
    """
    Recibe un dict `logs` con listas:
      logs["train_reward"] = [...]
      logs["eval_reward"]  = [...]
    y dibuja dos subplots:
      · Avg Reward (Train)
      · Avg Reward (Eval)
    Similar a la función de tu notebook original.
    """
    clear_output(wait=True)
    fig, axes = plt.subplots(1, 2, figsize=(6, 3))
    axes = axes.flatten()
    titles = ["Avg Reward (Train)", "Avg Reward (Eval)"]
    data = [
        (logs["train_reward"], "blue"),
        (logs["eval_reward"],  "green"),
    ]
    for ax, title, (y, color) in zip(axes, titles, data):
        ax.plot(y, color=color)
        ax.set_title(title)
        ax.relim()
        ax.autoscale_view()
    plt.tight_layout()
    display(fig)
    plt.close(fig)


# %% [markdown]
# ## 2. Wrapper discreto para CarRacing-v3

# %% [code]
class DiscreteCarRacingWrapper(gym.ActionWrapper):
    """
    Convierte cada acción discreta (0..4) en el vector continuo [steer, gas, brake].
    0: no-op            → [ 0.0,  0.0,  0.0 ]
    1: acelerar        → [ 0.0,  1.0,  0.0 ]
    2: frenar          → [ 0.0,  0.0,  0.8 ]
    3: girar izquierda → [−1.0,  0.0,  0.0 ]
    4: girar derecha   → [ 1.0,  0.0,  0.0 ]
    """
    def __init__(self, env: gym.Env):
        super().__init__(env)
        self.discrete_actions = [
            np.array([ 0.0,  0.0,  0.0 ], dtype=np.float32),
            np.array([ 0.0,  1.0,  0.0 ], dtype=np.float32),
            np.array([ 0.0,  0.0,  0.8 ], dtype=np.float32),
            np.array([-1.0,  0.0,  0.0 ], dtype=np.float32),
            np.array([ 1.0,  0.0,  0.0 ], dtype=np.float32),
        ]
        self.action_space = gym.spaces.Discrete(len(self.discrete_actions))

    def action(self, action_discrete: int) -> np.ndarray:
        return self.discrete_actions[int(action_discrete)]


# %% [markdown]
# ## 3. Registrar “DiscreteCarRacing-v3” en Gym
#
# Aquí creamos un nuevo id de entorno para Gym que internamente usa nuestro wrapper discreto.
# De esta forma podemos usar `GymEnv("DiscreteCarRacing-v3", ...)` sin errores de env_name.

# %% [code]
# Ignoramos excepción si ya está registrado
try:
    register(
        id="DiscreteCarRacing-v3",
        entry_point=lambda **kwargs: DiscreteCarRacingWrapper(
            gym.make("CarRacing-v3", render_mode="rgb_array")
        )
    )
except Exception:
    pass


# %% [markdown]
# ## 4. Crear entornos TorchRL (base “continuo” corregido a “discreto”)
#
# En tu notebook original usabas algo como:
# ```
# base_env = GymEnv("CarRacing-v3", continuous=cont, render_mode="rgb_array", device=device)
# env = TransformedEnv(base_env, Compose(DoubleToFloat(), ToTensorImage(), GrayScale(), ...))
# ```
# Ahora reemplazamos `"CarRacing-v3"` con `"DiscreteCarRacing-v3"` y `continuous=False`.

# %% [code]
def make_torchrl_env(device: str = "cpu"):
    """
    Crea un TransformedEnv para “DiscreteCarRacing-v3” siguiendo el patrón de tu notebook:
    1. GymEnv("DiscreteCarRacing-v3", continuous=False, render_mode="rgb_array", device=device)
    2. Transformaciones idénticas: DoubleToFloat, ToTensorImage, GrayScale, UnsqueezeTransform, CatFrames, ObservationNorm, StepCounter.
    """
    # 1) GymEnv con nombre de entorno discreto
    base_env = GymEnv(
        "DiscreteCarRacing-v3",
        continuous=False,
        render_mode="rgb_array",
        device=device
    )

    # 2) Armar TransformedEnv con tus mismas transformaciones (4 frames apilados)
    env = TransformedEnv(
        base_env,
        Compose(
            DoubleToFloat(),            # uint8 → float32
            ToTensorImage(),            # HWC → CHW en [0,1]
            GrayScale(),                # pasar a escala de grises
            UnsqueezeTransform(-4),     # agrega dimensión de batch “enmedio”
            CatFrames(dim=-3, N=4),     # apila 4 frames para el agente recurrente
            ObservationNorm(in_keys=["pixels"]),  # normaliza píxeles
            StepCounter()               # agrega key “step_count” si quieres
        )
    )

    # Inicializar estadísticas de ObservationNorm (igual que en tu notebook)
    env.transform[-2].init_stats(num_iter=256, reduce_dim=0, cat_dim=0)

    return env


# %% [markdown]
# ## 5. Definir Actor–Crítico con SafeModule + ProbabilisticActor + ValueOperator

# %% [code]
# 5.1) Cuerpo CNN (idéntico al tuyo, pero ajustado a input de 4 frames en escala de grises)
class CNNBody(nn.Module):
    def __init__(self):
        super().__init__()
        # Ahora la entrada es (4, 96, 96) en escala de grises, no 3 canales
        self.conv1 = nn.Conv2d(4, 32, kernel_size=8, stride=4)   # → [32,23,23]
        self.conv2 = nn.Conv2d(32, 64, kernel_size=4, stride=2)  # → [64,10,10]
        self.conv3 = nn.Conv2d(64, 64, kernel_size=3, stride=1)  # → [64,8,8]
        self.flatten_dim = 64 * 8 * 8
        self.fc = nn.Linear(self.flatten_dim, 512)

    def forward(self, obs_image: torch.Tensor) -> torch.Tensor:
        x = F.relu(self.conv1(obs_image))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))
        x = x.view(-1, self.flatten_dim)
        x = F.relu(self.fc(x))
        return x  # → [B,512]

# 5.2) Red Actor–Crítico completa
class ActorCriticNet(nn.Module):
    def __init__(self, num_actions: int = 5):
        super().__init__()
        self.body = CNNBody()
        self.policy_head = nn.Linear(512, num_actions)  # logits [B,5]
        self.value_head  = nn.Linear(512, 1)            # value [B,1]

    def forward(self, obs_image: torch.Tensor):
        features = self.body(obs_image)                  # [B,512]
        logits   = self.policy_head(features)            # [B,5]
        value    = self.value_head(features).squeeze(-1) # [B]
        return logits, value

# 5.3) SafeModule + ProbabilisticActor + ValueOperator
def make_ac_modules(device: str = "cpu"):
    """
    Devuelve:
      - actor_module: ProbabilisticActor que recibe 'pixels' y sale 'action', 'log_prob'
      - value_module: ValueOperator que recibe 'pixels' y sale 'state_value'
    """
    base_net = ActorCriticNet(num_actions=5).to(device)

    # (a) Política → SafeModule produce (logits, value), luego ProbabilisticActor
    actor_sm = SafeModule(
        module=base_net,
        in_keys=["pixels"],
        out_keys=["logits", "value"],
        function=lambda m, x: m(x)
    )
    actor_module = ProbabilisticActor(
        module=actor_sm,
        in_keys=["logits"],
        out_keys=["action", "log_prob"],
        distribution_class=Categorical,
        distribution_kwargs={},         # Categorical(logits=logits)
        return_log_prob=True,
        default_interaction_mode="random",  # samplea en entrenamiento
    )

    # (b) Crítico → SafeModule extrae “value” para ValueOperator
    value_sm = SafeModule(
        module=base_net,
        in_keys=["pixels"],
        out_keys=["state_value"],
        function=lambda m, x: m(x)[1].unsqueeze(-1)
    )
    value_module = ValueOperator(
        in_keys=["state_value"],
        out_keys=["state_value"],
    )

    return actor_module, value_module


# %% [markdown]
# ## 6. Configuración del Collector y pérdida PPO (ClipPPOLoss + GAE)

# %% [code]
# 6.1) Hiperparámetros
frames_per_batch   = 1024    # pasos por iteración
num_iterations     = 256     # iteraciones (épocas) totales
ppo_epoch          = 4       # pasadas sobre cada batch
sub_batch_size     = 128     # tamaño de minibatch en PPO
gamma              = 0.99
lam                = 0.95
learning_rate      = 2e-4

# 6.2) Crear entornos TorchRL
device = "cuda" if torch.cuda.is_available() else "cpu"
env_train = make_torchrl_env(device=device)
env_eval  = make_torchrl_env(device=device)

# 6.3) Crear Actor y Crítico
actor_module, value_module = make_ac_modules(device=device)

# 6.4) Establecer exploración “on-policy”
set_exploration_type(ExplorationType.RANDOM)

# 6.5) ClipPPOLoss y GAE
clip_ppo_loss = ClipPPOLoss(
    actor=actor_module,
    value_network=value_module,
    clip_eps=0.2,
    value_loss_coeff=0.5,
    entropy_coeff=0.01,
    gamma=gamma,
    lambda_=lam,
    max_grad_norm=0.5,
)
optimizer = torch.optim.Adam(
    list(actor_module.parameters()) + list(value_module.parameters()),
    lr=learning_rate
)
gae_module = GAE(
    gamma=gamma,
    lambda_=lam,
    reduction="mean"
)

# 6.6) Collector síncrono con AutoResetEnv (idéntico a tu notebook)
train_collector = SyncDataCollector(
    env=AutoResetEnv(env_train),
    policy=actor_module,
    frames_per_batch=frames_per_batch,
    total_frames=None,
    device=device,
)


# %% [markdown]
# ## 7. Rutina de entrenamiento + evaluación con tqdm y plot(logs)

# %% [code]
# 7.1) Logs para graficar
logs = {
    "train_reward": [],
    "eval_reward":  []
}
num_eval_episodes = 3  # episodios de evaluación por iteración

pbar = tqdm(range(num_iterations), desc="PPO Entrenando", unit="iter")
for iter_idx in pbar:
    # —————————————————————————————————————————————————
    # (1) RECOLECTAR frames_per_batch pasos (rollout on-policy)
    # —————————————————————————————————————————————————
    experience = next(train_collector)
    # experience: TensorDict con keys como:
    #   "pixels", "action", "log_prob", "state_value", "next_pixels", "reward", "done", "rollout_info", ...

    # —————————————————————————————————————————————————
    # (2) CALCULAR VENTAJAS Y RETORNOS (GAE)
    # —————————————————————————————————————————————————
    _ = gae_module(experience)
    # Ahora `experience` contiene "advantage" y "return"

    # —————————————————————————————————————————————————
    # (3) OPTIMIZAR CON ClipPPOLoss
    # —————————————————————————————————————————————————
    _ = clip_ppo_loss(
        experience,
        optimizer=optimizer,
        ppo_epoch=ppo_epoch,
        mini_batch_size=sub_batch_size,
    )

    # —————————————————————————————————————————————————
    # (4) CALCULAR avg_train_reward
    # —————————————————————————————————————————————————
    train_episode_rewards = experience.get(("rollout_info", "reward_sum")).cpu().numpy()
    avg_train_reward = float(np.mean(train_episode_rewards))
    logs["train_reward"].append(avg_train_reward)

    # —————————————————————————————————————————————————
    # (5) EVALUACIÓN (política determinista)
    # —————————————————————————————————————————————————
    set_exploration_type(ExplorationType.MODE)  # fuerza argmax en ProbabilisticActor
    eval_episode_rewards = []
    for _ in range(num_eval_episodes):
        td = env_eval.reset()  # Tensordict con “pixels”
        done = False
        ep_rew = 0.0
        while not done:
            obs_td = td.clone().to(device)
            out = actor_module(obs_td)                 # saca acción determinista
            action = out.get("action").item()
            td, reward, term, trunc, _ = env_eval.step(action)
            done = bool(term or trunc)
            ep_rew += reward
        eval_episode_rewards.append(ep_rew)
    avg_eval_reward = float(np.mean(eval_episode_rewards))
    logs["eval_reward"].append(avg_eval_reward)
    set_exploration_type(ExplorationType.RANDOM)  # de vuelta a modo entrenamiento

    # —————————————————————————————————————————————————
    # (6) Actualizar tqdm y graficar
    # —————————————————————————————————————————————————
    pbar.set_postfix({
        "avg_train": f"{avg_train_reward:.1f}",
        "avg_eval":  f"{avg_eval_reward:.1f}"
    })
    plot(logs)

pbar.close()


# %% [markdown]
# ## 8. Guardar modelo final
#
# Si quieres salvar el checkpoint del actor y el crítico:
# ```python
# save_name = f"ppo_cr_{uuid.uuid4().hex[:6]}.pt"
# checkpoint = {
#     "actor_state_dict": actor_module.state_dict(),
#     "value_state_dict": value_module.state_dict(),
# }
# torch.save(checkpoint, save_name)
# print("Modelo guardado en", save_name)
# ```
#
# ———————————————————————————————————————————————————————————————————————————————  
