In [1]:
from pacman.actions import Actions
from pacman.agents import Agent
from pacman.env import PacmanEnv
import numpy as np
from collections import defaultdict

In [2]:
class SARSAgent(Agent):
    def __init__(self, alpha, gamma, env, eps):
        self.alpha  = alpha
        self.gamma  = gamma
        self.env    = env
        self.eps    = eps
        self.Qtable = defaultdict(lambda : {action: 0 for action in Actions.actions})


    def training(self, n_episodes=1000):
        for i in range(n_episodes):
            state, done = self.env.reset()
            action = self.act(state)

            while not done:
                next_state, reward, done, info = self.env.step(action)
                next_action = self.act(next_state)
                self.Qtable[state][action] += self.alpha * (reward + self.gamma * self.Qtable[next_state][next_action] - self.Qtable[state][action])

                state  = next_state
                action = next_action

            self.eps = 0.99 * self.eps


    def act(self, state):
        if np.random.rand() > self.eps:
            action = max(self.Qtable[state], key=self.Qtable[state].get)

        else:
            action = Actions.sample()

        return action


    def eval(self):
        self.eps = 0


env = PacmanEnv.from_file('testMaze',render_mode=None)
sarsa = SARSAgent(0.8, 0.9, env, 0.8)
sarsa.training()

In [4]:
sarsa.eval()
env.set_render('ansi')
env.run_policy(sarsa, 100, 0, 1)

% % % % % % % % %
% [33m^[0m %       %   %
%   %   %   %   %
%   %   %   %   %
%       %       %
% % % % % % % % %
Score: 51


[((7, 4), 'Down', 0, (7, 3)),
 ((7, 3), 'Down', 0, (7, 2)),
 ((7, 2), 'Down', 0, (7, 1)),
 ((7, 1), 'Left', 0, (6, 1)),
 ((6, 1), 'Left', 0, (5, 1)),
 ((5, 1), 'Up', 0, (5, 2)),
 ((5, 2), 'Up', 0, (5, 3)),
 ((5, 3), 'Up', 0, (5, 4)),
 ((5, 4), 'Left', 0, (4, 4)),
 ((4, 4), 'Left', 0, (3, 4)),
 ((3, 4), 'Down', 0, (3, 3)),
 ((3, 3), 'Down', 0, (3, 2)),
 ((3, 2), 'Down', 0, (3, 1)),
 ((3, 1), 'Left', 0, (2, 1)),
 ((2, 1), 'Left', 0, (1, 1)),
 ((1, 1), 'Up', 0, (1, 2)),
 ((1, 2), 'Up', 0, (1, 3)),
 ((1, 3), 'Up', 51, (1, 4))]