## First Visit Monte Carlo Prediction Estimating State-Value Function

In [1]:
import sys
import os

sys.path.append(os.path.dirname(os.getcwd()))

from gridworld import GridWorld,plot
import numpy as np

In [6]:
def generate_episode(Pi,size):
    '''
    play an episode on given policy and return G
    '''
    
    env = GridWorld(size)
    done = False
    G = {} #episode reward
    
    while not done:
        #current_state
        current_state = env.agent_position.copy()
        
        #get action and step
        action_prob = Pi[tuple(current_state)] #
        action = np.random.choice(range(len(action_prob)),p = action_prob)
        next_state, reward, done, info = env.step(action)
        
        #append reward to all visited state
        if tuple(current_state) in G:
            G[tuple(current_state)] += reward
        else:
            G[tuple(current_state)] = reward

    return G

def MC_prediction(Pi,size,iter_num):
    '''
    perfrom first-visit MC prediction
    '''
    
    V = dict.fromkeys([(i,j) for i in range(size) for j in range(size)], 0) # values as 0
    Returns = {k: [] for k in [(i,j) for i in range(size) for j in range(size)]} # returns as 0
    
    
    for i in range(iter_num):
        #generate an episode
        G = generate_episode(Pi,size)
        
        for seen_state in G:
            #append G to Returns
            Returns[seen_state]+=[G[seen_state]]
            #average returns and update V
            V[seen_state] = round(sum(Returns[seen_state])/len(Returns[seen_state]),3)
    return V,Returns

In [13]:
size = 5
V = dict.fromkeys([(i,j) for i in range(size) for j in range(size)], 0) # values as 0
Pi = dict.fromkeys([(i,j) for i in range(size) for j in range(size) if i+j != 0 and i+j != (size-1)*2], [0.25]*4)# inital actions as UP
V_final,Returns = MC_prediction(Pi,size,5000)

In [14]:
plot(V_final,Pi,size)

0	|-1.71	|-3.34	|-4.307	|-5.598	|
-1.745	|-2.56	|-3.047	|-3.5	|-4.235	|
-3.464	|-3.006	|-2.913	|-3.006	|-3.473	|
-4.352	|-3.494	|-2.927	|-2.491	|-1.76	|
-5.43	|-4.344	|-3.438	|-1.731	|0	|
-------------------------------
X	|↑←↓→	|↑←↓→	|↑←↓→	|↑←↓→	|
↑←↓→	|↑←↓→	|↑←↓→	|↑←↓→	|↑←↓→	|
↑←↓→	|↑←↓→	|↑←↓→	|↑←↓→	|↑←↓→	|
↑←↓→	|↑←↓→	|↑←↓→	|↑←↓→	|↑←↓→	|
↑←↓→	|↑←↓→	|↑←↓→	|↑←↓→	|X	|
 
