## Initial Settings

In [None]:
!pip install git+https://github.com/takuseno/d4rl-pybullet > /dev/null 2>&1
!pip install pybullet > /dev/null 2>&1

In [None]:
import random

import gym
import pybullet
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.distributions import Normal

from torch.utils.tensorboard import SummaryWriter


import d4rl_pybullet

In [None]:
#from google.colab import drive
#drive.mount("./drive")

import os

BASE_PATH = "./Saves/"

try:
  os.mkdir(BASE_PATH)
except:
  pass

TEST_NAME = "SAC-ant-pybullet"

try:
  os.mkdir(BASE_PATH + "checkpoints/")
except:
  pass

CHECKPOINT_PATH = BASE_PATH + "checkpoints/" + TEST_NAME
TENSORBOARD_PATH = BASE_PATH + "runs/"

## Set random seed

In [None]:
if torch.backends.cudnn.enabled:
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True

seed = 1
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)

## Network

In [None]:
def init_layer_uniform(layer: nn.Linear, init_w: float = 3e-3) -> nn.Linear:
    """Init uniform parameters on the single layer."""
    layer.weight.data.uniform_(-init_w, init_w)
    layer.bias.data.uniform_(-init_w, init_w)

    return layer


class Actor(nn.Module):
    def __init__(
        self, 
        in_dim: int, 
        out_dim: int,
        log_std_min: float = -20,
        log_std_max: float = 2,
    ):
        """Initialize."""
        super(Actor, self).__init__()
        
        # set the log std range
        self.log_std_min = log_std_min
        self.log_std_max = log_std_max
        
        # set the hidden layers
        self.hidden1 = nn.Linear(in_dim, 256)
        self.hidden2 = nn.Linear(256, 256)
        
        # set log_std layer
        self.log_std_layer = nn.Linear(256, out_dim)
        self.log_std_layer = init_layer_uniform(self.log_std_layer)

        # set mean layer
        self.mu_layer = nn.Linear(256, out_dim)
        self.mu_layer = init_layer_uniform(self.mu_layer)

    def forward(self, state: torch.Tensor) -> torch.Tensor:
        """Forward method implementation."""
        x = F.relu(self.hidden1(state))
        x = F.relu(self.hidden2(x))
        
        # get mean
        mu = self.mu_layer(x).tanh()
        
        # get std
        log_std = self.log_std_layer(x).tanh()
        log_std = self.log_std_min + 0.5 * (
            self.log_std_max - self.log_std_min
        ) * (log_std + 1)
        std = torch.exp(log_std)
        
        # sample actions
        dist = Normal(mu, std)
        z = dist.rsample()
        
        # normalize action and log_prob
        action = z.tanh()
        log_prob = dist.log_prob(z) - torch.log(1 - action.pow(2) + 1e-7)
        log_prob = log_prob.sum(-1, keepdim=True)
        
        return action, log_prob
    
    
class CriticQ(nn.Module):
    def __init__(self, in_dim: int):
        """Initialize."""
        super(CriticQ, self).__init__()
        
        self.hidden1 = nn.Linear(in_dim, 256)
        self.hidden2 = nn.Linear(256, 256)
        self.out = nn.Linear(256, 1)
        self.out = init_layer_uniform(self.out)

    def forward(
        self, state: torch.Tensor, action: torch.Tensor
    ) -> torch.Tensor:
        """Forward method implementation."""
        x = torch.cat((state, action), dim=-1)
        x = F.relu(self.hidden1(x))
        x = F.relu(self.hidden2(x))
        value = self.out(x)
        
        return value

## SAC Agent

