In [13]:
import torch
from torch import nn
import gymnasium as gym
import math
import numpy as np

In [102]:
EPS_START = 0.9
EPS_END = 0.01
EPS_DECAY = 3000

BATCH_SIZE = 128
MAX_MEMORY_LENGTH = 10_000
GAMMA = 0.99
TAU = 0.005

NUM_STEPS = 1_000_000
LEARNING_RATE = 1e-4
NUM_STEPS_TO_AVERAGE = 10

In [None]:
class DQN(nn.Module):
    def __init__(self, obs_space, action_space):
        super().__init__()

        # Maps quality of each action at given state
        self.fc = nn.Sequential(nn.Linear(obs_space, 128),
                                nn.ReLU(),
                                nn.Linear(128, 128),
                                nn.ReLU(),
                                nn.Linear(128, action_space))

    def forward(self, state):
        return self.fc(state)

In [None]:
class ReplayBuffer:
    def __init__(self, max_size):
        self._max_size = max_size

        # Memory for storing transitions
        self._memory = {"state": [],
                       "action": [],
                       "new_state": [],
                       "reward": []}

    def get_batch(self, batch_size):
        if batch_size > len(self._memory["state"]):
            raise ValueError("Batch size must be smaller than the memory length")

        batch_indices = torch.randperm(len(self._memory["state"]))[:batch_size]

        states = torch.tensor(np.array(self._memory["state"]))[batch_indices]
        actions = torch.tensor(np.array(self._memory["action"]))[batch_indices]
        new_states = torch.tensor(np.array(self._memory["new_state"]))[batch_indices]
        rewards = torch.tensor(np.array(self._memory["reward"]))[batch_indices]

        return states, actions, new_states, rewards
    
    def update(self, state, action, new_state, reward):
        self._memory["state"].append(state)
        self._memory["action"].append(action)
        self._memory["new_state"].append(new_state)
        self._memory["reward"].append(reward)

        if len(self._memory["state"]) > self._max_size:
            self._memory["state"].pop(0)
            self._memory["action"].pop(0)
            self._memory["new_state"].pop(0)
            self._memory["reward"].pop(0)

    def __len__(self):
        return len(self._memory["state"])

In [None]:
def get_action(state, policy_network, action_space, eps_threshold):
    if torch.rand(1).item() > eps_threshold:
        actions = policy_network(state)
        return torch.argmax(actions).item()
    else:
        return torch.randint(action_space, (1,)).item()

In [None]:
def optimize_model(replay_buffer: ReplayBuffer, 
                   policy_network: torch.nn.Module, 
                   target_network: torch.nn.Module, 
                   optimizer: torch.optim.Adam, 
                   criterion: nn.SmoothL1Loss) -> None:
    optimizer.zero_grad()
    
    state_tensor, action_tensor, new_state_tensor, reward_tensor = replay_buffer.get_batch(BATCH_SIZE)

    # Get current q values for each action taken given a state
    q_values_for_actions_taken = policy_network(state_tensor).gather(1, action_tensor.reshape(-1, 1))

    expected_q_values = GAMMA*torch.max(target_network(new_state_tensor), dim=-1).values+reward_tensor
    # expected_q_values *= reward_tensor

    loss = criterion(q_values_for_actions_taken.squeeze(), expected_q_values)

    loss.backward()
    torch.nn.utils.clip_grad_value_(policy_network.parameters(), 100)
    optimizer.step()

In [None]:
env = gym.make("CartPole-v1", render_mode="rgb_array")

In [None]:
obs_space = env.observation_space.shape[0]
action_space = env.action_space.n.item()

print(f"Observation space is: {obs_space}")
print(f"Action space is: {action_space}")

In [None]:
policy_network = DQN(obs_space, action_space)
target_network = DQN(obs_space, action_space)
target_network.load_state_dict(policy_network.state_dict())

In [None]:
replay_buffer = ReplayBuffer(MAX_MEMORY_LENGTH)

In [None]:
optimizer = torch.optim.AdamW(policy_network.parameters(), LEARNING_RATE, amsgrad=True)
criterion = nn.SmoothL1Loss()

In [None]:
episode_reward = []

state, info = env.reset()
episode_length = torch.tensor(0, dtype=torch.float64)
for step in range(NUM_STEPS):
    eps_threshold = EPS_END + (EPS_START - EPS_END) * \
        math.exp(-1. * step / EPS_DECAY)
    
    action = get_action(torch.tensor(state), policy_network, action_space, eps_threshold)
    new_state, reward, terminated, truncated, _ = env.step(action)

    episode_length += 1

    if terminated or truncated:
        env.reset()
        episode_reward.append(episode_length)
        episode_length = torch.tensor(0, dtype=torch.float64)
        reward = 0

    
    replay_buffer.update(state, action, new_state, reward)

    state = new_state

    if len(replay_buffer) > BATCH_SIZE:
        optimize_model(replay_buffer, policy_network, target_network, optimizer, criterion)

    target_net_state_dict = target_network.state_dict()
    policy_net_state_dict = policy_network.state_dict()
    for key in policy_net_state_dict:
        target_net_state_dict[key] = policy_net_state_dict[key]*TAU + target_net_state_dict[key]*(1-TAU)
    target_network.load_state_dict(target_net_state_dict)

    if len(episode_reward) > NUM_STEPS_TO_AVERAGE:
        mean_episode_reward = torch.mean(torch.tensor(episode_reward))
        episode_reward = []
        print(f"Step: {step}, mean episode reward: {mean_episode_reward}")