<a href="https://colab.research.google.com/github/JHyunjun/SNU/blob/main/SAC.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

If you run in jupyter, turn 

```
colab = False
```

In [15]:
colab = True
if colab:
    !pip install gym==0.21 pyvirtualdisplay > /dev/null 2>&1
    !apt-get install -y xvfb python-opengl ffmpeg > /dev/null 2>&1
    !apt-get update > /dev/null 2>&1
    !apt-get install cmake > /dev/null 2>&1
    !pip install --upgrade setuptools 2>&1
    !pip install ez_setup > /dev/null 2>&1
    !pip3 install box2d-py
    !pip3 install gym[Box_2D]
    !pip3 install pybullet --upgrade

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [16]:
if colab:
    from google.colab import drive
    drive.mount('/content/drive')

    %cd /content/drive/MyDrive/Colab\ Notebooks/rl-master/rl-master/day4/sac
    !ls

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
[Errno 2] No such file or directory: '/content/drive/MyDrive/Colab Notebooks/rl-master/rl-master/day4/sac'
/content
drive  sample_data


In [17]:
if colab:
    import gym
    from gym.wrappers import Monitor
    import glob
    import io
    import base64
    from IPython.display import HTML
    from pyvirtualdisplay import Display
    from IPython import display as ipythondisplay

    display = Display(visible=0, size=(1400, 900))
    display.start()

    def show_video():
      mp4list = glob.glob('video/*.mp4')
      if len(mp4list) > 0:
        mp4 = mp4list[0]
        video = io.open(mp4, 'r+b').read()
        encoded = base64.b64encode(video)
        ipythondisplay.display(HTML(data='''<video alt="test" autoplay 
                    loop controls style="height: 400px;">
                    <source src="data:video/mp4;base64,{0}" type="video/mp4" />
                </video>'''.format(encoded.decode('ascii'))))
      else: 
        print("Could not find video")
        

    def wrap_env(env):
      env = Monitor(env, './video', force=True)
      return env

# SAC Practice

Remind : Key elements of SAC


*   Max-entropy MDP setting
*   Soft actor improvement with KL-divergence
*   Reparameterization trick





In [18]:
import time
import csv
import gym
import copy
import os
import numpy as np
from collections import namedtuple

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam
from torch.distributions import Independent
from torch.distributions.normal import Normal

from utils import *
from buffer import *

import pybullet_envs

ModuleNotFoundError: ignored

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print('current device =', device)

# 0. Define Q-network & policy-network

In [None]:
##################################################
##  Policy network with multi-layer perceptron  ##
##################################################

# Input - |S|
# Output - normal distribution of size |A|

class SACActor(nn.Module):
    def __init__(self, dimS, dimA, hidden1, hidden2, ctrl_range):
        super(SACActor, self).__init__()
        self.fc1 = nn.Linear(dimS, hidden1)
        self.fc2 = nn.Linear(hidden1, hidden2)

        self.fc3 = nn.Linear(hidden2, dimA)
        self.fc4 = nn.Linear(hidden2, dimA)

        self.ctrl_range = ctrl_range

    def forward(self, state, eval=False, with_log_prob=False):
        # Network architecture!
        # We will use gaussian policy
        #                   -> fc3 -> mu
        # s -> fc1 -> fc2 <
        #                   -> fc4 -> log(sigma)
        x = F.relu(self.fc1(state))
        x = F.relu(self.fc2(x))
        mu = self.fc3(x)
        log_sigma = self.fc4(x)
        
        # clip value of log_sigma, as was done in Haarnoja's implementation of SAC:
        # https://github.com/haarnoja/sac.git
        log_sigma = torch.clamp(log_sigma, -20.0, 2.0)
        
        # Build normal distribution with parameters from layer
        sigma = torch.exp(log_sigma)
        distribution = Independent(Normal(mu, sigma), 1)

        if not eval:
            # use rsample() instead of sample(), for reparameterization trick
            u = distribution.rsample()
            if with_log_prob:
                log_prob = distribution.log_prob(u)
                log_prob -= 2.0 * torch.sum((np.log(2.0) + 0.5 * np.log(self.ctrl_range) - u - F.softplus(-2.0 * u)), dim=1)
            else:
                log_prob = None
        # Give deterministic policy (centered at mu) when evaluation
        else:
            u = mu
            log_prob = None
            
        # apply tanh so that the resulting action lies in (-1, 1)^D
        # Reformulated into squashed gaussian policy
        a = self.ctrl_range * torch.tanh(u)

        return a, log_prob
    

