In [78]:
import numpy as np

In [79]:
class Game():
    def __init__(self, grid, start_state=(0, 0)):
        self.grid = grid
        self.rows, self.cols = grid.shape
        self.state = start_state
        
    def get_reward(self):
        """
        return reward obtained in current state
        """
        return self.grid[self.state[0]][self.state[1]]
    
    def is_valid(self, state):
        """
        checks if this is a valid state(move)
        """
        if state[0] < 0 or state[0] >= self.rows:
            # outside row bound
            return False
        if state[1] < 0 or state[1] >= self.cols:
            # outside column bound
            return False
        if self.grid[state[0]][state[1]] == np.nan:
            # wall
            return False
        return True
    
    def next_state(self, action):
        """
        update state according to action taken
        """
        if action == "up":
            next_state = (self.state[0]-1, self.state[1])
        elif action == "down":
            next_state = (self.state[0]+1, self.state[1])
        elif action == "right":
            next_state = (self.state[0], self.state[1]+1)
        elif action == "left":
            next_state = (self.state[0], self.state[1]-1)
        else:
            raise ValueError("Invalid action!")
        if self.is_valid(next_state):
            return next_state
        else:
            return self.state
    
    def is_end(self):
        """
        Returns true if the game has come to an end
        """
        val = self.grid[self.state[0]][self.state[1]]
        return val == -1 or val == 1

In [81]:
class Agent():
    def __init__(self, game):
        self.game = game
        self.lr = 0.2
        self.explore_rate = 0.1
        self.actions = ["up", "down", "left", "right"]
        
        self.state_values = {}
        for i in range(self.game.rows):
            for j in range(self.game.cols):
                self.state_values[(i, j)] = 0
    
    def select_action(self):
        if np.random.uniform(0, 1) <= self.explore_rate:
            action = np.random.choice(self.actions)
        else:
            max_reward = -np.inf
            action = ""
            for a in self.actions:
                exp_reward = self.state_values[self.game.next_state(a)]
                if exp_reward >= max_reward:
                    max_reward = exp_reward
                    action = a
        return action
    
    def reset(self):
        self.game.state = (0, 0)
    
    
    def play_episode(self, max_iter=1000):
        self.reset()
        it = 0
        episode_states = []
        episode_states.append(self.game.state)
        while True:
            action = self.select_action()
            self.game.state = self.game.next_state(action)
            episode_states.append(self.game.state)
            
            if self.game.is_end():
                reward = self.game.get_reward()
                for state in episode_states:
                    self.state_values[state] += self.lr * (reward-self.state_values[state])
                return reward, len(episode_states)
    
    def play(self, episodes=10):
        for _ in range(episodes):
            print(self.play_episode())
        

In [83]:
grid = np.array([[0, 0, 0, 0], [0, np.nan, 0, 0], [0, 0, 0, -1], [0, 0, 0, 1]])
game = Game(grid)
agent = Agent(game)
agent.play(1000)

(-1.0, 67)
(-1.0, 12)
(1.0, 323)
(-1.0, 31667)
(1.0, 3650)
(1.0, 18289)
(1.0, 12379)
(1.0, 4638)
(1.0, 121)
(-1.0, 11532)
(-1.0, 317)
(-1.0, 6)
(-1.0, 6)
(-1.0, 6)
(-1.0, 6)
(-1.0, 6)
(-1.0, 6)
(-1.0, 6)
(-1.0, 6)
(-1.0, 6)
(-1.0, 6)
(-1.0, 6)
(-1.0, 6)
(-1.0, 6)
(-1.0, 6)
(-1.0, 6)
(-1.0, 6)
(-1.0, 8)
(-1.0, 6)
(-1.0, 6)
(-1.0, 6)
(-1.0, 6)
(-1.0, 6)
(-1.0, 948)
(-1.0, 8)
(-1.0, 6)
(-1.0, 6)
(-1.0, 6)
(-1.0, 6)
(-1.0, 6)
(-1.0, 8)
(-1.0, 6)
(-1.0, 6)
(1.0, 7)
(1.0, 10)
(1.0, 8)
(1.0, 452)
(1.0, 10858)
(1.0, 35200)
(1.0, 52310)
(-1.0, 919)
(1.0, 19261)
(1.0, 23704)
(1.0, 25364)
(1.0, 3168)
(-1.0, 13676)
(-1.0, 79)
(1.0, 634)
(1.0, 8461)
(1.0, 2894)
(1.0, 21410)
(1.0, 9923)
(1.0, 1737)
(1.0, 3810)
(-1.0, 4415)
(1.0, 1012)
(1.0, 2626)
(1.0, 3334)
(1.0, 6648)
(1.0, 99416)
(1.0, 1433)
(1.0, 19901)
(1.0, 1521)
(1.0, 573)
(1.0, 650)
(1.0, 246)
(1.0, 425)
(1.0, 54)
(1.0, 282)
(1.0, 833)
(1.0, 565)
(1.0, 140)
(1.0, 374)
(1.0, 88)
(1.0, 1004)
(1.0, 408)
(1.0, 54)
(1.0, 730)
(1.0, 1331)
(1.0, 10

(1.0, 857)
(1.0, 334)
(1.0, 51)
(1.0, 201)
(1.0, 402)
(1.0, 130)
(1.0, 94)
(1.0, 120)
(1.0, 200)
(1.0, 190)
(1.0, 190)
(1.0, 514)
(1.0, 207)
(1.0, 177)
(1.0, 178)
(1.0, 158)
(1.0, 340)
(1.0, 425)
(1.0, 68)
(1.0, 806)
(1.0, 648)
(1.0, 126)
(1.0, 107)
(1.0, 205)
(1.0, 379)
(1.0, 276)
(1.0, 73)
(1.0, 38)
(1.0, 244)
(1.0, 158)
(1.0, 614)
(1.0, 97)
(1.0, 78)
(1.0, 423)
(1.0, 116)
(1.0, 177)
(1.0, 367)
(1.0, 372)
(1.0, 63)
(1.0, 53)
(1.0, 267)
(1.0, 52)
(1.0, 293)
(1.0, 285)
(1.0, 67)
(1.0, 112)
(1.0, 179)
(1.0, 596)
(1.0, 843)
(1.0, 212)
(1.0, 228)
(1.0, 157)
(1.0, 58)
(1.0, 241)
(1.0, 811)
(1.0, 141)
(1.0, 230)
(1.0, 103)
(1.0, 369)
(1.0, 66)
(1.0, 236)
(1.0, 65)
(1.0, 479)
(1.0, 185)
(1.0, 208)
(1.0, 467)
(1.0, 67)
(1.0, 128)
(1.0, 167)
(1.0, 671)
(1.0, 126)
(1.0, 446)
(1.0, 116)
(1.0, 135)
(1.0, 212)
(1.0, 98)
(1.0, 320)
(1.0, 617)
(1.0, 324)
(1.0, 27)
(1.0, 706)
(1.0, 40)
(1.0, 271)
(1.0, 337)
(1.0, 108)
(1.0, 170)
(-1.0, 242)
(1.0, 7)
(1.0, 20)
(1.0, 29)
(1.0, 39)
(1.0, 36)
(1.0, 43)
(

In [17]:
5/np.inf

0.0