# SARSA algorithm

In [1]:
import gym
import torch
import torch.nn as nn
from torch.nn import functional as F
import matplotlib.pyplot as plt
import numpy as np
from collections import namedtuple


Declare hyperparameters for model training.

In [2]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
gamma = 0.99
batch = 1
lr = 0.001
epsilon_decay_rate = 0.9999

#### gamma
The cartpole environment gives a reward of +1 every time before the stick collapses. Therefore, if gamma is set to 1, the value of the q function is diverged, so the q network cannot be trained theoretically (In fact, even if gamma is 1, it is possible to search for an optimal policy to some extent). In order to accurately get the q function when gamma is 1, note that the maximum time steps of one episode is 500 in the cartpole-v1 environment. Using this, we can make a neural network whose input states include time. In this tutorial, to simplify the problem, only the state condition basically given in the cartpole environment was used.

### Building SARSA network using PyTorch

First we build q network, which approximates the q function in the Cartpole environment, and create a SARSA class that performs model training and agent behavior.

In [3]:
class Q_Network(nn.Module):
    def __init__(self):
        super(Q_Network, self).__init__()

        self.fc1 = nn.Linear(obs_size, 32)
        self.fc2 = nn.Linear(32,32)
        self.fc3 = nn.Linear(32,32)
        self.fc4 = nn.Linear(32, act_size)

    def forward(self, s):
        q = F.relu(self.fc1(s))
        q = F.relu(self.fc2(q))
        q = F.relu(self.fc3(q))
        q = self.fc4(q)

        return q


class SARSA():
    def __init__(self, q_network, epsilon = 1):
        self.net = q_network
        self.epsilon = epsilon
        self.min_epsilon = 0.01
        self.transition = namedtuple("Transition" , ['s', 'a', 'r', 's_', 'a_', 'd'])
        self.buffer = []
        self.criterion = nn.MSELoss()
        self.optimizer = torch.optim.Adam(self.net.parameters(), lr)
        
    def get_action(self, s, is_train=True):
        s_tensor = (torch.from_numpy(s)).float().to(device) 

        # epsilon-greedy policy
        if np.random.rand() < self.epsilon:
            a = np.random.randint(act_size)
        else:
            q = self.net(s_tensor).detach()
            a = torch.argmax(q).item()
        return a
    
    def stack_buffer(self, sample):
        self.buffer.append(self.transition(*sample))

    def update(self):      
        self.epsilon = max(self.min_epsilon, self.epsilon * epsilon_decay_rate)
        
        s = torch.tensor([e.s for e in self.buffer]).reshape(batch, obs_size).float().to(device)
        a = torch.tensor([e.a for e in self.buffer]).reshape(batch, 1).long().to(device) 
        r = torch.tensor([e.r for e in self.buffer]).reshape(batch, 1).float().to(device)
        s_ = torch.tensor([e.s_ for e in self.buffer]).reshape(batch, obs_size).float().to(device)
        a_ = torch.tensor([e.a_ for e in self.buffer]).reshape(batch, 1).long().to(device)         
        d = torch.tensor([e.d for e in self.buffer]).reshape(batch, 1).float().to(device)
        self.buffer.clear()
        
        q_all = self.net(s)
        q = torch.gather(q_all, 1, a)
        q_all_ = self.net(s_)
        q_ = torch.gather(q_all_, 1, a_)
        target = (r + (1-d) * gamma * q_).detach()
        # TD_error = target - q
        loss = self.criterion(q, target)

        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        return loss.detach().cpu().numpy()




### Visualization of training

Next, to visualize the training, we define a class that draws a loss function and cumulative rewards as training progresses.

