In [2]:
import copy
import pylab
import numpy as np
from environment import Env
from keras.layers import Dense
from keras.optimizers import Adam
from keras.models import Sequential
from keras import backend as K
%run environment.py

EPISODES = 2500

# this is REINFORCE Agent for GridWorld
class ReinforceAgent:
    def __init__(self):
        self.load_model = True
        # actions which agent can do
        self.action_space = [0, 1, 2, 3, 4]
        # get size of state and action
        self.action_size = len(self.action_space)
        self.state_size = 15
        self.discount_factor = 0.99
        self.learning_rate = 0.001

        self.model = self.build_model()
        self.optimizer = self.optimizer()
        self.states, self.actions, self.rewards = [], [], []

        if self.load_model:
            self.model.load_weights('./save_model/reinforce_trained.h5')

    # state is input and probability of each action(policy) is output of network
    def build_model(self):
        model = Sequential()
        model.add(Dense(24, input_dim=self.state_size, activation='relu'))
        model.add(Dense(24, activation='relu'))
        model.add(Dense(self.action_size, activation='softmax'))
        model.summary()
        return model

    # create error function and training function to update policy network
    def optimizer(self):
        action = K.placeholder(shape=[None, 5])
        discounted_rewards = K.placeholder(shape=[None, ])

        # Calculate cross entropy error function
        action_prob = K.sum(action * self.model.output, axis=1)
        cross_entropy = K.log(action_prob) * discounted_rewards
        loss = -K.sum(cross_entropy)

        # create training function
        optimizer = Adam(lr=self.learning_rate)
        updates = optimizer.get_updates(self.model.trainable_weights, [],
                                        loss)
        train = K.function([self.model.input, action, discounted_rewards], [],
                           updates=updates)

        return train

    # get action from policy network, see np.random.choice instructions
    def get_action(self, state):
        policy = self.model.predict(state)[0]
        return np.random.choice(self.action_size, 1, p=policy)[0] 

    # calculate discounted rewards
    def discount_rewards(self, rewards):
        discounted_rewards = np.zeros_like(rewards)
        running_add = 0
        for t in reversed(range(0, len(rewards))):
            running_add = running_add * self.discount_factor + rewards[t]
            discounted_rewards[t] = running_add
        return discounted_rewards

    # save states, actions and rewards for an episode
    def append_sample(self, state, action, reward):
        self.states.append(state[0])
        self.rewards.append(reward)
        act = np.zeros(self.action_size)
        act[action] = 1
        self.actions.append(act)

    # update policy neural network
    def train_model(self):
        discounted_rewards = np.float32(self.discount_rewards(self.rewards))
        discounted_rewards -= np.mean(discounted_rewards)
        discounted_rewards /= np.std(discounted_rewards)

        self.optimizer([self.states, self.actions, discounted_rewards])
        self.states, self.actions, self.rewards = [], [], []


if __name__ == "__main__":
    env = Env()
    agent = ReinforceAgent()

    global_step = 0
    scores, episodes = [], []

    for e in range(EPISODES):
        done = False
        score = 0
        # fresh env
        state = env.reset()
        state = np.reshape(state, [1, 15])

        while not done:
            start = time.time()
            global_step += 1
            # get action for the current state and go one step in environment
            action = agent.get_action(state)
            next_state, reward, done = env.step(action)
            next_state = np.reshape(next_state, [1, 15])

            agent.append_sample(state, action, reward)
            score += reward
            state = copy.deepcopy(next_state)

            end = time.time()
            print(end-start)
            
            if done:
                # update policy neural network for each episode
                agent.train_model()
                scores.append(score)
                episodes.append(e)
                score = round(score, 2)
                print("episode:", e, "  score:", score, "  time_step:",
                      global_step)

        if e % 100 == 0:
            pylab.plot(episodes, scores, 'b')
            pylab.savefig("./save_graph/reinforce.png")
            agent.model.save_weights("./save_model/reinforce.h5")

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
dense_4 (Dense)              (None, 24)                384       
_________________________________________________________________
dense_5 (Dense)              (None, 24)                600       
_________________________________________________________________
dense_6 (Dense)              (None, 5)                 125       
Total params: 1,109
Trainable params: 1,109
Non-trainable params: 0
_________________________________________________________________
0.018595218658447266
0.0008349418640136719
0.0007507801055908203
0.00063323974609375
0.000713348388671875
0.0006330013275146484
0.00067901611328125
0.0006549358367919922
0.0006136894226074219
0.0006229877471923828
0.0007040500640869141
episode: 0   score: -0.1   time_step: 11
0.0007979869842529297
0.0006930828094482422
0.0005948543548583984
0.0007479190826416016
0.0006172657012939453
0.000798225402832

