In [None]:
# https://github.com/gsurma/cartpole
# https://towardsdatascience.com/cartpole-introduction-to-reinforcement-learning-ed0eb5b58288

In [1]:
import warnings
warnings.filterwarnings('ignore')

import random
import gym
import numpy as np
from collections import deque

import torch
import torch.nn.functional as F
from DQN import DQN

ENV_NAME = "CartPole-v1"

GAMMA = 0.995
LEARNING_RATE = 0.001

MEMORY_SIZE = 100000
BATCH_SIZE =  20

EXPLORATION_MAX = 1.0
EXPLORATION_MIN = 0.01
EXPLORATION_DECAY = 0.95

In [2]:
class DQNSolver:
    def __init__(self, observation_space, action_space, single = True):
        self.exploration_rate = EXPLORATION_MAX
        self.action_space = action_space
        self.memory = deque(maxlen=MEMORY_SIZE)
        self.single = single
        
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu' 
        self.model = DQN(observation_space, action_space).to(self.device)
        
        self.model.train()
        
        if not self.single:
            self.target_net = DQN(observation_space, action_space).to(self.device)
            self.target_net.load_state_dict(self.model.state_dict())
            self.target_net.eval()
            
        #self.optimizer  = torch.optim.RMSprop(self.policy_net.parameters())
        self.loss_fn = torch.nn.MSELoss()
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr = LEARNING_RATE)
        
    def remember(self, state, action, reward, next_state, done):
        self.memory.append((state, action, reward, next_state, done))

    def act(self, state):
        if np.random.rand() < self.exploration_rate:
            return random.randrange(self.action_space)
        
        if self.single:
            self.model.eval()
            q_values = self.model(state)
        else:
            q_values = self.target_net(state)
        return torch.argmax(q_values)
    
    def predict(self, state):
        if self.single:
            self.model.eval()
            prediction = self.model(state)
        else:
            prediction = self.target_net(state)
        return prediction
    
    def fit(self, state, expected_state_action_values):
        self.model.train()
        predictions = self.model(state)
        # print(predictions, expected_state_action_values)
        # F.smooth_l1_loss - Base loss fn
        loss = self.loss_fn(predictions, expected_state_action_values)
        # Optimize the model
        self.optimizer.zero_grad()
        loss.backward()
        for param in self.model.parameters():
            param.grad.data.clamp_(-1, 1)
        self.optimizer.step()
    
    def experience_replay(self):
        if len(self.memory) < BATCH_SIZE:
            return
        
        batch = random.sample(self.memory, BATCH_SIZE)
        for state, action, reward, state_next, terminal in batch:
            q_update = reward
            if not terminal:
                action_n = self.predict(state_next)
                q_update = reward + GAMMA * torch.max(action_n)
            q_values = self.predict(state) # future_action
            q_values[action] = q_update
            
            self.fit(state, q_values) 
            
        self.exploration_rate *= EXPLORATION_DECAY
        self.exploration_rate  = max(EXPLORATION_MIN, self.exploration_rate)


def cartpole():
    pass

def t_wrap(array, device):
    return torch.tensor(array, dtype = torch.float32, device=device)

In [None]:
# state = env.reset()
# env.render()
# cart_position, cart_velocity, pole_angle, pole_velocity = state 
# print('cart_position = ', cart_position) 
# print('cart_velocity = ', cart_velocity) 
# print('pole_angle    = ', pole_angle) 
# print('pole_velocity = ', pole_velocity)
# run += 1
# state = t_wrap(env.reset(), dqn_solver.device)
# step = 0

In [3]:
single = True
env = gym.make(ENV_NAME)
#score_logger = ScoreLogger(ENV_NAME)
observation_space = env.observation_space.shape[0]
action_space = env.action_space.n
dqn_solver = DQNSolver(observation_space, action_space, single)
run = 0
max_score = 0

In [4]:
run += 1
state = t_wrap(env.reset(), dqn_solver.device)
step = 0
env.render() 

True

