In [None]:
### Code cell 0 ###
import gymnasium as gym
import numpy as np
import matplotlib.pyplot as plt

np.set_printoptions(precision=3,suppress=True)  # pretty printing of numpy arrays

# 1) Environment setup
env=gym.make('FrozenLake-v1',is_slippery=True,render_mode='ansi')  # 4x4 slippery lake
obs,info=env.reset()

print("Initial State:",obs)
print("Action Space:",env.action_space)           # Discrete(4)
print("Observation Space:",env.observation_space) # Discrete(16)
print("Grid shape (rows, cols):",(4,4))

# reward range from transition model
P=env.unwrapped.P  # transition model: P[s][a] = list of (prob,next_state,reward,done)
reward_min=min({r for s in P for a in P[s] for (_,_,r,_) in P[s][a]})
reward_max=max({r for s in P for a in P[s] for (_,_,r,_) in P[s][a]})
print("Reward range:",(reward_min,reward_max))

# show text rendering of the map
frame=env.render()  # ANSI text map with S,F,H,G
print(frame)

gamma=0.99          # discount factor
theta=1e-8          # small threshold for convergence
nS=env.observation_space.n  # number of states (16)
nA=env.action_space.n       # number of actions (4)

# 2) Define a fixed (random) policy π (NOT optimal)
policy=np.ones((nS,nA))/nA   # uniform: each action has prob 1/nA in every state
print("\nRandom fixed policy (each row = probs over 4 actions):")
print(policy)

# 3) Policy evaluation: compute V^π for this fixed policy (no max over actions)
def policy_evaluation(env,policy,gamma=0.99,theta=1e-8,max_iter=1000):
    P=env.unwrapped.P
    nS=env.observation_space.n
    nA=env.action_space.n
    V=np.zeros(nS)          # start with V(s)=0 for all states
    it=0
    while it<max_iter:
        delta=0.0
        for s in range(nS):
            v_old=V[s]
            v_new=0.0
            # Bellman expectation: sum_a π(a|s) * sum_{s'} P(s'|s,a)[r + γ V(s')]
            for a,action_prob in enumerate(policy[s]):
                if action_prob==0:continue
                for prob,next_s,reward,done in P[s][a]:
                    v_new+=action_prob*prob*(reward+gamma*V[next_s])
            V[s]=v_new
            delta=max(delta,abs(v_old-v_new))
        it+=1
        if delta<theta:break
    print(f"\nPolicy evaluation converged in {it} iterations.")
    return V

V_pi=policy_evaluation(env,policy,gamma,theta)
print("\nState values V^π (as 4x4 grid, under random policy):")
print(V_pi.reshape(4,4))

# 4) Greedy policy improvement (one step) using V^π
def q_from_v(env,V,s,gamma=0.99):
    P=env.unwrapped.P
    nA=env.action_space.n
    q=np.zeros(nA)
    for a in range(nA):
        for prob,next_state,reward,done in P[s][a]:
            q[a]+=prob*(reward+gamma*V[next_state])
    return q

def policy_improvement(env,V,discount_factor=0.99):
    nS=env.observation_space.n
    nA=env.action_space.n
    new_policy=np.zeros((nS,nA))
    for s in range(nS):
        Q=q_from_v(env,V,s,discount_factor)   # Q(s,a) from V^π
        best_action=np.argmax(Q)              # greedy: argmax_a Q(s,a)
        new_policy[s]=np.eye(nA)[best_action] # one-hot row
    return new_policy

policy_improved=policy_improvement(env,V_pi,gamma)

print("\nOld random policy (π):")
print(policy)
print("\nGreedy improved policy (π') from V^π:")
print(policy_improved)

# 5) Visualization helper (same as before)
def plot(V,policy,col_ramp=1,dpi=175,draw_vals=True,title_suffix=""):
    plt.rcParams['figure.dpi']=dpi
    plt.rcParams.update({'axes.edgecolor':(0.32,0.36,0.38)})
    plt.rcParams.update({'font.size':6 if env.unwrapped.nrow==8 else 8})
    plt.figure(figsize=(3,3))

    desc=env.unwrapped.desc           # map layout: S,F,H,G
    nrow,ncol=desc.shape
    V_sq=V.reshape((nrow,ncol))       # reshape V into 4x4 grid

    plt.imshow(V_sq,cmap='cool' if col_ramp else 'gray',alpha=0.7)
    ax=plt.gca()

    arrow_dict={0:'←',1:'↓',2:'→',3:'↑'}  # action arrows

    # draw grid lines
    for x in range(ncol+1):
        ax.axvline(x-0.5,lw=0.5,color='black')
    for y in range(nrow+1):
        ax.axhline(y-0.5,lw=0.5,color='black')

    # fill each cell with tile label, value, and arrow
    for r in range(nrow):
        for c in range(ncol):
            s=r*ncol+c
            val=V[s]
            tile=desc[r,c].decode('utf-8')  # b'S' -> 'S'

            if tile=='H':color='red'
            elif tile=='G':color='green'
            elif tile=='S':color='blue'
            else:color='black'

            plt.text(c,r,tile,ha='center',va='center',
                     color=color,fontsize=10,fontweight='bold')

            if draw_vals and tile not in ['H']:  # do not print value on holes
                plt.text(c,r+0.3,f"{val:.2f}",ha='center',va='center',
                         color='black',fontsize=6)

            if policy is not None:
                best_action=np.argmax(policy[s])
                plt.text(c,r-0.25,arrow_dict[best_action],ha='center',va='center',
                         color='purple',fontsize=12)

    plt.title(f"FrozenLake: V^π and policy {title_suffix}")
    plt.axis('off')
    plt.show()

# 6) Plot old and improved policies (with same V^π just for comparison of arrows)
plot(V_pi,policy,draw_vals=True,title_suffix="(old random π)")
plot(V_pi,policy_improved,draw_vals=True,title_suffix="(greedy improved π')")