0.0006968975067138672
0.0010018348693847656
0.0006320476531982422
0.0005691051483154297
0.0007548332214355469
0.000701904296875
0.0006849765777587891
0.0006809234619140625
0.0008099079132080078
episode: 31   score: 0.1   time_step: 293
0.0006377696990966797
0.0008459091186523438
0.0006680488586425781
0.000762939453125
0.0006988048553466797
0.0007691383361816406
0.0007259845733642578
0.0007441043853759766
0.0006902217864990234
episode: 32   score: 0.1   time_step: 302
0.0006971359252929688
0.0006668567657470703
0.0007371902465820312
0.0006921291351318359
0.0007371902465820312
0.0006628036499023438
0.0007939338684082031
0.0006539821624755859
0.0007040500640869141
episode: 33   score: 0.1   time_step: 311
0.0006661415100097656
0.0007979869842529297
0.0006339550018310547
0.0007300376892089844
0.0007140636444091797
0.0007238388061523438
0.0006861686706542969
0.0007791519165039062
0.0006449222564697266
episode: 34   score: 0.1   time_step: 320
0.0007040500640869141
0.0006911754608154297
0.00

0.0008220672607421875
0.0014798641204833984
0.000762939453125
0.0011680126190185547
0.0008068084716796875
0.0008268356323242188
0.0007288455963134766
0.0009479522705078125
episode: 68   score: 0.1   time_step: 635
0.0006692409515380859
0.0007340908050537109
0.0007202625274658203
0.0006880760192871094
0.0006840229034423828
0.0007898807525634766
0.0006363391876220703
0.0007081031799316406
0.0007910728454589844
episode: 69   score: 0.1   time_step: 644
0.0006511211395263672
0.0006899833679199219
0.0007951259613037109
0.0006811618804931641
0.0007340908050537109
0.0006971359252929688
0.0007269382476806641
0.0006351470947265625
0.0007212162017822266
episode: 70   score: 0.1   time_step: 653
0.0006811618804931641
0.0007781982421875
0.0006709098815917969
0.0006992816925048828
0.0006859302520751953
0.0007250308990478516
0.0006589889526367188
0.0007090568542480469
0.0006749629974365234
episode: 71   score: 0.1   time_step: 662
0.0006909370422363281
0.0006451606750488281
0.0007967948913574219
0.0

0.0008120536804199219
episode: 113   score: 0.1   time_step: 1041
0.0006890296936035156
0.0007102489471435547
0.0006849765777587891
0.000782012939453125
0.0006420612335205078
0.0007197856903076172
0.0006546974182128906
0.0007171630859375
0.0007128715515136719
episode: 114   score: 0.1   time_step: 1050
0.0007178783416748047
0.0007150173187255859
0.0008041858673095703
0.0007429122924804688
0.0007801055908203125
0.0007710456848144531
0.0008890628814697266
0.0007641315460205078
0.0008051395416259766
episode: 115   score: 0.1   time_step: 1059
0.0006740093231201172
0.0006780624389648438
0.0006270408630371094
0.0008127689361572266
0.0006508827209472656
0.0006608963012695312
0.0006802082061767578
0.0007140636444091797
0.0006759166717529297
episode: 116   score: 0.1   time_step: 1068
0.0007188320159912109
0.0005910396575927734
0.0006551742553710938
0.0006582736968994141
0.0007469654083251953
0.0006692409515380859
0.0006930828094482422
0.0006811618804931641
0.0007061958312988281
episode: 117  

0.0006508827209472656
episode: 150   score: 0.1   time_step: 1375
0.0006940364837646484
0.0007150173187255859
0.0006339550018310547
0.0006949901580810547
0.0006620883941650391
0.000698089599609375
0.0006442070007324219
0.0007717609405517578
0.0007550716400146484
episode: 151   score: 0.1   time_step: 1384
0.0008478164672851562
0.0007679462432861328
0.0010020732879638672
0.0006918907165527344
0.0007131099700927734
0.0006721019744873047
0.0007457733154296875
0.0006301403045654297
0.0007030963897705078
episode: 152   score: 0.1   time_step: 1393
0.0006568431854248047
0.0007328987121582031
0.0006508827209472656
0.00080108642578125
0.0006601810455322266
0.0007331371307373047
0.0006611347198486328
0.0007307529449462891
0.0006461143493652344
episode: 153   score: 0.1   time_step: 1402
0.0006821155548095703
0.0006327629089355469
0.0007658004760742188
0.0007550716400146484
0.0007069110870361328
0.0007679462432861328
0.0007579326629638672
0.0006248950958251953
0.0006110668182373047
episode: 154 