In [None]:
class SACAgent:
    """SAC agent interacting with environment.
    
    Attrtibutes:
        actor (nn.Module): actor model to select actions
        actor_optimizer (Optimizer): optimizer for training actor
        qf_1 (nn.Module): critic model to predict state-action values
        qf_2 (nn.Module): critic model to predict state-action values
        qf_1_target (nn.Module): target critic model to predict state-action values
        qf_2_target (nn.Module): target critic model to predict state-action values
        qf_1_optimizer (Optimizer): optimizer for training qf_1
        qf_2_optimizer (Optimizer): optimizer for training qf_2
        state_dim (int): dimension of the state space,
        action_dim (int): dimension of the action space,
        dataset (dict of numpy): dataset from d4rl of tuple (s,a,s',r,done)
        batch_size (int): batch size for sampling
        gamma (float): discount factor
        tau (float): parameter for soft target update
        policy_update_freq (int): policy update frequency
        device (torch.device): cpu / gpu
        target_entropy (int): desired entropy used for the inequality constraint
        log_alpha (torch.Tensor): weight for entropy
        alpha_optimizer (Optimizer): optimizer for alpha
        log_alpha_cql (torch.Tensor): weight for Conservative part
        alpha_cql_optimizer (Optimizer): optimizer for alpha_cql
        cql_threshold (int): threshold for the alpha_cql maximization
        total_step (int): total step numbers
    """
    
    def __init__(
        self,
        dataset,
        state_dim,
        action_dim,
        batch_size: int,
        gamma: float = 0.99,
        tau: float = 5e-3,
        policy_update_freq: int = 1,
        cql_threshold=10,
        random_actions_num=10
    ):
        """Initialize."""
        self.state_dim = state_dim
        self.action_dim = action_dim

        self.batch_size = batch_size
        self.gamma = gamma
        self.tau = tau
        self.policy_update_freq = policy_update_freq
        self.random_actions_num = random_actions_num

        # device: cpu / gpu
        self.device = torch.device(
            "cuda" if torch.cuda.is_available() else "cpu"
        )
        print(self.device)
        
        # automatic entropy tuning
        self.target_entropy = -np.prod((self.action_dim,)).item()  # heuristic
        self.log_alpha = torch.zeros(1, requires_grad=True, device=self.device)
        self.alpha_optimizer = optim.Adam([self.log_alpha], lr=3e-4)

        # with Lagrange alpha
        self.cql_threshold = cql_threshold
        self.log_alpha_cql = torch.zeros(1, requires_grad=True, device=self.device)
        self.alpha_cql_optimizer = optim.Adam([self.log_alpha_cql], lr=3e-4)

        # actor
        self.actor = Actor(self.state_dim, self.action_dim).to(self.device)
        
        # q function
        self.qf_1 = CriticQ(self.state_dim + self.action_dim).to(self.device)
        self.qf_2 = CriticQ(self.state_dim + self.action_dim).to(self.device)
        self.qf_1_target = CriticQ(self.state_dim + self.action_dim).to(self.device)
        self.qf_2_target = CriticQ(self.state_dim + self.action_dim).to(self.device)
        self.qf_1_target.load_state_dict(self.qf_1.state_dict())
        self.qf_2_target.load_state_dict(self.qf_2.state_dict())

        # optimizers
        self.actor_optimizer = optim.Adam(self.actor.parameters(), lr=3e-5)
        self.qf_1_optimizer = optim.Adam(self.qf_1.parameters(), lr=3e-4)
        self.qf_2_optimizer = optim.Adam(self.qf_2.parameters(), lr=3e-4)
        
        # total steps count
        self.total_step = 0
    
    
    def select_action(self, state: np.ndarray) -> np.ndarray:
        """Select an action from the input state."""
        selected_action = self.actor(torch.FloatTensor(state).to(self.device))[0].detach().cpu().numpy()
        
        return selected_action

    def get_policy_actions(self, states, random_actions_num):
        """Returns the action chosen by the actor one th input states, repeated random_actions_num times"""
        states_resized = states.unsqueeze(1).repeat(1, random_actions_num, 1).view(states.shape[0] * random_actions_num, states.shape[-1])
        actions, log_prob = self.actor(states_resized)
        return actions, log_prob.view(states.shape[0], random_actions_num, 1)
    
    def get_q_from_actions(self, states, actions, random_actions_num):
        """
        Returns the value of the function Q (in particular Q1 and Q2) using the critics networks on the input state and actions.
        states: shape = (BATCH_SIZE, space_dim)
        actions: shape = (BATCH_SIZE * random_actions_num)
        """
        states_resized = states.unsqueeze(1).repeat(1, random_actions_num, 1).view(states.shape[0] * random_actions_num, states.shape[-1])
        q1 = self.qf_1(states_resized, actions)
        q2 = self.qf_2(states_resized, actions)
        return q1.view(states.shape[0], random_actions_num, 1), q2.view(states.shape[0], random_actions_num, 1)
    
    def update_model(self):
        """Update the model by gradient descent."""
        device = self.device  # for shortening the following lines
        
        samples = sample_batch(self.batch_size)
        states = torch.FloatTensor(samples["state"]).to(device)
        next_states = torch.FloatTensor(samples["next_state"]).to(device)
        actions = torch.FloatTensor(samples["action"]).to(device)
        rewards = torch.FloatTensor(samples["reward"].reshape(-1, 1)).to(device)
        dones = torch.FloatTensor(samples["done"].reshape(-1, 1)).to(device)

        new_actions, log_prob = self.actor(states)
        
        # train alpha (dual problem)
        alpha_loss = (
            -self.log_alpha * (log_prob + self.target_entropy).detach()
        ).mean()

        self.alpha_optimizer.zero_grad()
        alpha_loss.backward()
        self.alpha_optimizer.step()
        
        alpha = self.log_alpha.exp()  # used for the actor loss calculation

        q_pred = torch.min(
            self.qf_1(states, new_actions), self.qf_2(states, new_actions)
        )
        
        # actor loss
        actor_loss = (alpha * log_prob - q_pred).mean()
        
        # q function loss
        mask = 1 - dones
        q_1_pred = self.qf_1(states, actions)
        q_2_pred = self.qf_2(states, actions)

        target_actions, target_actions_log_prob = self.actor(next_states)
        q_target = torch.min(
            self.qf_1_target(next_states, target_actions), self.qf_2_target(next_states, target_actions)
        ) - alpha * target_actions_log_prob
        q_target = rewards + self.gamma * q_target * mask

        # calculate the mean squared error loss
        qf_1_loss_mse = F.mse_loss(q_1_pred, q_target.detach())
        qf_2_loss_mse = F.mse_loss(q_2_pred, q_target.detach())
        

        # ---- CQL part -------------------------------------------------------------------------------------------------------------------------------------- #
        random_actions = torch.FloatTensor(self.batch_size * self.random_actions_num, actions.shape[-1]).uniform_(-1, 1).to(device)
        policy_actions, policy_actions_log_prob = self.get_policy_actions(states, self.random_actions_num)

        q1_random, q2_random = self.get_q_from_actions(states, random_actions, self.random_actions_num)  # shape: (BATCH_SIZE, self.random_actions_num, 1)
        q1_policy, q2_policy = self.get_q_from_actions(states, policy_actions, self.random_actions_num)

        # importance sampling (continuous variant of log_sum_exp)
        random_density = np.log(0.5 ** self.action_dim)
        q1_sampling = torch.cat([q1_random - random_density, q1_policy - policy_actions_log_prob.detach()], 1)
        q2_sampling = torch.cat([q2_random - random_density, q2_policy - policy_actions_log_prob.detach()], 1)

        cql_q1_loss = torch.logsumexp(q1_sampling, dim=1).mean() - q_1_pred.mean()
        cql_q2_loss = torch.logsumexp(q2_sampling, dim=1).mean() - q_2_pred.mean()


        # alpha_cql optimization
        alpha_cql = torch.clamp(self.log_alpha_cql.exp(), min=0.0, max=1000000.0)
        cql_q1_loss_alpha = alpha_cql * (cql_q1_loss - self.cql_threshold)
        cql_q2_loss_alpha = alpha_cql * (cql_q2_loss - self.cql_threshold)
        
        self.alpha_cql_optimizer.zero_grad()
        alpha_cql_loss = (- cql_q1_loss_alpha - cql_q2_loss_alpha) * 0.5
        alpha_cql_loss.backward(retain_graph=True)
        self.alpha_cql_optimizer.step()

        # update losses
        qf_1_loss = qf_1_loss_mse + cql_q1_loss_alpha
        qf_2_loss = qf_2_loss_mse + cql_q2_loss_alpha
        # ---------------------------------------------------------------------------------------------------------------------------------------------------- #
            
        # train Q functions and actor
        self.qf_1_optimizer.zero_grad()
        self.qf_2_optimizer.zero_grad()
        self.actor_optimizer.zero_grad()

        qf_1_loss.backward(retain_graph=True)
        qf_2_loss.backward(retain_graph=True)
        actor_loss.backward(retain_graph=False)

        self.qf_1_optimizer.step()
        self.qf_2_optimizer.step()
        self.actor_optimizer.step()

        # target update (qf1 and qf2)
        if self.total_step % self.policy_update_freq == 0:
            self._target_soft_update()
        
        return actor_loss.detach().cpu().numpy(), qf_1_loss.detach().cpu().numpy(), qf_2_loss.detach().cpu().numpy(), alpha_loss.detach().cpu().numpy(), alpha_cql_loss.detach().cpu().numpy()
    
    def train(self, num_frames: int, save_every: int, test_env, writer, load_from_checkpoint: bool = False):
        """Train the agent."""

        start_frame = 1

        if load_from_checkpoint:
            cp = torch.load(CHECKPOINT_PATH + ".tar")

            start_frame = cp["frame"]

            self.actor.load_state_dict(cp["actor_network"])
            self.qf_1.load_state_dict(cp["qf1_network"])
            self.qf_2.load_state_dict(cp["qf2_network"])
            self.qf_1_target.load_state_dict(cp["qf1_target_network"])
            self.qf_2_target.load_state_dict(cp["qf2_target_network"])

            self.log_alpha = cp["log_alpha"]
            self.log_alpha_cql = cp["log_alpha_cql"]

            self.actor_optimizer.load_state_dict(cp["actor_optimizer"])
            self.qf_1_optimizer.load_state_dict(cp["qf1_optimizer"])
            self.qf_2_optimizer.load_state_dict(cp["qf2_optimizer"])
            self.alpha_optimizer.load_state_dict(cp["alpha_optimizer"])
            self.alpha_cql_optimizer.load_state_dict(cp["alpha_cql_optimizer"])
        
        for self.total_step in range(start_frame, num_frames + 1):
            print("\rframe: " + str(self.total_step), end="")
            actor_loss, qf1_loss, qf2_loss, alpha_loss, alpha_cql_loss = self.update_model()

            writer.add_scalar("Actor_loss", actor_loss, self.total_step)
            writer.add_scalar("Qf1_loss", qf1_loss, self.total_step)
            writer.add_scalar("Qf2_loss", qf2_loss, self.total_step)
            writer.add_scalar("Alpha_loss", alpha_loss, self.total_step)
            writer.add_scalar("Alpha_cql_loss", alpha_cql_loss, self.total_step)

            # save the model every SAVE_EVERY steps
            if self.total_step % save_every == 0:
                self._save_checkpoint()
                self.test(test_env, writer)
            
        
        self._save_checkpoint()
    
    def _save_checkpoint(self):
        torch.save({
            "frame":                self.total_step,
            "actor_network":        self.actor.state_dict(),
            "qf1_network":          self.qf_1.state_dict(),
            "qf2_network":          self.qf_2.state_dict(),
            "qf1_target_network":   self.qf_1_target.state_dict(),
            "qf2_target_network":   self.qf_2_target.state_dict(),
            "log_alpha":            self.log_alpha,
            "log_alpha_cql":        self.log_alpha_cql,
            "actor_optimizer":      self.actor_optimizer.state_dict(),
            "qf1_optimizer":        self.qf_1_optimizer.state_dict(),
            "qf2_optimizer":        self.qf_2_optimizer.state_dict(),
            "alpha_optimizer":      self.alpha_optimizer.state_dict(),
            "alpha_cql_optimizer":  self.alpha_cql_optimizer.state_dict(),
        }, CHECKPOINT_PATH + ".tar")

        torch.save({
            "frame":                self.total_step,
            "actor_network":        self.actor.state_dict(),
            "qf1_network":          self.qf_1.state_dict(),
            "qf2_network":          self.qf_2.state_dict(),
            "qf1_target_network":   self.qf_1_target.state_dict(),
            "qf2_target_network":   self.qf_2_target.state_dict(),
            "log_alpha":            self.log_alpha,
            "log_alpha_cql":        self.log_alpha_cql,
            "actor_optimizer":      self.actor_optimizer.state_dict(),
            "qf1_optimizer":        self.qf_1_optimizer.state_dict(),
            "qf2_optimizer":        self.qf_2_optimizer.state_dict(),
            "alpha_optimizer":      self.alpha_optimizer.state_dict(),
            "alpha_cql_optimizer":  self.alpha_cql_optimizer.state_dict(),
        }, CHECKPOINT_PATH + str(self.total_step) + ".tar")
        
    def test(self, test_env, writer):
        """Test the agent."""
        
        state = test_env.reset()
        done = False
        score = 0
        
        while not done:
            action = self.select_action(state)
            next_state, reward, done, _ = test_env.step(action)

            state = next_state
            score += reward
        
        print("score: ", score)
        writer.add_scalar("Eval_score", score, self.total_step)
    
    def _target_soft_update(self):
        """Soft-update: target = tau*local + (1-tau)*target."""
        tau = self.tau
        
        for t_param, l_param in zip(self.qf_1_target.parameters(), self.qf_1.parameters()):
            t_param.data.copy_(tau * l_param.data + (1.0 - tau) * t_param.data)
        
        for t_param, l_param in zip(self.qf_2_target.parameters(), self.qf_2.parameters()):
            t_param.data.copy_(tau * l_param.data + (1.0 - tau) * t_param.data)

