In [1]:
import torch
import torch.nn as nn
import numpy as np
import torch.nn.functional as F
import copy
import gym
import random
import time
from collections import deque

In [2]:
GAMMA = .62
ALPHA = .7
BATCH_SIZE = 128
MIN_REPLAY_SIZE = 1000
SAMPLE_SIZE = 50000
MAX_EPOCHS = 1000
EPS_DECAY = .001
EPSILON_INITIAL = 1

In [3]:
class Net(nn.Module):
    def __init__(self, state_size, action_size):
        super(Net, self).__init__()
        learning_rate = 0.001
        self.w1 = nn.Linear(state_size, 24)
        self.w2 = nn.Linear(24, 12)
        self.w3 = nn.Linear(12, action_size)
        self.loss_ = nn.SmoothL1Loss()
        self.optimizer = torch.optim.Adam(self.parameters(), lr=learning_rate)
        
    def forward(self, x):
        x = self.w1(x)
        x = F.relu(x)
        x = self.w2(x)
        x = F.relu(x)
        x = self.w3(x)
        return x
    
    def backwards(self, my_loss):
        self.optimizer.zero_grad()
        my_loss.backward()
        self.optimizer.step()
    
    def loss(self, y_ex, y_act):
        return self.loss_(y_ex, y_act)

def train(replay_memory, main, target, batch_size):
    if len(replay_memory) < MIN_REPLAY_SIZE:
        return # wait till replay memory filled
    
    mini_batch = random.sample(replay_memory, batch_size)
    s_tminus1 = [memory[0] for memory in mini_batch]
    s_tminus1 = torch.FloatTensor(s_tminus1)
    qs_tminus1 = main.forward(s_tminus1)
    
    s_t = [memory[3] for memory in mini_batch]
    s_t = torch.FloatTensor(s_t)
    qs_t = target.forward(s_t)
    
    X = []
    Y = []
    
    # Gather Ys Q values for model
    for index, (prev_obs, reward, action, obs, done) in enumerate(mini_batch):
        q_tminus1 = qs_tminus1[index]
        q_max = torch.max(qs_t[index])
        if (not done):
            sample = reward + GAMMA * q_max
        else:
            sample = reward

        q_tminus1[action] = (1-ALPHA) * q_tminus1[action] + ALPHA * sample 
        X.append(prev_obs)
        Y.append(q_tminus1.tolist())

    # run minibatch
#     print(Y)
    Y_act = main.forward(torch.FloatTensor(X))
    loss = main.loss(torch.FloatTensor(Y), Y_act)
    if (batch_size > BATCH_SIZE):
        print(loss)
    main.backwards(loss)
        

In [4]:
main = Net(4,2)
target = copy.deepcopy(main)
print(target)
print(main)

Net(
  (w1): Linear(in_features=4, out_features=24, bias=True)
  (w2): Linear(in_features=24, out_features=12, bias=True)
  (w3): Linear(in_features=12, out_features=2, bias=True)
  (loss_): SmoothL1Loss()
)
Net(
  (w1): Linear(in_features=4, out_features=24, bias=True)
  (w2): Linear(in_features=24, out_features=12, bias=True)
  (w3): Linear(in_features=12, out_features=2, bias=True)
  (loss_): SmoothL1Loss()
)


In [6]:
env = gym.make('CartPole-v0')
env.seed(5)  # comment out after testing

max_step = 100
displacement = 4
epoch = 0
step = 0

main = Net(4,2)
target = copy.deepcopy(main)


experience_replay = deque(maxlen=SAMPLE_SIZE)

epsilon = EPSILON_INITIAL
while(epoch < MAX_EPOCHS):
    if (epsilon < .01):
        epsilon = .01
    prev_obs = env.reset()
    while True:
        step += 1

        # Epsilon Greedy
        eps_sample = random.random()        
        if eps_sample >= epsilon:
            q = main.forward(torch.FloatTensor(prev_obs))
            # exploitation
            action = torch.argmax(q).item()
        else:
            # exploration
            action = env.action_space.sample()
        # decay epsilon after every random sample
        # for more exploitation over time
        epsilon = (1 - EPS_DECAY) * epsilon

        obs, reward, done, info = env.step(env.action_space.sample())
        experience_replay.append([prev_obs, reward, action, obs, done])


        if (step % displacement == 0 and len(experience_replay) >= MIN_REPLAY_SIZE):
            train(experience_replay, main, target, BATCH_SIZE)


        prev_obs = obs
        if (step >= max_step):
            train(experience_replay, main, target, len(experience_replay))
            
            
#             print("copying main to target")
            print("Epoch", epoch, "finished")
            target = copy.deepcopy(main)
            step = 0
            epoch += 1
            if (epoch >= MAX_EPOCHS):
                break
        if done:
            break

#     env.render()
    step += 1
print("Finished!")
env.close()

Epoch 0 finished
Epoch 1 finished
Epoch 2 finished
Epoch 3 finished
Epoch 4 finished
Epoch 5 finished
Epoch 6 finished
Epoch 7 finished
Epoch 8 finished
Epoch 9 finished
tensor(0.1057, grad_fn=<SmoothL1LossBackward>)
Epoch 10 finished
tensor(0.0820, grad_fn=<SmoothL1LossBackward>)
Epoch 11 finished
tensor(0.0744, grad_fn=<SmoothL1LossBackward>)
Epoch 12 finished
tensor(0.0714, grad_fn=<SmoothL1LossBackward>)
Epoch 13 finished
tensor(0.0732, grad_fn=<SmoothL1LossBackward>)
Epoch 14 finished
tensor(0.0693, grad_fn=<SmoothL1LossBackward>)
Epoch 15 finished
tensor(0.0570, grad_fn=<SmoothL1LossBackward>)
Epoch 16 finished
tensor(0.0414, grad_fn=<SmoothL1LossBackward>)
Epoch 17 finished
tensor(0.0281, grad_fn=<SmoothL1LossBackward>)
Epoch 18 finished
tensor(0.0207, grad_fn=<SmoothL1LossBackward>)
Epoch 19 finished
tensor(0.0164, grad_fn=<SmoothL1LossBackward>)
Epoch 20 finished
tensor(0.0151, grad_fn=<SmoothL1LossBackward>)
Epoch 21 finished
tensor(0.0135, grad_fn=<SmoothL1LossBackward>)
Epo

tensor(0.0035, grad_fn=<SmoothL1LossBackward>)
Epoch 133 finished


KeyboardInterrupt: 

In [7]:
env = gym.make('CartPole-v0')
obs = env.reset()
N = 1000
for t in range(N):
    env.render()
    time.sleep(.1)
    x = torch.from_numpy(obs).float()
    action_vect = main.forward(x)
    action = torch.argmax(action_vect)
    obs, reward, done, info = env.step(action.item())
    if (done):
        print(obs)
        print("Model failed! on", t, "Could not survive", N, "steps")
        break;
if t == N:
    print("model success!")
time.sleep(1)
env.close() 

[ 0.09813787  0.56478765 -0.21013758 -1.25954689]
Model failed! on 20 Could not survive 1000 steps


In [74]:
print(env.action_space)
print(env.observation_space)

Discrete(2)
Box(-3.4028234663852886e+38, 3.4028234663852886e+38, (4,), float32)


In [19]:
env.close()