In [1]:
import torch
from torch import nn, optim
import torch.nn.functional as F
import gym
from gym import wrappers
import numpy as np
from itertools import count
from tqdm.notebook import tqdm
from my_utils import ReplayBuffer, construct_nn, Logger
from time import time
import os
import pathlib
from torch.utils.tensorboard import SummaryWriter

In [2]:
%load_ext tensorboard

In [3]:
SEED = 11
torch.manual_seed(SEED)
np.random.seed(SEED)

In [4]:
dev = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

In [5]:
class Actor(nn.Module):
    '''
    Actor performs acitons in the enviroment, taking the observation as an input.
    
    
    Args:
        obs_dim (int): Shape of the observation
        act_dim (int): Action space shape
        sizes (list of ints): Sizes of hidden layers, in order from observation to action
    '''
    def __init__(self, obs_dim, act_dim, sizes):
        super().__init__()
        
        self.pi = construct_nn([obs_dim] + sizes + [act_dim], nn.Tanh)
        
    def forward(self, obs):
        '''
        Process observation and choose an action
        
        Args:
            obs (torch.tensor): the environment state to take an action in.
        '''
        return self.pi(obs)

In [6]:
class Critic(nn.Module):
    '''
    Critic judges how good an action is, given preceding observation and action itself.
    Its goal is to compute the precise Q-value of a state.
    
    Args:
        obs_dim (int): Shape of the observation
        act_dim (int): Action space shape
        sizes (list of ints): Sizes of hidden layers, in order from input to output

    '''
    def __init__(self, obs_dim, act_dim, sizes):
        super().__init__()
        
        self.q = construct_nn([obs_dim + act_dim] + sizes + [1])
    
    def forward(self, obs, act):
        return self.q(torch.cat([obs, act], -1)).squeeze(-1)

In [7]:
class TD3(nn.Module):
    '''
    TD3 as a class incorporates the three networks described in the paper.
    There's a single Actor and two Critics, whose joint goal is to judge Actor's actions more precisely.
    
    Args:
        obs_dim (int): Shape of the observation
        act_dim (int): Action space shape
        sizes (list of ints): Sizes of hidden layers, in order from input to output. Same for the Actor and Critics.
    '''
    def __init__(self, obs_dim, act_dim, sizes):
        super().__init__()
        
        self.actor = Actor(obs_dim, act_dim, sizes)
        self.critic1 = Critic(obs_dim, act_dim, sizes)
        self.critic2 = Critic(obs_dim, act_dim, sizes)
    
    def act(self, obs):
        '''
        Choose best action according to the Actor's policy
        
        Args:
            obs (torch.tensor): the state to perform an action in
            
        Returns:
            (torch.tensor): Actor's action
        '''
        with torch.no_grad():
            return self.actor(obs)
    
    def criticize(self, obs, act):
        '''
        Let critics judge how good an action is given the observation.
        
        Args:
            obs (torch.tensor): the state that the action was performed in
            act (torch.tensor): the action
        '''
        return self.critic1(obs, act), self.critic2(obs, act)

Let's set some hyperparameters!
These are mostly taken from the paper, and are either the values recommended by authors, or well-suitable ones for the problem (like network size).

In [8]:
BATCH_SIZE = 100
GAMMA = 0.99
POLYAK = 0.995
POLICY_UPDATE_EVERY = 2
NETWORK_SIZES = [256, 256]
MAX_LEN = 1600
EPOCHS = 13 # About a million steps
EPISODES_PER_EPOCH = 50
RENDER_PER_EPOCH = 1 # Number of episodes to render every epoch. Footage is saved to the disk.
NOISE_LIMIT = 0.5
NOISE_STD = 0.2

total_steps = 0

def avg(l): return sum(l) / len(l)
logger = Logger()
logger.add_attribute('ret', [max, avg])
logger.add_attribute('len', [max, avg])
logger.add_attribute('critics_loss', [avg])
logger.add_attribute('actor_loss', [avg])
logger.add_attribute('q1', [avg])
logger.add_attribute('q2', [avg])

# Upon each run, a new directory is created. There will be stored model snaphosts,
# taken at the end of every epoch, tensorboard logs and video footage of training process.
RUN_ID = str(time())
RUN_PATH = f'./td3/{RUN_ID}'
if not os.path.exists(RUN_PATH) or not os.path.isdir(RUN_PATH):
    pathlib.Path(RUN_PATH).mkdir(parents=True, exist_ok=True)
    
TENSORBOARD_PATH = f'./td3/tensorboard'
writer = SummaryWriter(f'{TENSORBOARD_PATH}/{RUN_ID}')

In [9]:
env = gym.make('BipedalWalker-v3')
env.seed(SEED)
obs_dim = env.observation_space.shape[0]
act_dim = env.action_space.shape[0]
act_high = torch.as_tensor(env.action_space.high, dtype=torch.float32).to(dev)
act_low = torch.as_tensor(env.action_space.low, dtype=torch.float32).to(dev)



