In [1]:
import gymnasium as gym
import torch
import torch.nn as nn
import numpy as np
import random

In [None]:
class ReplayMemory:
    # obs_size = state/observable space size
    def __init__(self, max_size: int, obs_size: int, batch_size:int=32):
        # apparently, np takes in (y, x) so we have to order the sizes like so
        self.state_buf = np.zeros((obs_size, max_size), dtype=np.float32)
        # different from pytorch's dqn tutorial, we don't store all actions but rather only the one we take
        self.action_buf = np.zeros(max_size, dtype=np.float32)
        self.reward_buf = np.zeros(max_size, dtype=np.float32)
        self.ns_buf = np.zeros((obs_size, max_size), dtype=np.float32)
        # this will serve as a mask later
        self.done_buf = np.zeros(max_size, dtype=np.bool)

        self.max_size, self.batch_size = max_size, batch_size
        
        self.ptr, self.size = 0, 0

    def push(state, action, reward, ns):
        self.state_buf.append(state)
        self.reward_buf.append(reward)
        self.action_buf.append(action)
        self.ns_buf.append(ns)
        
        self.ptr = self.ptr + 1 % self.max_size
        self.size = min(self.size + 1, self.max_size)
    
    def sample(self):
        idx = random.randint(0, self.size)
        return dict(
            state=self.state_buf[idx],
            action=self.action_buf[idx],
            reward=self.reward_buf[idx],
            ns=self.ns_buf[idx],
        )

In [None]:
class Network(nn.Module):
    def __init__(
        self,
        in_size,
        out_size
    ):
        self.layers = nn.Sequential(
            nn.Linear(in_size, 128),
            nn.ReLU(),
            nn.Linear(128, 128),
            nn.ReLU(),
            nn.Linear(128, out_size)
        )
    
    def forward(x):
        return self.layers(x)

In [None]:
class DQNAgent:
    def __init__(self,
        env: gym.Env
    ):
        self.env = env

        obs_size = env.observation_space.shape[0]
        # different from obs_size, we use Discrete, not Box -> https://www.gymlibrary.dev/api/spaces/#discrete
        action_size = env.action_space.n
        self.nn = Network(obs_size, action_size)
