In [1]:
import gym
import numpy as np
import random
from keras.models import Sequential
from keras.layers import Dense, Dropout
from keras.optimizers import Adam

from collections import deque
import os

Using TensorFlow backend.


In [2]:
env = gym.make('Acrobot-v1')
state_size = env.observation_space.shape[0]
action_size = env.action_space.n
batch_size = 32
n_episodes = 1000
output_dir = 'model_output/Acrobat/'
if not os.path.exists(output_dir):
    os.makedirs(output_dir)

In [3]:
class AcrobatArm():
    def __init__(self, state_size, action_size):
        self.state_size = state_size
        self.action_size = action_size
        self.learning_rate = 0.001
        self.gamma = 0.95
        self.epsilon = 1.0
        self.epsilon_decay = 0.995
        self.epsilon_min = 0.01
        self.memory = deque(maxlen = 2000)
        self.model = self.build_model()
        
    def build_model(self):
        model = Sequential()
        model.add(Dense(32, activation = 'relu', input_dim = self.state_size))
        model.add(Dense(32, activation = 'relu'))
        model.add(Dense(self.action_size, activation='linear'))
        model.compile(loss='mse', optimizer=Adam(learning_rate=self.learning_rate))
        
        return model
    
    def remember(self, state, action, reward, next_state, done):
        self.memory.append((state, action, reward, next_state, done))
        
    def train(self, batch_size):
        if len(agent.memory) >= batch_size:
            mini_batch = random.sample(self.memory, batch_size)

            if done:
                target = reward
            else:
                target = reward + np.max(self.gamma*self.model.predict(next_state))

            y = self.model.predict(state)
            y[0][action] = target

            self.model.fit(state, y, epochs=1, verbose=0)

        
    def act(self, state):
        if self.epsilon > np.random.randn():
            return env.action_space.sample()
        pred = self.model.predict(state)
        return np.argmax(pred[0])
    
    def save(self, name): 
        self.model.save_weights(name)

In [4]:
agent = AcrobatArm(state_size, action_size)

In [5]:
for epi in range(n_episodes):
    state = env.reset().reshape(1, state_size)
    
    done = False
    time= 0
    reward_in_epi = []
    while not done:
        env.render()
        action = agent.act(state)
        next_state, reward, done, _ = env.step(action)
        reward_in_epi.append(reward)
        
        #reward = reward if not done else -10
        next_state = np.reshape(next_state, [1, state_size])
        agent.remember(state, action, reward, next_state, done)
        agent.train(batch_size)
        state = next_state
        if done: 
            if time >=199:
                print("episode: {}/{}, TReward {}, e: {:.2}".format(epi, n_episodes-1, np.mean(reward_in_epi), agent.epsilon))
            else:
                print('-'*20 + 'Reached  episode: {}/{} TReward {}, e: {:.2}'.format(epi, n_episodes-1, np.mean(reward_in_epi), agent.epsilon))
        
        time+=1

        
    if agent.epsilon > agent.epsilon_min:
            agent.epsilon *= agent.epsilon_decay
            
    if epi%50 == 0:
        agent.save(output_dir + "weights_" + '{:04d}'.format(epi) + ".hdf5")

episode: 0/999, TReward -1.0, e: 1.0
episode: 1/999, TReward -0.9970674486803519, e: 0.99
episode: 2/999, TReward -1.0, e: 0.99
episode: 3/999, TReward -1.0, e: 0.99
episode: 4/999, TReward -1.0, e: 0.98
episode: 5/999, TReward -1.0, e: 0.98
episode: 6/999, TReward -0.997229916897507, e: 0.97
episode: 7/999, TReward -1.0, e: 0.97
episode: 8/999, TReward -1.0, e: 0.96
episode: 9/999, TReward -1.0, e: 0.96
episode: 10/999, TReward -1.0, e: 0.95
episode: 11/999, TReward -1.0, e: 0.95
episode: 12/999, TReward -1.0, e: 0.94
episode: 13/999, TReward -1.0, e: 0.94
episode: 14/999, TReward -1.0, e: 0.93
episode: 15/999, TReward -0.9974226804123711, e: 0.93
episode: 16/999, TReward -1.0, e: 0.92
episode: 17/999, TReward -1.0, e: 0.92
episode: 18/999, TReward -0.9979381443298969, e: 0.91
episode: 19/999, TReward -1.0, e: 0.91
episode: 20/999, TReward -1.0, e: 0.9
episode: 21/999, TReward -1.0, e: 0.9
episode: 22/999, TReward -0.9971590909090909, e: 0.9
episode: 23/999, TReward -1.0, e: 0.89
epis