In [None]:
while True:
    step += 1
    env.render()    
    #action = dqn_solver.act(state)
    if np.random.rand() < dqn_solver.exploration_rate:
        action = random.randrange(dqn_solver.action_space)
    else:
        q_values = dqn_solver.model(state) # target_net
        # print('q_values =',q_values)
        action = torch.argmax(q_values)
        
    if isinstance(action, torch.Tensor):
        action = action.item()
    # print('action  =',action)

    state_next, reward, terminal, info = env.step(action)
    # print('cart_position, cart_velocity, pole_angle, pole_velocity')
    # print('state_n = ', state_next)
    # print('reward  = ', reward)
    # print('terminal= ', terminal)
    # print('info    = ', info)
    state_next = t_wrap(state_next, dqn_solver.device)
    reward = reward if not terminal else -reward
    dqn_solver.remember(state, action, reward, state_next, terminal)
    state = state_next
    # print('state = ', state)
    if terminal:
        # score_logger.add_score(step, run)
        print("Run: " + str(run) + ", exploration: " 
              + str(dqn_solver.exploration_rate) + ", score: " + str(step))
        run += 1
        state = t_wrap(env.reset(), dqn_solver.device)
        step = 0
        env.render() 
    else:
        # dqn_solver.experience_replay()
        if len(dqn_solver.memory) < BATCH_SIZE:
            # return
            print('collect more samples!')
        else:
            batch = random.sample(dqn_solver.memory, BATCH_SIZE)
            for state, action, reward, state_next, terminal in batch:
                q_update = reward
                if not terminal:
                    with torch.no_grad():
                        action_n = dqn_solver.model(state_next)
                    q_update = reward + GAMMA * torch.max(action_n)
                with torch.no_grad():
                    q_values = dqn_solver.model(state) # future_action
                q_values[action] = q_update
                # print('q_values =', q_values)
                dqn_solver.fit(state, q_values) 

            dqn_solver.exploration_rate *= EXPLORATION_DECAY
            dqn_solver.exploration_rate  = max(EXPLORATION_MIN, dqn_solver.exploration_rate)

collect more samples!
collect more samples!
collect more samples!
collect more samples!
collect more samples!
collect more samples!
collect more samples!
collect more samples!
collect more samples!
collect more samples!
collect more samples!
collect more samples!
collect more samples!
Run: 1, exploration: 1.0, score: 14
collect more samples!
collect more samples!
collect more samples!
collect more samples!
collect more samples!
Run: 2, exploration: 0.2919890243387723, score: 30
Run: 3, exploration: 0.18402591023557577, score: 10
Run: 4, exploration: 0.12208654873684793, score: 9
Run: 5, exploration: 0.0852575903343082, score: 8
Run: 6, exploration: 0.053733545982740265, score: 10
Run: 7, exploration: 0.03564793225056021, score: 9
Run: 8, exploration: 0.022467088258818428, score: 10
Run: 9, exploration: 0.014159869113351011, score: 10
Run: 10, exploration: 0.01, score: 9
Run: 11, exploration: 0.01, score: 10
Run: 12, exploration: 0.01, score: 8
Run: 13, exploration: 0.01, score: 9
Run: 

Run: 204, exploration: 0.01, score: 10
Run: 205, exploration: 0.01, score: 9
Run: 206, exploration: 0.01, score: 10
Run: 207, exploration: 0.01, score: 9
Run: 208, exploration: 0.01, score: 9
Run: 209, exploration: 0.01, score: 8
Run: 210, exploration: 0.01, score: 9
Run: 211, exploration: 0.01, score: 10
Run: 212, exploration: 0.01, score: 10
Run: 213, exploration: 0.01, score: 10
Run: 214, exploration: 0.01, score: 9
Run: 215, exploration: 0.01, score: 9
Run: 216, exploration: 0.01, score: 10
Run: 217, exploration: 0.01, score: 9
Run: 218, exploration: 0.01, score: 10
Run: 219, exploration: 0.01, score: 10
Run: 220, exploration: 0.01, score: 9
Run: 221, exploration: 0.01, score: 10
Run: 222, exploration: 0.01, score: 8
Run: 223, exploration: 0.01, score: 9
Run: 224, exploration: 0.01, score: 10
Run: 225, exploration: 0.01, score: 9
Run: 226, exploration: 0.01, score: 11
Run: 227, exploration: 0.01, score: 10
Run: 228, exploration: 0.01, score: 9
Run: 229, exploration: 0.01, score: 10

