In [1]:
import numpy as np
import cv2
from collections import deque
import gymnasium as gym
from pyboy import PyBoy
from pyboy.pyboy import WindowEvent
from gymnasium import spaces
from stable_baselines3 import DQN
from stable_baselines3.common.vec_env import DummyVecEnv
from stable_baselines3.common.callbacks import CheckpointCallback, EvalCallback
import torch
import os 
import time
import matplotlib.pyplot as plt
from PIL import Image
from threading import Thread
import keyboard
from pathlib import Path



In [2]:
class InteractiveTrainer:
    def __init__(self, env, model):
        
        self.env = env
        self.model = model
        self.manual_mode = False
        self.current_obs = None
        self.running = True
        self.control_thread = None
        self.key_bindings = {
            'up':0,
            'down':1,
            'left':2,
            'right':3,
            'a':4,
            's':5,
            'enter':6,
            'space':7
        }
        
    def start(self):
        
        self.control_thread = Thread(target=self._control_loop)
        self.control_thread.start()
        
    def _control_loop(self):
        """Main control loop running in bg"""
        while self.running:
            if keyboard.is_pressed('h'):
                self.manual_mode = not self.manual_mode
                time.sleep(0.5)
                
                if keyboard.is_pressed('q'):
                    self.running = False
                    print("\nTraining interrupted by user")
                    time.sleep(0.5)
                    
                time.sleep(0.5)
                
    def get_action(self, obs):
        """Get the next action based on current mode"""
        self.current_obs = obs
        
        if self.manual_mode:
            return self._get_manual_action()
        else:
            return self.model.predict(obs, deterministic=False)[0]
        
    def _get_manual_action(self):
        """Get action from keyboard input"""
        action = None
        for key, action_code in self.key_bindings.items():
            if keyboard.is_pressed(key):
                action = action_code
                break
        return action if action is not None else 0
    

Environment

