In [7]:
import math
import gymnasium as gym
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
from collections import deque
import random

In [8]:
env = gym.make("CartPole-v1")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [9]:
class ReplayBuffer():
    def __init__(self, max_capacity: int = 10000):
        self.max_capacity = max_capacity
        self.memory = deque([], maxlen=max_capacity)

    def push(self, transition):
        self.memory.append(transition)

    def sample(self, batch_size: int = 8):
        return random.sample(self.memory, batch_size)
    
    def __len__(self):
        return len(self.memory)

In [10]:
class DQN(nn.Module):
    def __init__(self, n_obs:int, n_actions:int, width:int = 128):
        super(DQN, self).__init__()
        self.net = nn.Sequential(nn.Linear(n_obs, width), nn.ReLU(),
                                 nn.Linear(width, width), nn.ReLU(),
                                 nn.Linear(width, n_actions))


    def forward(self, x):
        return self.net(x)    



In [None]:
batch_size = 128
gamma = 0.99
eps_start = 0.9
eps_end = 0.01
eps_decay = 2500
tau = 0.005
lr = 3e-4

n_actions = env.action_space.n
state, info = env.reset()
n_obs = len(state)

policy = DQN(n_obs, n_actions).to(device)
target = DQN(n_obs, n_actions).to(device)
target.load_state_dict(policy.state_dict())

optimizer = torch.optim.AdamW(policy.parameters(), lr=lr)
buffer = ReplayBuffer(10000)

In [None]:
steps_done = 0

def select_action(state, steps_done:int):
    s = random.random()
    eps_threshold = eps_end + (eps_start - eps_end) * math.exp(-steps_done/eps_decay)
    steps_done += 1
    if s > eps_threshold:
        with torch.no_grad():
            return policy(state).argmax(axis=-1)
    else: # random action
        a = torch.tensor([[env.action_space.sample()]], device=device, dtype=torch.long)
        return a


In [None]:
def train():
    if len(buffer) < batch_size:
        return
    transitions = buffer.sample(batch_size)

    3