Run: 417, exploration: 0.01, score: 9
Run: 418, exploration: 0.01, score: 9
Run: 419, exploration: 0.01, score: 9
Run: 420, exploration: 0.01, score: 11
Run: 421, exploration: 0.01, score: 10
Run: 422, exploration: 0.01, score: 10
Run: 423, exploration: 0.01, score: 10
Run: 424, exploration: 0.01, score: 8
Run: 425, exploration: 0.01, score: 10
Run: 426, exploration: 0.01, score: 9
Run: 427, exploration: 0.01, score: 10
Run: 428, exploration: 0.01, score: 9
Run: 429, exploration: 0.01, score: 9
Run: 430, exploration: 0.01, score: 8
Run: 431, exploration: 0.01, score: 9
Run: 432, exploration: 0.01, score: 10
Run: 433, exploration: 0.01, score: 9
Run: 434, exploration: 0.01, score: 10
Run: 435, exploration: 0.01, score: 8
Run: 436, exploration: 0.01, score: 10
Run: 437, exploration: 0.01, score: 10
Run: 438, exploration: 0.01, score: 10
Run: 439, exploration: 0.01, score: 10
Run: 440, exploration: 0.01, score: 11
Run: 441, exploration: 0.01, score: 10
Run: 442, exploration: 0.01, score: 

Run: 630, exploration: 0.01, score: 9
Run: 631, exploration: 0.01, score: 10
Run: 632, exploration: 0.01, score: 9
Run: 633, exploration: 0.01, score: 9
Run: 634, exploration: 0.01, score: 9
Run: 635, exploration: 0.01, score: 9
Run: 636, exploration: 0.01, score: 10
Run: 637, exploration: 0.01, score: 10
Run: 638, exploration: 0.01, score: 10
Run: 639, exploration: 0.01, score: 9
Run: 640, exploration: 0.01, score: 10
Run: 641, exploration: 0.01, score: 10
Run: 642, exploration: 0.01, score: 9
Run: 643, exploration: 0.01, score: 10
Run: 644, exploration: 0.01, score: 10
Run: 645, exploration: 0.01, score: 10
Run: 646, exploration: 0.01, score: 9
Run: 647, exploration: 0.01, score: 10
Run: 648, exploration: 0.01, score: 10
Run: 649, exploration: 0.01, score: 9
Run: 650, exploration: 0.01, score: 10
Run: 651, exploration: 0.01, score: 9
Run: 652, exploration: 0.01, score: 10
Run: 653, exploration: 0.01, score: 10
Run: 654, exploration: 0.01, score: 10
Run: 655, exploration: 0.01, score:

Run: 844, exploration: 0.01, score: 10
Run: 845, exploration: 0.01, score: 10
Run: 846, exploration: 0.01, score: 8
Run: 847, exploration: 0.01, score: 8
Run: 848, exploration: 0.01, score: 10
Run: 849, exploration: 0.01, score: 10
Run: 850, exploration: 0.01, score: 9
Run: 851, exploration: 0.01, score: 10
Run: 852, exploration: 0.01, score: 9
Run: 853, exploration: 0.01, score: 9
Run: 854, exploration: 0.01, score: 9
Run: 855, exploration: 0.01, score: 9
Run: 856, exploration: 0.01, score: 9
Run: 857, exploration: 0.01, score: 8
Run: 858, exploration: 0.01, score: 10
Run: 859, exploration: 0.01, score: 9
Run: 860, exploration: 0.01, score: 10
Run: 861, exploration: 0.01, score: 11
Run: 862, exploration: 0.01, score: 10
Run: 863, exploration: 0.01, score: 9
Run: 864, exploration: 0.01, score: 10
Run: 865, exploration: 0.01, score: 10
Run: 866, exploration: 0.01, score: 13
Run: 867, exploration: 0.01, score: 10
Run: 868, exploration: 0.01, score: 9
Run: 869, exploration: 0.01, score: 1

