In [1]:
# Configuration Section
import os
import datetime
import json
from pathlib import Path
import gymnasium as gym
import ale_py
import torch
import pandas as pd
import matplotlib.pyplot as plt
from gymnasium.wrappers import AtariPreprocessing, FrameStackObservation, RecordVideo, TransformReward
from agents.reinforce import ReinforcePolicyGradientsAgent
from agents.dqn import DQNAgent
from models.cnns import CNNBackbone
from wrappers.gym_wrappers import ProgressiveRewardWrapper, ChannelFirstWrapper, NormalizeWrapper, ChannelWiseFrameStack

# Enable autoreload for all modules
%load_ext autoreload
%autoreload 2

# Training Configuration
CONFIG = {
    # Environment Settings
    "env_name": "ALE/Surround-v5",
    "game": "Surround",
    "mode": 2,
    "difficulty": 1,
    "frameskip": 6,
    
    # Directory Settings
    "base_dir": "./experiments",
    "video_subdir": "videos",
    "checkpoint_subdir": "checkpoints",
    "log_subdir": "logs",
    
    # Training Settings
    "num_episodes": 5000,
    "render_every_n": 5,
    "save_every_n": 5,
    
    # Environment Preprocessing
    "screen_size": 60,
    "grayscale": False,
    "frame_stack": 0,  # Set to 0 to disable
    "normalize": False, # Don't use this, normalization built into AtariPreprocessing wrapper
    "base_survival_reward": 1,  # Lower base reward since it will scale up
    # "survival_scaling_factor": 1.05, # Exponential growth factor
    "reward_scaling_factor": 1, # Scaling factor for rewards

    # Agent Selection
    # "agent_type": "reinforce",  # Options: "reinforce" or "dqn"
    "agent_type": "dqn",
    
    # Shared Agent Settings
    "device": "cuda" if torch.cuda.is_available() else "cpu",
    "learning_rate": 1e-3, 
    "gamma": 0.995, # Discount factor
    "use_cnn": True,
    
    # REINFORCE-specific Settings
    "reinforce": {
        "max_grad_norm": 0.5,
        "entropy_coef": 1e-4
    },
    
    "dqn": {
        # Buffer settings
        "buffer_size": 100000,
        "batch_size": 512,
        
        # Network update settings
        "target_update_freq": 1000,
        "update_freq": 4,
        "gradient_clip": 2.0,
        "hidden_dims": [256, 128],
        
        # Exploration settings
        "eps_start": 1.0,
        "eps_end": 0.05,
        "eps_decay": 0.99,
        
        # Algorithm settings
        "double_dqn": True,
        
        # Prioritized Experience Replay settings
        "per_alpha": 0.6,        # How much prioritization to use (0 = uniform, 1 = full prioritization)
        "per_beta_start": 0.4,   # Initial importance sampling correction
        "per_beta_end": 1.0,     # Final importance sampling correction
        "per_beta_steps": 100000 # Steps over which to anneal beta
    }
}


In [2]:

# Create experiment directory structure
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
experiment_dir = Path(CONFIG["base_dir"]) / f"{CONFIG['game']}_{CONFIG['agent_type']}_{timestamp}"

for subdir in [CONFIG["video_subdir"], CONFIG["checkpoint_subdir"], CONFIG["log_subdir"]]:
    (experiment_dir / subdir).mkdir(parents=True, exist_ok=True)

# Save configuration
with open(experiment_dir / "config.json", "w") as f:
    json.dump(CONFIG, f, indent=4)

In [3]:

def create_env(config):
    """Create and wrap the environment according to configuration."""
    env = gym.make(
        config["env_name"],
        render_mode="rgb_array",
        mode=config["mode"],
        difficulty=config["difficulty"],
        frameskip=config["frameskip"]
    )
    
    # Progressive reward wrapper
    env = ProgressiveRewardWrapper(
        env,
        base_survival_reward=config["base_survival_reward"]
    )
    
    if config["reward_scaling_factor"] != 1:
        env = TransformReward(env, lambda r: r * config["reward_scaling_factor"])
    
    # Preprocessing wrappers
    env = AtariPreprocessing(
        env,
        noop_max=15,
        frame_skip=1,
        screen_size=config["screen_size"],
        scale_obs=True,
        grayscale_obs=config["grayscale"]
    )

    # Video wrapper
    env = RecordVideo(
        env,
        video_folder=str(experiment_dir / config["video_subdir"]),
        episode_trigger=lambda episode_id: episode_id % config["render_every_n"] == 0,
        fps=60
    )
    
    env = ChannelFirstWrapper(env)
    
    if config["frame_stack"] > 0:
        env = FrameStackObservation(env, config["frame_stack"])
        env = ChannelWiseFrameStack(env)
    
    if config["normalize"]:
        env = NormalizeWrapper(env)
    
    return env

