In [1]:
import random
import gym
import numpy as np
from collections import deque
from keras.models import Sequential
from keras.layers import Dense
from keras.optimizers import Adam
from tqdm import tqdm

from sklearn.linear_model import LinearRegression
from sklearn.ensemble import RandomForestRegressor
from xgboost import XGBRegressor
from catboost import CatBoostRegressor
from sklearn.multioutput import MultiOutputRegressor
from lightgbm import LGBMRegressor
from sklearn.neighbors import KNeighborsRegressor
from sklearn.svm import SVR
from sklearn.model_selection import train_test_split
from scores.score_logger import ScoreLogger

ENV_NAME = "CartPole-v1"

GAMMA = 0.95
LEARNING_RATE = 0.001

MEMORY_SIZE = 1000
BATCH_SIZE = 20

EXPLORATION_MAX = 1.0
EXPLORATION_MIN = 0.05
EXPLORATION_DECAY = 0.95

Using TensorFlow backend.


In [2]:
class DQNSolver:

    def __init__(self, observation_space, action_space):
        self.exploration_rate = EXPLORATION_MAX

        self.action_space = action_space
        self.memory = deque(maxlen=MEMORY_SIZE)

        self.model = MultiOutputRegressor(LGBMRegressor(n_estimators=500, n_jobs=-1))
        #self.model = KNeighborsRegressor(n_jobs=-1)
        #self.model = MultiOutputRegressor(SVR(), n_jobs=8)
        self.isFit = False

    def remember(self, state, action, reward, next_state, done):
        self.memory.append((state, action, reward, next_state, done))

    def act(self, state):
        if np.random.rand() < self.exploration_rate:
            return random.randrange(self.action_space)
        if self.isFit == True:
            q_values = self.model.predict(state)
        else:
            q_values = np.zeros(self.action_space).reshape(1, -1)
        return np.argmax(q_values[0])

    def experience_replay(self):
        if len(self.memory) < BATCH_SIZE:
            return
        batch = random.sample(self.memory, int(len(self.memory)/1))
        X = []
        targets = []
        for state, action, reward, state_next, terminal in batch:
            q_update = reward
            if not terminal:
                if self.isFit:
                    q_update = (reward + GAMMA * np.amax(self.model.predict(state_next)[0]))
                    #print(self.model.predict(state_next))
                else:
                    q_update = reward
            if self.isFit:
                q_values = self.model.predict(state)
            else:
                q_values = np.zeros(self.action_space).reshape(1, -1)
            q_values[0][action] = q_update
            
            #print(state)
            #print(action)
            #print(q_values)
            X.append(list(state[0]))
            targets.append(q_values[0])
        #print(X)
        #print(targets)
        self.model.fit(X, targets)
        self.isFit = True
        self.exploration_rate *= EXPLORATION_DECAY
        self.exploration_rate = max(EXPLORATION_MIN, self.exploration_rate)

In [3]:
env = gym.make(ENV_NAME)
score_logger = ScoreLogger(ENV_NAME)
observation_space = env.observation_space.shape[0]
action_space = env.action_space.n
dqn_solver = DQNSolver(observation_space, action_space)
run = 0
while True:
    run += 1
    state = env.reset()
    state = np.reshape(state, [1, observation_space])
    step = 0
    while True:
        step += 1
        #env.render()
        action = dqn_solver.act(state)
        state_next, reward, terminal, info = env.step(action)
        reward = reward if not terminal else -reward
        state_next = np.reshape(state_next, [1, observation_space])
        dqn_solver.remember(state, action, reward, state_next, terminal)
        state = state_next
        if terminal:
            print("Run: " + str(run) + ", exploration: " + str(dqn_solver.exploration_rate) + ", score: " + str(step))
            score_logger.add_score(step, run)
            break
    dqn_solver.experience_replay()

Run: 1, exploration: 1.0, score: 32
Scores: (min: 32, avg: 32, max: 32)

Run: 2, exploration: 0.95, score: 18
Scores: (min: 18, avg: 25, max: 32)



  z = np.polyfit(np.array(trend_x), np.array(y[1:]), 1)


Run: 3, exploration: 0.9025, score: 24
Scores: (min: 18, avg: 24.666666666666668, max: 32)

Run: 4, exploration: 0.8573749999999999, score: 22
Scores: (min: 18, avg: 24, max: 32)

Run: 5, exploration: 0.8145062499999999, score: 27
Scores: (min: 18, avg: 24.6, max: 32)

Run: 6, exploration: 0.7737809374999999, score: 40
Scores: (min: 18, avg: 27.166666666666668, max: 40)

Run: 7, exploration: 0.7350918906249998, score: 31
Scores: (min: 18, avg: 27.714285714285715, max: 40)

Run: 8, exploration: 0.6983372960937497, score: 20
Scores: (min: 18, avg: 26.75, max: 40)

Run: 9, exploration: 0.6634204312890623, score: 18
Scores: (min: 18, avg: 25.77777777777778, max: 40)

Run: 10, exploration: 0.6302494097246091, score: 26
Scores: (min: 18, avg: 25.8, max: 40)

Run: 11, exploration: 0.5987369392383786, score: 19
Scores: (min: 18, avg: 25.181818181818183, max: 40)

Run: 12, exploration: 0.5688000922764596, score: 40
Scores: (min: 18, avg: 26.416666666666668, max: 40)

Run: 13, exploration: 0.540

Run: 87, exploration: 0.05, score: 76
Scores: (min: 13, avg: 72.10344827586206, max: 198)

Run: 88, exploration: 0.05, score: 209
Scores: (min: 13, avg: 73.6590909090909, max: 209)

Run: 89, exploration: 0.05, score: 97
Scores: (min: 13, avg: 73.92134831460675, max: 209)

Run: 90, exploration: 0.05, score: 171
Scores: (min: 13, avg: 75, max: 209)

Run: 91, exploration: 0.05, score: 124
Scores: (min: 13, avg: 75.53846153846153, max: 209)

Run: 92, exploration: 0.05, score: 163
Scores: (min: 13, avg: 76.48913043478261, max: 209)

Run: 93, exploration: 0.05, score: 234
Scores: (min: 13, avg: 78.18279569892474, max: 234)

Run: 94, exploration: 0.05, score: 161
Scores: (min: 13, avg: 79.06382978723404, max: 234)

Run: 95, exploration: 0.05, score: 191
Scores: (min: 13, avg: 80.2421052631579, max: 234)

Run: 96, exploration: 0.05, score: 187
Scores: (min: 13, avg: 81.35416666666667, max: 234)

Run: 97, exploration: 0.05, score: 129
Scores: (min: 13, avg: 81.84536082474227, max: 234)

Run: 98

NameError: name 'exit' is not defined