In [1]:
! pip install gym



In [2]:
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import torch
import random
import gym
import threading

In [3]:
Env = gym.make("CartPole-v1")
Env

<TimeLimit<CartPoleEnv<CartPole-v1>>>

In [4]:
input_dim = Env.observation_space.shape
output_dim = Env.action_space
print(input_dim, output_dim)

(4,) Discrete(2)


In [5]:
Env.reset()

array([-0.04129239, -0.03440432,  0.01166018,  0.04215938])

In [6]:
#Soft target update
def soft_update(net, net_target, tau):
  for param_target, param in zip(net_target.parameters(), net.parameters()):
    param_target.data.copy_(param_target.data*(1.0-tau) + param.data*tau)

In [9]:
class DNN(nn.Module):
  def __init__(self, 
               input_dim: int, output_dim: int,
               hidden_act: str, out_act: str):
    super().__init__()

    self.hidden_act = getattr(nn, hidden_act)()
    self.out_act = getattr(nn, out_act)()

    self.layers = nn.ModuleList()
    self.layers.append(nn.Linear(input_dim, 16))
    self.layers.append(self.hidden_act)
    self.layers.append(nn.Linear(16, 32))
    self.layers.append(self.hidden_act)
    self.layers.append(nn.Linear(32, 64))
    self.layers.append(self.hidden_act)
    self.layers.append(nn.Linear(64, output_dim))
    self.layers.append(self.out_act)

  def forward(self, x):
    for layer in self.layers:
      x = layer(x)
    return x

In [94]:
class DQNrunner(nn.Module, threading.Thread):
  def __init__(self, ADQNagent, env,
               qnet:nn.Module, qnet_target:nn.Module, 
               lr:float, discount_factor:float, epsilon:float,
               target_update_interval:float,
               async_update_interval:float, 
               num_epi:float, tau:float):
    
    nn.Module.__init__(self)
    threading.Thread.__init__(self)

    self.Lock = threading.Lock()
    self.env = env
    self.ADQNagent = ADQNagent
    self.qnet = qnet
    self.qnet_target = qnet_target

    self.lr = lr
    self.tau = tau
    self.discount_factor = discount_factor
    self.epsilon = epsilon
    self.num_epi = num_epi
    self.target_update_interval = target_update_interval
    self.async_update_interval = async_update_interval

    self.criteria = nn.MSELoss()
    self.optimizer = torch.optim.Adam(params=qnet.parameters(), lr=lr)
    
  def get_action(self, state):
    if np.random.uniform(0, 1, size=1) <= self.epsilon:
      action = np.random.choice([0, 1], size=1, p=[1/2, 1/2])
      return torch.tensor(action).numpy()[0]
    else:
      state = torch.tensor(state).float().view(1,-1)
      q = self.qnet(state).detach()
      actions = torch.argmax(q).numpy()
      action = np.max(actions)
      return action

  def cum_gradient(self, state, action, next_state, reward, done):
    s, a, r, ns = state, action, reward, next_state

    with torch.no_grad():
      q = self.qnet_target(ns)
      q_max = torch.max(q)
      target = r + self.discount_factor*q_max*(1-done).clone()
    
    infer = self.qnet(s).clone()
    loss = self.criteria(infer[0][a], target)
    loss.backward()

  def main_update(self):
    for paramA, paramB in zip(self.ADQNagent.global_qnet.parameters(), self.qnet.parameters()):
      paramA.grad = paramB.grad
    self.ADQNagent.optimizer.step()
    self.ADQNagent.optimizer.zero_grad()
    self.qnet.load_state_dict(self.ADQNagent.global_qnet.state_dict())

  def run(self):
    self.qnet_target.load_state_dict(self.qnet.state_dict())
    
    for epi in range(self.num_epi):
      cum_r = 0
      state = self.env.reset()
      while True:
        action = self.get_action(state)
        next_state, reward, done, info = self.env.step(action)

        if done: 
          done = 1
        else: 
          done = 0

        state = torch.tensor(state).float().view(1, -1)
        action = torch.tensor(action)
        next_state = torch.tensor(next_state).float().view(1, -1)
        reward = torch.tensor(reward)
        done = torch.tensor(done)

        self.cum_gradient(state, action, next_state, reward, done)
        cum_r += reward
        if done:
          self.main_update()
          break
      
      if epi % self.target_update_interval == 0:
        soft_update(self.qnet, self.qnet_target, self.tau)
      
      if epi & self.async_update_interval == 0:
        self.main_update()
        print("Episode: {} | cum_r: {} | main_update complete".format(epi, cum_r))

In [96]:
class ADQN(nn.Module):
  def __init__(self, 
               local_qnet:nn.Module,
               global_qnet:nn.Module, global_lr:float, thread_num:int):
    super().__init__()
    self.global_qnet = global_qnet
    self.local_qnet = local_qnet
    self.global_lr = global_lr
    self.thread_num = thread_num
    self.optimizer = torch.optim.Adam(params=self.global_qnet.parameters(), lr=self.global_lr)

  def get_action(self, state):
    state = torch.tensor(state).float().view(1,-1)
    q = self.global_qnet(state).detach()
    action = torch.argmax(q)
    return action

  def train(self):
    runners = [DQNrunner(env=Env,qnet=self.local_qnet, qnet_target=self.local_qnet, ADQNagent=self,
                         lr=0.001, discount_factor=0.99, epsilon=1.0,
                         target_update_interval=20,
                         async_update_interval=50,
                         num_epi=200, tau=0.002) for _ in range(self.thread_num)]    

    for i, runner in enumerate(runners):
      print("Start runner #{}".format(i))
      runner.start()