In [10]:
!pip install stable_baselines3[extra]
!pip install gymnasium-robotics



In [11]:
import gymnasium as gym
import gymnasium_robotics
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.vec_env import VecNormalize
import torch
from torch import Tensor
import torch.nn as nn
from torch.optim import Adam
import numpy as np
from tqdm.notebook import tqdm

from typing import Tuple, List, Dict

In [12]:
class ReplayBuffer:
  def __init__(self, max_len: int, n_envs: int, state_size: int, act_size: int, device: str):
    self.device = device
    self.max_len = max_len
    self.n_envs = n_envs
    self.state_size = state_size
    self.act_size = act_size
    self.buffer = []

  def add_frame(self, state, next_state, action, reward, finished) -> None:
    mask = 1 - finished
    self.buffer.append((state, next_state, action, reward, mask))

  def random_sample(self, batch_size: int) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]:
    states = torch.zeros(batch_size, self.n_envs, self.state_size, device=self.device)
    next_states = torch.zeros(batch_size, self.n_envs, self.state_size, device=self.device)
    actions = torch.zeros(batch_size, self.n_envs, self.act_size, device=self.device)
    rewards = torch.zeros(batch_size, self.n_envs, device=self.device)
    mask = torch.zeros(batch_size, self.n_envs, device=self.device)
    sample_tuple = (states, next_states, actions, rewards, mask)

    samples = np.random.randint(0, len(self.buffer), size=(batch_size,))

    for i, s in enumerate(samples):
      for t in range(len(sample_tuple)):
        sample_tuple[t][i] = self.buffer[s][t]

    states = sample_tuple[0].reshape(batch_size * self.n_envs, self.state_size)
    next_states = sample_tuple[1].reshape(batch_size * self.n_envs, self.state_size)
    actions = sample_tuple[2].reshape(batch_size * self.n_envs, self.act_size)
    rewards = sample_tuple[3].reshape(batch_size * self.n_envs)
    mask = sample_tuple[4].reshape(batch_size * self.n_envs)

    return states, next_states, actions, rewards, mask


In [13]:
class MLP(nn.Module):
  def __init__(self, input_size: int, net_arch: List[int], activation: nn.Module):
    super(MLP, self).__init__()
    layers = [nn.Linear(input_size, net_arch[0]), activation()]
    for i in range(len(net_arch) - 1):
      layers.append(nn.Linear(net_arch[i], net_arch[i+1]))
      layers.append(activation())
    self.model = nn.Sequential(*layers)

  def forward(self, x: Tensor) -> Tensor:
    return self.model(x)

