## Initial Settings

In [None]:
!pip install gym > /dev/null 2>&1

In [None]:
# for the datasets
!pip install git+https://github.com/takuseno/d4rl-atari > /dev/null 2>&1

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import torch.optim as optim
from torch.nn.utils import clip_grad_norm_
import random
import math
from torch.utils.tensorboard import SummaryWriter
from collections import deque, namedtuple
import time
import gym
import torchvision.transforms as trans

In [None]:
#from google.colab import drive
#drive.mount("./drive")

import os

BASE_PATH = "./Saves/"

try:
  os.mkdir(BASE_PATH)
except:
  pass

TEST_NAME = "CQL-pacman"

try:
  os.mkdir(BASE_PATH + "checkpoints/")
except:
  pass

CHECKPOINT_PATH = BASE_PATH + "checkpoints/" + TEST_NAME
TENSORBOARD_PATH = BASE_PATH + "runs/"

## Environment

In [None]:
import d4rl_atari

env = gym.make('ms-pacman-expert-v0')
video_env = gym.make("MsPacman-v0")  # environment used for test and visualization

# clear dataset variable
dataset = 0
# load dataset
dataset = env.get_dataset()


def sampleDataset(batch_size, device):
  """Randomly sample a batch of experiences from the dataset."""
  indeces = random.sample(range(len(dataset['observations']) - 1), k=batch_size)

  # to avoid final states
  for i in range(batch_size):
    for j in range(4):
      if dataset['terminals'][indeces[i-j]] == 1:
        indeces[i] -= j+1                         # WARNING: every episode needs to be at least 4 frames

  states = torch.from_numpy(np.stack([np.vstack([dataset['observations'][i], dataset['observations'][i-1], dataset['observations'][i-2], dataset['observations'][i-3]]) for i in indeces])).float().to(device)
  next_states = torch.from_numpy(np.stack([np.vstack([dataset['observations'][i+1], dataset['observations'][i], dataset['observations'][i-1], dataset['observations'][i-2]]) for i in indeces])).float().to(device)

  actions = torch.from_numpy(np.vstack([dataset['actions'][i] for i in indeces])).long().to(device)
  rewards = torch.from_numpy(np.vstack([dataset['rewards'][i] for i in indeces])).float().to(device)
  dones = torch.from_numpy(np.vstack([dataset['terminals'][i] for i in indeces]).astype(np.uint8)).float().to(device)

  return (states, actions, rewards, next_states, dones)

## Network

In [None]:
class QR_DQN(nn.Module):
    def __init__(self, state_size, action_size, seed, N):
        super(QR_DQN, self).__init__()
        self.seed = torch.manual_seed(seed)
        self.input_shape = state_size
        self.action_size = action_size
        self.N = N

        self.layer1 = nn.Sequential(
            nn.Conv2d(4, 32, kernel_size=8, stride=4),
            nn.ReLU()
        )
        self.layer2 = nn.Sequential(
            nn.Conv2d(32, 64, kernel_size=4, stride=2),
            nn.ReLU()
        )
        self.layer3 = nn.Sequential(
            nn.Conv2d(64, 64, kernel_size=3, stride=1),
            nn.ReLU()
        )
        self.ff_1 = nn.Linear(64*7*7, 512)
        self.ff_2 = nn.Linear(512, action_size * N)
    
    def forward(self, x):

        # x.shape = (BATCH_SIZE, 4, 84, 84)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        
        x = x.view(-1, 64*7*7)
        x = torch.relu(self.ff_1(x))
        out = torch.relu(self.ff_2(x))
        
        return out.view(x.shape[0], self.N, self.action_size)
    
    def get_action(self,input):
        x = self.forward(input)
        return x.mean(dim=1)

## QR-DQN Agent