## Environment


In [None]:
# environment
env = gym.make("ant-bullet-medium-v0")
dataset = env.get_dataset()

def sample_batch(batch_size: int):
    l = len(dataset["actions"])

    indeces = np.asarray(random.sample(range(l - 1), batch_size))

    for i in range(len(indeces)):
        if dataset["terminals"][indeces[i]] == 1:
            indeces[i] -= 1

    res = {
        "state": dataset["observations"][indeces],
        "action": dataset["actions"][indeces],
        "next_state": dataset["observations"][indeces + 1],
        "reward": dataset["rewards"][indeces],
        "done": dataset["terminals"][indeces]
    }

    return res



## Initialize

In [None]:
# parameters
num_frames = int(5e6)
save_every = num_frames//100
batch_size = 256
writer = SummaryWriter(TENSORBOARD_PATH + TEST_NAME)

agent = SACAgent(
    dataset, 
    state_dim=env.observation_space.shape[0], 
    action_dim=env.action_space.shape[0], 
    batch_size=batch_size, 
    random_actions_num=10, 
    cql_threshold=10
)

## Train

In [None]:
agent.train(num_frames, save_every, test_env=env, load_from_checkpoint=False, writer=writer)

## Test

In [None]:
!pip install gym pyvirtualdisplay > /dev/null 2>&1
!apt-get install x11-utils > /dev/null 2>&1
!apt-get install -y xvfb python-opengl ffmpeg > /dev/null 2>&1

