<a href="https://colab.research.google.com/github/ThomasWong-ST/Intro-to-RL/blob/main/Function_Approximation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import numpy as np

#Exp 9.1: State Aggregation on the 1000-state Random Walk

In [2]:
class RandomWalk:
    """
    Implements the 1000-state random walk environment described in
    Sutton and Barto's Reinforcement Learning, Example 9.1 (page 203).
    """
    def __init__(self):
        self.n_states = 1000
        self.start_state = 500
        self.step_range = 100

        # Terminal states are outside the 1-1000 range.
        # State 0 is the left terminal state, 1001 is the right.
        self.left_terminal = 0
        self.right_terminal = self.n_states + 1

        self.terminal_rewards = (-1, 1)  # (left_reward, right_reward)
        self.state = None

        # Pre-calculate the possible moves for efficiency
        moves = np.arange(1, self.step_range + 1)
        self._possible_moves = np.concatenate((-moves, moves))

    def reset(self):
        """
        Resets the environment to the starting state.
        All episodes begin in state 500.
        """
        self.state = self.start_state
        return self.state

    def step(self):
        """
        Takes a random step of 1-100 units to the left or right.

        Returns:
            next_state (int): The state after the move. Can be a terminal state.
            reward (int): The reward for this transition.
            done (bool): True if the episode has terminated.
        """
        if self.state is None:
            raise ValueError("You must call reset() before calling step()")

        # Select a random move from [-100, -1] U [1, 100]
        move = np.random.choice(self._possible_moves)
        next_state = self.state + move

        # Check for termination
        if next_state < 1:
            reward = self.terminal_rewards[0]
            done = True
            self.state = self.left_terminal
            return self.left_terminal, reward, done
        elif next_state > self.n_states:
            reward = self.terminal_rewards[1]
            done = True
            self.state = self.right_terminal
            return self.right_terminal, reward, done
        else:
            reward = 0
            done = False
            self.state = next_state
            return self.state, reward, done

    def get_states(self):
        """Return list of nonterminal state indices."""
        return list(range(1, self.n_states + 1))


1. Weight Vector (w): We have a weight vector, let's call it w = [w₀, w₁, ..., w₉]ᵀ. It has 10 components, one for each group of states.

2. Feature Vector (x(s)): For state aggregation, the feature vector x(s) is constructed in a special way. It's a vector that has:

- The same number of components as the weight vector (10 in our case).

- A value of 1 at the index corresponding to the group the state s belongs to.

- A value of 0 for all other components.

3. This is often called a one-hot encoding. For example:

- If state s = 42 is in group 0, then x(42) = [1, 0, 0, 0, 0, 0, 0, 0, 0, 0]ᵀ.

- If state s = 350 is in group 3, then x(350) = [0, 0, 0, 1, 0, 0, 0, 0, 0, 0]ᵀ.

4. Calculating wᵀx(s): When we compute the dot product: wᵀx(s) = w₀*x₀(s) + w₁*x₁(s) + ... + w₉*x₉(s)

Since only one component of x(s) is 1 (say, at index i) and the rest are 0, this simplifies to:

wᵀx(s) = wᵢ * 1 = wᵢ

...where i is the index returned by our get_group(s) function.

In [3]:
def get_group(state, num_groups=10, states_per_group=100):
  """ Returns the index of the group the state belongs to. """
  # Ensure state is within valid range (1 to 1000 for non-terminals)
  if state <= 0 or state > num_groups * states_per_group:
        # Handle terminal or invalid states - they don't map to a feature
        return None
  return (state - 1) // states_per_group # group indices 0 to 9

def v_hat_RW(state, w, num_groups=10, states_per_group=100):
  """ Approximate value function using state aggregation. """
  group_index = get_group(state, num_groups, states_per_group)
  if group_index is None:
      return 0 # Value of terminal states is 0
  return w[group_index]

def grad_v_hat(state, num_weights=10, num_groups=10, states_per_group=100):
  """ Gradient of the approximate value function for state aggregation. """
  group_index = get_group(state, num_groups, states_per_group)
  gradient = np.zeros(num_weights)
  if group_index is not None:
      gradient[group_index] = 1.0
  return gradient