In [3]:
class PokemonRedEnv(gym.Env):
    def __init__(self, rom_path, headless=True, max_steps=10000, render_mode=None):
        super().__init__()
        
        self.rom_path = rom_path
        self.headless = headless
        self.render_mode = render_mode

        self.pyboy = PyBoy(rom_path, window="null" if headless else "SDL2", debug=False)
        self.game_wrapper = self.pyboy.game_wrapper
        self.screen_has_alpha = None
        self.action_space = spaces.Discrete(8)
        
        self.observation_space = spaces.Box(
            low=0, high=255,
            shape=(4, 84, 84),
            dtype=np.uint8
        )
        
        self.frame_stack = deque(maxlen=4)
        
        self.step_count = 0
        self.max_steps = max_steps
        
        # Previous state tracking
        self.previous_badges = 0
        self.previous_pokemon_count = 0
        self.previous_levels = {}
        self.previous_areas_visited = set()
        self.previous_trainer_battles = 0 
        self.visited_maps = set()
        self.previous_battle_state = 0
        self.previous_position = (0, 0)
        
        # Game state flags
        self.has_pokemon = False
        self.has_left_oaks_lab = False
        self.has_left_pallet_town = False
        self.previous_menu_state = 0
        self.was_in_battle = False
        
        # Memory addresses
        self.BADGES_ADDRESS = 0xD356
        self.POKEMON_COUNT_ADDRESS = 0xD163
        self.PLAYER_X_ADDRESS = 0xD362
        self.PLAYER_Y_ADDRESS = 0xD361
        self.CURRENT_MAP_ADDRESS = 0xD35E
        
        self.action_names = {
            0: "UP", 1: "DOWN", 2: "LEFT", 3: "RIGHT", 
            4: "A", 5: "B", 6: "START", 7: "SELECT"
        }
        
        # Episode tracking
        self.episode_rewards = []
        self.episode_steps = []
        self.episode_badges = []
        
        self._init_frame_stack()
        
    def _init_frame_stack(self):
        """Initialize frame stack with identical frames"""
        initial_frame = self._get_processed_frame()
        for _ in range(4):
            self.frame_stack.append(initial_frame)
            
    def _get_processed_frame(self):
        """Get processed frame: grayscale + resize to 84x84"""
        screen = self.pyboy.screen.ndarray
        
        if self.screen_has_alpha is None:
            self.screen_has_alpha = screen.shape[2] == 4
        
        if self.screen_has_alpha:
            screen = screen[:, :, :3]   

        gray = cv2.cvtColor(screen, cv2.COLOR_RGB2GRAY)
        resized = cv2.resize(gray, (84, 84), interpolation=cv2.INTER_AREA)
        
        return resized
    
    def _get_raw_screen(self):
        """Get raw screen for visualization"""
        screen = self.pyboy.screen.ndarray
        if screen.shape[2] == 4:
            screen = screen[:, :, :3]
        return screen
    
    def _get_observation(self):
        """Get stacked frames as observation"""
        return np.array(self.frame_stack, dtype=np.uint8)
    
    def _read_memory(self, address):
        """Read value from game memory"""
        try:
            if hasattr(self.pyboy, 'memory') and address is not None:
                return self.pyboy.memory[address]
        except Exception as e:
            print(f"Error reading memory at {hex(address)}: {e}")
        return 0
        
    def _get_badges_count(self):
        """Get number of badges earned"""
        badges_byte = self._read_memory(self.BADGES_ADDRESS)
        return bin(badges_byte).count('1')
    
    def _get_pokemon_count(self):
        """Get number of Pokemon caught"""
        return self._read_memory(self.POKEMON_COUNT_ADDRESS)
    
    def _get_player_position(self):
        """Get player position"""
        x = self._read_memory(self.PLAYER_X_ADDRESS)
        y = self._read_memory(self.PLAYER_Y_ADDRESS)
        return (x, y)
    
    def _get_current_map(self):
        """Get current map ID"""
        return self._read_memory(self.CURRENT_MAP_ADDRESS)
    
    def _is_in_battle(self):
        """Check if currently in a battle"""
        battle_type = self._read_memory(0xD057)
        return battle_type != 0 
    
    def get_party_hp(self):
        """Get total HP of all Pokemon in party"""
        total_hp = 0
        for i in range(6):
            base_address = 0xD16B + (i * 44)
            species = self._read_memory(base_address)
            if species != 0:
                hp_address = base_address + 0x22
                current_hp = self._read_memory(hp_address)
                total_hp += current_hp
        return total_hp
    
    def _get_total_xp(self):
        """Get total experience points of all Pokemon"""
        total_xp = 0
        for i in range(6):
            xp_addr = 0xD16B + (i * 44) + 8
            xp = (self._read_memory(xp_addr) << 16) + \
                (self._read_memory(xp_addr + 1) << 8) + \
                (self._read_memory(xp_addr + 2))
            total_xp += xp
        return total_xp

    def _get_pokemon_levels(self):
        """Get levels of all Pokemon in party"""
        levels = {}
        for i in range(6):
            species_addr = 0xD16B + (i * 44)
            level_addr = species_addr + 33

            species = self._read_memory(species_addr)
            if species > 0:
                level = self._read_memory(level_addr)
                levels[species] = level
        return levels

    def _is_done(self):
        """Check if episode should end"""
        if self.step_count >= self.max_steps:
            return True

        if self._get_badges_count() >= 8:
            return True
        
        return False

    def reset(self, seed=None, options=None):
        """Reset the environment"""
        super().reset(seed=seed)
        
        self.pyboy.stop()
        self.pyboy = PyBoy(self.rom_path, window="null" if self.headless else "SDL2", debug=False)
        self.game_wrapper = self.pyboy.game_wrapper

        self._skip_intro()

        # Reset all state variables
        self.step_count = 0
        self.previous_badges = 0
        self.previous_pokemon_count = 0
        self.previous_levels = {}
        self.previous_areas_visited = set()
        self.previous_trainer_battles = 0
        self.previous_position = (0, 0)
        self.visited_maps = set()
        self.previous_battle_state = 0
        self.was_in_battle = False

        self._init_frame_stack()
        
        return self._get_observation(), {}

    def _skip_intro(self):
        """Skip the game intro sequence"""
        for _ in range(200):
            self.pyboy.tick()

        for _ in range(20):
            self.pyboy.send_input(WindowEvent.PRESS_BUTTON_A)
            for _ in range(5):
                self.pyboy.tick()
            self.pyboy.send_input(WindowEvent.RELEASE_BUTTON_A)
            for _ in range(10):
                self.pyboy.tick()

    def step(self, action):
        """Execute one step in the environment"""
        self._execute_action(action)

        new_frame = self._get_processed_frame()
        self.frame_stack.append(new_frame)

        observation = self._get_observation()
        reward = self._calculate_reward()
        done = self._is_done()

        self.step_count += 1
        
        info = {
            'badges': self._get_badges_count(),
            'pokemon_count': self._get_pokemon_count(),
            'position': self._get_player_position(),
            'map_id': self._get_current_map(),
            'step_count': self.step_count,
            'total_xp': self._get_total_xp(),
            'party_hp': self.get_party_hp(),
            'in_battle': self._is_in_battle(),
            'action_name': self.action_names[action],
            'reward': reward
        }

        return observation, reward, done, False, info

    def _execute_action(self, action):
        """Execute the given action"""
        actions = {
            0: self._move_up,
            1: self._move_down,
            2: self._move_left,
            3: self._move_right, 
            4: self._press_a,
            5: self._press_b,
            6: self._press_start,
            7: self._press_select
        }

        actions[action]()

    def _move_up(self):
        self.pyboy.send_input(WindowEvent.PRESS_ARROW_UP)
        for _ in range(8):
            self.pyboy.tick()
        self.pyboy.send_input(WindowEvent.RELEASE_ARROW_UP)
        for _ in range(4):
            self.pyboy.tick()

    def _move_down(self):
        self.pyboy.send_input(WindowEvent.PRESS_ARROW_DOWN)
        for _ in range(8):
            self.pyboy.tick()
        self.pyboy.send_input(WindowEvent.RELEASE_ARROW_DOWN)
        for _ in range(4):
            self.pyboy.tick()

    def _move_left(self):
        self.pyboy.send_input(WindowEvent.PRESS_ARROW_LEFT)
        for _ in range(8):
            self.pyboy.tick()
        self.pyboy.send_input(WindowEvent.RELEASE_ARROW_LEFT)
        for _ in range(4):
            self.pyboy.tick()

    def _move_right(self):
        self.pyboy.send_input(WindowEvent.PRESS_ARROW_RIGHT)
        for _ in range(8):
            self.pyboy.tick()
        self.pyboy.send_input(WindowEvent.RELEASE_ARROW_RIGHT)
        for _ in range(4):
            self.pyboy.tick()

    def _press_a(self):
        self.pyboy.send_input(WindowEvent.PRESS_BUTTON_A)
        for _ in range(8):
            self.pyboy.tick()
        self.pyboy.send_input(WindowEvent.RELEASE_BUTTON_A)
        for _ in range(4):
            self.pyboy.tick()

    def _press_b(self):
        self.pyboy.send_input(WindowEvent.PRESS_BUTTON_B)
        for _ in range(8):
            self.pyboy.tick()
        self.pyboy.send_input(WindowEvent.RELEASE_BUTTON_B)
        for _ in range(4):
            self.pyboy.tick()

    def _press_start(self):
        self.pyboy.send_input(WindowEvent.PRESS_BUTTON_START)
        for _ in range(8):
            self.pyboy.tick()
        self.pyboy.send_input(WindowEvent.RELEASE_BUTTON_START)
        for _ in range(4):
            self.pyboy.tick()
        
    def _press_select(self):
        self.pyboy.send_input(WindowEvent.PRESS_BUTTON_SELECT)
        for _ in range(8):
            self.pyboy.tick()
        self.pyboy.send_input(WindowEvent.RELEASE_BUTTON_SELECT)
        for _ in range(4):
            self.pyboy.tick()

    def _calculate_reward(self):
        """Calculate reward based on game state changes"""
        reward = 0
        reward_components = {}

        current_badges = self._get_badges_count()
        current_pokemon_count = self._get_pokemon_count()
        current_position = self._get_player_position()
        current_map = self._get_current_map()
        currently_in_battle = self._is_in_battle()

        # Badge rewards
        if current_badges > self.previous_badges:
            badge_reward = 1000 * (current_badges - self.previous_badges)
            reward += badge_reward
            reward_components['badges'] = badge_reward
            self.previous_badges = current_badges

        # Pokemon count rewards
        if current_pokemon_count > self.previous_pokemon_count:
            pokemon_reward = 100 * (current_pokemon_count - self.previous_pokemon_count)
            reward += pokemon_reward
            reward_components['pokemon_caught'] = pokemon_reward
            self.previous_pokemon_count = current_pokemon_count

        # Battle start reward
        if currently_in_battle and not self.was_in_battle:
            battle_start_reward = 30
            reward += battle_start_reward
            reward_components['battle_start'] = battle_start_reward
            self.battle_start_hp = self.get_party_hp()
            self.battle_start_time = self.step_count

        # Battle end rewards
        if self.was_in_battle and not currently_in_battle:
            current_hp = self.get_party_hp()
            if hasattr(self, 'battle_start_hp') and current_hp >= self.battle_start_hp:
                battle_duration = self.step_count - getattr(self, 'battle_start_time', self.step_count)
                time_bonus = max(0, 50 - battle_duration // 10)
                battle_win_reward = 100 + time_bonus
                reward += battle_win_reward
                reward_components['battle_win'] = battle_win_reward
            else:
                battle_loss_penalty = -50
                reward += battle_loss_penalty
                reward_components['battle_loss'] = battle_loss_penalty

            # Check for blackout
            if self._read_memory(0xCD3A) == 0x3E:
                blackout_penalty = -100
                reward += blackout_penalty
                reward_components['blackout'] = blackout_penalty

        # Exploration rewards
        if current_map not in self.visited_maps:
            self.visited_maps.add(current_map)
            exploration_reward = 20 + (5 if current_map > 20 else 0)
            reward += exploration_reward
            reward_components['exploration'] = exploration_reward

        # Early game rewards
        early_game_reward = self._get_early_game_rewards(current_map, current_position)
        if early_game_reward > 0:
            reward += early_game_reward
            reward_components['early_game'] = early_game_reward

        # Movement rewards
        if current_position != self.previous_position:
            movement_reward = 0.5
            if self._is_progressive_movement(self.previous_position, current_position, current_map):
                movement_reward += 0.5
            reward += movement_reward
            reward_components['movement'] = movement_reward

        # Wild Pokemon encounter
        wild_pokemon_flag = self._read_memory(0xD12A)
        if wild_pokemon_flag > 0:
            wild_encounter_reward = 25
            reward += wild_encounter_reward
            reward_components['wild_encounter'] = wild_encounter_reward

        # Level up rewards
        current_levels = self._get_pokemon_levels()
        if hasattr(self, 'previous_levels'):
            total_level_reward = 0
            for pokemon_id, level in current_levels.items():
                if pokemon_id in self.previous_levels:
                    level_gain = level - self.previous_levels[pokemon_id]
                    if level_gain > 0:
                        total_level_reward += 20 * level_gain
            
            if total_level_reward > 0:
                reward += total_level_reward
                reward_components['level_up'] = total_level_reward

        # Update state
        self.previous_levels = current_levels.copy()
        self.was_in_battle = currently_in_battle
        self.previous_position = current_position

        # Small penalty for each step to encourage efficiency
        reward -= 0.05

        return reward

    def _get_early_game_rewards(self, current_map, current_position):
        """Calculate early game specific rewards"""
        # Placeholder for early game logic
        # You can implement specific rewards for early game progression
        return 0

    def _is_progressive_movement(self, prev_pos, curr_pos, map_id):
        """Check if movement is progressive (exploring new areas)"""
        # Simple heuristic - you can make this more sophisticated
        distance = abs(curr_pos[0] - prev_pos[0]) + abs(curr_pos[1] - prev_pos[1])
        return distance > 0

    def render(self):
        """Render the environment"""
        if self.render_mode == 'human':
            return self._get_raw_screen()
        return None
        
    def close(self):
        """Close the environment"""
        self.pyboy.stop()

Pokemon Environment

In [4]:
def create_pokemon_env(rom_path, headless=True, max_steps=50000, render_mode=None):
    """Factory function to create Pokemon environment"""
    env = PokemonRedEnv(rom_path, headless=headless, max_steps=max_steps, render_mode=render_mode)
    
    if not headless:
        import pygame
        pygame.init()
        env.info_screen = pygame.display.set_mode((400, 300))
        env.font = pygame.font.Font(None, 24)
    
    return env

Interactive play

In [5]:
def interactive_play(rom_path):

    print("Starting interactive play mode")
    print("use arrows to move, A/S for A/B actions, enter for START, space for SELECT")
    
    env = PokemonRedEnv(rom_path, headless=False, max_steps=10000, render_mode='human')
    obs, _ = env.rest()

    import pygame
    pygame.init()

    screen = pygame.display.set_mode((400, 300 ))
    pygame.display.set_caption("Pokemon Red Controls")
    font = pygame.font.Font(None, 24)

    running = True
    clock = pygame.time.Clock()

    while running:
        for event in pygame.event.get():
            if event.type == pygame.QUIT:
                running = False
            elif event.type == pygame.KEYDOWN:
                action = None
                if event.key == pygame.K_UP:
                    action = 0
                elif event.key == pygame.K_DOWN:
                    action = 1
                elif event.key == pygame.K_LEFT:
                    action = 2
                elif event.key == pygame.K_RIGHT:
                    action = 3
                elif event.key == pygame.K_a:
                    action = 4
                elif event.key == pygame.K_s:
                    action = 5
                elif event.key == pygame.K_RETURN:
                    action = 6
                elif event.key == pygame.K_SPACE:
                    action = 7
                elif event.key == pygame.K_q:
                    running = False
                    
                if action is not None:
                    obs, reward, done, _, info = env.step(action)
                    print(f"Action: {env.action_names[action]}, "
                          f"Reward: {reward}, "
                          f"Badges: {info['badges']}, "
                          f"Pokemon Count: {info['pokemon_count']}, ")
                    
                    if done:
                        print("Episode finished")
                        obs, _ = env.reset()

        screen.fill((0, 0, 0))
        instructions = [
            "Pokemon red controls:",
            "arrow keys: move",
            "A/S: A/B actions",
            "Enter: START",
            "Space: SELECT",
            "Q: Quit"
        ]


        for i, instruction in enumerate(instructions):
            text = font.render(instruction, True, (255, 255, 255))
            screen.blit(text, (10, 10 + i * 30))

        pygame.display.flip()
        clock.tick(60)

    env.close()
    pygame.quit()

watch agent play

In [6]:
def watch_agent_play(model_path, rom_path , episodes= 5,delay = 0.1):
    print(f"loading model from {model_path}")

    try:
        model = DQN.load(model_path)
        print("model loaded successfully")
    except Exception as e:
        print(f"Error loading model: {e}")
        return
    
    env = PokemonRedEnv(rom_path, headless= False, max_steps = 10000)

    for episode in range(episodes):
        print(f"Episode {episode + 1}/{episodes}")
        obs, _ = env.reset()
        total_reward = 0
        episode_steps = 0

        time.sleep(2)

        for step in range(10000):
            action, _ = model.predict(obs, deterministic = False)
            obs, reward, done, _, info = env.step(action)
            total_reward += reward
            episode_steps += 1

            time.sleep(delay)


        if episode  < episodes - 1:
            print(f"Episode {episode + 1} finished: Total Reward: {total_reward}, Steps: {episode_steps}")

    env.close()


Gym Environment

In [7]:
class PokemonRedEnv(gym.Env):
    def _get_early_game_rewards(self, current_map, current_position):
        reward = 0

        if current_map == 13 and not getattr(self, 'reached_route_1', False):
            reward += 50
            self.reached_route_1 = True


        if self._get_pokemon_count() > 0 and not getattr(self, 'has_pokemon', False):
            reward += 200
            self.has_pokemon = True

        return reward
    

Train pokemon agent

In [8]:
def _train_pokemon_agent(show_training = False, interactive = False):

    ROM_PATH = r"C:\Users\Kamlesh\Documents\Code aur Masti\pokemonrl\Pokemon Red (USA, Europe).gb"
    MODEL_SAVE_PATH = "./pokemon_dqn_model"
    LOG_PATH = "./pokemon_logs"

    os.makedirs(MODEL_SAVE_PATH, exist_ok=True )
    os.makedirs(LOG_PATH, exist_ok=True)

    env = create_pokemon_env(ROM_PATH, headless= not(show_training or interactive))


    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")   
    print(f"Using device: {device}")

    model = DQN(
        "CnnPolicy",
        env,
        learning_rate=1e-4,
        buffer_size=50000,
        learning_starts=10000,
        batch_size=32,
        tau=1.0,
        gamma=0.99,
        train_freq=4,
        gradient_steps=1,
        target_update_interval=10000,
        exploration_fraction= 0.3,
        exploration_initial_eps=1.0,
        exploration_final_eps=0.05,
        max_grad_norm=0.5,
        tensorboard_log=LOG_PATH,
        device=device,
        verbose=1
    )

    checkpoint_callback = CheckpointCallback(
        save_freq=50000,
        save_path=MODEL_SAVE_PATH,
        name_prefix="pokemon_dqn",
        verbose=1
    )

    print("Starting training...")
    if interactive:
        print("Starting interactive training mode. Press 'h' to toggle manual control, 'q' to quit.")

    try:
        if interactive:
            trainer = InteractiveTrainer(env, model)
            trainer.start()

            obs = env.reset()
            total_reward = 0
            episode = 0

            while trainer.running:
                action = trainer.get_action(obs)
                obs, reward, done,  info = env.step(action)
                total_reward += reward

                if done:
                    print(f"Episode {episode + 1} finished: Total Reward: {total_reward}")
                    obs = env.reset()
                    total_reward = 0
                    episode += 1

                    if episode % 10 ==0:
                        model.save(f"{MODEL_SAVE_PATH}/pokemon_dqn_episode_{episode}")

                time.sleep(0.05)

        else:

            model.learn(
                total_timesteps = 2000000,
                callback = checkpoint_callback,
                log_interval = 100,
                progress_bar= True
            )

        model.save(f"{MODEL_SAVE_PATH}/pokemon_dqn_final")        
        print("Training completed and model saved.")

    except KeyboardInterrupt:
        print("Training interrupted by user. Saving model...")
        model.save(f"{MODEL_SAVE_PATH}/pokemon_dqn_interrupted")
        print("Model saved.")
        
    finally:
        if interactive and 'trainer' in locals():
            trainer.running = False
            if trainer.control_thread is not None:
                trainer.control_thread.join()
        env.close()

def manual_early_game_demo(rom_path):
    print("manual early game demo")

    env = PokemonRedEnv(rom_path, headless=False, max_steps=5000, render_mode='human')
    obs, _ = env.reset()

    time.sleep(2)

    demo_actions = [
            
        (1, "Moving down to Oak"),
        (1, "Moving down to Oak"),
        (4, "Press A to talk to Oak"),
        (4, "Continue dialogue"),
        (4, "Continue dialogue"),
        (4, "Continue dialogue"),
        (4, "Continue dialogue"),

        (0, "Moving up toward Pokeballs"),
        (3, "Moving right toward Pokeballs"),
        (4, "Press A on Pokeball"),
        (4, "Choose Pokemon"),
        (4, "Confirm choice"),
        (4, "Continue dialogue"),
        (4, "Continue dialogue"),

        (1, "Moving toward exit"),
        (1, "Moving toward exit"),
        (1, "Moving toward exit"),
        (4, "Press A at door (if needed)"),

        (0, "Moving north in Pallet Town"),
        (0, "Moving north in Pallet Town"),
        (0, "Moving north in Pallet Town"),
    ]

    total_reward = 0
    for i, (action, description) in enumerate(demo_actions):
        obs, reward, done, _, infor = env.step(action)
        total_reward += reward

        time.sleep(1)

        if done:
            print("demo completed")
            break

    print(f"Demo finished: Total Reward: {total_reward}")
    input("press enter to exit")
    env.close()


test environment

In [9]:
def test_environment(rom_path):
    print("testing environment")
    
    try:
        env = PokemonRedEnv(rom_path, headless=False, max_steps=1000, render_mode='human')
        print("Environment initialized successfully")

        obs, _ = env.reset()
        print("Environment reset successfully")

        for i in range(10):
            action = env.action_space.sample()
            obs, reward, done, _, info = env.step(action)
            print(f"Step {i + 1}: Action: {env.action_names[action]}, Reward: {reward}, Info: {info}")
            time.sleep(0.5)


        env.close()
        print("Environment closed successfully")

    except Exception as e:
        print(f"environment test failed: {e}")


curriculum training

In [10]:
def smart_training_with_curriculum(rom_path):
    print("starting smart training")

    MODEL_SAVE_PATH= "./pokemon_dqn_model"
    LOG_PATH = "./pokemon_logs"

    os.makedirs(MODEL_SAVE_PATH, exist_ok=True)
    os.makedirs(LOG_PATH, exist_ok=True)


    env = DummyVecEnv([lambda: create_pokemon_env(rom_path, headless=True, max_steps=2000)])


    device = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"using device:{device}")

    model = DQN(
        "CnnPolicy",
        env,
        learning_rate = 2e-4,
        buffer_size= 50000,
        learning_starts= 5000,
        batch_size=32,
        tau=1.0,
        gamma=0.95,
        train_freq=4,
        gradient_steps=1,
        target_update_interval=5000,
        exploration_fraction=0.3,
        exploration_initial_eps=1.0,
        exploration_final_eps= 0.1,
        max_grad_norm= 10,
        tensorboard_log=LOG_PATH,
        device = device,
        verbose=1
    )

    print("training phase 1")
    model.learn(total_timesteps=5000000,progress_bar= True)
    model.save(f"{MODEL_SAVE_PATH}/pokemon_dqn_phase1")

    print("phase 2: full gameplay training")
    env.close()
    def make_env():
        return create_pokemon_env(rom_path, headless = True, max_steps = 50000)
    env = DummyVecEnv([lambda: create_pokemon_env(rom_path, headless=True, max_steps = 50000)])

    model.learning_rate = 1e-4,
    model.gamma = 0.99,
    model.exploration_final_eps = 0.02

    checkpoint_callback = CheckpointCallback(
        save_freq=100000,
        save_path= MODEL_SAVE_PATH,
        name_prefix="pokemon_dqn_phase2"
    )

    print("training phase 2: full game training")

    model.learn(
        total_timesteps=1500000,
        callback=checkpoint_callback,
        progress_bar=True
    )
    
    model.save(f"{MODEL_SAVE_PATH}/pokemon_dqn_final_curriculum")
    print("curriculum training completed")

    env.close()


