In [1]:
import gymnasium as gym
import torch
import torch.nn as nn
import numpy as np
import random

In [None]:
'''
we use this because, otherwise, there would be a high correlation between memories, thus a higher likelihood that we'd select the same few memories,
then leading to some form of over-fitting.
remember that, in Q-learning, we update a guess with a guess. If we selected states that are too "near" each other, it'd be very likely that the update
to Q(s, a) would directly or indirectly alter that of Q(s', a').
'''
class ReplayMemory:
    # obs_size = state/observable space size
    def __init__(self, max_size: int, obs_size: int, batch_size:int=32):
        # apparently, np takes in (y, x) so we have to order the sizes like so
        self.state_buf = np.zeros((obs_size, max_size), dtype=np.float32)
        # different from pytorch's dqn tutorial, we don't store all actions but rather only the one we take
        self.action_buf = np.zeros(max_size, dtype=np.float32)
        self.reward_buf = np.zeros(max_size, dtype=np.float32)
        self.ns_buf = np.zeros((obs_size, max_size), dtype=np.float32)
        # this will serve as a mask later
        self.done_buf = np.fill(max_size, False, dtype=bool)

        self.max_size, self.batch_size = max_size, batch_size
        
        self.ptr, self.size = 0, 0

    def push(self, state, action, reward, ns):
        self.state_buf[idx] = state
        self.reward_buf[idx] = reward
        self.action_buf[idx] = action
        self.ns_buf[idx] = ns
        
        self.ptr = (self.ptr + 1) % self.max_size
        self.size = min(self.size + 1, self.max_size)
    
    def sample(self):
        idx = random.sample(range(self.size), self.batch_size)
        return dict(
            state=self.state_buf[idx],
            action=self.action_buf[idx],
            reward=self.reward_buf[idx],
            ns=self.ns_buf[idx],
        )

In [None]:
class Network(nn.Module):
    def __init__(
        self,
        in_size,
        out_size
    ):
        self.layers = nn.Sequential(
            nn.Linear(in_size, 128),
            nn.ReLU(),
            nn.Linear(128, 128),
            nn.ReLU(),
            nn.Linear(128, out_size)
        )
    
    def forward(x):
        return self.layers(x)

In [None]:
class DQNAgent:
    def __init__(self,
        env: gym.Env,
        max_eps: float = 0.9,
        min_eps: float = 0.1,
        eps_decay: float,
        gamma: float = .99,
        batch_size: int
    ):
        self.env = env
        self.eps = max_eps
        self.max_eps = max_eps
        self.min_eps = min_eps
        self.eps_decay = eps_decay
        self.gamma = gamma
        self.batch_size = batch_size
        obs_size = env.observation_space.shape[0]
        # different from obs_size, we use Discrete, not Box -> https://www.gymlibrary.dev/api/spaces/#discrete
        action_size = env.action_space.n

        self.device = torch.device(
            "cuda" if torch.cuda.is_available() else "cpu"
        )
        
        self.memory = ReplayBuffer(10000, obs_size, batch_size)
        
        self.dqn = Network(obs_size, action_size).to(self.device)
        self.dqn_target = Network(obs_size, action_size).to(self.device)
        self.dqn_target.load_state_dict(self.dqn.state_dict())

        self.optimizer = optim.AdamW(self.dqn.parameters())


        self.transition = []

        self.test = False

    def choose_action(self, state):
        if self.eps > np.random.random():
            selected_action = self.env.action_space.sample()
        else:
            # selected_action = self.dqn(self.env.action_space).max(0)[1]
            selected_action = self.dqn(state).argmax()
            # "WTF is this??" check: https://stackoverflow.com/a/63869655/15806103
        return selected_action
    
    def take_step(self, action):
        next_state, reward, terminated, truncated, _ = self.env.step(action)
        done = terminated or truncated
        
        if not self.is_test:
            # in training, self.transition will already look like
            # [state]
            # so we just add to it!
            self.transition += [next_state, reward, done]
            self.memory.push(*self.transition)
            
        return next_state, actions, reward
    
    def compute_dqn(self, batch):
        """calculate the DQN loss for a batch of memories"""

        target_value = self.dqn_target(self)+
        loss = F.SmoothL1Loss(q_value, target_value)

        
