# TL;DR

In this lab scenario you will finish implementation of a variant of the Q-learning method, called DQN. On top of the usual q-learning using neural nets as function approximations, DQN uses:
* experience replay - used to increase efficacy of samples from the environment and decorrelate elements of a batch, 
* target network - used to avoid constantly changing targets in the learning process (to avoid "chasing own tail").

For algorithm's details recall the lecture and/or follow the [original paper](https://arxiv.org/abs/1312.5602), which is rather self-contained and not hard to understand. 

Without changing any hyperparameters, the agent should solve the problem (obtain rewards ~200) after ~1000 episodes, which for GPU runtime takes ~10 minutes of training.

You can run this code locally (not in Colab), which allows to see the agent in action, unfortunately visualization inside Colab worked poorly and was removed from this lab scenario.

# Imports

In [None]:
!pip install Box2D==2.3.10 gym==0.26.2 pygame==2.6.1
!pip install numpy==1.26.4 torch
!pip install tensorboard
!pip install matplotlib
!pip install ipywidgets
!pip install opencv-python
!pip install tqdm

In [2]:
import gym
import torch
import torch.nn as nn
import torch.nn.functional as F

from torch.utils.tensorboard import SummaryWriter

import numpy as np
import argparse
import datetime
import time
import random
from collections import namedtuple
from pathlib import Path
from typing import Tuple, List
from matplotlib import pyplot as plt
import IPython.display as display
import ipywidgets as widgets
import cv2
from tqdm import tqdm

%load_ext tensorboard

# Utilities


## Misc

In [3]:
def try_gpu(i: int = 0):
    """Return gpu(i) if exists, otherwise return cpu()"""
    if torch.cuda.device_count() >= i + 1:
        return torch.device(f"cuda:{i}")
    return torch.device("cpu")


def save_model(model: nn.Module, PATH: str):
    """Saves model's state_dict.

    Reference: https://pytorch.org/tutorials/beginner/saving_loading_models.html
    """
    torch.save(model.state_dict(), PATH)


def load_model(model: nn.Module, PATH: str):
    """Loads model's parameters from state_dict"""
    model.load_state_dict(torch.load(PATH))

## Scheduler

Training RL agents requires dealing with exploration-exploitation trade-off. To handle this we will adopt the most basic, but extremely efficient, epsilon-greedy strategy. At the beginning our agent will focus on exploration, and over time will start exploiting his knowledge, and thus becoming more and more greedy. To implement this logic we will use LinearDecay scheduler.

In [4]:
class Constant:
    """Constant scheduler.

    Can be used e.g. to create agent with with greedy policy, namely epsilon == 0
    """

    def __init__(self, value: float, *args):
        self._value = value

    def value(self, *args):
        return self._value


class LinearDecay:
    """Linear decay scheduler.

    At each call linearly decays the value by simply subtracting `decay` from the current value,
    until some minimum value is reached.
    Can be used e.g. to decay epsilon value for epsilon-greedy exploration/exploitation strategy.
    """

    def __init__(self, initial_value: float, final_value: float, decay: float):
        self._value = initial_value
        self.final_value = final_value
        self.decay = decay

    def value(self, *args) -> float:
        self._value = max(self.final_value, self._value - self.decay)
        return self._value

## Replay buffer

The key trick that makes DQN feasible is replay buffer. The idea is to store observed transitions, sample them randomly and perform updates based on them. This solution has many advantages, the most significant ones are:

1.   *Data efficiency* - each transition (env step) can be used in many weight updates.
2.   *Data decorrelation* - consecutive transitions are naturally highly correlated. Randomizing the samples reduces these correlations, thus reducing variance of the updates.

Note that when learning by experience replay, it is necessary to learn off-policy (because our current parameters are different to those used to generate the sample), which motivates the choice of Q-learning.

In [5]:
# non_terminal_mask is a mask indicating whether the state is terminal or not
# it will become usefull when using target_net for predicting qvalues.
Transition = namedtuple(
    "Transition", ("state", "action", "next_state", "reward", "non_terminal_mask")
)


class ReplayBuffer(object):
    def __init__(self, size: int):
        """Create new replay buffer.

        Args:
            size: capacity of the buffer
        """
        self._storage: List[Transition] = []
        self._capacity = size
        self._next_idx = 0

    def add(self, data: Transition):
        if len(self._storage) < self._capacity:
            self._storage.append(None)
        self._storage[self._next_idx] = data
        self._next_idx = (self._next_idx + 1) % self._capacity

    def sample(self, batch_size: int) -> List[Transition]:
        """Sample batch of eixperience from memory.

        Args:
            batch_size: size of the batch

        Returns:
            batch of transitions
        """
        batch = random.sample(self._storage, batch_size)
        return batch

    def __len__(self) -> int:
        return len(self._storage)

