<a href="https://colab.research.google.com/github/EVSoaress/rl_studies/blob/main/rl_normalized_advantage_function.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## DQN for continuous action spaces: Normalized Advantage Function (NAF)

In [None]:
!apt-get install -y xvfb

!pip install \
    gym==0.22 \
    gym[box2d] \
    pytorch-lightning==1.6.0 \
    pyvirtualdisplay

In [None]:
from pyvirtualdisplay import Display
Display(visible=False, size=(1400, 900)).start()

In [None]:
import copy
import gym
import torch

import numpy as np
import torch.nn.functional as F

from collections import deque, namedtuple
from IPython.display import HTML
from base64 import b64encode

from torch import nn
from torch.utils.data import DataLoader
from torch.utils.data.dataset import IterableDataset
from torch.optim import AdamW

from pytorch_lightning import LightningModule, Trainer

from gym.wrappers import RecordVideo, RecordEpisodeStatistics, TimeLimit


device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
num_gpus = torch.cuda.device_count()




In [None]:
def display_video(episode=0):
  video_file = open(f'/content/videos/rl-video-episode-{episode}.mp4', "r+b").read()
  video_url = f"data:video/mp4;base64,{b64encode(video_file).decode()}"
  return HTML(f"<video width=600 controls><source src='{video_url}'></video>")

###Create the Deep Q-Network

In [None]:
class NafDQN(nn.Module):

  def __init__(self, hidden_size, obs_size, action_dims, max_action):
    super().__init__()
    self.action_dims = action_dims
    self.max_action = torch.from_numpy(max_action).to(device)
    self.net = nn.Sequential(
        nn.Linear(obs_size, hidden_size),
        nn.ReLU(),
        nn.Linear(hidden_size, hidden_size),
        nn.ReLU()
    )
    self.linear_mu = nn.Linear(hidden_size, action_dims)
    self.linear_value = nn.Linear(hidden_size, 1)
    self.linear_matrix = nn.Linear(hidden_size, 
                                   int(action_dims * (action_dims + 1) / 2))
    
    #Mu: compute action hightest Q-value
    @torch.no_grad()
    def mu(self, x):
      x = self.net(x)
      x = self.linear_mu
      x = torch.tanh(x) * self.max_action
      return x

    #value: compute value state
    @torch.no_grad()
    def value(self, x):
      x = self.net(x)
      x = self.linear_value(x)
      return x

    #forward: Q-value
    def forward(self, x, a):
      x = self.net(x)
      mu = torch.tanh(self.linear_mu(x) * )
      value = self.linear_value(x)

      #P[x]
      matrix = torch.tanh(self.linear_matrix(x))

      L = torch.zeros(x.shape[0], self.action_dims,  self.action_dims)
      tril_indices = torch.tril_indices(row=self.action_dims, 
                                        col=self.actio_dims).to(device)
      L[:, tril_indices[0], tril_indices[1]] = matrix
      L.diagonal(dim1 = 1, dim2 = 2).exp_()
      P = L * L.transpose(2,1)

      u_mu = (a-mu).unsqueeze(dim=1)
      u_mu_t = u_mu.transpose(1, 2)

      adv = -1 / 2 * u_mu @ P @ u_mu_t
      adv = adv.unsqueeze(dim=1)

      return adv + value 