# TD Learning: Sarsa e *Q*-learning

#### Prof. Armando Alves Neto - Introdução ao Aprendizado por Reforço - PPGEE/UFMG

<img src="cave.png" width="400">

In [1]:
%matplotlib qt
import numpy as np
import gym
from functools import partial
import class_maze as cm
import matplotlib.pyplot as plt
plt.rcParams['figure.figsize'] = (15,5)
import seaborn as sns
sns.set()

Cria uma classe (```RunningAverage()```) apenas para calcular a média móvel do sinal de reforço.

In [2]:
class RunningAverage(object):
    def __init__(self, N):
        self.N = N
        self.vals = []
        self.num_filled = 0

    def push(self, val):
        if self.num_filled == self.N:
            self.vals.pop(0)
            self.vals.append(val)
        else:
            self.vals.append(val)
            self.num_filled += 1

    def get(self):
        return float(sum(self.vals)) / self.num_filled

Criando a classe para o TD learning.

In [3]:
class TDlearning(object):
    def __init__(self, parameters):

        self.parameters = parameters

        # metodo
        self.method = parameters['method']

        # numero de episodios
        self.episode = 0

        # cria o ambiente
        xlim = np.array([0.0, 10.0])
        ylim = np.array([0.0, 10.0])
        resolution = 0.4
        self.env = cm.Maze(xlim=xlim, ylim=ylim, res=resolution, image='cave.png')

        # tamanho dos espacos de estados e acoes
        self.num_states = np.prod(np.array(self.env.num_states))
        self.num_actions = self.env.action_space.n

        # parametros de aprendizado
        self.gamma = parameters['gamma']
        self.eps = parameters['eps']
        self.alpha = parameters['alpha']

        # log file
        self.logfile = parameters['q-file']

        # reseta a politica
        self.reset_policy()

    ##########################################
    # reseta a funcao acao-valor
    def reset_policy(self):
        
        # Q(s,a)
        self.Q = np.zeros((self.num_states, self.num_actions))

        if self.parameters['load_Q']:
            try:
                with open(self.logfile, 'rb') as f:
                    data = np.load(f)
                    self.Q = data['Q']
                    self.episode = data['episodes']
            except: None

    ##########################################
    # retorna a politica corrente
    def curr_policy(self, copy=False):
        if copy:
            return partial(self.TabularEpsilonGreedyPolicy, np.copy(self.Q))
        else:
            return partial(self.TabularEpsilonGreedyPolicy, self.Q)
        
    ########################################
    # salva tabela Q(s,a)
    def save(self):
        with open(self.logfile, 'wb') as f:
            np.savez(f, Q=self.Q, episodes=self.episode)

    ##########################################
    def __del__(self):
        self.env.close()

Probabilidade de escolha de uma ação $a$ baseada na política $\varepsilon$-soft:
$$
\pi(a|S_t) \gets 
                        \begin{cases}
                            1 - \varepsilon + \varepsilon/|\mathcal{A}|,  & \text{se}~ a = A^*,\\
                            \varepsilon/|\mathcal{A}|, & \text{caso contrário.}
                        \end{cases}
$$

In [4]:
class TDlearning(TDlearning):
    ##########################################
    # escolha da açao (epsilon-soft)
    def TabularEpsilonGreedyPolicy(self, Q, state):

        # acao otima corrente
        Aast = Q[state, :].argmax()

        # numero total de acoes
        nactions = Q.shape[1]
    
        # probabilidades de escolher as acoes
        p1 = 1.0 - self.eps + self.eps/nactions
        p2 = self.eps/nactions
        prob = [p1 if a == Aast else p2 for a in range(nactions)]
        
        return np.random.choice(nactions, p=np.array(prob))

Método do Sarsa:
- aplique ação $A$, receba $S'$ e $R$
- escolha $A'$ a partir de $S'$ usando $Q$ ($\varepsilon$-greedy, por exemplo)
- $Q(S,A) \gets Q(S,A) + \alpha \big[R + \gamma Q(S',A') - Q(S,A)\big]$
- $S \gets S'$
- $A \gets A'$