In [None]:
class QRDQN_Agent():
    """Interacts with and learns from the environment."""

    def __init__(self,
                 state_size,
                 action_size,
                 BATCH_SIZE,
                 N,
                 LR,
                 EPS_ADAM,
                 TAU,
                 GAMMA,
                 ALPHA,
                 UPDATE_TARGET_NETWORK_STEPS,
                 device,
                 seed):
        """Initialize an Agent object.
        
        Params
        ======
            state_size (int): dimension of each state
            action_size (int): dimension of each action
            BATCH_SIZE (int): size of the training batch
            N (int): number of heads of the training network
            LR (float): learning rate
            EPS_ADAM (float): epsilon used by the optimizer ADAM
            TAU (float): tau for soft updating the network weights
            GAMMA (float): discount factor
            ALPHA (float): weight of the log_sum_exp part of the loss
            UPDATE_TARGET_NETWORK_STEPS (int): steps between each update of the target network
            device (str): device that is used for the compute
            seed (int): random seed
        """
        self.state_size = state_size
        self.action_size = action_size
        self.BATCH_SIZE = BATCH_SIZE
        self.N = N
        self.TAU = TAU
        self.GAMMA = GAMMA
        self.ALPHA = ALPHA
        self.update_target_network_count = 0  # count of steps to use for target network update
        self.UPDATE_TARGET_NETWORK_STEPS = UPDATE_TARGET_NETWORK_STEPS
        self.seed = random.seed(seed)
        self.device = device

        # create the tensor containing the quantiles values
        self.quantile_tau = torch.FloatTensor([i/self.N for i in range(1,self.N+1)]).to(device)

        # initialize the action retainer variables
        self.action_step = 4
        self.last_action = None

        # Q-Networks
        self.qnetwork_local = QR_DQN(state_size, action_size, seed, self.N).to(device)  # main network
        self.qnetwork_target = QR_DQN(state_size, action_size, seed, self.N).to(device) # auxiliary network

        self.optimizer = optim.Adam(self.qnetwork_local.parameters(), lr=LR, eps=EPS_ADAM)
        print(self.qnetwork_local)
    
    def step(self, writer, frame):
        # sample the dataset for BATCH_SIZE transactions
        experiences = sampleDataset(self.BATCH_SIZE, self.device)
        # learn from this batch and calculate the loss
        loss = self.learn(experiences)
        writer.add_scalar("Q_loss", loss, frame)

    def act(self, state, eps=0.):
        """Returns actions for given state as per current policy. Acting only every 4 frames!
        
        Params
        ======
            frame: to adjust epsilon
            state (array_like): current state
            
        """

        if self.action_step == 4:

            state = state.float().unsqueeze(0).to(self.device)
            self.qnetwork_local.eval()
            with torch.no_grad():
                action_values = self.qnetwork_local.get_action(state)
            self.qnetwork_local.train()

            # Epsilon-greedy action selection
            if random.random() > eps: # select greedy action if random number is higher than epsilon or noisy network is used!
                action = np.argmax(action_values.cpu().data.numpy())
                self.last_action = action
                return action
            else:
                action = random.choice(np.arange(self.action_size))
                self.last_action = action 
                return action
            self.action_step = 0
        else:
            self.action_step += 1
            return self.last_action

    def learn(self, experiences):
        """Update value parameters using given batch of experience tuples.
        Params
        ======
            experiences (Tuple[torch.Tensor]): tuple of (s, a, r, s', done) tuples 
        """
        self.optimizer.zero_grad()
        states, actions, rewards, next_states, dones = experiences

        # Get max predicted Q values (for next states) from target model
        Q_targets_next = self.qnetwork_target(next_states).detach().cpu() #(BATCH_SIZE, N, action_size)
        action_indx = torch.argmax(Q_targets_next.mean(dim=1), dim=1, keepdim=True) # predicted action for each state # (BATCH_SIZE, 1)

        action_indx = action_indx.unsqueeze(-1).expand(self.BATCH_SIZE, self.N, 1)  # (BATCH_SIZE, N, 1)
        Q_targets_next = Q_targets_next.gather(2, action_indx).transpose(1,2)   # (BATCH_SIZE, 1, N)  

        assert Q_targets_next.shape == (self.BATCH_SIZE, 1, self.N)

        # Compute Q targets for current states 
        Q_targets = rewards.unsqueeze(-1) + (self.GAMMA * Q_targets_next.to(self.device) * (1 - dones.unsqueeze(-1)))  # (BATCH_SIZE, 1, N)
        # Get expected Q values from local model
        Q_expected_actions = self.qnetwork_local(states)  # (BATCH_SIZE, N, action_size)
        Q_expected = Q_expected_actions.gather(2, actions.unsqueeze(-1).expand(self.BATCH_SIZE, self.N, 1))  # (BATCH_SIZE, N, 1)
        
        # Compute loss
        td_error = Q_targets - Q_expected
        assert td_error.shape == (self.BATCH_SIZE, self.N, self.N), "wrong td error shape"
        huber_l = calculate_huber_loss(td_error, 1.0)
        quantil_l = abs(self.quantile_tau - (td_error.detach() < 0).float()) * huber_l / 1.0

        loss = quantil_l.mean(dim=2).sum(dim=1) # (BATCH_SIZE)
        loss = loss.mean()  # mean between batch values
        
        # ---- CQL extension --------------------------------------------------------------------------------------------------------------------
        Q_expected_actions_single = Q_expected_actions.mean(dim=1)  # (BATCH_SIZE, action_size)

        log_sum_exp = torch.logsumexp(Q_expected_actions_single, dim=1)  # (BATCH_SIZE)
        alpha_term = log_sum_exp - Q_expected_actions_single.gather(1, actions).squeeze() # (BATCH_SIZE)
        alpha_term = alpha_term.mean()  # mean over BATCH_SIZE
        # ---------------------------------------------------------------------------------------------------------------------------------------

        # CQL(H)
        loss = self.ALPHA * alpha_term + loss # * 0.5
        
        # Minimize the loss
        loss.backward()
        
        self.optimizer.step()

        # ------------------- update target network ------------------- #
        self.update_target_network_count += 1
        if self.update_target_network_count == self.UPDATE_TARGET_NETWORK_STEPS:
          self.soft_update(self.qnetwork_local, self.qnetwork_target)
        # ------------------------------------------------------------- #

        return loss.detach().cpu().numpy()            

    def soft_update(self, local_model, target_model):
        """Soft update model parameters.
        θ_target = τ*θ_local + (1 - τ)*θ_target
        Params
        ======
            local_model (PyTorch model): weights will be copied from
            target_model (PyTorch model): weights will be copied to
        """
        for target_param, local_param in zip(target_model.parameters(), local_model.parameters()):
            target_param.data.copy_(self.TAU*local_param.data + (1.0-self.TAU)*target_param.data)
            
