In [1]:
import numpy as np

In [24]:
actions = np.zeros((3, 3, 2))
actions[..., 0], actions[..., 1] = np.meshgrid(np.arange(-1,2), np.arange(-1,2))
print(actions)

[[[-1. -1.]
  [ 0. -1.]
  [ 1. -1.]]

 [[-1.  0.]
  [ 0.  0.]
  [ 1.  0.]]

 [[-1.  1.]
  [ 0.  1.]
  [ 1.  1.]]]


In [25]:
actions = np.empty((3,3), dtype=object)
for dvx in range(-1,2):
    for dvy in range(-1,2):
        actions[dvx+1, dvy+1] = np.array([dvx, dvy])
        
print(actions)

[[array([-1, -1]) array([-1,  0]) array([-1,  1])]
 [array([ 0, -1]) array([0, 0]) array([0, 1])]
 [array([ 1, -1]) array([1, 0]) array([1, 1])]]


In [26]:
def random_start_state():
    x = np.random.randint(0, 14)
    return (np.array([x, 0]), np.zeros((2), dtype=int))

def did_win(state):
    x, v = state
    nx = x + v
    fx = 14
    fy0 = 26
    fy1 = 31
    return x[0]<fx and nx[0]>=fx and x[1]>=fy0 and x[1]<=fy1 and nx[1]>=fy0 and nx[1]<=fy1    

def did_out_of_bounds(state):
    x, v = state
    nx = x + v
    
    first = nx[0]>=0 and nx[0]<=5 and nx[1]>=0 and nx[1]<=31
    second = nx[0]>=0 and nx[0]<=14 and nx[1]>=26 and nx[1]<=31
    
    return not (first or second)
    

def environment(state, action):
    x, v = state
    dv = action
    
    if np.random.rand() < 0.1:
        dv = np.zeros_like(dv)
    
    print(dv)
    end = False
    
    if did_win(state):
        end = True
        newstate = (x+v, v+dv)
    elif did_out_of_bounds(state):
        newstate = random_start_state()
        pass
    else:
        nx = x+v
        nv = (v+dv).clip(0,4)
        newstate = (nx, nv)
        
    reward = -1
    return newstate, reward, end

def action_from_policy(policy, state):
    x, v = state
    action_probs = policy[x[0], x[1], v[0], v[1]]
    action = np.random.choice(actions.flatten(), p=action_probs.flatten()/action_probs.sum())
    return action

def generate_episode(policy):
    state = random_start_state()
    end = False
    
    total_reward = 0
    all_states = [state]
    
    while not end:
        action = action_from_policy(policy, state)
        state, reward, end = environment(state, action)
        
        total_reward += reward
        all_states.append(state)
    return all_states, total_reward

In [27]:
def show_state(state):
    x, v = state
    for yy in range(32, -2, -1):
        for xx in range(-1, 16, 1):
            point_x = np.array([xx,yy])
            point_state = (point_x, np.zeros((2)))
            if did_out_of_bounds(point_state):
                print('X', end='')
            elif (x == point_x).all():
                print('O', end='')
            elif (x+v == point_x).all():
                print('.', end='')
            else:
                print(' ', end='')
            print(' ', end='')
        print()
        
from IPython.display import clear_output
def show_all_states(all_states):
    for state in all_states:
        show_state(state)
#         clear_output(wait=True)

In [28]:
policy = np.ones((15, 32, 5, 5, 3, 3))
policy /= policy.sum()

for i in range(10):
    all_states, total_reward = generate_episode(policy)
    print(total_reward)

