<a href="https://colab.research.google.com/github/aliciafmachado/sac/blob/main/notebooks/Demo_colab_mujoco.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

##Installation

In [1]:
#@title Installations { form-width: "30%" }

# Fixing the haiku problem
!pip install --upgrade pip
!pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_releases.html

# Standard installs
!pip install dm-acme
!pip install dm-acme[reverb]
!pip install dm-acme[jax]
!pip install dm-acme[tf]
!pip install dm-acme[envs]
!pip install dm-env
!pip install dm-haiku
!pip install dm-tree
!pip install chex
!sudo apt-get install -y xvfb ffmpeg
!pip install imageio
!pip install gym
!pip install gym[classic_control]

# Need ml-collections for config file
!pip install ml_collections

!apt-get install x11-utils
!pip install pyglet

!pip install gym pyvirtualdisplay

from IPython.display import clear_output
clear_output()

In [2]:
#@title Imports  { form-width: "30%" }

%matplotlib inline
import IPython
from IPython.display import HTML
from IPython import display as ipythondisplay

import acme
from acme import datasets
from acme import types
from acme import specs
from acme.wrappers import gym_wrapper
import base64
from base64 import b64encode
import chex
import collections
from collections import namedtuple
import dm_env
import enum
import functools
import gym
import haiku as hk
import imageio
import io
import itertools
import jax
from jax import tree_util
import optax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import multiprocessing as mp
import multiprocessing.connection
import numpy as np
import pandas as pd
import random
import reverb
import rlax
import time
import tree
from typing import *
import warnings
import pyglet
pyglet.options['search_local_libs'] = False
pyglet.options['shadow_window']=False
from pyglet.window import xlib
xlib._have_utf8 = False

from pyvirtualdisplay import Display
display = Display(visible=False, size=(1400, 900))
display.start()
 
np.set_printoptions(precision=3, suppress=1)

%matplotlib inline

In [3]:
#Include this at the top of your colab code
import os
if not os.path.exists('.mujoco_setup_complete'):
  # Get the prereqs
  ! apt-get -qq update
  ! apt-get -qq install -y libosmesa6-dev libgl1-mesa-glx libglfw3 libgl1-mesa-dev libglew-dev patchelf
  # Get Mujoco
  ! mkdir ~/.mujoco
  ! wget -q https://mujoco.org/download/mujoco210-linux-x86_64.tar.gz -O mujoco.tar.gz
  ! tar -zxf mujoco.tar.gz -C "$HOME/.mujoco"
  ! rm mujoco.tar.gz
  # Add it to the actively loaded path and the bashrc path (these only do so much)
  ! echo 'export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$HOME/.mujoco/mujoco210/bin' >> ~/.bashrc 
  ! echo 'export LD_PRELOAD=$LD_PRELOAD:/usr/lib/x86_64-linux-gnu/libGLEW.so' >> ~/.bashrc 
  # THE ANNOYING ONE, FORCE IT INTO LDCONFIG SO WE ACTUALLY GET ACCESS TO IT THIS SESSION
  ! echo "/root/.mujoco/mujoco210/bin" > /etc/ld.so.conf.d/mujoco_ld_lib_path.conf
  ! ldconfig
  # Install Mujoco-py
  ! pip install -U 'mujoco-py<2.2,>=2.1'
  # run once
  ! touch .mujoco_setup_complete

try:
  if _mujoco_run_once:
    pass
except NameError:
  _mujoco_run_once = False
if not _mujoco_run_once:
  # Add it to the actively loaded path and the bashrc path (these only do so much)
  try:
    os.environ['LD_LIBRARY_PATH']=os.environ['LD_LIBRARY_PATH'] + ':/root/.mujoco/mujoco210/bin'
  except KeyError:
    os.environ['LD_LIBRARY_PATH']='/root/.mujoco/mujoco210/bin'
  try:
    os.environ['LD_PRELOAD']=os.environ['LD_PRELOAD'] + ':/usr/lib/x86_64-linux-gnu/libGLEW.so'
  except KeyError:
    os.environ['LD_PRELOAD']='/usr/lib/x86_64-linux-gnu/libGLEW.so'
  # presetup so we don't see output on first env initialization
  import mujoco_py
  _mujoco_run_once = True

