In [1]:
from pacman.actions import Actions
from pacman.agents import Agent
from pacman.env import PacmanEnv
import numpy as np
from collections import defaultdict

SARSA LAMBDA 

In [6]:
class SARSAgent(Agent):
    def __init__(self, env, alpha, gamma , epsilon, lmbda):
        self.alpha   = alpha
        self.gamma   = gamma
        self.env     = env
        self.epsilon = epsilon
        self.Qtable  = defaultdict(lambda : {action: 0 for action in Actions.actions})
        self.lmbda   = lmbda


    def train(self, max_episodes=None, tol=1e-7):
        curr_epsiode = 0
        if max_episodes is None:
            max_episodes = np.inf


        while curr_epsiode < max_episodes:
            errors = []
            
            Etable  = defaultdict(lambda : {action: 0 for action in Actions.actions})

            state, done = self.env.reset()
            action = self.act(state)

            while not done:
                Etable[state][action] = Etable[state][action] + 1

                next_state, reward, done, info = self.env.step(action)
                next_action = self.act(next_state)

                td_error = reward + self.gamma * self.Qtable[next_state][next_action] - self.Qtable[state][action]

                errors.append(abs(td_error))

                self.update_tables(Etable, td_error)

                state  = next_state
                action = next_action

            self.epsilon = 0.99 * self.epsilon

            if max(errors) < tol:
                break


    def act(self, state):
        if np.random.rand() > self.epsilon:
            action = max(self.Qtable[state], key=self.Qtable[state].get)

        else:
            action = Actions.sample()

        return action


    def eval(self):
        self.epsilon = 0

    def update_tables(self, Etable, td_error):
        for state in Etable:
            for action in Actions.actions:
                self.Qtable[state][action] += td_error * self.alpha * Etable[state][action]
                Etable[state][action] *= self.gamma * self.lmbda



            


env = PacmanEnv.contourDanger(7, ghost_name= "FollowGhost",render_mode=None, config={"TIME_PENALTY" : 1})
sarsa = SARSAgent(env, 0.2, 0.99, 1, 0.2)
sarsa.train()

In [7]:
sarsa.eval()
env.set_render('ansi')
env.run_policy(sarsa, 0, 1)

% % % % % % % % %
% [33m<[0m [31mE[0m           %
%               %
%               %
%               %
%               %
%               %
%               %
% % % % % % % % %
Score: 19


[(((4, 4), (0.0,), ((2, 6),), (0,)),
  'Right',
  -1,
  ((5, 4), (0.0,), ((2.0, 5.0),), (0,))),
 (((5, 4), (0.0,), ((2.0, 5.0),), (0,)),
  'Up',
  -1,
  ((5, 5), (0.0,), ((2.0, 4.0),), (0,))),
 (((5, 5), (0.0,), ((2.0, 4.0),), (0,)),
  'Up',
  -1,
  ((5, 6), (0.0,), ((2.0, 5.0),), (0,))),
 (((5, 6), (0.0,), ((2.0, 5.0),), (0,)),
  'Down',
  -1,
  ((5, 5), (0.0,), ((2.0, 6.0),), (0,))),
 (((5, 5), (0.0,), ((2.0, 6.0),), (0,)),
  'Left',
  -1,
  ((4, 5), (0.0,), ((2.0, 5.0),), (0,))),
 (((4, 5), (0.0,), ((2.0, 5.0),), (0,)),
  'Up',
  -1,
  ((4, 6), (0.0,), ((3.0, 5.0),), (0,))),
 (((4, 6), (0.0,), ((3.0, 5.0),), (0,)),
  'Up',
  -1,
  ((4, 7), (0.0,), ((3.0, 6.0),), (0,))),
 (((4, 7), (0.0,), ((3.0, 6.0),), (0,)),
  'Up',
  -2,
  ((4, 7), (0.0,), ((3.0, 7.0),), (0,))),
 (((4, 7), (0.0,), ((3.0, 7.0),), (0,)),
  'Left',
  -21,
  ((3, 7), (0.0,), ((4.0, 7.0),), (0,))),
 (((3, 7), (0.0,), ((4.0, 7.0),), (0,)),
  'Left',
  -1,
  ((2, 7), (0.0,), ((3.0, 7.0),), (0,))),
 (((2, 7), (0.0,), ((3