In [1]:
import numpy as np
from collections import defaultdict
from environment import TreasureCube
from pprint import pprint

In [2]:
env = TreasureCube(max_step=500)

In [3]:
epsilon = 0.1
ACTION_SPACE = ['left', 'right', 'forward', 'backward', 'up', 'down']
nA = len(ACTION_SPACE)

In [4]:
episode_rewards_progress = []

In [5]:
def get_epision_greedy_action_policy(Q, observation):
    
    A = np.ones(nA, dtype=float) * epsilon / nA
    best_action = np.argmax(Q[observation])
    A[best_action] += (1.0 - epsilon)
    
    return A

In [6]:
def generate_episode(Q):
    episode = []
    current_state = env.reset()
    # current_state = ''.join(map(str, env.curr_pos))
    terminate = False
    episode_reward = 0

    while not terminate:
        
        prob_scores = get_epision_greedy_action_policy(Q,current_state)
        action = np.random.choice(ACTION_SPACE, p=prob_scores)
        
        reward, terminate, next_state = env.step(action)
        episode.append((current_state, action, reward))

        episode_reward += reward
        
        current_state = next_state    
    
    episode_rewards_progress.append(episode_reward)

    return episode

In [14]:
test = defaultdict(lambda: np.zeros(nA))
episodes = generate_episode(test)

In [7]:
def mc_control_epsilon_greedy(total_episodes):
  
    returns_sum = defaultdict(float)
    states_count = defaultdict(float)
    
    Q = defaultdict(lambda: np.zeros(nA)) # length of action space
    
    for k in range(total_episodes):
        
        episode = generate_episode(Q)
        
        state_actions_in_episode = list(set([(sar[0], sar[1]) for sar in episode]))
        
        for i,sa_pair in enumerate(state_actions_in_episode):
            state, action = sa_pair
            
            action_index = ACTION_SPACE.index(action)

            G = sum([sar[2] for i,sar in enumerate(episode[i:])])
            
            returns_sum[sa_pair] += G
            states_count[sa_pair] += 1.0
            Q[state][action_index] = returns_sum[sa_pair] / states_count[sa_pair]
        
        
    return Q

In [8]:
q_table = mc_control_epsilon_greedy(5000)
q_table

defaultdict(<function __main__.mc_control_epsilon_greedy.<locals>.<lambda>()>,
            {'000': array([-5.33241758, -6.47397959, -5.11947368, -4.85978261, -4.7795571 ,
                    -5.765     ]),
             '001': array([-6.37855362, -5.93258145, -7.02047478, -5.31374622, -5.4762533 ,
                    -6.09040404]),
             '101': array([-6.8025641 , -9.0047619 , -5.83232975, -9.34411765, -7.83469388,
                    -9.80952381]),
             '201': array([ -8.10365854,  -8.47073171,  -9.30106383,  -5.64680711,
                    -10.98382353,  -9.98974359]),
             '202': array([-10.26119403,  -5.12734177,  -9.99076923,  -7.30335731,
                     -9.75689655,  -9.025     ]),
             '301': array([-12.10235294, -10.35714286, -10.72162162,  -9.28714286,
                     -7.04980211, -10.29295775]),
             '300': array([ -8.20697674,  -7.7195122 ,  -8.50363636,  -9.98823529,
                     -6.70390438, -12.01836735]),
        

In [9]:
np.argmax(q_table['000'])

4

In [10]:
optimal_policy = defaultdict(str)

In [11]:
for state in q_table:
    action_space_index = np.argmax(q_table[state])
    action = ACTION_SPACE[action_space_index]

    optimal_policy[state] = action

optimal_policy

defaultdict(str,
            {'000': 'up',
             '001': 'backward',
             '101': 'forward',
             '201': 'backward',
             '202': 'right',
             '301': 'up',
             '300': 'up',
             '200': 'forward',
             '211': 'forward',
             '111': 'right',
             '011': 'backward',
             '002': 'right',
             '102': 'right',
             '103': 'forward',
             '203': 'forward',
             '303': 'backward',
             '302': 'forward',
             '100': 'right',
             '313': 'backward',
             '213': 'right',
             '311': 'up',
             '110': 'down',
             '310': 'up',
             '010': 'backward',
             '021': 'right',
             '121': 'up',
             '221': 'forward',
             '321': 'right',
             '220': 'up',
             '210': 'left',
             '120': 'forward',
             '020': 'right',
             '022': 'forward',
             

In [12]:
episode_rewards_progress.index(max(episode_rewards_progress))

2985