In [None]:
import gymnasium as gym
import numpy as np
import matplotlib.pyplot as plt
import copy

np.set_printoptions(precision=3,suppress=True)

# ========= 1) Environment setup =========
env=gym.make('FrozenLake-v1',is_slippery=True,render_mode='ansi')
obs,info=env.reset(seed=0)

print("Initial State:",obs)
print("Observation space size:",env.observation_space.n)  # 16 states (4x4)
print("Action space size:",env.action_space.n)            # 4 actions (L,D,R,U)

P_orig=env.unwrapped.P
reward_min=min({r for s in P_orig for a in P_orig[s] for (_,_,r,_) in P_orig[s][a]})
reward_max=max({r for s in P_orig for a in P_orig[s] for (_,_,r,_) in P_orig[s][a]})
print("Reward range (env default):",(reward_min,reward_max))

print(env.render())  # text map

desc=env.unwrapped.desc              # 2D array of bytes: b'S', b'F', b'H', b'G'
flat_desc=desc.flatten()             # map next_state -> tile char


# ========= 2) Helper: modify rewards in transition model =========
def modify_rewards_frozenlake(P_base,flat_desc,
                              step_reward=0.0,
                              hole_reward=0.0,
                              goal_reward=1.0):
    """
    Returns a deep-copied transition model with modified rewards
    depending on what tile the NEXT STATE is:

      S/F -> step_reward
      H   -> hole_reward
      G   -> goal_reward
    """
    P_new=copy.deepcopy(P_base)
    for s in P_new:
        for a in P_new[s]:
            new_list=[]
            for prob,next_state,old_r,done in P_new[s][a]:
                tile=flat_desc[next_state].decode('utf-8')
                if tile=='G':
                    r_new=goal_reward
                elif tile=='H':
                    r_new=hole_reward
                else:
                    r_new=step_reward
                new_list.append((prob,next_state,r_new,done))
            P_new[s][a]=new_list
    return P_new


# ========= 3) Value Iteration (takes P) =========
def value_iteration(env,P,discount_factor=0.99,theta=1e-8,max_iterations=1000):
    nS=env.observation_space.n
    nA=env.action_space.n
    V=np.zeros(nS)
    deltas=[]

    for it in range(max_iterations):
        delta=0.0
        for s in range(nS):
            q_sa=np.zeros(nA)
            for a in range(nA):
                for prob,next_state,reward,done in P[s][a]:
                    q_sa[a]+=prob*(reward+discount_factor*V[next_state])
            new_v=np.max(q_sa)
            delta=max(delta,abs(new_v-V[s]))
            V[s]=new_v
        deltas.append(delta)
        if delta<theta:
            break

    policy=extract_policy_from_v(env,P,V,discount_factor)
    return V,policy,it+1,deltas


# ========= 4) Policy extraction (uses P and V) =========
def extract_policy_from_v(env,P,V,discount_factor=0.99):
    nS=env.observation_space.n
    nA=env.action_space.n
    policy=np.zeros((nS,nA))
    for s in range(nS):
        q_sa=np.zeros(nA)
        for a in range(nA):
            for prob,next_state,reward,done in P[s][a]:
                q_sa[a]+=prob*(reward+discount_factor*V[next_state])
        best_action=np.argmax(q_sa)
        policy[s]=np.eye(nA)[best_action]
    return policy


# ========= 5) Evaluate policy on real env =========
def evaluate_policy(env,policy,n_episodes=500):
    success=0
    total_steps_success=0

    for _ in range(n_episodes):
        obs,_=env.reset()
        done=False
        steps=0

        while not done:
            action=np.argmax(policy[obs])
            obs,reward,terminated,truncated,_=env.step(action)
            done=terminated or truncated
            steps+=1
            if done and reward>0:  # reached goal G
                success+=1
                total_steps_success+=steps

    success_rate=success/n_episodes
    avg_steps_success=(total_steps_success/success) if success>0 else float('inf')
    return success_rate,avg_steps_success


