In [1]:
import gym
import random
import numpy as np
import tflearn
from tflearn.layers.core import input_data, dropout, fully_connected
from tflearn.layers.estimator import regression
from statistics import mean, median
from collections import Counter

curses is not supported on this machine (please install/reinstall curses for an optimal experience)


In [2]:
lr = 1e-4
env = gym.make('CartPole-v0')
env.reset()

[2018-01-22 17:28:14,371] Making new env: CartPole-v0


array([ 0.04125665,  0.02371672,  0.03806007,  0.03194964])

In [4]:
goal_steps = 500
score_requirement = 50
initial_games = 25000

In [None]:
def some_random_games_first():
    for episode in range(5):
        env.reset()
        for t in range(goal_steps):
            env.render()
            # takes a random action from the action space
            action = env.action_space.sample()
            observation, reward, done, info = env.step(action)
            if done:
                break

In [None]:
some_random_games_first()

In [5]:
# generate training samples
def initial_population():
    # stores the observations that a random move made
    training_data = []
    scores = []
    # appends the data only if the corresponding score happens to be above 50
    accepted_scores = []
    for i in range(initial_games):
        score = 0
        # store game memory as we will not know until the end of the game whether or not we beat the required score
        game_memory = []
        prev_observation = []
        # iterate through the plausible steps
        for j in range(goal_steps):
            action = random.randrange(0, 2) # generates a 0 or a 1
            observation, reward, done, info = env.step(action)
            
            if len(prev_observation) > 0:
                # basically looks at the previous frame and stores the current action and if the score is high enough,
                # it will be appended to the training data as we looked at the previous frame and did something on
                # this frame that increased our score, and we would like our neural network to learn this mapping
                game_memory.append([prev_observation, action])
                
            prev_observation = observation
            score += reward
            if done:
                break
        
        if score >= score_requirement:
            accepted_scores.append(score)
            # converting output into a one-hot vector for training
            for data in game_memory:
                if data[1] == 1:
                    output = [0, 1]
                elif data[1] == 0:
                    output = [1, 0]
                    
                # appends the observation and the given output(in one-hot vector format)
                training_data.append([data[0], output])
                
        env.reset()
        scores.append(score)
    
    training_data_save = np.array(training_data)
    np.save('cartpole-v0_training_data.npy', training_data_save)
    
    print('Average accepted score: ', mean(accepted_scores))
    print('Median accepted score: ', median(accepted_scores))
    print(Counter(accepted_scores))
    
    return training_data

In [11]:
initial_population()

[2018-01-22 12:41:14,917] You are calling 'step()' even though this environment has already returned done = True. You should always call 'reset()' once you receive 'done = True' -- any further steps are undefined behavior.


Average accepted score:  60.79389312977099
Median accepted score:  57.0
Counter({51.0: 37, 54.0: 33, 50.0: 33, 52.0: 25, 55.0: 21, 56.0: 19, 60.0: 18, 57.0: 17, 53.0: 17, 62.0: 17, 59.0: 12, 67.0: 11, 58.0: 11, 65.0: 11, 61.0: 11, 70.0: 9, 63.0: 9, 71.0: 8, 64.0: 7, 69.0: 6, 77.0: 6, 78.0: 6, 74.0: 5, 68.0: 5, 76.0: 4, 66.0: 3, 93.0: 3, 75.0: 3, 90.0: 3, 73.0: 2, 89.0: 2, 83.0: 2, 85.0: 2, 81.0: 2, 84.0: 1, 118.0: 1, 109.0: 1, 80.0: 1, 79.0: 1, 91.0: 1, 94.0: 1, 110.0: 1, 92.0: 1, 88.0: 1, 87.0: 1, 86.0: 1, 98.0: 1})