## MLP Network

For fast iteration we will stick to numerical observations (original DQN paper works with graphical observations). We will use simple MLP to net approximate our estimates of Q-values for (action, states).

In [6]:
class MLP(nn.Module):
    """Simple MLP net.

    Each of the layers, despite the last one, is followed by ReLU non-linearity.
    """

    def __init__(self, layers_sizes: List[int]):
        super(MLP, self).__init__()

        modules = []
        for in_features, out_features in zip(layers_sizes, layers_sizes[1:-1]):
            modules.extend(
                [
                    nn.Linear(in_features, out_features),
                    nn.ReLU(),
                ]
            )
        # final output is not followed by non-linearity
        modules.extend([nn.Linear(layers_sizes[-2], layers_sizes[-1])])
        self.layers = nn.Sequential(*modules)

    def forward(self, state):
        return self.layers(state)

# DQN Agent

First we implement constructor and some utility functions for the agent.

In [22]:
class DQNAgent:
    def __init__(
        self, exploration_fn, policy_net: torch.nn.Module, target_net: torch.nn.Module
    ):
        self.exploration_fn = exploration_fn
        self.policy_net = policy_net
        self.target_net = target_net
        self.optim = None
        self.replay_buffer = None

    def save_policy_net(self, checkpoint: str):
        """Saves policy_net parameters as given checkpoint.

        state_dict of current policy_net is stored.

        Args:
            checkpoint: path were to store model's parameters.
        """
        save_model(self.policy_net, checkpoint)

    def load_policy_net(self, checkpoint: str):
        """Loads policy_net parameters from given checkpoint.

        Note that proper model should be instantiated as only parameters of form state_dict
        are stored as a checkpoint.

        Args:
            checkpoint: path to model's parameters.
        """
        load_model(self.policy_net, checkpoint)

    def play_episodes(self, n_episodes: int, env: gym.Env):
        """Function to watch the agent playing - locally"""

        def render(frame, widget):
            frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
            _, jpeg = cv2.imencode(".jpeg", frame)
            widget.value = jpeg.tobytes()

        self.policy_net.eval()

        for episode in range(n_episodes):
            # 0.5 sec breaks between episodes, so it's easier to watch
            time.sleep(0.5)
            state = env.reset()

            if isinstance(state, tuple):
                state = state[0]

            total_reward, timesteps, done = 0, 0, False
            image_widget = widgets.Image(format="jpeg")
            display.display(image_widget)
            frame = env.render()
            render(frame=frame, widget=image_widget)
            
            episode_steps_left = 1000
            while not done and episode_steps_left > 0:
                episode_steps_left -= 1
                # Pick next action, simulate and observe next_state and reward
                action = self.act(state)
                next_state, reward, done, _, _ = env.step(action.item())
                state = next_state

                frame = env.render()
                render(frame=frame, widget=image_widget)
                # To make watching easier
                time.sleep(0.01)

                total_reward += reward
                timesteps += 1

            print(f"Episode length: {timesteps}, total reward: {total_reward}")

### Policy 

Given observation agent follows epsilon-greedy strategy.

In [23]:
class DQNAgent(DQNAgent):
    def act(self, obs) -> torch.Tensor:
        """Epsilon-greedy policy derived from policy_net

        With probability epsilon select a random action a_t.
        Otherwise select a_t = max_a(Q(obs, a; theta))
        """
        eps_exploration = self.exploration_fn.value()
        if torch.rand(1).item() <= eps_exploration:
            return torch.randint(0, N_ACTIONS, [1])
        else:
            if not type(obs) == torch.Tensor:
                obs = torch.tensor(obs, dtype=torch.float32, device=DEVICE).view(
                    -1, OBS_SHAPE
                )
            with torch.no_grad():
                return torch.argmax(self.policy_net(obs))

### Learning procedure

