In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import gym

In [2]:
# Hyper Parameters
BATCH_SIZE = 32
LR = 0.01  # learning rate
EPSILON = 0.9  # greedy policy
GAMMA = 0.9  # reward discount
TARGET_REPLACE_ITER = 100  # target update frequency
MEMORY_CAPACITY = 2000
env = gym.make('CartPole-v0')
env = env.unwrapped
N_ACTIONS = env.action_space.n
N_STATES = env.observation_space.shape[0]
ENV_A_SHAPE = 0 if isinstance(env.action_space.sample(),
                              int) else env.action_space.sample().shape  # to confirm the shape


In [3]:
# Policy Network
class Net(nn.Module):
    def __init__(self, ):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(N_STATES, 50)
        self.fc1.weight.data.normal_(0, 0.1)  # initialization
        self.out = nn.Linear(50, N_ACTIONS)
        self.out.weight.data.normal_(0, 0.1)  # initialization

    def forward(self, x):
        x = self.fc1(x)
        x = F.relu(x)
        actions_value = self.out(x)
        return actions_value


class DQN(object):
    def __init__(self):
        self.eval_net, self.target_net = Net(), Net()

        self.learn_step_counter = 0  # for target updating
        self.memory_counter = 0  # for storing memory
        self.memory = np.zeros((MEMORY_CAPACITY, N_STATES * 2 + 2))  # initialize memory
        self.optimizer = torch.optim.Adam(self.eval_net.parameters(), lr=LR)
        self.loss_func = nn.MSELoss()

    def choose_action(self, x):
        x = torch.unsqueeze(torch.FloatTensor(x), 0)
        # input only one sample
        if np.random.uniform() < EPSILON:  # greedy
            actions_value = self.eval_net.forward(x)
            action = torch.max(actions_value, 1)[1].data.numpy()
            action = action[0] if ENV_A_SHAPE == 0 else action.reshape(ENV_A_SHAPE)  # return the argmax index
        else:  # random
            action = np.random.randint(0, N_ACTIONS)
            action = action if ENV_A_SHAPE == 0 else action.reshape(ENV_A_SHAPE)
        return action

    def store_transition(self, s, a, r, s_):
        transition = np.hstack((s, [a, r], s_))
        # replace the old memory with new memory
        index = self.memory_counter % MEMORY_CAPACITY
        self.memory[index, :] = transition
        self.memory_counter += 1

    def learn(self):
        # target parameter update
        if self.learn_step_counter % TARGET_REPLACE_ITER == 0:
            self.target_net.load_state_dict(self.eval_net.state_dict())
        self.learn_step_counter += 1

        # sample batch transitions
        sample_index = np.random.choice(MEMORY_CAPACITY, BATCH_SIZE)
        b_memory = self.memory[sample_index, :]
        b_s = torch.FloatTensor(b_memory[:, :N_STATES])
        b_a = torch.LongTensor(b_memory[:, N_STATES:N_STATES + 1].astype(int))
        b_r = torch.FloatTensor(b_memory[:, N_STATES + 1:N_STATES + 2])
        b_s_ = torch.FloatTensor(b_memory[:, -N_STATES:])

        # q_eval w.r.t the action in experience
        q_eval = self.eval_net(b_s).gather(1, b_a)  # shape (batch, 1)
        q_next = self.target_net(b_s_).detach()  # detach from graph, don't backpropagate
        q_target = b_r + GAMMA * q_next.max(1)[0].view(BATCH_SIZE, 1)  # shape (batch, 1)
        loss = self.loss_func(q_eval, q_target)

        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()


dqn = DQN()

print('\nCollecting experience...')
for i_episode in range(400):
    s = env.reset()
    ep_r = 0
    print('Episode ', i_episode)

    cnt = 0
    while True:
        #env.render()
        a = dqn.choose_action(s)

        # action
        s_, r, done, info = env.step(a)

        # modify the reward

        # see https://github.com/openai/gym/blob/master/gym/envs/classic_control/cartpole.py
        x, x_dot, theta, theta_dot = s_  # CartPole states, including pos, velocity, angle, tip velocity
        r1 = (env.x_threshold - abs(x)) / env.x_threshold - 0.8  # pos is not deviating too much
        r2 = (env.theta_threshold_radians - abs(theta)) / env.theta_threshold_radians - 0.5  # angle is not too big
        r = r1 + r2
        
        dqn.store_transition(s, a, r, s_)

        ep_r += r
        if dqn.memory_counter > MEMORY_CAPACITY:
            dqn.learn()
            if done:
                print('Ep: ', i_episode,
                      '| Ep_reward: ', round(ep_r, 2))

        if done:
            print(' steps ', cnt)
            break
        s = s_
        cnt += 1


Collecting experience...
Episode  0
 steps  8
Episode  1
 steps  8
Episode  2
 steps  9
Episode  3
 steps  10
Episode  4
 steps  11
Episode  5
 steps  11
Episode  6
 steps  9
Episode  7
 steps  9
Episode  8
 steps  8
Episode  9
 steps  9
Episode  10
 steps  7
Episode  11
 steps  12
Episode  12
 steps  8
Episode  13
 steps  12
Episode  14
 steps  8
Episode  15
 steps  10
Episode  16
 steps  9
Episode  17
 steps  9
Episode  18
 steps  9
Episode  19
 steps  8
Episode  20
 steps  8
Episode  21
 steps  8
Episode  22
 steps  8
Episode  23
 steps  8
Episode  24
 steps  8
Episode  25
 steps  10
Episode  26
 steps  8
Episode  27
 steps  9
Episode  28
 steps  10
Episode  29
 steps  10
Episode  30
 steps  9
Episode  31
 steps  9
Episode  32
 steps  10
Episode  33
 steps  9
Episode  34
 steps  9
Episode  35
 steps  9
Episode  36
 steps  9
Episode  37
 steps  10
Episode  38
 steps  8
Episode  39
 steps  9
Episode  40
 steps  9
Episode  41
 steps  9
Episode  42
 steps  8
Episode  43
 steps  9
Episo

Ep:  277 | Ep_r:  412.82
 steps  1135
Episode  278
Ep:  278 | Ep_r:  662.08
 steps  2469
Episode  279
Ep:  279 | Ep_r:  492.14
 steps  1273
Episode  280
Ep:  280 | Ep_r:  321.62
 steps  956
Episode  281
Ep:  281 | Ep_r:  1474.53
 steps  3117
Episode  282
Ep:  282 | Ep_r:  701.61
 steps  2437
Episode  283
Ep:  283 | Ep_r:  1288.69
 steps  3863
Episode  284
Ep:  284 | Ep_r:  433.14
 steps  1234
Episode  285
Ep:  285 | Ep_r:  672.15
 steps  1689
Episode  286
Ep:  286 | Ep_r:  366.71
 steps  1753
Episode  287
Ep:  287 | Ep_r:  320.46
 steps  1708
Episode  288
Ep:  288 | Ep_r:  406.51
 steps  1162
Episode  289
Ep:  289 | Ep_r:  791.26
 steps  1735
Episode  290
Ep:  290 | Ep_r:  239.38
 steps  782
Episode  291
Ep:  291 | Ep_r:  561.36
 steps  1604
Episode  292
Ep:  292 | Ep_r:  433.41
 steps  1112
Episode  293
Ep:  293 | Ep_r:  320.28
 steps  1830
Episode  294
Ep:  294 | Ep_r:  369.89
 steps  1201
Episode  295
Ep:  295 | Ep_r:  500.34
 steps  1524
Episode  296
Ep:  296 | Ep_r:  254.99
 steps