In [None]:
import torch
from torch import nn, optim
import torch.nn.functional as F
import gym
from gym import wrappers
import numpy as np
from itertools import count
from tqdm.notebook import tqdm
from my_utils import ReplayBuffer, construct_nn, Logger
from time import time
import os

In [None]:
SEED = 11
torch.manual_seed(SEED)
np.random.seed(SEED)

In [None]:
dev = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

In [None]:
class Actor(nn.Module):
    def __init__(self, obs_dim, act_dim, sizes):
        super().__init__()
        
        self.pi = construct_nn([obs_dim] + sizes + [act_dim], nn.Tanh)
        
    def forward(self, obs):
        return self.pi(obs)

In [None]:
class Critic(nn.Module):
    def __init__(self, obs_dim, act_dim, sizes):
        super().__init__()
        
        self.q = construct_nn([obs_dim + act_dim] + sizes + [1])
    
    def forward(self, obs, act):
        return self.q(torch.cat([obs, act], -1)).squeeze(-1)

In [None]:
class TD3(nn.Module):
    def __init__(self, obs_dim, act_dim, sizes):
        super().__init__()
        
        self.actor = Actor(obs_dim, act_dim, sizes)
        self.critic1 = Critic(obs_dim, act_dim, sizes)
        self.critic2 = Critic(obs_dim, act_dim, sizes)
    
    def act(self, obs):
        with torch.no_grad():
            return self.actor(obs)
    
    def criticize(self, obs, act):
        return self.critic1(obs, act), self.critic2(obs, act)

In [None]:
BATCH_SIZE = 100
GAMMA = 0.99
POLYAK = 0.995
POLICY_UPDATE_EVERY = 2
NETWORK_SIZES = [256, 256]
MAX_LEN = 1600
EPOCHS = 20
EPISODES_PER_EPOCH = 50
RENDER_PER_EPOCH = 1
NOISE_LIMIT = 0.5
NOISE_STD = 0.2

total_steps = 0

def avg(l): return sum(l) / len(l)

logger = Logger()
logger.add_attribute('ret', [max, avg])
logger.add_attribute('len', [max, avg])

MODEL_SAVE_PATH = './td3'
# This will raise an exception if there's already a file with this name
if not os.path.exists(MODEL_SAVE_PATH) or not os.path.isdir(MODEL_SAVE_PATH):
    os.mkdir(MODEL_SAVE_PATH)

In [None]:
env = gym.make('BipedalWalker-v3')
env.seed(SEED)
obs_dim = env.observation_space.shape[0]
act_dim = env.action_space.shape[0]
act_high = torch.as_tensor(env.action_space.high, dtype=torch.float32).to(dev)
act_low = torch.as_tensor(env.action_space.low, dtype=torch.float32).to(dev)

In [None]:
RENDER_THIS = False
env = wrappers.Monitor(env, f'./videos/{str(time())}/', video_callable=lambda episode_id: RENDER_THIS)

In [None]:
buffer = ReplayBuffer(obs_dim, act_dim, 5 * EPISODES_PER_EPOCH * MAX_LEN, dev)

In [None]:
td3 = TD3(obs_dim, act_dim, NETWORK_SIZES).to(dev)
td3_target = TD3(obs_dim, act_dim, NETWORK_SIZES).to(dev)
td3_target.load_state_dict(td3.state_dict())
td3_target.eval()

In [None]:
actor_optimizer = optim.Adam(td3.actor.parameters())
critic1_optimizer = optim.Adam(td3.critic1.parameters())
critic2_optimizer = optim.Adam(td3.critic2.parameters())

In [None]:
def compute_loss(q1, q2, q_exp):
    return F.mse_loss(q1, q_exp) + F.mse_loss(q2, q_exp)

In [None]:
def select_action(obs, network, noisy=True):
    act = network.act(obs)
    if noisy:
        noise = torch.randn_like(act) * NOISE_STD
        noise.clamp_(-NOISE_LIMIT, NOISE_LIMIT)
        return torch.min(torch.max(act + noise, act_low), act_high)
    return act

In [None]:
def optimize(update_actor=False):
    if len(buffer) < BATCH_SIZE:
        return
    
    batch = buffer.sample_batch(BATCH_SIZE)
    obs, act, rew, next_obs, done = \
            batch['obs'], batch['act'], batch['rew'], batch['next_obs'], batch['done']
    
    next_act = select_action(next_obs, td3_target)
    with torch.no_grad():
        q_exp = torch.min(*td3_target.criticize(next_obs, next_act)) * GAMMA * (1 - done) + rew
    
    q1, q2 = td3.criticize(obs, act)
    critics_loss = compute_loss(q1, q2, q_exp)
    
    critic1_optimizer.zero_grad()
    critic2_optimizer.zero_grad()
    critics_loss.backward()
    critic1_optimizer.step()
    critic2_optimizer.step()
    
    if update_actor:
        actor_loss = -td3.critic1(obs, td3.actor(obs)).mean()
        
        actor_optimizer.zero_grad()
        actor_loss.backward()
        actor_optimizer.step()

        with torch.no_grad():
            for p, p_target in zip(td3.actor.parameters(), td3_target.actor.parameters()):
                p_target.data.mul_(POLYAK)
                p_target.data.add_((1 - POLYAK) * p.data)
            for p, p_target in zip(td3.critic1.parameters(), td3_target.critic1.parameters()):
                p_target.data.mul_(POLYAK)
                p_target.data.add_((1 - POLYAK) * p.data)
            for p, p_target in zip(td3.critic2.parameters(), td3_target.critic2.parameters()):
                p_target.data.mul_(POLYAK)
                p_target.data.add_((1 - POLYAK) * p.data)

In [None]:
%%time
for epoch in range(EPOCHS):
    for episode in tqdm(range(EPISODES_PER_EPOCH), desc=f'[{epoch}]'):
        RENDER_THIS = True if episode % (EPISODES_PER_EPOCH // RENDER_PER_EPOCH) == 0 else False
        
        obs = torch.as_tensor(env.reset(), dtype=torch.float32).to(dev)
        ep_ret = 0
        ep_len = 0        
        for t in count():
            act = select_action(obs, td3)
            
            next_obs, rew, done, _ = env.step(act.cpu().numpy())
            next_obs = torch.as_tensor(next_obs, dtype=torch.float32).to(dev)
            done = False if ep_len == MAX_LEN else done

            ep_ret += rew
            ep_len += 1
            total_steps += 1

            if RENDER_THIS:
                env.render()
            
            buffer.put(obs, act, rew, next_obs, done)
            obs = next_obs
            
            optimize(total_steps % POLICY_UPDATE_EVERY)
            
            if done or ep_len == MAX_LEN:
                break
        logger.put('ret', ep_ret)
        logger.put('len', ep_len)
    print(f'[{epoch}] {logger.summarize()}')
    torch.save(td3.state_dict(), f'{MODEL_SAVE_PATH}/td3-{time()}-{epoch}.pt')
env.close()