#### Defining the environment

In [None]:
class Maze_env:
    def __init__(self, cols, rows, mod):
        maze, mouse, cheese = gen_maze(cols, rows, mod)
        self.cols = cols
        self.rows = rows
        self.init_maze = maze
        self.maze = maze
        self.init_pos = mouse
        self.state = (mouse[0], mouse[1], 'start')
        self.goal = cheese
        self.free_cells = [(r,c) for r in range(rows) for c in range(cols) if self.maze[r,c] == 1.0]
        
        self.min_reward = -200
        self.total_reward = 0
        self.visited = set()

    def reset(self):
        self.maze = np.copy(self.init_maze)
        if np.random.random() < 0.9:
            self.state = (self.init_pos[0], self.init_pos[1], 'start')
        else:
            pos_start = np.random.choice(range(len(self.free_cells)))
            pos_start = self.free_cells[pos_start]
            self.state = (pos_start[0], pos_start[1], 'start')
        self.total_reward = 0
        self.visited = set()

    def act(self, action):
        self.update_state(action)
        reward = self.get_reward()
        self.total_reward += reward
        status = self.game_status()
        envstate = self.observe(empty=True)
        return envstate, reward, status

    def update_state(self, action):
        row, col, mode = self.state
        valid_actions = self.valid_actions()
        if not valid_actions:
            mode = 'blocked'
        elif action in valid_actions:
            mode = 'valid'
            self.visited.add((row,col))
            #####################################LEFT = 0, UP = 1, RIGHT = 2, DOWN = 3
            if action == 0:
                col -= 1
            elif action == 1:
                row -= 1
            if action == 2:
                col += 1
            elif action == 3:
                row += 1 
        else:
            mode = 'invalid'
        self.state = (row, col, mode)
        
    def get_reward(self):
        row, col, mode = self.state
        if row == self.goal[0] and col == self.goal[1]:
            return 1.0
        if mode == 'blocked':
            return self.min_reward
        if (row, col) in self.visited:
            return -0.25
        if mode == 'invalid':
            return -0.75
        if mode == 'valid':
            return -0.04
        
    def game_status(self):
        if self.total_reward < self.min_reward:
            return -1
        row, col, mode = self.state
        if row == self.goal[0] and col == self.goal[1]:
            return 1
        return 0

    def observe(self, empty=True):
        canvas = self.draw_env(empty)
        envstate = canvas.reshape((1, -1)).squeeze()
        return envstate

    def draw_env(self, empty=False):
        canvas = np.copy(self.maze)
        if empty:
            for r in range(self.rows):
                for c in range(self.cols):
                    if canvas[r,c] > 0.0:
                        canvas[r,c] = 1.0
            canvas[self.state[0], self.state[1]] = 0.5
            return canvas
        else:
            for cell in self.visited:
                canvas[cell] = 0.18
            return canvas

    def valid_actions(self, cell=None):
        if cell is None:
            row, col, mode = self.state
        else:
            row, col = cell
        actions = [0, 1, 2, 3]
        if row == 0:
            actions.remove(1)
        elif row == self.rows-1:
            actions.remove(3)

        if col == 0:
            actions.remove(0)
        elif col == self.cols-1:
            actions.remove(2)

        if row>0 and self.maze[row-1,col] == 0.0:
            actions.remove(1)
        if row<self.rows-1 and self.maze[row+1,col] == 0.0:
            actions.remove(3)

        if col>0 and self.maze[row,col-1] == 0.0:
            actions.remove(0)
        if col<self.cols-1 and self.maze[row,col+1] == 0.0:
            actions.remove(2)

        return actions