In [10]:
from torch import nn
import torch 
import os, sys
sys.path.append('/mnt/data/projects/wankun01/workdir/playground/dqn/venv/lib/python3.9/site-packages')
import gym
import itertools
import numpy as np
import random
from collections import deque

GAMMA = 0.99
BATCH_SIZE = 64
BUFFER_SIZE = 50000
MIN_REPLAY_SIZE = 1000
EPS_START = 1.0
EPS_END = 0.02
EPS_DECAY = 10000
TARGET_UPDATE_FREQ = 1000

class Network(nn.Module):
    def __init__(self, env):
        super().__init__()
        input_dim = int(np.prod(env.observation_space.shape))
        self.net = nn.Sequential(
            nn.Linear(input_dim, 64),
            nn.Tanh(),
            nn.Linear(64, env.action_space.n)
        )
    def forward(self, x):
        return self.net(x)
    
    def act(self, obs):
        obs_t = torch.as_tensor(obs, dtype=torch.float32)
        q_values = self.forward(obs_t.unsqueeze(0))
        max_q_index = torch.argmax(q_values, dim=1)[0]
        action = max_q_index.detach().item()
        return action

# create env
env = gym.make('CartPole-v1')
replay_buffer = deque(maxlen=BUFFER_SIZE)
rew_buffer = deque([0.0], maxlen=100)
episode_reward = 0.0

# print('Device', 'cuda:0' if torch.cuda.is_available() else 'cpu')

online_net = Network(env)
target_net = Network(env)
target_net.load_state_dict(online_net.state_dict())
optimizer = torch.optim.Adam(online_net.parameters(), lr=5e-4)

In [18]:
# replay buffer init by a 1st random action 
obs = env.reset()
for _ in range(MIN_REPLAY_SIZE):
    action = env.action_space.sample()
    new_obs, rew, done, _ = env.step(action)
    transition = (obs, action, rew, done, new_obs)
    replay_buffer.append(transition)
    obs = new_obs
    
    if done:
        obs = env.reset()

# Main training loop
obs = env.reset()
for step in itertools.count():
    eps = np.interp(step, [0, EPS_DECAY], [EPS_START, EPS_END])
    rnd_sample = random.random()
    
    if rnd_sample <= eps:
        action = env.action_space.sample()
    else:
        action = online_net.act(obs)
                                
    new_obs, rew, done, _ = env.step(action)
    transition = (obs, action, rew, done, new_obs)
    
    replay_buffer.append(transition)
    obs = new_obs
    
    episode_reward += rew

    if done:
        osb = env.reset()
        rew_buffer.append(episode_reward)
        episode_reward = 0.0


    # start grad step
    transitions = random.sample(replay_buffer, BATCH_SIZE)
    
    obses = np.asarray([t[0] for t in transitions])
    actions = np.asarray([t[1] for t in transitions])
    rews = np.asarray([t[2] for t in transitions])
    dones = np.asarray([t[3] for t in transitions])
    new_obses = np.asarray([t[4] for t in transitions])

    obses_t = torch.as_tensor(obses, dtype=torch.float32)
    actions_t = torch.as_tensor(actions, dtype=torch.int64).unsqueeze(-1)
    rews_t = torch.as_tensor(rews, dtype=torch.float32).unsqueeze(-1)
    dones_t = torch.as_tensor(dones, dtype=torch.float32).unsqueeze(-1)
    new_obses_t = torch.as_tensor(new_obses, dtype=torch.float32)


    # compute targets
    target_q_values = target_net(new_obses_t)
    max_target_q_values = target_q_values.max(dim=1, keepdim=True)[0]
    # print(rews_t.shape, dones_t.shape, max_target_q_values.shape)
    targets = rews_t + GAMMA * (1 - dones_t) * max_target_q_values

    # compute loss
    q_values = online_net(obses_t)
    # print(q_values.shape, actions_t.shape)
    action_q_values = torch.gather(input=q_values, dim=1, index=actions_t)
    loss = nn.functional.smooth_l1_loss(action_q_values, targets)
    loss = loss

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if step % TARGET_UPDATE_FREQ == 0:
        target_net.load_state_dict(online_net.state_dict())

    # logging
    if step % 1000 == 0:
        print()
        print('Step', step)
        print('Avg Rew', np.mean(rew_buffer))

        


Step 0
Avg Rew 10.24

Step 1000
Avg Rew 15.28

Step 2000
Avg Rew 21.19

Step 3000
Avg Rew 26.21

Step 4000
Avg Rew 30.1

Step 5000
Avg Rew 36.17

Step 6000
Avg Rew 40.58

Step 7000
Avg Rew 46.08

Step 8000
Avg Rew 49.74

Step 9000
Avg Rew 56.0

Step 10000
Avg Rew 59.2

Step 11000
Avg Rew 59.19

Step 12000
Avg Rew 61.63

Step 13000
Avg Rew 60.21

Step 14000
Avg Rew 56.86

Step 15000
Avg Rew 55.46

Step 16000
Avg Rew 53.79

Step 17000
Avg Rew 53.49

Step 18000
Avg Rew 56.56

Step 19000
Avg Rew 61.31

Step 20000
Avg Rew 66.57

Step 21000
Avg Rew 72.3

Step 22000
Avg Rew 74.94

Step 23000
Avg Rew 77.2


KeyboardInterrupt: 

In [10]:
!jupyter nbconvert dqn.ipynb --to=script

[NbConvertApp] Converting notebook dqn.ipynb to script
[NbConvertApp] Writing 3938 bytes to dqn.py