def generate_episdoe(env, policy, max_steps=100):
  states, actions, rewards = [], [], []
  state = env.reset()
  for episode_num in range(max_steps):
      # Generate an episode using the environment's dynamics
      states, rewards = [], []
      state = env.reset() # Starts at 500
      done = False
      while not done:
            # In this random walk, there's no action choice by the policy.
            # The environment dictates the next state.
          next_state, reward, done = env.step()
          states.append(state)
          rewards.append(reward)
          state = next_state
          # Optional: Add a max_steps check here if needed
      return states, actions, rewards

In [4]:
env = RandomWalk()
dummy_policy = lambda s: None
generated_episode = generate_episdoe(env, dummy_policy)
print(generated_episode)

([500, np.int64(434), np.int64(354), np.int64(375), np.int64(405), np.int64(470), np.int64(538), np.int64(624), np.int64(555), np.int64(494), np.int64(556), np.int64(465), np.int64(460), np.int64(512), np.int64(416), np.int64(366), np.int64(354), np.int64(368), np.int64(359), np.int64(290), np.int64(345), np.int64(406), np.int64(335), np.int64(275), np.int64(204), np.int64(129), np.int64(207), np.int64(248), np.int64(340), np.int64(321), np.int64(260), np.int64(211), np.int64(266), np.int64(240), np.int64(312), np.int64(365), np.int64(462), np.int64(466), np.int64(454), np.int64(460), np.int64(538), np.int64(563), np.int64(485), np.int64(398), np.int64(437), np.int64(395), np.int64(320), np.int64(399), np.int64(443), np.int64(367), np.int64(400), np.int64(330), np.int64(396), np.int64(420), np.int64(410), np.int64(318), np.int64(328), np.int64(423), np.int64(489), np.int64(464), np.int64(513), np.int64(450), np.int64(397), np.int64(363), np.int64(388), np.int64(289), np.int64(251), np.

#Gradient Monte Carlo Algorithm for Estimating $\hat{v} ≈ v_{π}$

In summary of what is happening with MC gradient. We define v hat and it's gradient, based on the structure of our example, it makes sense to choose a 10 element vector, where each element is a representation of an equal divid of the 1000 states (i.e. something like 0-100 for first element, 101 to 200 to second,...), and the gradient is just  $\nabla\hat{v}(s_t,\mathbf{w}) = \left( \frac{\partial\hat{v}(s_t,\mathbf{w})}{\partial w_1}, \frac{\partial\hat{v}(s_t,\mathbf{w})}{\partial w_2}, \dots, \frac{\partial\hat{v}(s_t,\mathbf{w})}{\partial w_d} \right)^T$.

We than have G_t which is obtained by standard MC return calculate with a whole episode, and than afterwards the weights are calculated using the update rule. My intuition here is that as the states are being explored, the agent will reach the terminal states and the reward of the terminal states will flow backward towards the starting position. and because the return for the first 5 element of v hat are bad, this naturally reduces the weights and the next 5 elements in v hat are good, so that naturally pushes the weights up.

Note that this specific Gradient Monte Carlo algorithm is for prediction (evaluating states), not control (choosing actions). In the 1000-state random walk (Example 9.1), the "agent" doesn't actually choose its moves; the environment's step() function randomly moves it +/- 1 to 100 states. The algorithm learns the value of the states under this random movement policy. It doesn't learn a policy to prefer moving right.

In [53]:
def gradient_mc_prediction(env, policy, num_episodes=1000,max_steps=100, alpha=2e-5, gamma=1.0, num_weights=10):
    """
    Implements Gradient Monte Carlo prediction (page 202).
    Uses state aggregation from Example 9.1.
    Note: 'policy' is not used here as the environment defines transitions,
          but kept for consistency with your previous structure.
    """
    w = np.zeros(num_weights) # Initialize weights

    for episode_num in range(num_episodes):
      states, actions, rewards = generate_episdoe(env, policy, max_steps)
      G = 0
      # Loop backwards through the episode steps
      for t in reversed(range(len(states))):
          G = gamma * G + rewards[t]
          St = states[t]

          # Calculate gradient and update weights
          # No need for first-visit check with function approximation
          gradient = grad_v_hat(St, num_weights)
          delta = G - v_hat_RW(St, w) # Target G, prediction v_hat
          w += alpha * delta * gradient # The core update rule (9.7)

      # Optional: Add print statements to track progress, e.g., print RMS error
      # if (episode_num + 1) % 100 == 0:
      #     print(f"Episode {episode_num+1} finished.")


    return w # Return the learned weights

In [54]:
env = RandomWalk()

# The policy argument isn't really needed for this specific env.step()
# but we pass a dummy lambda function.
dummy_policy = lambda s: None

# Use alpha from the example figure caption [cite: 869]
alpha_from_example = 2e-5

learned_weights = gradient_mc_prediction(env, dummy_policy, num_episodes=30000, alpha=alpha_from_example)

print("Learned Weights:", learned_weights)

# You can now use v_hat(state, learned_weights) to get predictions
print("\nExample Predictions:")
for s in [1, 50, 100, 500, 901, 950, 1000]:
  print(f"  State {s} (Group {get_group(s)}): Predicted Value = {v_hat_RW(s, learned_weights):.4f}")

Learned Weights: [-0.63807779 -0.61957346 -0.46703318 -0.27897621 -0.08998633  0.08614611
  0.27091379  0.45763158  0.62571118  0.64410345]

Example Predictions:
  State 1 (Group 0): Predicted Value = -0.6381
  State 50 (Group 0): Predicted Value = -0.6381
  State 100 (Group 0): Predicted Value = -0.6381
  State 500 (Group 4): Predicted Value = -0.0900
  State 901 (Group 9): Predicted Value = 0.6441
  State 950 (Group 9): Predicted Value = 0.6441
  State 1000 (Group 9): Predicted Value = 0.6441


#Semi Gradient TD(0) for Estimating $\hat{v} ≈ v_{π}$

In [16]:
def semi_gradient_TD_zero(env, policy, num_episodes=1000,max_steps=500, alpha=2e-5, gamma=1.0, num_weights=10):
  """
  For this specific RandomWalk environment, the step()
  method internally randomizes the next state without needing
  an action a as input. So, the policy argument to gradient_TD_zero
  isn't actually used to choose actions.
  """
  w = np.zeros(num_weights)
  for ep in range(num_episodes):
    s = env.reset()
    #a = policy
    steps = 0

    while steps < max_steps:
        s_next, r, done = env.step()

        if done:
            # terminal target: v_hat(S',w) = 0
            td_target = r
            w += alpha * (td_target - v_hat_RW(s, w))*grad_v_hat(s, num_weights)
            steps += 1
            break

        # on-policy action for next state
        #a_next = policy(Q, s_next)

        # Q-learning target and update
        #max_action = np.max(Q[s_next])
        td_target = r + gamma * v_hat_RW(s_next, w) - v_hat_RW(s, w)
        w += alpha * td_target * grad_v_hat(s, num_weights)

        s = s_next
        steps += 1
  return w

In [18]:
env = RandomWalk()

# The policy argument isn't really needed for this specific env.step()
# but we pass a dummy lambda function.
dummy_policy = lambda s: None

# Use alpha from the example figure caption [cite: 869]
alpha_from_example = 2e-5

learned_weights = semi_gradient_TD_zero(env, policy=dummy_policy, num_episodes=10000, alpha=alpha_from_example)

print("Learned Weights:", learned_weights)

# You can now use v_hat(state, learned_weights) to get predictions
print("\nExample Predictions:")
for s in [1, 50, 100, 500, 901, 950, 1000]:
  print(f"  State {s} (Group {get_group(s)}): Predicted Value = {v_hat_RW(s, learned_weights):.4f}")

Learned Weights: [-8.96420644e-02 -9.58753321e-03 -1.06855577e-03 -1.24668513e-04
 -1.29467460e-05  1.34778394e-05  1.24984884e-04  1.07660483e-03
  9.41200760e-03  8.87900254e-02]

Example Predictions:
  State 1 (Group 0): Predicted Value = -0.0896
  State 50 (Group 0): Predicted Value = -0.0896
  State 100 (Group 0): Predicted Value = -0.0896
  State 500 (Group 4): Predicted Value = -0.0000
  State 901 (Group 9): Predicted Value = 0.0888
  State 950 (Group 9): Predicted Value = 0.0888
  State 1000 (Group 9): Predicted Value = 0.0888
