In [18]:
%matplotlib notebook
import numpy as np
from tqdm import tqdm_notebook as tqdm
from matplotlib import pyplot as plt


In [4]:
class shortGridWorld():
    def __init__(self, start_state):
        self.state = start_state
        self.reward = -1
    
    def take_action(self, state, action):
        if state == 0:
            if all(action == np.array([[0.], [1.]])):
                self.state = 0
            else:
                self.state = 1
        elif state == 1:
            if all(action == np.array([[0.], [1.]])):
                self.state = 2
            else:
                self.state = 1
        elif state == 2:
            if all(action ==np.array([[0.],[1.]])):
                self.state = 1
            else:
                self.state = 3
            
    def is_terminal(self):
        return self.state == 3        

In [38]:
class REINFORCE_mc():
    def __init__(self, Problem, num_episodes=100, alpha_t=2**(-13), alpha_w = 0.001, y=0.9, max_ep_len=1000):
        self.Problem = Problem
        self.num_episodes = num_episodes
        self.y = y
        self.theta = np.array([[-1.47], [1.47]])
        self.w = 0
        self.x = np.array([[1.,0.],[0.,1.]])
        self.r = []
        self.s = []
        self.a = []
        self.G = []
        self.alpha_w = alpha_w
        self.alpha_t = alpha_t
        self.max_ep_len = max_ep_len
        
        self.I = []
        self.L = []
        
    def get_policy_distribution(self):
        return (np.exp(self.theta.T @ self.x)/np.sum(np.exp(self.theta.T @ self.x))).T
    
    def sample_from_distribution(self, P):
        CP = np.cumsum(P)
        rn = np.random.uniform()

        for i in range(len(CP)):
            if CP[i] >= rn:
                return i
    
    def get_action(self):
        action = np.zeros((2,1))

        # maybe a separate functon to get the feature vector of action
        action[self.sample_from_distribution(self.get_policy_distribution())] = 1 
        return action
    
    def clear(self):
        self.r = []
        self.s = []
        self.a = []
        
    def gen_episode(self,s0=0, max_len=1000):
        self.clear()
        problem = self.Problem(s0)
        
        state = s0
        self.s.append(state)
        
        G = 0
        I = 1
        
        while (not problem.is_terminal()) and max_len != 0:
            action = self.get_action()
            self.a.append(np.argmax(action))
            
            problem.take_action(state, action)
            state = problem.state
            self.s.append(state)
            
            self.r.append(problem.reward)
            G += problem.reward * I
            I *= self.y
            
            max_len -= 1
        
        self.G.append(G)
        self.I.append(I)
        self.L.append((len(self.a), self.G[-1]))
        
        return G
    
    def mc(self):
        alpha = self.alpha_t
        for i in tqdm(range(self.num_episodes)):
            G = self.gen_episode(0, self.max_ep_len)
            sub = 0
            I = 1
            
            for t in range(len(self.s)-1):
                G -= sub
                delta_ln = np.copy(self.x[:,self.a[t], np.newaxis])
                pi = self.get_policy_distribution()
                delta_ln -= self.x @ pi
    
                # y^t is covered in G in this implementation. G is not strictly G, rather y^t*G
                self.theta = self.theta + alpha*G*delta_ln 
                
                sub = self.r[t]*I
                I *= self.y
                
    def mc_with_baseline(self):
        alpha = self.alpha_t
        alpha_w = self.alpha_w
        
        for i in tqdm(range(self.num_episodes)):
            G = self.gen_episode(0, self.max_ep_len)
            sub = 0
            I = 1
            
            for t in range(len(self.s)-1):
                G -= sub # here G is the proper G, ie does not contain y^t
                G /= I
                delta = G - self.w
                
                self.w += alpha_w*delta
                
                delta_ln = np.copy(self.x[:,self.a[t], np.newaxis])
                pi = self.get_policy_distribution()
                delta_ln -= self.x @ pi
                self.theta += alpha*I*delta*delta_ln 
                
                sub = self.r[t]
                I *= self.y
                
            

In [34]:
num_trials = 100
num_ep_per_trial = 1000

rew = np.zeros((num_ep_per_trial,1))
for i in tqdm(range(num_trials)):
    a = REINFORCE_mc(shortGridWorld, num_episodes=num_ep_per_trial, alpha_t=2**(-12), y=1, max_ep_len=1000)    
    a.mc()
    rew += np.reshape(np.array(a.G), (num_ep_per_trial,1))
    a.G = []
    a.L = []
    
print(rew/num_trials)

HBox(children=(IntProgress(value=0), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))

[[-61.3 ]
 [-63.58]
 [-61.66]
 [-55.48]
 [-50.5 ]
 [-57.78]
 [-50.14]
 [-47.79]
 [-54.91]
 [-45.06]
 [-41.97]
 [-41.13]
 [-45.32]
 [-40.94]
 [-51.11]
 [-39.93]
 [-38.15]
 [-44.43]
 [-45.62]
 [-38.31]
 [-37.9 ]
 [-40.21]
 [-37.27]
 [-34.33]
 [-38.11]
 [-39.88]
 [-32.15]
 [-38.31]
 [-31.7 ]
 [-33.6 ]
 [-34.79]
 [-35.  ]
 [-33.28]
 [-28.01]
 [-34.19]
 [-32.81]
 [-28.66]
 [-30.27]
 [-32.95]
 [-32.18]
 [-27.62]
 [-30.87]
 [-25.45]
 [-30.2 ]
 [-32.86]
 [-31.  ]
 [-28.16]
 [-29.14]
 [-25.41]
 [-24.77]
 [-26.29]
 [-27.15]
 [-27.58]
 [-28.71]
 [-24.53]
 [-26.92]
 [-23.94]
 [-24.84]
 [-28.6 ]
 [-24.34]
 [-27.27]
 [-21.97]
 [-26.81]
 [-23.88]
 [-24.  ]
 [-24.83]
 [-24.21]
 [-21.78]
 [-22.41]
 [-23.83]
 [-23.58]
 [-22.79]
 [-22.84]
 [-22.58]
 [-22.46]
 [-19.65]
 [-24.42]
 [-21.94]
 [-22.27]
 [-24.41]
 [-21.77]
 [-23.27]
 [-20.42]
 [-20.46]
 [-19.65]
 [-20.4 ]
 [-20.72]
 [-20.82]
 [-21.31]
 [-22.04]
 [-18.85]
 [-18.88]
 [-21.43]
 [-19.56]
 [-18.45]
 [-17.31]
 [-17.38]
 [-18.53]
 [-17.54]
 [-20.8 ]


In [35]:
a.get_policy_distribution()

array([[0.55244391],
       [0.44755609]])

In [36]:
a.w

0

In [42]:
# num_trials = 100
# num_ep_per_trial = 1000

rew2 = np.zeros((num_ep_per_trial,1))
for i in tqdm(range(num_trials)):
    a = REINFORCE_mc(shortGridWorld, num_episodes=num_ep_per_trial, alpha_t=2**(-12), alpha_w = 0.0001, y=1, max_ep_len=1000)    
    a.mc_with_baseline()
    rew2 += np.reshape(np.array(a.G), (num_ep_per_trial,1))
    a.G = []
    a.L = []
    
print(rew2/num_trials)

HBox(children=(IntProgress(value=0), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))

KeyboardInterrupt: 

In [43]:
plt.plot(rew/num_trials, label='mc')
plt.plot(rew2/num_trials, label='with_baseline')
plt.legend()
plt.show()

<IPython.core.display.Javascript object>