In [4]:
 class Graph():
    def __init__(self):
        self.max_num_data = 200
        
        self.fig, self.ax = plt.subplots(1, 2, figsize=(12, 4))
        self.loss_graph, = self.ax[0].plot(1)
        self.reward_graph, = self.ax[1].plot(1)

        self.ax[0].set_title("Loss")
        self.ax[0].set_xlim(0,1)
        self.ax[0].set_ylim(0,2)
        self.ax[0].set_yscale('log')
        self.ax[0].set_xlabel('step')
        self.ax[0].set_ylabel('loss')

        self.ax[1].set_title("Cumulative reward")
        self.ax[1].set_xlim(0,10)
        self.ax[1].set_ylim(0,2)
        self.ax[1].set_xlabel('episode')
        self.ax[1].set_ylabel('reward')

    
    def update_loss(self, loss_list):
        while len(loss_list) > self.max_num_data:
            loss_list = [data for idx, data in enumerate(loss_list) if idx %2 == 0]
        
        x = [i[0] for i in loss_list]
        y = [i[1] for i in loss_list]
        
        self.ax[0].set_xlim(0,x[-1])
        self.ax[0].set_ylim(np.min(y)*0.9,np.max(y)*1.1)

        self.loss_graph.set_xdata(x)
        self.loss_graph.set_ydata(y)        

        self.fig.canvas.draw()
        self.fig.canvas.flush_events()     
        
    def update_reward(self, reward_list):
        while len(reward_list) > self.max_num_data:
            reward_list = [data for idx, data in enumerate(reward_list) if idx %2 == 0]
        
        x = [i[0] for i in reward_list]
        y = [i[1] for i in reward_list]
        
        self.ax[1].set_xlim(0,x[-1])
        self.ax[1].set_ylim(np.min(y)*0.9, np.max(y)*1.1)
        self.reward_graph.set_xdata(x)
        self.reward_graph.set_ydata(y)        

        self.fig.canvas.draw()
        self.fig.canvas.flush_events()    

### Training the model

Create an Cartpole environment and instances of our SARSA model.

In [5]:
env = gym.make("CartPole-v1")
obs_size = env.observation_space.shape[0]
act_size = env.action_space.n
q_network = Q_Network().to(device)
agent = SARSA(q_network)

After setting the training related variables, execute the cartpole environment. Upon receiving the observation state from the environment, agent acts depending on the value of q calculated through the q network and the epsilon greedy policy. As a result, the agent gets a reward and moves on to the next state.

Originally, the SARSA algorithm does not use mini-batch training, but updates every iteration. In this example, however, we use the neural network approximation for the q function and allow the model to be trained by applying mini-batch training.

In [6]:
%matplotlib notebook
plt.ion()
graph = Graph()
loss_list = []
reward_list = []
loss_record_term = 100
max_steps = 100000
num_update = 0
episode = 0
step = 0
while step < max_steps:
    episode += 1
    cum_reward = 0
    state = env.reset()
    action = agent.get_action(state)

    while True :
        step += 1
        env.render()
        next_state, reward, done, info = env.step(action)
        next_action = agent.get_action(next_state)
        sample = [state, action, reward, next_state, next_action, done]
        agent.stack_buffer(sample)
        cum_reward += reward    
        
        state = next_state
        action = next_action
        # update
        if len(agent.buffer) == batch:
            loss = agent.update()
            num_update += 1
        # record loss
        if step % loss_record_term == 0: 
            loss_list.append([step,loss])
            graph.update_loss(loss_list)
            #print("num_update, loss , epsilon : " ,num_update, loss, agent.epsilon)

        # terminated
        if done:
            reward_list.append([episode, cum_reward])
            graph.update_reward(reward_list)
            #print("episode: {}, total reward: {}".format(episode, cum_reward))
            break
env.close()

<IPython.core.display.Javascript object>

Save the trained model.

In [8]:
torch.save(q_network.state_dict(), "./sarsa_cartpole.pt")

### Test with trained model

Load the saved model and get ready to run.

In [9]:
# if applicable, initialize the instances again.
#env = gym.make("CartPole-v1")
#observation_size = env.observation_space.shape[0]
#action_size = env.action_space.n
#q_network = Q_Network(observation_size, action_size).to(device)

# load the trained q network
q_network.load_state_dict(torch.load("./sarsa_cartpole.pt"))
q_network.eval()
agent = SARSA(q_network, epsilon= 0)
agent.min_epsilon=0

Evaluate the trained SARSA model with greedy policy.

In [10]:
for episode in range(1, 5):
    cum_reward = 0
    state = env.reset()
    action = agent.get_action(state)

    while True :
        env.render()
        next_state, reward, done, info = env.step(action)
        next_action = agent.get_action(next_state)
        cum_reward += reward    
        
        state = next_state
        action = next_action

        # terminated
        if done:
            print("episode: {}, total reward: {}".format(episode, cum_reward))
            break
env.close()

episode: 1, total reward: 500.0
episode: 2, total reward: 500.0
episode: 3, total reward: 500.0
episode: 4, total reward: 500.0