In [5]:
class TDlearning(TDlearning):
    ##########################################
    def sarsa(self, S, A):

        # passo de interacao com o ambiente
        [Sl, R, done, _] = self.env.step(A)
        
        # escolhe A' a partir de S'
        Al = self.policy(Sl)
        
        # update de Q(s,a)
        self.Q[S, A] = self.Q[S, A] + self.alpha*(R + self.gamma*self.Q[Sl, Al] - self.Q[S, A])
        
        return Sl, Al, R, done

Método do *Q*-learning:
- escolha $A$ a partir de $S$ usando $Q$ ($\varepsilon$-greedy, por exemplo)
- aplique ação $A$, receba $S'$ e $R$
- $Q(S,A) \gets Q(S,A) + \alpha \big[R + \gamma \max\limits_a Q(S',a) - Q(S,A)\big]$
- $S \gets S'$

In [6]:
class TDlearning(TDlearning):
    ##########################################
    def qlearning(self, S):
        
        # \pi(s)
        A = self.policy(S)

        # passo de interacao com o ambiente
        [Sl, R, done, _] = self.env.step(A)
        
        self.Q[S, A] = self.Q[S, A] + self.alpha*(R + self.gamma*self.Q[Sl, :].max() - self.Q[S, A])
        
        return Sl, R, done

Executando um dos dois métodos.

In [7]:
class TDlearning(TDlearning):
    ##########################################
    # simula um episodio até o fim seguindo a politica corente
    def rollout(self, max_iter=1000, render=False):

        # inicia o ambiente (começa aleatoriamente)
        S = self.env.reset()
        
        # \pi(s)
        A = self.policy(S)

        # lista de rewards
        rewards = []

        for _ in range(max_iter):
            
            if self.method == 'Sarsa':
                Sl, Al, R, done = self.sarsa(S, A)
                # proximo estado e ação
                S = Sl
                A = Al
                
            elif self.method == 'Q-learning':
                Sl, R, done = self.qlearning(S)
                # proximo estado
                S = Sl

            # Salva rewards
            rewards.append(R)

            # renderiza o ambiente
            if render:
                plt.subplot(1, 3, 3)
                plt.gca().clear()
                self.env.render(self.Q)

            # chegou a um estado terminal?
            if done: break

        return rewards

Executando um episódio.

In [8]:
class TDlearning(TDlearning):
    ##########################################
    def runEpisode(self):

        # novo episodio
        self.episode += 1

        # pega a politica corrente (on-policy)
        self.policy = self.curr_policy()

        # gera um episodio seguindo a politica corrente
        rewards = self.rollout(render=((self.episode-1)%100 == 0))
        
        # salva a tabela Q
        if self.parameters['save_Q']:
            self.save()

        return np.sum(np.array(rewards))

Código principal:
- episodes: número de episódios
- gamma: fator de desconto
- eps: $\varepsilon$
- alpha: $\alpha$
- method: *Sarsa* ou *Q-learning*
- save_Q: salva tabela *Q*
- load_Q: carrega tabela *Q*
- q-file: arquivo da tabela *Q*

In [9]:
##########################################
# main
##########################################
if __name__ == '__main__':
    
    plt.ion()
    
    # cria objeto para calculo da média movel do reward
    avg_calc = RunningAverage(100)

    # parametros
    parameters = {'episodes'  : 2000,
                  'gamma'     : 0.99,
                  'eps'       : 0.5e-2,
                  'alpha'     : 0.5,
                  'method'    : 'Q-learning', #'Sarsa' ou 'Q-learning'
                  'save_Q'    : True,
                  'load_Q'    : False,
                  'q-file'    : 'q-table.npy',}

    # TD algorithm
    mc = TDlearning(parameters)

    # historico dos reforços
    rewards = []
    avg_rewards = []

    while mc.episode <= parameters['episodes']:
        # roda um episodio
        total_reward = mc.runEpisode()
        
        # rewrds
        rewards.append(total_reward)
        # reward medio
        avg_calc.push(total_reward)
        avg_rewards.append(avg_calc.get())

        plt.figure(1)
        plt.subplot(1, 3, (1,2))
        plt.gca().clear()
        plt.plot(avg_rewards, 'b', linewidth=2)
        plt.plot(rewards, 'r', alpha=0.3)
        plt.title('Reforço por episódios')
        plt.xlabel('Episódios')
        plt.ylabel('Reforço')

        plt.show()
        plt.pause(.1)

    plt.ioff()