In [5]:
import sys
import os

# FIX: Resolver conflicto de OpenMP (Error #15) que causa crash del kernel
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"

import json
import shutil
import types
import importlib
from gymnasium import spaces

# Configuración de rutas locales
project_path = os.getcwd()
if project_path not in sys.path:
    sys.path.append(project_path)

baselines_path = os.path.join(project_path, 'baselines')
if baselines_path not in sys.path:
    sys.path.append(baselines_path)

print(f"Directorio de trabajo: {project_path}")

Directorio de trabajo: c:\Users\javi1\Documents\repos_git\TEL351-PokemonRed


In [6]:
# --- RELOAD MODULES ---
def reload_modules():
    modules_to_reload = [
        'v2.red_gym_env_v2',
        'advanced_agents.features',
        'advanced_agents.wrappers',
        'advanced_agents.base',
        'advanced_agents.train_agents',
        'advanced_agents.combat_apex_agent',
        'advanced_agents.puzzle_speed_agent',
        'advanced_agents.hybrid_sage_agent',
        'advanced_agents.transition_models'
    ]
    for mod_name in modules_to_reload:
        if mod_name in sys.modules:
            try:
                importlib.reload(sys.modules[mod_name])
                print(f"♻️ Recargado: {mod_name}")
            except Exception as e:
                print(f"⚠️ No se pudo recargar {mod_name}: {e}")

reload_modules()

♻️ Recargado: v2.red_gym_env_v2
♻️ Recargado: advanced_agents.features
♻️ Recargado: advanced_agents.wrappers
♻️ Recargado: advanced_agents.base
♻️ Recargado: advanced_agents.train_agents
♻️ Recargado: advanced_agents.combat_apex_agent
♻️ Recargado: advanced_agents.puzzle_speed_agent
♻️ Recargado: advanced_agents.hybrid_sage_agent
♻️ Recargado: advanced_agents.transition_models


In [7]:
# Copiar events.json si es necesario
events_source = os.path.join(project_path, 'baselines', 'events.json')
events_dest = os.path.join(project_path, 'events.json')
if os.path.exists(events_source) and not os.path.exists(events_dest):
    shutil.copy(events_source, events_dest)
    print(f"Copiado events.json a {events_dest}")

In [8]:
import json
import shutil
import types
import importlib
from typing import Dict, Iterable, List, Optional

from gymnasium import spaces

try:
    from advanced_agents.train_agents import _base_env_config
    from advanced_agents.combat_apex_agent import CombatApexAgent, CombatAgentConfig
    from advanced_agents.puzzle_speed_agent import PuzzleSpeedAgent, PuzzleAgentConfig
    from advanced_agents.hybrid_sage_agent import HybridSageAgent, HybridAgentConfig
except ImportError as e:
    print("ERROR CRÍTICO: Fallo en imports.")
    raise e

# --- Cargar escenarios ---
SCENARIO_PATH = os.path.join(project_path, 'gym_scenarios', 'scenarios.json')
with open(SCENARIO_PATH, 'r') as f:
    scenarios_data = json.load(f)

SCENARIOS: Dict[str, Dict] = {scenario['id']: scenario for scenario in scenarios_data['scenarios']}

AGENT_REGISTRY = {
    'combat': {
        'agent_cls': CombatApexAgent,
        'config_cls': CombatAgentConfig,
        'default_phase': 'battle'
    },
    'puzzle': {
        'agent_cls': PuzzleSpeedAgent,
        'config_cls': PuzzleAgentConfig,
        'default_phase': 'puzzle'
    },
    'hybrid': {
        'agent_cls': HybridSageAgent,
        'config_cls': HybridAgentConfig,
        'default_phase': 'battle'
    }
}

MODELS_DIR = os.path.join(project_path, 'models_local')
os.makedirs(MODELS_DIR, exist_ok=True)

def resolve_phase(scenario_id: str, phase_name: Optional[str]) -> Dict:
    scenario = SCENARIOS.get(scenario_id)
    if scenario is None:
        raise ValueError(f"Escenario {scenario_id} no encontrado en {SCENARIO_PATH}")
    target_phase = phase_name or AGENT_REGISTRY['combat']['default_phase']
    selected_phase = next((p for p in scenario['phases'] if p['name'] == target_phase), None)
    if selected_phase is None:
        raise ValueError(f"Fase {target_phase} no encontrada en el escenario {scenario_id}")
    return selected_phase

def ensure_state_file(state_file_path: str) -> str:
    abs_path = os.path.join(project_path, state_file_path) if not os.path.isabs(state_file_path) else state_file_path
    if not os.path.exists(abs_path):
        raise FileNotFoundError(
            f"No se encontró el archivo de estado requerido: {abs_path}. "
            "Genera los .state con generate_gym_states.py o ajusta la ruta."
        )
    return abs_path

def build_env_overrides(state_file_path: str, headless: bool) -> Dict:
    return {
        'init_state': state_file_path,
        'headless': headless,
        'save_video': False,
        'gb_path': os.path.join(project_path, 'PokemonRed.gb'),
        'session_path': os.path.join(project_path, 'sessions', f"local_{os.path.basename(state_file_path)}"),
        'render_mode': 'rgb_array' if headless else 'human',
        'fast_video': headless
    }