def create_agent(env, config):
    """Create agent based on configuration."""
    if config["agent_type"].lower() == "reinforce":
        return ReinforcePolicyGradientsAgent(
            env=env,
            device=config["device"],
            use_cnn=config["use_cnn"],
            lr=config["learning_rate"],
            gamma=config["gamma"],
            max_grad_norm=config["reinforce"]["max_grad_norm"],
            entropy_coef=config["reinforce"]["entropy_coef"]
        )
    elif config["agent_type"].lower() == "dqn":
        return DQNAgent(
            env=env,
            device=config["device"],
            use_cnn=config["use_cnn"],
            lr=config["learning_rate"],
            gamma=config["gamma"],
            buffer_size=config["dqn"]["buffer_size"],
            batch_size=config["dqn"]["batch_size"],
            target_update_freq=config["dqn"]["target_update_freq"],
            eps_start=config["dqn"]["eps_start"],
            eps_end=config["dqn"]["eps_end"],
            eps_decay=config["dqn"]["eps_decay"],
            # tau=config["dqn"]["tau"],
            hidden_dims=config["dqn"]["hidden_dims"],
            gradient_clip=config["dqn"]["gradient_clip"],
            double_dqn=config["dqn"]["double_dqn"],
            update_freq=config["dqn"]["update_freq"],
            per_alpha=config["dqn"]["per_alpha"],
            per_beta_start=config["dqn"]["per_beta_start"],
            per_beta_end=config["dqn"]["per_beta_end"],
            per_beta_steps=config["dqn"]["per_beta_steps"]
        )
    else:
        raise ValueError(f"Unknown agent type: {config['agent_type']}")

In [4]:
def plot_training_history(history, save_dir):
    """Plot and save training metrics."""
    df = pd.DataFrame(history)
    save_dir.mkdir(parents=True, exist_ok=True)
    
    # Define metrics groupings
    value_metrics = ['total_return', 'avg_q_value']  # DQN and REINFORCE
    loss_metrics = ['q_loss', 'policy_loss', 'entropy_loss', 'total_loss']  # Combined
    auxiliary_metrics = ['eps']  # DQN-specific
    
    # Filter to only metrics present in the data
    value_metrics = [m for m in value_metrics if m in df.columns]
    loss_metrics = [m for m in loss_metrics if m in df.columns]
    auxiliary_metrics = [m for m in auxiliary_metrics if m in df.columns]
    
    # Determine number of subplot rows needed
    n_rows = (bool(value_metrics) + bool(loss_metrics) + bool(auxiliary_metrics))
    
    # Create combined figure
    fig, axes = plt.subplots(n_rows, 1, figsize=(12, 5*n_rows))
    if n_rows == 1:
        axes = [axes]
    
    current_ax = 0
    
    # Plot value metrics
    if value_metrics:
        ax = axes[current_ax]
        for metric in value_metrics:
            ax.plot(df[metric], label=metric)
        ax.set_title('Value Metrics')
        ax.set_xlabel('Episode')
        ax.legend()
        current_ax += 1
    
    # Plot loss metrics
    if loss_metrics:
        ax = axes[current_ax]
        for metric in loss_metrics:
            ax.plot(df[metric], label=metric)
        ax.set_title('Loss Metrics')
        ax.set_xlabel('Episode')
        ax.legend()
        current_ax += 1
    
    # Plot auxiliary metrics
    if auxiliary_metrics:
        ax = axes[current_ax]
        for metric in auxiliary_metrics:
            ax.plot(df[metric], label=metric)
        ax.set_title('Auxiliary Metrics')
        ax.set_xlabel('Episode')
        ax.legend()
    
    plt.tight_layout()
    plt.savefig(save_dir / "training_history.png")
    plt.close()
    
    # Also save individual plots
    all_metrics = df.columns.tolist()
    excluded_metrics = {'steps', 'total_steps', 'buffer_size'}
    
    for metric in all_metrics:
        if metric not in excluded_metrics:
            plt.figure(figsize=(10, 6))
            plt.plot(df[metric])
            plt.title(metric)
            plt.xlabel('Episode')
            plt.ylabel(metric)
            plt.savefig(save_dir / f"{metric}_history.png")
            plt.close()

