In [1]:
from environment import TreasureCube

In [2]:
import numpy as np
from collections import defaultdict
from pprint import pprint
from copy import deepcopy

In [3]:
action_space = ['left', 'right', 'forward', 'backward', 'up', 'down']
DISCOUNT_FACTOR = 0.99
THETA = 0.0001

In [4]:
def generate_all_possible_states(dimension):
    return [f"{i}{j}{k}" for i in range(dimension) for j in range(dimension) for k in range(dimension)]

In [5]:
env = TreasureCube(max_step=500)
state_list = generate_all_possible_states(env.dim)
q_table = defaultdict(lambda: np.zeros(len(action_space)))
v_table = defaultdict(lambda: 0.0)

In [6]:
def populate_q_table(state):
    action_index = 0

    for action in action_space:
        all_possible_actions = get_list_of_possible_actions(action)

        action_value = calculate_q_value_for_action(all_possible_actions, action)
        q_table[state][action_index] = action_value

        action_index += 1

In [7]:
def get_list_of_possible_actions(action):
    copy_of_action_space = deepcopy(action_space)
    if action == 'left':
        copy_of_action_space.remove('right')
        return copy_of_action_space
    if action == 'right':
        copy_of_action_space.remove('left')
        return copy_of_action_space
    if action == 'up':
        copy_of_action_space.remove('down')
        return copy_of_action_space
    if action == 'down':
        copy_of_action_space.remove('up')
        return copy_of_action_space
    if action == 'forward':
        copy_of_action_space.remove('backward')
        return copy_of_action_space

    # if action is backward
    copy_of_action_space.remove('forward')
    return copy_of_action_space

In [8]:
def calculate_q_value_for_action(all_possible_actions, current_state_action):
    action_value = 0

    for action in all_possible_actions:
        # pass in the env in to the function in the future
        current_env_state_in_list = env.curr_pos
        
        reward, _, next_state = env.step(action)

        transition_probability = get_transition_probability(action, current_state_action)
        action_value += transition_probability * (reward + DISCOUNT_FACTOR * v_table[next_state])

        env.curr_pos = current_env_state_in_list
    
    return action_value

In [9]:
def get_transition_probability(action, current_state_action):
    return 0.6 if current_state_action == action else 0.1

In [10]:
env.reset()
while True:
    delta = 0

    for state in state_list:
        populate_q_table(state)

        best_action_value_of_q_from_state = np.max(q_table[state])

        delta = max(delta, best_action_value_of_q_from_state - v_table[state])

        v_table[state] = best_action_value_of_q_from_state
    
    if delta < THETA:
        break

In [11]:
pprint(v_table)

defaultdict(<function <lambda> at 0x000001C85A130048>,
            {'000': 5.083734847365726,
             '001': 5.0507240598497045,
             '002': 5.155585885845626,
             '003': 5.108745919637135,
             '010': 5.030043986894787,
             '011': 5.114933564408955,
             '012': 5.09925752492153,
             '013': 5.070115489519788,
             '020': 5.056550581831263,
             '021': 4.998940755345577,
             '022': 5.098601878119148,
             '023': 5.083724596770177,
             '030': 5.086432674748336,
             '031': 5.024437737081961,
             '032': 5.055587375613738,
             '033': 5.063239605948565,
             '100': 5.100814953457908,
             '101': 5.108702900230931,
             '102': 5.068327465520567,
             '103': 5.027080805668779,
             '110': 5.006411337794145,
             '111': 4.999661685795445,
             '112': 5.166717586950516,
             '113': 5.086956639361642,
         

In [12]:
pprint(q_table)

defaultdict(<function <lambda> at 0x000001C84BDCA4C8>,
            {'000': array([5.04131296, 5.06259361, 5.01525527, 5.08373485, 4.96570781,
       4.96671931]),
             '001': array([4.96239466, 4.95047411, 5.01825467, 4.97321547, 5.02251622,
       5.05072406]),
             '002': array([5.06537663, 5.12866339, 5.05045731, 5.15558589, 5.11497521,
       5.12730572]),
             '003': array([5.09481575, 5.08079819, 5.10874592, 5.06015752, 5.07380792,
       4.94162425]),
             '010': array([4.95596262, 5.02163286, 5.01542366, 5.01764328, 5.02486671,
       5.03004399]),
             '011': array([5.06343661, 5.06526526, 5.08790401, 5.11493356, 5.01792616,
       5.11198531]),
             '012': array([5.07804582, 5.05738495, 4.95577182, 5.09925752, 4.97682869,
       4.98314394]),
             '013': array([5.00627039, 5.07011549, 4.98659953, 4.97950622, 4.9630074 ,
       4.9809492 ]),
             '020': array([4.98124583, 4.96146499, 4.96366485, 5.05655058, 4.9511

In [13]:
policy = defaultdict(str)

In [14]:
env.reset()

for state in state_list:
    q_table_buffer = np.zeros(len(action_space))

    action_index = 0
    for action in action_space:
        all_possible_actions = get_list_of_possible_actions(action)
        
        action_value = calculate_q_value_for_action(all_possible_actions, action)
        q_table_buffer[action_index] = action_value

        action_index += 1
    
    best_action_index = np.argmax(q_table_buffer)
    policy[state] = action_space[best_action_index]

In [15]:
policy

defaultdict(str,
            {'000': 'down',
             '001': 'up',
             '002': 'down',
             '003': 'left',
             '010': 'left',
             '011': 'up',
             '012': 'left',
             '013': 'right',
             '020': 'forward',
             '021': 'left',
             '022': 'backward',
             '023': 'forward',
             '030': 'left',
             '031': 'right',
             '032': 'up',
             '033': 'up',
             '100': 'left',
             '101': 'backward',
             '102': 'left',
             '103': 'down',
             '110': 'left',
             '111': 'down',
             '112': 'up',
             '113': 'down',
             '120': 'left',
             '121': 'forward',
             '122': 'up',
             '123': 'forward',
             '130': 'up',
             '131': 'forward',
             '132': 'left',
             '133': 'up',
             '200': 'left',
             '201': 'left',
             '202': 'u