##################################################
##  Critic network with multi-layer perceptron  ##
##################################################

# Input - |S|+|A|
# Output - single value

class DoubleCritic(nn.Module):
    # Retain double network - Idea from TD3
    def __init__(self, dimS, dimA, hidden1, hidden2):
        super(DoubleCritic, self).__init__()    
        # Q1
        self.fc1 = nn.Linear(dimS + dimA, hidden1)
        self.fc2 = nn.Linear(hidden1, hidden2)
        self.fc3 = nn.Linear(hidden2, 1)
        
        # Q2
        self.fc4 = nn.Linear(dimS + dimA, hidden1)
        self.fc5 = nn.Linear(hidden1, hidden2)
        self.fc6 = nn.Linear(hidden2, 1)

    def forward(self, state, action):
        x = torch.cat([state, action], dim=1)
        
        # Q1
        x1 = F.relu(self.fc1(x))
        x1 = F.relu(self.fc2(x1))
        x1 = self.fc3(x1)
        
        # Q2
        x2 = F.relu(self.fc4(x))
        x2 = F.relu(self.fc5(x2))
        x2 = self.fc6(x2)

        return x1, x2

    def Q1(self, state, action):
        x = torch.cat([state, action], dim=1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)

        return x

# 1. Define SAC agent

In [None]:
#class SACAgent:
    def __init__(self,
                 dimS,
                 dimA,
                 ctrl_range,
                 gamma=0.99,
                 pi_lr=1e-4,
                 q_lr=1e-3,
                 polyak=1e-3,
                 alpha=0.2,
                 hidden1=256,
                 hidden2=256,
                 buffer_size=1000000,
                 batch_size=128,
                 device='cpu',
                 render=False):

        self.dimS = dimS
        self.dimA = dimA
        self.ctrl_range = ctrl_range

        self.gamma = gamma
        self.pi_lr = pi_lr
        self.q_lr = q_lr
        self.polyak = polyak
        self.alpha = alpha
        
        self.batch_size = batch_size
        
        # networks definition
        # pi : actor network, Q : 2 critic network
        self.pi = SACActor(dimS, dimA, hidden1, hidden2, ctrl_range).to(device)
        self.Q = DoubleCritic(dimS, dimA, hidden1, hidden2).to(device)

        # target networks
        self.target_Q = copy.deepcopy(self.Q).to(device)
        freeze(self.target_Q)

        self.buffer = ReplayBuffer(dimS, dimA, limit=buffer_size)

        self.Q_optimizer = Adam(self.Q.parameters(), lr=self.q_lr)
        self.pi_optimizer = Adam(self.pi.parameters(), lr=self.pi_lr)

        self.device = device
        self.render = render

        return
    
    def act(self, state, eval=False):

        state = torch.tensor(state, dtype=torch.float).to(self.device)
        with torch.no_grad():
            action, _ = self.pi(state, eval=eval, with_log_prob=False)
        action = action.cpu().detach().numpy()

        return action
    
    def target_update(self):

        for params, target_params in zip(self.Q.parameters(), self.target_Q.parameters()):
            target_params.data.copy_(self.polyak * params.data + (1.0 - self.polyak) * target_params.data)

        return
    
    def save_model(self, path):
        print('adding checkpoints...')
        checkpoint_path = path + 'model.pth.tar'
        torch.save(
                    {'actor': self.pi.state_dict(),
                     'critic': self.Q.state_dict(),
                     'target_critic': self.target_Q.state_dict(),
                     'actor_optimizer': self.pi_optimizer.state_dict(),
                     'critic_optimizer': self.Q_optimizer.state_dict()
                    },
                    checkpoint_path)

        return

