In [None]:
import gymnasium as gym
import pytorch
import pytorch.nn

In [None]:
class Network(nn.Module):
    def __init__(
        self,
        in_size,
        out_size
    ):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Linear(in_size, 128),
            nn.ReLU(),
            nn.Linear(128, 128),
            nn.ReLU(),
            nn.Linear(128, out_size)
        )
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.layers(x)

In [None]:
class DQNAgent:
    def __init__(self,
        env: gym.Env,
        seed: int,
        mu: int,
        lambda: int,
    ):
        self.env = env
        self.seed = seed
        self.memory_size = memory_size
        self.batch_size = batch_size
        self.target_update = target_update
        self.eps_decay = eps_decay
        self.eps = max_eps
        self.max_eps = max_eps
        self.min_eps = min_eps
        self.gamma = gamma
        obs_size = env.observation_space.shape[0]
        # different from obs_size, we use Discrete, not Box -> https://www.gymlibrary.dev/api/spaces/#discrete
        action_size = env.action_space.n

        self.device = torch.device(
            "cuda" if torch.cuda.is_available() else "cpu"
        )
        
        self.memory = ReplayBuffer(memory_size, obs_size, batch_size)
        
        self.dqn = Network(obs_size, action_size).to(self.device)
        self.dqn_target = Network(obs_size, action_size).to(self.device)
        self.dqn_target.load_state_dict(self.dqn.state_dict())
        # we're NOT going to be optimizing the target nn, so this is like a safelock
        self.dqn_target.eval()

        # in the original tutorial, LR is not initialized.. but I think we're fine?
        # self.optimizer = optim.AdamW(self.dqn.parameters(), lr=self.lr, amsgrad=True)
        self.optimizer = optim.AdamW(self.dqn.parameters(), amsgrad=True)

        self.transition = []

        self.is_test = False

    def choose_action(self, state: np.ndarray) -> int:
        explore = self.eps > np.random.random()
        if explore:
            selected_action = self.env.action_space.sample()
        if not explore or self.is_test:
            # selected_action = self.dqn(self.env.action_space).max(0)[1]
            # the reason why we wrap state in a tensor is because this is how Torch.nn.Module(s) operate. This way,
            # we can easily input a batch.
            selected_action = self.dqn(torch.FloatTensor(state)).argmax().item()
            # "WTF is this??" check: https://stackoverflow.com/a/63869655/15806103
        if not self.is_test:
            self.transition = [state, selected_action]
        return selected_action
    
    def take_step(self, action: int) -> Tuple[np.ndarray, np.float64, bool]:
        next_state, reward, terminated, truncated, _ = self.env.step(action)
        
        done = terminated or truncated
        
        if not self.is_test:
            # in training, self.transition will already look like
            # [state, action]
            # (for ex. if we use take_action(state))
            # so we just add to it!
            self.transition += [reward, next_state, done]
            self.memory.push(*self.transition)
            
        return next_state, reward, done
    
    def compute_dqn(self, samples: dict[str, np.array]):
        """calculate the DQN loss for a batch of memories"""
        device = self.device
        # each of these is a batch/samples of their corresponding name
        state = torch.FloatTensor(samples["state"], device=device)
        ns = torch.FloatTensor(samples["ns"], device=device)
        action = torch.LongTensor(samples["action"], device=device)
        reward = torch.FloatTensor(samples["reward"], device=device)
        done = torch.FloatTensor(samples["done"], device=device)

        # reminder that the target DQN refers to the network that's always a little behind the actual DQN
        # and which we use for the next_state prediction
        curr_q_value = self.dqn(state).gather(1, action.unsqueeze(1))
        with torch.no_grad():
            next_q_value = self.dqn_target(ns).max(1)[0]
        mask = 1 - done
        target = (reward + next_q_value * self.gamma * mask).to(self.device)
        loss = F.smooth_l1_loss(curr_q_value.squeeze(), target)

        '''
        this saved my 4SS and I would like to document how.
        For whatever reason, I had the line:
        85 -- `next_q_value = self.dqn_target(ns).max(1)[1]`
        Reminder that max()[1] returns the INDEX of the max.
        I spent an entire afternoon trying to understand why my model wasn't working,
        and that was why.
        And I realized it because the next_q tensor kept outputting either 0s or 1s.
        '''
        # # for debugging!
        # print('state', state[:2])
        # print('ns', ns[:2])
        # print('act', action[:2])
        # print('rew', reward[:2])
        # print('done', done[:2])
        # print('curr_q', curr_q_value[:2])
        # print('next_q', next_q_value[:2])
        # print('target', target[:2])

        return loss, torch.mean(curr_q_value).detach().numpy()
        
    def update_model(self) -> float:
        samples = self.memory.sample()

        loss, q_value = self.compute_dqn(samples)

        # https://stackoverflow.com/questions/48001598/why-do-we-need-to-call-zero-grad-in-pytorch
        # if in doubt, check comments on cartpole dqn
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        
        return loss.item(), q_value

    def train(self, seed: int, num_frames: int, plotting_interval: int=200):
        try:
            self.is_test = False

            state, _ = self.env.reset(seed=seed)
            total_episodes = 0
            # effective episode - n. of episodes after initial "memory gathering"
            eff_episode = 0
            score = 0
            scores = []
            losses = []
            predictions = []
            
            for frame_idx in range(1, num_frames+1):
                action = self.choose_action(state)
                next_state, reward, done = self.take_step(action)

                state = next_state
                score += reward

                if done:
                    total_episodes += 1
                    # trying to make gym reproducible
                    state, _ = self.env.reset(seed=seed+total_episodes)
                    scores.append(score)
                    score = 0

                # TODO: maybe initial batch before optimizing?
                loss, curr_q_value = self.update_model()
                losses.append(loss)
                predictions.append(curr_q_value)
                eff_episode += 1

                if eff_episode % self.target_update == 0:
                    self.target_hard_update()

                if frame_idx % plotting_interval == 0:
                    self._plot(frame_idx, scores, losses, epsilons, predictions)
        except KeyboardInterrupt:
            self.save_state()
        else:
            self.save_state('saved-state-done')
        self.env.close()
    
    def save_state(self, name='saved-state'):
        state_dict = self.dqn.state_dict()
        torch.save(state_dict, f'saved-states/{NB_ID}.s{self.seed}.pt')
        print('saved state!')
        return state_dict

    def load_state(self, state_dict):
        self.dqn.load_state_dict(state_dict)
        return state_dict
        
    def test(self, video_folder: str):
        self.is_test = True
        # save current environment to swap it back later on
        naive_env = self.env
        
        self.env = gym.wrappers.RecordVideo(self.env, video_folder=video_folder, name_prefix=NB_ID)
        state, _ = self.env.reset()
        self.env.start_video_recorder()
        done = False
        score = 0
        while not done:
            action = self.choose_action(state)
            next_state, reward, done = self.take_step(action)

            state = next_state
            score += reward

        print("score: ", score)
        self.env.close()

        self.env = naive_env

    def _plot(self, frame_idx, scores, losses, predictions):
        clear_output(wait=True)
        fig, (ax1, ax2, ax3) = plt.subplots(1,3, figsize=(20, 5), num=1, clear=True)

        ax1.set_title(f'frame {frame_idx} | score: {np.mean(scores[-10:])}')
        ax1.plot(scores)
        
        ax2.set_title('loss')
        ax2.plot(losses)

        ax3.set_title('Q value')
        ax3.plot(predictions)
        
        plt.show()


In [None]:
env_id = "CartPole-v1"
env = gym.make(env_id)

# hyperparameters
mu = 20
lambd = 100

agent = ESAgent(env)