def _patch_callbacks(agent, additional_callbacks: Optional[List] = None):
    base_callbacks_method = agent.extra_callbacks

    def _patched_callbacks(self):
        callbacks = list(base_callbacks_method())
        if additional_callbacks:
            callbacks.extend(additional_callbacks)
        return callbacks

    agent.extra_callbacks = types.MethodType(_patched_callbacks, agent)

def train_single_run(
    agent_key: str,
    scenario_id: str,
    phase_name: str,
    total_timesteps: int = 200_000,
    headless: bool = False,
    additional_callbacks: Optional[List] = None
):
    registry_entry = AGENT_REGISTRY.get(agent_key)
    if registry_entry is None:
        raise ValueError(f"Agente desconocido: {agent_key}")

    phase = resolve_phase(scenario_id, phase_name)
    state_file_path = ensure_state_file(phase['state_file'])

    env_overrides = build_env_overrides(state_file_path, headless=headless)
    config = registry_entry['config_cls'](
        env_config=_base_env_config(env_overrides),
        total_timesteps=total_timesteps
    )

    agent = registry_entry['agent_cls'](config)

    env_for_check = agent.make_env()
    obs_space = getattr(env_for_check, 'observation_space', None)
    if isinstance(obs_space, spaces.Dict):
        print("Observación Dict detectada -> MultiInputPolicy")
        agent.policy_name = types.MethodType(lambda self: "MultiInputPolicy", agent)
    env_for_check.close()

    if additional_callbacks:
        _patch_callbacks(agent, additional_callbacks)

    print(
        f"\n=== Entrenando {agent_key.upper()} en {scenario_id} ({phase_name}) por {total_timesteps:,} pasos ===")
    runtime = agent.train()

    agent_dir = os.path.join(MODELS_DIR, agent_key)
    os.makedirs(agent_dir, exist_ok=True)
    model_path = os.path.join(agent_dir, f"{scenario_id}_{phase_name}.zip")
    runtime.model.save(model_path)
    print(f"Modelo guardado en {model_path}")

    return runtime

def train_plan(
    agent_key: str,
    plan: List[Dict],
    default_timesteps: int = 200_000,
    headless: bool = False,
    callback_factory: Optional[callable] = None
) -> Dict[tuple, object]:
    results = {}
    total_runs = len(plan)
    for run_idx, entry in enumerate(plan, start=1):
        scenario_id = entry['scenario']
        phase_name = entry.get('phase') or AGENT_REGISTRY[agent_key]['default_phase']
        run_timesteps = entry.get('timesteps', default_timesteps)
        callbacks = None
        if callback_factory is not None:
            callbacks = callback_factory(entry)
        print(f"\n>>> [{agent_key.upper()}] Ejecución {run_idx}/{total_runs}")
        runtime = train_single_run(
            agent_key=agent_key,
            scenario_id=scenario_id,
            phase_name=phase_name,
            total_timesteps=run_timesteps,
            headless=headless,
            additional_callbacks=callbacks
        )
        results[(scenario_id, phase_name)] = runtime
    return results

### Configura planes de entrenamiento locales
Especifica los escenarios, fases y timesteps que quieres para cada agente. Puedes ejecutar cada bloque por separado y combinar headless=True/False según quieras ver la ventana del emulador.

In [9]:
combat_plan_local = [
    {"scenario": "pewter_brock", "phase": "battle", "timesteps": 200_000},
    # {"scenario": "cerulean_misty", "phase": "battle", "timesteps": 220_000},
]

puzzle_plan_local = [
    {"scenario": "pewter_brock", "phase": "puzzle", "timesteps": 180_000},
    # {"scenario": "cerulean_misty", "phase": "puzzle", "timesteps": 200_000},
]

hybrid_plan_local = [
    {"scenario": "pewter_brock", "phase": "battle", "timesteps": 220_000},
    # {"scenario": "vermillion_lt_surge", "phase": "battle", "timesteps": 250_000},
]

DEFAULT_TIMESTEPS_LOCAL = 200_000
DEFAULT_HEADLESS_LOCAL = False  # Cambia a True si no necesitas la ventana SDL

In [None]:
combat_runs_local = train_plan(
    agent_key='combat',
    plan=combat_plan_local,
    default_timesteps=DEFAULT_TIMESTEPS_LOCAL,
    headless=DEFAULT_HEADLESS_LOCAL
)


>>> [COMBAT] Ejecución 1/1
Observación Dict detectada -> MultiInputPolicy

=== Entrenando COMBAT en pewter_brock (battle) por 200,000 pasos ===


In [None]:
puzzle_runs_local = train_plan(
    agent_key='puzzle',
    plan=puzzle_plan_local,
    default_timesteps=DEFAULT_TIMESTEPS_LOCAL,
    headless=DEFAULT_HEADLESS_LOCAL
)

In [None]:
hybrid_runs_local = train_plan(
    agent_key='hybrid',
    plan=hybrid_plan_local,
    default_timesteps=DEFAULT_TIMESTEPS_LOCAL,
    headless=DEFAULT_HEADLESS_LOCAL
)

In [None]:
# Guardar modelo
save_dir = "models_local"
os.makedirs(save_dir, exist_ok=True)
save_path = os.path.join(save_dir, f"{AGENT_TYPE}_{SCENARIO_ID}_{PHASE_NAME}")
model.save(save_path)
print(f"Modelo guardado en {save_path}")