<a href="https://colab.research.google.com/github/TiaBerte/rl-soft-actor-critic/blob/main/render.ipynb" target="_parent">
<img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open in Colab"/></a>

In [1]:
# Rendering Dependencies
!pip install gym pyvirtualdisplay > /dev/null 2>&1
!apt-get install -y xvfb python-opengl ffmpeg > /dev/null 2>&1

# Gym Dependencies
!apt-get update > /dev/null 2>&1
!apt-get install cmake > /dev/null 2>&1
!pip install --upgrade setuptools 2>&1
!pip install ez_setup > /dev/null 2>&1
!pip install gym[mujoco] > /dev/null 2>&1

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting setuptools
  Downloading setuptools-66.0.0-py3-none-any.whl (1.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[0m [31m16.5 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: setuptools
  Attempting uninstall: setuptools
    Found existing installation: setuptools 57.4.0
    Uninstalling setuptools-57.4.0:
      Successfully uninstalled setuptools-57.4.0
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
ipython 7.9.0 requires jedi>=0.10, which is not installed.
cvxpy 1.2.3 requires setuptools<=64.0.2, but you have setuptools 66.0.0 which is incompatible.[0m[31m
[0mSuccessfully installed setuptools-66.0.0


In [2]:
!git clone https://github.com/TiaBerte/rl-soft-actor-critic.git
%cd rl-soft-actor-critic

Cloning into 'rl-soft-actor-critic'...
remote: Enumerating objects: 56, done.[K
remote: Counting objects: 100% (31/31), done.[K
remote: Compressing objects: 100% (29/29), done.[K
remote: Total 56 (delta 10), reused 10 (delta 0), pack-reused 25[K
Unpacking objects: 100% (56/56), 125.88 MiB | 11.53 MiB/s, done.
Updating files: 100% (17/17), done.
/content/rl-soft-actor-critic


In [3]:
import gym
from gym.wrappers.record_video import RecordVideo
import glob
import io
import base64
from IPython.display import HTML
from IPython import display as ipythondisplay
from sac import SAC
from replay_buffer import ReplayBuffer
from pyvirtualdisplay import Display

In [4]:
def show_video():
    mp4list = glob.glob('video/*.mp4')
    if len(mp4list) > 0:
      mp4 = mp4list[-1]
      video = io.open(mp4, 'r+b').read()
      encoded = base64.b64encode(video)
      ipythondisplay.display(HTML(data='''<video alt="test" autoplay 
                  controls style="height: 400px;">
                  <source src="data:video/mp4;base64,{0}" type="video/mp4" />
              </video>'''.format(encoded.decode('ascii'))))
    else: 
      print("Could not find video")
    

def wrap_env(env):
    env = RecordVideo(env, './video', episode_trigger = lambda episode_number: True)
    return env


def main(args):

    display = Display(visible=0, size=(1400, 900))
    display.start()

    env = wrap_env(gym.make(args.env_name))
    action_dim = env.action_space.shape[0]
    state_dim  = env.observation_space.shape[0]
    scale = env.action_space.high[0]

    replay_buffer = ReplayBuffer(1e4, 1e5)

    agent = SAC(state_dim, action_dim, scale, replay_buffer, args)
    agent.load_checkpoint(args.model_path)
    state = env.reset()
    done = False
    while not done: 
        env.render(mode='rgb_array')
        action = agent.get_action(state, True)
        new_state, reward, done, _ = env.step(action)
        state = new_state
        if done: 
          break;
                
    env.close()
    show_video()

In [5]:
from argparse import ArgumentParser
from typing import List


parser = ArgumentParser()
parser.add_argument("--env_name", help="Gym environment", default="Humanoid-v4", type=str)
parser.add_argument("--gamma", help="discount value", default=0.99, type=float)
parser.add_argument("--alpha", help="temperature entropy", default=0.2, type=float)
parser.add_argument("--alpha_tuning", help="temperature entropy", action='store_true')
parser.add_argument("--K", help="temperature entropy", type=int, default=0)
parser.add_argument("--tau", help="soft update", default=5e-3, type=float)
parser.add_argument("--batch_size", help="batch size", default=256, type=int)
parser.add_argument("--lr_p", help="learning rate", default=3e-4, type=float)
parser.add_argument("--lr_c", help="learning rate", default=3e-4, type=float)
parser.add_argument("--lr_a", help="learning rate", default=3e-4, type=float)
parser.add_argument("--hidden_dim_q", help="hidden dim list", default=[256, 256], type=List[int])
parser.add_argument("--hidden_dim_p", help="hidden dim list", default=[256, 256], type=List[int])
parser.add_argument("--log_std_min", help="log std", default=-20, type=float)
parser.add_argument("--log_std_max", help="log std", default=3, type=float)
parser.add_argument("--model_path", help="path from which model is loaded, if none the model is randomly intialized", type=str, default=None)

parser.add_argument("-f", "--fff", help="a dummy argument to fool ipython", default="1")
args = parser.parse_args()

In [6]:
args = parser.parse_args(["--env_name", "HalfCheetah-v4", '--model_path', "/content/rl-soft-actor-critic/checkpoints/HalfCheetah-v4/sac_alpha_tuning/HalfCheetah-v4_sac_alpha_tuning_ep_5780_test_rew_11348.2"])
main(args)

  deprecation(
  deprecation(


Loading model ...
Model loaded from /content/Humanoid-v4_sac_avg_K_10_ep_9410_test_rew_7972.3


  logger.deprecation(
See here for more information: https://www.gymlibrary.ml/content/api/[0m
  deprecation(
See here for more information: https://www.gymlibrary.ml/content/api/[0m
  deprecation(
