# Chapter 3: SARSA
## Author: Wenchang Gao

In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import matplotlib.pyplot as plt
import numpy as np
import gym

import time
import os

Corridor and TD example

In [53]:
class Corridor:

  def __init__(self):
    self.board = range(5)
    self.state = 1
    self.observation_space = 5
    self.action_space = 2
  
  def step(self, action):
    state_prime = self.state+1 if action == 0 else self.state-1
    self.state = state_prime
    done = (state_prime==0) or (state_prime==4)
    reward = 1 if state_prime==4 else 0
    return state_prime, reward, done
  
  def reset(self):
    self.state = 1
    return self.state
  
  def render(self):
    for i in range(5):
      print(i if self.state!=i else '*', end=' ')
    print('')


In [54]:
class TabularSARSA:

  def __init__(self, obs, act, gamma=0.99, epsilon=0.3):
    self.q_table = np.zeros((obs, act), dtype=np.float32)
    self.epsilon = epsilon
    self.gamma = gamma
  
  def print_table(self):
    for i in range(len(self.q_table)):
      for j in range(len(self.q_table[i])):
        print(self.q_table[i][j], end=' ')
      print('')
    print('')

  def act(self, state):
    prob = np.random.random()
    # print(self.act)
    action = np.random.choice(len(self.q_table[state])) \
            if prob<self.epsilon else np.argmax(self.q_table[state])
    return action


Train the agent

In [55]:
def trainTD(agent=TabularSARSA(5, 2), env=Corridor(), episodes=100):
  for epi in range(episodes):
    display = epi%1000==0
    if display:
      print(f'Episode {epi}:')
    state = env.reset()
    if display:
      env.render()
    action = agent.act(state)
    done = False
    while not done:
      state_prime, reward, done = env.step(action)
      if display:
        env.render()
      action_prime = agent.act(state_prime)
      agent.q_table[state, action] = reward+agent.gamma*agent.q_table[state_prime, action_prime]
      state, action = state_prime, action_prime

  agent.print_table()

In [56]:
trainTD()

Episode 0:
0 * 2 3 4 
0 1 * 3 4 
0 1 2 * 4 
0 1 * 3 4 
0 * 2 3 4 
0 1 * 3 4 
0 1 2 * 4 
0 1 * 3 4 
0 1 2 * 4 
0 1 2 3 * 
0.0 0.0 
0.98010004 0.0 
0.99 0.0 
1.0 0.0 
0.0 0.0 