episode: 190/999, TReward -0.9977064220183486, e: 0.39
episode: 191/999, TReward -1.0, e: 0.38
episode: 192/999, TReward -0.9975728155339806, e: 0.38
episode: 193/999, TReward -1.0, e: 0.38
episode: 194/999, TReward -0.9977678571428571, e: 0.38
episode: 195/999, TReward -1.0, e: 0.38
episode: 196/999, TReward -1.0, e: 0.37
episode: 197/999, TReward -0.9978308026030369, e: 0.37
episode: 198/999, TReward -0.9974874371859297, e: 0.37
episode: 199/999, TReward -1.0, e: 0.37
episode: 200/999, TReward -0.9979550102249489, e: 0.37
episode: 201/999, TReward -0.9975786924939467, e: 0.37
episode: 202/999, TReward -0.9967105263157895, e: 0.36
episode: 203/999, TReward -0.9979423868312757, e: 0.36
episode: 204/999, TReward -0.9972527472527473, e: 0.36
episode: 205/999, TReward -0.9969879518072289, e: 0.36
episode: 206/999, TReward -1.0, e: 0.36
episode: 207/999, TReward -1.0, e: 0.35
episode: 208/999, TReward -0.9971509971509972, e: 0.35
episode: 209/999, TReward -1.0, e: 0.35
episode: 210/999, TR

episode: 347/999, TReward -0.9970588235294118, e: 0.18
episode: 348/999, TReward -0.996309963099631, e: 0.17
episode: 349/999, TReward -0.9977876106194691, e: 0.17
episode: 350/999, TReward -0.9978586723768736, e: 0.17
episode: 351/999, TReward -0.9977116704805492, e: 0.17
episode: 352/999, TReward -1.0, e: 0.17
episode: 353/999, TReward -0.9978586723768736, e: 0.17
episode: 354/999, TReward -1.0, e: 0.17
episode: 355/999, TReward -0.9971509971509972, e: 0.17
episode: 356/999, TReward -1.0, e: 0.17
episode: 357/999, TReward -1.0, e: 0.17
episode: 358/999, TReward -1.0, e: 0.17
episode: 359/999, TReward -1.0, e: 0.17
episode: 360/999, TReward -0.9979423868312757, e: 0.16
episode: 361/999, TReward -1.0, e: 0.16
episode: 362/999, TReward -0.9974937343358395, e: 0.16
episode: 363/999, TReward -1.0, e: 0.16
episode: 364/999, TReward -0.9975247524752475, e: 0.16
episode: 365/999, TReward -0.996268656716418, e: 0.16
episode: 366/999, TReward -0.9971509971509972, e: 0.16
episode: 367/999, TRew

episode: 506/999, TReward -0.9970760233918129, e: 0.079
episode: 507/999, TReward -1.0, e: 0.079
episode: 508/999, TReward -0.9976580796252927, e: 0.078
episode: 509/999, TReward -0.9968454258675079, e: 0.078
episode: 510/999, TReward -1.0, e: 0.078
episode: 511/999, TReward -0.9973821989528796, e: 0.077
episode: 512/999, TReward -0.9974874371859297, e: 0.077
episode: 513/999, TReward -0.997716894977169, e: 0.076
episode: 514/999, TReward -0.9973890339425587, e: 0.076
episode: 515/999, TReward -0.9978858350951374, e: 0.076
episode: 516/999, TReward -0.9963898916967509, e: 0.075
--------------------Reached  episode: 517/999 TReward -0.9931972789115646, e: 0.075
episode: 518/999, TReward -0.9964912280701754, e: 0.075
episode: 519/999, TReward -0.9979253112033195, e: 0.074
episode: 520/999, TReward -0.996969696969697, e: 0.074
episode: 521/999, TReward -0.997624703087886, e: 0.073
episode: 522/999, TReward -1.0, e: 0.073
episode: 523/999, TReward -0.997275204359673, e: 0.073
episode: 524/

