In [1]:
import torch
from torch import nn, optim
import torch.nn.functional as F
import gym
import numpy as np
from itertools import count
from tqdm.notebook import tqdm
from my_utils import ReplayBuffer, construct_nn, Logger

In [2]:
SEED = 11
torch.manual_seed(SEED)
np.random.seed(SEED)

In [3]:
dev = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

In [4]:
class DDQN(nn.Module):
    def __init__(self, obs_dim, act_dim, sizes):
        super().__init__()
        self.policy = construct_nn([obs_dim] + sizes + [act_dim])
        self.target = construct_nn([obs_dim] + sizes + [act_dim])
        self.target.load_state_dict(self.policy.state_dict())
        self.target.eval()
    
    def act(self, obs):
        with torch.no_grad():
            return self.policy(obs).max(0)[1].view(1, 1)

In [5]:
BATCH_SIZE = 32
GAMMA = 0.99
TARGET_UPDATE = 10
TARGET_UPDATE_AFTER = 30
NETWORK_SIZES = [24, 48]
MAX_LEN = 200
EPS_START = 0.95
EPS_END = 0.05
EPS_DECAY = 5000

EPOCHS = 10
EPISODES_PER_EPOCH = 50

start = 0
total_steps = 0

def avg(l): return sum(l) / len(l)

logger = Logger()
logger.add_attribute('ret', [max, avg])
logger.add_attribute('len', [min, avg])
logger.add_attribute('loss', [sum])

In [6]:
env = gym.make('MountainCar-v0')
env.seed(SEED)
obs_dim = env.observation_space.shape[0]
act_dim = env.action_space.n

In [7]:
buffer = ReplayBuffer(obs_dim, 1, 100000, dev)

In [8]:
ddqn = DDQN(obs_dim, act_dim, NETWORK_SIZES).to(dev)
ddqn_optimizer = optim.Adam(ddqn.policy.parameters())

In [9]:
def compute_loss(q, q_exp):
    return F.mse_loss(q, q_exp)

In [10]:
def select_action(obs):
    eps = np.random.random(1)[0]
    eps_threshold = EPS_END + (EPS_START - EPS_END) * np.exp(-1. * total_steps / EPS_DECAY)
    if eps > eps_threshold:
        return ddqn.act(obs).item()
    else:
        return np.random.choice(act_dim)

In [11]:
def optimize():
    if len(buffer) < BATCH_SIZE:
        return
    
    batch = buffer.sample_batch(BATCH_SIZE)
    obs, act, rew, next_obs, done = \
            batch['obs'], batch['act'], batch['rew'], batch['next_obs'], batch['done']
    
    q = ddqn.policy(obs).gather(1, act.long())
    q_best_by_policy = ddqn.policy(obs).max(1)[1].detach()
    q_next = ddqn.target(obs).gather(1, q_best_by_policy.view(-1, 1)).squeeze().detach()
    q_exp = q_next * GAMMA * (1 - done) + rew
    
    loss = compute_loss(q, q_exp.unsqueeze(1))
    logger.put('loss', loss.item())
    
    ddqn_optimizer.zero_grad()
    loss.backward()
    ddqn_optimizer.step()

In [12]:
%%time
for epoch in range(EPOCHS):
    for episode in tqdm(range(EPISODES_PER_EPOCH), desc=f'[{epoch}]'):
        obs = torch.as_tensor(env.reset(), dtype=torch.float32).to(dev)
        ep_ret = 0
        ep_len = 0
        ep_loss = 0
        for t in count():
            act = select_action(obs)

            next_obs, rew, done, _ = env.step(act)

            ep_ret += rew
            ep_len += 1
            total_steps += 1
                        
            if episode == 0:
                env.render()

            next_obs = torch.as_tensor(next_obs, dtype=torch.float32).to(dev)
            done = False if ep_len == MAX_LEN else done
            buffer.put(obs, act, rew, next_obs, done)
            
            obs = next_obs
            
            optimize()
            if done or ep_len == MAX_LEN:
                break
        logger.put('ret', ep_ret)
        logger.put('len', ep_len)
        if episode % TARGET_UPDATE == 0 and episode > TARGET_UPDATE_AFTER:
            ddqn.target.load_state_dict(ddqn.policy.state_dict())
    print(f'[{epoch}] Done: {logger.summarize()}')
print('Complete!')
env.close()

HBox(children=(FloatProgress(value=0.0, description='[0]', max=50.0, style=ProgressStyle(description_width='in…


[0] Done: ret_max=-200.0000; ret_avg=-200.0000; len_min=200.0000; len_avg=200.0000; loss_sum=88.5227


HBox(children=(FloatProgress(value=0.0, description='[1]', max=50.0, style=ProgressStyle(description_width='in…


[1] Done: ret_max=-200.0000; ret_avg=-200.0000; len_min=200.0000; len_avg=200.0000; loss_sum=5.3837


HBox(children=(FloatProgress(value=0.0, description='[2]', max=50.0, style=ProgressStyle(description_width='in…


[2] Done: ret_max=-200.0000; ret_avg=-200.0000; len_min=200.0000; len_avg=200.0000; loss_sum=4.8975


HBox(children=(FloatProgress(value=0.0, description='[3]', max=50.0, style=ProgressStyle(description_width='in…


[3] Done: ret_max=-200.0000; ret_avg=-200.0000; len_min=200.0000; len_avg=200.0000; loss_sum=4.6178


HBox(children=(FloatProgress(value=0.0, description='[4]', max=50.0, style=ProgressStyle(description_width='in…


[4] Done: ret_max=-200.0000; ret_avg=-200.0000; len_min=200.0000; len_avg=200.0000; loss_sum=4.4215


HBox(children=(FloatProgress(value=0.0, description='[5]', max=50.0, style=ProgressStyle(description_width='in…


[5] Done: ret_max=-200.0000; ret_avg=-200.0000; len_min=200.0000; len_avg=200.0000; loss_sum=4.3449


HBox(children=(FloatProgress(value=0.0, description='[6]', max=50.0, style=ProgressStyle(description_width='in…


[6] Done: ret_max=-200.0000; ret_avg=-200.0000; len_min=200.0000; len_avg=200.0000; loss_sum=4.2509


HBox(children=(FloatProgress(value=0.0, description='[7]', max=50.0, style=ProgressStyle(description_width='in…


[7] Done: ret_max=-200.0000; ret_avg=-200.0000; len_min=200.0000; len_avg=200.0000; loss_sum=3.9854


HBox(children=(FloatProgress(value=0.0, description='[8]', max=50.0, style=ProgressStyle(description_width='in…


[8] Done: ret_max=-200.0000; ret_avg=-200.0000; len_min=200.0000; len_avg=200.0000; loss_sum=3.8801


HBox(children=(FloatProgress(value=0.0, description='[9]', max=50.0, style=ProgressStyle(description_width='in…


[9] Done: ret_max=-195.0000; ret_avg=-199.9000; len_min=195.0000; len_avg=199.9000; loss_sum=4.0827
Complete!
CPU times: user 5min 14s, sys: 3.86 s, total: 5min 18s
Wall time: 5min 33s
