In [1]:
import numpy as np
import random

class GridWorld:
    def __init__(self):
        self.rows = 4
        self.cols = 4
        self.state = (0, 0)
        self.goal = (3, 3)
        self.actions = [0, 1, 2, 3]

    def reset(self):
        self.state = (0, 0)
        return self.state

    def step(self, action):
        x, y = self.state
        if action == 0:   x = max(0, x - 1)
        elif action == 1: y = min(self.cols - 1, y + 1)
        elif action == 2: x = min(self.rows - 1, x + 1)
        elif action == 3: y = max(0, y - 1)
        self.state = (x, y)
        if self.state == self.goal: return self.state, 10, True
        else: return self.state, -1, False

def run_sarsa():
    env = GridWorld()
    Q = np.zeros((4, 4, 4))
    alpha = 0.1
    gamma = 0.9
    epsilon = 0.1
    
    print("Running SARSA (On-Policy)...")
    for _ in range(1000):
        state = env.reset()
        # Choose Action 1
        if random.random() < epsilon: action = random.choice(env.actions)
        else: action = np.argmax(Q[state])
        
        done = False
        while not done:
            next_state, reward, done = env.step(action)
            
            # Choose Action 2 (On-Policy)
            if random.random() < epsilon: next_action = random.choice(env.actions)
            else: next_action = np.argmax(Q[next_state])
            
            # SARSA Update
            target = reward + gamma * Q[next_state][next_action] * (not done)
            Q[state][action] += alpha * (target - Q[state][action])
            
            state = next_state
            action = next_action
            
    print("SARSA Q-Table Sample (State 0,0):", np.round(Q[0,0], 2))

if __name__ == "__main__":
    run_sarsa()

Running SARSA (On-Policy)...
SARSA Q-Table Sample (State 0,0): [-0.57  1.    0.4  -0.51]


In [2]:
# exp7_sarsa.py

import numpy as np
import random

class GridWorld4x4:
    def __init__(self, start_state=0, goal_state=15, max_steps=100):
        self.n_rows = 4
        self.n_cols = 4
        self.n_states = self.n_rows * self.n_cols
        self.n_actions = 4
        self.start_state = start_state
        self.goal_state = goal_state
        self.max_steps = max_steps

    def state_to_xy(self, s):
        return (s // self.n_cols, s % self.n_cols)

    def xy_to_state(self, r, c):
        return r * self.n_cols + c

    def reset(self):
        self.state = self.start_state
        self.steps = 0
        return self.state

    def step(self, action):
        r, c = self.state_to_xy(self.state)
        if action == 0:
            r = max(0, r - 1)
        elif action == 1:
            c = min(self.n_cols - 1, c + 1)
        elif action == 2:
            r = min(self.n_rows - 1, r + 1)
        elif action == 3:
            c = max(0, c - 1)
        ns = self.xy_to_state(r, c)
        self.state = ns
        self.steps += 1
        done = (ns == self.goal_state) or (self.steps >= self.max_steps)
        reward = 0 if ns == self.goal_state else -1
        return ns, reward, done, {}

def epsilon_greedy(Q, s, n_actions, eps):
    if random.random() < eps:
        return random.randrange(n_actions)
    return int(np.argmax(Q[s]))

def sarsa(env, num_episodes=2000, alpha=0.5, gamma=1.0, eps=0.1):
    Q = np.zeros((env.n_states, env.n_actions))
    for _ in range(num_episodes):
        s = env.reset()
        a = epsilon_greedy(Q, s, env.n_actions, eps)
        done = False
        while not done:
            ns, r, done, _ = env.step(a)
            a2 = epsilon_greedy(Q, ns, env.n_actions, eps)
            Q[s, a] = Q[s, a] + alpha * (r + gamma * Q[ns, a2] - Q[s, a])
            s, a = ns, a2
    policy = np.argmax(Q, axis=1)
    return Q, policy

if __name__ == "__main__":
    env = GridWorld4x4()
    Q, pi = sarsa(env)
    print("SARSA Q-table:")
    print(Q)
    print("\nDerived SARSA policy (0:U,1:R,2:D,3:L):")
    print(pi.reshape(4, 4))


SARSA Q-table:
[[-6.33471621 -5.38650269 -5.94220543 -6.35286445]
 [-5.48780575 -4.79188751 -4.23659414 -7.1486671 ]
 [-4.54383597 -4.71299329 -3.6026735  -6.41880734]
 [-4.31009506 -4.05996366 -2.02505467 -5.18282878]
 [-7.27501214 -4.55959271 -5.91773903 -5.66418541]
 [-5.30890974 -3.27555798 -5.29633922 -6.10114671]
 [-4.66526042 -2.21786723 -3.34499423 -5.69117701]
 [-3.69138677 -3.24195007 -1.00000861 -3.61250328]
 [-6.31885776 -3.05162992 -3.40321725 -5.05599167]
 [-6.11645427 -2.03797823 -3.43883653 -5.06029875]
 [-4.15106289 -1.00000001 -1.52274214 -3.38246145]
 [-2.28562982 -1.00095157  0.         -3.17671954]
 [-4.99861519 -2.03870098 -4.7748514  -3.9438077 ]
 [-3.64407212 -1.39115129 -3.31745796 -3.89413123]
 [-3.75059147  0.         -1.18245964 -3.33400978]
 [ 0.          0.          0.          0.        ]]

Derived SARSA policy (0:U,1:R,2:D,3:L):
[[1 2 2 2]
 [1 1 1 2]
 [1 1 1 2]
 [1 1 1 0]]
