In [1]:
import torch
import torch.nn as nn
import random
from collections import deque
import numpy as np
import torch.optim as optim
import gym

In [2]:
def cartpole_model(observation_space, action_space):
    return nn.Sequential(
        nn.Linear(observation_space, 24),
        nn.ReLU(),
        nn.Linear(24, 24),
        nn.ReLU(),
        nn.Linear(24, action_space)
    )

In [3]:
class DQN:
    def __init__(self, observation_space, action_space):
        self.exploration_rate = MAX_EXPLORE
        self.action_space = action_space
        self.observation_space = observation_space
        self.memory = deque(maxlen=MEMORY_LEN)
        
        self.target_net = cartpole_model(self.observation_space, self.action_space)
        self.policy_net = cartpole_model(self.observation_space, self.action_space)
        
        self.target_net.load_state_dict(self.policy_net.state_dict())
        self.target_net.eval()
        
        self.criterion = nn.MSELoss()
        self.optimizer = optim.Adam(self.policy_net.parameters())
        
        self.explore_limit = False
        
    def load_memory(self, state, action, reward, next_state, terminal):
        self.memory.append((state, action, reward, next_state, terminal))
        
    def predict_action(self, state):
        random_number = np.random.rand()
        
        if random_number < self.exploration_rate:
            return random.randrange(self.action_space)
        
        q_values = self.target_net(state).detach().numpy()
        return np.argmax(q_values[0])
    
    def experience_replay(self):
        if len(self.memory) < BATCH_SIZE:
            return
        
        batch = random.sample(self.memory, BATCH_SIZE)
        
        for state, action, reward, next_state, terminal in batch:
            q_update = reward
            
            if not terminal:
                    q_update = reward + GAMMA * self.target_net(next_state).max(axis=1)[0]
                    
            q_values = self.target_net(state)
            q_values[0][action] = q_update
            
            loss = self.criterion(self.policy_net(state), q_values)
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()
            
        if not self.explore_limit:
            self.exploration_rate *= EXPLORE_DECAY
            if self.exploration_rate < MIN_EXPLORE:
                self.exploration_rate = MIN_EXPLORE
                self.explore_limit = True

In [4]:
ENV_NAME = "CartPole-v1"
BATCH_SIZE = 20
GAMMA = 0.95
LEARNING_RATE = 0.001
MAX_EXPLORE = 1.0
MIN_EXPLORE = 0.01
EXPLORE_DECAY = 0.995
MEMORY_LEN = 1_000_000
UPDATE_FREQ = 10

In [5]:
env = gym.make(ENV_NAME)
observation_space = env.observation_space.shape[0]
action_space = env.action_space.n
dqn = DQN(observation_space, action_space)

In [6]:
print(f'| Run | Exploration Rate | Score |')
steps = 0
for i in range(100):
    state = env.reset()
    state = np.reshape(state, [1, observation_space])
    state = torch.from_numpy(state).float()
    
    score = 0
    while True:
        steps += 1
        score += 1
        action = dqn.predict_action(state)
        next_state, reward, terminal, info = env.step(action)
        
        next_state = torch.from_numpy(np.reshape(next_state, [1, observation_space])).float()
        dqn.load_memory(state, action, reward, next_state, terminal)
        state = next_state
        
        if terminal:
            print(f'| {i+1:03} |       {dqn.exploration_rate:.4f}     |  {score:03}  |')
            break
        
        dqn.experience_replay()
        if steps%UPDATE_FREQ == 0:
            dqn.target_net.load_state_dict(dqn.policy_net.state_dict())

| Run | Exploration Rate | Score |
| 001 |       0.7862     |  068  |
| 002 |       0.7329     |  015  |
| 003 |       0.6369     |  029  |
| 004 |       0.5186     |  042  |
| 005 |       0.4575     |  026  |
| 006 |       0.3157     |  075  |
| 007 |       0.2584     |  041  |
| 008 |       0.2212     |  032  |
| 009 |       0.1670     |  057  |
| 010 |       0.1459     |  028  |
| 011 |       0.1113     |  055  |
| 012 |       0.0788     |  070  |
| 013 |       0.0684     |  029  |
| 014 |       0.0580     |  034  |
| 015 |       0.0390     |  080  |
| 016 |       0.0341     |  028  |
| 017 |       0.0312     |  019  |
| 018 |       0.0210     |  080  |
| 019 |       0.0136     |  088  |
| 020 |       0.0100     |  071  |
| 021 |       0.0100     |  043  |
| 022 |       0.0100     |  080  |
| 023 |       0.0100     |  103  |
| 024 |       0.0100     |  091  |
| 025 |       0.0100     |  031  |
| 026 |       0.0100     |  099  |
| 027 |       0.0100     |  184  |
| 028 |       0.0100

In [7]:
def play_agent(dqn, env):
    observation = env.reset()
    total_reward=0
    for _ in range(500):
        env.render()
        observation = torch.tensor(observation).type('torch.FloatTensor').view(1,-1)
        q_values = dqn.target_net(observation).detach().numpy()
        action = np.argmax(q_values[0])
        new_observation, reward, done, _ = env.step(action)
        total_reward += reward
        observation = new_observation

        if(done):
            break

    env.close()
    print("Rewards: ",total_reward)

In [8]:
play_agent(dqn, env)

Rewards:  160.0
