In [1]:
import torch
from torch import nn, optim
import torch.nn.functional as F
import gym
import numpy as np
from itertools import count
from my_utils import ReplayBuffer, construct_nn, Logger

In [2]:
SEED = 11
torch.manual_seed(SEED)
np.random.seed(SEED)

In [3]:
dev = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

In [4]:
class DDQN(nn.Module):
    def __init__(self, obs_dim, act_dim, sizes):
        super().__init__()
        self.policy = construct_nn([obs_dim] + sizes + [act_dim])
        self.target = construct_nn([obs_dim] + sizes + [act_dim])
        self.target.load_state_dict(self.policy.state_dict())
        self.target.eval()
    
    def act(self, obs):
        with torch.no_grad():
            return self.policy(obs).max(0)[1].view(1, 1)

In [5]:
BATCH_SIZE = 256
GAMMA = 0.99
TARGET_UPDATE = 10
TARGET_UPDATE_AFTER = 30
NETWORK_SIZES = [24, 48]
LEARNING_RATE = 1e-5
MAX_LEN = 200
EPS_START = 0.95
EPS_END = 0.05
EPS_DECAY = 200

EPOCHS = 10
EPISODES_PER_EPOCH = 50

start = 0
total_steps = 0

def avg(l): return sum(l) / len(l)

logger = Logger()
logger.add_attribute('ret', [max, avg])
logger.add_attribute('len', [min, avg])
logger.add_attribute('loss', [sum])

In [6]:
env = gym.make('MountainCar-v0')
env.seed(SEED)
obs_dim = env.observation_space.shape[0]
act_dim = env.action_space.n

In [7]:
buffer = ReplayBuffer(obs_dim, 1, 100000)

In [8]:
ddqn = DDQN(obs_dim, act_dim, NETWORK_SIZES).to(dev)
ddqn_optimizer = optim.Adam(ddqn.policy.parameters(), lr=LEARNING_RATE)

In [9]:
def compute_loss(q, q_exp):
    return F.mse_loss(q, q_exp)

In [11]:
def select_action(obs):
    eps = np.random.random(1)[0]
    eps_threshold = EPS_END + (EPS_START - EPS_END) * np.exp(-1. * total_steps / EPS_DECAY)
    
    if eps > eps_threshold:
        return ddqn.act(obs).item()
    else:
        return np.random.choice(act_dim)

In [10]:
def optimize():
    if len(buffer) < BATCH_SIZE:
        return
    
    batch = buffer.sample_batch(BATCH_SIZE)
    obs, act, rew, next_obs, done = \
            batch['obs'], batch['act'], batch['rew'], batch['next_obs'], batch['done']
    
    q = ddqn.policy(obs).gather(1, torch.as_tensor(act, dtype=torch.long))
    q_best_by_policy = ddqn.policy(obs).max(1)[1].detach()
    q_next = ddqn.target(obs).gather(1, q_best_by_policy.view(-1, 1)).squeeze().detach()
    q_exp = q_next * GAMMA * (1 - done) + rew
    
    loss = compute_loss(q, q_exp.unsqueeze(1))
    logger.put('loss', loss.item())
    
    ddqn_optimizer.zero_grad()
    loss.backward()
    ddqn_optimizer.step()

In [12]:
%%time
for epoch in range(EPOCHS):
    print(f'[{epoch}] Epoch has started!')
    for episode in range(EPISODES_PER_EPOCH):
#         if episode % (EPISODES_PER_EPOCH / 5) == 0:
#             print([f'[{epoch}] e{episode} has started!'])
        obs = torch.as_tensor(env.reset(), dtype=torch.float32)
        ep_ret = 0
        ep_len = 0
        ep_loss = 0
        for t in count():
            act = select_action(obs)
            
            next_obs, rew, done, _ = env.step(act)
                
            ep_ret += rew
            ep_len += 1
            total_steps += 1
                        
            if episode % (EPISODES_PER_EPOCH // 5) == 0:
                env.render()

            next_obs = torch.as_tensor(next_obs, dtype=torch.float32)
            done = False if ep_len == MAX_LEN else done
            buffer.put(obs, act, rew, next_obs, done)
            
            obs = next_obs
            
            optimize()
            if done or ep_len == MAX_LEN:
                break
        logger.put('ret', ep_ret)
        logger.put('len', ep_len)
        if episode % TARGET_UPDATE == 0 and episode > TARGET_UPDATE_AFTER:
            ddqn.target.load_state_dict(ddqn.policy.state_dict())
    print(f'[{epoch}] Done: {logger.summarize()}')
print('Complete!')
env.close()

[0] Epoch has started!
[0] ret_max=-200.0000; ret_avg=-200.0000; len_max=200.0000; len_avg=200.0000; loss_sum=3291.1614
[1] Epoch has started!
[1] ret_max=-200.0000; ret_avg=-200.0000; len_max=200.0000; len_avg=200.0000; loss_sum=356.5439
[2] Epoch has started!
[2] ret_max=-188.0000; ret_avg=-199.7600; len_max=200.0000; len_avg=199.7600; loss_sum=221.7297
[3] Epoch has started!
[3] ret_max=-200.0000; ret_avg=-200.0000; len_max=200.0000; len_avg=200.0000; loss_sum=159.6550
[4] Epoch has started!
[4] ret_max=-200.0000; ret_avg=-200.0000; len_max=200.0000; len_avg=200.0000; loss_sum=124.3100
[5] Epoch has started!
[5] ret_max=-200.0000; ret_avg=-200.0000; len_max=200.0000; len_avg=200.0000; loss_sum=102.4893
[6] Epoch has started!
[6] ret_max=-200.0000; ret_avg=-200.0000; len_max=200.0000; len_avg=200.0000; loss_sum=87.1982
[7] Epoch has started!
[7] ret_max=-200.0000; ret_avg=-200.0000; len_max=200.0000; len_avg=200.0000; loss_sum=74.8034
[8] Epoch has started!
[8] ret_max=-200.0000; ret