In [8]:
import torch
import torch.nn as nn
import torch.optim as optim
import random
import numpy as np
from collections import deque
from gd_env import GeometryEnv
from model import DQN
import os

In [None]:
def train_dqn(episodes=1000, checkpoint_path=None):
    env = GeometryEnv()

    sample_state = env.reset()
    state_dim = len(sample_state)
    n_actions = 2

    print(f"Estado inicial: shape={sample_state.shape}, dim={state_dim}")

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print("Usando dispositivo:", device)

    policy_net = DQN(state_dim, n_actions).to(device)
    target_net = DQN(state_dim, n_actions).to(device)
    optimizer = optim.Adam(policy_net.parameters(), lr=1e-3)
    criterion = nn.SmoothL1Loss()

    memory = deque(maxlen=50000)
    batch_size = 64
    gamma = 0.99
    eps = 1.0
    eps_min = 0.05
    eps_decay = 0.995
    update_target_every = 100
    start_ep = 0 

    if checkpoint_path and os.path.exists(checkpoint_path):
        print(f"Cargando checkpoint desde: {checkpoint_path}")
        checkpoint = torch.load(checkpoint_path, map_location=device)

        if "policy_state" in checkpoint:
            policy_net.load_state_dict(checkpoint["policy_state"])
            target_net.load_state_dict(checkpoint["target_state"])
            optimizer.load_state_dict(checkpoint["optimizer_state"])
            eps = checkpoint.get("epsilon", 1.0)
            start_ep = checkpoint.get("episode", 0)
        else:
            policy_net.load_state_dict(checkpoint)
            target_net.load_state_dict(policy_net.state_dict())

        print(f"Checkpoint cargado correctamente (episodio {start_ep})")

    else:
        print("No se encontr√≥ checkpoint, entrenamiento desde cero.")

    target_net.eval()
    os.makedirs("checkpoints", exist_ok=True)

    # ===============================
    # ENTRENAMIENTO PRINCIPAL
    # ===============================
    for ep in range(start_ep, episodes):
        s = env.reset().astype(np.float32)
        total_r = 0.0
        done = False

        for t in range(2000):
            if np.any(np.isnan(s)):
                print(f"Estado NaN en ep {ep}, t={t}. Reiniciando entorno.")
                s = env.reset().astype(np.float32)
                continue

            # Acci√≥n epsilon-greedy
            if random.random() < eps:
                a = random.randrange(n_actions)
            else:
                with torch.no_grad():
                    qvals = policy_net(torch.tensor(s, dtype=torch.float32, device=device).unsqueeze(0))
                    a = int(torch.argmax(qvals).item())

            s2, r, done, _ = env.step(a)
            s2 = np.array(s2, dtype=np.float32)
            memory.append((s, a, r, s2, done))
            s = s2
            total_r += r

            # Entrenamiento
            if len(memory) < batch_size:
                continue

            batch = random.sample(memory, batch_size)
            s_b = torch.tensor(np.array([x[0] for x in batch]), dtype=torch.float32, device=device)
            a_b = torch.tensor([x[1] for x in batch], dtype=torch.int64, device=device).unsqueeze(1)
            r_b = torch.tensor([x[2] for x in batch],dtype=torch.float32, device=device ), 
            r_b = torch.tensor([x[2] for x in batch], dtype=torch.float32, device=device).unsqueeze(1)
            s2_b = torch.tensor(np.array([x[3] for x in batch]), dtype=torch.float32, device=device)
            d_b = torch.tensor([x[4] for x in batch], dtype=torch.float32, device=device).unsqueeze(1)

            if s_b.shape[1] != state_dim:
                print(f"Shape error: s_b {s_b.shape}, esperado (*, {state_dim})")
                continue

            q_vals = policy_net(s_b).gather(1, a_b)
            with torch.no_grad():
                next_q = target_net(s2_b).max(1, keepdim=True)[0]
            target = r_b + gamma * next_q * (1 - d_b)

            loss = criterion(q_vals, target)
            if torch.isnan(loss):
                print(f"Loss NaN en ep {ep}, paso {t}")
                continue

            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(policy_net.parameters(), 1.0)
            optimizer.step()

            if done:
                break

        eps = max(eps_min, eps * eps_decay)
        if ep % update_target_every == 0:
            target_net.load_state_dict(policy_net.state_dict())

        if (ep + 1) % 20 == 0:
            print(f"Ep {ep+1:04d} | Reward: {total_r:.1f} | Eps: {eps:.2f}")

   
        if (ep + 1) % 100 == 0:
            ckpt_path = f"checkpoints/geometry_ep{ep+1}.pth"
            torch.save({
                "episode": ep + 1,
                "policy_state": policy_net.state_dict(),
                "target_state": target_net.state_dict(),
                "optimizer_state": optimizer.state_dict(),
                "epsilon": eps,
            }, ckpt_path)
            print(f"Checkpoint guardado: {ckpt_path}")

    
    torch.save({
        "state_dim": state_dim,
        "model_state_dict": policy_net.state_dict()
    }, "geometry_dash_dqn.pth")

    print("Modelo final guardado como geometry_dash_dqn.pth")
    return policy_net

In [12]:
train_dqn(episodes=100, checkpoint_path="checkpoints/geometry_ep500.pth")

üîç Estado inicial: shape=(9,), dim=9
‚öôÔ∏è  Usando dispositivo: cpu
üß† Cargando checkpoint desde: checkpoints/geometry_ep500.pth
‚úÖ Checkpoint cargado correctamente (episodio 0)
Ep 0020 | Reward: 0.0 | Eps: 0.90
Ep 0040 | Reward: 2.0 | Eps: 0.82
Ep 0060 | Reward: 1.0 | Eps: 0.74
Ep 0080 | Reward: -5.0 | Eps: 0.67
Ep 0100 | Reward: -5.0 | Eps: 0.61
üíæ Checkpoint guardado: checkpoints/geometry_ep100.pth
‚úÖ Modelo final guardado como geometry_dash_dqn.pth


DQN(
  (net): Sequential(
    (0): Linear(in_features=9, out_features=128, bias=True)
    (1): ReLU()
    (2): Linear(in_features=128, out_features=128, bias=True)
    (3): ReLU()
    (4): Linear(in_features=128, out_features=2, bias=True)
  )
)