In [1]:
import gymnasium as gym
import moving_firefighter_env as mfp
import numpy as np
import networkx as nx
import matplotlib.pyplot as plt
from itertools import count

In [2]:
class GreeedyPolicy:
    def __init__(self, initial_state):
        graph_burnt_state = initial_state["graph_burnt"]
        graph_burnt = nx.Graph()
        graph_burnt.add_nodes_from(range(len(graph_burnt_state.nodes)))
        graph_burnt.add_edges_from(graph_burnt_state.edge_links)

        graph_fighter_state = state["graph_fighter"]
        graph_fighter = nx.complete_graph(len(graph_fighter_state.nodes))

        positions = dict()
        for i, pos in enumerate(graph_fighter_state.nodes):
            positions[i] = pos

        distances = dict()
        for edge, weight in zip(graph_fighter_state.edge_links, graph_fighter_state.edges):
            distances[tuple(edge)] = weight.item()
        nx.set_edge_attributes(graph_fighter, distances, "distance")

        root = state["burnt_nodes"][0]
        depth = dict()
        visited = [root]
        def compute_depth(root):
            node_depth = 1
            for neighbor in graph_burnt.neighbors(root):
                if neighbor not in visited:
                    visited.append(neighbor)
                    node_depth += compute_depth(neighbor)
            depth[root] = node_depth
            return node_depth
        compute_depth(root)

        self.graph_burnt = graph_burnt
        self.graph_fighter = graph_fighter
        self.positions = positions
        self.distances = distances
        self.depth = depth
        self.root = root

    def next_action(self, valid_actions):
        if len(valid_actions) == 1:
            return valid_actions[0]
        else:
            greedy_actions = []
            max_depth = float("-inf")
            for action in valid_actions[1:]:
                if self.depth[action] > max_depth:
                    max_depth = self.depth[action]
                    greedy_actions = [action]
                elif self.depth[action] == max_depth:
                    greedy_actions.append(action)
            return np.random.choice(greedy_actions)

    def draw(self):
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 5))
        nx.draw(self.graph_burnt, pos=self.positions, with_labels=True, node_color=[0] * n, ax=ax1)
        nx.draw(self.graph_fighter, pos=self.positions, with_labels=True, node_color=[0] * n + [1], ax=ax2)
        ax1.set_title("Fire's graph view")
        ax2.set_title("Fighter's graph view")
        plt.show()

In [10]:
experiments = 20
n = 200
env = gym.make("mfp/MovingFirefighter-v0", n=n, num_fires=1, is_tree=True, time_slots=0.3, render_mode="human", seed=32)
env.metadata["render_fps"] = 32


final_reward = 0
for _ in range(experiments):
    state, _ = env.reset()
    policy = GreeedyPolicy(state)
    for _ in count():
        valid_actions = env.unwrapped.valid_actions()
        action = policy.next_action(valid_actions)
        observation, reward, terminated, truncated, info = env.step(action)
        final_reward += reward
    
        if terminated or truncated:
            break

print(f"Average reward: {final_reward / experiments}.")
env.close()

Average reward: -50.95.
