# Temporal-Difference Learning

Temporal-difference (TD) learning is a combination of Monte-carlo ideas and dynamic programming (DP) ideas.

TD is:
- Like Monte Carlo, TD can learn directly from raw experience without a model of the environment's dynamics
- Like DP, TD methods update estimates based in part on other learned estimates, without waiting for final outcomes (they bootstrap)

## Temporal-Learning Prediction

Both TD and MC methods use experience to solve prediction problem.

In [None]:
import collections

class TemporalLearningZeroPrediction():

    def __init__(self, gamma, alpha, policy):
        self.gamma  = gamma
        self.alpha  = alpha
        self.policy = policy

        self.state_value = collections.defaultdict(lambda: 0)
        self.returns = collections.defaultdict(lambda: 0)

        self.states = []
        self.rewards = []

    def action(self, state):
        return self.policy(state)
    
    def observe(self, state, action, reward):
        self.states.append(state)
        self.rewards.append(reward)
    
    def optimize(self):
        g = 0

        for t in reversed(range(len(self.states))):
            g = self.gamma * g + self.rewards[t]

            if not self.states[t] in self.states[0:t]:
                self.returns[self.states[t]] += 1
                self.state_value[self.states[t]] += (1 / self.returns[self.states[t]]) * (g - self.state_value[self.states[t]])
        
        self.states = []
        self.rewards = []

## Sarsa: On-policy TD Control

Following the pattern of generalized policy iteration (GPI), using TD methods for the evaluation or prediction part.  
Again, we need to trade off exploration and exploitation, with two approaches: on-policy and off-policy.

The theorems assuring the convergence of state values under TD(0) also apply to the corresponding algorithm for action values:

$ Q(S_t, A_t) = Q(S_t, A_t) + \alpha * [R_{t+1} + \gamma * Q(S_{t+1}, A_{t+1}) - Q(S_t, A_t)]$

In [None]:
import collections

class Sarsa():

    def __init__(self, action_space, gamma, alpha, policy):
        self.gamma  = gamma
        self.alpha  = alpha
        self.policy = policy

        self.state_action_value = collections.defaultdict(lambda: np.zeros((action_space.n)))
        self.selected_action = None

    def action(self, state):
        if self.selected_action == None:
            return self.policy(self.state_action_value, state)
        return self.selected_action
    
    def observe(self, state, action, reward, next_state):
        next_action = self.policy(self.state_action_value, next_state)
        target = reward + self.gamma * self.state_action_value[next_state][next_action]

        self.state_action_value[state][action] += self.alpha * (target - self.state_action_value[state][action])

        self.selected_action = next_action
    
    def optimize(self):
        self.selected_action = None

In [None]:
import collections

class Human():

    def __init__(self):
        pass

    def action(self, state):
        return int(input())
    
    def observe(self, state, action, reward, next_state):
        pass
    
    def optimize(self):
        pass

In [None]:
# Windy Gridworld Env
from enum import Enum

import numpy as np

import gymnasium as gym
from gymnasium import spaces

class Actions(Enum):
    RIGHT = 0
    UP = 1
    LEFT = 2
    DOWN = 3

