https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/sac_continuous_action.py

In [1]:
import gym
import gym.wrappers.monitor
import gym_cartpole_swingup
import os
import random
import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm.notebook import tqdm
from itertools import count
from IPython.display import Video, Image
from collections import deque
from typing import Tuple, List
from dataclasses import dataclass

## Configuration

In [2]:
# config
LOG_STD_MAX = 2
LOG_STD_MIN = -5
Q_LEARNING_RATE = 0.001
MEMORY_SIZE = 100000
MEMORY_WARMUP = 500
BATCH_SIZE = 64
GAMMA = 0.99
POLICY_UPDATE_FREQUENCY = 2
TARGET_UPDATE_FREQUENCY = 1
SOFT_UPDATE_TAU = 0.005

## Create environment

In [3]:
env = gym.make("CartPoleSwingUp-v1")

In [4]:
assert isinstance(env.action_space, gym.spaces.Box), "only continuous action space is supported"

In [5]:
# preview env
fname = "cartpole-swingup-random.mp4"
if not os.path.exists(fname):
    recorder = gym.wrappers.monitor.video_recorder.VideoRecorder(env, fname, enabled=True)
    env.reset()
    recorder.capture_frame()
    for i in tqdm(count()):
        action = env.action_space.sample()
        _, _, done, _ = env.step(action)
        recorder.capture_frame()
        if done:
            break
    recorder.close()
Video(fname)

In [6]:
print(env.observation_space.shape)
print(env.action_space.shape)

(5,)
(1,)


## Memory

In [7]:
@dataclass
class Transition:
    observation: torch.Tensor
    action: torch.Tensor
    reward: float
    next_observation: torch.Tensor
    terminal: bool

@dataclass
class TransitionBatch:
    observations: torch.Tensor
    actions: torch.Tensor
    rewards: torch.Tensor
    next_observations: torch.Tensor
    terminals: torch.Tensor

In [8]:
memory = deque(maxlen=MEMORY_SIZE)

In [9]:
def sample_memory(batch_size: int) -> TransitionBatch:
    global memory
    samples: List[Transition] = random.sample(memory, batch_size)
    return TransitionBatch(
        observations=torch.stack([s.observation for s in samples]),
        actions=torch.as_tensor([s.action for s in samples], dtype=torch.float32),
        rewards=torch.as_tensor([s.reward for s in samples], dtype=torch.float32),
        next_observations=torch.stack([s.next_observation for s in samples]),
        terminals=torch.as_tensor([s.terminal for s in samples], dtype=torch.bool),
    )

## Create model