[[array([-0.00597893, -0.20881766, -0.03947159,  0.2866333 ]), [0, 1]],
 [array([-0.01015528, -0.01315568, -0.03373893, -0.01823271]), [0, 1]],
 [array([-0.0104184 ,  0.18243348, -0.03410358, -0.32136684]), [1, 0]],
 [array([-0.00676973, -0.01218664, -0.04053092, -0.03963092]), [0, 1]],
 [array([-0.00701346,  0.18349238, -0.04132354, -0.34482122]), [0, 1]],
 [array([-0.00334361,  0.37917707, -0.04821996, -0.65024333]), [0, 1]],
 [array([ 0.00423993,  0.57493633, -0.06122483, -0.95771209]), [1, 0]],
 [array([ 0.01573865,  0.3806887 , -0.08037907, -0.68487543]), [1, 0]],
 [array([ 0.02335243,  0.18676928, -0.09407658, -0.41854145]), [0, 1]],
 [array([ 0.02708781,  0.38308966, -0.10244741, -0.73933759]), [1, 0]],
 [array([ 0.03474961,  0.18952029, -0.11723416, -0.48057278]), [1, 0]],
 [array([ 0.03854001, -0.00376857, -0.12684561, -0.22701755]), [1, 0]],
 [array([ 0.03846464, -0.19687106, -0.13138596,  0.02311639]), [1, 0]],
 [array([ 0.03452722, -0.38988803, -0.13092364,  0.27163005]), [

In [6]:
def neural_network_model(input_size):
    network = input_data(shape=[None, input_size, 1], name='input')
    network = fully_connected(network, 128, activation='relu')
    network = dropout(network, 0.8) # keep_rate=0.8
    network = fully_connected(network, 256, activation='relu')
    network = dropout(network, 0.8)
    network = fully_connected(network, 512, activation='relu')
    network = dropout(network, 0.8)
    network = fully_connected(network, 256, activation='relu')
    network = dropout(network, 0.8)
    network = fully_connected(network, 128, activation='relu')
    network = dropout(network, 0.8)
    network = fully_connected(network, 2, activation='softmax')
    network = regression(network, optimizer='adam', learning_rate=lr, loss='categorical_crossentropy', name='targets')
    model = tflearn.DNN(network, tensorboard_dir='log')
    return model

In [7]:
def train_model(training_data, model=False):
    X = np.array([i[0] for i in training_data]).reshape(-1, len(training_data[0][0]), 1) # grabs the observation column of the training data
    y = [i[1] for i in training_data]
    
    if not model:
        model = neural_network_model(input_size=len(X[0]))
    
    model.fit({'input':X}, {'targets':y}, n_epoch=5, snapshot_step=500, show_metric=True, run_id='cartpole-v0-v0.00')
    return model

In [8]:
training_data = initial_population()
model = train_model(training_data)

Training Step: 4309  | total loss: [1m[32m0.66298[0m[0m | time: 24.079s
| Adam | epoch: 005 | loss: 0.66298 - acc: 0.6074 -- iter: 55104/55126
Training Step: 4310  | total loss: [1m[32m0.65800[0m[0m | time: 24.137s
| Adam | epoch: 005 | loss: 0.65800 - acc: 0.6092 -- iter: 55126/55126
--


In [9]:
model.save('cartpole-v0-v0.10')

INFO:tensorflow:C:\Users\Aman Deep Singh\Documents\Python\Data Science\Machine Learning\Reinforcement Learning\cartpole-v0-v0.10 is not in all_model_checkpoint_paths. Manually adding it.


[2018-01-22 17:32:20,862] C:\Users\Aman Deep Singh\Documents\Python\Data Science\Machine Learning\Reinforcement Learning\cartpole-v0-v0.10 is not in all_model_checkpoint_paths. Manually adding it.


In [10]:
scores = []
choices = []
for each_game in range(10):
    score = 0
    game_memory = []
    prev_observation = []
    env.reset()
    for _ in range(goal_steps):
        env.render()
        if len(prev_observation) == 0:
            action = random.randrange(0, 2)
        else:
            print(model.predict(prev_observation.reshape(-1, len(prev_observation), 1)))
            action = np.argmax(model.predict(prev_observation.reshape(-1, len(prev_observation), 1))[0])
            
        choices.append(action)
        new_observation, reward, done, info = env.step(action)
        prev_observation = new_observation
        game_memory.append([new_observation, action])
        score += reward
        if done:
            break
    scores.append(score)
    
print('Average Score ', sum(scores)/len(scores))
print(f'Choice 1: {choices.count(1)/len(choices)}, Choice 0: {choices.count(0)/len(choices)}')

[[ 0.65564549  0.34435448]]
[[ 0.54755849  0.45244148]]
[[ 0.44263902  0.55736107]]
[[ 0.54467857  0.45532146]]
[[ 0.43925086  0.56074911]]
[[ 0.54060084  0.45939916]]
[[ 0.43510371  0.56489629]]
[[ 0.53540921  0.46459079]]
[[ 0.42989689  0.57010311]]
[[ 0.52909887  0.4709011 ]]
[[ 0.4234333  0.5765667]]
[[ 0.52155226  0.47844777]]
[[ 0.41570035  0.58429968]]
[[ 0.51260698  0.48739305]]
[[ 0.40658981  0.59341019]]
[[ 0.50230163  0.49769837]]
[[ 0.39598823  0.60401171]]
[[ 0.49077275  0.50922722]]
[[ 0.60059172  0.39940828]]
[[ 0.49285248  0.50714755]]
[[ 0.60435534  0.39564469]]
[[ 0.49611458  0.50388539]]
[[ 0.60931748  0.39068252]]
[[ 0.50080889  0.49919111]]
[[ 0.3952615   0.60473847]]
[[ 0.49079701  0.5092029 ]]
[[ 0.60286993  0.3971301 ]]
[[ 0.49455246  0.50544751]]
[[ 0.60826832  0.39173171]]
[[ 0.49969104  0.50030899]]
[[ 0.61492181  0.38507822]]
[[ 0.50615335  0.49384671]]
[[ 0.40116271  0.59883732]]
[[ 0.49774012  0.50225991]]
[[ 0.61218554  0.3878144 ]]
[[ 0.50377077  0.49622

[[ 0.49211276  0.50788724]]
[[ 0.60746723  0.39253277]]
[[ 0.49608687  0.50391316]]
[[ 0.61261553  0.38738453]]
[[ 0.50127238  0.49872765]]
[[ 0.39764103  0.60235894]]
[[ 0.49160579  0.50839424]]
[[ 0.60684556  0.39315444]]
[[ 0.495543    0.50445706]]
[[ 0.61195767  0.38804236]]
[[ 0.50065351  0.49934646]]
[[ 0.39705944  0.60294056]]
[[ 0.49096662  0.50903344]]
[[ 0.60603577  0.39396426]]
[[ 0.49481678  0.50518316]]
[[ 0.61105955  0.38894048]]
[[ 0.4997972   0.50020272]]
[[ 0.61713552  0.38286445]]
[[ 0.50597763  0.49402237]]
[[ 0.40262383  0.59737617]]
[[ 0.49706185  0.50293809]]
[[ 0.61375469  0.38624534]]
[[ 0.50236905  0.49763086]]
[[ 0.39870128  0.60129875]]
[[ 0.49273592  0.50726402]]
[[ 0.60825545  0.39174452]]
[[ 0.49683508  0.50316489]]
[[ 0.61351925  0.38648078]]
[[ 0.50217193  0.49782807]]
[[ 0.39855015  0.60144991]]
[[ 0.49257368  0.50742626]]
[[ 0.60807765  0.39192235]]
[[ 0.49669659  0.50330341]]
[[ 0.61336553  0.38663444]]
[[ 0.50204307  0.4979569 ]]
[[ 0.39845106  0.601

[[ 0.49969131  0.50030863]]
[[ 0.5008015   0.49919853]]
[[ 0.41359296  0.58640707]]
[[ 0.50949156  0.4905085 ]]
[[ 0.404246    0.59575397]]
[[ 0.49892855  0.50107139]]
[[ 0.61454797  0.38545206]]
[[ 0.50268841  0.49731159]]
[[ 0.39787614  0.60212386]]
[[ 0.49184012  0.50815982]]
[[ 0.60639977  0.39360029]]
[[ 0.4948245   0.50517553]]
[[ 0.61051756  0.38948241]]
[[ 0.49905971  0.50094038]]
[[ 0.61585033  0.3841497 ]]
[[ 0.50465578  0.49534431]]
[[ 0.40129569  0.59870434]]
[[ 0.49537098  0.50462896]]
[[ 0.61138237  0.38861763]]
[[ 0.50022292  0.49977705]]
[[ 0.39689946  0.60310054]]
[[ 0.49049976  0.50950021]]
[[ 0.60504824  0.39495176]]
[[ 0.49429193  0.5057081 ]]
[[ 0.61009902  0.38990095]]
[[ 0.4991616   0.50083846]]
[[ 0.61621815  0.38378182]]
[[ 0.50532991  0.49467003]]
[[ 0.40305445  0.59694558]]
[[ 0.49688768  0.50311232]]
[[ 0.61321384  0.38678622]]
[[ 0.50224996  0.49775001]]
[[ 0.39982393  0.60017604]]
[[ 0.49321121  0.50678885]]
[[ 0.6083883   0.39161161]]
[[ 0.49760777  0.502

[[ 0.60452962  0.39547035]]
[[ 0.49767488  0.50232518]]
[[ 0.60856724  0.39143273]]
[[ 0.50197798  0.49802202]]
[[ 0.40227589  0.59772408]]
[[ 0.49255711  0.50744295]]
[[ 0.6018272  0.3981728]]
[[ 0.49559656  0.50440347]]
[[ 0.605268    0.39473197]]
[[ 0.49934685  0.50065309]]
[[ 0.60954505  0.39045489]]
[[ 0.50384218  0.49615791]]
[[ 0.4048624   0.59513754]]
[[ 0.49475056  0.5052495 ]]
[[ 0.60347849  0.39652151]]
[[ 0.49800953  0.50199044]]
[[ 0.60713214  0.39286789]]
[[ 0.50195831  0.49804169]]
[[ 0.40295035  0.59704971]]
[[ 0.49231249  0.50768751]]
[[ 0.59959638  0.40040353]]
[[ 0.49488199  0.50511807]]
[[ 0.60240537  0.3975946 ]]
[[ 0.49805641  0.50194359]]
[[ 0.60598183  0.39401808]]
[[ 0.50189126  0.49810874]]
[[ 0.40345791  0.59654206]]
[[ 0.49224174  0.50775826]]
[[ 0.59821296  0.40178704]]
[[ 0.49460998  0.50538993]]
[[ 0.60074204  0.39925802]]
[[ 0.49752364  0.50247633]]
[[ 0.60400695  0.39599299]]
[[ 0.50105417  0.49894586]]
[[ 0.40300995  0.59699005]]
[[ 0.49113289  0.50886

[[ 0.56931514  0.43068486]]
[[ 0.50297076  0.49702922]]
[[ 0.4230046   0.57699543]]
[[ 0.49880287  0.50119716]]
[[ 0.56798393  0.4320161 ]]
[[ 0.50500596  0.49499407]]
[[ 0.42965868  0.57034129]]
[[ 0.50165558  0.49834439]]
[[ 0.42919359  0.57080644]]
[[ 0.49814367  0.50185627]]
[[ 0.56204844  0.43795156]]
[[ 0.50421596  0.49578395]]
[[ 0.43634161  0.56365836]]
[[ 0.50173324  0.49826673]]
[[ 0.43653265  0.56346726]]
[[ 0.49916768  0.50083232]]
[[ 0.55799574  0.44200426]]
[[ 0.50516623  0.49483374]]
[[ 0.4449982  0.5550018]]
[[ 0.50384444  0.49615559]]
[[ 0.44646811  0.55353189]]
[[ 0.50257027  0.49742979]]
[[ 0.44765052  0.55234945]]
[[ 0.50131536  0.49868464]]
[[ 0.44863677  0.55136323]]
[[ 0.50005317  0.4999468 ]]
[[ 0.44963795  0.55036199]]
[[ 0.49897504  0.50102496]]
[[ 0.54603446  0.4539656 ]]
[[ 0.50493938  0.49506059]]
[[ 0.45807654  0.54192346]]
[[ 0.50517154  0.49482846]]
[[ 0.46029079  0.53970915]]
[[ 0.5053035   0.49469653]]
[[ 0.46244094  0.53755915]]
[[ 0.50555378  0.49444

[[ 0.49311239  0.50688756]]
[[ 0.6088087   0.39119133]]
[[ 0.49779388  0.50220621]]
[[ 0.61475605  0.38524398]]
[[ 0.5039835   0.49601647]]
[[ 0.40096846  0.59903151]]
[[ 0.49540472  0.50459528]]
[[ 0.61174935  0.38825071]]
[[ 0.50105864  0.49894136]]
[[ 0.39825606  0.60174394]]
[[ 0.49219236  0.50780767]]
[[ 0.60749519  0.39250484]]
[[ 0.49702734  0.50297272]]
[[ 0.61373162  0.38626832]]
[[ 0.50321043  0.49678957]]
[[ 0.40121958  0.59878051]]
[[ 0.4949576   0.50504243]]
[[ 0.61086166  0.38913837]]
[[ 0.50044638  0.49955359]]
[[ 0.39845499  0.60154492]]
[[ 0.49172685  0.50827318]]
[[ 0.60637784  0.39362216]]
[[ 0.49638563  0.50361431]]
[[ 0.61200154  0.38799855]]
[[ 0.50221217  0.49778774]]
[[ 0.40085775  0.59914219]]
[[ 0.49391124  0.50608873]]
[[ 0.60846275  0.39153731]]
[[ 0.49895301  0.50104696]]
[[ 0.61439329  0.38560668]]
[[ 0.50516313  0.49483681]]
[[ 0.40440986  0.59559017]]
[[ 0.49729908  0.50270087]]
[[ 0.61191624  0.38808376]]
[[ 0.50279498  0.49720502]]
[[ 0.40186962  0.598

[[ 0.49007922  0.50992078]]
[[ 0.57422495  0.42577496]]
[[ 0.48979324  0.51020676]]
[[ 0.57287586  0.42712414]]
[[ 0.48945734  0.51054275]]
[[ 0.57145941  0.42854062]]
[[ 0.48903391  0.51096606]]
[[ 0.56993753  0.43006247]]
[[ 0.48848835  0.51151156]]
[[ 0.56827366  0.43172634]]
[[ 0.48777315  0.51222688]]
[[ 0.56643564  0.43356439]]
[[ 0.48683664  0.51316333]]
[[ 0.56442183  0.43557814]]
[[ 0.48567882  0.51432115]]
[[ 0.5621919   0.43780807]]
[[ 0.48426208  0.51573789]]
[[ 0.55969095  0.44030905]]
[[ 0.4825139  0.5174861]]
[[ 0.55697268  0.44302735]]
[[ 0.4804208   0.51957923]]
[[ 0.55390567  0.44609436]]
[[ 0.47789291  0.52210701]]
[[ 0.55042058  0.44957939]]
[[ 0.47494996  0.5250501 ]]
[[ 0.54650491  0.45349509]]
[[ 0.47144118  0.52855885]]
[[ 0.5419665   0.45803353]]
[[ 0.46744329  0.53255671]]
[[ 0.53672367  0.46327627]]
[[ 0.46280512  0.53719485]]
[[ 0.53075856  0.46924144]]
[[ 0.4572798  0.5427202]]
[[ 0.52386898  0.47613096]]
[[ 0.4507758  0.5492242]]
[[ 0.5161435   0.48385647]