In [None]:
import gymnasium as gym
import numpy as np
import matplotlib.pyplot as plt

np.set_printoptions(precision=3,suppress=True)  # pretty printing of numpy arrays

# 1) Environment setup
env=gym.make('FrozenLake-v1',is_slippery=True,render_mode='ansi')  # 4x4 slippery lake
obs,info=env.reset()

print("Initial State:",obs)
print("Action Space:",env.action_space)           # Discrete(4)
print("Observation Space:",env.observation_space) # Discrete(16)
print("Grid shape (rows, cols):",(4,4))

P=env.unwrapped.P  # transition model: P[s][a] = list of (prob,next_state,reward,done)
reward_min=min({r for s in P for a in P[s] for (_,_,r,_) in P[s][a]})
reward_max=max({r for s in P for a in P[s] for (_,_,r,_) in P[s][a]})
print("Reward range:",(reward_min,reward_max))

frame=env.render()  # ANSI text map with S,F,H,G
print(frame)

gamma=0.99          # discount factor
theta=1e-8          # small threshold for convergence
nS=env.observation_space.n  # number of states (16)
nA=env.action_space.n       # number of actions (4)

def plot(V,policy,col_ramp=1,dpi=175,draw_vals=True):
    # Visualize FrozenLake state values V and policy arrows
    plt.rcParams['figure.dpi']=dpi
    plt.rcParams.update({'axes.edgecolor':(0.32,0.36,0.38)})
    plt.rcParams.update({'font.size':6 if env.unwrapped.nrow==8 else 8})
    plt.figure(figsize=(3,3))

    desc=env.unwrapped.desc           # map layout: S,F,H,G
    nrow,ncol=desc.shape
    V_sq=V.reshape((nrow,ncol))       # reshape V into 4x4 grid

    plt.imshow(V_sq,cmap='cool' if col_ramp else 'gray',alpha=0.7)
    ax=plt.gca()

    arrow_dict={0:'←',1:'↓',2:'→',3:'↑'}  # action arrows

    # draw grid lines
    for x in range(ncol+1):
        ax.axvline(x-0.5,lw=0.5,color='black')
    for y in range(nrow+1):
        ax.axhline(y-0.5,lw=0.5,color='black')

    # fill each cell with tile label, value, and arrow
    for r in range(nrow):
        for c in range(ncol):
            s=r*ncol+c
            val=V[s]
            tile=desc[r,c].decode('utf-8')  # b'S' -> 'S'

            if tile=='H':color='red'
            elif tile=='G':color='green'
            elif tile=='S':color='blue'
            else:color='black'

            plt.text(c,r,tile,ha='center',va='center',
                     color=color,fontsize=10,fontweight='bold')

            if draw_vals and tile not in ['H']:  # do not print value on holes
                plt.text(c,r+0.3,f"{val:.2f}",ha='center',va='center',
                         color='black',fontsize=6)

            if policy is not None:              # best action arrow
                best_action=np.argmax(policy[s])
                plt.text(c,r-0.25,arrow_dict[best_action],ha='center',va='center',
                         color='purple',fontsize=12)

    plt.title("FrozenLake: V and π")
    plt.axis('off')
    plt.show()


def policy_evaluation(env,policy,discount_factor=0.99,theta=1e-8):
    nS=env.observation_space.n
    nA=env.action_space.n
    P=env.unwrapped.P
    V=np.zeros(nS)  # start with V(s)=0

    while True:
        delta=0.0
        for s in range(nS):
            v=0.0
            # sum over actions and next states: Bellman expectation
            for a,action_prob in enumerate(policy[s]):
                if action_prob==0:continue
                for prob,next_state,reward,done in P[s][a]:
                    v+=action_prob*prob*(reward+discount_factor*V[next_state])
            delta=max(delta,abs(V[s]-v))
            V[s]=v
        if delta<theta:break
    return V


def q_from_v(env,V,s,gamma=0.99):
    P=env.unwrapped.P
    nA=env.action_space.n
    q=np.zeros(nA)
    for a in range(nA):
        for prob,next_state,reward,done in P[s][a]:
            q[a]+=prob*(reward+gamma*V[next_state])
    return q


def policy_improvement(env,V,discount_factor=0.99):
    nS=env.observation_space.n
    nA=env.action_space.n
    policy=np.zeros((nS,nA))
    for s in range(nS):
        Q=q_from_v(env,V,s,discount_factor)
        best_action=np.argmax(Q)
        policy[s]=np.eye(nA)[best_action]  # one-hot
    return policy


def policy_iteration(env,discount_factor=0.99,theta=1e-8,max_iter=100):
    nS=env.observation_space.n
    nA=env.action_space.n

    # start with random (uniform) policy
    policy=np.ones((nS,nA))/nA

    stable=False
    it=0
    while not stable and it<max_iter:
        it+=1
        # 1) policy evaluation: V^π
        V=policy_evaluation(env,policy,discount_factor,theta)

        # 2) policy improvement: greedy w.r.t V^π
        new_policy=policy_improvement(env,V,discount_factor)

        # check if policy changed
        if np.array_equal(new_policy,policy):
            stable=True
        policy=new_policy

    return V,policy,it


V_pi,policy_pi,iters=policy_iteration(env,discount_factor=gamma,theta=theta)
print(f"\nPolicy iteration converged in {iters} iterations.")

print("\nFinal value function V^π (4x4):")
print(V_pi.reshape(4,4))

print("\nFinal policy (one-hot over actions):")
print(policy_pi)

plot(V_pi,policy_pi,draw_vals=True)