In [None]:
from gym import logger as gymlogger
from gym.wrappers import Monitor
gymlogger.set_level(40) #error only
import glob
import io
import os
import base64
from IPython.display import HTML
from IPython import display as ipythondisplay
import time

from pyvirtualdisplay import Display
display = Display(visible=0, size=(1400, 900))
display.start()

"""
Utility functions to enable video recording of gym environment and displaying it
To enable video, just do "env = wrap_env(env)""
"""
def show_video():
    mp4list = glob.glob('videos/*/*.mp4')
    mp4list.sort(key=os.path.getmtime)
    if len(mp4list) > 0:
        mp4 = mp4list[-1]
        video = io.open(mp4, 'rb').read()
        encoded = base64.b64encode(video)
        ipythondisplay.display(HTML(data='''<video alt="test" autoplay 
                    loop controls style="height: 400px;">
                    <source src="data:video/mp4;base64,{0}" type="video/mp4" />
                </video>'''.format(encoded.decode('ascii'))))
    else: 
        print("Could not find video")
    

def wrap_env(env):
    env = Monitor(env, './videos/' + str(time.time()) + '/')  # Monitor objects are used to save interactions as videos
    return env

In [None]:
for e in range(5):
  test_env = wrap_env(gym.make("ant-bullet-medium-v0"))
  state = test_env.reset()
  done = False
  score = 0

  while not done:
      test_env.render()

      action = agent.select_action(state)
      next_state, reward, done, _ = test_env.step(action)

      state = next_state
      score += reward

  test_env.close()
  show_video()

  print("score: ", score)