## Imports

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import random
import numpy as np
import gym
from collections import deque

## Q-learning model

In [2]:
class act_model(nn.Module):
    def __init__(self,inp,hidden,output):
        super(act_model, self).__init__()
        
        #Pytorch variables
        self.fc1 = nn.Linear(inp, hidden, bias=True)
        self.fc2 = nn.Linear(hidden, hidden, bias=True)
        self.fc3 = nn.Linear(hidden, hidden, bias=True)
        self.fc4 = nn.Linear(hidden, 1, bias=True)
        self.fc5 = nn.Linear(hidden, output, bias=True)
        self.fc12 = nn.LeakyReLU()
        
        #memory buffer
        self.memory = deque(maxlen=500)
        
        #q-learning hyperparameters
        self.gamma = .95
        self.epsilon = 1.0 #exploration rate
        self.epsilon_min = .001
        self.epsilon_decay = .995
        self.tau = .01
        
        #loss & optimizer
        self.mse = nn.MSELoss()
        self.optimizer = optim.Adam(self.parameters(),lr=.001)
        
    #Choose epsilon-greedy action
    def action(self,state):
        if(random.random() <= self.epsilon):
            return np.random.choice(out, 1)[0]
        else:
            q_values = self.forward(state)
            return np.argmax(q_values.detach().numpy())  #Q(s,a)
        
    #save state,action,reward, & next state in replay buffer
    def memorize(self,state,action,reward,next_state,done):
        self.memory.append((state,action,reward,next_state,done))
        
    #get batch from memory buffer and train pytorch model
    def replay(self,batch_size,target_model):
        
        #if batch size too small, don't train network
        if(len(self.memory) < batch_size): return 0 
        minibatch = random.sample(self.memory,batch_size)
        
        #Train Q-networks
        for state,action,reward,next_state,done in minibatch:
            target = reward
            if not done:
                #find minimum value between the two networks of the max action
                target_q_values = target_model.forward(next_state)
                primary_q_values = self.forward(next_state)
                
                max_target_action = np.argmax(target_q_values.detach().numpy())
                primary_q_value = primary_q_values[max_target_action]
                target_q_value = primary_q_values[max_target_action]
                
                min_action = min(primary_q_value,target_q_value)
                
                #get target value
                target = (reward + self.gamma*min_action)
                
            #Target_f: Index of chosen action is set to target value
            target_f = self.forward(state)
            target_f[action] = target
            target_g = self.forward(state)

            self.zero_grad()
            self.optimizer.zero_grad()
            
            #Primary network update
            loss = self.mse(target_g,target_f)
            loss.backward(retain_graph=True)
            self.optimizer.step() 
        
            #Target network update
            for target_param, param in zip(target_model.parameters(), self.parameters()):
                target_param.data.copy_(self.tau * param + (1 - self.tau) * target_param)
        
        #adjust epsilon value to take less random actions
        if self.epsilon > self.epsilon_min:
            self.epsilon *= self.epsilon_decay
            
    #laod pytorch model
    def load(self,PATH):
        torch.save(self, PATH)

    #save pytorch model
    def save(self,PATH):
        model = torch.load(PATH)
        return model
        
    #run forward pass of pytorch model
    def forward(self,x):
        out = self.fc12(self.fc1(x))
        out = self.fc12(self.fc2(out))
        
        v = self.fc12(self.fc4(out))
        a = self.fc12(self.fc5(out))
        
        return v+(a-a.mean())

## Initialize parameters

In [3]:
inp = 4
hid = 24
out = 2

epochs = 2000
batch_size = 50

primary_model = act_model(inp,hid,out)
target_model = act_model(inp,hid,out)

## Train models

In [4]:
env = gym.make('CartPole-v0')
for epoch in range(epochs):
    state = env.reset()
    for t in range(200):
        state = torch.FloatTensor(state)
        action = primary_model.action(state)
        next_state, reward, done, info = env.step(action)
        reward = reward if not done else -10
        
        next_state = torch.FloatTensor(next_state)
        primary_model.memorize(state,action,reward,next_state,done)
        state = next_state
        
        primary_model.replay(batch_size,target_model)

        if done:
            break   
    print("episode: {}/{}, score: {}, e: {:.2}"
          .format(epoch, epochs, t, primary_model.epsilon))
env.close()

episode: 0/2000, score: 40, e: 1.0
episode: 1/2000, score: 19, e: 0.94
episode: 2/2000, score: 15, e: 0.87
episode: 3/2000, score: 26, e: 0.76
episode: 4/2000, score: 19, e: 0.69
episode: 5/2000, score: 35, e: 0.57
episode: 6/2000, score: 15, e: 0.53
episode: 7/2000, score: 8, e: 0.51
episode: 8/2000, score: 22, e: 0.45
episode: 9/2000, score: 29, e: 0.39
episode: 10/2000, score: 9, e: 0.37
episode: 11/2000, score: 34, e: 0.31
episode: 12/2000, score: 37, e: 0.26
episode: 13/2000, score: 68, e: 0.18
episode: 14/2000, score: 38, e: 0.15
episode: 15/2000, score: 22, e: 0.13
episode: 16/2000, score: 35, e: 0.11
episode: 17/2000, score: 61, e: 0.081
episode: 18/2000, score: 28, e: 0.07
episode: 19/2000, score: 27, e: 0.061
episode: 20/2000, score: 38, e: 0.05
episode: 21/2000, score: 19, e: 0.045
episode: 22/2000, score: 26, e: 0.04
episode: 23/2000, score: 22, e: 0.035
episode: 24/2000, score: 28, e: 0.031
episode: 25/2000, score: 34, e: 0.026
episode: 26/2000, score: 141, e: 0.013
episod

KeyboardInterrupt: 