def calculate_huber_loss(td_errors, k=1.0):
    """
    Calculate huber loss element-wisely depending on k.
    """
    loss = torch.where(td_errors.abs() <= k, 0.5 * td_errors.pow(2), k * (td_errors.abs() - 0.5 * k))
    return loss


## Run declaration

In [None]:
def run(frames=10000, save_every=10000, load_from_checkpoint=False):
    """Quantile Regression - Deep Q-Learning.
    
    Params
    ======
        frames (int): maximum number of training frames
        save_every (int): number of frame between each save
    """
    
    start_frame = 1

    # if we want to continue the training, load the old checkpoint
    if load_from_checkpoint:
      checkpoint = torch.load(CHECKPOINT_PATH + ".tar")

      start_frame = checkpoint["frame"]
      agent.qnetwork_local.load_state_dict(checkpoint["local_network"])
      agent.qnetwork_target.load_state_dict(checkpoint["target_network"])
      agent.optimizer.load_state_dict(checkpoint["optim"])

    for frame in range(start_frame, frames+1):
      agent.step(writer, frame)

      print("\rFrame {} ".format(frame), end="")

      # every save_every frames we save the state_dict in a file
      if frame % save_every == 0:
        # save the checkpoint
        torch.save({
            "frame": frame,
            "local_network": agent.qnetwork_local.state_dict(),
            "target_network": agent.qnetwork_target.state_dict(),
            "optim": agent.optimizer.state_dict()
        }, CHECKPOINT_PATH + str(frame) + ".tar")
        torch.save({
            "frame": frame,
            "local_network": agent.qnetwork_local.state_dict(),
            "target_network": agent.qnetwork_target.state_dict(),
            "optim": agent.optimizer.state_dict()
        }, CHECKPOINT_PATH + ".tar")

        # every 5000 frames evaluate the agent
        agent.qnetwork_local.eval()

        state = env.reset()
        done = False
        score = 0
        count = 0

        prev_state = state.copy()
        last_max_states = deque(maxlen=4) # the first is the newest
        for i in range(3):
          last_max_states.appendleft(state.copy())

        while not done:

          last_max_states.appendleft(np.maximum(state, prev_state))

          states_quadr = torch.stack([torch.from_numpy(x) for x in last_max_states])
          action = agent.act(states_quadr, 0)
          next_state, reward, done, _ = env.step(action)

          prev_state = state
          state = next_state
          score += reward

          count += 1
          if count == 1000 and score == 0:
            break

        writer.add_scalar("Eval Score", score, frame)
        print("\rFrame {} \tScore: {}".format(frame, score))

        agent.qnetwork_local.train()

## Initialization and random seed settings

