In [2]:
!pip install gymnasium[other]

Collecting moviepy>=1.0.0 (from gymnasium[other])
  Downloading moviepy-2.1.2-py3-none-any.whl.metadata (6.9 kB)
Collecting opencv-python>=3.0 (from gymnasium[other])
  Downloading opencv_python-4.11.0.86-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (20 kB)
Collecting seaborn>=0.13 (from gymnasium[other])
  Downloading seaborn-0.13.2-py3-none-any.whl.metadata (5.4 kB)
Collecting imageio<3.0,>=2.5 (from moviepy>=1.0.0->gymnasium[other])
  Downloading imageio-2.37.0-py3-none-any.whl.metadata (5.2 kB)
Collecting imageio_ffmpeg>=0.2.0 (from moviepy>=1.0.0->gymnasium[other])
  Downloading imageio_ffmpeg-0.6.0-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting proglog<=1.0.0 (from moviepy>=1.0.0->gymnasium[other])
  Downloading proglog-0.1.10-py3-none-any.whl.metadata (639 bytes)
Collecting python-dotenv>=0.10 (from moviepy>=1.0.0->gymnasium[other])
  Downloading python_dotenv-1.0.1-py3-none-any.whl.metadata (23 kB)
Collecting pillow>=8 (from matplotlib>=3.0-

In [3]:
import gymnasium as gym
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import random
from collections import deque
import matplotlib.pyplot as plt
import logging
from gymnasium.wrappers import RecordEpisodeStatistics, RecordVideo

# Hyperparameters
BATCH_SIZE = 64
GAMMA = 0.99
EPS_START = 1.0
EPS_END = 0.01
EPS_DECAY = 0.995
TARGET_UPDATE = 10
MEMORY_SIZE = 10000
LEARNING_RATE = 0.001
training_period = 25  # Record every 25 episodes
num_training_episodes = 100  # Total number of training episodes

# Neural Network for DQN
class DQN(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(DQN, self).__init__()
        self.fc1 = nn.Linear(input_dim, 128)
        self.fc2 = nn.Linear(128, 128)
        self.fc3 = nn.Linear(128, output_dim)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        return self.fc3(x)

# DQN Agent
class DQNAgent:
    def __init__(self, env):
        self.env = env
        self.input_dim = env.observation_space.shape[0]
        self.output_dim = env.action_space.n

        self.policy_net = DQN(self.input_dim, self.output_dim)
        self.target_net = DQN(self.input_dim, self.output_dim)
        self.target_net.load_state_dict(self.policy_net.state_dict())
        self.target_net.eval()

        self.optimizer = optim.Adam(self.policy_net.parameters(), lr=LEARNING_RATE)
        self.memory = deque(maxlen=MEMORY_SIZE)
        self.steps_done = 0
        self.epsilon = EPS_START

    def select_action(self, state):
        sample = random.random()
        if sample < self.epsilon:
            return self.env.action_space.sample()  # Explore
        else:
            with torch.no_grad():
                return self.policy_net(state).argmax().item()  # Exploit

    def store_transition(self, state, action, reward, next_state, done):
        self.memory.append((state, action, reward, next_state, done))

    def optimize_model(self):
        if len(self.memory) < BATCH_SIZE:
            return

        batch = random.sample(self.memory, BATCH_SIZE)
        states, actions, rewards, next_states, dones = zip(*batch)

        states = torch.cat(states)
        actions = torch.tensor(actions, dtype=torch.long)
        rewards = torch.tensor(rewards, dtype=torch.float)
        next_states = torch.cat(next_states)
        dones = torch.tensor(dones, dtype=torch.float)

        current_q_values = self.policy_net(states).gather(1, actions.unsqueeze(1))
        next_q_values = self.target_net(next_states).max(1)[0].detach()
        target_q_values = rewards + (1 - dones) * GAMMA * next_q_values

        loss = nn.functional.mse_loss(current_q_values.squeeze(), target_q_values)
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        # Decay epsilon
        self.epsilon = max(EPS_END, self.epsilon * EPS_DECAY)

    def update_target_net(self):
        self.target_net.load_state_dict(self.policy_net.state_dict())

# Main
if __name__ == "__main__":
    # Set up logging
    logging.basicConfig(level=logging.INFO)

    # Create the environment with recording wrappers
    env = gym.make("CartPole-v1", render_mode="rgb_array")
    env = RecordVideo(env, video_folder="cartpole-agent", name_prefix="training",
                      episode_trigger=lambda x: x % training_period == 0)  # Record periodically
    env = RecordEpisodeStatistics(env)

    # Create the DQN agent
    agent = DQNAgent(env)

    # Train the agent
    for episode_num in range(num_training_episodes):
        state, _ = env.reset()
        state = torch.tensor(state, dtype=torch.float32).unsqueeze(0)
        episode_over = False

        while not episode_over:
            action = agent.select_action(state)  # Use the trained agent's action
            next_state, reward, terminated, truncated, info = env.step(action)
            episode_over = terminated or truncated

            # Store transition and optimize the model
            next_state = torch.tensor(next_state, dtype=torch.float32).unsqueeze(0)
            agent.store_transition(state, action, reward, next_state, episode_over)
            agent.optimize_model()

            state = next_state

        # Log episode statistics
        logging.info(f"Episode {episode_num + 1}: {info['episode']}")

        # Update the target network periodically
        if episode_num % TARGET_UPDATE == 0:
            agent.update_target_net()

    env.close()

  logger.warn(
INFO:root:Episode 1: {'r': 22.0, 'l': 22, 't': 0.057855}
INFO:root:Episode 2: {'r': 33.0, 'l': 33, 't': 0.002229}
INFO:root:Episode 3: {'r': 13.0, 'l': 13, 't': 0.019272}
INFO:root:Episode 4: {'r': 44.0, 'l': 44, 't': 0.090424}
INFO:root:Episode 5: {'r': 26.0, 'l': 26, 't': 0.05275}
INFO:root:Episode 6: {'r': 10.0, 'l': 10, 't': 0.02524}
INFO:root:Episode 7: {'r': 21.0, 'l': 21, 't': 0.067525}
INFO:root:Episode 8: {'r': 17.0, 'l': 17, 't': 0.037292}
INFO:root:Episode 9: {'r': 14.0, 'l': 14, 't': 0.034833}
INFO:root:Episode 10: {'r': 51.0, 'l': 51, 't': 0.09857}
INFO:root:Episode 11: {'r': 17.0, 'l': 17, 't': 0.039353}
INFO:root:Episode 12: {'r': 42.0, 'l': 42, 't': 0.078392}
INFO:root:Episode 13: {'r': 8.0, 'l': 8, 't': 0.02103}
INFO:root:Episode 14: {'r': 13.0, 'l': 13, 't': 0.032225}
INFO:root:Episode 15: {'r': 12.0, 'l': 12, 't': 0.029744}
INFO:root:Episode 16: {'r': 8.0, 'l': 8, 't': 0.022524}
INFO:root:Episode 17: {'r': 12.0, 'l': 12, 't': 0.028173}
INFO:root:Episod