In [None]:
%pip install -q -e .

In [None]:
import nmmo
from implementations.train_ppo import train_ppo, EvaluationCallback
from implementations.Observations import Observations
import os
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation

In [None]:
class AnimationCallback(EvaluationCallback):
    def __init__(self, agent_id: int, image_dir: str, output_file: str, quiet: bool = True):
        self.agent_id = agent_id
        self.image_dir = image_dir
        self.output_file = output_file
        self.quiet = quiet
        self._current_episode_steps = 0
    
    def create_animation(self, output_file: str, fps: float = 2) -> None:
        os.makedirs(os.path.dirname(output_file), exist_ok=True)
        
        fig, ax = plt.subplots(figsize=(10, 8))

        def update(step):
            filename = os.path.join(self.image_dir, f"step_{step}.png")
            if os.path.exists(filename):
                img = plt.imread(filename)
                ax.clear()
                ax.imshow(img)
                ax.axis('off')
            else:
                if not self.quiet:
                    print(f"File {filename} does not exist")

        ani = animation.FuncAnimation(fig, update, frames=self._current_episode_steps, repeat=False)
        ani.save(output_file, writer='pillow', fps=fps)
        plt.close(fig)
        if not self.quiet:
            print(f"Saved animation to {output_file}")

    def plot_agent_view(self, obs: dict[int, Observations], env_actions: dict[int, dict[str, dict[str, int]]], agent_id: int, step: int) -> None:
        if not os.path.exists(self.image_dir):
            os.makedirs(self.image_dir)

        agent_observations = obs.get(agent_id, None)
        if agent_observations is None:
            if not self.quiet:
                print(f"Agent {agent_id} not found in observations at step {step}")
            return

        tiles = agent_observations.tiles
        tile_rows, tile_cols, tile_values = tiles[:, 0], tiles[:, 1], tiles[:, 2]

        agent_health = agent_observations.entities.health[0]
        agent_food = agent_observations.entities.food[0]
        agent_water = agent_observations.entities.water[0]

        min_row, max_row = tile_rows.min(), tile_rows.max()
        min_col, max_col = tile_cols.min(), tile_cols.max()
        grid_rows, grid_cols = max_row - min_row + 1, max_col - min_col + 1
        grid = np.zeros((grid_rows, grid_cols))

        for r, c, v in zip(tile_rows, tile_cols, tile_values):
            grid[r - min_row, c - min_col] = v

        fig, ax = plt.subplots(figsize=(8, 8))
        ax.imshow(grid, cmap='coolwarm', interpolation='nearest', alpha=0.8)

        def get_health_color(health):
            if health > 75:
                return 'green'
            elif health > 50:
                return 'yellow'
            elif health > 25:
                return 'orange'
            else:
                return 'red'

        center_x, center_y = (
            agent_observations.entities.row[0] - min_row,
            agent_observations.entities.col[0] - min_col
        )
        agent_health_color = 'lightgreen' if agent_health > 75 else \
                            'yellowgreen' if agent_health > 50 else \
                            'darkorange' if agent_health > 25 else 'darkred'
        ax.scatter(center_y, center_x, c=agent_health_color, s=100, label=f'Agent {agent_id}', edgecolors='black')

        agent_stats = f"Health: {agent_health}\nFood: {agent_food}\nWater: {agent_water}"

        action = env_actions.get(agent_id, None)
        action_details = "No action"

        if action:
            move_direction = action.get('Move', {}).get('Direction', None)
            attack_style = action.get('Attack', {}).get('Style', None)
            attack_target = action.get('Attack', {}).get('Target', None)
            use_item = action.get('Use', {}).get('InventoryItem', None)
            destroy_item = action.get('Destroy', {}).get('InventoryItem', None)

            action_details = ""
            if move_direction is not None:
                action_details += f"Move: Direction {move_direction}\n"
            if attack_style is not None and attack_target is not None:
                action_details += f"Attack: Style {attack_style}, Target {attack_target}\n"
            if use_item is not None:
                action_details += f"Use: Item {use_item}\n"
            if destroy_item is not None:
                action_details += f"Destroy: Item {destroy_item}\n"

        ax.text(0.05, 0.05, f"Action:\n{action_details}", transform=ax.transAxes, fontsize=10, 
                verticalalignment='bottom', horizontalalignment='left', color='black', 
                bbox=dict(facecolor='white', alpha=0.7, edgecolor='black'))

        for idx, entity_id in enumerate(agent_observations.entities.id):
            if entity_id == agent_id or entity_id == 0:
                continue

            entity_row = agent_observations.entities.row[idx]
            entity_col = agent_observations.entities.col[idx]
            entity_health = agent_observations.entities.health[idx]

            local_x = entity_row - min_row
            local_y = entity_col - min_col
            ax.scatter(
                local_y, local_x, c=get_health_color(entity_health), s=50, alpha=0.7,
                label=f'Entity {entity_id}', edgecolors='black'
            )

        ax.set_xticks(np.arange(grid_cols))
        ax.set_yticks(np.arange(grid_rows))
        ax.set_xticklabels(np.arange(min_col, max_col + 1))
        ax.set_yticklabels(np.arange(min_row, max_row + 1))

        ax.set_title(f"Agent {agent_id}'s View at Step {step}")
        ax.text(0.95, 0.95, agent_stats, transform=ax.transAxes, fontsize=12, 
                verticalalignment='top', horizontalalignment='right', color='black', 
                bbox=dict(facecolor='white', alpha=0.7, edgecolor='black'))
        ax.legend(loc='upper left')
        plt.tight_layout()

        output_file = f"{self.image_dir}/step_{step}.png"
        plt.savefig(output_file)
        plt.close()
        if not self.quiet:
            print(f"Saved agent view to {output_file}")
        
    def step(
        self,
        observations_per_agent: dict[int, Observations], 
        actions_per_agent: dict[int, dict[str, dict[str, int]]], 
        episode: int, 
        step: int) -> None:
        self.plot_agent_view(observations_per_agent, actions_per_agent, self.agent_id, step)
        self._current_episode_steps += 1

    def episode_start(self, episode: int) -> None:
        self._current_episode_steps = 0

    def episode_end(self, episode: int, rewards_per_agent: dict[int, float]) -> None:
        self.create_animation(f"{self.output_file}_{episode}.gif", fps=2)

In [None]:
env = nmmo.Env()
train_ppo(env, 
          episodes=10, 
          save_every=1, 
          agent_name="test_agent", 
          callbacks=[AnimationCallback(1, "views/frames", "views/agent_views_episode", quiet=True)])

In [None]:
from implementations.train_ppo import evaluate_loaded_agent


env = nmmo.Env()
evaluate_loaded_agent(env, 
                      episodes=3, 
                      agent_name="test_agent_at_ep10", 
                      callbacks=[AnimationCallback(1, "views_2/frames", "views_2/agent_views_episode", quiet=True)])