In [5]:

def save_checkpoint(agent, episode, history, save_dir, config):
    """Save agent checkpoint and training history."""
    timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
    
    # Create checkpoint directory if it doesn't exist
    save_dir.mkdir(parents=True, exist_ok=True)
    
    # Save agent state
    agent_path = save_dir / f"agent_{config['agent_type']}_episode_{episode}_{timestamp}.pth"
    agent.save(agent_path)
    
    # Save training state
    checkpoint = {
        'episode': episode,
        'history': history,
        'config': config
    }
    checkpoint_path = save_dir / f"training_state_episode_{episode}_{timestamp}.pt"
    torch.save(checkpoint, checkpoint_path)
    
    return agent_path, checkpoint_path

def load_checkpoint(checkpoint_path, env):
    """Load agent checkpoint and return training state."""
    checkpoint = torch.load(checkpoint_path)
    config = checkpoint['config']
    
    # Create agent with saved configuration
    agent = create_agent(env, config)
    agent.load(checkpoint_path)
    
    return agent, checkpoint['episode'], checkpoint['history'], config


In [6]:


def main(config):
    # Create experiment directory
    timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
    experiment_name = f"{config['game']}_{config['agent_type']}_{timestamp}"
    experiment_dir = Path(config["base_dir"]) / experiment_name
    
    # Create subdirectories
    for subdir in [config["video_subdir"], config["checkpoint_subdir"], config["log_subdir"]]:
        (experiment_dir / subdir).mkdir(parents=True, exist_ok=True)
    
    # Save configuration
    with open(experiment_dir / "config.json", "w") as f:
        json.dump(config, f, indent=4)
    
    # Create environment and agent
    env = create_env(config)
    agent = create_agent(env, config)
    
    print(f"\nAgent Type: {config['agent_type'].upper()}")
    print("Model Architecture:")
    print(agent.model if hasattr(agent, 'model') else agent.q_network)
    print("\nObservation Space:", env.observation_space)
    
    # Training loop
    history = []
    
    for episode in range(config["num_episodes"]):
        # Run episode
        results = agent.run_episode(env)
        history.append(results)
        
        # Print progress
        print(f"Episode {episode + 1}/{config['num_episodes']}: {results}")
        
        # Save checkpoint
        if (episode + 1) % config["save_every_n"] == 0:
            agent_path, checkpoint_path = save_checkpoint(
                agent,
                episode + 1,
                history,
                experiment_dir / config["checkpoint_subdir"],
                config
            )
            print(f"Saved checkpoint: {checkpoint_path}")

            # Plot training history
            plot_training_history(history, experiment_dir / config["log_subdir"])
    
    # Final save
    agent_path, checkpoint_path = save_checkpoint(
        agent,
        config["num_episodes"],
        history,
        experiment_dir / config["checkpoint_subdir"],
        config
    )
    
    # Plot training history
    plot_training_history(history, experiment_dir / config["log_subdir"])
    
    return agent, history, env, experiment_dir

In [7]:
agent, history, env, exp_dir = main(CONFIG)

A.L.E: Arcade Learning Environment (version 0.10.1+unknown)
[Powered by Stella]
  logger.warn(


CNN Dummy Output Shape: torch.Size([1, 4608])
CNN Dummy Output Shape: torch.Size([1, 4608])

Agent Type: DQN
Model Architecture:
CNNBackbone(
  (features): Sequential(
    (0): Conv2d(3, 16, kernel_size=(4, 4), stride=(2, 2))
    (1): ReLU()
    (2): Conv2d(16, 32, kernel_size=(3, 3), stride=(2, 2))
    (3): ReLU()
    (4): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1))
    (5): ReLU()
    (6): Flatten(start_dim=1, end_dim=-1)
  )
  (hidden_layers): ModuleList(
    (0): Linear(in_features=4608, out_features=512, bias=True)
    (1): ReLU()
    (2): Linear(in_features=512, out_features=256, bias=True)
    (3): ReLU()
    (4): Linear(in_features=256, out_features=5, bias=True)
  )
)

Observation Space: Box(0.0, 255.0, (3, 60, 60), float32)
Episode 1/1000: {'total_return': -2.8400000000000194, 'steps': 726, 'total_steps': 726, 'eps': 1.0, 'buffer_size': 726, 'q_loss': 0.013398924646123002, 'avg_q_value': -0.042401958593788244, 'max_q_value': 0.039605668142127494, 'min_q_value': -0.21700