In [24]:
class DQNAgent(DQNAgent):
    def learn(
        self,
        gamma: float,
        optim: torch.optim.Optimizer,
        n_episodes: int,
        batch_size: int,
        target_update_interval: int,
        buffer_size: int,
        checkpoints_dir: str,
        checkpoint_save_interval: int,
        tensorboard_log_dir: str,
        env: gym.Env,
    ):
        self.optim = optim
        self.replay_buffer = ReplayBuffer(buffer_size)
        self.target_net.load_state_dict(self.policy_net.state_dict())
        self.target_net.eval()
        self.policy_net.train()

        total_steps, rewards_history = 0, []
        writer = SummaryWriter(tensorboard_log_dir)

        for episode in tqdm(range(n_episodes), desc="Training episode"):
            episode_reward, episode_steps, done = 0, 0, False
            state = env.reset()

            if isinstance(state, tuple):
                state = state[0]
            episode_steps_left = 1000
            while not done and episode_steps_left > 0:
                episode_steps_left -= 1
                # Pick next action, simulate and observe next_state and reward
                action = self.act(state)
                next_state, reward, done, _, _ = env.step(action.item())

                ##### TODO IMPLEMENT #####
                # Store Transition in replay buffer. ("state", "action", "next_state", "reward", "non_terminal_mask")

                ##### END OF TODO    #####

                # Update target_net
                loss = self._update_policy_net(gamma, batch_size)

                # Update current state
                state = next_state

                # Update target_net with current parameters
                if (total_steps + 1) % target_update_interval == 0:
                    self._update_target_net()
                #
                if (total_steps + 1) % checkpoint_save_interval == 0:
                    self.save_policy_net(
                        f"{checkpoints_dir}/params_nsteps{total_steps + 1}_nepis{episode}"
                    )

                # Misc
                total_steps += 1
                episode_steps += 1
                episode_reward += reward
                if loss:
                    writer.add_scalar("Loss/MSE", loss, total_steps)

            rewards_history.append(episode_reward)
            # Tensorboard
            writer.add_scalar("Reward/episode", episode_reward, episode)
            writer.add_scalar(
                "Reward/mean_100_episodes", np.mean(rewards_history[-100:]), episode
            )
            writer.add_scalar("Episode/n_steps", episode_steps, episode)
            writer.add_scalar(
                "Misc/eps_exploration", self.exploration_fn._value, episode
            )

        writer.close()

### PolicyNet update step

In [25]:
class DQNAgent(DQNAgent):
    def _update_policy_net(self, gamma: float, batch_size: int):
        """Perform one round of policy_net update.

        Sample random minibatch of transitions (fi(s_t), a_t, r_t, fi(s_t+1)) from replay buffer
        and update policy_net according to DQN algorithm.
        """
        if len(self.replay_buffer) < batch_size:
            return

        def get_targets(gamma: float, batch: Transition):
            """Uses `target_net` and immediate rewards to calculate expected future rewards."""
            batch_next_state = torch.tensor(batch.next_state, device=DEVICE).detach()
            # target_net prediction for terminal states should be 0, as our expectation from terminal state is 0
            non_terminal_mask = torch.tensor(
                batch.non_terminal_mask, device=DEVICE
            ).detach()
            next_state_bootstrapped_values = (
                torch.max(self.target_net(batch_next_state), dim=1)[0].detach()
                * non_terminal_mask
            )
            assert torch.all(
                (non_terminal_mask == 0).nonzero()
                == (next_state_bootstrapped_values == 0).nonzero()
            )

            assert len(batch.reward.shape) == 1
            assert len(next_state_bootstrapped_values.shape) == 1
            ##### TODO IMPLEMENT - given the pieces from above, compute the targets #####
            # to match remaining portions of the code, reshape the target tensor as follows: (-1, 1)

            # Expected future reward for terminal state is equal to immediate reward
            # For non terminal states expected future reward:
            # immediate reward + discounted future expectation
            ##### END OF TODO    #####

            assert targets.shape == (batch.next_state.shape[0], 1)
            return targets

        def get_state_action_values(batch):
            """Uses `policy_net` to calculate current estimates of future rewards."""
            batch_state = torch.tensor(batch.state, device=DEVICE)
            # Calculate current estimates for the (state, action) we have observed and taken
            # 'preds' shape: (batch_size, n_states, n_actions)
            preds = self.policy_net(batch_state)
            # Extracting values from various indices might be a little confusing:
            # https://medium.com/analytics-vidhya/understanding-indexing-with-pytorch-gather-33717a84ebc4
            action_index = torch.tensor(
                batch.action, dtype=torch.long, device=DEVICE
            ).unsqueeze(-1)
            state_action_values = torch.gather(preds, dim=1, index=action_index)
            return state_action_values

        # Sample and convert batch into big Transition of form:
        # Transition(state=(0,0,...), action=(1,4,...), next_state=(0,3,...), reward(3,0,...), non_terminal_mask(0,1,0,...))
        # In other words: list_of_tuples -> tuple_of_lists
        transitions = self.replay_buffer.sample(batch_size)
        batch = Transition(*zip(*transitions))
        # Convert to numpy arrays so that we can use binary mask as indices to extract e.g. non terminal masks
        # Types are chosen so that torch.tensor will inherit correct one
        batch = Transition(
            np.array(batch.state),
            np.array(batch.action),
            np.array(batch.next_state),
            np.array(batch.reward, np.float32),
            np.array(batch.non_terminal_mask, np.float32),
        )

        state_action_values = get_state_action_values(batch)
        targets = get_targets(gamma, batch)
        loss = F.mse_loss(state_action_values, targets)
        self.optim.zero_grad()
        loss.backward()
        self.optim.step()

        return loss.item()