class Actor(nn.Module):
  def __init__(self, input_size: int, output_size: int, net_arch: List[int], activation: nn.Module):
    super(Actor, self).__init__()
    self.base = MLP(input_size, net_arch, activation)
    self.mu_head = nn.Linear(net_arch[-1], output_size)
    self.sigma_head = nn.Linear(net_arch[-1], output_size)

  def forward(self, x: Tensor) -> Tuple[Tensor, Tensor]:
    x = self.base(x)
    mu = self.mu_head(x)
    log_sigma = self.sigma_head(x)
    return mu, log_sigma

  def eval_state(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
    mu, log_sigma = self.forward(x)
    sigma = log_sigma.exp()
    pi_s = torch.distributions.Normal(mu, sigma)
    action = pi_s.rsample()
    bounded_action = torch.tanh(action) # shift the result to valid region
    log_prob = pi_s.log_prob(action).sum(axis=-1)
    tanh_fix = 2 * (np.log(2) - action - nn.functional.softplus(-2*action))
    log_prob -= tanh_fix.sum(axis=1)
    return bounded_action, log_prob

  def get_action(self, x: Tensor) -> Tensor:
    mu, log_sigma = self.forward(x)
    sigma = log_sigma.exp()
    pi_s = torch.distributions.Normal(mu, sigma)
    action = pi_s.sample()
    return torch.tanh(action)


class QNet(nn.Module):
  def __init__(self, input_size: int, net_arch: List[int], activation: nn.Module):
    super(QNet, self).__init__()
    self.base = MLP(input_size, net_arch, activation)
    self.value_head = nn.Linear(net_arch[-1], 1)

  def forward(self, s: Tensor, a: Tensor) -> Tensor:
    x = torch.cat((s, a), dim = 1)
    x = self.base(x)
    x = self.value_head(x)
    return x

In [14]:
class SAC(nn.Module):
  def __init__(self, obs_size: int, act_size: int, lr: float, gamma: float, tau: float, device: str):
    super(SAC, self).__init__()
    self.activation = nn.ReLU
    self.actor = Actor(obs_size, act_size, [64, 64, 64, 64], self.activation)
    self.Q1 = QNet(obs_size + act_size, [64, 64, 64], self.activation)
    self.Q2 = QNet(obs_size + act_size, [64, 64, 64], self.activation)
    self.critic1 = QNet(obs_size + act_size, [64, 64, 64], self.activation)
    self.critic2 = QNet(obs_size + act_size, [64, 64, 64], self.activation)
    self.log_ent_coe = torch.tensor([1.0], requires_grad=True, device=device)

    self.actor_optim = Adam(self.actor.parameters(), lr)
    self.Q1_optim = Adam(self.Q1.parameters(), lr)
    self.Q2_optim = Adam(self.Q2.parameters(), lr)
    self.Ent_optim = Adam([self.log_ent_coe], lr)

    self.gamma = gamma
    self.max_grad = 1
    self.tau = tau
    self.target_ent = -np.log(act_size)
    self.device = device

    self.dict_to_vec = lambda state : torch.from_numpy(
        np.concat([state['achieved_goal'], state['desired_goal'], state['observation']], axis = -1)
        )

  def forward(self, x: Tensor) -> Tensor:
    raise NotImplementedError()

  @torch.no_grad()
  def rollout(self, env, max_step: int, buffer: ReplayBuffer) -> None:
    state = env.reset()
    state = self.dict_to_vec(state)
    scores = None
    success = 0
    for t in range(max_step):
      action, log_prob = self.actor.eval_state(state)
      next_state, reward, terminated, info = env.step(action.detach().numpy())
      next_state = self.dict_to_vec(next_state)
      termination = np.array([info_n['is_success'] for info_n in info])
      if any(terminated):
        success += np.sum(termination)
        state = self.dict_to_vec(env.reset())
      termination_mask = torch.from_numpy(1 - termination)
      buffer.add_frame(state, next_state, action.detach(), torch.from_numpy(reward), termination_mask)
      state = next_state

      if scores is None:
        scores = reward
      else:
        scores += reward
    return scores.mean(), success

  def get_q(self, state: Tensor, action: Tensor) -> Tensor:
    q1 = self.Q1(state, action)
    q2 = self.Q2(state, action)
    return torch.min(q1, q2)

  def step_optim(self, loss, model_name: str, logger: Dict | None) -> None:
    optimizer = getattr(self, model_name + "_optim")
    optimizer.zero_grad()
    loss.backward()
    # model = getattr(self, model_name)
    # nn.utils.clip_grad_norm_(model.parameters(), self.max_grad)
    optimizer.step()
    if logger is not None:
      logger[model_name + "_loss"] = loss.item()

  def polyak_average(self, model, target_model) -> None:
    for p, target_p in zip(model.parameters(), target_model.parameters()):
      target_p.data.copy_((1 - self.tau) * target_p + self.tau * p)

  def update(self, sample: Tuple) -> None:
    log_info = {}

    state, next_state, action, reward, mask = sample
    new_action, log_prob = self.actor.eval_state(state)

    ent_loss = -self.log_ent_coe * (log_prob + self.target_ent).detach()
    ent_loss = ent_loss.mean()
    self.step_optim(ent_loss, "Ent", log_info)

    ent_coe = torch.exp(self.log_ent_coe.detach())
    with torch.no_grad():
      next_action, next_log_prob = self.actor.eval_state(next_state)
      next_q1 = self.critic1(next_state, next_action)
      next_q2 = self.critic2(next_state, next_action)
      next_q = torch.min(next_q1, next_q2) - ent_coe * next_log_prob.unsqueeze(1)
      target_q = reward + self.gamma * mask * next_q

    Q1_value = self.Q1(state, action)
    Q2_value = self.Q2(state, action)
    Q1_loss = (Q1_value - target_q.detach()).pow(2).mean()
    Q2_loss = (Q2_value - target_q.detach()).pow(2).mean()
    self.step_optim(Q1_loss, "Q1", log_info)
    self.step_optim(Q2_loss, "Q2", log_info)

    new_q = self.get_q(state, new_action)
    actor_loss = (ent_coe * log_prob.unsqueeze(1) + new_q).mean()
    self.step_optim(actor_loss, "actor", log_info)

    self.polyak_average(self.Q1, self.critic1)
    self.polyak_average(self.Q2, self.critic2)

    return log_info

In [15]:
device = 'cpu'

def train(model, env, buffer, episodes: int, max_steps: int,
          batch_size: int, update_steps: int, gamma: float, print_per_epi: int, device: str) -> None:
  scores = []
  losses = []
  success_rate = 0
  for epi in tqdm(range(episodes)):
    score, success = model.rollout(env, max_steps, buffer)
    scores.append(score)
    success_rate += success

    Q1_losses = []
    Q2_losses = []
    actor_losses = []
    for _ in range(update_steps):
      log_info = model.update(buffer.random_sample(batch_size))
      Q1_losses.append(log_info['Q1_loss'])
      Q2_losses.append(log_info['Q2_loss'])
      actor_losses.append(log_info['actor_loss'])
    losses.append({
        "Q1_loss": np.mean(Q1_losses),
        "Q2_loss": np.mean(Q2_losses),
        "actor_loss": np.mean(actor_losses)
        })

    if epi % print_per_epi == 0:
      L_q1 = 0
      L_q2 = 0
      L_actor = 0
      n = len(losses)
      for i in losses:
        L_q1 += i["Q1_loss"]
        L_q2 += i["Q2_loss"]
        L_actor += i["actor_loss"]
      print(f"mean score is {np.mean(scores)}, success_rate is {success_rate / (max_steps * print_per_epi)}")
      print(f"Q1_loss is {L_q1 / n:.5f}, Q2_loss is {L_q2 / n:.5f}, actor_loss is {L_actor / n:.5f}")
      scores = []
      losses = []


In [16]:
n_envs = 4
env_name = "FetchReach-v4"
eval_env = gym.make(env_name, render_mode="rgb_array")
print(eval_env.observation_space.sample())
print(eval_env.action_space.sample())
print("--------------------------")
obs_size = 16
act_size = 4

seed = 0
env = VecNormalize(make_vec_env(env_name, n_envs, seed = seed))
state = env.reset()
print(state)
def dict_to_vec(state):
  return torch.from_numpy(np.concat([state['achieved_goal'], state['desired_goal'], state['observation']], axis = -1))
print(dict_to_vec(state))

{'achieved_goal': array([ 3.61161903,  0.45822068, -0.54496972]), 'desired_goal': array([-0.56726475, -0.46459048, -1.09511917]), 'observation': array([ 0.68497806, -1.05098556, -1.88148002, -1.61562953, -1.3532234 ,
        1.33078812,  1.25619465, -1.42892526, -0.38106066, -1.19690965])}
[-0.39724004 -0.31798786 -0.54086494  0.19464599]
--------------------------
OrderedDict({'achieved_goal': array([[0.00400882, 0.00299729, 0.00235733],
       [0.00400882, 0.00299729, 0.00235733],
       [0.00400882, 0.00299729, 0.00235733],
       [0.00400882, 0.00299729, 0.00235733]], dtype=float32), 'desired_goal': array([[ 1.2169112 , -0.5693924 , -1.1366205 ],
       [ 0.6380136 ,  1.7233865 , -0.8499632 ],
       [-0.5194549 , -0.47270492,  1.0115223 ],
       [-1.3334591 , -0.68046874,  0.97554284]], dtype=float32), 'observation': array([[ 4.00882121e-03,  2.99728918e-03,  2.35732808e-03,
         1.94986319e-06,  5.85004898e-08, -2.63015956e-08,
        -4.23140356e-10,  1.23087375e-07,  3.32

In [17]:
lr = 5e-4
gamma = 0.99
tau = 0.1
n_episodes = 1000
max_steps = 100
update_steps = 50
batch_size = 4

print_per_epi = 10

if seed:
  torch.manual_seed(seed)
  np.random.seed(seed)

model = SAC(obs_size, act_size, lr, gamma, tau, device)
buffer = ReplayBuffer(max_steps, n_envs, obs_size, act_size, device)
train(model, env, buffer, n_episodes, max_steps, batch_size, update_steps, gamma, print_per_epi, device)

  0%|          | 0/1000 [00:00<?, ?it/s]

mean score is -27.88772964477539, success_rate is 0.0
Q1_loss is 4.28689, Q2_loss is 4.30449, actor_loss is -7.45056
mean score is -8.746698379516602, success_rate is 0.0010000000474974513
Q1_loss is 2.55527, Q2_loss is 2.55724, actor_loss is -6.54660
mean score is -8.825740814208984, success_rate is 0.0020000000949949026
Q1_loss is 2.23792, Q2_loss is 2.23857, actor_loss is -5.09553
mean score is -8.793031692504883, success_rate is 0.0020000000949949026
Q1_loss is 1.45892, Q2_loss is 1.45990, actor_loss is -4.05008
mean score is -8.745903015136719, success_rate is 0.003000000026077032
Q1_loss is 0.91136, Q2_loss is 0.91143, actor_loss is -3.18464
mean score is -8.744513511657715, success_rate is 0.003000000026077032
Q1_loss is 0.94709, Q2_loss is 0.94761, actor_loss is -2.51247
mean score is -8.778531074523926, success_rate is 0.004999999888241291
Q1_loss is 0.99066, Q2_loss is 0.99125, actor_loss is -2.00470
mean score is -8.782065391540527, success_rate is 0.004999999888241291
Q1_lo

KeyboardInterrupt: 

In [None]:
import imageio


images = []

state, _ = eval_env.reset()
state = dict_to_vec(state).to(torch.float)
images.append(eval_env.render())

for i in range(50):
  action = model.actor.get_action(state)
  state, reward, terminated, truncated, _ = eval_env.step(action.detach().numpy())
  state = dict_to_vec(state).to(torch.float)
  # images.append(eval_env.render())

imageio.mimsave("./result.gif", images)

In [18]:
import stable_baselines3 as sb
from stable_baselines3.sac.policies import MultiInputPolicy

env_name = "FetchReach-v4"
model = sb.SAC(MultiInputPolicy, env_name, verbose=1)
model.learn(total_timesteps=50000, log_interval=10)

Using cpu device
Creating environment from the given name 'FetchReach-v4'
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
---------------------------------
| rollout/           |          |
|    ep_len_mean     | 50       |
|    ep_rew_mean     | -49.2    |
|    success_rate    | 0.1      |
| time/              |          |
|    episodes        | 10       |
|    fps             | 35       |
|    time_elapsed    | 13       |
|    total_timesteps | 500      |
| train/             |          |
|    actor_loss      | -6.24    |
|    critic_loss     | 0.0853   |
|    ent_coef        | 0.887    |
|    ent_coef_loss   | -0.804   |
|    learning_rate   | 0.0003   |
|    n_updates       | 399      |
---------------------------------
---------------------------------
| rollout/           |          |
|    ep_len_mean     | 50       |
|    ep_rew_mean     | -49.4    |
|    success_rate    | 0.05     |
| time/              |          |
|    episodes        | 20       |

KeyboardInterrupt: 