In [None]:
# Papermill injected parameters (default values here are placeholders)
rows = 5
cols = 5
start = (0, 0)
goal = (4, 4)
holes = [(0,3),(1,0),(2,2),(3,0),(3,2)]
slippery = False
slip_prob = 0.2  # (not used by base gym env)
gamma = 0.99
theta = 1e-8
map_desc = None  # optional list[str] provided by UI
current_state = None  # optional flattened index
current_row = None  # optional row if current_state not given
current_col = None  # optional col if current_state not given


In [None]:
import numpy as np, time, ast
try:
    import gymnasium as gym
except ImportError:
    import gym

# --- Normalize potentially stringified parameters ---
def _to_tuple_2(x):
    if isinstance(x, tuple) and len(x)==2:
        return (int(x[0]), int(x[1]))
    if isinstance(x, str):
        t = ast.literal_eval(x)
        return (int(t[0]), int(t[1]))
    t = tuple(x)
    return (int(t[0]), int(t[1]))

def _to_tuple_list(x):
    if isinstance(x, str):
        v = ast.literal_eval(x)
    else:
        v = x
    return [tuple(map(int,t)) for t in v]

start = _to_tuple_2(start)
goal = _to_tuple_2(goal)
holes = _to_tuple_list(holes)
if isinstance(map_desc, str):
    map_desc = ast.literal_eval(map_desc)

# --- Build map_desc if not supplied ---
def build_desc(rows, cols, start, goal, holes):
    arr = np.full((rows, cols), 'F', dtype='<U1')
    sr, sc = start; gr, gc = goal
    arr[sr, sc] = 'S'; arr[gr, gc] = 'G'
    for hr, hc in holes: arr[hr, hc] = 'H'
    return [''.join(r) for r in arr]

if map_desc is None:
    map_desc = build_desc(rows, cols, start, goal, holes)

# --- Create environment ---
env = gym.make(
    'FrozenLake-v1',
    is_slippery=bool(slippery),
    desc=map_desc,
    render_mode='human',
)

# Force starting state
env.reset()
if current_state is not None:
    try: env.unwrapped.s = int(current_state)
    except Exception: pass
elif (current_row is not None) and (current_col is not None):
    try: env.unwrapped.s = int(current_row)*cols + int(current_col)
    except Exception: pass
else:
    try: env.unwrapped.s = start[0]*cols + start[1]
    except Exception: pass

print(f'Start position in notebook set to index: {env.unwrapped.s}')
# Transition model reference
P = env.unwrapped.P
nS = env.observation_space.n
nA = env.action_space.n

# --- Policy Iteration Components ---
def policy_evaluation(pi, V=None, gamma=gamma, theta=theta):
    if V is None: V = np.zeros(nS, dtype=np.float64)
    else: V = np.array(V, dtype=np.float64, copy=True)
    while True:
        delta = 0.0
        for s in range(nS):
            v_old = V[s]; a = pi[s]; v_new = 0.0
            for (prob, ns, r, done) in P[s][a]:
                v_new += prob * (r + gamma * (0.0 if done else V[ns]))
            V[s] = v_new; delta = max(delta, abs(v_old - v_new))
        if delta < theta: break
    return V

def policy_improvement(V, gamma=gamma):
    pi = np.zeros(nS, dtype=int)
    for s in range(nS):
        q = np.zeros(nA, dtype=np.float64)
        for a in range(nA):
            for (prob, ns, r, done) in P[s][a]:
                q[a] += prob * (r + gamma * (0.0 if done else V[ns]))
        pi[s] = int(np.argmax(q))
    return pi

def policy_iteration(gamma=gamma, theta=theta):
    pi = np.random.randint(0, nA, size=nS, dtype=int)
    V = np.zeros(nS, dtype=np.float64)
    iters = 0
    while True:
        iters += 1
        V = policy_evaluation(pi, V, gamma=gamma, theta=theta)
        new_pi = policy_improvement(V, gamma=gamma)
        if np.array_equal(pi, new_pi):
            pi = new_pi; break
        pi = new_pi
    return pi, V, iters

pi_opt, V_opt, iters = policy_iteration(gamma=gamma, theta=theta)

def run_episode(env, pi):
    obs, info = env.reset()
    if current_state is not None:
        try: env.unwrapped.s = int(current_state); obs = env.unwrapped.s
        except Exception: pass
    elif (current_row is not None) and (current_col is not None):
        try: env.unwrapped.s = int(current_row)*cols + int(current_col); obs = env.unwrapped.s
        except Exception: pass
    terminated = truncated = False; total = 0.0; steps = 0
    while not (terminated or truncated):
        a = int(pi[obs])
        obs, r, terminated, truncated, info = env.step(a)
        total += r; steps += 1
    return total, steps, terminated, truncated

total_reward, steps, terminated, truncated = run_episode(env, pi_opt)
print(f'Episode -> reward: {total_reward}, steps: {steps}, terminated: {terminated}, truncated: {truncated}')
print(f'Converged in {iters} iterations')
side = int(np.sqrt(nS))
print('Optimal V (reshaped if square):')
if side*side == nS: print(np.round(V_opt.reshape(side, side), 3))
else: print(np.round(V_opt, 3))
arrow_map = {0:'←',1:'↓',2:'→',3:'↑'}
print('\nOptimal Policy:')
if side*side == nS:
    grid = np.array([arrow_map[a] for a in pi_opt]).reshape(side, side)
    for r in range(side): print(' '.join(grid[r]))
else: print(pi_opt)

# Keep window responsive briefly
try:
    import pygame
    for _ in range(80):
        pygame.event.pump(); time.sleep(0.03)
except Exception:
    time.sleep(2.0)
env.close()
