In [1]:
%pip uninstall -y torch torchvision torchaudio


Found existing installation: torch 2.6.0+cu124
Uninstalling torch-2.6.0+cu124:
  Successfully uninstalled torch-2.6.0+cu124
Found existing installation: torchvision 0.21.0+cu124
Uninstalling torchvision-0.21.0+cu124:
  Successfully uninstalled torchvision-0.21.0+cu124
Found existing installation: torchaudio 2.6.0+cu124
Uninstalling torchaudio-2.6.0+cu124:
  Successfully uninstalled torchaudio-2.6.0+cu124
Note: you may need to restart the kernel to use updated packages.


In [2]:
%pip uninstall -y box2d-py Box2D box2d



Note: you may need to restart the kernel to use updated packages.


In [3]:
# Preferido (tiene wheels para py3.12):
%pip install --no-cache-dir -q --only-binary=:all: Box2D==2.3.10

# Si (y solo si) lo anterior fallara:
# %pip install --no-cache-dir -q --only-binary=:all: box2d-py==2.3.8


Note: you may need to restart the kernel to use updated packages.


In [4]:
# CPU seguro
%pip install -q --index-url https://download.pytorch.org/whl/cpu torch==2.5.1+cpu
# (o cu121 si estás 100% seguro que tu entorno tiene CUDA 12.1 funcional)


Note: you may need to restart the kernel to use updated packages.


In [5]:
%pip install -q numpy==1.26.4 matplotlib==3.8.4 tqdm==4.66.4 cma==3.2.2 imageio==2.36.0 imageio-ffmpeg==0.5.1


Note: you may need to restart the kernel to use updated packages.


In [6]:
%pip install --no-cache-dir -q pygame==2.6.1


Note: you may need to restart the kernel to use updated packages.


In [7]:
!pip install box2d==2.3.10 gymnasium==0.29.1 pygame==2.6.1 numpy==1.26.4 matplotlib==3.8.4 tqdm==4.66.4


Collecting gymnasium==0.29.1
  Downloading gymnasium-0.29.1-py3-none-any.whl.metadata (10 kB)
Collecting farama-notifications>=0.0.1 (from gymnasium==0.29.1)
  Downloading Farama_Notifications-0.0.4-py3-none-any.whl.metadata (558 bytes)
