In [10]:
import gymnasium as gym
import torch
import matplotlib.pyplot as plt
import numpy as np
from datetime    import datetime
from pathlib import Path
from agents.dqn import DQNAgent

# hot reload
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [11]:


# CartPole-specific configuration
CONFIG = {
    "env_name": "CartPole-v1",
    "num_episodes": 1000,
    "save_every_n": 50,
    
    # Agent settings
    "device": "cuda" if torch.cuda.is_available() else "cpu",
    "use_cnn": False,  # CartPole uses MLP
    
    # DQN specific
    "learning_rate": 1e-3,
    "gamma": 0.99,
    "buffer_size": 50000,  
    "batch_size": 32,
    "target_update_freq": 1, # (in steps), set to 1 for soft update
    "tau": 0.001,
    "eps_start": 1.0,
    "eps_end": 0.01,
    "eps_decay": 0.99,
    "hidden_dims": [32,32],  
    "gradient_clip": 1.0,
    "double_dqn": True,
    "update_freq": 4,
    "per_alpha": 0.6,        # How much prioritization to use (0 = uniform, 1 = full prioritization)
    "per_beta_start": 0.4,   # Initial importance sampling correction
    "per_beta_end": 1.0,     # Final importance sampling correction
    "per_beta_steps": 100000 # Steps over which to anneal beta
}

def plot_training_history(returns, q_losses, q_values, save_dir):
    """Plot and save training metrics."""
    fig, (ax1, ax2, ax3) = plt.subplots(3, 1, figsize=(10, 15))
    
    # Plot returns
    ax1.plot(returns)
    ax1.set_title('Episode Returns')
    ax1.set_xlabel('Episode')
    ax1.set_ylabel('Return')
    
    # Plot Q-losses
    ax2.plot(q_losses)
    ax2.set_title('Q-Loss')
    ax2.set_xlabel('Episode')
    ax2.set_ylabel('Loss')
    
    # Plot average Q-values
    ax3.plot(q_values)
    ax3.set_title('Average Q-Value')
    ax3.set_xlabel('Episode')
    ax3.set_ylabel('Q-Value')
    
    plt.tight_layout()
    plt.savefig(save_dir / 'training_curves.png')
    plt.close()


In [12]:

def main():
    # Create environment
    env = gym.make(CONFIG["env_name"])
    
    # Initialize agent
    agent = DQNAgent(
        env=env,
        device=CONFIG["device"],
        use_cnn=CONFIG["use_cnn"],
        lr=CONFIG["learning_rate"],
        gamma=CONFIG["gamma"],
        buffer_size=CONFIG["buffer_size"],
        batch_size=CONFIG["batch_size"],
        target_update_freq=CONFIG["target_update_freq"],
        eps_start=CONFIG["eps_start"],
        eps_end=CONFIG["eps_end"],
        eps_decay=CONFIG["eps_decay"],
        # tau=CONFIG["tau"],
        hidden_dims=CONFIG["hidden_dims"],
        gradient_clip=CONFIG["gradient_clip"],
        double_dqn=CONFIG["double_dqn"],
        update_freq=CONFIG["update_freq"],
        per_alpha=CONFIG["per_alpha"],
        per_beta_start=CONFIG["per_beta_start"],
        per_beta_end=CONFIG["per_beta_end"],
        per_beta_steps=CONFIG["per_beta_steps"]
    )
    
    # Create experiment directory
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    exp_dir = Path(f"./experiments/CartPole_{timestamp}")
    exp_dir.mkdir(parents=True, exist_ok=True)
    
    # Training loop
    returns = []
    q_losses = []
    q_values = []
    
    print("Starting training...")
    print(f"Agent architecture:\n{agent.q_network}")
    
    for episode in range(CONFIG["num_episodes"]):
        results = agent.run_episode(env)
        
        returns.append(results["total_return"])
        q_losses.append(results["q_loss"])
        q_values.append(results["avg_q_value"])
        
        # Print progress
        if (episode + 1) % 10 == 0:
            avg_return = np.mean(returns[-10:])
            print(f"Episode {episode + 1}/{CONFIG['num_episodes']}")
            print(f"Average Return (last 10): {avg_return:.2f}")
            print(f"Epsilon: {agent.eps:.3f}")
            print(f"Steps Taken: {results['steps']}")
            print(f"Latest Q-Loss: {results['q_loss']:.6f}")
            print("--------------------")
        
        # Save checkpoint
        if (episode + 1) % CONFIG["save_every_n"] == 0:
            checkpoint_path = exp_dir / f"checkpoint_episode_{episode+1}.pth"
            agent.save(checkpoint_path)
            
            # Plot current progress
            plot_training_history(returns, q_losses, q_values, exp_dir)
    
    # Final plots
    plot_training_history(returns, q_losses, q_values, exp_dir)
    
    return agent, env, returns

if __name__ == "__main__":
    agent, env, returns = main()

Starting training...
Agent architecture:
MLPBackbone(
  (model): Sequential(
    (0): Linear(in_features=4, out_features=32, bias=True)
    (1): ReLU()
    (2): Linear(in_features=32, out_features=32, bias=True)
    (3): ReLU()
    (4): Linear(in_features=32, out_features=2, bias=True)
  )
)
Episode 10/1000
Average Return (last 10): 24.00
Epsilon: 0.904
Steps Taken: 19
Latest Q-Loss: 0.154025
--------------------
Episode 20/1000
Average Return (last 10): 26.40
Epsilon: 0.818
Steps Taken: 26
Latest Q-Loss: 0.011843
--------------------
Episode 30/1000
Average Return (last 10): 17.60
Epsilon: 0.740
Steps Taken: 23
Latest Q-Loss: 0.010068
--------------------
Episode 40/1000
Average Return (last 10): 22.20
Epsilon: 0.669
Steps Taken: 15
Latest Q-Loss: 0.010655
--------------------
Episode 50/1000
Average Return (last 10): 16.50
Epsilon: 0.605
Steps Taken: 13
Latest Q-Loss: 0.012483
--------------------
Episode 60/1000
Average Return (last 10): 14.90
Epsilon: 0.547
Steps Taken: 14
Latest Q

: 