In [1]:
import gym
import random
import numpy as np
import tflearn
from tflearn.layers.core import input_data, dropout, fully_connected
from tflearn.layers.estimator import regression
from statistics import mean, median
from collections import Counter

curses is not supported on this machine (please install/reinstall curses for an optimal experience)


In [2]:
lr = 1e-4
env = gym.make('CartPole-v0')
env.reset()

[2018-01-22 17:28:14,371] Making new env: CartPole-v0


array([ 0.04125665,  0.02371672,  0.03806007,  0.03194964])

In [3]:
goal_steps = 500
score_requirement = 50
initial_games = 10000

In [None]:
def some_random_games_first():
    for episode in range(5):
        env.reset()
        for t in range(goal_steps):
            env.render()
            # takes a random action from the action space
            action = env.action_space.sample()
            observation, reward, done, info = env.step(action)
            if done:
                break

In [None]:
some_random_games_first()

In [10]:
# generate training samples
def initial_population():
    # stores the observations that a random move made
    training_data = []
    scores = []
    # appends the data only if the corresponding score happens to be above 50
    accepted_scores = []
    for i in range(initial_games):
        score = 0
        # store game memory as we will not know until the end of the game whether or not we beat the required score
        game_memory = []
        prev_observation = []
        # iterate through the plausible steps
        for j in range(goal_steps):
            action = random.randrange(0, 2) # generates a 0 or a 1
            observation, reward, done, info = env.step(action)
            
            if len(prev_observation) > 0:
                # basically looks at the previous frame and stores the current action and if the score is high enough,
                # it will be appended to the training data as we looked at the previous frame and did something on
                # this frame that increased our score, and we would like our neural network to learn this mapping
                game_memory.append([prev_observation, action])
                
            prev_observation = observation
            score += reward
            if done:
                break
        
        if score >= score_requirement:
            accepted_scores.append(score)
            # converting output into a one-hot vector for training
            for data in game_memory:
                if data[1] == 1:
                    output = [0, 1]
                elif data[1] == 0:
                    output = [1, 0]
                    
                # appends the observation and the given output(in one-hot vector format)
                training_data.append([data[0], output])
                
        env.reset()
        scores.append(score)
    
    training_data_save = np.array(training_data)
    np.save('cartpole-v0_training_data.npy', training_data_save)
    
    print('Average accepted score: ', mean(accepted_scores))
    print('Median accepted score: ', median(accepted_scores))
    print(Counter(accepted_scores))
    
    return training_data

In [11]:
initial_population()

[2018-01-22 12:41:14,917] You are calling 'step()' even though this environment has already returned done = True. You should always call 'reset()' once you receive 'done = True' -- any further steps are undefined behavior.


Average accepted score:  60.79389312977099
Median accepted score:  57.0
Counter({51.0: 37, 54.0: 33, 50.0: 33, 52.0: 25, 55.0: 21, 56.0: 19, 60.0: 18, 57.0: 17, 53.0: 17, 62.0: 17, 59.0: 12, 67.0: 11, 58.0: 11, 65.0: 11, 61.0: 11, 70.0: 9, 63.0: 9, 71.0: 8, 64.0: 7, 69.0: 6, 77.0: 6, 78.0: 6, 74.0: 5, 68.0: 5, 76.0: 4, 66.0: 3, 93.0: 3, 75.0: 3, 90.0: 3, 73.0: 2, 89.0: 2, 83.0: 2, 85.0: 2, 81.0: 2, 84.0: 1, 118.0: 1, 109.0: 1, 80.0: 1, 79.0: 1, 91.0: 1, 94.0: 1, 110.0: 1, 92.0: 1, 88.0: 1, 87.0: 1, 86.0: 1, 98.0: 1})