0.0006401538848876953
0.0006301403045654297
0.0006601810455322266
0.0007398128509521484
0.0006899833679199219
0.0006971359252929688
0.0006682872772216797
0.0006859302520751953
0.0006473064422607422
episode: 187   score: 0.1   time_step: 1708
0.0007581710815429688
0.0006420612335205078
0.0007081031799316406
0.0006349086761474609
0.0007779598236083984
0.0006580352783203125
0.0007138252258300781
0.0006439685821533203
0.0007240772247314453
episode: 188   score: 0.1   time_step: 1717
0.0006508827209472656
0.000675201416015625
0.0007138252258300781
0.0007081031799316406
0.0006458759307861328
0.0006988048553466797
0.0006399154663085938
0.0007197856903076172
0.0006039142608642578
episode: 189   score: 0.1   time_step: 1726
0.0006859302520751953
0.0007588863372802734
0.000827789306640625
0.0006699562072753906
0.0007390975952148438
0.0006718635559082031
0.0007700920104980469
0.0006630420684814453
0.0006561279296875
episode: 190   score: 0.1   time_step: 1735
0.0009379386901855469
0.0009899139404

0.0006711483001708984
0.0006358623504638672
0.0005838871002197266
0.0007939338684082031
0.0007071495056152344
0.000713348388671875
0.0006978511810302734
0.0007538795471191406
0.0005888938903808594
episode: 221   score: 0.1   time_step: 2014
0.0007469654083251953
0.0006601810455322266
0.0007989406585693359
0.0006759166717529297
0.0007150173187255859
0.0006110668182373047
0.0006918907165527344
0.0006592273712158203
0.0006947517395019531
episode: 222   score: 0.1   time_step: 2023
0.0006799697875976562
0.0007369518280029297
0.0006711483001708984
0.0006778240203857422
0.0006110668182373047
0.0007290840148925781
0.0007328987121582031
0.0007240772247314453
0.0006718635559082031
episode: 223   score: 0.1   time_step: 2032
0.0006930828094482422
0.0006480216979980469
0.0007669925689697266
0.0006902217864990234
0.0008361339569091797
0.0006842613220214844
0.0007088184356689453
0.0005948543548583984
0.0006518363952636719
episode: 224   score: 0.1   time_step: 2041
0.0008490085601806641
0.000853061

0.0008180141448974609
0.0009548664093017578
0.0007348060607910156
0.0008418560028076172
episode: 257   score: 0.1   time_step: 2339
0.0006480216979980469
0.0006849765777587891
0.0006477832794189453
0.0007159709930419922
0.0007200241088867188
0.0007359981536865234
0.0006630420684814453
0.0007290840148925781
0.0006382465362548828
episode: 258   score: 0.1   time_step: 2348
0.0007178783416748047
0.0006699562072753906
0.0007927417755126953
0.0006251335144042969
0.0006351470947265625
0.0006542205810546875
0.0007050037384033203
0.0006132125854492188
0.0007507801055908203
episode: 259   score: 0.1   time_step: 2357
0.0006210803985595703
0.0007109642028808594
0.0006389617919921875
0.0007779598236083984
0.0006659030914306641
0.0007998943328857422
0.0006957054138183594
0.0006921291351318359
0.0005998611450195312
episode: 260   score: 0.1   time_step: 2366
0.000804901123046875
0.0007908344268798828
0.0007648468017578125
0.0007979869842529297
0.0006878376007080078
0.0006389617919921875
0.000617027

0.0006432533264160156
episode: 294   score: 0.1   time_step: 2673
0.0006651878356933594
0.0007481575012207031
0.0006730556488037109
0.0008029937744140625
0.0006802082061767578
0.0008420944213867188
0.0006489753723144531
0.0006802082061767578
0.0006160736083984375
episode: 295   score: 0.1   time_step: 2682
0.0007178783416748047
0.0005981922149658203
0.0007863044738769531
0.0007510185241699219
0.0007779598236083984
0.0006849765777587891
0.000698089599609375
0.0006279945373535156
0.0007710456848144531
episode: 296   score: 0.1   time_step: 2691
0.0009601116180419922
0.0009348392486572266
0.0008020401000976562
0.0007469654083251953
0.0007069110870361328
0.0007789134979248047
0.0006649494171142578
0.0007109642028808594
0.0006673336029052734
episode: 297   score: 0.1   time_step: 2700
0.0007200241088867188
0.0006551742553710938
0.0007510185241699219
0.0007159709930419922
0.0007970333099365234
0.0006783008575439453
0.0007410049438476562
0.0007119178771972656
0.0008947849273681641
episode: 29

TclError: invalid command name ".!canvas"