In [1]:
import torch
from torch import nn, optim
import torch.nn.functional as F
import gym
from gym import wrappers
import numpy as np
from itertools import count
from tqdm.notebook import tqdm
from time import time
import os
import pathlib
from torch.utils.tensorboard import SummaryWriter
from my_utils import ReplayBuffer, construct_nn, Logger

In [2]:
%load_ext tensorboard

In [3]:
SEED = 11
torch.manual_seed(SEED)
np.random.seed(SEED)

In [4]:
dev = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

In [5]:
class DDQN(nn.Module):
    def __init__(self, obs_dim, act_dim, sizes):
        super().__init__()
        self.policy = construct_nn([obs_dim] + sizes + [act_dim])
        self.target = construct_nn([obs_dim] + sizes + [act_dim])
        self.target.load_state_dict(self.policy.state_dict())
        self.target.eval()
    
    def act(self, obs):
        with torch.no_grad():
            return self.policy(obs).max(0)[1].view(1, 1)

In [6]:
BATCH_SIZE = 64
GAMMA = 0.99
TARGET_UPDATE = 50
TARGET_UPDATE_AFTER = 30
NETWORK_SIZES = [48, 48]
MAX_LEN = 200
EPS_START = 0.95
EPS_END = 0.005
EPS_DECAY = 2000

EPOCHS = 5
EPISODES_PER_EPOCH = 50
RENDER_PER_EPOCH = 2

total_steps = 0

def avg(l): return sum(l) / len(l)

logger = Logger()
logger.add_attribute('ret', [max, avg])
logger.add_attribute('len', [min, avg])
logger.add_attribute('loss', [avg])
logger.add_attribute('q', [avg])

RUN_ID = str(time())
RUN_PATH = f'./ddqn/{RUN_ID}'
if not os.path.exists(RUN_PATH) or not os.path.isdir(RUN_PATH):
    pathlib.Path(RUN_PATH).mkdir(parents=True, exist_ok=True)
    
TENSORBOARD_PATH = f'./ddqn/tensorboard'
writer = SummaryWriter(f'{TENSORBOARD_PATH}/{RUN_ID}')

In [7]:
env = gym.make('MountainCar-v0')
env.seed(SEED)
obs_dim = env.observation_space.shape[0]
act_dim = env.action_space.n

In [8]:
RENDER_THIS = False
env = wrappers.Monitor(env, f'{RUN_PATH}/videos/', video_callable=lambda episode_id: RENDER_THIS)

In [9]:
buffer = ReplayBuffer(obs_dim, 1, 10000, dev)

In [10]:
ddqn = DDQN(obs_dim, act_dim, NETWORK_SIZES).to(dev)
ddqn_optimizer = optim.Adam(ddqn.policy.parameters())

In [11]:
def compute_loss(q, q_exp):
    return F.mse_loss(q, q_exp)

In [12]:
def select_action(obs, randomize=True):
    if randomize:
        eps = np.random.random(1)[0]
        eps_threshold = EPS_END + (EPS_START - EPS_END) * np.exp(-1. * total_steps / EPS_DECAY)
        if eps > eps_threshold:
            return ddqn.act(obs).item()
    return np.random.choice(act_dim)

In [13]:
def optimize():
    if len(buffer) < BATCH_SIZE:
        return
    
    batch = buffer.sample_batch(BATCH_SIZE)
    obs, act, rew, next_obs, done = \
            batch['obs'], batch['act'], batch['rew'], batch['next_obs'], batch['done']
    
    q = ddqn.policy(obs).gather(1, act.long())
    q_best_by_policy = ddqn.policy(next_obs).max(1)[1].detach()
    q_next = ddqn.target(obs).gather(1, q_best_by_policy.view(-1, 1)).squeeze().detach()
    q_exp = q_next * GAMMA * (1 - done) + rew
    
    loss = compute_loss(q, q_exp.unsqueeze(1))
    
    logger.put('loss', loss.item())
    logger.put('q', q.mean().item())
    
    ddqn_optimizer.zero_grad()
    loss.backward()
    ddqn_optimizer.step()

In [14]:
%tensorboard --logdir $TENSORBOARD_PATH --host localhost

Reusing TensorBoard on port 6006 (pid 266506), started 2:21:07 ago. (Use '!kill 266506' to kill it.)

In [15]:
%%time
for epoch in range(EPOCHS):
    for episode in tqdm(range(EPISODES_PER_EPOCH), desc=f'[{epoch}]'):
        RENDER_THIS = True if episode % (EPISODES_PER_EPOCH // RENDER_PER_EPOCH) == 0 else False

        obs = torch.as_tensor(env.reset(), dtype=torch.float32).to(dev)
        ep_ret = 0
        ep_len = 0
        for t in count():
            act = select_action(obs)

            next_obs, rew, done, _ = env.step(act)
            
            # Motivate the agent to speed up more often
            rew += (abs(next_obs[1]) - abs(obs[1])) * 250

            ep_ret += rew
            ep_len += 1
            total_steps += 1
                        
            if RENDER_THIS:
                env.render()

            next_obs = torch.as_tensor(next_obs, dtype=torch.float32).to(dev)
            done = False if ep_len == MAX_LEN else done
            buffer.put(obs, act, rew, next_obs, done)
            
            obs = next_obs
            
            optimize()
            if done or ep_len == MAX_LEN:
                break
        logger.put('ret', ep_ret)
        logger.put('len', ep_len)
        if episode % TARGET_UPDATE == 0 and episode > TARGET_UPDATE_AFTER:
            ddqn.target.load_state_dict(ddqn.policy.state_dict())
            
        tb_data = logger.summarize(attributes=['loss', 'q'], fmt=False)
        ep_id = epoch * EPISODES_PER_EPOCH + episode
        for (attr, val) in tb_data:
            writer.add_scalar(attr, val, ep_id)
        writer.add_scalar('return', ep_ret, ep_id)
        writer.add_scalar('length', ep_len, ep_id)
        writer.flush()

    print(f'[{epoch}] Done: {logger.summarize(attributes=["ret", "len"])}')
    torch.save(ddqn.state_dict(), f'{RUN_PATH}/ddqn-{time()}-{epoch}.pt')
print('Complete!')
env.close()

HBox(children=(FloatProgress(value=0.0, description='[0]', max=50.0, style=ProgressStyle(description_width='in…


[0] Done: ret_max=-84.7824; ret_avg=-143.4456; len_min=90.0000; len_avg=152.7800


HBox(children=(FloatProgress(value=0.0, description='[1]', max=50.0, style=ProgressStyle(description_width='in…


[1] Done: ret_max=-77.0586; ret_avg=-117.9509; len_min=84.0000; len_avg=127.7800


HBox(children=(FloatProgress(value=0.0, description='[2]', max=50.0, style=ProgressStyle(description_width='in…


[2] Done: ret_max=-78.2371; ret_avg=-107.7594; len_min=85.0000; len_avg=117.3000


HBox(children=(FloatProgress(value=0.0, description='[3]', max=50.0, style=ProgressStyle(description_width='in…


[3] Done: ret_max=-78.1032; ret_avg=-103.0356; len_min=85.0000; len_avg=112.7600


HBox(children=(FloatProgress(value=0.0, description='[4]', max=50.0, style=ProgressStyle(description_width='in…


[4] Done: ret_max=-77.3485; ret_avg=-107.9094; len_min=84.0000; len_avg=117.8800
Complete!
CPU times: user 2min 44s, sys: 4.06 s, total: 2min 49s
Wall time: 3min 22s
