# Distributional Reinforcement Learning with Quantile Regression

<img src="https://ars-ashuha.github.io/images/QR-Net1.png", width=1200>
<img src="https://ars-ashuha.github.io/images/QR-Net2.png", width=1200>

- Distributional Reinforcement Learning with Quantile Regression, https://arxiv.org/pdf/1710.10044.pdf 
- The solution is you got stuck you could cheat a little bit https://github.com/ars-ashuha/quantile-regression-dqn-pytorch 

# Implementation

In [None]:
import gym
import torch
import pickle
import random
import numpy as np
import torch.nn as nn
import torch.optim as optim

from logger import Logger
from rl_utils import ReplayMemory, huber

In [None]:
class Network(nn.Module):
    def __init__(self, len_state, num_quant, num_actions):
        nn.Module.__init__(self)
        
        self.num_quant = num_quant
        self.num_actions = num_actions
        
        ###########################################################
        ########         You Code should be here         ##########
        # Define your model here, it is ok to use just 
        # two layers and tanh nonlinearity, do not forget that 
        # shape of the output should be 
        # batch_size x self.num_actions x self.num_quant
        self.layer1 = ....
        ###########################################################
        
    def forward(self, x):
        ###########################################################
        ########         You Code should be here         ##########
        # Compute the output of the network
        x = ....
        return x
        # Tensor of shape batch_size x self.num_actions x self.num_quant
        ###########################################################
    
    def select_action(self, state, eps):
        if not isinstance(state, torch.Tensor): 
            state = torch.Tensor([state])    
            
        action = torch.randint(0, 2, (1,))
        if random.random() > eps:
            ###########################################################
            ########         You Code should be here         ##########
            action = # Select Greedy action wrt Q(s, a) = E(Z(s, a))
            ###########################################################
        return int(action)

In [None]:
# Here we've defined a schedule for exploration i.e. random action with prob eps
eps_start, eps_end, eps_dec = 0.9, 0.1, 500 
eps = lambda steps: eps_end + (eps_start - eps_end) * np.exp(-1. * steps / eps_dec)

In [None]:
# We start from CartPole-v0         
# and then will solve MountainCar-v0
env_name = 'CartPole-v0' 
env = gym.make(env_name)

memory = ReplayMemory(10000)
logger = Logger('q-net', fmt={'loss': '.5f'})

In [None]:
Z = # Define Z an approximation network 
Ztgt = # Define Z a target network 
optimizer = optim.Adam(Z.parameters(), 1e-3)

In [None]:
tau = torch.Tensor((2 * np.arange(Z.num_quant) + 1) / (2.0 * Z.num_quant)).view(1, -1)

## Training cicle

In [None]:
gamma, batch_size = 0.99, 32 
steps_done, running_reward = 0, None

for episode in range(501): 
    sum_reward = 0
    state = env.reset()
    while True:
        steps_done += 1
        
        action = Z.select_action(torch.Tensor([state]), eps(steps_done))
        next_state, reward, done, _ = env.step(action)
        memory.push(state, action, next_state, reward, float(done))
        sum_reward += reward
        
        if len(memory) < batch_size: break    
            
        ###########################################################
        ########         You Code should be here         ##########
        # Sample transitions from Replay Memory
        states, actions, rewards, next_states, dones = ...
        ###########################################################
        
        ###########################################################
        ########         You Code should be here         ##########
        # Calculate quantiles theta for current state and actions
        theta = ...
        # Calculate quantiles for the next stage with target network 
        # and then take value for a max action
        Znext_max = ...
        Ttheta = rewards + gamma * (1 - dones) * Znext_max
        # Calculate loss, use this trick to compute pairwise differences
        # Trick Tensor of shape (3,2,1) minus Tensor of shape (1,2,3) is Tensor of shape (3, 2, 3)
        # With all pairwise differences :)
        # Use Huber elementwise function to compute Huber loss
        diff = Ttheta.t().unsqueeze(-1) - theta 
        loss = torch.mean(huber(diff) * (tau - (diff.detach() < 0).float()).abs())
        ###########################################################

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        state = next_state
        
        if steps_done % 100 == 0:
            Ztgt.load_state_dict(Z.state_dict())
            
        if done and episode % 50 == 0:
            logger.add(episode, steps=steps_done, running_reward=running_reward, loss=loss.data.numpy())
            logger.iter_info()
            
        if done: 
            running_reward = sum_reward if not running_reward else 0.2 * sum_reward + running_reward*0.8
            break

# Vizualization

## Train the model for  MountainCar-v0 env here!

In [None]:
import time
import matplotlib.pyplot as plt
from IPython import display

%matplotlib inline
import seaborn as sns
sns.set_style('whitegrid')

from matplotlib import rcParams
rcParams['figure.figsize'] = 7, 2
rcParams['figure.dpi'] = 150

In [None]:
actions={
    'CartPole-v0': ['Left', 'Right'],
    'MountainCar-v0': ['Left', 'Non', 'Right'],
}

In [None]:
def get_plot(q):
    eps, p = 1e-8, 0
    x, y = [q[0]-np.abs(q[0]*0.2)], [0]
    for i in range(0, len(q)):
        x += [q[i]-eps, q[i]]
        y += [p, p+1/len(q)]
        p += 1/len(q)
    x+=[q[i]+eps, q[i]+np.abs(q[i]*0.2)]
    y+=[1.0, 1.0]
    return x, y

In [None]:
state, done, steps = env.reset(), False, 0
while True:
    plt.clf()
    steps += 1
    action = Z.select_action(torch.Tensor([state]), eps(steps_done))
    state, reward, done, _ = env.step(action)
    
    if steps % 3 == 0:  
        plt.subplot(1, 2, 1)
        plt.title('step = %s' % steps)
        plt.imshow(env.render(mode='rgb_array'))
        plt.axis('off')

        plt.subplot(1, 2, 2)
        Zval = Z(torch.Tensor([state])).detach().numpy()
        for i in range(env.action_space.n):
            x, y = get_plot(Zval[0][i])
            plt.plot(x, y, label='%s Q=%.1f' % (actions[env_name][i], Zval[0][i].mean()))
            plt.legend(bbox_to_anchor=(1.1, 1.1), ncol=3, prop={'size': 6})

        if done: break
        display.clear_output(wait=True)
        display.display(plt.gcf())
plt.clf()