In [1]:
from pacman.actions import Actions
from pacman.agents import Agent
from pacman.env import PacmanEnv
import numpy as np
from collections import defaultdict

In [2]:
class SARSAgent(Agent):
    def __init__(self, alpha, gamma, env, eps):
        self.alpha  = alpha
        self.gamma  = gamma
        self.env    = env
        self.eps    = eps
        self.Qtable = defaultdict(lambda : {action: 0 for action in Actions.actions})


    def training(self, n_episodes=1000):
        for i in range(n_episodes):
            state, done = self.env.reset()
            action = self.act(state)

            while not done:
                next_state, reward, done, info = self.env.step(action)
                next_action = self.act(next_state)
                self.Qtable[state][action] += self.alpha * (reward + self.gamma * self.Qtable[next_state][next_action] - self.Qtable[state][action])

                state  = next_state
                action = next_action

            self.eps = 0.99 * self.eps


    def act(self, state):
        if np.random.rand() > self.eps:
            action = max(self.Qtable[state], key=self.Qtable[state].get)

        else:
            action = Actions.sample()

        return action


    def eval(self):
        self.eps = 0


env = PacmanEnv.from_file('smallClassic',render_mode=None, config={'TIME_PENALTY' : 1})
sarsa = SARSAgent(1, 0.99, env, 0.8)
sarsa.training()

{RandomGhost at (8, 5): 'Right', RandomGhost at (11, 5): 'Left'}
{RandomGhost at (9.0, 5.0): 'Down', RandomGhost at (10.0, 5.0): 'Down'}
{RandomGhost at (9.0, 4.0): 'Down', RandomGhost at (10.0, 4.0): 'Down'}
{RandomGhost at (9.0, 3.0): 'Right', RandomGhost at (10.0, 3.0): 'Right'}
{RandomGhost at (10.0, 3.0): 'Up', RandomGhost at (11.0, 3.0): 'Right'}
{RandomGhost at (10.0, 4.0): 'Up', RandomGhost at (12.0, 3.0): 'Right'}
{RandomGhost at (10.0, 5.0): 'Right', RandomGhost at (13.0, 3.0): 'Up'}
{RandomGhost at (11.0, 5.0): 'Left', RandomGhost at (13.0, 4.0): 'Up'}
{RandomGhost at (10.0, 5.0): 'Down', RandomGhost at (13.0, 5.0): 'Right'}
{RandomGhost at (10.0, 4.0): 'Down', RandomGhost at (14.0, 5.0): 'Down'}
{RandomGhost at (10.0, 3.0): 'Right', RandomGhost at (14.0, 4.0): 'Right'}
{RandomGhost at (11.0, 3.0): 'Right', RandomGhost at (15.0, 4.0): 'Up'}
{RandomGhost at (12.0, 3.0): 'Right', RandomGhost at (15.0, 5.0): 'Right'}
{RandomGhost at (13.0, 3.0): 'Up', RandomGhost at (16.0, 5.0)

Exception: Illegal ghost action Up

In [5]:
env.reset()

(((9, 1),
  (0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0,
   0.0),
  ((8, 5), (11, 5)),
  (0, 0),
  (0, 0)),
 False)

In [6]:
env

position             (9, 1)
collected food       (0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0)
ghost positions      ((8, 5), (11, 5))
is scared            (0, 0)
collected capsules   (0, 0)

In [3]:
env.get_legal_actions((15, 3))

['Up', 'Down', 'Right']

In [3]:
sarsa.eval()
env.set_render('ansi')
env.run_policy(sarsa, 0, .5)

% % % % % % % % % % % % % % % % % % % %
% . . . . . . %         % . . . . [31m3[0m . %
% . % % . . . % %     % % . . . % % . %
% . % o . % . . . . . . . . % . o % . %
% . % % . % . % % % % % %   % . % % . %
% . . . . . . . .               [36mE[0m . . %
% % % % % % % % % % % % % % % % % % % %
Score: -45 Game Over!


[(((9, 1),
   (0.0,
    0.0,
    0.0,
    0.0,
    0.0,
    0.0,
    0.0,
    0.0,
    0.0,
    0.0,
    0.0,
    0.0,
    0.0,
    0.0,
    0.0,
    0.0,
    0.0,
    0.0,
    0.0,
    0.0,
    0.0,
    0.0,
    0.0,
    0.0,
    0.0,
    0.0,
    0.0,
    0.0,
    0.0,
    0.0,
    0.0,
    0.0,
    0.0,
    0.0,
    0.0,
    0.0,
    0.0,
    0.0,
    0.0,
    0.0,
    0.0,
    0.0,
    0.0,
    0.0,
    0.0,
    0.0,
    0.0,
    0.0,
    0.0,
    0.0,
    0.0,
    0.0,
    0.0,
    0.0,
    0.0),
   ((8, 5), (11, 5)),
   (0, 0),
   (0, 0)),
  'Right',
  0,
  ((10, 1),
   (0.0,
    0.0,
    0.0,
    0.0,
    0.0,
    0.0,
    0.0,
    0.0,
    0.0,
    0.0,
    0.0,
    0.0,
    0.0,
    0.0,
    0.0,
    0.0,
    0.0,
    0.0,
    0.0,
    0.0,
    0.0,
    0.0,
    0.0,
    0.0,
    0.0,
    0.0,
    0.0,
    1.0,
    0.0,
    0.0,
    0.0,
    0.0,
    0.0,
    0.0,
    0.0,
    0.0,
    0.0,
    0.0,
    0.0,
    0.0,
    0.0,
    0.0,
    0.0,
    0.0,
    0.0,
    0.0,
    0.

In [None]:
a = set([(1,1)])

In [None]:
b = set([(1,1), (5,5)])

In [None]:
tuple(int(capsule not in a) for capsule in b)