Downloading gymnasium-0.29.1-py3-none-any.whl (953 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m953.9/953.9 kB[0m [31m3.8 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
Downloading Farama_Notifications-0.0.4-py3-none-any.whl (2.5 kB)
Installing collected packages: farama-notifications, gymnasium
Successfully installed farama-notifications-0.0.4 gymnasium-0.29.1


In [8]:
import gymnasium as gym, Box2D, torch, numpy as np
print("Gym:", gym.__version__, "| Torch:", torch.__version__)
env = gym.make("LunarLander-v2")
s, _ = env.reset()
print("Obs:", s.shape, "Actions:", env.action_space.n)


Gym: 0.29.1 | Torch: 2.5.1+cpu
Obs: (8,) Actions: 4


In [9]:
# ===========================================
# 1) Carpeta del proyecto + semillas + utils
# ===========================================
from pathlib import Path
import json, math, random, time
import numpy as np
import torch
import gymnasium as gym
from collections import deque, namedtuple

# --- Carpeta donde guardaremos TODO (dataset, modelos, videos, figuras)
ROOT = Path("./world_models")
for p in [ROOT, ROOT/"data", ROOT/"checkpoints", ROOT/"videos", ROOT/"figs"]:
    p.mkdir(parents=True, exist_ok=True)

# --- Semillas para reproducibilidad
SEED = 42
random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)

# --- Elegimos la versión del entorno disponible
def pick_lander_env():
    for env_id in ("LunarLander-v3","LunarLander-v2"):
        try:
            env = gym.make(env_id); env.close()
            return env_id
        except:
            pass
    raise RuntimeError("No encontré LunarLander v2/v3.")
ENV_ID = pick_lander_env()
print("Usando:", ENV_ID)

# --- Rango “seguro” para recortar observaciones (ayuda a estabilizar el modelo)
LOW  = np.array([-1.5,-0.5,-2.0,-2.0,-math.pi,-3.0,0.0,0.0],dtype=np.float32)
HIGH = np.array([ 1.5, 1.5, 2.0, 2.0, math.pi, 3.0,1.0,1.0],dtype=np.float32)
def clip_obs(s: np.ndarray) -> np.ndarray:
    """Recorta la observación a rangos razonables y asegura dtype float32."""
    s = np.asarray(s, np.float32)
    s[:6] = np.clip(s[:6], LOW[:6], HIGH[:6])  # las 6 primeras son continuas
    s[6:] = np.clip(s[6:], 0.0, 1.0)           # patas ya están en [0,1]
    return s

# --- Tupla para transiciones (s, a, r, s', done)
Transition = namedtuple("Transition", ["s","a","r","sp","d"])

# --- Buffer simple (lo usaremos para recolectar dataset real)
class ReplayBuffer:
    def __init__(self, cap=1_000_000): self.buf = deque(maxlen=cap)
    def __len__(self): return len(self.buf)
    def add(self, *args): self.buf.append(Transition(*args))

print("Estructura lista:", list(ROOT.iterdir()))


Usando: LunarLander-v2
Estructura lista: [PosixPath('world_models/figs'), PosixPath('world_models/data'), PosixPath('world_models/videos'), PosixPath('world_models/checkpoints')]


In [10]:
# ==========================================================
# 2) Recolecta un dataset real con vecenv (rápido y estable)
# ==========================================================
from tqdm import trange

def collect_dataset(env_id=ENV_ID, steps=300_000, num_envs=8,
                    eps_start=1.0, eps_end=0.05):
    """
    Recolecta 'steps' transiciones usando ε-greedy MUY simple.
    - num_envs: entornos en paralelo (acelera muchísimo la recolección)
    - epsilon: parte alto y decae linealmente; así mezclamos aleatorio y heurística
    Guarda S, A, R, Sp, D + estadísticas para normalizar.
    """
    env = gym.vector.make(env_id, num_envs=num_envs, asynchronous=True)
    s, _ = env.reset(seed=SEED)
    s = np.stack([clip_obs(si) for si in s])

    S, A, R, Sp, D = [], [], [], [], []
    eps = eps_start
    pbar = trange(0, steps, num_envs, ncols=120, desc="Recolectando")
    while len(S) < steps:
        # Política “tonta”: mira la altura y decide algo grosero sobre el motor
        greedy = np.zeros((num_envs,), dtype=np.int64)
        greedy[s[:,1] > 0.3] = 2  # (heurística arbitraria sólo para explorar)
        greedy[s[:,1] < -0.3] = 1
        a = greedy
        # ε-greedy
        mask = np.random.rand(num_envs) < eps
        a[mask] = np.random.randint(0, 4, size=mask.sum())
        # decaimos epsilon proporcional a transiciones
        eps = max(eps_end, eps - (eps_start-eps_end) / (steps/num_envs))

        sp, r, term, trunc, _ = env.step(a)
        d = np.logical_or(term, trunc)
        sp = np.stack([clip_obs(si) for si in sp])

        S.append(s.copy()); A.append(a.copy()); R.append(r.copy()); Sp.append(sp.copy()); D.append(d.astype(np.float32))
        s = sp
        pbar.update(num_envs)
    env.close()

    # Concatenamos y guardamos
    S = np.concatenate(S, 0); A = np.concatenate(A, 0); R = np.concatenate(R, 0)
    Sp = np.concatenate(Sp, 0); D = np.concatenate(D, 0)

    np.savez_compressed(ROOT/"data"/"lander_dataset.npz", S=S, A=A, R=R, Sp=Sp, D=D)
    print("Dataset guardado en:", (ROOT/"data"/"lander_dataset.npz").resolve())
    return S, A, R, Sp, D

# 🔁 Si es la primera vez, descomenta:
S, A, R, Sp, D = collect_dataset(steps=300_000, num_envs=8)
# Si ya lo tienes guardado, puedes cargarlo así:
# data = np.load(ROOT/"data"/"lander_dataset.npz")
# S, A, R, Sp, D = data["S"], data["A"], data["R"], data["Sp"], data["D"]


Recolectando: 2400000it [01:40, 23837.69it/s]                                                                           

Dataset guardado en: /home/jovyan/MVP_RL_LunarLander/world_models/data/lander_dataset.npz





In [11]:
# ==================================================
# 3) Estadísticas para normalizar entradas y targets
# ==================================================
class Stats:
    """
    Guardamos medias y desvíos para:
    - s (estado),
    - ds = s_{t+1} - s_t,
    - r (recompensa).
    Esto vuelve "bien condicionado" el entrenamiento del modelo del mundo.
    """
    def __init__(self):
        self.mu_s=None; self.std_s=None
        self.mu_ds=None; self.std_ds=None
        self.mu_r=0.;   self.std_r=1.
    def fit(self, S, R, Sp):
        DS = Sp - S
        eps = 1e-6
        self.mu_s  = S.mean(0); self.std_s  = S.std(0) + eps
        self.mu_ds = DS.mean(0); self.std_ds = DS.std(0) + eps
        self.mu_r  = float(R.mean()); self.std_r = float(R.std() + eps)
    def dumps(self):
        return dict(mu_s=self.mu_s.tolist(), std_s=self.std_s.tolist(),
                    mu_ds=self.mu_ds.tolist(), std_ds=self.std_ds.tolist(),
                    mu_r=self.mu_r, std_r=self.std_r)
    @staticmethod
    def loads(js):
        st = Stats()
        st.mu_s  = np.array(js["mu_s"],  np.float32)
        st.std_s = np.array(js["std_s"], np.float32)
        st.mu_ds = np.array(js["mu_ds"], np.float32)
        st.std_ds= np.array(js["std_ds"],np.float32)
        st.mu_r  = float(js["mu_r"]); st.std_r = float(js["std_r"])
        return st

STATS = Stats(); STATS.fit(S, R, Sp)
(Path(ROOT/"data"/"stats.json")).write_text(json.dumps(STATS.dumps(), indent=2))
print("Stats guardadas:", (ROOT/"data"/"stats.json").resolve())


Stats guardadas: /home/jovyan/MVP_RL_LunarLander/world_models/data/stats.json


In [12]:
# ======================================================
# 4) MDN-RNN: RNN que predice una mezcla Gaussiana sobre
#    Δs, además de recompensa y done. Entrenamos con NLL.
# ======================================================
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

# --- utilidades MDN
def mdn_params(x, K, D):
    """
    Parte la salida lineal en:
    - pi_logits: [B,T,K]
    - mu:        [B,T,K,D]
    - log_sigma: [B,T,K,D]
    """
    B,T,_ = x.shape
    pi = x[...,:K]
    mu = x[...,K:K+K*D].view(B,T,K,D)
    ls = x[...,K+K*D:].view(B,T,K,D)
    return pi, mu, ls

def mdn_nll(x, pi_logits, mu, log_sigma):
    """
    NLL de una mezcla Gaussiana diagonal.
    x: [B,T,D], pi_logits: [B,T,K], mu/log_sigma: [B,T,K,D]
    """
    import math as _m
    B,T,D = x.shape; K = pi_logits.shape[-1]
    x = x.unsqueeze(2)             # [B,T,1,D]
    var = torch.exp(2*log_sigma)   # [B,T,K,D]
    log_comp = -0.5*((x-mu)**2/var + 2*log_sigma + _m.log(2*_m.pi))  # [B,T,K,D]
    log_comp = log_comp.sum(-1)    # [B,T,K]
    log_mix  = torch.log_softmax(pi_logits, -1) + log_comp
    nll = -torch.logsumexp(log_mix, -1)  # [B,T]
    return nll.mean()

# --- dataset de secuencias (teacher forcing)
class SeqDataset(Dataset):
    """
    Corta el flujo plano en bloques de longitud T.
    Devuelve:
      x  = concat( s_normalizado, onehot(a) )  --> entrada al RNN
      ds = Δs_normalizado                      --> target para MDN
      r  = r_normalizado
      d  = done (0/1)
    """
    def __init__(self, S, A, R, Sp, D, stats:Stats, T=32):
        DS = Sp - S
        self.s  = ((S  - stats.mu_s)/stats.std_s).astype(np.float32)
        self.ds = ((DS - stats.mu_ds)/stats.std_ds).astype(np.float32)
        self.r  = ((R  - stats.mu_r)/stats.std_r).astype(np.float32)
        self.d  = D.astype(np.float32)
        self.a  = A.astype(np.int64)
        self.T  = T
        self.N  = len(S)//T
    def __len__(self): return self.N
    def __getitem__(self, i):
        i0 = i*self.T; i1 = i0 + self.T
        s   = torch.from_numpy(self.s[i0:i1])     # [T,8]
        ds  = torch.from_numpy(self.ds[i0:i1])    # [T,8]
        r   = torch.from_numpy(self.r[i0:i1])     # [T]
        d   = torch.from_numpy(self.d[i0:i1])     # [T]
        a   = torch.from_numpy(self.a[i0:i1])     # [T]
        aoh = F.one_hot(a, num_classes=4).float() # [T,4]
        x   = torch.cat([s, aoh], -1)             # [T,12]
        return x, ds, r, d

# --- el RNN con cabezas MDN + r + done
class MDNRNN(nn.Module):
    """
    LSTM(256) → MDN (K mezclas) para Δs, y cabezas para r (MSE) y done (BCE).
    """
    def __init__(self, s_dim=8, a_dim=4, h=256, K=5):
        super().__init__()
        self.rnn = nn.LSTM(input_size=s_dim+a_dim, hidden_size=h, batch_first=True)
        self.mdn = nn.Linear(h, K + K*s_dim + K*s_dim)  # pi + mu + log_sigma
        self.head_r = nn.Linear(h, 1)
        self.head_d = nn.Linear(h, 1)
        self.K = K; self.s_dim = s_dim
    def forward(self, x, h=None):
        y, h = self.rnn(x, h)           # y: [B,T,H]
        mdn_out = self.mdn(y)
        pi, mu, ls = mdn_params(mdn_out, self.K, self.s_dim)
        r = self.head_r(y).squeeze(-1)
        d = torch.sigmoid(self.head_d(y)).squeeze(-1)
        return (pi, mu, ls), r, d, h


In [13]:
# ======================================================
# 5) Entrenador: DataLoader, NLL + MSE + BCE, guardado
# ======================================================
from torch.utils.data import random_split, DataLoader

def train_mdnrnn(S, A, R, Sp, D, stats:Stats,
                 K=5, H=256, T=32, batch=512, lr=3e-4, epochs=25):
    ds = SeqDataset(S,A,R,Sp,D, stats, T=T)
    n = len(ds); n_val = max(256, n//10); n_tr = n - n_val
    tr_set, val_set = random_split(ds, [n_tr, n_val], generator=torch.Generator().manual_seed(SEED))
    tr = DataLoader(tr_set, batch_size=batch, shuffle=True,  drop_last=True)
    va = DataLoader(val_set, batch_size=batch, shuffle=False, drop_last=False)

    model = MDNRNN(K=K, h=H)  # CPU va sobrado para este tamaño
    opt = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=1e-5)

    best_loss = 1e9; best_state=None
    for ep in range(1, epochs+1):
        model.train(); tr_mdn=tr_r=tr_d=0; nsteps=0
        for x, ds_t, r_t, d_t in tr:
            (pi,mu,ls), rp, dp, _ = model(x)
            loss_mdn = mdn_nll(ds_t, pi, mu, ls)
            loss_r   = F.mse_loss(rp, r_t)
            loss_d   = F.binary_cross_entropy(dp, d_t)
            loss = loss_mdn + loss_r + loss_d
            opt.zero_grad(); loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            opt.step()
            tr_mdn += loss_mdn.item(); tr_r += loss_r.item(); tr_d += loss_d.item(); nsteps += 1

        model.eval(); va_mdn=va_r=va_d=0; m=0
        with torch.no_grad():
            for x, ds_t, r_t, d_t in va:
                (pi,mu,ls), rp, dp, _ = model(x)
                va_mdn += mdn_nll(ds_t, pi, mu, ls).item()
                va_r   += F.mse_loss(rp, r_t).item()
                va_d   += F.binary_cross_entropy(dp, d_t).item()
                m += 1

        train_log = f"Ep {ep:02d} | tr(MDN:{tr_mdn/nsteps:.3f} r:{tr_r/nsteps:.3f} d:{tr_d/nsteps:.3f})"
        val_log   = f" | va(MDN:{va_mdn/m:.3f} r:{va_r/m:.3f} d:{va_d/m:.3f})"
        print(train_log + val_log)

        score = va_mdn/m + va_r/m + va_d/m
        if score < best_loss:
            best_loss = score
            best_state = {k:v.cpu() for k,v in model.state_dict().items()}

    # Guardamos el mejor
    torch.save(best_state, ROOT/"checkpoints"/"mdnrnn.pt")
    (ROOT/"checkpoints"/"stats.json").write_text(json.dumps(stats.dumps(), indent=2))
    print("Guardado MDN-RNN en:", (ROOT/"checkpoints"/"mdnrnn.pt").resolve())
    return model

MDN = train_mdnrnn(S,A,R,Sp,D, STATS, epochs=20)


Ep 01 | tr(MDN:-2.121 r:1.000 d:0.251) | va(MDN:-7.393 r:0.947 d:0.129)
Ep 02 | tr(MDN:-7.398 r:0.995 d:0.129) | va(MDN:-8.707 r:0.946 d:0.125)
Ep 03 | tr(MDN:-8.070 r:0.993 d:0.124) | va(MDN:-9.459 r:0.945 d:0.121)
Ep 04 | tr(MDN:-8.010 r:0.993 d:0.121) | va(MDN:-6.933 r:0.944 d:0.117)
Ep 05 | tr(MDN:-9.250 r:0.992 d:0.117) | va(MDN:-9.894 r:0.944 d:0.114)
Ep 06 | tr(MDN:-9.615 r:0.991 d:0.114) | va(MDN:-9.799 r:0.943 d:0.110)
Ep 07 | tr(MDN:-8.227 r:0.990 d:0.110) | va(MDN:-9.143 r:0.941 d:0.107)
Ep 08 | tr(MDN:-9.153 r:0.988 d:0.107) | va(MDN:-9.919 r:0.938 d:0.104)
Ep 09 | tr(MDN:-11.653 r:0.987 d:0.105) | va(MDN:-12.689 r:0.937 d:0.101)
Ep 10 | tr(MDN:-10.985 r:0.984 d:0.103) | va(MDN:-8.548 r:0.936 d:0.100)
Ep 11 | tr(MDN:-9.446 r:0.983 d:0.101) | va(MDN:-10.549 r:0.934 d:0.098)
Ep 12 | tr(MDN:-11.449 r:0.980 d:0.099) | va(MDN:-12.619 r:0.931 d:0.096)
Ep 13 | tr(MDN:-12.082 r:0.979 d:0.097) | va(MDN:-8.556 r:0.929 d:0.094)
Ep 14 | tr(MDN:-11.434 r:0.977 d:0.095) | va(MDN:-12.028 

In [14]:
# ==========================================
# 6) DreamEnv: simula episodios en el “sueño”
# ==========================================
class DreamEnv:
    """
    Simulador ligero:
      - Estado interno = (s_t real, h_t del LSTM)
      - step(a) usa el MDN para muestrear Δs con temperatura tau
      - devuelve (s_{t+1}, r_hat, done_hat)
    """
    def __init__(self, mdn:MDNRNN, stats:Stats, tau=1.15, device="cpu"):
        self.net = mdn.eval().to(device)
        self.stats = stats; self.tau = tau; self.device = device
        self.h = None; self.s = None

    def reset(self, s0=None):
        if s0 is None:
            s0 = np.array([0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0., 0.], np.float32)
        self.s = clip_obs(s0); self.h = None
        return self.s.copy()

    @torch.no_grad()
    def step(self, a:int):
        # normalizamos s_t y concatenamos onehot(a)
        s_n = (torch.tensor(self.s, dtype=torch.float32, device=self.device)
               - torch.tensor(self.stats.mu_s, device=self.device)) / torch.tensor(self.stats.std_s, device=self.device)
        a_oh = F.one_hot(torch.tensor([a], device=self.device), num_classes=4).float()
        x = torch.cat([s_n.unsqueeze(0), a_oh], -1).unsqueeze(0)  # [1,1,12]

        (pi,mu,ls), r_n, d_p, self.h = self.net(x, self.h)

        # muestreo de mezcla con temperatura
        pi = torch.softmax(pi[0,0]/self.tau, -1)  # [K]
        k  = torch.multinomial(pi, 1).item()
        mu_k  = mu[0,0,k]
        std_k = torch.exp(ls[0,0,k]) * self.tau
        ds_n  = torch.normal(mu_k, std_k)

        # desnormalizamos y avanzamos
        ds = (ds_n.cpu().numpy())*self.stats.std_ds + self.stats.mu_ds
        sp = clip_obs(self.s + ds)
        r  = float((r_n[0,0].item())*self.stats.std_r + self.stats.mu_r)
        d  = bool(d_p[0,0].item() > 0.5)
        self.s = sp
        return sp, r, d, {"k":k, "pi":pi.cpu().numpy()}


In [15]:
# ==========================================================
# 7) Render minimalista (triángulo) + guardado de videos
# ==========================================================
import matplotlib.pyplot as plt
from matplotlib.patches import Polygon
from matplotlib.backends.backend_agg import FigureCanvasAgg
import imageio

def draw_lander_from_state(s, size=256):
    """
    Dibuja un triángulo representando la nave usando x,y y el ángulo theta.
    Es muy barato y suficiente para visualizar el “sueño”.
    """
    x,y, vx,vy, theta, vth, leg_l, leg_r = s
    fig = plt.figure(figsize=(size/100, size/100), dpi=100)
    ax = fig.add_axes([0,0,1,1]); ax.set_xlim(-1.5,1.5); ax.set_ylim(-0.2,1.6); ax.axis('off')
    ax.plot([-1.5,1.5],[0,0], lw=2, color='gray')  # suelo

    L = 0.1
    pts = np.array([[0,L],[-L,-L],[L,-L]])
    c, sct = math.cos(theta), math.sin(theta)
    R = np.array([[c,-sct],[sct,c]])
    pts = (pts @ R.T) + np.array([x,y])
    ax.add_patch(Polygon(pts, closed=True, color='steelblue'))

    canvas = FigureCanvasAgg(fig); canvas.draw()
    frame = np.frombuffer(canvas.buffer_rgba(), dtype=np.uint8)
    frame = frame.reshape((int(fig.bbox.height), int(fig.bbox.width), 4))[...,:3]
    plt.close(fig)
    return frame

def save_video(frames, path, fps=30):
    path = Path(path); path.parent.mkdir(parents=True, exist_ok=True)
    imageio.mimsave(path, frames, fps=fps)
    return str(path)


In [None]:
# ==========================================================
# 8) Controller pequeño + CMA-ES (entrenado en DreamEnv)
# ==========================================================
import cma
import torch.nn as nn

class Controller(nn.Module):
    """
    Política muy compacta.
    - Entrada = concat(s_t, h_t) ≈ 8 + 256 = 264 dim
    - Salida = logits para 4 acciones (softmax → distribución)
    """
    def __init__(self, in_dim=8+256, hidden=64, out_dim=4):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim, hidden), nn.Tanh(),
            nn.Linear(hidden, hidden), nn.Tanh(),
            nn.Linear(hidden, out_dim)
        )
    def forward(self, s, h):
        # h es una tupla (h_n, c_n); aplanamos h_n
        h_flat = (h[0] if h is not None else torch.zeros(1,1,256)).squeeze(0).squeeze(0)
        x = torch.cat([torch.tensor(s, dtype=torch.float32), h_flat.cpu()], -1)
        logits = self.net(x)
        return torch.softmax(logits, -1).detach().cpu().numpy()

def flatten_params(model:nn.Module):
    with torch.no_grad(): return torch.cat([p.view(-1) for p in model.parameters()]).numpy()

def assign_params(model:nn.Module, vec):
    vec = torch.tensor(vec, dtype=torch.float32)
    i = 0
    with torch.no_grad():
        for p in model.parameters():
            n = p.numel()
            p.copy_(vec[i:i+n].view_as(p)); i += n

def dream_episode(env:DreamEnv, ctrl:Controller, max_steps=1000, render_every=0):
    frames=[]; s = env.reset(); total=0.0
    for t in range(max_steps):
        if render_every and (t%render_every==0): frames.append(draw_lander_from_state(s,256))
        probs = ctrl(s, env.h)
        a = int(np.random.choice(4, p=probs))
        s, r, d, _ = env.step(a); total += r
        if d: break
    return total, frames

def train_cma_dream(mdn_path=ROOT/"checkpoints"/"mdnrnn.pt",
                    stats_path=ROOT/"checkpoints"/"stats.json",
                    gens=60, popsize=384, episodes_per_ind=8, tau=1.15):
    """
    CMA-ES entrena el controlador EXCLUSIVAMENTE en el entorno soñado.
    - gens: generaciones de evolución
    - popsize: tamaño de población por generación
    - episodes_per_ind: episodios soñados por individuo (promediamos)
    - tau: temperatura del MDN (1.1–1.3 suele ir bien)
    """
    mdn = MDNRNN(); mdn.load_state_dict(torch.load(mdn_path, map_location="cpu"))
    stats = Stats.loads(json.loads(Path(stats_path).read_text()))
    dream = DreamEnv(mdn, stats, tau=tau, device="cpu")

    ctrl = Controller()
    x0 = flatten_params(ctrl)
    es = cma.CMAEvolutionStrategy(x0, 0.5, {"popsize": popsize, "maxiter": gens, "verb_log":0, "verb_disp":1})

    best = (None, -1e9)
    gen = 0
    while not es.stop():
        cand = es.ask()
        fitness = []
        for w in cand:
            assign_params(ctrl, w)
            returns = []
            for _ in range(episodes_per_ind):
                ret, _ = dream_episode(dream, ctrl, max_steps=1000, render_every=0)
                returns.append(ret)
            mean_ret = float(np.mean(returns))
            fitness.append(-mean_ret)  # CMA minimiza
            if mean_ret > best[1]:
                best = (w.copy(), mean_ret)
        es.tell(cand, fitness)
        gen += 1
        print(f"Gen {gen:03d} | best_dream_return={best[1]:.1f}")

    assign_params(ctrl, best[0])
    torch.save(ctrl.state_dict(), ROOT/"checkpoints"/"controller.pt")
    print("Controller guardado:", (ROOT/"checkpoints"/"controller.pt").resolve())

    # Grabamos un video soñado como muestra
    _, frames = dream_episode(dream, ctrl, max_steps=800, render_every=2)
    save_video(frames, ROOT/"videos"/"dream_controller.mp4", fps=30)
    print("Video soñado →", (ROOT/"videos"/"dream_controller.mp4").resolve())
    return ctrl

CTRL = train_cma_dream(gens=40, popsize=256, episodes_per_ind=6, tau=1.15)


(128_w,256)-aCMA-ES (mu_w=66.9,w_1=3%) in dimension 21380 (seed=124935, Wed Aug 13 14:08:06 2025)
Gen 001 | best_dream_return=-760.0
Gen 002 | best_dream_return=-716.5
Gen 003 | best_dream_return=-664.8
Gen 004 | best_dream_return=-664.8
Gen 005 | best_dream_return=-664.8
Gen 006 | best_dream_return=-526.7
Gen 007 | best_dream_return=-526.7
Gen 008 | best_dream_return=-526.7


In [None]:
# ==========================================================
# 9) Eval real/soñado + videos + RMSE de rollouts del modelo
# ==========================================================
def eval_real_with_ctrl(env_id=ENV_ID, ctrl:Controller,
                        mdn_path=ROOT/"checkpoints"/"mdnrnn.pt",
                        stats_path=ROOT/"checkpoints"/"stats.json",
                        episodes=10, grab_first=True):
    """
    Ejecuta el controller en el entorno REAL.
    Importante: actualizamos h_t del MDN usando (s_t, a_{t-1}) reales
    para que el controller reciba su memoria como en el paper.
    """
    mdn = MDNRNN(); mdn.load_state_dict(torch.load(mdn_path, map_location="cpu")); mdn.eval()
    st  = Stats.loads(json.loads(Path(stats_path).read_text()))
    env = gym.make(env_id, render_mode="rgb_array")
    scores=[]; frames=[]

    for ep in range(episodes):
        s,_ = env.reset(seed=SEED+ep); s = clip_obs(s)
        total=0.0; done=False; h=None; a_prev=0
        while not done:
            # Paso del RNN con (s_t, a_{t-1}) reales para actualizar h
            s_n = (torch.tensor(s, dtype=torch.float32) - torch.tensor(st.mu_s)) / torch.tensor(st.std_s)
            a_oh = F.one_hot(torch.tensor([a_prev]), num_classes=4).float()
            x = torch.cat([s_n.unsqueeze(0), a_oh], -1).unsqueeze(0)
            with torch.no_grad():
                _,_,_, h = mdn(x, h)

            probs = ctrl(s, h)
            a = int(np.random.choice(4, p=probs))
            sp, r, term, trunc, _ = env.step(a); d = term or trunc
            if grab_first and ep==0:
                frames.append(env.render())
            total += r; s = clip_obs(sp); done=d; a_prev=a
        scores.append(total)

    env.close()
    if grab_first and frames:
        save_video(frames, ROOT/"videos"/"real_controller.mp4", fps=30)
        print("Video REAL →", (ROOT/"videos"/"real_controller.mp4").resolve())
    return float(np.mean(scores)), scores

def rollout_rmse(model:MDNRNN, stats:Stats, S,A,Sp, horizon=30, trials=300):
    """
    Mide qué tan bien “se sostiene” el modelo por múltiples pasos.
    Tomamos fragmentos del dataset y comparamos (x,y,theta).
    """
    model.eval()
    idx = np.random.choice(len(S)-horizon-1, size=min(trials, len(S)-horizon-1), replace=False)
    errs=[]
    with torch.no_grad():
        for i in idx:
            s = S[i]; h=None
            err=[]
            for t in range(horizon):
                a = A[i+t]
                s_n = (torch.tensor(s, dtype=torch.float32) - torch.tensor(stats.mu_s)) / torch.tensor(stats.std_s)
                a_oh = F.one_hot(torch.tensor([a]), num_classes=4).float()
                x = torch.cat([s_n.unsqueeze(0), a_oh], -1).unsqueeze(0)
                (pi,mu,ls),_,_, h = model(x, h)
                k = torch.softmax(pi[0,0],-1).argmax().item()  # “mejor” componente
                ds_n = mu[0,0,k]
                ds = (ds_n.numpy())*stats.std_ds + stats.mu_ds
                s = clip_obs(s + ds)
                err.append(np.square(s[:3] - Sp[i+t][:3]).mean())  # x,y,theta
            errs.append(np.mean(err))
    return float(np.sqrt(np.mean(errs)))

# --- Cargar mejor controller y evaluar
CTRL.load_state_dict(torch.load(ROOT/"checkpoints"/"controller.pt", map_location="cpu"))

# 9a) Retorno soñado (un episodio + vídeo ya guardado)
dream = DreamEnv(MDN, STATS, tau=1.15, device="cpu")
dream_ret, _ = dream_episode(dream, CTRL, max_steps=800, render_every=0)
print(f"Return SOÑADO (1 ep): {dream_ret:.1f}")

# 9b) Retorno real (10 episodios + vídeo del 1º episodio)
real_mean, real_scores = eval_real_with_ctrl(ENV_ID, CTRL, episodes=10, grab_first=True)
print(f"Return REAL (media 10 eps): {real_mean:.1f}")

# 9c) Calidad del modelo: RMSE de rollout a 30 pasos
rmse30 = rollout_rmse(MDN, STATS, S,A,Sp, horizon=30, trials=300)
print(f"RMSE@30 pasos (x,y,theta): {rmse30:.3f}")


In [None]:
pip install pipreqs 