# Deep neural network agent with GRU

### Other reference implementations
* https://github.com/keep9oing/DRQN-Pytorch-CartPole-v1/blob/main/DRQN.py (pytorch)
* https://github.com/marload/DeepRL-TensorFlow2/blob/master/DRQN/DRQN_Discrete.py
(clean, tensorflow)

In [None]:

import torch.nn as nn
import torch.nn.functional as F
import torch as th


class GRUAgent(nn.Module):
    def __init__(self, input_size, output_size, hidden_size):
        super(GRUAgent, self).__init__()
        self.linear1 = nn.Linear(in_features=input_size, out_features=hidden_size)
        self.rnn = nn.GRU(input_size=hidden_size, hidden_size=hidden_size, batch_first=True)
        self.linear2 = nn.Linear(in_features=hidden_size, out_features=output_size)

    def reset(self):
        self.hidden = None

    def forward(self, obs):
        x = F.relu(self.linear1(obs))
        h, self.hidden = self.rnn(x, self.hidden)
        q = self.linear2(h)
        return q


In [None]:


class DQN():
    def __init__(
            self, *, gamma=None, target_update_freq=None, model_args=None, opt_args=None, eps=None, policy_model=None):
        if policy_model is None:
            self.policy_model = Model(**model_args) # Here goes the Neural Network 
            self.target_model = Model(**model_args) # Here goes the Neural Network 

            self.target_model.eval()
            self.optimizer = th.optim.RMSprop(self.policy_model.parameters(), **opt_args)
            self.gamma = gamma
            self.target_update_freq = target_update_freq
            self.eps = eps
            self.trainable = True
        else:
            self.policy_model = policy_model
            self.trainable = False

    def get_q(self, obs, first=False):
        with th.no_grad():
            return self.policy_model(obs, reset_rnn=first)

    def eps_greedy(self, q_values):
        """
        Args:
            q_values: Tensor of type `th.float` and arbitrary shape, last dimension reflect the actions.
        Returns:
            actions: Tensor of type `th.long` and the same dimensions then q_values, besides of the last.
        """
        assert self.trainable, "Model is not trainable."
        n_actions = q_values.shape[-1]
        actions_shape = q_values.shape[:-1]

        greedy_actions = q_values.argmax(-1)
        random_actions = th.randint(0, n_actions, size=actions_shape, device=self.device)

        # random number which determine whether to take the random action
        random_numbers = th.rand(size=actions_shape, device=self.device)
        select_random = (random_numbers < self.eps).long()
        picked_actions = select_random * random_actions + (1 - select_random) * greedy_actions

        return picked_actions

    def update(self, update_step, action, reward, **obs):
        assert self.trainable, "Model is not trainable."
        if (update_step % self.target_update_freq == 0):
            # copy policy net to target net
            self.target_model.load_state_dict(self.policy_model.state_dict())

        self.policy_model.train()
        current_state_action_values = self.policy_model(
            obs, reset_rnn=True).gather(-1, action.unsqueeze(-1))

        next_state_values = th.zeros_like(reward, device=self.device)

        # we skip the first observation and set the future value for the terminal
        # state to 0
        next_state_values[:, :-1] \
            = self.target_model(obs, reset_rnn=True)[:, 1:].max(-1)[0].detach()

        # Compute the expected Q values
        expected_state_action_values = (next_state_values * self.gamma) + reward

        loss_ur = th.nn.functional.smooth_l1_loss(
            current_state_action_values, expected_state_action_values.unsqueeze(-1),
            reduction='none')
        loss_ur = loss_ur.mean(dim=0)

        loss = loss_ur.mean()

        # Optimize the model
        self.optimizer.zero_grad()
        loss.backward()

        # truncate large loss
        for param in self.policy_model.parameters():
            param.grad.data.clamp_(-1, 1)
        self.optimizer.step()
        return loss_ur


    def save(self, filename):
        to_save = {
            'policy_model': self.policy_model
        }
        th.save(to_save, filename)

    @classmethod
    def load(cls, filename, device):
        to_load = th.load(filename)
        ah = cls(**to_load, device=device)
        return ah


# Replay memory

In [None]:
import collections
import numpy as np
import torch as th