In [10]:
# This global variable is updated inside the training loop and indicates whether current episode
# should be rendered and saved to disk or not.
RENDER_THIS = False
env = wrappers.Monitor(env, f'{RUN_PATH}/videos/', video_callable=lambda episode_id: RENDER_THIS)

In [11]:
buffer = ReplayBuffer(obs_dim, act_dim, size=5 * EPISODES_PER_EPOCH * MAX_LEN, dev=dev)

In [12]:
# Create the online network and the offline (target) one.
# They are initialized to the same random weights.
td3 = TD3(obs_dim, act_dim, NETWORK_SIZES).to(dev)
td3_target = TD3(obs_dim, act_dim, NETWORK_SIZES).to(dev)
td3_target.load_state_dict(td3.state_dict())
td3_target.eval()

TD3(
  (actor): Actor(
    (pi): Sequential(
      (0): Linear(in_features=24, out_features=256, bias=True)
      (1): ReLU()
      (2): Linear(in_features=256, out_features=256, bias=True)
      (3): ReLU()
      (4): Linear(in_features=256, out_features=4, bias=True)
      (5): Tanh()
    )
  )
  (critic1): Critic(
    (q): Sequential(
      (0): Linear(in_features=28, out_features=256, bias=True)
      (1): ReLU()
      (2): Linear(in_features=256, out_features=256, bias=True)
      (3): ReLU()
      (4): Linear(in_features=256, out_features=1, bias=True)
      (5): Identity()
    )
  )
  (critic2): Critic(
    (q): Sequential(
      (0): Linear(in_features=28, out_features=256, bias=True)
      (1): ReLU()
      (2): Linear(in_features=256, out_features=256, bias=True)
      (3): ReLU()
      (4): Linear(in_features=256, out_features=1, bias=True)
      (5): Identity()
    )
  )
)

In [13]:
# Each sub-network of TD3 has its own optimizer, though the critics' ones will only be used together.
actor_optimizer = optim.Adam(td3.actor.parameters())
critic1_optimizer = optim.Adam(td3.critic1.parameters())
critic2_optimizer = optim.Adam(td3.critic2.parameters())

In [14]:
# Nothing fancy here :)
def compute_loss(q1, q2, q_exp):
    return F.mse_loss(q1, q_exp) + F.mse_loss(q2, q_exp)

In [15]:
# The action is selected by the actor of provided TD3 and then we add some noise to it,
# sampled from normal distribution with center 0 and std of 0.2
def select_action(obs, network, noisy=True):
    act = network.act(obs)
    if noisy:
        noise = torch.randn_like(act) * NOISE_STD
        noise.clamp_(-NOISE_LIMIT, NOISE_LIMIT)
        return torch.min(torch.max(act + noise, act_low), act_high)
    return act

In [16]:
# Do not update the actor every single time: authors recommend doing it half as often
def optimize(update_actor=False):
    if len(buffer) < BATCH_SIZE:
        return
    
    batch = buffer.sample_batch(BATCH_SIZE)
    obs, act, rew, next_obs, done = \
            batch['obs'], batch['act'], batch['rew'], batch['next_obs'], batch['done']
    
    next_act = select_action(next_obs, td3_target)
    with torch.no_grad():
        # Choose the minimum of two Q-values
        q_exp = torch.min(*td3_target.criticize(next_obs, next_act)) * GAMMA * (1 - done) + rew

    # Compute loss of online network's Q-values to the smaller of target network's ones
    q1, q2 = td3.criticize(obs, act)
    critics_loss = compute_loss(q1, q2, q_exp)
    
    logger.put('critics_loss', critics_loss.item())
    logger.put('q1', q1.mean().item())
    logger.put('q2', q2.mean().item())
    
    # As stated previously, despite having separate optimizers for the critics,
    # they are executed together.
    critic1_optimizer.zero_grad()
    critic2_optimizer.zero_grad()
    critics_loss.backward()
    critic1_optimizer.step()
    critic2_optimizer.step()
    
    if update_actor:
        actor_loss = -td3.critic1(obs, td3.actor(obs)).mean()
        logger.put('actor_loss', actor_loss.item())
        
        actor_optimizer.zero_grad()
        actor_loss.backward()
        actor_optimizer.step()
        
        # Update target model with regards to the online one.
        with torch.no_grad():
            for p, p_target in zip(td3.actor.parameters(), td3_target.actor.parameters()):
                p_target.data.mul_(POLYAK)
                p_target.data.add_((1 - POLYAK) * p.data)
            for p, p_target in zip(td3.critic1.parameters(), td3_target.critic1.parameters()):
                p_target.data.mul_(POLYAK)
                p_target.data.add_((1 - POLYAK) * p.data)
            for p, p_target in zip(td3.critic2.parameters(), td3_target.critic2.parameters()):
                p_target.data.mul_(POLYAK)
                p_target.data.add_((1 - POLYAK) * p.data)

In [17]:
# This notebook embeds a tensorboard!
%tensorboard --logdir $TENSORBOARD_PATH --host localhost

Reusing TensorBoard on port 6006 (pid 7340), started 0:00:55 ago. (Use '!kill 7340' to kill it.)