[-1 -1]
[ 1 -1]
[-1 -1]
[-1  0]
[0 0]
[-1 -1]
[0 1]
[0 0]
[ 1 -1]
[0 0]
[1 1]
[ 0 -1]
[1 0]
[0 0]
[1 0]
[ 0 -1]
[1 0]
[0 0]
[-1  1]
[0 1]
[0 1]
[ 0 -1]
[0 1]
[0 1]
[1 1]
[ 0 -1]
[1 1]
[0 0]
[0 0]
[-1 -1]
[0 0]
[-1  0]
[0 0]
[-1  1]
[-1  0]
[1 0]
[0 0]
[ 0 -1]
[ 1 -1]
[0 0]
[ 0 -1]
[0 0]
[-1  1]
[-1  0]
[-1  0]
[-1 -1]
[0 0]
[ 1 -1]
[1 1]
[ 0 -1]
[-1  1]
[-1  1]
[-1  0]
[0 0]
[-1  1]
[ 0 -1]
[-1  0]
[0 0]
[0 1]
[-1  1]
[-1  1]
[-1  1]
[-1  1]
[0 0]
[ 1 -1]
[1 1]
[0 0]
[ 1 -1]
[ 1 -1]
[0 0]
[0 0]
[-1  0]
[1 0]
[1 1]
[-1 -1]
[0 0]
[-1  1]
[0 1]
[ 1 -1]
[-1 -1]
[-1 -1]
[-1  1]
[-1 -1]
[-1 -1]
[-1  1]
[ 0 -1]
[-1  0]
[1 0]
[-1  0]
[0 1]
[0 0]
[ 0 -1]
[0 0]
[1 1]
[-1 -1]
[-1  0]
[1 1]
[0 0]
[0 1]
[ 1 -1]
[-1  0]
[ 1 -1]
[-1  0]
[0 0]
[0 0]
[ 1 -1]
[ 1 -1]
[ 1 -1]
[-1  0]
[ 0 -1]
[-1  0]
[0 0]
[-1 -1]
[ 1 -1]
[0 0]
[ 1 -1]
[-1  1]
[0 0]
[0 1]
[-1  0]
[-1  1]
[-1  0]
[1 1]
[ 1 -1]
[ 1 -1]
[1 1]
[-1  1]
[-1  1]
[-1  1]
[-1  0]
[-1  0]
[0 1]
[-1  0]
[0 0]
[0 0]
[1 1]
[1 1]
[0 0]
[0 0]
[0 1]
[0 1

[0 0]
[ 0 -1]
[0 0]
[-1  0]
[1 1]
[1 0]
[ 0 -1]
[-1 -1]
[0 0]
[-1  0]
[0 0]
[1 1]
[-1 -1]
[ 1 -1]
[ 1 -1]
[1 0]
[-1  0]
[0 0]
[ 0 -1]
[ 1 -1]
[ 1 -1]
[1 1]
[-1 -1]
[0 1]
[-1  1]
[0 0]
[0 0]
[0 1]
[-1  0]
[0 0]
[-1 -1]
[1 1]
[-1  0]
[-1  1]
[1 1]
[1 1]
[0 0]
[0 0]
[ 0 -1]
[ 0 -1]
[ 1 -1]
[0 0]
[ 1 -1]
[0 0]
[0 0]
[-1  0]
[-1  1]
[-1  0]
[-1  1]
[0 1]
[0 0]
[1 1]
[ 1 -1]
[ 1 -1]
[ 0 -1]
[1 1]
[0 1]
[-1 -1]
[-1  0]
[1 1]
[ 1 -1]
[1 1]
[1 0]
[-1  1]
[1 1]
[1 0]
[-1  0]
[1 1]
[ 0 -1]
[-1 -1]
[0 0]
[1 0]
[1 1]
[0 0]
[-1  0]
[1 1]
[-1  0]
[0 1]
[0 0]
[-1 -1]
[ 0 -1]
[0 0]
[1 0]
[-1  0]
[0 0]
[-1  0]
[1 0]
[0 0]
[ 1 -1]
[0 0]
[0 0]
[0 0]
[ 1 -1]
[ 0 -1]
[0 0]
[-1 -1]
[0 0]
[1 1]
[0 1]
[0 0]
[0 0]
[-1  0]
[1 1]
[1 1]
[0 1]
[-1 -1]
[ 0 -1]
[1 0]
[ 1 -1]
[1 0]
[1 1]
[ 0 -1]
[-1  0]
[0 0]
[-1  0]
[0 0]
[0 0]
[1 1]
[0 1]
[0 0]
[0 0]
[1 1]
[-1  0]
[-1  0]
[1 1]
[ 1 -1]
[-1  0]
[1 1]
[-1  0]
[0 1]
[1 1]
[0 1]
[0 1]
[1 1]
[1 0]
[0 0]
[0 0]
[ 0 -1]
[-1  0]
[0 0]
[1 0]
[ 1 -1]
[0 1]
[0 0]
[0 0]
[0 0]
[ 

[0 0]
[-1 -1]
[1 1]
[0 1]
[ 1 -1]
[-1 -1]
[0 1]
[ 0 -1]
[0 0]
[-1  0]
[0 1]
[1 1]
[0 0]
[0 0]
[-1 -1]
[0 0]
[1 0]
[1 0]
[-1  0]
[-1  1]
[-1  1]
[-1  1]
[1 0]
[0 0]
[-1  1]
[-1  1]
[1 0]
[-1  1]
[ 1 -1]
[0 0]
[0 0]
[-1  0]
[-1  0]
[-1  0]
[0 0]
[-1  0]
[-1  0]
[-1 -1]
[0 1]
[-1  1]
[1 0]
[0 0]
[1 0]
[0 0]
[ 0 -1]
[0 1]
[-1  1]
[-1  1]
[-1 -1]
[1 0]
[1 0]
[0 0]
[1 1]
[ 1 -1]
[1 0]
[ 1 -1]
[-1  1]
[1 0]
[ 1 -1]
[ 1 -1]
[0 0]
[0 0]
[ 0 -1]
[ 1 -1]
[0 0]
[-1 -1]
[0 1]
[0 0]
[1 0]
[-1  0]
[0 1]
[1 1]
[0 0]
[1 0]
[1 0]
[-1  0]
[0 0]
[-1  1]
[0 0]
[-1 -1]
[1 1]
[1 1]
[-1  1]
[0 0]
[0 0]
[-1  0]
[-1 -1]
[1 1]
[0 1]
[0 1]
[0 1]
[-1  1]
[0 0]
[ 0 -1]
[-1  0]
[1 1]
[ 0 -1]
[ 0 -1]
[ 1 -1]
[1 0]
[ 1 -1]
[0 0]
[ 0 -1]
[ 1 -1]
[-1 -1]
[0 0]
[0 0]
[-1 -1]
[-1  1]
[ 0 -1]
[1 1]
[ 1 -1]
[ 0 -1]
[ 1 -1]
[0 1]
[1 0]
[ 0 -1]
[0 0]
[-1  0]
[1 1]
[1 0]
[-1  1]
[1 1]
[0 0]
[ 0 -1]
[1 0]
[0 0]
[0 1]
[ 1 -1]
[-1  1]
[0 0]
[-1  1]
[1 0]
[0 0]
[ 1 -1]
[0 0]
[1 1]
[1 1]
[1 0]
[0 1]
[0 1]
[1 0]
[-1  0]
[0 0]
[1 0]


[1 0]
[-1 -1]
[-1  1]
[0 0]
[0 1]
[-1  1]
[0 0]
[-1  0]
[-1 -1]
[0 0]
[0 1]
[-1 -1]
[-1  1]
[ 0 -1]
[0 0]
[0 1]
[1 1]
[-1  1]
[-1 -1]
[-1 -1]
[-1 -1]
[ 0 -1]
[-1  0]
[0 0]
[ 1 -1]
[0 1]
[-1 -1]
[-1  0]
[-1  0]
[1 0]
[ 1 -1]
[-1  0]
[-1  0]
[-1  1]
[0 0]
[0 0]
[ 1 -1]
[1 0]
[ 1 -1]
[-1  0]
[0 0]
[-1  1]
[-1  0]
[0 1]
[0 1]
[-1  0]
[0 1]
[0 0]
[1 1]
[0 0]
[0 0]
[1 1]
[0 0]
[0 1]
[1 0]
[1 0]
[ 0 -1]
[ 0 -1]
[ 0 -1]
[0 0]
[0 1]
[0 1]
[0 0]
[1 0]
[-1  1]
[0 0]
[-1  0]
[0 0]
[ 0 -1]
[-1 -1]
[ 0 -1]
[0 0]
[0 0]
[-1 -1]
[0 0]
[0 1]
[-1  0]
[0 0]
[-1 -1]
[1 0]
[0 1]
[1 0]
[-1  1]
[ 1 -1]
[-1  0]
[1 0]
[0 0]
[ 0 -1]
[1 1]
[-1  1]
[0 0]
[0 1]
[-1  0]
[0 1]
[1 0]
[-1  1]
[1 1]
[0 0]
[ 1 -1]
[-1  1]
[1 0]
[-1  1]
[ 0 -1]
[-1  0]
[0 0]
[-1  1]
[1 0]
[-1 -1]
[0 0]
[0 0]
[0 1]
[0 1]
[-1  0]
[0 1]
[ 1 -1]
[ 1 -1]
[1 0]
[ 1 -1]
[1 0]
[0 0]
[ 1 -1]
[1 0]
[-1  1]
[-1  0]
[0 0]
[0 0]
[-1 -1]
[-1 -1]
[ 0 -1]
[1 1]
[1 0]
[ 1 -1]
[1 0]
[0 0]
[0 0]
[-1 -1]
[1 1]
[ 1 -1]
[-1  0]
[-1  1]
[-1  0]
[ 0 -1]
[-1 -1]


[-1  1]
[ 0 -1]
[0 0]
[-1  1]
[0 0]
[1 0]
[1 0]
[ 1 -1]
[1 0]
[1 1]
[1 0]
[-1 -1]
[1 0]
[-1  0]
[-1 -1]
[-1 -1]
[0 0]
[-1 -1]
[ 0 -1]
[0 0]
[1 1]
[ 1 -1]
[1 1]
[-1  0]
[0 0]
[-1  1]
[1 0]
[ 1 -1]
[0 0]
[0 0]
[0 0]
[ 1 -1]
[ 0 -1]
[ 1 -1]
[-1 -1]
[0 0]
[0 0]
[1 0]
[-1  0]
[ 1 -1]
[1 0]
[ 0 -1]
[0 0]
[-1 -1]
[0 0]
[1 1]
[0 0]
[-1 -1]
[-1 -1]
[ 0 -1]
[1 1]
[0 0]
[-1 -1]
[-1  0]
[-1  0]
[-1  0]
[-1  0]
[-1  0]
[0 0]
[ 1 -1]
[0 1]
[1 0]
[-1  1]
[1 1]
[1 1]
[ 1 -1]
[-1  1]
[ 0 -1]
[0 1]
[-1  0]
[ 1 -1]
[0 0]
[-1  0]
[-1  1]
[0 0]
[0 0]
[0 1]
[-1 -1]
[1 1]
[0 0]
[0 0]
[0 0]
[1 1]
[ 0 -1]
[ 0 -1]
[-1  1]
[1 1]
[-1 -1]
[ 0 -1]
[0 0]
[ 1 -1]
[1 1]
[1 1]
[1 1]
[0 0]
[-1  0]
[1 0]
[-1 -1]
[0 1]
[0 0]
[1 0]
[-1 -1]
[ 0 -1]
[-1  1]
[ 0 -1]
[ 0 -1]
[1 0]
[-1  0]
[0 0]
[0 1]
[1 0]
[-1  1]
[ 0 -1]
[ 0 -1]
[0 0]
[ 0 -1]
[1 0]
[1 1]
[ 1 -1]
[-1 -1]
[0 0]
[0 0]
[1 1]
[ 0 -1]
[ 1 -1]
[-1  1]
[0 1]
[-1  1]
[1 1]
[-1  0]
[0 1]
[0 1]
[-1 -1]
[0 0]
[1 0]
[1 0]
[-1  1]
[0 1]
[ 0 -1]
[ 1 -1]
[0 0]
[0 0]
[-1 -1]


[-1 -1]
[-1  0]
[0 1]
[ 1 -1]
[0 0]
[ 1 -1]
[0 0]
[ 0 -1]
[ 1 -1]
[-1  1]
[0 0]
[0 0]
[0 0]
[ 0 -1]
[-1 -1]
[1 0]
[-1  0]
[ 1 -1]
[0 1]
[ 0 -1]
[ 1 -1]
[ 1 -1]
[ 0 -1]
[1 1]
[-1  0]
[-1 -1]
[-1  0]
[-1 -1]
[0 0]
[0 0]
[-1 -1]
[ 0 -1]
[0 0]
[ 0 -1]
[1 0]
[1 0]
[0 0]
[ 0 -1]
[0 1]
[ 1 -1]
[0 0]
[ 0 -1]
[ 0 -1]
[0 1]
[ 1 -1]
[0 0]
[-1 -1]
[-1  0]
[1 0]
[ 0 -1]
[0 0]
[-1 -1]
[1 0]
[-1 -1]
[ 1 -1]
[0 1]
[ 0 -1]
[-1  1]
[0 0]
[-1 -1]
[0 0]
[1 1]
[0 0]
[-1 -1]
[ 0 -1]
[1 0]
[0 0]
[ 0 -1]
[-1  1]
[-1  1]
[ 0 -1]
[-1  1]
[ 1 -1]
[1 1]
[0 0]
[-1  1]
[0 1]
[0 0]
[ 1 -1]
[ 0 -1]
[0 1]
[0 0]
[0 0]
[-1  1]
[-1 -1]
[0 0]
[ 0 -1]
[ 1 -1]
[0 0]
[-1  0]
[1 1]
[1 1]
[0 1]
[ 0 -1]
[ 1 -1]
[-1 -1]
[0 0]
[1 0]
[ 1 -1]
[1 1]
[1 1]
[0 0]
[1 1]
[0 0]
[-1  0]
[1 1]
[1 1]
[-1  1]
[-1  1]
[1 1]
[-1  0]
[ 0 -1]
[1 1]
[0 0]
[-1  1]
[-1 -1]
[ 0 -1]
[1 1]
[-1 -1]
[-1  0]
[-1 -1]
[0 0]
[0 0]
[1 1]
[ 1 -1]
[1 1]
[0 0]
[0 0]
[1 1]
[-1 -1]
[ 1 -1]
[1 1]
[-1  1]
[0 1]
[ 1 -1]
[1 1]
[-1  1]
[ 1 -1]
[-1  1]
[ 1 -1]
[1 1]
[1

[-1 -1]
[ 0 -1]
[ 0 -1]
[-1 -1]
[0 1]
[-1  0]
[1 0]
[0 1]
[ 0 -1]
[1 1]
[1 1]
[-1  1]
[-1 -1]
[1 1]
[-1  1]
[0 0]
[1 0]
[1 1]
[-1  1]
[-1  1]
[0 0]
[0 0]
[ 0 -1]
[-1  0]
[ 1 -1]
[-1  0]
[1 0]
[ 1 -1]
[-1  1]
[1 1]
[-1  0]
[-1  0]
[0 0]
[1 0]
[-1 -1]
[ 0 -1]
[-1 -1]
[0 0]
[ 1 -1]
[1 1]
[0 0]
[-1 -1]
[1 0]
[-1 -1]
[-1  0]
[1 0]
[1 1]
[0 0]
[ 0 -1]
[-1  1]
[0 0]
[0 1]
[ 1 -1]
[-1 -1]
[-1  0]
[-1  0]
[0 0]
[1 0]
[ 0 -1]
[ 0 -1]
[1 1]
[-1  0]
[1 0]
[1 1]
[-1  1]
[-1  1]
[1 1]
[ 1 -1]
[ 0 -1]
[-1  0]
[0 0]
[0 0]
[0 0]
[-1  0]
[ 1 -1]
[ 0 -1]
[0 0]
[0 0]
[1 1]
[0 0]
[0 0]
[0 1]
[-1 -1]
[ 1 -1]
[-1  0]
[1 1]
[1 1]
[0 0]
[0 1]
[-1 -1]
[-1  1]
[1 0]
[ 0 -1]
[-1  1]
[-1  0]
[0 0]
[0 1]
[0 0]
[1 0]
[ 1 -1]
[-1  0]
[1 1]
[0 0]
[-1 -1]
[-1  1]
[ 0 -1]
[0 0]
[0 0]
[0 0]
[-1 -1]
[0 1]
[0 0]
[0 1]
[ 1 -1]
[1 0]
[-1 -1]
[0 0]
[0 1]
[0 0]
[0 1]
[0 0]
[1 0]
[ 0 -1]
[-1 -1]
[1 0]
[-1  1]
[0 1]
[ 1 -1]
[-1  1]
[0 0]
[-1  1]
[-1  0]
[1 0]
[-1  1]
[-1 -1]
[ 1 -1]
[0 1]
[0 0]
[ 1 -1]
[0 0]
[1 0]
[ 0 -1]
[-1  0

KeyboardInterrupt: 