In [None]:
writer = SummaryWriter(TENSORBOARD_PATH + TEST_NAME)
seed = 1
BATCH_SIZE = 32
N = 200
GAMMA = 0.99
TAU = 1   # update the target network coping the local network
LR = 5e-5
EPS_ADAM = 0.01/32
ALPHA = 1.0
UPDATE_TARGET_NETWORK_STEPS = 2000

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("Using ", device)


np.random.seed(seed)

print('State space: {}'.format(env.observation_space.shape))
print('Action space: {}'.format(env.action_space.n))

env.seed(seed)
action_size = env.action_space.n
state_size = env.observation_space.shape

agent = QRDQN_Agent(state_size=state_size,    
                  action_size=action_size,
                  BATCH_SIZE=BATCH_SIZE,
                  N=N, 
                  LR=LR, 
                  EPS_ADAM=EPS_ADAM, 
                  TAU=TAU, 
                  GAMMA=GAMMA,
                  ALPHA=ALPHA,
                  UPDATE_TARGET_NETWORK_STEPS=UPDATE_TARGET_NETWORK_STEPS,
                  device=device, 
                  seed=seed)


## Train

In [None]:
TOTAL_FRAMES = int(10e6)
SAVE_EVERY = TOTAL_FRAMES//100

t0 = time.time()
run(frames= TOTAL_FRAMES, save_every= SAVE_EVERY, load_from_checkpoint=False)
t1 = time.time()

print("\nTraining time: {}min".format(round((t1-t0)/60,2)))

# save the checkpoint
torch.save({
    "frame": TOTAL_FRAMES,
    "local_network": agent.qnetwork_local.state_dict(),
    "target_network": agent.qnetwork_target.state_dict(),
    "optim": agent.optimizer.state_dict()
}, CHECKPOINT_PATH + ".tar")

## Test

In [None]:
!pip install gym pyvirtualdisplay > /dev/null 2>&1
!apt-get install x11-utils > /dev/null 2>&1
!apt-get install -y xvfb python-opengl ffmpeg > /dev/null 2>&1

In [None]:
from gym import logger as gymlogger
from gym.wrappers import Monitor
gymlogger.set_level(40) #error only
import glob
import io
import os
import base64
from IPython.display import HTML
from IPython import display as ipythondisplay
import time

from pyvirtualdisplay import Display
display = Display(visible=0, size=(1400, 900))
display.start()

"""
Utility functions to enable video recording of gym environment and displaying it
To enable video, just do "env = wrap_env(env)""
"""
def show_video():
    mp4list = glob.glob('videos/*/*.mp4')
    mp4list.sort(key=os.path.getmtime)
    if len(mp4list) > 0:
        mp4 = mp4list[-1]
        video = io.open(mp4, 'rb').read()
        encoded = base64.b64encode(video)
        ipythondisplay.display(HTML(data='''<video alt="test" autoplay 
                    loop controls style="height: 400px;">
                    <source src="data:video/mp4;base64,{0}" type="video/mp4" />
                </video>'''.format(encoded.decode('ascii'))))
    else: 
        print("Could not find video")
    

def wrap_env(env):
    env = Monitor(env, './videos/' + str(time.time()) + '/')  # Monitor objects are used to save interactions as videos
    return env

In [None]:
def preprocessing(imageRGB):
  transformation = trans.Resize((84, 84))

  imageRGB_torch = torch.Tensor(imageRGB)
  L = 0.2126 * imageRGB_torch[:, :, 0] + 0.7152 * imageRGB_torch[:, :, 1] + 0.0722 * imageRGB_torch[:, :, 2]
  trans_image = transformation(L.unsqueeze(0)).squeeze(0)
  
  return trans_image

In [None]:
agent.qnetwork_local.eval()

for e in range(5):
  video_env = wrap_env(video_env)
  state = video_env.reset()
  done = False
  score = 0

  prev_state = state.copy()
  last_max_states = deque(maxlen=4) # the first is the newest
  for i in range(3):
    last_max_states.appendleft(state.copy())

  count = 0
  while not done:
    video_env.render()

    last_max_states.appendleft(np.maximum(state, prev_state))

    states_quadr = torch.stack([preprocessing(x) for x in last_max_states])
    action = agent.act(states_quadr, 0)

    if count == 0:
      action = 1
      count += 1

    next_state, reward, done, _ = video_env.step(action)

    prev_state = state
    state = next_state
    score += reward

  video_env.close()
  show_video()

  print('Test episode {} - R(tau) = {}'.format(e, score))