Run: 1056, exploration: 0.01, score: 9
Run: 1057, exploration: 0.01, score: 10
Run: 1058, exploration: 0.01, score: 9
Run: 1059, exploration: 0.01, score: 9
Run: 1060, exploration: 0.01, score: 9
Run: 1061, exploration: 0.01, score: 10
Run: 1062, exploration: 0.01, score: 10
Run: 1063, exploration: 0.01, score: 9
Run: 1064, exploration: 0.01, score: 10
Run: 1065, exploration: 0.01, score: 9
Run: 1066, exploration: 0.01, score: 9
Run: 1067, exploration: 0.01, score: 11
Run: 1068, exploration: 0.01, score: 10
Run: 1069, exploration: 0.01, score: 9
Run: 1070, exploration: 0.01, score: 10
Run: 1071, exploration: 0.01, score: 10
Run: 1072, exploration: 0.01, score: 9
Run: 1073, exploration: 0.01, score: 10
Run: 1074, exploration: 0.01, score: 8
Run: 1075, exploration: 0.01, score: 9
Run: 1076, exploration: 0.01, score: 10
Run: 1077, exploration: 0.01, score: 9
Run: 1078, exploration: 0.01, score: 10
Run: 1079, exploration: 0.01, score: 10
Run: 1080, exploration: 0.01, score: 10
Run: 1081, e

Run: 1264, exploration: 0.01, score: 8
Run: 1265, exploration: 0.01, score: 10
Run: 1266, exploration: 0.01, score: 9
Run: 1267, exploration: 0.01, score: 9
Run: 1268, exploration: 0.01, score: 8
Run: 1269, exploration: 0.01, score: 10
Run: 1270, exploration: 0.01, score: 9
Run: 1271, exploration: 0.01, score: 10
Run: 1272, exploration: 0.01, score: 9
Run: 1273, exploration: 0.01, score: 10
Run: 1274, exploration: 0.01, score: 9
Run: 1275, exploration: 0.01, score: 8
Run: 1276, exploration: 0.01, score: 9
Run: 1277, exploration: 0.01, score: 10
Run: 1278, exploration: 0.01, score: 10
Run: 1279, exploration: 0.01, score: 10
Run: 1280, exploration: 0.01, score: 10
Run: 1281, exploration: 0.01, score: 10
Run: 1282, exploration: 0.01, score: 9
Run: 1283, exploration: 0.01, score: 9
Run: 1284, exploration: 0.01, score: 10
Run: 1285, exploration: 0.01, score: 10
Run: 1286, exploration: 0.01, score: 10
Run: 1287, exploration: 0.01, score: 9
Run: 1288, exploration: 0.01, score: 10
Run: 1289, e

Run: 1472, exploration: 0.01, score: 11
Run: 1473, exploration: 0.01, score: 10
Run: 1474, exploration: 0.01, score: 9
Run: 1475, exploration: 0.01, score: 9
Run: 1476, exploration: 0.01, score: 9
Run: 1477, exploration: 0.01, score: 10
Run: 1478, exploration: 0.01, score: 9
Run: 1479, exploration: 0.01, score: 10
Run: 1480, exploration: 0.01, score: 10
Run: 1481, exploration: 0.01, score: 8
Run: 1482, exploration: 0.01, score: 8
Run: 1483, exploration: 0.01, score: 10
Run: 1484, exploration: 0.01, score: 9
Run: 1485, exploration: 0.01, score: 10
Run: 1486, exploration: 0.01, score: 11
Run: 1487, exploration: 0.01, score: 9
Run: 1488, exploration: 0.01, score: 8
Run: 1489, exploration: 0.01, score: 10
Run: 1490, exploration: 0.01, score: 10
Run: 1491, exploration: 0.01, score: 10
Run: 1492, exploration: 0.01, score: 11
Run: 1493, exploration: 0.01, score: 9
Run: 1494, exploration: 0.01, score: 9
Run: 1495, exploration: 0.01, score: 9
Run: 1496, exploration: 0.01, score: 9
Run: 1497, ex

