In [9]:
import numpy as np
from tqdm import tqdm_notebook as tqdm

In [10]:
class shortGridWorld():
    def __init__(self, start_state):
        self.state = start_state
        self.reward = -1
    
    def take_action(self, state, action):
        if state == 0:
            if all(action == np.array([[0.], [1.]])):
                self.state = 0
            else:
                self.state = 1
        elif state == 1:
            if all(action == np.array([[0.], [1.]])):
                self.state = 2
            else:
                self.state = 1
        elif state == 2:
            if all(action ==np.array([[0.],[1.]])):
                self.state = 1
            else:
                self.state = 3
            
    def is_terminal(self):
        return self.state == 3        

In [112]:
class REINFORCE_mc():
    def __init__(self, Problem, num_episodes=100, y=0.9):
        self.Problem = Problem
        self.num_episodes = num_episodes
        self.y = y
        self.theta = np.zeros((2,1))
        self.x = np.array([[1.,0.],[0.,1.]])
        self.r = []
        self.s = []
        self.a = []
        self.G = []
        
        self.I = []
        self.L = []
        
    def get_policy_distribution(self):
        return (np.exp(self.theta.T @ self.x)/np.sum(np.exp(self.theta.T @ self.x))).T
    
    def sample_from_distribution(self, P):
        CP = np.cumsum(P)
        rn = np.random.uniform()

        for i in range(len(CP)):
            if CP[i] >= rn:
                return i
    
    def get_action(self):
        action = np.zeros((2,1))

        # maybe a separate functon to get the feature vector of action
        action[self.sample_from_distribution(self.get_policy_distribution())] = 1 
        return action
    
    def clear(self):
        self.r = []
        self.s = []
        self.a = []
        
    def gen_episode(self,s0=0, max_len=1000):
        self.clear()
        problem = self.Problem(s0)
        
        state = s0
        self.s.append(state)
        
        G = 0
        I = 1
        
        while (not problem.is_terminal()) and max_len != 0:
            action = self.get_action()
            self.a.append(np.argmax(action))
            
            problem.take_action(state, action)
            state = problem.state
            self.s.append(state)
            
            self.r.append(problem.reward)
            G += problem.reward * I
            I *= self.y
            
            max_len -= 1
        
        self.G.append(np.sum(np.array(self.r)))
        self.I.append(I)
        self.L.append((len(self.a), self.G[-1]))
        
        return G
    
    def mc_without_baseline(self, max_ep_len=1000):
        for i in tqdm(range(self.num_episodes)):
            alpha = 2**(-12)
            G = self.gen_episode(0, max_ep_len)
            sub = 0
            I = 1
            
            for t in range(len(self.s)-1):
                G -= sub
                delta_ln = np.copy(self.x[:,self.a[t], np.newaxis])
                pi = self.get_policy_distribution()
                delta_ln -= self.x @ pi
    
                # y^t is covered in G in this implementation. G is not strictly G, rather y^t*G
                self.theta = self.theta + alpha*G*delta_ln 
                
                sub = self.r[t]*I
                I *= self.y
        
#         print(self.get_policy_distribution())
#         print(min(self.I))
#         print(self.L)
                
            

In [116]:
rew = np.zeros((10,1))
for i in tqdm(range(10)):
    a = REINFORCE_mc(shortGridWorld, num_episodes=10, y=1)    
    a.mc_without_baseline(max_ep_len=1000)
    rew += np.reshape(np.array(a.G), (10,1))
    print(rew, a.G)
#     print(np.array(a.G).mean())
    a.G = []
    a.L = []
print(rew/100)

HBox(children=(IntProgress(value=0, max=10), HTML(value='')))

HBox(children=(IntProgress(value=0, max=10), HTML(value='')))

[[ -9.]
 [ -9.]
 [ -5.]
 [ -4.]
 [ -4.]
 [ -7.]
 [ -9.]
 [-11.]
 [ -3.]
 [ -5.]] [-9, -9, -5, -4, -4, -7, -9, -11, -3, -5]


HBox(children=(IntProgress(value=0, max=10), HTML(value='')))

