In [1]:
import gymnasium as gym
import torch
import torch.nn as nn
import numpy as np
import random

In [2]:
'''
we use this because, otherwise, there would be a high correlation between memories, thus a higher likelihood that we'd select the same few memories,
then leading to some form of over-fitting.
remember that, in Q-learning, we update a guess with a guess. If we selected states that are too "near" each other, it'd be very likely that the update
to Q(s, a) would directly or indirectly alter that of Q(s', a').
'''
class ReplayMemory:
    # obs_size = state/observable space size
    def __init__(self, max_size: int, obs_size: int, batch_size:int=32):
        # apparently, np takes in (y, x) so we have to order the sizes like so
        self.state_buf = np.zeros((obs_size, max_size), dtype=np.float32)
        # different from pytorch's dqn tutorial, we don't store all actions but rather only the one we take
        self.action_buf = np.zeros(max_size, dtype=np.float32)
        self.reward_buf = np.zeros(max_size, dtype=np.float32)
        self.ns_buf = np.zeros((obs_size, max_size), dtype=np.float32)
        # this will serve as a mask later
        # we use integers because numpy's syntax enables you to assign booleans and they
        # get converted to the dtype
        self.done_buf = np.zeros(max_size, dtype=np.float32)

        self.max_size, self.batch_size = max_size, batch_size
        self.ptr, self.size = 0, 0

    def push(self, state, action, reward, ns, not_done):
        self.state_buf[idx] = state
        self.reward_buf[idx] = reward
        self.action_buf[idx] = action
        self.ns_buf[idx] = ns
        self.not_done_buf[idx] = not_done
        
        self.ptr = (self.ptr + 1) % self.max_size
        self.size = min(self.size + 1, self.max_size)
    
    def sample(self):
        idx = random.sample(range(self.size), self.batch_size)
        return dict(
            state=self.state_buf[idx],
            action=self.action_buf[idx],
            reward=self.reward_buf[idx],
            ns=self.ns_buf[idx],
            not_dote=self.not_done_buf[idx],
        )

In [3]:
class Network(nn.Module):
    def __init__(
        self,
        in_size,
        out_size
    ):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Linear(in_size, 128),
            nn.ReLU(),
            nn.Linear(128, 128),
            nn.ReLU(),
            nn.Linear(128, out_size)
        )
    
    def forward(self, x):
        return self.layers(x)

In [5]:
class DQNAgent:
    def __init__(self,
        env: gym.Env,
        memory_size: int,
        batch_size: int,
        eps_decay: float,
        max_eps: float = 0.9,
        min_eps: float = 0.1,
        gamma: float = .99,
    ):
        self.env = env
        self.memory_size = memory_size
        self.eps_decay = eps_decay
        self.eps = max_eps
        self.max_eps = max_eps
        self.min_eps = min_eps
        self.gamma = gamma
        self.batch_size = batch_size
        obs_size = env.observation_space.shape[0]
        # different from obs_size, we use Discrete, not Box -> https://www.gymlibrary.dev/api/spaces/#discrete
        action_size = env.action_space.n

        self.device = torch.device(
            "cuda" if torch.cuda.is_available() else "cpu"
        )
        
        self.memory = ReplayBuffer(10000, obs_size, batch_size)
        
        self.dqn = Network(obs_size, action_size).to(self.device)
        self.dqn_target = Network(obs_size, action_size).to(self.device)
        self.dqn_target.load_state_dict(self.dqn.state_dict())

        self.optimizer = optim.AdamW(self.dqn.parameters(), lr=self.lr, amsgrad=True)


        self.transition = []

        self.test = False

    def choose_action(self, state):
        if self.eps > np.random.random():
            selected_action = self.env.action_space.sample()
        else:
            # selected_action = self.dqn(self.env.action_space).max(0)[1]
            # the reason why we wrap state in a tensor is because this is how Torch.nn.Module(s) operate. This way,
            # we can easily input a batch.
            selected_action = self.dqn(torch.FloatTensor(state)).argmax()
            # "WTF is this??" check: https://stackoverflow.com/a/63869655/15806103
        if not self.is_test:
            self.transition = [state, selected_action]
        return selected_action
    
    def take_step(self, action):
        next_state, reward, terminated, truncated, _ = self.env.step(action)
        
        done = terminated or truncated
        
        if not self.is_test:
            # in training, self.transition will already look like
            # [state, action]
            # (for ex. if we use take_action(state))
            # so we just add to it!
            self.transition += [next_state, reward, done]
            self.memory.push(*self.transition)
            
        return next_state, reward, done
    
    def compute_dqn(self, batch):
        """calculate the DQN loss for a batch of memories"""
        # reminder that the target DQN refers to the network that's always a little behind the actual DQN
        # and which we use for the next_state prediction
        curr_qvalue = self.dqn(batch.state).gather(1, batch.action)
        next_q_value = self.dqn_target(self.batch.ns).max(1)[1]
        mask = 1 - done
        target = (next_q_value * self.gamma + reward) * mask
        # batch.not_done acts as mask
        loss = F.SmoothL1Loss(curr_q_value, target)
        return loss
        
    def update_model(self):
        samples = self.memory.sample()

        loss = self.compute_dqn(samples)

        # https://stackoverflow.com/questions/48001598/why-do-we-need-to-call-zero-grad-in-pytorch
        # if in doubt, check comments on cartpole dqn
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        # I think we do it for graphing purposes?
        return loss.item()

    def train(self, num_frames, plotting_interval=200):
        self.is_test = False

        state = self.env.reset()
        # effective episode - n. of episodes after initial "memory gathering"
        eff_episode = 0
        score = 0
        scores = []
        losses = []
        epsilons = []
        
        for frame_idx in range(1, num_frames+1):
            action = self.select_action(state)
            next_state, reward, done = self.step(action)

            state = next_state
            score += reward

            # we will first let it collect a few memories,
            # and THEN start training.
            if len(self.memory) >= self.batch_size:
                loss = self.update_model()
                losses.append(loss)
                # linear decay
                self.eps = max(self.min_eps, self.eps - (self.max_eps - self.min_eps) * self.eps_decay)
                epsilons.append(self.eps)

                if eff_episode % self.target_update == 0:
                    self.target_hard_update()
                # I think I could use this but I want to be consistent with the material so nvm
                # self.eps = max(self.min_eps, self.eps - self.eps_decay)

            if frame_idx % self.plotting_interval == 0:
                self._plot(frame_idx, scores, losses, epsilons)
        self.env.close()
    def test(self):
        self.is_test = True

    def target_hard_update(self):
        self.dqn_target.load_state_dict(self.dqn.state_dict())

    def _plot(self, frame_idx, scores, losses, epsilons)

        # https://stackoverflow.com/questions/37970424/what-is-the-difference-between-drawing-plots-using-plot-axes-or-figure-in-matpl
        # https://matplotlib.org/stable/_images/anatomy.png
        fig, (ax1, ax2, ax3) = plt.subplots(3)
        ax1.title