# 2. Implement one-step param update

In [None]:
def update(agent, batch):
    # Upload batch to GPU
    obs_batch = torch.tensor(batch.obs, dtype=torch.float).to(device)
    act_batch = torch.tensor(batch.act, dtype=torch.float).to(device)
    next_obs_batch = torch.tensor(batch.next_obs, dtype=torch.float).to(device)
    rew_batch = torch.tensor(batch.rew, dtype=torch.float).to(device)
    done_batch = torch.tensor(batch.done, dtype=torch.float).to(device)
    masks = torch.tensor([1.]).to(device) - done_batch
    
    #########################
    ##    Critic Update    ##
    #########################
    # Build Bellman target
    with torch.no_grad():
        # Get action with log(pi(a|s)) (also gradient)
        next_actions, log_probs = agent.pi(next_obs_batch, with_log_prob=True)
        
        # To calculate TQ, we need Q(s',pi(s'))
        target_q1, target_q2 = agent.target_Q(next_obs_batch, next_actions)
        
        # To mitigate overestimation! - Idea from TD3
        target_q = torch.min(target_q1, target_q2)
        
        # TQ^pi = r + gamma [ Q(s',pi(s')) - alpha H(pi(s')) ]
        # Recall : H = sum[ -P(X) * log(P(x)) ] = E [ -log(P(x)) ]
        # TODO : Make target Q value!
        TQ = rew_batch + agent.gamma * masks * (target_q - agent.alpha * log_probs)

    # Calculate MSELoss
    Q1, Q2 = agent.Q(obs_batch, act_batch)
    Q_loss1 = torch.mean((Q1 - TQ)**2)
    Q_loss2 = torch.mean((Q2 - TQ)**2)
    Q_loss = Q_loss1 + Q_loss2

    # Gradient descent
    agent.Q_optimizer.zero_grad()
    Q_loss.backward()
    agent.Q_optimizer.step()
    
    ########################
    ##    Actor Update    ##
    ########################
    actions, log_probs = agent.pi(obs_batch, with_log_prob=True)
    
    freeze(agent.Q)
    q1, q2 = agent.Q(obs_batch, actions)
    q = torch.min(q1, q2)

    # TODO: build soft actor loss
    # Hint : agent.alpha is alpha value in loss!
    pi_loss = torch.mean(agent.alpha * log_probs - q)
    
    # Gradient ascent
    agent.pi_optimizer.zero_grad()
    pi_loss.backward()
    agent.pi_optimizer.step()
    
    ####################################
    ##    Soft Target Critic Update    #
    ####################################
    unfreeze(agent.Q)
    agent.target_update()

# 3. Putting these together

In [None]:
def run_sac(
            agent,
            env_id,
            max_iter=1e6,
            eval_interval=2000,
            start_train=10000,
            train_interval=50,
            fill_buffer=20000,
            truncate=1000,
            ):

    params = locals()

    max_iter = int(max_iter)
    env = gym.make(env_id)

    if truncate is not None:
        max_ep_len = truncate

    set_log_dir(env_id)
    
    # Logging & Saving Weights
    num_checkpoints = 5
    checkpoint_interval = max_iter // (num_checkpoints - 1)
    current_time = time.strftime("%m%d-%H%M%S")
    train_log = open('./train_log/' + env_id + '/SAC_' + current_time + '.csv',
                     'w', encoding='utf-8', newline='')

    path = './eval_log/' + env_id + '/SAC_' + current_time
    eval_log = open(path + '.csv', 'w', encoding='utf-8', newline='')

    train_logger = csv.writer(train_log)
    eval_logger = csv.writer(eval_log)

    with open(path + '.txt', 'w') as f:
        for key, val in params.items():
            print(key, '=', val, file=f)

    ##############################
    ##    Main training loop    ##
    ##############################
    obs = env.reset()
    step_count, ep_reward = 0, 0
    start = time.time()
    
    for t in range(max_iter + 1):
        # Rollout agent to fill in replay buffer
        if t < fill_buffer:
            # For early stage of training, use random agent to promote exploration
            action = env.action_space.sample()
        else:
            action = agent.act(obs)

        next_obs, reward, done, _ = env.step(action)
        step_count += 1

        if step_count == max_ep_len:
            done = False

        agent.buffer.append(obs, action, next_obs, reward, done)

        obs = next_obs
        ep_reward += reward
        
        # Reset environment if trajectory ends
        if done or (step_count == max_ep_len):
            train_logger.writerow([t, ep_reward])
            obs = env.reset()
            step_count, ep_reward = 0, 0
        
        # Actor-Critic
        if (t >= start_train) and (t % train_interval == 0):
            # Iterate sampling batch and updating actor-critic
            for _ in range(train_interval):
                batch = agent.buffer.sample_batch(batch_size=batch_size)
                update(agent, batch)
        
        # Evaluate agent
        if t % eval_interval == 0:
            eval_score = eval_agent(agent, env_id, render=False)
            log = [t, eval_score]
            print('step {} : {:.4f}'.format(t, eval_score))
            eval_logger.writerow(log)
            agent.save_model('./checkpoints/' + env_id + '/sac_{}th_iter_'.format(t))

    train_log.close()
    eval_log.close()

    return

