In [1]:
import os

In [2]:
path = 'C:\\Users\\raven\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.7_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python37\\site-packages'

In [3]:
os.chdir(path)

In [4]:
import torch
import torch.nn as nn
import torch.optim as optim
import random
import numpy as np
import gym
from collections import deque

In [5]:
class act_model(nn.Module):
    def __init__(self,inp,hidden,output):
        super(act_model, self).__init__()
        self.fc1 = nn.Linear(inp, hidden, bias=True)
        self.fc2 = nn.Linear(hidden, hidden, bias=True)
        self.fc3 = nn.Linear(hidden, output, bias=True)
        self.fc12 = nn.LeakyReLU()
        
        self.memory = deque(maxlen=200)
        
        self.gamma = .95
        self.epsilon = 1.0 #exploration rate
        self.epsilon_min = .001
        self.epsilon_decay = .995
        self.tau = .01
        
        self.mse = nn.MSELoss()
        self.optimizer = optim.Adam(self.parameters(),lr=.001)
        
    def action(self,state):
        #choose action with probability
        if(random.random() <= self.epsilon):
            return np.random.choice(out, 1)[0]
        else:
            q_values = self.forward(state)
            return np.argmax(q_values.detach().numpy())  #Q(s,a)
            
    def memorize(self,state,action,reward,next_state,done):
        self.memory.append((state,action,reward,next_state,done))
        
    def replay(self,batch_size,target_model):
        if(len(self.memory) < batch_size): return 0 
        minibatch = random.sample(self.memory,batch_size)
        for state,action,reward,next_state,done in minibatch:
            target = reward
            if not done:
                target_q_values = target_model.forward(next_state)
                primary_q_values = self.forward(next_state)
                
                max_target_action = np.argmax(target_q_values.detach().numpy())
                primary_q_value = primary_q_values[max_target_action]
                target_q_value = primary_q_values[max_target_action]
                
                min_action = min(primary_q_value,target_q_value)
                
                target = (reward + self.gamma*min_action)
                
            target_f = self.forward(state)
            target_f[action] = target
            target_g = self.forward(state)

            self.zero_grad()
            self.optimizer.zero_grad()
            
            loss = self.mse(target_g,target_f)
            loss.backward(retain_graph=True)
            self.optimizer.step() 
        
            # target network update
            for target_param, param in zip(target_model.parameters(), self.parameters()):
                target_param.data.copy_(self.tau * param + (1 - self.tau) * target_param)
        
        if self.epsilon > self.epsilon_min:
            self.epsilon *= self.epsilon_decay
            
    def load(self,PATH):
        torch.save(self, PATH)

    def save(self,PATH):
        model = torch.load(PATH)
        return model
        
    def forward(self,x):
        out = self.fc12(self.fc1(x))
        out = self.fc12(self.fc2(out))
        out = self.fc3(out)
        
        return out

In [6]:
inp = 4
hid = 24
out = 2
primary_model = act_model(inp,hid,out)
target_model = act_model(inp,hid,out)

env = gym.make('CartPole-v0')
epochs = 2000
batch_size = 50

for epoch in range(epochs):
    state = env.reset()
    for t in range(200):
        state = torch.FloatTensor(state)
        action = primary_model.action(state)
        next_state, reward, done, info = env.step(action)
        reward = reward if not done else -10
        
        next_state = torch.FloatTensor(next_state)
        primary_model.memorize(state,action,reward,next_state,done)
        state = next_state
        
        primary_model.replay(batch_size,target_model)

        if done:
            print("episode: {}/{}, score: {}, e: {:.2}"
                  .format(epoch, epochs, t, primary_model.epsilon))
            break   
env.close()

episode: 1/2000, score: 163, e: 0.21
episode: 2/2000, score: 189, e: 0.08
episode: 4/2000, score: 154, e: 0.013
episode: 6/2000, score: 99, e: 0.003
episode: 8/2000, score: 105, e: 0.001
episode: 9/2000, score: 107, e: 0.001
episode: 11/2000, score: 102, e: 0.001
episode: 12/2000, score: 176, e: 0.001
episode: 13/2000, score: 170, e: 0.001
episode: 14/2000, score: 95, e: 0.001
episode: 16/2000, score: 87, e: 0.001
episode: 18/2000, score: 122, e: 0.001
episode: 19/2000, score: 141, e: 0.001
episode: 21/2000, score: 125, e: 0.001
episode: 23/2000, score: 103, e: 0.001
episode: 25/2000, score: 85, e: 0.001
episode: 26/2000, score: 116, e: 0.001
episode: 27/2000, score: 128, e: 0.001
episode: 28/2000, score: 152, e: 0.001
episode: 29/2000, score: 115, e: 0.001
episode: 30/2000, score: 122, e: 0.001
episode: 32/2000, score: 117, e: 0.001
episode: 34/2000, score: 96, e: 0.001


KeyboardInterrupt: 