# ========= 6) Plot helpers =========
def plot_grid_values_and_policy(env,V,policy,label=""):
    desc=env.unwrapped.desc
    nrow,ncol=desc.shape
    V_grid=V.reshape((nrow,ncol))

    plt.figure(figsize=(3,3))
    plt.imshow(V_grid,cmap='cool',alpha=0.7)
    ax=plt.gca()

    arrow_dict={0:'←',1:'↓',2:'→',3:'↑'}

    for x in range(ncol+1):
        ax.axvline(x-0.5,lw=0.5,color='black')
    for y in range(nrow+1):
        ax.axhline(y-0.5,lw=0.5,color='black')

    for r in range(nrow):
        for c in range(ncol):
            s=r*ncol+c
            tile=desc[r,c].decode('utf-8')
            v=V[s]

            if tile=='H':color='red'
            elif tile=='G':color='green'
            elif tile=='S':color='blue'
            else:color='black'

            plt.text(c,r,tile,ha='center',va='center',
                     color=color,fontsize=10,fontweight='bold')
            if tile!='H':
                plt.text(c,r+0.3,f"{v:.2f}",ha='center',va='center',
                         color='black',fontsize=6)

            best_action=np.argmax(policy[s])
            plt.text(c,r-0.25,arrow_dict[best_action],
                     ha='center',va='center',color='purple',fontsize=12)

    title="FrozenLake: V and π*"
    if label:
        title+=f" ({label})"
    plt.title(title)
    plt.axis('off')
    plt.show()


# ========= 7) Test different reward structures =========
if __name__=="__main__":
    gamma=0.99

    # (name, step_reward, hole_reward, goal_reward)
    reward_settings=[
        ("default",        0.0,   0.0,   1.0),   # original env
        ("step_penalty",  -0.05,  0.0,   1.0),   # penalize every move
        ("hole_penalty",   0.0,  -1.0,   1.0),   # strongly punish falling into holes
        ("high_goal",      0.0,   0.0,   5.0),   # larger reward for reaching goal
    ]

    all_V={}
    all_iters={}
    all_deltas={}
    success_rates={}
    avg_steps={}

    for name,step_r,hole_r,goal_r in reward_settings:
        print(f"\n=== Running Value Iteration with reward setting: {name} ===")
        P_mod=modify_rewards_frozenlake(P_orig,flat_desc,
                                        step_reward=step_r,
                                        hole_reward=hole_r,
                                        goal_reward=goal_r)

        V_opt,policy_opt,iters,deltas=value_iteration(
            env,P_mod,discount_factor=gamma
        )

        all_V[name]=V_opt
        all_iters[name]=iters
        all_deltas[name]=deltas

        print(f"Converged in {iters} iterations for setting '{name}'")

        rate,avg_steps_succ=evaluate_policy(env,policy_opt,n_episodes=500)
        success_rates[name]=rate
        avg_steps[name]=avg_steps_succ

        print(f"Success rate ({name}): {rate*100:.2f}%")
        print(f"Avg steps / successful episode ({name}): {avg_steps_succ:.2f}")

        plot_grid_values_and_policy(env,V_opt,policy_opt,label=name)

    # ---- Convergence (delta) comparison ----
    plt.figure(figsize=(8,4))
    for cfg in reward_settings:
        name=cfg[0]
        plt.plot(all_deltas[name],label=name)
    plt.title("Value Iteration Convergence for Different Reward Structures (FrozenLake)")
    plt.xlabel("Iteration")
    plt.ylabel("Delta (max |V_new - V_old|)")
    plt.yscale("log")
    plt.grid(True)
    plt.legend()
    plt.show()

    # ---- Success rate vs reward structure ----
    labels=[cfg[0] for cfg in reward_settings]
    sr_vals=[success_rates[name]*100 for name in labels]
    plt.figure(figsize=(6,4))
    plt.bar(labels,sr_vals)
    plt.title("Success Rate vs Reward Structure (FrozenLake)")
    plt.ylabel("Success Rate (%)")
    plt.xticks(rotation=20)
    plt.show()

    # ---- Average steps vs reward structure ----
    steps_vals=[avg_steps[name] for name in labels]
    plt.figure(figsize=(6,4))
    plt.bar(labels,steps_vals)
    plt.title("Average Steps per Successful Episode vs Reward Structure")
    plt.ylabel("Avg Steps (successful episodes)")
    plt.xticks(rotation=20)
    plt.show()

    # ---- Iterations to converge vs reward structure ----
    it_vals=[all_iters[name] for name in labels]
    plt.figure(figsize=(6,4))
    plt.bar(labels,it_vals)
    plt.title("Convergence Speed vs Reward Structure (FrozenLake)")
    plt.ylabel("Iterations to Converge")
    plt.xticks(rotation=20)
    plt.show()