main

In [11]:
if __name__ == "__main__":

    print("Pokemon Red DQN training & visualization")
    print("=" * 50)

    if torch.cuda.is_available():
        print(f"CUDA available: {torch.cuda.get_device_name(0)}")
        print(f"CUDA memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
    else:
        print("CUDA not available, using CPU")

    print("\nOptions:")
    print("1. Train new model (headless)")
    print("2. Train new model (with visualization)")
    print("3. Interactive training ")
    print("4. Train with curriculum learning")
    print("5. Watch trained agent play")
    print("6. Play interactively")
    print("7. Manual early game demo")
    print("8. Test environment setup")

    choice = input("\nChoice (1-8): ")

    ROM_PATH = "Pokemon Red (USA, Europe).gb"

    if choice == "1":
        _train_pokemon_agent(show_training=False)
    elif choice == "2":
        _train_pokemon_agent(show_training=True)
    elif choice == "3":
        _train_pokemon_agent(show_training=True, interactive=True)
    elif choice == "4":
            smart_training_with_curriculum(ROM_PATH)
    elif choice == "5":
        model_path = input("enter model path: ")
        episodes = int(input("number of episodes to watch (default 3): "))
        delay = float(input("delay between actions in seconds (default 0.1): "))
        watch_agent_play(model_path, ROM_PATH, episodes, delay)
    elif choice == "6":
        interactive_play(ROM_PATH)
    elif choice == "7":
        manual_early_game_demo(ROM_PATH)
    elif choice == "8":
        test_environment(ROM_PATH)
    else:
        print("invalid choice")



Pokemon Red DQN training & visualization
CUDA not available, using CPU

Options:
1. Train new model (headless)
2. Train new model (with visualization)
3. Interactive training 
4. Train with curriculum learning
5. Watch trained agent play
6. Play interactively
7. Manual early game demo
8. Test environment setup


TypeError: object.__new__() takes exactly one argument (the type to instantiate)

In [None]:
import gymnasium as gym

print("Registered environments:")
# Correct way for Gymnasium 1.0.0
for env_id in gym.envs.registry:
    if 'pokemon' in env_id.lower() or 'red' in env_id.lower():
        print(f"  {env_id}")

# Also check all environments (first few)
print("\nAll registered environments (first 10):")
env_list = list(gym.envs.registry.keys())
for env_id in env_list[:10]:
    print(f"  {env_id}")

Registered environments:

All registered environments (first 10):
  CartPole-v0
  CartPole-v1
  MountainCar-v0
  MountainCarContinuous-v0
  Pendulum-v1
  Acrobot-v1
  phys2d/CartPole-v0
  phys2d/CartPole-v1
  phys2d/Pendulum-v0
  LunarLander-v3