class Memory():
    def __init__(
            self, device, n_batches, n_rounds, batch_size, output_file=None):
        self.memory = None
        self.n_batches = n_batches
        self.size = n_batches * batch_size
        self.batch_size = batch_size
        self.n_rounds = n_rounds
        self.device = device
        self.output_file = output_file
        self.start_row = None
        self.end_row = None
        self.rewind = 0
        self.deque = collections.deque([], maxlen=self.n_batches*batch_size)

    @property
    def last_valid_row(self):
        lvr = self.start_row + self.batch_size
        return lvr if lvr >= 0 else None

    def init_store(self, state):
        self.memory = {
            k: th.zeros((self.size, self.n_rounds, *t.shape[2:]),
                        dtype=t.dtype, device=self.device)
            for k, t in state.items() if t is not None
        }

    def start_batch(self):
        if (self.end_row is None):
            self.start_row = 0
        elif (self.end_row + self.batch_size) >= self.size:
            self.write()
            self.rewind += 1
            self.start_row = 0
        else:
            self.start_row = self.end_row
        self.end_row = self.start_row + self.batch_size
        self.current_batch = list(range(self.start_row, self.end_row))

    def finish_batch(self):
        self.deque.extendleft(self.current_batch)

    def add(self, round_number, **state):
        if self.memory is None:
            self.init_store(state)

        for k, t in state.items():
            if t is not None:
                self.memory[k][self.start_row:self.end_row, round_number] = t[:, 0].to(self.device)

    def sample(self, **kwargs):
        assert self.batch_size is not None, 'No sample size defined.'
        if len(self) < self.batch_size:
            return None, None
        relative_episode = np.random.choice(len(self), self.batch_size, replace=False)
        return self.get_relative(relative_episode, **kwargs)

    def get_relative(self, relative_episode, keys=None, device=None):
        if keys is None:
            keys = self.memory.keys()
        hist_idx = th.tensor(self.deque[relative_episode], dtype=th.int64, device=self.device)

        sample = {k: v[hist_idx] for k, v in self.memory.items() if k in keys}
        if device is not None:
            sample = {k: v.to(device) for k, v in sample.items()}
        return sample

    def __len__(self):
        return len(self.deque)


# Sampling of one batch of episodes

In [None]:
from itertools import count

def run_batch(env, controller, replay_mem=None, on_policy=True, update_step=None):

    obs = env.reset()
    metric_list = []
    for round_number in count():      
        encoded = encode_observations(obs) # TODO

        # Get q values from controller
        q_values = controller.get_q(encoded, first=round_number == 0)
    
        if on_policy:
            actions = q_values.argmax(-1)
        else:
            # Sample a action
            actions = controller.eps_greedy(q_values=q_values)
        prev_obs = obs
        obs, rewards, done, info = env.step(actions)
        
        if replay_mem is not None:
            replay_mem.add(
                round_number=round_number, actions=actions, rewards=rewards, 
                prev_obs=prev_obs)
        metrics = dict(
            **info,
            rewards=rewards.mean().item(),
            q_min=q_values.min().item(),
            q_max=q_values.max().item(),
            q_mean=q_values.mean().item(),
            round_number=round_number,
            sampling='greedy' if on_policy else 'eps-greedy',
            update_step=update_step,
        )
        metric_list.append(metrics)
         
        if done:
            break
    return metric_list

# Main loop

In [None]:
eval_period = ...
n_update_steps = ...
device = ...

device = th.device(device)
cpu = th.device('cpu')

env = ...

controller = ...

replay_mem = Memory(...)

metrics_list = []

for update_step in range(n_update_steps):
    replay_mem.start_batch(env.groups)

    # here we sample one batch of episodes and add them to the replay buffer
    off_policy_metrics = run_batch(env, controller, replay_mem, on_policy=False, update_step=update_step)

    replay_mem.finish_batch()
    
    # allow controller to update itself
    sample = replay_mem.sample(device=device)

    if sample is not None:
        loss = controller.update(update_step, **sample)
    
    if (update_step % eval_period) == 0:
        metrics_list.extend([{**m, 'loss': l.item()} for m, l in zip(off_policy_metrics, loss)])
        # We run one batch on policy
        metrics_list.extend(
            run_batch(env, controller, replay_mem=None, on_policy=True, update_step=update_step))


model_file = ...
controller.save(model_file)

In [None]:
# save metrics file

import pandas as pd

id_vars = ['round_number', 'sampling', 'update_step']

df = pd.DataFrame.from_records(metrics_list)

value_vars = list(set(df.columns) - set(id_vars))

df = df.melt(id_vars=id_vars, value_vars=value_vars, var_name='metric')

metrics_file = ...

df.to_parquet(metrics_file)