In [6]:
import gym
import math
import random
import matplotlib
import matplotlib.pyplot as plt
from collections import namedtuple, deque
from itertools import count

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

In [7]:
env = gym.make("CartPole-v1")

device = torch.device(
  "cuda" if torch.cuda.is_available() else
  "mps" if torch.backends.mps.is_available() else
  "cpu"
)

device

device(type='cpu')

## Replay Memory

Make a memory to keep track of all the transistions (state, action and next_state and reward relation). These transistions will be sampled at random to get training data


In [8]:
Transition = namedtuple('Transition', ('state', 'action', 'next_state', 'reward'))

class ReplayMemory(object):
  def __init__(self, capacity):
    self.memory = deque([], maxlen = capacity)

  def push(self, *args):
    """Save a transition"""
    self.memory.append(Transition(*args))
  
  def sample(self, batch_size):
    return random.sample(self.memory, batch_size)
  
  def __len__(self):
    return len(self.memory)

## DQN algorithm
### Model

In [9]:
class DQN(nn.Module):
  def __init__(self, n_observations, n_actions):
    super(DQN, self).__init__()
    self.layer1 = nn.Linear(n_observations, 128)
    self.layer2 = nn.Linear(128, 128)
    self.layer3 = nn.Linear(128, n_actions)
  
  def forward(self, x):
    x = F.relu(self.layer1(x))
    x = F.relu(self.layer2(x))
    return self.layer3(x)

### Training

In [54]:
BATCH_SIZE = 128 #The number of transitions sampled from the replay buffer
GAMMA = 0.99 #The discount factor for the rewards
EPS_START = 0.9 #The starting value of epsilon (random action)
EPS_END = 0.05 #The final value of epsilon
EPS_DECAY = 1000 #Controls the rate of explonentional decay of epsilon, higher value means slower decay
TAU = 0.005 #The update rate of the target network
LR = 1e-4 #The learning rate of the ''Adam'' optimizer

#Get the number of actions from the gym action_space
n_actions = env.action_space.n

#Get the number of state observations
state = env.reset()
n_observations = len(state)

policy_net = DQN(n_observations, n_actions).to(device)
target_net = DQN(n_observations, n_actions).to(device)
target_net.load_state_dict(policy_net.state_dict())

optimizer = optim.AdamW(policy_net.parameters(), lr = LR, amsgrad = True)
memory = ReplayMemory(10000)

steps_done = 0

def select_action(state):
  global steps_done
  sample = random.random()
  eps_threshold = EPS_END + (EPS_START - EPS_END) * math.exp(-1 * steps_done / EPS_DECAY)
  steps_done += 1
  if sample > eps_threshold:
    with torch.no_grad():
      return policy_net(state).max(1).indices.view(1, 1)
  else:
    return torch.tensor([[env.action_space.sample()]], device = device, dtype = torch.long)
  
episode_durations = []

def plot_durations(show_result = False):
  plt.figure(1)
  durations_t = torch.tensor(episode_durations, dtype = torch.float)
  if show_result:
    plt.title('Result')
  else:
    plt.clf()
    plt.title('Training')
  plt.xlabel('Episode')
  plt.ylabel('Duration')
  plt.plot(durations_t.numpy())

  if len(durations_t) >= 100:
    means = durations_t.unfold(0, 100, 1).mean(1).view(-1)
    means = torch.cat((torch.zeros(99), means))
    plt.plot(means.numpy())
  
  plt.pause(0.001)
  if is_ipython:
    if not show_result:
      display.display(plt.gcf())
      display.clear_output(wait = True)
    else:
      display.display(plt.gcf())


In [None]:
def optimize_model():
  if len(memory) < BATCH_SIZE:
    return
  transitions = memory.sample(BATCH_SIZE)
  batch = Transition(*zip(*transitions))

  non_final_mask = torch.tensor(tuple(map(lambda s: s is not None, batch.next_state)), device = device, dtype = torch.bool)
  non_final_next_states = torch.cat([s for s in batch.next_state if s is not None])

  state_batch = torch.cat(batch.state)
  action_batch = torch.cat(batch.action)
  reward_batch = torch.cat(batch.reward)