In [29]:
import numpy as np
import pandas as pd
import time

EVENTS = ["m1", "m2", "m3", "c1", "c2", "c3"]

MOUSE_STATES = {
        (1,3):0,
        (1,4):1, (1,5):1,
        (2,3):2,
        (2,4):3, (2,5):3,
        (3,3):4,
        (3,4):5, (3,5):5
    }

CAT_STATES = {
        (1,3):0, (2,3):0,
        (1,4):1, (2,4):1, 
        (1,5):2, (2,5):2, 
        (3,3):3,
        (3,4):4, 
        (3,5):5
    }

FEASIBLE_EVENTS = {
    (1,3):("m2", "c3"),
    (1,4):("m2", "c1"),
    (1,5):("m2", "c2"),
    
    (2,3):("m1", "c3"),
    (2,4):("m1", "c1"),
    (2,5):("m1", "c2"),
    
    (3,3):("m3", "c3"),
    (3,4):("m3", "c1"),
    (3,5):("m3", "c2"),
}

In [30]:
def policy_num_to_binary_list(policy_num):
    """ 
    Takes policy number (0-7) and return the binary list corresponds to the number
        1 represents enabled and opened
        0 represents disabled and closed
    Ex)
        0 = [0, 0, 0]
        1 = [0, 0, 1]
        2 = [0, 1, 0]
        3 = [0, 1, 1]
        4 = [1, 0, 0]
        5 = [1, 0, 1]
        6 = [1, 1, 0]
        7 = [1, 1, 1]

    Args:
        policy_num (int): Column number that represents policy (0-7)
    Returns:
        list of (0,1): Binary list
    """
    binary = np.binary_repr(policy_num)
    while len(binary)<3:
        binary = "0"+binary

    binary_int = int(binary)
    binary_list = []
    for _ in range(3):
        binary_list.insert(0, binary_int%10)
        binary_int = binary_int//10
    return binary_list

def get_net_policy(cat_policy, mouse_policy):
    net_policy = []
    for i in cat_policy:
        if i in mouse_policy:
            net_policy.append(i)
    return net_policy

def mouse_observation_to_state(observation):
    return MOUSE_STATES.get(tuple(observation))

def cat_observation_to_state(observation):
    return CAT_STATES.get(tuple(observation))

def get_event(q_mouse, q_cat, eta_mouse, eta_cat, state, observation, print_policy):
    mouse_state = mouse_observation_to_state(state)
    cat_state = mouse_observation_to_state(state)
    
    mouse_policy_num = np.argmax(q_mouse[mouse_state])
    
    if mouse_policy_num == 1:
        mouse_policy = [FEASIBLE_EVENTS.get(tuple(observation))[0]]+["c1", "c2", "c3"]
    else:
        mouse_policy = ["c1", "c2", "c3"]
    
    
    cat_policy_num = np.argmax(q_cat[cat_state])
    
    if cat_policy_num == 1:
        cat_policy = ["m1", "m2", "m3"]+[FEASIBLE_EVENTS.get(tuple(observation))[1]]
    else:
        cat_policy = ["m1", "m2", "m3"]
    
    
    net_policy = get_net_policy(cat_policy, mouse_policy)
    
    if print_policy:
        print("Mouse Policy:", mouse_policy)
        print("  Cat Policy:", cat_policy)
        print("  Net Policy:", net_policy)
    
    eta_mouse_state = eta_mouse[mouse_state]
    dummy_eta_mouse = np.concatenate((eta_mouse_state[0:3], [-1], eta_mouse_state[3:]))
    
    eta_cat_state = eta_cat[cat_state]
    dummy_eta_cat = np.concatenate(([-1], eta_cat_state))
    
    max_eta = 0    
    event = None
    for curr_event in net_policy:
        curr_event_num = EVENTS.index(curr_event)
        curr_eta = max(dummy_eta_mouse[curr_event_num], dummy_eta_cat[curr_event_num])
        
        if max_eta <= curr_eta:
            max_eta = curr_eta
            event = curr_event
            
    disabled = get_disabled_event(net_policy, state)
    
    return event, disabled

def get_disabled_event(net_policy, observation):
    feasible_events = FEASIBLE_EVENTS.get(tuple(observation))

    disabled = []
    for event in EVENTS:
        if (event not in net_policy) and (event in feasible_events):
            disabled.append(event)
            
    return disabled   

In [31]:
q_mouse = pd.read_csv("q_mouse.csv").drop(columns="Unnamed: 0").to_numpy()
q_cat = pd.read_csv("q_cat.csv").drop(columns="Unnamed: 0").to_numpy()

eta_mouse = pd.read_csv("eta_mouse.csv").drop(columns="Unnamed: 0").to_numpy()
eta_cat = pd.read_csv("eta_cat.csv").drop(columns="Unnamed: 0").to_numpy()


# 1. Render Simulation

In [32]:
def update_door(doors, event, disabled):
    # Enabling event
    doors[EVENTS.index(event)] = 1
    
    # Disabling event(s)
    for d_event in disabled:
        d_event_num = EVENTS.index(d_event)
        if doors[d_event_num] == 1:
            doors[d_event_num] = 0
    
def cat_move(cat_position, event):
    if cat_position == 3:
        required_event = "c3"
        
    if cat_position == 4:
        required_event = "c1"
    
    if cat_position == 5:
        required_event = "c2"
    
    if required_event == event:        
        cat_position = cat_position + 1
        if cat_position == 6:
            cat_position = 3
    return cat_position

def mouse_move(mouse_position, event):
    if mouse_position == 1:
        required_event = "m2"
        
    if mouse_position == 2:
        required_event = "m1"
    
    if mouse_position == 3:
        required_event = "m3"

    if required_event == event:        
        mouse_position = mouse_position - 1
        if mouse_position == 0:
            mouse_position = 3
    
    return mouse_position

def render(mouse_position, cat_position):
    print(mouse_position, cat_position)
    """
    Render the environment.
    """
    cat = cat_position-1
    mouse = mouse_position-1
    if cat_position == 3:
        cat=5

    grid = [" " for _ in range(6)]
    grid[cat] = "C"
    grid[mouse] = "M"
    print(f"\
_________________________\n\
|   1   |       |   4   |\n\
|   {grid[0]}   |   3   |   {grid[3]}   |\n\
---------   {grid[2]}   ---------\n\
|   2   |   {grid[5]}   |   5   |\n\
|   {grid[1]}   |       |   {grid[4]}   |")

In [33]:
from IPython.display import clear_output

cat_position = 4
mouse_position = 2

terminated = False

render(mouse_position, cat_position)

doors = np.ones(shape=(6,))

for count in range(20):
    state = (mouse_position, cat_position)
    event, disabled = get_event(q_mouse, q_cat, eta_mouse, eta_cat, state, (mouse_position, cat_position), True)
    
    update_door(doors, event, disabled)
    
    cat_position = cat_move(cat_position, event)
    mouse_position = mouse_move(mouse_position, event)
    
    time.sleep(0.7)
    clear_output(wait=True)
    render(mouse_position, cat_position)
    
    print("Door Condition:", doors)
    print("Event:", event)
    print("Disabled:", disabled)
    
    if cat_position == 3 and mouse_position == 3:
        terminated = True
    
    if terminated:
        break

3 4
_________________________
|   1   |       |   4   |
|       |   3   |   C   |
---------   M   ---------
|   2   |       |   5   |
|       |       |       |
Door Condition: [1. 1. 1. 0. 1. 1.]
Event: m2
Disabled: []