In [18]:
%%time
for epoch in range(EPOCHS):
    for episode in tqdm(range(EPISODES_PER_EPOCH), desc=f'[{epoch}]'):
        # Do not render all the episode, as it would take too much time:
        # instead, only render $RENDER_PER_EPOCH ones every epoch
#         RENDER_THIS = True if episode % (EPISODES_PER_EPOCH // RENDER_PER_EPOCH) == 0 else False
        
        obs = torch.as_tensor(env.reset(), dtype=torch.float32).to(dev)
        ep_ret = 0
        ep_len = 0        
        for t in count():
            act = select_action(obs, td3)
            
            next_obs, rew, done, _ = env.step(act.cpu().numpy())
            next_obs = torch.as_tensor(next_obs, dtype=torch.float32).to(dev)
            done = False if ep_len == MAX_LEN else done

            ep_ret += rew
            ep_len += 1
            total_steps += 1

            if RENDER_THIS:
                env.render()
            
            buffer.put(obs, act, rew, next_obs, done)
            obs = next_obs
            
            optimize(total_steps % POLICY_UPDATE_EVERY)
            
            if done or ep_len == MAX_LEN:
                break
        logger.put('ret', ep_ret)
        logger.put('len', ep_len)

        tb_data = logger.summarize(attributes=['critics_loss', 'actor_loss', 'q1', 'q2'], fmt=False)
        ep_id = epoch * EPISODES_PER_EPOCH + episode
        for (attr, val) in tb_data:
            writer.add_scalar(attr, val, ep_id)
        writer.add_scalar('return', ep_ret, ep_id)
        writer.add_scalar('length', ep_len, ep_id)
        writer.flush()

    print(f'[{epoch}] {logger.summarize(attributes=["ret", "len"])}')
    torch.save(td3.state_dict(), f'{RUN_PATH}/td3-epoch-{epoch}.pt')
env.close()

HBox(children=(FloatProgress(value=0.0, description='[0]', max=50.0, style=ProgressStyle(description_width='in…


[0] ret_max=-104.4224; ret_avg=-122.3216; len_max=1600.0000; len_avg=171.2000


HBox(children=(FloatProgress(value=0.0, description='[1]', max=50.0, style=ProgressStyle(description_width='in…


[1] ret_max=-83.6791; ret_avg=-115.2220; len_max=1102.0000; len_avg=174.3000


HBox(children=(FloatProgress(value=0.0, description='[2]', max=50.0, style=ProgressStyle(description_width='in…


[2] ret_max=-43.3383; ret_avg=-116.3471; len_max=1600.0000; len_avg=693.5600


HBox(children=(FloatProgress(value=0.0, description='[3]', max=50.0, style=ProgressStyle(description_width='in…


[3] ret_max=244.2094; ret_avg=-13.2133; len_max=1600.0000; len_avg=888.5200


HBox(children=(FloatProgress(value=0.0, description='[4]', max=50.0, style=ProgressStyle(description_width='in…


[4] ret_max=277.3995; ret_avg=138.6303; len_max=1554.0000; len_avg=910.4400


HBox(children=(FloatProgress(value=0.0, description='[5]', max=50.0, style=ProgressStyle(description_width='in…


[5] ret_max=290.4075; ret_avg=122.8555; len_max=1133.0000; len_avg=691.5200


HBox(children=(FloatProgress(value=0.0, description='[6]', max=50.0, style=ProgressStyle(description_width='in…


[6] ret_max=292.8721; ret_avg=148.2164; len_max=996.0000; len_avg=686.3800


HBox(children=(FloatProgress(value=0.0, description='[7]', max=50.0, style=ProgressStyle(description_width='in…


[7] ret_max=293.1718; ret_avg=149.7821; len_max=968.0000; len_avg=676.2400


HBox(children=(FloatProgress(value=0.0, description='[8]', max=50.0, style=ProgressStyle(description_width='in…


[8] ret_max=296.9715; ret_avg=245.9381; len_max=931.0000; len_avg=806.0400


HBox(children=(FloatProgress(value=0.0, description='[9]', max=50.0, style=ProgressStyle(description_width='in…


[9] ret_max=301.2549; ret_avg=197.1258; len_max=890.0000; len_avg=696.8800


HBox(children=(FloatProgress(value=0.0, description='[10]', max=50.0, style=ProgressStyle(description_width='i…


[10] ret_max=302.2743; ret_avg=207.4384; len_max=877.0000; len_avg=695.7000


HBox(children=(FloatProgress(value=0.0, description='[11]', max=50.0, style=ProgressStyle(description_width='i…


[11] ret_max=305.1095; ret_avg=226.9158; len_max=856.0000; len_avg=713.6400


HBox(children=(FloatProgress(value=0.0, description='[12]', max=50.0, style=ProgressStyle(description_width='i…


[12] ret_max=304.6622; ret_avg=280.2831; len_max=896.0000; len_avg=793.3600


HBox(children=(FloatProgress(value=0.0, description='[13]', max=50.0, style=ProgressStyle(description_width='i…


[13] ret_max=302.7494; ret_avg=289.1408; len_max=1106.0000; len_avg=852.1400