Selecting previously unselected package libgl1-mesa-glx:amd64.
(Reading database ... (Reading database ... 5%(Reading database ... 10%(Reading database ... 15%(Reading database ... 20%(Reading database ... 25%(Reading database ... 30%(Reading database ... 35%(Reading database ... 40%(Reading database ... 45%(Reading database ... 50%(Reading database ... 55%(Reading database ... 60%(Reading database ... 65%(Reading database ... 70%(Reading database ... 75%(Reading database ... 80%(Reading database ... 85%(Reading database ... 90%(Reading database ... 95%(Reading database ... 100%(Reading database ... 156271 files and directories currently installed.)
Preparing to unpack .../0-libgl1-mesa-glx_20.0.8-0ubuntu1~18.04.1_amd64.deb ...
Unpacking libgl1-mesa-glx:amd64 (20.0.8-0ubuntu1~18.04.1) ...
Selecting previously unselected package libglew2.0:amd64.
Preparing to unpack .../1-libglew2.0_2.0.0-5_amd64.deb ...
Unpacking libglew2.0:amd64 (2.0.0-5) ...
Selecting previously 

In [4]:
! git clone https://aliciafmachado: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@github.com/aliciafmachado/sac.git

Cloning into 'sac'...
remote: Enumerating objects: 203, done.[K
remote: Counting objects: 100% (203/203), done.[K
remote: Compressing objects: 100% (144/144), done.[K
remote: Total 203 (delta 107), reused 131 (delta 49), pack-reused 0[K
Receiving objects: 100% (203/203), 347.81 KiB | 23.19 MiB/s, done.
Resolving deltas: 100% (107/107), done.
/content/sac


In [1]:
% cd sac

/content/sac


In [2]:
! pip install -e .

Obtaining file:///content/sac
  Preparing metadata (setup.py) ... [?25l[?25hdone
Installing collected packages: sac-jax
  Attempting uninstall: sac-jax
    Found existing installation: sac-jax 0.0.1
    Can't uninstall 'sac-jax'. No files were found to uninstall.
  Running setup.py develop for sac-jax
Successfully installed sac-jax-0.0.1
[0m

In [56]:
! python src/tests/test_buffer.py

[[0.9888073 ]
 [0.04515111]
 [0.944697  ]
 [0.7622583 ]
 [0.46336806]
 [0.6107861 ]
 [0.27294135]
 [0.15984583]
 [0.96106446]
 [0.58273685]]


In [17]:
from src.configs.default import get_config
from src.agents.sac import SAC
from src.envs.pendulum import PendulumEnv
import tensorflow as tf
import acme

tf.config.experimental.set_visible_devices([], 'GPU')
configs = get_config()
env = PendulumEnv(for_evaluation=False)
environment_spec = acme.make_environment_spec(env)
model = SAC(environment_spec, configs)

n_collection_steps = 10
n_updates = 1000


In [None]:
# Running random agent
from src.envs.pendulum import PendulumEnv
from src.agents.random_agent import RandomAgent
import tree

n_simulations = 10
render = False
random_agent = RandomAgent(acme.make_environment_spec(env))

for _ in range(n_simulations):
    ts = env.reset()
    while True:
        batched_observation = tree.map_structure(lambda x: x[None], ts.observation)
        a = random_agent.batched_actor_step(batched_observation)[0]
        ts = env.step(a)
        if render:
            env._env.render()
        if ts.last():
            break

env.close()
print("Done.")


In [18]:
# pre filling buffer
import copy
import tree
from src.agents.random_agent import RandomAgent

random_agent = RandomAgent(acme.make_environment_spec(env))

def pre_fill(env, buffer, n_trajectories):
  for _ in range(n_trajectories):
    ts = env.reset()
    obs = ts.observation
    while True:
        batched_observation = tree.map_structure(lambda x: x[None], ts.observation)
        a = random_agent.batched_actor_step(batched_observation)[0]
        ts = env.step(a)
        last_obs = copy.deepcopy(obs)
        obs = ts.observation
        action = a
        done = ts.last()
        reward = ts.reward
        buffer.store(state=last_obs, action=action, reward=reward, next_state=obs, done=done)
        if ts.last():
            break

    return buffer

In [19]:
buffer = pre_fill(env, model.buffer, 10)

In [20]:
print(buffer.sample(10).actions)

[[0.21850133]
 [0.58273685]
 [0.4510342 ]
 [0.3866675 ]
 [0.6531116 ]
 [0.99430287]
 [0.10536897]
 [0.9010216 ]
 [0.74942386]
 [0.09902263]]


In [21]:
# Now we use update_fn with things on the buffer:
ls = model.initialize()

In [26]:
# Running agent with transitions fresh from the environment
from jax import numpy as jnp
from src.utils.training_utils import Transitions

n_simulations = 10
render = False

for _ in range(n_simulations):
    ts = env.reset()
    obs = tree.map_structure(lambda x: x[None], ts.observation)
    while True:
        a = random_agent.batched_actor_step(obs)
        ts = env.step(a[0])
        last_obs = copy.deepcopy(obs)
        obs = tree.map_structure(lambda x: x[None], ts.observation)
        reward = tree.map_structure(lambda x: x[None], ts.reward)      
        done = jnp.array(ts.last())
        action = a

        t = Transitions(
            observations=last_obs,
            actions=a,
            rewards=reward,
            dones=done,
            next_observations=obs)

        ls, logs = model.update_fn(ls, t)

        if render:
            env._env.render()
        if ts.last():
            break

env.close()
print("Done.")

Done.


In [9]:
# Make sure tf does not allocate gpu memory.
from src.utils.training_utils import Transitions
from jax import numpy as jnp

ls = model.initialize()

obs_shp = (1, *environment_spec.observations.shape)
act_shp = (1, *environment_spec.actions.shape)

fake_obs = jnp.concatenate([jnp.zeros(obs_shp), jnp.ones(obs_shp)], axis=0)
fake_actions = jnp.concatenate([jnp.zeros(act_shp), jnp.ones(act_shp)], axis=0)
fake_n_obs = jnp.concatenate([jnp.ones(obs_shp), jnp.zeros(obs_shp)], axis=0)
fake_reward = jnp.full((2,1), 0.2)
fake_dones = jnp.full((2,1), 0)

print(fake_obs.shape)
print(fake_actions.shape)
print(fake_n_obs.shape)
print(fake_reward.shape)
print(fake_dones.shape)


transitions = Transitions(
  observations=fake_obs,
  actions=fake_actions,
  next_observations=fake_n_obs,
  rewards=fake_reward,
  dones=fake_dones,
)

n_updates = 100
for i in range(n_updates):
  ls, logs = model.update_fn(ls, transitions)
  print(logs)

(2, 3)
(2, 1)
(2, 3)
(2, 1)
(2, 1)
{'loss_pi': DeviceArray(-13.3467655, dtype=float32), 'loss_q': DeviceArray(0.10719462, dtype=float32), 'loss_v': DeviceArray(1685610., dtype=float32)}
{'loss_pi': DeviceArray(-1361.3218, dtype=float32), 'loss_q': DeviceArray(8.9778805, dtype=float32), 'loss_v': DeviceArray(1.222655e+10, dtype=float32)}
{'loss_pi': DeviceArray(-116987.805, dtype=float32), 'loss_q': DeviceArray(2.6512635, dtype=float32), 'loss_v': DeviceArray(4.893285e+12, dtype=float32)}
{'loss_pi': DeviceArray(-2333868.2, dtype=float32), 'loss_q': DeviceArray(0.7675578, dtype=float32), 'loss_v': DeviceArray(2.8826857e+14, dtype=float32)}
{'loss_pi': DeviceArray(-17695148., dtype=float32), 'loss_q': DeviceArray(1.1345975, dtype=float32), 'loss_v': DeviceArray(4.85507e+15, dtype=float32)}
{'loss_pi': DeviceArray(-72423320., dtype=float32), 'loss_q': DeviceArray(1.0975128, dtype=float32), 'loss_v': DeviceArray(4.5915988e+16, dtype=float32)}
{'loss_pi': DeviceArray(-2.2247258e+08, dtype=f