<a href="https://colab.research.google.com/github/ManupatiEshwar/reniforecement/blob/main/Lab%203.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import gymnasium as gym
import numpy as np
from collections import defaultdict
from typing import Callable, Tuple

In [None]:
def epsilon_greedy(Q: np.ndarray, s: int, epsilon: float) -> int:
    if np.random.rand() < epsilon:
        return np.random.randint(Q.shape[1])
    return int(np.argmax(Q[s]))

def run_episode(env, policy: Callable[[int], int], gamma: float = 0.99, max_steps: int = 1000) -> Tuple[float, int]:
    s, _ = env.reset()
    total_reward, steps = 0.0, 0
    for _ in range(max_steps):
        a = policy(s)
        s_next, r, terminated, truncated, _ = env.step(a)
        total_reward += r * (gamma ** steps)
        steps += 1
        s = s_next
        if terminated or truncated:
            break
    return total_reward, steps

# TD(0) Policy Evaluation

In [None]:
def td0_policy_evaluation(env_id: str = "FrozenLake-v1",
                          is_slippery: bool = True,
                          gamma: float = 0.99,
                          alpha: float = 0.1,
                          epsilon_random_policy: float = 1.0,
                          episodes: int = 20_000,
                          seed: int = 0) -> np.ndarray:
    """
    Estimates V^pi for a fixed policy using TD(0).
    By default, pi is a fully random policy (epsilon_random_policy = 1.0).
    """
    env = gym.make(env_id, is_slippery=is_slippery)
    rng = np.random.default_rng(seed)
    env.reset(seed=seed)

    nS = env.observation_space.n
    nA = env.action_space.n
    V = np.zeros(nS, dtype=np.float64)

    def policy(s: int) -> int:
        # Equiprobable random (or epsilon-random around greedy on V, but typically random)
        if rng.random() < epsilon_random_policy:
            return rng.integers(nA)
        # If not fully random, do a crude one-step lookahead using V (optional)
        return rng.integers(nA)

    for ep in range(episodes):
        s, _ = env.reset()
        done = False
        while not done:
            a = policy(s)
            s_next, r, terminated, truncated, _ = env.step(a)
            td_target = r + (0.0 if (terminated or truncated) else gamma * V[s_next])
            V[s] += alpha * (td_target - V[s])
            s = s_next
            done = terminated or truncated

    env.close()
    return V

# SARSA (on-policy TD control)

In [None]:
def sarsa_control(env_id: str = "FrozenLake-v1",
                  is_slippery: bool = True,
                  gamma: float = 0.99,
                  alpha: float = 0.1,
                  epsilon_start: float = 1.0,
                  epsilon_end: float = 0.05,
                  epsilon_decay_steps: int = 50_000,
                  episodes: int = 100_000,
                  max_steps_per_ep: int = 200,
                  seed: int = 1) -> Tuple[np.ndarray, np.ndarray]:

    env = gym.make(env_id, is_slippery=is_slippery)
    env.reset(seed=seed)
    rng = np.random.default_rng(seed)

    nS = env.observation_space.n
    nA = env.action_space.n
    Q = np.zeros((nS, nA), dtype=np.float64)

    def epsilon_by_step(t: int) -> float:
        # Linear decay
        frac = min(1.0, max(0.0, t / max(1, epsilon_decay_steps)))
        return epsilon_start + (epsilon_end - epsilon_start) * frac

    timestep = 0
    for ep in range(episodes):
        s, _ = env.reset()
        eps = epsilon_by_step(timestep)
        a = rng.integers(nA) if rng.random() < eps else int(np.argmax(Q[s]))
        for _ in range(max_steps_per_ep):
            s_next, r, terminated, truncated, _ = env.step(a)
            eps_next = epsilon_by_step(timestep + 1)
            a_next = rng.integers(nA) if rng.random() < eps_next else int(np.argmax(Q[s_next]))
            td_target = r + (0.0 if (terminated or truncated) else gamma * Q[s_next, a_next])
            Q[s, a] += alpha * (td_target - Q[s, a])

            timestep += 1
            s, a = s_next, a_next
            if terminated or truncated:
                break

    policy_greedy = np.argmax(Q, axis=1)
    env.close()
    return Q, policy_greedy


#Example

In [None]:
if __name__ == "__main__":
    # TD(0) prediction under a random policy
    V = td0_policy_evaluation(
        env_id="FrozenLake-v1",
        is_slippery=True,   # set False for deterministic grid
        gamma=0.99,
        alpha=0.1,
        epsilon_random_policy=1.0,
        episodes=25_000,
        seed=42
    )
    print("TD(0) Value function estimate (V):")
    print(V.reshape(4, 4))  # for 4x4 FrozenLake

TD(0) Value function estimate (V):
[[0.01154136 0.00756639 0.02747178 0.013248  ]
 [0.01991066 0.         0.0626995  0.        ]
 [0.03684871 0.0758343  0.18880016 0.        ]
 [0.         0.16529731 0.52072376 0.        ]]


In [None]:
    Q, pi = sarsa_control(
        env_id="FrozenLake-v1",
        is_slippery=True,
        gamma=0.99,
        alpha=0.1,
        epsilon_start=1.0,
        epsilon_end=0.05,
        epsilon_decay_steps=80_000,
        episodes=120_000,
        max_steps_per_ep=200,
        seed=7
    )
    print("\nGreedy policy from SARSA (actions 0:Left, 1:Down, 2:Right, 3:Up):")
    print(pi.reshape(4, 4))
    print("\nState-action values Q[s,a] (reshaped per action for readability):")
    for a in range(Q.shape[1]):
        print(f"Action {a}:")
        print(Q[:, a].reshape(4, 4))



Greedy policy from SARSA (actions 0:Left, 1:Down, 2:Right, 3:Up):
[[0 3 1 3]
 [0 0 2 0]
 [3 1 0 0]
 [0 2 1 0]]

State-action values Q[s,a] (reshaped per action for readability):
Action 0:
[[0.35879099 0.256256   0.2354679  0.12917957]
 [0.40523588 0.         0.15157585 0.        ]
 [0.2666601  0.318245   0.56247467 0.        ]
 [0.         0.30634394 0.63039223 0.        ]]
Action 1:
[[0.31427124 0.19830804 0.24873389 0.11229348]
 [0.23471058 0.         0.169121   0.        ]
 [0.26211337 0.56930223 0.30913391 0.        ]
 [0.         0.38829745 0.85720356 0.        ]]
Action 2:
[[0.2981604  0.10287446 0.2361989  0.14233501]
 [0.24663785 0.         0.21868749 0.        ]
 [0.25617631 0.40590125 0.30485473 0.        ]
 [0.         0.70056179 0.62130542 0.        ]]
Action 3:
[[0.31968015 0.26972596 0.23349976 0.22772739]
 [0.29724172 0.         0.05580837 0.        ]
 [0.47997678 0.27368312 0.16935533 0.        ]
 [0.         0.49462199 0.63985899 0.        ]]