### TargetNet update
Finally, the last missing step is to *update target_net*

In [26]:
class DQNAgent(DQNAgent):
    def _update_target_net(self):
        """Sets `target_net` parameters to the current `policy_net` parameters."""
        self.target_net.load_state_dict(self.policy_net.state_dict())

# Environment

We will try to solve: https://gym.openai.com/envs/LunarLander-v2/

LunearLander env can be considered solved once we achieve 200 points.

In [None]:
env = gym.make("LunarLander-v2", render_mode="rgb_array")
N_ACTIONS = env.action_space.n
OBS_SHAPE = 8
print(f"Number of actions = {N_ACTIONS}")

# Experiment

In [None]:
DEVICE = try_gpu()
print(DEVICE)

EXP_NAME = "LunarLander"
LOG_DIR = f"runs/{EXP_NAME}"
TENSORBOARD_LOG_DIR = f"runs/{EXP_NAME}/tensorboard"
CHECKPOINTS_DIR = f"runs/{EXP_NAME}/checkpoints"
Path(CHECKPOINTS_DIR).mkdir(parents=True, exist_ok=True)

In [14]:
def parse_args():
    # Training settings
    parser = argparse.ArgumentParser(description="PyTorch DQN implementation")

    # Hack for colab...
    parser.add_argument(
        "-f",
        "--fff",
        help="a dummy argument to fool ipython in colab. Comment out for local dev.",
        default="1",
    )

    # To see the agent playing
    parser.add_argument(
        "--play",
        type=bool,
        default=False,
        help="play mode, if True then agent will play env instead of do training (default: False). "
        "If checkpoint is not specified then randomly initialized network will play",
    )
    parser.add_argument(
        "--checkpoint",
        type=str,
        default=None,
        help="checkpoint storing state_dict to load for the model. "
        "If None then agent will be initialized with random params (default: None)",
    )
    parser.add_argument(
        "--n_episodes", type=int, default=10, help="number of episodes to play"
    )

    # To train the agent
    parser.add_argument(
        "--exp_dir",
        type=str,
        default=f"exp/{datetime.datetime.now().timestamp()}",
        help="experiment directory were logs and checkpoints will be stored (default: exp/{timestamp}",
    )
    parser.add_argument(
        "--batch-size",
        type=int,
        default=64,
        metavar="N",
        help="input batch size for training (default: 64)",
    )
    parser.add_argument(
        "--epochs",
        type=int,
        default=14,
        metavar="N",
        help="number of epochs to train (default: 5000)",
    )
    parser.add_argument(
        "--lr",
        type=float,
        default=0.0005,
        metavar="LR",
        help="learning rate (default: 0.0005)",
    )
    args = parser.parse_args()

    return args

In [15]:
def main():
    args = parse_args()

    layers = [OBS_SHAPE, 256, 256, N_ACTIONS]
    policy_net = MLP(layers).to(DEVICE)
    target_net = MLP(layers).to(DEVICE)

    agent_params = {
        "exploration_fn": LinearDecay(1, 0.05, 0.00001),
        "policy_net": policy_net,
        "target_net": target_net,
    }
    if args.play:
        print("Wanna play a game...")
        agent_params["exploration_fn"] = Constant(0.01)
        agent = DQNAgent(**agent_params)
        if args.checkpoint:
            agent.load_policy_net(args.checkpoint)
        agent.play_episodes(args.n_episodes, env=env)

    else:
        print("Training mode...")
        train_params = {
            "gamma": 0.99,
            "optim": torch.optim.Adam(policy_net.parameters(), lr=0.0005),
            "n_episodes": int(2e4),
            "batch_size": 64,
            # Target update interval in number of env steps (not episodes)
            "target_update_interval": 100,
            "buffer_size": 10000,
            "checkpoint_save_interval": 5000,
            "checkpoints_dir": CHECKPOINTS_DIR,
            "tensorboard_log_dir": TENSORBOARD_LOG_DIR,
        }

        agent = DQNAgent(**agent_params)
        agent.learn(**train_params, env=env)


In [None]:
# Start tensorboard in google colab
# If you can't see anything run this cell twice
%tensorboard --logdir $TENSORBOARD_LOG_DIR

In [None]:
main()

# Tasks


1.   Implement missing code #### TODO IMPLEMENT #####
2.   Experiment with the hyperparameters e.g. gamma (discount-factor), epsilon (for exploration-exploitation trade-off)
3.   Observe weird behaviors of agent, e.g. "forgetting how to play" - reward going significantly down, and then "re-learning" again. Why can it happen? What can we do to avoid it?
4.   Change the args and observe the trained model behavior. What do you see?
5.   What can be improved in the training code?

