In [None]:
### Code cell 0 ###
import gymnasium as gym
import numpy as np
import matplotlib.pyplot as plt

np.set_printoptions(precision=3,suppress=True)   # print 3-decimal, no scientific

env=gym.make('FrozenLake-v1',is_slippery=True)   # 4x4 frozen lake, stochastic
obs,info=env.reset()

print("Initial State:",obs)
print("Observation space size:",env.observation_space.n)   # 16 states
print("Action space size:",env.action_space.n)             # 4 actions

P=env.unwrapped.P    # transition dict: P[s][a] -> list of (prob,next_state,reward,done)

# Reward model: print step, hole, goal rewards
all_rewards={r for s in P for a in P[s] for(prob,next_state,r,done) in P[s][a]}
print("Unique rewards in FrozenLake:",all_rewards)

reward_min=min(all_rewards)
reward_max=max(all_rewards)
print("Reward range:",(reward_min,reward_max))

# Value Iteration
def value_iteration(env,discount_factor=0.99,theta=1e-6,max_iterations=10000):
    nS=env.observation_space.n
    nA=env.action_space.n
    P=env.unwrapped.P
    V=np.zeros(nS)          # V(s) init to 0 for all states

    for i in range(max_iterations):
        delta=0             # track max change in this sweep
        for s in range(nS):
            q_sa=np.zeros(nA)   # store Q(s,a) for all actions
            for a in range(nA):
                for prob,next_state,reward,done in P[s][a]:
                    q_sa[a]+=prob*(reward+discount_factor*V[next_state])  # Bellman optimality
            new_v=np.max(q_sa)          # V_new(s)=max_a Q(s,a)
            delta=max(delta,abs(new_v-V[s]))
            V[s]=new_v
        if delta<theta:                 # stop when values almost stop changing
            break

    policy=extract_policy_from_v(env,V,discount_factor)  # greedy policy from V*
    return V,policy,i+1

# Extract policy
def extract_policy_from_v(env,V,discount_factor=0.99):
    nS=env.observation_space.n
    nA=env.action_space.n
    P=env.unwrapped.P
    policy=np.zeros((nS,nA))   # each row: prob over 4 actions

    for s in range(nS):
        q_sa=np.zeros(nA)
        for a in range(nA):
            for prob,next_state,reward,done in P[s][a]:
                q_sa[a]+=prob*(reward+discount_factor*V[next_state])
        best=np.argmax(q_sa)               # best action index
        policy[s]=np.eye(nA)[best]         # one-hot row
    return policy

# Plot V
def plot_values(env,V,gamma_label=""):
    plt.figure(figsize=(6,3))
    plt.plot(V)                            # V over state index 0..15
    title="Value Function (FrozenLake-v1)"
    if gamma_label:
        title+=" (gamma="+gamma_label+")"
    plt.title(title)
    plt.xlabel("State (0â€“15)")
    plt.ylabel("Value")
    plt.grid(True)
    plt.show()

# Plot policy
def plot_policy(env,policy,gamma_label=""):
    nS=env.observation_space.n
    actions=np.argmax(policy,axis=1)       # greedy action per state
    plt.figure(figsize=(6,3))
    plt.bar(np.arange(nS),actions)
    title="Greedy Policy"
    if gamma_label:
        title+=" (gamma="+gamma_label+")"
    plt.title(title)
    plt.xlabel("State")
    plt.ylabel("Action (0=Left,1=Down,2=Right,3=Up)")
    plt.show()

### Running VI
if __name__=="__main__":
    gammas=[0.9,0.99,0.6]   # test different discount factors
    all_V={}
    all_iters=[]

    for gamma in gammas:
        print("\n=== Running Value Iteration gamma=",gamma,"===")
        V_opt,policy_opt,iterations=value_iteration(env,discount_factor=gamma)
        all_V[gamma]=V_opt
        all_iters.append(iterations)

        print("Converged in",iterations,"iterations")

        plot_values(env,V_opt,gamma_label=str(gamma))
        plot_policy(env,policy_opt,gamma_label=str(gamma))

    # compare V* curves for different gammas
    plt.figure(figsize=(7,3))
    for gamma in gammas:
        plt.plot(all_V[gamma],label="gamma="+str(gamma))
    plt.title("Value Functions Comparison")
    plt.xlabel("State")
    plt.ylabel("Value")
    plt.legend()
    plt.grid(True)
    plt.show()

    # compare how many iterations each gamma took
    plt.figure(figsize=(5,3))
    plt.bar([str(g) for g in gammas],all_iters)
    plt.title("Iterations vs Gamma")
    plt.xlabel("Gamma")
    plt.ylabel("Iterations")
    plt.show()

    ## for gamma larger value function larger. care moe about future rewards
    ## larger gamma ore iterrations