<div style="text-align:center">
    <h1>
        SARSA
    </h1>
</div>

<br><br>

<div style="text-align:center">
    Nesse Notebook vamos implementar um metodo on-policy que aprende por tentativa e erro e utiliza bootstrapping.
    O nome SARSA se da por utilizar a seguinte regra de atualizacao
</div>

\begin{equation}
\text{State}_t, \text{Action}_t, \text{Reward}_t, \text{State}_{t+1}, \text{Action}_{t+1}
\end{equation}

<br>


## Importar os pacotes

In [None]:
import numpy as np
import matplotlib.pyplot as plt

from envs import Maze
from utils import plot_policy, plot_action_values, test_agent

## Criar o ambiente, tabela de valores e politica

#### Criar o ambiente

In [None]:
env = Maze()

#### Criar a tabela de valores $Q(s, a)$

In [None]:
action_values = np.zeros(shape=(5, 5, 4))

#### Criar a politica $\pi(s)$

In [None]:
def policy(state, epsilon=0.):
    if np.random.random() < epsilon:
        return np.random.randint(4)
    else:
        av = action_values[state]
        return np.random.choice(np.flatnonzero(av == av.max()))

#### Plotar a tabela de valores $Q(s,a)$

In [None]:
plot_action_values(action_values)

#### Plotar a politica

In [None]:
plot_policy(action_values, env.render(mode='rgb_array'))

## Implementar o algoritmo

</br>



<div style="text-align:center">
    Adapted from Barto & Sutton: "Reinforcement Learning: An Introduction".
</div>

In [None]:
def sarsa(action_values, policy, episodes, alpha=0.1, gamma=0.99, epsilon=0.2):
    
    for episode in range(1, episodes + 1):
        state = env.reset()
        action = policy(state, epsilon)
        done = False
        while not done:
            next_state, reward, done, _ = env.step(action)
            next_action = policy(next_state, epsilon)
            
            qsa = action_values[state][action]
            next_qsa = action_values[next_state][next_action]
            action_values[state][action] = qsa + alpha * (reward + gamma * next_qsa - qsa)
            state = next_state
            action = next_action

In [None]:
sarsa(action_values, policy, 100)

## Mostrar os resultados

#### Mostrar a tabela resultante $Q(s,a)$

In [None]:
plot_action_values(action_values)

#### Mostrar a politica resultante $\pi(\cdot|s)$

In [None]:
plot_policy(action_values, env.render(mode='rgb_array'))

#### Testar o agente resultante

In [None]:
test_agent(env, policy)