# Chapter 3: SARSA
## Author: Wenchang Gao

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import matplotlib.pyplot as plt
import numpy as np
import gym

import time
import os

Corridor and TD example

In [None]:
class Corridor:

  def __init__(self):
    self.board = range(5)
    self.state = 1
    self.observation_space = 5
    self.action_space = 2
  
  def step(self, action):
    state_prime = self.state+1 if action == 0 else self.state-1
    self.state = state_prime
    done = (state_prime==0) or (state_prime==4)
    reward = 1 if state_prime==4 else 0
    return state_prime, reward, done
  
  def reset(self):
    self.state = 1
    return self.state
  
  def render(self):
    for i in range(5):
      print(i if self.state!=i else '*', end=' ')
    print('')


In [None]:
class TabularSARSA:

  def __init__(self, obs, act, gamma=0.99, epsilon=0.3):
    self.q_table = np.zeros((obs, act), dtype=np.float32)
    self.epsilon = epsilon
    self.gamma = gamma
  
  def print_table(self):
    for i in range(len(self.q_table)):
      for j in range(len(self.q_table[i])):
        print(self.q_table[i][j], end=' ')
      print('')
    print('')

  def act(self, state):
    prob = np.random.random()
    # print(self.act)
    action = np.random.choice(len(self.q_table[state])) \
            if prob<self.epsilon else np.argmax(self.q_table[state])
    return action


Train the agent

In [None]:
def trainTD(agent=TabularSARSA(5, 2), env=Corridor(), episodes=100):
  for epi in range(episodes):
    display = epi%1000==0
    if display:
      print(f'Episode {epi}:')
    state = env.reset()
    if display:
      env.render()
    action = agent.act(state)
    done = False
    while not done:
      state_prime, reward, done = env.step(action)
      if display:
        env.render()
      action_prime = agent.act(state_prime)
      agent.q_table[state, action] = reward+agent.gamma*agent.q_table[state_prime, action_prime]
      state, action = state_prime, action_prime

  agent.print_table()

In [None]:
trainTD()

Episode 0:
0 * 2 3 4 
0 1 * 3 4 
0 1 2 * 4 
0 1 * 3 4 
0 * 2 3 4 
0 1 * 3 4 
0 1 2 * 4 
0 1 * 3 4 
0 1 2 * 4 
0 1 2 3 * 
0.0 0.0 
0.98010004 0.0 
0.99 0.0 
1.0 0.0 
0.0 0.0 



SARSA with a neural network to approximate Q-function

SARSA batched memory

In [8]:
class Experience:
  def __init__(self):
    self.label_list = ['states', 'actions', 'rewards', 'next_states',
                       'next_actions', 'dones']
    self.batch = {
        'states':[],
        'actions':[],
        'rewards':[],
        'next_states':[],
        'next_actions':[],
        'dones':[]
    }

  def add_exp(self, exp):
    for i in range(len(self.label_list)):
      self.batch[self.label_list[i]].append([exp[i]])

  def reset(self):
    for i in self.label_list:
      self.batch[i] = []

SARSA agent

In [18]:
class Net(nn.Module):
  
  def __init__(self, indim, outdim):
    super(Net, self).__init__()
    self.model = nn.Sequential(*[
        nn.Linear(indim, 32),
        nn.ReLU(),
        nn.Linear(32, 32),
        nn.ReLU(),
        nn.Linear(32, outdim),
        nn.Softmax()
    ])
  
  def forward(self, state):
    return self.model(state)

In [71]:
class SARSA:
  '''
  SARSA based on a neural network
  '''

  def __init__(self, indim, outdim, epsilon=0.9, minie=0.1, gamma=0.999, device='cpu'):
    self.epsilon = epsilon
    self.gamma = gamma
    self.minie = minie
    self.device = device
    self.net, self.target = Net(indim, outdim), Net(indim, outdim)
    self.net, self.target = self.net.to(self.device), self.target.to(self.device)
    self.optimizer = optim.Adam(self.net.parameters(), lr=0.01)
    self.training = True
  
  def train(self):
    self.training = True
  
  def eval(self):
    self.training = False
  
  def epsilon_decay(self):
    if self.epsilon > self.minie: self.epsilon -= 0.01
  
  def forward(self, state):
    return self.net(state)
  
  def egreedy(self, actions):
    prob = np.random.random()
    # print(prob)
    # print(actions)
    return np.random.choice(len(actions)) if prob<self.epsilon else torch.argmax(actions).item()
  
  def learn(self, memory):
    '''
    returns ys according to memory for agent to learn
    '''
    states = torch.Tensor(memory.batch['states'], device=self.device)
    actions = torch.Tensor(memory.batch['actions'], device=self.device)
    rewards = torch.Tensor(memory.batch['rewards'], device=self.device)
    next_states, next_actions = torch.Tensor(memory.batch['next_states'], device=self.device), torch.Tensor(memory.batch['next_actions'], device=self.device)
    dones = torch.Tensor(memory.batch['dones'], device=self.device)
    with torch.no_grad():
      target_q = self.target(next_states)
    target_qa = torch.zeros_like(rewards)
    for i in range(len(rewards)):
      target_qa[i] = target_q[i][next_actions[i]]

    with torch.no_grad():
      ys = rewards+(1.-dones)*self.gamma*self.target_qa
    qs = self.net(states)
    y_hats = torch.zeros_like(rewards)
    for i in range(len(rewards)):
      y_hats[i] = qs[i][actions[i]]
    return ys, y_hats

  def act(self, state):
    '''
    choose action based on epsilon-greedy policy
    '''
    actions = self.forward(state)
    if self.training:
      action = self.egreedy(actions)
    else:
      action = torch.argmax(actions)
    # print(action)
    return action
  
  def save(self, path):
    torch.save(self.net.state_dict(), path+'.sarsa')
  
  def load(self, path):
    self.net.load_state_dict(path)
    self.net.to(self.device)

In [72]:
def train(env, agent, episodes=1000, lr=0.001):
  memory = Experience()
  criterion = torch.nn.MSELoss()
  loss = 0.
  opt = optim.Adam(agent.net.parameters(), lr=lr)
  for epi in range(episodes):
    memory.reset()
    state = env.reset()
    state = torch.from_numpy(state)
    action = agent.act(state)
    # print(action)
    done = False
    tot_reward = 0
    while not done:
      next_state, reward, done, _ = env.step(action)
      next_state = torch.from_numpy(next_state)
      tot_reward += reward
      next_action = agent.act(next_state)
      exp = [state, action, reward, next_state, next_action, done]
      memory.add_exp(exp)
      state, action = next_state, next_action
    opt.zero_grad()
    y, y_hat = agent.learn(memory)
    loss = criterion(y_hat, y)
    loss.backward()
    opt.step()
    agent.epsilon_decay
    print(f'Episode: {epi}, total reward:{tot_reward}')


In [73]:
env = gym.make('CartPole-v0')
indim = env.observation_space.shape[0]
outdim = env.action_space.n 
agent = SARSA(indim, outdim)
train(env, agent)

  f"The environment {id} is out of date. You should consider "
  "Initializing wrapper in old step API which returns one bool instead of two. It is recommended to set `new_step_api=True` to use new step API. This will be the default behaviour in future."
  "Initializing environment in old step API which returns one bool instead of two. It is recommended to set `new_step_api=True` to use new step API. This will be the default behaviour in future."
  input = module(input)


ValueError: ignored