Run: 1680, exploration: 0.01, score: 9
Run: 1681, exploration: 0.01, score: 10
Run: 1682, exploration: 0.01, score: 8
Run: 1683, exploration: 0.01, score: 10
Run: 1684, exploration: 0.01, score: 10
Run: 1685, exploration: 0.01, score: 10
Run: 1686, exploration: 0.01, score: 11
Run: 1687, exploration: 0.01, score: 9
Run: 1688, exploration: 0.01, score: 10
Run: 1689, exploration: 0.01, score: 8
Run: 1690, exploration: 0.01, score: 9
Run: 1691, exploration: 0.01, score: 8
Run: 1692, exploration: 0.01, score: 11
Run: 1693, exploration: 0.01, score: 9
Run: 1694, exploration: 0.01, score: 8
Run: 1695, exploration: 0.01, score: 8
Run: 1696, exploration: 0.01, score: 8
Run: 1697, exploration: 0.01, score: 10
Run: 1698, exploration: 0.01, score: 9
Run: 1699, exploration: 0.01, score: 10
Run: 1700, exploration: 0.01, score: 11
Run: 1701, exploration: 0.01, score: 10
Run: 1702, exploration: 0.01, score: 9
Run: 1703, exploration: 0.01, score: 9
Run: 1704, exploration: 0.01, score: 9
Run: 1705, exp

Run: 1888, exploration: 0.01, score: 10
Run: 1889, exploration: 0.01, score: 9
Run: 1890, exploration: 0.01, score: 9
Run: 1891, exploration: 0.01, score: 10
Run: 1892, exploration: 0.01, score: 10
Run: 1893, exploration: 0.01, score: 9
Run: 1894, exploration: 0.01, score: 10
Run: 1895, exploration: 0.01, score: 9
Run: 1896, exploration: 0.01, score: 9
Run: 1897, exploration: 0.01, score: 9
Run: 1898, exploration: 0.01, score: 9
Run: 1899, exploration: 0.01, score: 10
Run: 1900, exploration: 0.01, score: 8
Run: 1901, exploration: 0.01, score: 8
Run: 1902, exploration: 0.01, score: 9
Run: 1903, exploration: 0.01, score: 10
Run: 1904, exploration: 0.01, score: 9
Run: 1905, exploration: 0.01, score: 11
Run: 1906, exploration: 0.01, score: 9
Run: 1907, exploration: 0.01, score: 9
Run: 1908, exploration: 0.01, score: 10
Run: 1909, exploration: 0.01, score: 10
Run: 1910, exploration: 0.01, score: 9
Run: 1911, exploration: 0.01, score: 8
Run: 1912, exploration: 0.01, score: 9
Run: 1913, explo

Run: 2096, exploration: 0.01, score: 10
Run: 2097, exploration: 0.01, score: 9
Run: 2098, exploration: 0.01, score: 10
Run: 2099, exploration: 0.01, score: 12
Run: 2100, exploration: 0.01, score: 10
Run: 2101, exploration: 0.01, score: 9
Run: 2102, exploration: 0.01, score: 10
Run: 2103, exploration: 0.01, score: 8
Run: 2104, exploration: 0.01, score: 10
Run: 2105, exploration: 0.01, score: 10
Run: 2106, exploration: 0.01, score: 9
Run: 2107, exploration: 0.01, score: 10
Run: 2108, exploration: 0.01, score: 10
Run: 2109, exploration: 0.01, score: 11
Run: 2110, exploration: 0.01, score: 9
Run: 2111, exploration: 0.01, score: 10
Run: 2112, exploration: 0.01, score: 10
Run: 2113, exploration: 0.01, score: 10
Run: 2114, exploration: 0.01, score: 9
Run: 2115, exploration: 0.01, score: 9
Run: 2116, exploration: 0.01, score: 9
Run: 2117, exploration: 0.01, score: 9
Run: 2118, exploration: 0.01, score: 10
Run: 2119, exploration: 0.01, score: 8
Run: 2120, exploration: 0.01, score: 9
Run: 2121, 

