# Semi-gradient Sarsa

![SGS](./src/sgsarsa.jpg)

In [1]:
import numpy as np
import jdc

![SGS](./src/sgsarsa_psuedo.png)

In [2]:
class SemiGradientSARSA:
    def __init__(self, gamma=1, alpha=1e-2, epsilon=0.1):
        self.w = np.zeros((3,8,8,8))
        self.gamma = gamma
        self.epsilon = epsilon
        self.alpha = alpha
        
    def policy(self, state):
        if np.random.rand() < self.epsilon:
            return np.random.randint(3)
        actions=[0,1,2]
        f = feature_generator(state, actions)
        q = self.w * f
        return np.argmax(q.reshape(3,-1).sum(axis=1))


In [3]:
%%add_to SemiGradientSARSA
def estimate(self, n_episodes):
    for i in tqdm(range(n_episodes)):
        s = env.reset() #[pos, vel]
        a = self.policy(s)
        while True:
            s_p, reward, done, _ = env.step(a)
            a_p = self.policy(s_p)

            q_p = np.sum(self.w * feature_generator(s_p, a_p))
            q = np.sum(self.w * feature_generator(s, a))
            self.w += self.alpha * (reward + self.gamma * q_p - q) * feature_generator(s, a)

            if done:
                break
            s = s_p
            a = a_p

In [4]:
%%add_to SemiGradientSARSA
def cost_to_go(self, X, Y):
    Z = np.zeros(X.shape)
    concatenated = np.concatenate([X.ravel()[:,np.newaxis], Y.ravel()[:,np.newaxis]], axis=1)
    for i, k in enumerate(concatenated):
        q = self.w * feature_generator(k, [0,1,2])
        Z[i//X.shape[1],i%X.shape[1]] = np.max(q.reshape(3,-1).sum(axis=1))
    Z *= -1
    return Z