episode: 668/999, TReward -0.9979674796747967, e: 0.035
episode: 669/999, TReward -0.9967637540453075, e: 0.035
episode: 670/999, TReward -0.997907949790795, e: 0.035
episode: 671/999, TReward -0.9955156950672646, e: 0.035
episode: 672/999, TReward -0.9977272727272727, e: 0.034
episode: 673/999, TReward -0.9976303317535545, e: 0.034
episode: 674/999, TReward -0.9969135802469136, e: 0.034
episode: 675/999, TReward -1.0, e: 0.034
episode: 676/999, TReward -0.9975609756097561, e: 0.034
episode: 677/999, TReward -0.9974554707379135, e: 0.034
episode: 678/999, TReward -0.9975369458128078, e: 0.033
episode: 679/999, TReward -0.9961240310077519, e: 0.033
episode: 680/999, TReward -0.9971181556195965, e: 0.033
episode: 681/999, TReward -0.9971181556195965, e: 0.033
episode: 682/999, TReward -0.9968944099378882, e: 0.033
episode: 683/999, TReward -0.9973821989528796, e: 0.033
episode: 684/999, TReward -0.9964028776978417, e: 0.032
episode: 685/999, TReward -0.997979797979798, e: 0.032
episode: 

episode: 822/999, TReward -1.0, e: 0.016
episode: 823/999, TReward -0.9971098265895953, e: 0.016
episode: 824/999, TReward -0.9974683544303797, e: 0.016
episode: 825/999, TReward -0.9970238095238095, e: 0.016
episode: 826/999, TReward -1.0, e: 0.016
episode: 827/999, TReward -0.9974619289340102, e: 0.016
episode: 828/999, TReward -0.997275204359673, e: 0.016
episode: 829/999, TReward -1.0, e: 0.016
episode: 830/999, TReward -1.0, e: 0.016
episode: 831/999, TReward -1.0, e: 0.016
episode: 832/999, TReward -0.996415770609319, e: 0.015
episode: 833/999, TReward -1.0, e: 0.015
episode: 834/999, TReward -1.0, e: 0.015
episode: 835/999, TReward -0.9975247524752475, e: 0.015
episode: 836/999, TReward -0.9969230769230769, e: 0.015
episode: 837/999, TReward -1.0, e: 0.015
episode: 838/999, TReward -0.9967845659163987, e: 0.015
episode: 839/999, TReward -0.9976958525345622, e: 0.015
episode: 840/999, TReward -0.9967532467532467, e: 0.015
episode: 841/999, TReward -0.9970059880239521, e: 0.015
ep

--------------------Reached  episode: 984/999 TReward -0.9928571428571429, e: 0.01
episode: 985/999, TReward -0.9969040247678018, e: 0.01
episode: 986/999, TReward -0.9959677419354839, e: 0.01
episode: 987/999, TReward -0.9958333333333333, e: 0.01
episode: 988/999, TReward -0.9965635738831615, e: 0.01
episode: 989/999, TReward -0.9967105263157895, e: 0.01
episode: 990/999, TReward -0.9974093264248705, e: 0.01
episode: 991/999, TReward -0.9958847736625515, e: 0.01
episode: 992/999, TReward -0.9962264150943396, e: 0.01
episode: 993/999, TReward -0.9970588235294118, e: 0.01
episode: 994/999, TReward -0.9959016393442623, e: 0.01
--------------------Reached  episode: 995/999 TReward -0.9945652173913043, e: 0.01
episode: 996/999, TReward -0.9961389961389961, e: 0.01
episode: 997/999, TReward -0.995575221238938, e: 0.01
episode: 998/999, TReward -0.9966329966329966, e: 0.01
episode: 999/999, TReward -0.9959677419354839, e: 0.01