[[-13.]
 [-20.]
 [-12.]
 [ -7.]
 [-10.]
 [-17.]
 [-12.]
 [-16.]
 [ -8.]
 [-22.]] [-4, -11, -7, -3, -6, -10, -3, -5, -5, -17]


HBox(children=(IntProgress(value=0, max=10), HTML(value='')))

[[-23.]
 [-24.]
 [-20.]
 [-11.]
 [-13.]
 [-27.]
 [-27.]
 [-35.]
 [-14.]
 [-29.]] [-10, -4, -8, -4, -3, -10, -15, -19, -6, -7]


HBox(children=(IntProgress(value=0, max=10), HTML(value='')))

[[-28.]
 [-41.]
 [-30.]
 [-20.]
 [-24.]
 [-32.]
 [-39.]
 [-43.]
 [-20.]
 [-32.]] [-5, -17, -10, -9, -11, -5, -12, -8, -6, -3]


HBox(children=(IntProgress(value=0, max=10), HTML(value='')))

[[-41.]
 [-50.]
 [-41.]
 [-24.]
 [-28.]
 [-38.]
 [-45.]
 [-52.]
 [-25.]
 [-35.]] [-13, -9, -11, -4, -4, -6, -6, -9, -5, -3]


HBox(children=(IntProgress(value=0, max=10), HTML(value='')))

[[-48.]
 [-62.]
 [-55.]
 [-27.]
 [-42.]
 [-42.]
 [-52.]
 [-55.]
 [-33.]
 [-40.]] [-7, -12, -14, -3, -14, -4, -7, -3, -8, -5]


HBox(children=(IntProgress(value=0, max=10), HTML(value='')))

[[-70.]
 [-68.]
 [-59.]
 [-32.]
 [-54.]
 [-49.]
 [-60.]
 [-81.]
 [-40.]
 [-51.]] [-22, -6, -4, -5, -12, -7, -8, -26, -7, -11]


HBox(children=(IntProgress(value=0, max=10), HTML(value='')))

[[-77.]
 [-79.]
 [-71.]
 [-40.]
 [-61.]
 [-66.]
 [-76.]
 [-86.]
 [-43.]
 [-64.]] [-7, -11, -12, -8, -7, -17, -16, -5, -3, -13]


HBox(children=(IntProgress(value=0, max=10), HTML(value='')))

[[-81.]
 [-89.]
 [-78.]
 [-47.]
 [-68.]
 [-69.]
 [-87.]
 [-94.]
 [-47.]
 [-74.]] [-4, -10, -7, -7, -7, -3, -11, -8, -4, -10]


HBox(children=(IntProgress(value=0, max=10), HTML(value='')))

[[ -86.]
 [ -96.]
 [ -83.]
 [ -55.]
 [ -73.]
 [ -75.]
 [ -90.]
 [-103.]
 [ -70.]
 [ -77.]] [-5, -7, -5, -8, -5, -6, -3, -9, -23, -3]
[[-0.86]
 [-0.96]
 [-0.83]
 [-0.55]
 [-0.73]
 [-0.75]
 [-0.9 ]
 [-1.03]
 [-0.7 ]
 [-0.77]]


In [114]:
a.get_policy_distribution()

array([[0.60108247],
       [0.39891753]])

In [32]:
b = np.cumsum(np.array(a.G))
for i in range(99, len(b)):
    print((b[i]-b[i-99]+a.G[i-99])/100)