[[array([-0.00597893, -0.20881766, -0.03947159,  0.2866333 ]), [0, 1]],
 [array([-0.01015528, -0.01315568, -0.03373893, -0.01823271]), [0, 1]],
 [array([-0.0104184 ,  0.18243348, -0.03410358, -0.32136684]), [1, 0]],
 [array([-0.00676973, -0.01218664, -0.04053092, -0.03963092]), [0, 1]],
 [array([-0.00701346,  0.18349238, -0.04132354, -0.34482122]), [0, 1]],
 [array([-0.00334361,  0.37917707, -0.04821996, -0.65024333]), [0, 1]],
 [array([ 0.00423993,  0.57493633, -0.06122483, -0.95771209]), [1, 0]],
 [array([ 0.01573865,  0.3806887 , -0.08037907, -0.68487543]), [1, 0]],
 [array([ 0.02335243,  0.18676928, -0.09407658, -0.41854145]), [0, 1]],
 [array([ 0.02708781,  0.38308966, -0.10244741, -0.73933759]), [1, 0]],
 [array([ 0.03474961,  0.18952029, -0.11723416, -0.48057278]), [1, 0]],
 [array([ 0.03854001, -0.00376857, -0.12684561, -0.22701755]), [1, 0]],
 [array([ 0.03846464, -0.19687106, -0.13138596,  0.02311639]), [1, 0]],
 [array([ 0.03452722, -0.38988803, -0.13092364,  0.27163005]), [

In [12]:
def neural_network_model(input_size):
    network = input_data(shape=[None, input_size, 1], name='input')
    network = fully_connected(network, 128, activation='relu')
    network = dropout(network, 0.8) # keep_rate=0.8
    network = fully_connected(network, 256, activation='relu')
    network = dropout(network, 0.8)
    network = fully_connected(network, 512, activation='relu')
    network = dropout(network, 0.8)
    network = fully_connected(network, 256, activation='relu')
    network = dropout(network, 0.8)
    network = fully_connected(network, 128, activation='relu')
    network = dropout(network, 0.8)
    network = fully_connected(network, 2, activation='softmax')
    network = regression(network, optimizer='adam', learning_rate=lr, loss='categorical_crossentropy', name='targets')
    model = tflearn.DNN(network, tensorboard_dir='log')
    return model

In [15]:
def train_model(training_data, model=False):
    X = np.array([i[0] for i in training_data]).reshape(-1, len(training_data[0][0]), 1) # grabs the observation column of the training data
    y = [i[1] for i in training_data]
    
    if not model:
        model = neural_network_model(input_size=len(X[0]))
    
    model.fit({'input':X}, {'targets':y}, n_epoch=5, snapshot_step=500, show_metric=True, run_id='cartpole-v0-v0.00')
    return model

In [16]:
training_data = initial_population()
model = train_model(training_data)

Training Step: 1754  | total loss: [1m[32m0.66691[0m[0m | time: 8.414s
| Adam | epoch: 005 | loss: 0.66691 - acc: 0.5935 -- iter: 22400/22434
Training Step: 1755  | total loss: [1m[32m0.66660[0m[0m | time: 8.446s
| Adam | epoch: 005 | loss: 0.66660 - acc: 0.5966 -- iter: 22434/22434
--


In [17]:
model.save('cartpole-v0-v0.00')

INFO:tensorflow:C:\Users\Aman Deep Singh\Documents\Python\Data Science\Machine Learning\Reinforcement Learning\cartpole-v0-v0.00 is not in all_model_checkpoint_paths. Manually adding it.


[2018-01-22 13:14:20,888] C:\Users\Aman Deep Singh\Documents\Python\Data Science\Machine Learning\Reinforcement Learning\cartpole-v0-v0.00 is not in all_model_checkpoint_paths. Manually adding it.


In [18]:
scores = []
choices = []
for each_game in range(10):
    score = 0
    game_memory = []
    prev_observation = []
    env.reset()
    for _ in range(goal_steps):
        env.render()
        if len(prev_observation) == 0:
            action = random.randrange(0, 2)
        else:
            print(model.predict(prev_observation.reshape(-1, len(prev_observation), 1)))
            action = np.argmax(model.predict(prev_observation.reshape(-1, len(prev_observation), 1))[0])
            
        choices.append(action)
        new_observation, reward, done, info = env.step(action)
        prev_observation = new_observation
        game_memory.append([new_observation, action])
        score += reward
        if done:
            break
    scores.append(score)
    
print('Average Score ', sum(scores)/len(scores))
print(f'Choice 1: {choices.count(1)/len(choices)}, Choice 0: {choices.count(0)/len(choices)}')

[[ 0.42026985  0.57973015]]
[[ 0.45732287  0.5426771 ]]
[[ 0.52607733  0.47392267]]
[[ 0.45816597  0.541834  ]]
[[ 0.5336014  0.4663986]]
[[ 0.45944723  0.54055285]]
[[ 0.54118496  0.45881504]]
[[ 0.46124899  0.53875101]]
[[ 0.54909354  0.45090646]]
[[ 0.46368751  0.53631252]]
[[ 0.55687678  0.44312322]]
[[ 0.46708906  0.53291088]]
[[ 0.56493688  0.43506309]]
[[ 0.47152847  0.52847153]]
[[ 0.57421583  0.42578408]]
[[ 0.47728881  0.52271116]]
[[ 0.58534509  0.41465491]]
[[ 0.48469114  0.51530886]]
[[ 0.59687316  0.4031269 ]]
[[ 0.49518701  0.50481302]]
[[ 0.61005825  0.38994175]]
[[ 0.51057297  0.48942703]]
[[ 0.44847319  0.5515269 ]]
[[ 0.50720739  0.49279261]]
[[ 0.44782892  0.55217111]]
[[ 0.50370145  0.49629858]]
[[ 0.44712827  0.55287176]]
[[ 0.50037062  0.49962941]]
[[ 0.44633174  0.55366826]]
[[ 0.49742299  0.50257707]]
[[ 0.60183585  0.39816415]]
[[ 0.5152958   0.48470417]]
[[ 0.45162171  0.54837829]]
[[ 0.51359391  0.48640612]]
[[ 0.45172241  0.54827756]]
[[ 0.51222038  0.48777

[[ 0.54179937  0.4582006 ]]
[[ 0.47005364  0.52994633]]
[[ 0.54692459  0.45307544]]
[[ 0.47457522  0.52542484]]
[[ 0.55245578  0.44754422]]
[[ 0.48107028  0.51892972]]
[[ 0.55765706  0.44234294]]
[[ 0.48970655  0.51029348]]
[[ 0.56330436  0.43669561]]
[[ 0.49908054  0.5009194 ]]
[[ 0.56936735  0.43063259]]
[[ 0.51271361  0.48728642]]
[[ 0.46164232  0.53835773]]
[[ 0.51319003  0.48680994]]
[[ 0.46342498  0.53657502]]
[[ 0.51396561  0.48603442]]
[[ 0.46568781  0.53431225]]
[[ 0.51530516  0.48469487]]
[[ 0.46848166  0.53151828]]
[[ 0.51721799  0.48278204]]
[[ 0.4720048   0.52799523]]
[[ 0.51971877  0.48028117]]
[[ 0.47633392  0.52366602]]
[[ 0.5223577   0.47764236]]
[[ 0.4818759  0.5181241]]
[[ 0.52481288  0.47518709]]
[[ 0.48849249  0.51150751]]
[[ 0.52692699  0.47307301]]
[[ 0.49437439  0.50562567]]
[[ 0.52894533  0.4710547 ]]
[[ 0.50200188  0.49799809]]
[[ 0.4698711   0.53012896]]
[[ 0.50241077  0.49758923]]
[[ 0.47295722  0.52704281]]
[[ 0.50205851  0.49794155]]
[[ 0.47640637  0.52359

[[ 0.49506998  0.50493008]]
[[ 0.6323573   0.36764273]]
[[ 0.50308645  0.49691355]]
[[ 0.44016173  0.55983829]]
[[ 0.48918438  0.51081556]]
[[ 0.62656528  0.37343475]]
[[ 0.4953261  0.5046739]]
[[ 0.63297218  0.36702788]]
[[ 0.50276375  0.49723619]]
[[ 0.44076285  0.55923712]]
[[ 0.48841116  0.51158887]]
[[ 0.62598085  0.37401918]]
[[ 0.49374935  0.50625062]]
[[ 0.63174117  0.36825886]]
[[ 0.50035083  0.49964914]]
[[ 0.44063869  0.55936128]]
[[ 0.48525158  0.51474845]]
[[ 0.62272435  0.37727556]]
[[ 0.49008414  0.50991583]]
[[ 0.62702513  0.3729749 ]]
[[ 0.49537462  0.50462538]]
[[ 0.63165724  0.36834276]]
[[ 0.50155044  0.4984495 ]]
[[ 0.44187763  0.55812246]]
[[ 0.4856863  0.5143137]]
[[ 0.62091392  0.37908611]]
[[ 0.48955619  0.51044387]]
[[ 0.62373102  0.37626898]]
[[ 0.49333325  0.50666672]]
[[ 0.62675142  0.37324855]]
[[ 0.49732807  0.50267196]]
[[ 0.62981927  0.37018073]]
[[ 0.50215918  0.49784082]]
[[ 0.44366157  0.55633843]]
[[ 0.48575008  0.51424992]]
[[ 0.61728382  0.3827161

[[ 0.4873648   0.51263523]]
[[ 0.49922785  0.50077218]]
[[ 0.51623863  0.48376137]]
[[ 0.50325239  0.49674767]]
[[ 0.49228653  0.50771344]]
[[ 0.50287437  0.49712566]]
[[ 0.49221513  0.50778484]]
[[ 0.50361234  0.49638766]]
[[ 0.49277154  0.50722843]]
[[ 0.50471056  0.4952895 ]]
[[ 0.49418491  0.50581509]]
[[ 0.50668603  0.49331397]]
[[ 0.49697065  0.50302941]]
[[ 0.50800043  0.49199954]]
[[ 0.50306743  0.4969326 ]]
[[ 0.44900328  0.55099666]]
[[ 0.50231349  0.49768656]]
[[ 0.44814116  0.55185878]]
[[ 0.50149024  0.49850976]]
[[ 0.44723335  0.55276662]]
[[ 0.50064862  0.49935141]]
[[ 0.44629395  0.55370605]]
[[ 0.49971011  0.50028986]]
[[ 0.6068095  0.3931905]]
[[ 0.52311683  0.47688317]]
[[ 0.45187628  0.54812372]]
[[ 0.52505952  0.47494051]]
[[ 0.4521164   0.54788363]]
[[ 0.52694291  0.47305706]]
[[ 0.45240107  0.5475989 ]]
[[ 0.52894092  0.47105911]]
[[ 0.45272854  0.54727149]]
[[ 0.53103787  0.4689621 ]]
[[ 0.45310357  0.5468964 ]]
[[ 0.53328544  0.46671453]]
[[ 0.45354453  0.54645

[[ 0.49693438  0.50306559]]
[[ 0.50409079  0.49590924]]
[[ 0.49649721  0.50350285]]
[[ 0.50247413  0.49752578]]
[[ 0.49647078  0.50352925]]
[[ 0.50116378  0.49883625]]
[[ 0.49650875  0.50349122]]
[[ 0.50023139  0.49976864]]
[[ 0.49681684  0.50318313]]
[[ 0.49921304  0.50078702]]
[[ 0.49703622  0.50296372]]
[[ 0.49197415  0.50802583]]
[[ 0.48560879  0.51439118]]
[[ 0.47960472  0.52039522]]
[[ 0.47576565  0.52423435]]
[[ 0.47308496  0.52691507]]
[[ 0.47050518  0.52949482]]
[[ 0.63189447  0.36810553]]
[[ 0.52014017  0.47985986]]
[[ 0.44580758  0.55419242]]
[[ 0.5105328   0.48946717]]
[[ 0.44348207  0.5565179 ]]
[[ 0.5023011  0.4976989]]
[[ 0.44083184  0.55916816]]
[[ 0.49468714  0.5053128 ]]
[[ 0.62172854  0.37827137]]
[[ 0.50607747  0.49392247]]
[[ 0.44223288  0.55776715]]
[[ 0.49798888  0.50201112]]
[[ 0.62542367  0.37457636]]
[[ 0.51138663  0.4886134 ]]
[[ 0.44403023  0.55596977]]
[[ 0.50307214  0.49692789]]
[[ 0.44140613  0.55859381]]
[[ 0.49540654  0.50459349]]
[[ 0.62180132  0.37819