In [6]:
from enum import Enum
from envs.grid import Grid
from envs.constants import SQUARE_SIZE

# Fire Evacuation Planner MDP
This agent will implement a classic MDP with states, rewards and transition models
Extending the MDP to our use case could include:
 - Fire Spread algorithm:
   - Episode ends if stepping in fire state
   - Firefighter (MDP agent) recieves reward for steps that have people needing to rescue
   - Generate an environment that includes more sophisticated properties - generate walls, based on grid, doors and so on...
   - Default reward could be something like -0.04 to encourage efficiency
   - Pass arguments to the grid when defining the base environment (walls, starting fire, people)

## Compare Reinforcement Learning Methods (Q-learning, SARSA) to Classical Methods (Policy iteration, value iteration, linear programming)
The separate models will aim to answer whether classical models or RL-based are better suited for a simulation of a real-world fire hazard on a building floor.

# Possible challenges of a classical MDP implementation
Since we are dealing with a classical-based MDP, we would need to make sure that all processes are markovian - taking action based only on current state and possible rewards.

If we encode the fire in a way that it spreads independently, then that would mean that our agent acts in a non-MDP way.

 - One way to solve this would be to include the fire status of every grid, which can quickly turn out to be alot of calculations and statuses for a simple grid.

For small grids in examples like 3x4 size, this would be a challenge but for bigger ones, Reinforcement Learning almost definitely need to be adopted in order to manage the changing environment.

In [7]:
from envs.grid import Grid
from envs.tiles.tile import Tile # Import Tile for type hinting if necessary

class FireEvacuationAgentMDP:
    def __init__(self, start_state: tuple[int, int], grid: Grid):
        self.grid = grid
        self.rows = grid.size
        self.cols = grid.size
        self.actions = ['up', 'left', 'right', 'down']

        self.possible_agent_positions = []
        for y in range(self.rows):
            for x in range(self.cols):
                if self.grid.tiles[x][y].is_traversable:
                    self.possible_agent_positions.append((x, y))

        if start_state in self.possible_agent_positions:
            self.start_state_position = self.current_state_position = start_state
        else:
            if self.possible_agent_positions:
                self.start_state_position = self.current_state_position = self.possible_agent_positions[0]
            else:
                self.start_state_position = self.current_state_position = (0, 0) # Fallback

        # The actual MDP state will be (agent_x, agent_y, fire_config_tuple)
        self.current_mdp_state = self._get_current_mdp_state()

    def _get_fire_config_tuple(self) -> tuple:
        """
        Returns a flattened tuple representing the fire status of all traversable tiles.
        """
        fire_status = []
        for y in range(self.rows):
            for x in range(self.cols):
                if self.grid.tiles[x][y].is_traversable:
                    fire_status.append(self.grid.tiles[x][y].is_on_fire)
        return tuple(fire_status)

    def _get_current_mdp_state(self) -> tuple:
        """
        Combines agent's position and fire configuration into the full MDP state.
        """
        return (self.current_state_position[0], self.current_state_position[1]) + self._get_fire_config_tuple()

    def reset(self):
        self.current_state_position = self.start_state_position
        # When resetting, the fire state should also typically reset for MDPs,
        self.current_mdp_state = self._get_current_mdp_state()
        return self.current_mdp_state

    def step(self, action: str) -> tuple:
        if action not in self.actions:
            raise ValueError("Invalid action")

        old_state_position = self.current_state_position
        x, y = self.current_state_position
        
        new_state_position = None
        match action:
            case 'up':
                new_state_position = (x, y + 1)
            case 'left':
                new_state_position = (x - 1, y)
            case 'right':
                new_state_position = (x + 1, y)
            case 'down':
                new_state_position = (x, y - 1)
        
        # Check if new_state_position is within bounds and traversable
        if (0 <= new_state_position[0] < self.cols and
            0 <= new_state_position[1] < self.rows and
            self.grid.tiles[new_state_position[0]][new_state_position[1]].is_traversable):
            self.current_state_position = new_state_position
        else:
            self.current_state_position = old_state_position

        self.grid.update() # This is where the fire spread/extinguish logic runs

        reward = 0 # To be defined later

        is_terminal = False # To be defined later

        self.current_mdp_state = self._get_current_mdp_state()

        return self.current_mdp_state, reward, is_terminal, {} # info dict for additional details

    def get_possible_states(self):
        # For a full MDP, this would ideally return all (agent_pos, fire_config) combinations,
        # which can be extremely large. For now, it returns only possible agent positions.
        # This highlights the challenge of explicitly listing all states for a dynamic MDP.
        return self.possible_agent_positions

    def __str__(self):
        return f"Agent is at state {self.current_mdp_state}"


In [11]:
grid_instance = Grid(size = 5, tile_size=SQUARE_SIZE)
initial_agent_pos = (0,0)
found_traversable = False
for y in range(grid_instance.size):
    for x in range(grid_instance.size):
        if grid_instance.tiles[x][y].is_traversable:
            actual_start_state = (x, y)
            found_traversable = True
            break
    if found_traversable:
        break


if not found_traversable:
    print("Warning: No traversable states found in the grid!")        

In [12]:
mdp = FireEvacuationAgentMDP(start_state=actual_start_state, grid=grid_instance)

In [13]:
print(mdp)

Agent is at state (0, 0)


In [14]:
# Test moves
mdp.step('right')
print(mdp)
mdp.step('down')
print(mdp)
mdp.step('left')
print(mdp)
mdp.step('up')
print(mdp)

Agent is at state (1, 0)
Agent is at state (1, 0)
Agent is at state (0, 0)
Agent is at state (0, 1)