-8.38
-8.33
-8.33
-8.26
-8.25
-8.24
-8.24
-8.11
-8.1
-8.05
-8.04
-8.04
-8.12
-8.08
-8.03
-7.95
-7.82
-7.82
-7.84
-7.78
-7.67
-7.71
-7.62
-7.61
-7.64
-7.58
-7.64
-7.69
-7.64
-7.73
-7.74
-7.77
-7.73
-7.71
-7.84
-7.83
-7.81
-7.78
-7.78
-7.74
-7.74
-7.75
-7.83
-7.77
-7.83
-7.84
-7.79
-7.69
-7.68
-7.63
-7.58
-7.56
-7.58
-7.53
-7.49
-7.62
-7.6
-7.65
-7.71
-7.7
-7.62
-7.62
-7.63
-7.64
-7.6
-7.61
-7.59
-7.64
-7.6
-7.58
-7.57
-7.64
-7.57
-7.57
-7.62
-7.61
-7.26
-7.24
-7.23
-7.31
-7.18
-7.15
-7.18
-7.15
-7.17
-7.19
-7.23
-7.21
-7.11
-7.04
-7.05
-7.03
-7.03
-6.99
-6.97
-6.85
-6.81
-6.79
-6.7
-6.74
-6.74
-6.79
-6.81
-6.81
-6.85
-6.87
-6.91
-6.9
-6.92
-6.9
-6.9
-6.87
-6.71
-6.77
-6.75
-6.75
-6.73
-6.9
-6.9
-6.86
-6.9
-6.89
-6.9
-6.94
-6.89
-6.95
-6.96
-6.9
-6.95
-6.88
-7.05
-7.03
-6.92
-6.95
-6.82
-6.77
-6.78
-6.83
-6.82
-6.86
-6.87
-6.86
-6.92
-6.99
-7.05
-7.07
-7.12
-7.17
-7.27
-7.28
-7.29
-7.24
-7.34
-7.34
-7.34
-7.28
-7.3
-7.27
-7.23
-7.24
-7.24
-7.29
-7.3
-7.32
-7.34
-7.37
-7.36
-7.34
-7.36
-7

In [17]:
b

array([  -17,   -20,   -29,   -34,   -38,   -42,   -61,   -65,   -71,
         -74,   -83,   -87,  -101,  -118,  -124,  -133,  -140,  -146,
        -152,  -158,  -161,  -166,  -173,  -182,  -187,  -194,  -205,
        -219,  -223,  -229,  -232,  -243,  -249,  -253,  -259,  -262,
        -275,  -283,  -290,  -302,  -305,  -316,  -320,  -325,  -330,
        -335,  -343,  -356,  -367,  -370,  -376,  -387,  -391,  -396,
        -401,  -410,  -417,  -432,  -447,  -461,  -467,  -476,  -491,
        -512,  -519,  -531,  -542,  -552,  -556,  -570,  -580,  -587,
        -593,  -602,  -620,  -623,  -628,  -631,  -636,  -642,  -645,
        -650,  -659,  -666,  -677,  -695,  -706,  -718,  -726,  -735,
        -739,  -754,  -763,  -775,  -792,  -797,  -804,  -808,  -812,
        -815,  -818,  -823,  -829,  -838,  -853,  -861,  -864,  -876,
        -889,  -893,  -901,  -911,  -916,  -926,  -931,  -948,  -953,
        -958,  -981,  -995, -1005, -1012, -1026, -1042, -1046, -1054,
       -1065, -1072,

In [57]:
a = 0.9
for i in range(1000):
    a*=0.9
    print(a)

0.81
0.7290000000000001
0.6561000000000001
0.5904900000000002
0.5314410000000002
0.47829690000000014
0.43046721000000016
0.38742048900000015
0.34867844010000015
0.31381059609000017
0.28242953648100017
0.25418658283290013
0.22876792454961012
0.2058911320946491
0.1853020188851842
0.16677181699666577
0.1500946352969992
0.13508517176729928
0.12157665459056936
0.10941898913151243
0.0984770902183612
0.08862938119652508
0.07976644307687257
0.07178979876918531
0.06461081889226679
0.05814973700304011
0.0523347633027361
0.04710128697246249
0.042391158275216244
0.03815204244769462
0.03433683820292516
0.030903154382632643
0.02781283894436938
0.025031555049932444
0.0225283995449392
0.020275559590445278
0.01824800363140075
0.016423203268260675
0.014780882941434608
0.013302794647291147
0.011972515182562033
0.01077526366430583
0.009697737297875247
0.008727963568087723
0.00785516721127895
0.007069650490151055
0.00636268544113595
0.005726416897022355
0.00515377520732012
0.004638397686588107
0.0041745579