In [1]:
import torch
from torch import nn, optim
import torch.nn.functional as F
import gym
import numpy as np
from itertools import count

In [2]:
SEED = 11
torch.manual_seed(SEED)

<torch._C.Generator at 0x7fd9800435b0>

In [3]:
dev = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

In [4]:
def construct_nn(sizes: list, output=nn.Identity):
    layers = []
    for i in range(len(sizes) - 1):
        act = nn.ReLU if i < len(sizes) - 2 else output
        layers += [nn.Linear(sizes[i], sizes[i+1]), act()]
    return nn.Sequential(*layers)

In [5]:
class DDQN(nn.Module):
    def __init__(self, obs_dim, act_dim, sizes):
        super().__init__()
        self.policy = construct_nn([obs_dim] + sizes + [act_dim])
        self.target = construct_nn([obs_dim] + sizes + [act_dim])
        self.target.load_state_dict(self.policy.state_dict())
        self.target.eval()
    
    def act(self, obs):
        with torch.no_grad():
            return self.policy(obs).max(0)[1].view(1, 1)

In [6]:
class ReplayBuffer:
    def __init__(self, obs_dim, act_dim, size):
        self.obs_buf = np.zeros((size, obs_dim), dtype=np.float32)
        self.act_buf = np.zeros((size, act_dim), dtype=np.float32)
        self.rew_buf = np.zeros(size, dtype=np.float32)
        self.next_obs_buf = np.zeros((size, obs_dim), dtype=np.float32)
        self.done_buf = np.zeros(size, dtype=np.float32)

        self.ptr, self.size, self.limit = 0, 0, size
    
    def put(self, obs, act, rew, next_obs, done):
        self.obs_buf[self.ptr] = obs
        self.act_buf[self.ptr] = act
        self.rew_buf[self.ptr] = rew
        self.next_obs_buf[self.ptr] = next_obs
        self.done_buf[self.ptr] = done
        self.ptr = (self.ptr + 1) % self.limit
        self.size = min(self.size + 1, self.limit)
    
    def sample_batch(self, batch_size):
        idx = np.random.randint(0, self.size, size=batch_size)
        return {
            'obs': torch.as_tensor(self.obs_buf[idx], dtype=torch.float32),
            'act': torch.as_tensor(self.act_buf[idx], dtype=torch.float32),
            'rew': torch.as_tensor(self.rew_buf[idx], dtype=torch.float32),
            'next_obs': torch.as_tensor(self.next_obs_buf[idx], dtype=torch.float32),
            'done': torch.as_tensor(self.done_buf[idx], dtype=torch.float32)
        }
    
    def __len__(self):
        return self.size

In [7]:
BATCH_SIZE = 256
GAMMA = 0.99
TARGET_UPDATE = 10
TARGET_UPDATE_AFTER = 30
NETWORK_SIZES = [24, 48]
LEARNING_RATE = 1e-5
MAX_LEN = 200
EPS_START = 0.95
EPS_END = 0.05
EPS_DECAY = 200

EPOCHS = 10
EPISODES_PER_EPOCH = 50

start = 0
total_steps = 0

Logger = dict(
    loss=[],
    ret=[]
)

In [8]:
env = gym.make('MountainCar-v0')
env.seed(SEED)
obs_dim = env.observation_space.shape[0]
act_dim = env.action_space.n

In [9]:
buffer = ReplayBuffer(obs_dim, 1, 100000)

In [10]:
ddqn = DDQN(obs_dim, act_dim, NETWORK_SIZES).to(dev)
ddqn_optimizer = optim.Adam(ddqn.policy.parameters(), lr=LEARNING_RATE)

In [11]:
def compute_loss(q, q_exp):
    return F.mse_loss(q, q_exp)

