In [None]:
import torch
from torch import nn, optim
import torch.nn.functional as F
import gym
from gym import wrappers
import numpy as np
from itertools import count
from tqdm.notebook import tqdm
from time import time
import os
import pathlib
from torch.utils.tensorboard import SummaryWriter
from my_utils import ReplayBuffer, construct_nn, Logger

In [None]:
%load_ext tensorboard

In [None]:
SEED = 11
torch.manual_seed(SEED)
np.random.seed(SEED)

In [None]:
dev = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

In [None]:
class DDQN(nn.Module):
    def __init__(self, obs_dim, act_dim, sizes):
        super().__init__()
        self.policy = construct_nn([obs_dim] + sizes + [act_dim])
        self.target = construct_nn([obs_dim] + sizes + [act_dim])
        self.target.load_state_dict(self.policy.state_dict())
        self.target.eval()
    
    def act(self, obs):
        with torch.no_grad():
            return self.policy(obs).max(0)[1].view(1, 1)

In [None]:
BATCH_SIZE = 32
GAMMA = 0.99
TARGET_UPDATE = 10
TARGET_UPDATE_AFTER = 30
NETWORK_SIZES = [24, 48]
MAX_LEN = 200
EPS_START = 0.95
EPS_END = 0.05
EPS_DECAY = 5000

EPOCHS = 10
EPISODES_PER_EPOCH = 50
RENDER_PER_EPOCH = 2

total_steps = 0

def avg(l): return sum(l) / len(l)

logger = Logger()
logger.add_attribute('ret', [max, avg])
logger.add_attribute('len', [min, avg])
logger.add_attribute('loss', [avg])
logger.add_attribute('q', [avg])

RUN_ID = str(time())
RUN_PATH = f'./ddqn/{RUN_ID}'
if not os.path.exists(RUN_PATH) or not os.path.isdir(RUN_PATH):
    pathlib.Path(RUN_PATH).mkdir(parents=True, exist_ok=True)
    
TENSORBOARD_PATH = f'./ddqn/tensorboard'
writer = SummaryWriter(f'{TENSORBOARD_PATH}/{RUN_ID}')

In [None]:
env = gym.make('MountainCar-v0')
env.seed(SEED)
obs_dim = env.observation_space.shape[0]
act_dim = env.action_space.n

In [None]:
RENDER_THIS = False
env = wrappers.Monitor(env, f'{RUN_PATH}/videos/', video_callable=lambda episode_id: RENDER_THIS)

In [None]:
buffer = ReplayBuffer(obs_dim, 1, 100000, dev)

In [None]:
ddqn = DDQN(obs_dim, act_dim, NETWORK_SIZES).to(dev)
ddqn_optimizer = optim.Adam(ddqn.policy.parameters())

In [None]:
def compute_loss(q, q_exp):
    return F.mse_loss(q, q_exp)

In [None]:
def select_action(obs, randomize=True):
    if randomize:
        eps = np.random.random(1)[0]
        eps_threshold = EPS_END + (EPS_START - EPS_END) * np.exp(-1. * total_steps / EPS_DECAY)
        if eps > eps_threshold:
            return ddqn.act(obs).item()
    return np.random.choice(act_dim)

In [None]:
def optimize():
    if len(buffer) < BATCH_SIZE:
        return
    
    batch = buffer.sample_batch(BATCH_SIZE)
    obs, act, rew, next_obs, done = \
            batch['obs'], batch['act'], batch['rew'], batch['next_obs'], batch['done']
    
    q = ddqn.policy(obs).gather(1, act.long())
    q_best_by_policy = ddqn.policy(obs).max(1)[1].detach()
    q_next = ddqn.target(obs).gather(1, q_best_by_policy.view(-1, 1)).squeeze().detach()
    q_exp = q_next * GAMMA * (1 - done) + rew
    
    loss = compute_loss(q, q_exp.unsqueeze(1))
    
    logger.put('loss', loss.item())
    logger.put('q', q.mean().item())
    
    ddqn_optimizer.zero_grad()
    loss.backward()
    ddqn_optimizer.step()

In [None]:
%tensorboard --logdir $TENSORBOARD_PATH --host localhost

In [None]:
%%time
for epoch in range(EPOCHS):
    for episode in tqdm(range(EPISODES_PER_EPOCH), desc=f'[{epoch}]'):
        RENDER_THIS = True if episode % (EPISODES_PER_EPOCH // RENDER_PER_EPOCH) == 0 else False

        obs = torch.as_tensor(env.reset(), dtype=torch.float32).to(dev)
        ep_ret = 0
        ep_len = 0
        for t in count():
            act = select_action(obs)

            next_obs, rew, done, _ = env.step(act)

            ep_ret += rew
            ep_len += 1
            total_steps += 1
                        
            if RENDER_THIS:
                env.render()

            next_obs = torch.as_tensor(next_obs, dtype=torch.float32).to(dev)
            done = False if ep_len == MAX_LEN else done
            buffer.put(obs, act, rew, next_obs, done)
            
            obs = next_obs
            
            optimize()
            if done or ep_len == MAX_LEN:
                break
        logger.put('ret', ep_ret)
        logger.put('len', ep_len)
        if episode % TARGET_UPDATE == 0 and episode > TARGET_UPDATE_AFTER:
            ddqn.target.load_state_dict(ddqn.policy.state_dict())
            
        tb_data = logger.summarize(attributes=['loss', 'q'], fmt=False)
        ep_id = epoch * EPISODES_PER_EPOCH + episode
        for (attr, val) in tb_data:
            writer.add_scalar(attr, val, ep_id)
        writer.add_scalar('return', ep_ret, ep_id)
        writer.add_scalar('length', ep_len, ep_id)
        writer.flush()

    print(f'[{epoch}] Done: {logger.summarize(attributes=["ret", "len"])}')
    torch.save(ddqn.state_dict(), f'{RUN_PATH}/ddqn-{time()}-{epoch}.pt')
print('Complete!')
env.close()