# 4. Let's train our agent!

### Hyperparameter setting

In [None]:
# Use continuous control!
env_id = 'HalfCheetahBulletEnv-v0'
env = gym.make(env_id)
dimS, dimA, ctrl_range, max_ep_len = get_env_spec(env)
truncate = 1000
max_iter = 5e5
eval_interval = 5000
render = False
tau = 5e-3
lr = 3e-4
hidden1 = 256
hidden2 = 256
train_interval = 50
start_train = 1e4
fill_buffer = 2e4
gamma = 0.99
alpha = 0.01
buffer_size = 1e6
batch_size = 4000

### Setup environment and agent

In [None]:
# You can try one of these to perform 
# HopperBulletEnv-v0
# HumanoidBulletEnv-v0
# Walker2DBulletEnv-v0
# HalfCheetahBulletEnv-v0

env = gym.make('HalfCheetahBulletEnv-v0')
get_env_spec(env)

# Let's watch robotics environment!
if colab:
    env = wrap_env(env)

obs = env.reset()
done = False
score = 0.

while not done:
    env.render()
    obs, rew, done, _ = env.step(env.action_space.sample())
    score += rew
    
env.close()
print('score : ', score)

if colab:
    show_video()

In [None]:
# Instantize agent
agent = SACAgent(
                 dimS,
                 dimA,
                 ctrl_range,
                 gamma=gamma,
                 pi_lr=lr,
                 q_lr=lr,
                 polyak=tau,
                 alpha=alpha,
                 hidden1=hidden1,
                 hidden2=hidden2,
                 buffer_size=int(buffer_size),
                 batch_size=batch_size,
                 device=device,
                 render=render
                 )

# Load pretrained model
# load_model(agent, path='./checkpoints/'+env_id+'/sac_baseline_model.pth.tar', device=device)

### Run experiment!

In [None]:
run_sac(
        agent,
        env_id,
        max_iter=max_iter,
        eval_interval=eval_interval,
        start_train=start_train,
        train_interval=train_interval,
        fill_buffer=fill_buffer,
        truncate=truncate,
        )

Save model trained so far

In [None]:
agent.save_model('./checkpoints/' + env_id + '/sac_final_')

# 5. Watch the trained agent!

In [None]:
# For calling the weight and re-use
env_id = 'HalfCheetahBulletEnv-v0'

env = gym.make(env_id)
dimS, dimA, ctrl_range, max_ep_len = get_env_spec(env)
if colab:
    env = wrap_env(env)

obs = env.reset()
done = False
score = 0.

agent = SACAgent(dimS, dimA, ctrl_range)
load_model(agent, path='./checkpoints/'+env_id+'/sac_final_model.pth.tar', device=device)
# load_model(agent, path='./checkpoints/'+env_id+'/sac_expert_model.pth.tar', device=device)

while not done:
    env.render()
    obs, rew, done, _ = env.step(agent.act(obs))
    score += rew
    
env.close()
print('score : ', score)

if colab:
    show_video()