In [12]:
def optimize():
    if len(buffer) < BATCH_SIZE:
        return
    
    batch = buffer.sample_batch(BATCH_SIZE)
    obs, act, rew, next_obs, done = \
            batch['obs'], batch['act'], batch['rew'], batch['next_obs'], batch['done']
    
    q = ddqn.policy(obs).gather(1, torch.as_tensor(act, dtype=torch.long))
    q_best_by_policy = ddqn.policy(obs).max(1)[1].detach()
    q_next = ddqn.target(obs).gather(1, q_best_by_policy.view(-1, 1)).squeeze().detach()
    q_exp = (q_next * GAMMA) + rew

    ddqn_optimizer.zero_grad()
    loss = compute_loss(q, q_exp.unsqueeze(1))
    Logger['loss'].append(loss.item())
    loss.backward()
    ddqn_optimizer.step()

In [13]:
def select_action(obs):
    eps = np.random.random(1)[0]
    eps_threshold = EPS_END + (EPS_START - EPS_END) * np.exp(-1. * total_steps / EPS_DECAY)
    
    if eps > eps_threshold:
        return ddqn.act(obs).item()
    else:
        return np.random.choice(act_dim)

In [14]:
%%time
for epoch in range(EPOCHS):
    print(f'[{epoch}] Epoch has started!')
    for episode in range(EPISODES_PER_EPOCH):
#         if episode % (EPISODES_PER_EPOCH / 5) == 0:
#             print([f'[{epoch}] e{episode} has started!'])
        obs = torch.as_tensor(env.reset(), dtype=torch.float32)
        ep_ret = 0
        ep_len = 0
        ep_loss = 0
        for t in count():
            act = select_action(obs)
            
            next_obs, rew, done, _ = env.step(act)
            done = False if ep_len == MAX_LEN else done
            next_obs = torch.as_tensor(next_obs, dtype=torch.float32)
                
            ep_ret += rew
            ep_len += 1
            total_steps += 1
            
            if episode % (EPISODES_PER_EPOCH // 5) == 0:
                env.render()
            
            buffer.put(obs, act, rew, next_obs, done)         
            obs = next_obs
            
            optimize()
            if done or ep_len == MAX_LEN:
                break
        Logger['ret'].append(ep_ret)
        if episode % TARGET_UPDATE == 0 and episode > TARGET_UPDATE_AFTER:
            ddqn.target.load_state_dict(ddqn.policy.state_dict())
    epoch_loss = sum(Logger['loss'][start:])
    epoch_min_loss = min(Logger['loss'][start:])
    epoch_avg_ret = sum(Logger['ret'][epoch*EPISODES_PER_EPOCH:]) / (len(Logger['ret']) - epoch*EPISODES_PER_EPOCH)
    print(f'[{epoch}] Epoch has completed: loss={epoch_loss} min_loss={epoch_min_loss} avg_ret={epoch_avg_ret}')
    start = len(Logger['loss'])
print('Complete!')
env.render()
env.close()

[0] Epoch has started!
[0] Epoch has completed: loss=3361.267687718617 min_loss=0.0028642904944717884 avg_ret=-200.0
[1] Epoch has started!
[1] Epoch has completed: loss=361.8883148918321 min_loss=7.620827091159299e-05 avg_ret=-200.0
[2] Epoch has started!
[2] Epoch has completed: loss=228.12890471004903 min_loss=1.8410824850434437e-05 avg_ret=-200.0
[3] Epoch has started!
[3] Epoch has completed: loss=166.91742463772334 min_loss=6.124999345047399e-06 avg_ret=-200.0
[4] Epoch has started!
[4] Epoch has completed: loss=125.67890689285832 min_loss=5.30208126292564e-06 avg_ret=-200.0
[5] Epoch has started!
[5] Epoch has completed: loss=100.46328944095694 min_loss=7.498708328057546e-06 avg_ret=-200.0
[6] Epoch has started!
[6] Epoch has completed: loss=84.08605581662368 min_loss=6.5629806158540305e-06 avg_ret=-200.0
[7] Epoch has started!
[7] Epoch has completed: loss=72.23479927427434 min_loss=5.189955118112266e-06 avg_ret=-200.0
[8] Epoch has started!
[8] Epoch has completed: loss=62.869