# Importing dependancies

In [21]:
import gym

import numpy as np
import random
from collections import namedtuple, deque

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

import matplotlib.pyplot as plt

In [22]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


# Model definition

We can use the state variables (position of cart, angle of pole)

In [23]:
class DQN(nn.Module):

    def __init__(self, output_size):
        super(DQN, self).__init__()
        self.fn1 = nn.LazyLinear(64)
        self.fn2 = nn.LazyLinear(64)
        self.fn3 = nn.LazyLinear(output_size)
        self.act = nn.ReLU()

    def forward(self, input):
        x = self.act(self.fn1(input))
        x = self.act(self.fn2(x))
        output = self.fn3(x)
        return output


# Memory
The agent should memorize what it learned before so it can learn from it

In [24]:
Experience = namedtuple('Experience', ('curr_state', 'action', 'next_state', 'reward', 'is_going'))

class Memory(object):

    def __init__(self, capacity):
        self.memory = deque([],maxlen=capacity)

    def remember(self, *args):
        self.memory.append(Experience(*args))

    def recall(self, batch_size):
        experiences = random.sample(self.memory, batch_size)
        batch = Experience(*zip(*experiences))
        return batch

    def __len__(self):
        return len(self.memory)

# Creating the agent

In [25]:
class DQNAgent:
    def __init__(self, env):

        self.action_size = env.action_space.n # output size
        
        self.model = DQN(self.action_size).to(device)

        self.loss_fn = nn.SmoothL1Loss()
        self.optimizer = optim.AdamW(self.model.parameters(), lr=0.0005)

        self.exploration_rate = 1 # initial exploration rate, always leave at 1
        self.exploration_rate_decay = 0.9999 # rate at which the exploration decreases
        self.exploration_rate_min = 0.1 # minimun exploration rate
        
        self.gamma = 0.999 # falloff for Q score

        self.batch_size = 64
        self.num_epochs = 500

        self.memory = Memory(10000) # how many of the previous samples are used
    
    def act(self, state):
        if random.random() < self.exploration_rate:
            action = random.randrange(self.action_size) # act randomly
        else:
            action = self.act_ideal(state) # act ideally
            
        self.exploration_rate *= self.exploration_rate_decay
        self.exploration_rate = max(self.exploration_rate_min, self.exploration_rate)
        
        return action
    
    def act_ideal(self, state):
        return self.model(state).max(1)[1].item()
    
    def train_step(self):
        if len(self.memory) < self.batch_size:
            return
        batch = self.memory.recall(self.batch_size)
        
        curr_states = torch.stack(batch.curr_state).squeeze(1)
        actions = torch.Tensor(batch.action).to(device)
        next_states = torch.stack(batch.next_state).squeeze(1)
        rewards = torch.Tensor(batch.reward).to(device)
        is_goings = torch.Tensor(batch.is_going).to(device)
        
        # This is the fundamental logic behind calulating a deep Q value.
        curr_Q = self.model(curr_states).mul(actions).sum(1)
        next_Q = self.model(next_states).max(1)[0]
        expected_Q = rewards + is_goings * self.gamma * next_Q

        loss = self.loss_fn(expected_Q, curr_Q)
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
    
    def train(self, env, print_epochs=False):

        scores = []

        for epoch in range(self.num_epochs):
            done = False
            curr_state = env.reset()
            curr_state = torch.Tensor(curr_state).to(device)
            score = 0
            
            while not done:

                action = self.act(curr_state.unsqueeze(0))
                next_state, reward, done, _ = env.step(action)
                next_state = torch.Tensor(next_state).to(device)
                action_encode = np.eye(2)[action]
                self.memory.remember(curr_state, action_encode, next_state, reward, 1 - done)

                curr_state = next_state
                      
                self.train_step()

                score += 1
            
            scores.append(score)
            if print_epochs:
                print("Epoch: " + str(epoch + 1) + ". Score is: " + str(score))
            
        return scores


# Training the agent

In [26]:
env = gym.make("CartPole-v0")
agent = DQNAgent(env)
env.reset()
scores = agent.train(env, print_epochs=True)
plt.plot(scores)
plt.show()



Epoch: 1. Score is: 26
Epoch: 2. Score is: 28
Epoch: 3. Score is: 41
Epoch: 4. Score is: 12
Epoch: 5. Score is: 12
Epoch: 6. Score is: 19
Epoch: 7. Score is: 29
Epoch: 8. Score is: 58
Epoch: 9. Score is: 38
Epoch: 10. Score is: 25
Epoch: 11. Score is: 20
Epoch: 12. Score is: 25
Epoch: 13. Score is: 16
Epoch: 14. Score is: 32
Epoch: 15. Score is: 20
Epoch: 16. Score is: 32
Epoch: 17. Score is: 21
Epoch: 18. Score is: 15
Epoch: 19. Score is: 18
Epoch: 20. Score is: 15
Epoch: 21. Score is: 16
Epoch: 22. Score is: 34
Epoch: 23. Score is: 25
Epoch: 24. Score is: 21
Epoch: 25. Score is: 43
Epoch: 26. Score is: 53
Epoch: 27. Score is: 20
Epoch: 28. Score is: 22
Epoch: 29. Score is: 30
Epoch: 30. Score is: 17
Epoch: 31. Score is: 21
Epoch: 32. Score is: 41
Epoch: 33. Score is: 33
Epoch: 34. Score is: 34
Epoch: 35. Score is: 74
Epoch: 36. Score is: 33
Epoch: 37. Score is: 13
Epoch: 38. Score is: 33
Epoch: 39. Score is: 51
Epoch: 40. Score is: 33
Epoch: 41. Score is: 48
Epoch: 42. Score is: 14
E

KeyboardInterrupt: 

# Testing the agent

In [None]:
test_scores = []

for i in range(100):

    done = False
    curr_state = env.reset()
    curr_state = torch.Tensor(curr_state).to(device)
    score = 0

    while not done:

        action = agent.act_ideal(curr_state.unsqueeze(0))
        curr_state, _, done, _ = env.step(action)
        curr_state = torch.Tensor(curr_state).to(device)

        score += 1
    
    test_scores.append(score)
    
avg = sum(test_scores) / 100
print(avg)

env.close()

115.5