In [10]:
class SoftQNetwork(nn.Module):
    def __init__(self, env: gym.Env) -> None:
        super().__init__()
        self.fc1 = nn.Linear(math.prod(env.observation_space.shape) + math.prod(env.action_space.shape), 256)
        self.fc2 = nn.Linear(256, 256)
        self.fc3 = nn.Linear(256, 1)
    
    def forward(self, x, a):
        x = torch.cat((x, a), dim=1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

In [11]:
class Actor(nn.Module):
    def __init__(self, env: gym.Env) -> None:
        super().__init__()
        self.fc1 = nn.Linear(math.prod(env.observation_space.shape), 256)
        self.fc2 = nn.Linear(256, 256)
        self.fc_mean = nn.Linear(256, math.prod(env.action_space.shape))
        self.fc_logstd = nn.Linear(256, math.prod(env.action_space.shape))
        # action rescaling
        self.register_buffer("action_scale", torch.tensor((env.action_space.high - env.action_space.low) / 2.))
        self.register_buffer("action_bias", torch.tensor((env.action_space.high + env.action_space.low) / 2.))

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        mean = self.fc_mean(x)
        logstd = self.fc_logstd(x)
        logstd = torch.tanh(logstd)
        logstd = LOG_STD_MIN + 0.5 * (LOG_STD_MAX - LOG_STD_MIN) * (logstd + 1)
        return mean, logstd

    def get_action(self, x: torch.Tensor):
        mean, logstd = self(x)
        std = logstd.exp()
        normal = torch.distributions.Normal(mean, std)
        x_t = normal.rsample()
        y_t = torch.tanh(x_t)
        action = y_t * self.action_scale + self.action_bias
        log_prob = normal.log_prob(x_t)
        log_prob -= torch.log(self.action_scale * (1 - y_t.pow(2)) + 1e-6)
        log_prob = log_prob.sum(1, keepdim=True)
        mean = torch.tanh(mean) * self.action_scale + self.action_bias
        return action, log_prob, mean

In [12]:
# weight update functions
# from https://github.com/ghliu/pytorch-ddpg/blob/master/util.py#L26
def soft_update(target, source, tau):
    # using .parameters() doesn't include all params for batchnorm causing issues
    for (k1, v1), (k2, v2) in zip(target.state_dict().items(), source.state_dict().items()):
        assert k1 == k2
        v1.copy_(v1 * (1.0 - tau) + v2 * tau)
        # target_param.data.copy_(target_param.data * (1.0 - tau) + param.data * tau)

def hard_update(target, source):
    target.load_state_dict(source.state_dict())
    # for target_param, param in zip(target.parameters(), source.parameters()):
    #         target_param.data.copy_(param.data)

In [13]:
actor = Actor(env)
qf1 = SoftQNetwork(env)
qf2 = SoftQNetwork(env)
qf1_target = SoftQNetwork(env)
qf2_target = SoftQNetwork(env)
qf1_target.load_state_dict(qf1.state_dict())
qf2_target.load_state_dict(qf2.state_dict())
q_optimizer = torch.optim.Adam(list(qf1.parameters()) + list(qf2.parameters()), lr=Q_LEARNING_RATE)
actor_optimizer = torch.optim.Adam(actor.parameters(), lr=Q_LEARNING_RATE)

In [14]:
# setup autotune
target_entropy = -torch.prod(torch.Tensor(env.action_space.shape)).item()
log_alpha = torch.zeros(1, requires_grad=True)
alpha = log_alpha.exp().item()
alpha_optimizer = torch.optim.Adam([log_alpha], lr=Q_LEARNING_RATE)

## Train

In [15]:
obs = torch.as_tensor(env.reset(), dtype=torch.float32)
actor_loss=0
with tqdm() as pbar:
    for step in count():
        if len(memory) < MEMORY_WARMUP:
            action = env.action_space.sample()
        else:
            action, _, _ = actor.get_action(obs.unsqueeze(0))
            action = action.detach().cpu().numpy()

        env.render()
        
        next_obs, reward, done, info = env.step(action)
        # cartpoleswingupv1 giving weird next obs shape, manually fix for now
        next_obs = torch.as_tensor(next_obs, dtype=torch.float32).reshape(env.observation_space.shape)

        memory.append(Transition(
            observation=obs,
            action=action.item(),
            reward=reward.item(),
            next_observation=next_obs,
            terminal=done,
        ))

        if done:
            obs = torch.as_tensor(env.reset(), dtype=torch.float32)
        else:
            obs = next_obs

        if len(memory) < MEMORY_WARMUP:
            pbar.set_description(f"warmup")
            pbar.total = MEMORY_WARMUP
        else:
            pbar.total = None
            batch = sample_memory(BATCH_SIZE)
            with torch.no_grad():
                next_state_actions, next_state_log_pi, _ = actor.get_action(batch.next_observations)
                qf1_next_target = qf1_target(batch.next_observations, next_state_actions)
                qf2_next_target = qf2_target(batch.next_observations, next_state_actions)
                min_qf_next_target = torch.min(qf1_next_target, qf2_next_target) - alpha * next_state_log_pi
                next_q_value = batch.rewards.unsqueeze(1) + (((1. - batch.terminals.float()) * GAMMA).unsqueeze(1) * min_qf_next_target)
            qf1_a_values = qf1(batch.observations, batch.actions.unsqueeze(1))
            qf2_a_values = qf2(batch.observations, batch.actions.unsqueeze(1))
            qf1_loss = F.mse_loss(qf1_a_values, next_q_value)
            qf2_loss = F.mse_loss(qf2_a_values, next_q_value)
            qf_loss = qf1_loss + qf2_loss
            
            q_optimizer.zero_grad()
            qf_loss.backward()
            q_optimizer.step()

            if step % POLICY_UPDATE_FREQUENCY == 0:
                for _ in range(POLICY_UPDATE_FREQUENCY):
                    pi, log_pi, _ = actor.get_action(batch.observations)
                    qf1_pi = qf1(batch.observations, pi)
                    qf2_pi = qf2(batch.observations, pi)
                    min_qf_pi = torch.min(qf1_pi, qf2_pi).view(-1)
                    actor_loss = (alpha * log_pi - min_qf_pi).mean()

                    actor_optimizer.zero_grad()
                    actor_loss.backward()
                    actor_optimizer.step()

                    # autotune
                    with torch.no_grad():
                        _, log_pi, _ = actor.get_action(batch.observations)
                    alpha_loss = (-log_alpha * (log_pi + target_entropy)).mean()

                    alpha_optimizer.zero_grad()
                    alpha_loss.backward()
                    alpha_optimizer.step()
                    alpha = log_alpha.exp().item()

            if step % TARGET_UPDATE_FREQUENCY == 0:
                soft_update(qf1_target, qf1, SOFT_UPDATE_TAU)
                soft_update(qf2_target, qf2, SOFT_UPDATE_TAU)

            pbar.set_description(f"reward: {reward.item():.3f}, actor_loss: {actor_loss:.3f}, qf1_loss: {qf1_loss:.3f}, qf2_loss: {qf2_loss:.3f}, alpha: {alpha:.3f}")
        pbar.update()

0it [00:00, ?it/s]

KeyboardInterrupt: 