In [13]:
import torch
from torch import nn
import gymnasium as gym
import math

In [102]:
EPS_START = 0.9
EPS_END = 0.01
EPS_DECAY = 1000

BATCH_SIZE = 128
MAX_MEMORY_LENGTH = 1000
GAMMA = 0.99
TAU = 0.005

NUM_STEPS = 1_000_000
REPORT_EVERY = 100
LEARNING_RATE = 1e-5
NUM_STEPS_TO_AVERAGE = 100

In [103]:
class DQN(nn.Module):
    def __init__(self, obs_space, action_space):
        super().__init__()

        # Maps quality of each action at given state
        self.fc = nn.Sequential(nn.Linear(obs_space, 32),
                                nn.ReLU(),
                                nn.Linear(32, 32),
                                nn.ReLU(),
                                nn.Linear(32, action_space))

    def forward(self, state):
        return self.fc(state)

In [104]:
class ReplayBuffer:
    def __init__(self, max_size):
        self._max_size = max_size

        # Memory for storing transitions
        self._memory = {"state": [],
                       "action": [],
                       "new_state": [],
                       "reward": []}

    def get_batch(self, batch_size):
        if batch_size > len(self._memory["state"]):
            raise ValueError("Batch size must be smaller than the memory length")

        batch_indices = torch.randperm(len(self._memory["state"]))[:batch_size]

        states = torch.tensor(self._memory["state"])[batch_indices]
        actions = torch.tensor(self._memory["action"])[batch_indices]
        new_states = torch.tensor(self._memory["new_state"])[batch_indices]
        rewards = torch.tensor(self._memory["reward"])[batch_indices]

        return states, actions, new_states, rewards
    
    def update(self, state, action, new_state, reward):
        self._memory["state"].append(state)
        self._memory["action"].append(action)
        self._memory["new_state"].append(new_state)
        self._memory["reward"].append(reward)

        if len(self._memory["state"]) > self._max_size:
            self._memory["state"].pop(0)
            self._memory["action"].pop(0)
            self._memory["new_state"].pop(0)
            self._memory["reward"].pop(0)

    def __len__(self):
        return len(self._memory["state"])

In [105]:
def get_action(state, policy_network, action_space, eps_threshold):
    if torch.rand(1).item() > eps_threshold:
        actions = policy_network(state)
        return torch.argmax(actions).item()
    else:
        return torch.randint(action_space, (1,)).item()

In [113]:
def optimize_model(replay_buffer: ReplayBuffer, 
                   policy_network: torch.nn.Module, 
                   target_network: torch.nn.Module, 
                   optimizer: torch.optim.Adam, 
                   criterion: nn.SmoothL1Loss) -> None:
    optimizer.zero_grad()
    
    state_tensor, action_tensor, new_state_tensor, reward_tensor = replay_buffer.get_batch(BATCH_SIZE)

    # Get current q values for each action taken given a state
    q_values_for_actions_taken = policy_network(state_tensor)[0, action_tensor]

    expected_q_values = GAMMA*torch.max(target_network(new_state_tensor), dim=-1).values+reward_tensor
    expected_q_values *= reward_tensor

    loss = criterion(q_values_for_actions_taken, expected_q_values)

    loss.backward()
    torch.nn.utils.clip_grad_value_(policy_network.parameters(), 100)
    optimizer.step()

In [114]:
env = gym.make("CartPole-v1", render_mode="rgb_array")

In [115]:
obs_space = env.observation_space.shape[0]
action_space = env.action_space.n.item()

print(f"Observation space is: {obs_space}")
print(f"Action space is: {action_space}")

Observation space is: 4
Action space is: 2


In [116]:
policy_network = DQN(obs_space, action_space)
target_network = DQN(obs_space, action_space)
target_network.load_state_dict(policy_network.state_dict())

<All keys matched successfully>

In [117]:
replay_buffer = ReplayBuffer(MAX_MEMORY_LENGTH)

In [118]:
optimizer = torch.optim.Adam(policy_network.parameters(), LEARNING_RATE)
criterion = nn.SmoothL1Loss()

In [119]:
episode_reward = []

state, info = env.reset()
episode_length = torch.tensor(0, dtype=torch.float64)
for step in range(NUM_STEPS):
    eps_threshold = EPS_END + (EPS_START - EPS_END) * \
        math.exp(-1. * step / EPS_DECAY)
    
    action = get_action(torch.tensor(state), policy_network, action_space, eps_threshold)
    new_state, reward, terminated, truncated, _ = env.step(action)

    episode_length += 1

    if terminated or truncated:
        env.reset()
        episode_reward.append(episode_length)
        episode_length = torch.tensor(0, dtype=torch.float64)
        reward = 0

    replay_buffer.update(state, action, new_state, reward)

    if len(replay_buffer) > BATCH_SIZE:
        optimize_model(replay_buffer, policy_network, target_network, optimizer, criterion)

    target_net_state_dict = target_network.state_dict()
    policy_net_state_dict = policy_network.state_dict()
    for key in policy_net_state_dict:
        target_net_state_dict[key] = policy_net_state_dict[key]*TAU + target_net_state_dict[key]*(1-TAU)
    target_network.load_state_dict(target_net_state_dict)

    if len(episode_reward) > NUM_STEPS_TO_AVERAGE:
        mean_episode_reward = torch.mean(torch.tensor(episode_reward))
        episode_reward = []
        print(f"Step: {step}, mean episode reward: {mean_episode_reward}")

Step: 1437, mean episode reward: 14.237623762376238
Step: 2482, mean episode reward: 10.346534653465346
Step: 3443, mean episode reward: 9.514851485148515
Step: 4398, mean episode reward: 9.455445544554456
Step: 5363, mean episode reward: 9.554455445544555
Step: 6329, mean episode reward: 9.564356435643564
Step: 7295, mean episode reward: 9.564356435643564
Step: 8243, mean episode reward: 9.386138613861386
Step: 9183, mean episode reward: 9.306930693069306
Step: 10143, mean episode reward: 9.504950495049505
Step: 11099, mean episode reward: 9.465346534653465
Step: 12044, mean episode reward: 9.356435643564357
Step: 12997, mean episode reward: 9.435643564356436
Step: 13937, mean episode reward: 9.306930693069306
Step: 14895, mean episode reward: 9.485148514851485
Step: 15839, mean episode reward: 9.346534653465346
Step: 16773, mean episode reward: 9.247524752475247
Step: 17722, mean episode reward: 9.396039603960396
Step: 18669, mean episode reward: 9.376237623762377
Step: 19620, mean e

KeyboardInterrupt: 