class WindyGridworld(gym.Env):
    metadata = { "render_modes": ["ascii"] }

    def __init__(self, render_mode=None, grid_shape=(7, 10)):
        self._grid_shape = grid_shape

        # Observations are dictionaries with the agent's and the target's location.
        # Each location is encoded as an element of {0, ..., `size`}^2, i.e. MultiDiscrete([size, size]).
        self.observation_space = spaces.Dict(
            {
                "agent": spaces.Box(0, self._grid_shape[0] - 1, shape=(2,), dtype=int),
                "target": spaces.Box(0, self._grid_shape[0] - 1, shape=(2,), dtype=int),
            }
        )

        self._agent_location = np.array([3, 0], dtype=int)
        self._target_location = np.array([3, 7], dtype=int)

        # We have 4 actions, corresponding to "right", "up", "left", "down"
        self.action_space = spaces.Discrete(4)

        """
        The following dictionary maps abstract actions from `self.action_space` to
        the direction we will walk in if that action is taken.
        i.e. 0 corresponds to "right", 1 to "up" etc.
        """
        self._action_to_direction = {
            Actions.DOWN.value: np.array([1, 0]),
            Actions.RIGHT.value: np.array([0, 1]),
            Actions.UP.value: np.array([-1, 0]),
            Actions.LEFT.value: np.array([0, -1]),
        }

        assert render_mode is None or render_mode in self.metadata["render_modes"]
        self.render_mode = render_mode
    
    def _get_obs(self):
        return str(self._agent_location)
    
    def _get_info(self):
        return {
            "distance": np.linalg.norm(
            self._agent_location - self._target_location, ord=1
            )
        }
    
    def _render_frame(self):
        if self.render_mode == "ascii":
            grid = np.zeros((7, 10))
            grid[self._agent_location[0], self._agent_location[1]] = 1
            grid[self._target_location[0], self._target_location[1]] = 6
            print(grid, flush=True)

    def step(self, action):
        # Map the action (element of {0,1,2,3}) to the direction we walk in
        direction = self._action_to_direction[action]
        wind_force = 0

        if self._agent_location[1] in [3, 4, 5, 8]:
            wind_force = 1 # apply wind with force of 1

        if self._agent_location[1] in [6, 7]:
            wind_force = 2 # apply wind with force of 2

        # We use `np.clip` to make sure we don't leave the grid
        self._agent_location[0] = np.clip(
            self._agent_location[0] + direction[0] - wind_force, 0, self._grid_shape[0] - 1
        )

        self._agent_location[1] = np.clip(
            self._agent_location[1] + direction[1], 0, self._grid_shape[1] - 1
        )

        # An episode is done iff the agent has reached the target
        terminated = np.all(self._agent_location == self._target_location)
        reward = 0 if terminated else -1
        observation = self._get_obs()
        info = self._get_info()

        self._render_frame()

        return observation, reward, terminated, False, info
    
    def reset(self, seed=None, options=None):
        # We need the following line to seed self.np_random
        super().reset(seed=seed)

        # reset agent's position
        self._agent_location = np.array([3, 0], dtype=int)

        observation = self._get_obs()
        info = self._get_info()

        self._render_frame()

        return observation, info

In [None]:
# Import the necessaries libraries
import numpy as np
import plotly.graph_objects as go

import plotly.io as pio
pio.renderers.default = 'notebook'

In [None]:
env = WindyGridworld()

In [None]:
def play_env(env, agent):
    reward_sum = 0
    nb_steps = 0

    terminated = False
    observation, info = env.reset()

    while not terminated:
        action = agent.action(observation)

        new_observation, reward, terminated, truncated, info = env.step(action)

        agent.observe(observation, action, reward, new_observation)

        observation = new_observation

        reward_sum += reward
        nb_steps += 1
    
    agent.optimize()

    return reward_sum, nb_steps

In [None]:
def argmax(array):
    return np.random.choice(np.where(array == np.max(array))[0])

def get_epsilon_greedy_policy(epsilon=0.1):
    def epsilon_greedy_policy(state_action_value, state):
        take_random_action_prob = np.random.uniform(0, 1)

        if take_random_action_prob < epsilon:
            random_action = np.random.randint(0, len(state_action_value[state]))
            return random_action
        else:
            greedy_action = argmax(state_action_value[state])
            return greedy_action
    
    return epsilon_greedy_policy

agent = Sarsa(env.action_space, alpha=0.5, gamma=1, policy=get_epsilon_greedy_policy(epsilon=0.1))

In [None]:
rewards = []
time_steps = [0]


for i in range(200):
    reward, steps = play_env(env, agent)

    rewards.append(reward)
    time_steps.append(time_steps[-1] + steps)

print(agent.state_action_value)

In [None]:
import plotly.graph_objects as go
from plotly.subplots import make_subplots

fig = make_subplots(rows=1, cols=1, subplot_titles="Windy Gridworld")

x = np.array(time_steps)
y = np.arange(len(time_steps))
fig.add_trace(
    go.Scatter(
        x=x,
        y=y,
        line_color="red",
        name="test",
    ),
    row=1,
    col=1,
)

fig.update_layout(
    title="test",
    legend_title="Parameters",
)

fig.show()