Run: 2304, exploration: 0.01, score: 9
Run: 2305, exploration: 0.01, score: 8
Run: 2306, exploration: 0.01, score: 10
Run: 2307, exploration: 0.01, score: 9
Run: 2308, exploration: 0.01, score: 10
Run: 2309, exploration: 0.01, score: 10
Run: 2310, exploration: 0.01, score: 9
Run: 2311, exploration: 0.01, score: 9
Run: 2312, exploration: 0.01, score: 10
Run: 2313, exploration: 0.01, score: 10
Run: 2314, exploration: 0.01, score: 10
Run: 2315, exploration: 0.01, score: 10
Run: 2316, exploration: 0.01, score: 10
Run: 2317, exploration: 0.01, score: 10
Run: 2318, exploration: 0.01, score: 8
Run: 2319, exploration: 0.01, score: 8
Run: 2320, exploration: 0.01, score: 9
Run: 2321, exploration: 0.01, score: 8
Run: 2322, exploration: 0.01, score: 9
Run: 2323, exploration: 0.01, score: 9
Run: 2324, exploration: 0.01, score: 10
Run: 2325, exploration: 0.01, score: 10
Run: 2326, exploration: 0.01, score: 10
Run: 2327, exploration: 0.01, score: 10
Run: 2328, exploration: 0.01, score: 9
Run: 2329, e

Run: 2512, exploration: 0.01, score: 8
Run: 2513, exploration: 0.01, score: 8
Run: 2514, exploration: 0.01, score: 8
Run: 2515, exploration: 0.01, score: 12
Run: 2516, exploration: 0.01, score: 9
Run: 2517, exploration: 0.01, score: 9
Run: 2518, exploration: 0.01, score: 10
Run: 2519, exploration: 0.01, score: 9
Run: 2520, exploration: 0.01, score: 9
Run: 2521, exploration: 0.01, score: 9
Run: 2522, exploration: 0.01, score: 9
Run: 2523, exploration: 0.01, score: 10
Run: 2524, exploration: 0.01, score: 9
Run: 2525, exploration: 0.01, score: 10
Run: 2526, exploration: 0.01, score: 10
Run: 2527, exploration: 0.01, score: 10
Run: 2528, exploration: 0.01, score: 9
Run: 2529, exploration: 0.01, score: 10
Run: 2530, exploration: 0.01, score: 10
Run: 2531, exploration: 0.01, score: 9
Run: 2532, exploration: 0.01, score: 10
Run: 2533, exploration: 0.01, score: 9
Run: 2534, exploration: 0.01, score: 10
Run: 2535, exploration: 0.01, score: 9
Run: 2536, exploration: 0.01, score: 8
Run: 2537, expl

Run: 2720, exploration: 0.01, score: 9
Run: 2721, exploration: 0.01, score: 9
Run: 2722, exploration: 0.01, score: 10
Run: 2723, exploration: 0.01, score: 10
Run: 2724, exploration: 0.01, score: 9
Run: 2725, exploration: 0.01, score: 9
Run: 2726, exploration: 0.01, score: 10
Run: 2727, exploration: 0.01, score: 9
Run: 2728, exploration: 0.01, score: 9
Run: 2729, exploration: 0.01, score: 9
Run: 2730, exploration: 0.01, score: 9
Run: 2731, exploration: 0.01, score: 11
Run: 2732, exploration: 0.01, score: 9
Run: 2733, exploration: 0.01, score: 9
Run: 2734, exploration: 0.01, score: 10
Run: 2735, exploration: 0.01, score: 9
Run: 2736, exploration: 0.01, score: 10
Run: 2737, exploration: 0.01, score: 10
Run: 2738, exploration: 0.01, score: 10
Run: 2739, exploration: 0.01, score: 10
Run: 2740, exploration: 0.01, score: 9
Run: 2741, exploration: 0.01, score: 10
Run: 2742, exploration: 0.01, score: 10
Run: 2743, exploration: 0.01, score: 10
Run: 2744, exploration: 0.01, score: 8
Run: 2745, ex

In [None]:
env.close()