In [1]:
from pacman.actions import Actions
from pacman.agents import Agent
from pacman.env import PacmanEnv
import numpy as np
from collections import defaultdict

In [16]:
class SARSAgent(Agent):
    def __init__(self, env, alpha, gamma , epsilon):
        self.alpha   = alpha
        self.gamma   = gamma
        self.env     = env
        self.epsilon = epsilon
        self.Qtable = defaultdict(lambda : {action: 0 for action in Actions.actions})


    def train(self, max_episodes=None, tol=1e-8):
        curr_epsiode = 0
        if max_episodes is None:
            max_episodes = np.inf

        while curr_epsiode < max_episodes:
            errors = []

            state, done = self.env.reset()
            action = self.act(state)

            while not done:
                next_state, reward, done, info = self.env.step(action)
                next_action = self.act(next_state)

                td_error = reward + self.gamma * self.Qtable[next_state][next_action] - self.Qtable[state][action]

                errors.append(abs(td_error))

                self.Qtable[state][action] += self.alpha * td_error

                state  = next_state
                action = next_action

            self.epsilon = 0.99 * self.epsilon

            if max(errors) < tol:
                break


    def act(self, state):
        if np.random.rand() > self.epsilon:
            action = max(self.Qtable[state], key=self.Qtable[state].get)

        else:
            action = Actions.sample()

        return action


    def eval(self):
        self.epsilon = 0


env = PacmanEnv.contourDanger(10, ghost_name= "FollowGhost",render_mode=None, config={"TIME_PENALTY" : 1})
sarsa = SARSAgent(env, 0.2, 0.99, 1)
sarsa.train()

In [17]:
sarsa.eval()
env.set_render('ansi')
env.run_policy(sarsa, 0, 1)

% % % % % % % % % % % %
% [33m^[0m                   %
% [31mE[0m                   %
%                     %
%                     %
%                     %
%                     %
%                     %
%                     %
%                     %
%                     %
% % % % % % % % % % % %
Score: 35


[(((6, 5), (0.0,), ((2, 9),), (0,)),
  'Up',
  -1,
  ((6, 6), (0.0,), ((2.0, 8.0),), (0,))),
 (((6, 6), (0.0,), ((2.0, 8.0),), (0,)),
  'Up',
  -1,
  ((6, 7), (0.0,), ((2.0, 7.0),), (0,))),
 (((6, 7), (0.0,), ((2.0, 7.0),), (0,)),
  'Up',
  -1,
  ((6, 8), (0.0,), ((3.0, 7.0),), (0,))),
 (((6, 8), (0.0,), ((3.0, 7.0),), (0,)),
  'Up',
  -1,
  ((6, 9), (0.0,), ((3.0, 8.0),), (0,))),
 (((6, 9), (0.0,), ((3.0, 8.0),), (0,)),
  'Right',
  -1,
  ((7, 9), (0.0,), ((3.0, 9.0),), (0,))),
 (((7, 9), (0.0,), ((3.0, 9.0),), (0,)),
  'Left',
  -1,
  ((6, 9), (0.0,), ((4.0, 9.0),), (0,))),
 (((6, 9), (0.0,), ((4.0, 9.0),), (0,)),
  'Up',
  -1,
  ((6, 10), (0.0,), ((5.0, 9.0),), (0,))),
 (((6, 10), (0.0,), ((5.0, 9.0),), (0,)),
  'Up',
  -2,
  ((6, 10), (0.0,), ((5.0, 10.0),), (0,))),
 (((6, 10), (0.0,), ((5.0, 10.0),), (0,)),
  'Down',
  -1,
  ((6, 9), (0.0,), ((6.0, 10.0),), (0,))),
 (((6, 9), (0.0,), ((6.0, 10.0),), (0,)),
  'Left',
  -1,
  ((5, 9), (0.0,), ((6.0, 9.0),), (0,))),
 (((5, 9), (0.0,)