# SARSA
## Linear Approximation and Deep version

Estimates Q via TD learning.   
Works only with discrete action spaces.

## SARSA for Tabular Applications
### Sarsa on FrozenLake

![grafik.png](attachment:grafik.png)

In [5]:
import gym
import numpy as np

def SARSA(env, alpha = 0.85, num_tau = 5000, DISC_FACTOR = .9):    
    Q = np.zeros([env.nS, env.nA])    
    for i in range(num_tau):        
        epsilon = 0.5
        done = False
        s = env.reset()
        a = env.action_space.sample() if np.random.rand() < epsilon else np.argmax(Q[s])
        while not done:                        
            s_prime, r, done, _ = env.step(a)
            a_prime = env.action_space.sample() if np.random.rand() < epsilon else np.argmax(Q[s_prime])            
            Q[s, a] += alpha * (r + DISC_FACTOR * Q[s_prime, a_prime] - Q[s, a])            
            #epsilon = epsilon * 0.99999999999999999         
            s = s_prime            
            a = a_prime
    return Q


In [18]:
env = gym.make('FrozenLake-v0', is_slippery = False)
q = SARSA(env, num_tau = 10000)
state = env.reset()

done = False
ret = 0
while not done:
    a = np.argmax(q[state, :])

    state, rew, done, _ = env.step(a)
    ret += rew
ret

1.0

In [17]:
env = gym.make('FrozenLake-v0', is_slippery = True)
q = SARSA(env, num_tau = 10000)
state = env.reset()
ret = 0

for i in range(100):
    done = False

    while not done:
        a = np.argmax(q[state, :])

        state, rew, done, _ = env.step(a)
        ret += rew
ret/100

0.0

In [82]:
"""
SFFF
FHFH
FFFH
HFFG

LEFT = 0
DOWN = 1
RIGHT = 2
UP = 3

"""
for i in range(16):
    print(np.argmax(q[i]))

0
2
3
3
1
0
1
0
2
2
1
0
0
2
2
0


In [35]:
import gym
import numpy as np

env = gym.make('Pendulum-v0')

np.argmax(env.action_space.sample())
len(env.observation_space.high)
env.action_space.high

array([2.], dtype=float32)

## SARSA with Deep Learning
### SARSA on CartPole

In [12]:
import torch
import numpy as np

def deep_SARSA(env,  Q, epsilon = 0.5, num_samples = 10000, alpha=0.8, DISC_FACTOR=0.9, loss = torch.nn.MSELoss()):
    """
    Online(no memory) SARSA that trains a neural network to do Q-estimation for an openai gym environment.
    
    Args:
        env (gym environment)
        q_estimator (state -> 1d array): Model that estimates q values for states of env
        num_samples (int): number of sarsa samples
        
    Returns
    """
    optimizer = torch.optim.Adam(Q.parameters(), lr=alpha)
       
    for i in range(num_samples):
        s = env.reset()
        a = torch.tensor(env.action_space.sample()) if torch.rand(1) < epsilon else torch.argmax(Q.forward(torch.tensor(s, dtype=torch.float32)))
        
        done = False
        
        while not done:
            s_prime, r, done, _ = env.step(a.numpy())
            a_prime = torch.tensor(env.action_space.sample()) if torch.rand(1) < epsilon else torch.argmax(Q.forward(torch.tensor(s_prime, dtype=torch.float32)))
            
            optimizer.zero_grad()
            
            target = r + DISC_FACTOR * Q.forward(torch.tensor(s_prime, dtype=torch.float32))[a_prime]
            prediction = Q.forward(torch.tensor(s, dtype=torch.float32))[a]
            loss(target, prediction).backward()
            optimizer.step()
            
            i += 1
            
            if done or i == num_samples:
                break
                
            a = a_prime
            s = s_prime

import py_inforce as pin
import gym

env = gym.make('CartPole-v0')
in_dim = env.observation_space.shape[0] # 4
out_dim = env.action_space.n # 2
Q = pin.MLP([in_dim, 128, 128, out_dim], torch.nn.ReLU)

deep_SARSA(env, Q)

KeyboardInterrupt: 

In [44]:
torch.tensor(env.observation_space.sample(), dtype=torch.float32)

tensor([ 2.5745e+00, -1.5814e+38,  3.6223e-01, -2.1714e+37])

In [41]:
env.observation_space.sample()

array([3.5891879e+00, 1.0153620e+38, 2.5003183e-01, 1.0558521e+38],
      dtype=float32)

In [6]:
torch.rand(1).numpy()

array([0.0227617], dtype=float32)

In [10]:
torch.tensor(env.action_space.sample()).numpy()

array(0, dtype=int64)