In [None]:
%pip install -q -e .

In [None]:
import nmmo
from implementations.train_ppo import train_ppo, EvaluationCallback, evaluate_agent
from implementations.PpoAgent import PPOAgent
from implementations.SimplierInputAgent import SimplierInputAgent
from implementations.RandomAgent import get_avg_lifetime_for_random_agent, get_avg_reward_for_random_agent
from implementations.Observations import Observations
from implementations.CustomRewardBase import LifetimeReward, ResourcesReward, CustomRewardBase, ResourcesAndGatheringReward
from implementations.SavingCallback import SavingCallback
from implementations.observations_to_inputs import observations_to_inputs_simplier
from implementations.jar import Jar
from implementations.ActionData import ActionData
import torch
import os
import shutil
import time
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from nmmo import config

In [None]:
class AnimationCallback(EvaluationCallback):
    def __init__(self, agent_id: int, output_name: str, quiet: bool = True):
        self.agent_id = agent_id
        self.output_name = output_name
        self.quiet = quiet
        self._plots_dir = "plots"
        self._image_dir = f"{self._plots_dir}/frames"
        self._current_episode_steps = 0
        
        self.tile_color_map = {
            'void': '#000000',      # Black
            'water': '#4169E1',     # Royal blue
            'grass': '#7CBA3B',     # Yellow-green
            'stone': '#808080',     # Gray
            
            'ore': '#8B4513',       # Saddle brown
            'slag': '#A0522D',      # Sienna
            
            'herb': '#00FF7F',      # Spring green
            'weeds': '#556B2F',     # Dark olive green
            
            'foilage': '#90EE90',   # Light green
            'scrub': '#3CB371',     # Medium sea green
            
            'crystal': '#B8860B',   # Dark goldenrod
            'fragment': '#DEB887',  # Burlywood
            
            'tree': '#228B22',      # Forest green
            'stump': '#8B4513',     # Saddle brown
            
            'fish': '#87CEEB',      # Sky blue
            'ocean': '#000080',     # Navy blue
        }

    def create_animation(self, output_file: str, fps: float = 2) -> None:
        os.makedirs(os.path.dirname(output_file), exist_ok=True)
        
        fig, ax = plt.subplots(figsize=(10, 8))

        def update(step):
            filename = os.path.join(self._image_dir, f"step_{step}.png")
            if os.path.exists(filename):
                img = plt.imread(filename)
                ax.clear()
                ax.imshow(img)
                ax.axis('off')
            else:
                if not self.quiet:
                    print(f"File {filename} does not exist")

        ani = animation.FuncAnimation(fig, update, frames=self._current_episode_steps, repeat=False)
        ani.save(output_file, writer='pillow', fps=fps)
        plt.close(fig)
        if not self.quiet:
            print(f"Saved animation to {output_file}")

    def _get_color_grid(
        self, 
        tile_rows: np.ndarray, 
        tile_cols: np.ndarray, 
        tile_values: np.ndarray, 
        min_row: int, 
        min_col: int, 
        grid_rows: int, 
        grid_cols: int
    ) -> np.ndarray:        
        def hex_to_rgb(hex_color):
            hex_color = hex_color.lstrip('#')
            return [int(hex_color[i:i+2], 16)/255.0 for i in (0, 2, 4)]

        index_to_material = {m.index: m.tex for m in nmmo.material.All.materials}
        color_grid = np.zeros((grid_rows, grid_cols, 3))
        for r, c, v in zip(tile_rows, tile_cols, tile_values):
            color_grid[r - min_row, c - min_col] = hex_to_rgb(self.tile_color_map[index_to_material[v]])
        return color_grid

    def _get_action_text(self, action: dict[str, dict[str, int]]) -> str:
        if not action:
            return "No action"
        
        action_parts = []
        if 'Move' in action and 'Direction' in action['Move']:
            action_parts.append(f"Move: Direction {action['Move']['Direction']}")
        if 'Attack' in action:
            attack = action['Attack']
            if 'Style' in attack and 'Target' in attack:
                action_parts.append(f"Attack: Style {attack['Style']}, Target {attack['Target']}")
        if 'Use' in action and 'InventoryItem' in action['Use']:
            action_parts.append(f"Use: Item {action['Use']['InventoryItem']}")
        if 'Destroy' in action and 'InventoryItem' in action['Destroy']:
            action_parts.append(f"Destroy: Item {action['Destroy']['InventoryItem']}")
            
        return '\n'.join(action_parts) if action_parts else "No action"

    def _plot_entities(self, ax, agent_observations: Observations, min_row: int, min_col: int):
        def get_health_color(health):
            if health > 75: return 'green'
            elif health > 50: return 'yellow'
            elif health > 25: return 'orange'
            return 'red'

        for idx, entity_id in enumerate(agent_observations.entities.id):
            if entity_id == self.agent_id or entity_id == 0:
                continue

            local_x = agent_observations.entities.row[idx] - min_row
            local_y = agent_observations.entities.col[idx] - min_col
            health = agent_observations.entities.health[idx]
            ax.scatter(local_y, local_x, c=get_health_color(health), s=50, 
                      alpha=0.7, label=f'Entity {entity_id}', edgecolors='black')
            
    def _plot_tile_legend(self, ax):
        ax.set_xlim(0, 1)
        ax.set_ylim(0, len(self.tile_color_map))
        ax.invert_yaxis()
        ax.axis('off')
        items = list(self.tile_color_map.items())
        for idx, (mat, color) in enumerate(items):
            rect = plt.Rectangle((0, idx), 0.5, 0.5, facecolor=color, edgecolor='black', alpha=0.8)
            ax.add_patch(rect)
            text_color = 'white' if mat in ['ocean', 'void'] else 'black'
            ax.text(0.25, idx + 0.25, mat, va='center', ha='center', fontsize=12, family='monospace', color=text_color)

    def plot_agent_view(
        self, 
        obs: dict[int, Observations], 
        env_actions: dict[int, dict[str, dict[str, int]]], 
        agent_id: int, 
        step: int
    ) -> None:
        if not os.path.exists(self._image_dir):
            os.makedirs(self._image_dir)

        agent_observations = obs.get(agent_id)
        if agent_observations is None or not isinstance(agent_observations, Observations):
            if not self.quiet:
                print(f"Agent {agent_id} not found in observations at step {step}")
            return

        # Extract tile information
        tiles = agent_observations.tiles
        tile_rows, tile_cols, tile_values = tiles[:, 0], tiles[:, 1], tiles[:, 2]
        min_row, max_row = tile_rows.min(), tile_rows.max()
        min_col, max_col = tile_cols.min(), tile_cols.max()
        grid_rows, grid_cols = max_row - min_row + 1, max_col - min_col + 1

        # Create plot
        _, (ax, ax_legend) = plt.subplots(1, 2, figsize=(10, 8), gridspec_kw={'width_ratios': [10, 2]})
        self._plot_tile_legend(ax_legend)
        
        color_grid = self._get_color_grid(tile_rows, tile_cols, tile_values, min_row, min_col, grid_rows, grid_cols)
        ax.imshow(color_grid, interpolation='nearest', alpha=0.8)

        agent_idx = np.where(agent_observations.entities.id == agent_id)[0][0]

        # Plot agent
        center_x = agent_observations.entities.row[agent_idx] - min_row
        center_y = agent_observations.entities.col[agent_idx] - min_col
        agent_health = agent_observations.entities.health[agent_idx]
        agent_food = agent_observations.entities.food[agent_idx]
        agent_water = agent_observations.entities.water[agent_idx]
        agent_health_color = 'lightgreen' if agent_health > 75 else 'yellowgreen' if agent_health > 50 else 'darkorange' if agent_health > 25 else 'darkred'
        ax.scatter(center_y, center_x, c=agent_health_color, s=100, label=f'Agent {agent_id}', edgecolors='black')

        # Plot other entities
        self._plot_entities(ax, agent_observations, min_row, min_col)

        # Add text information
        agent_stats = f"Health: {agent_health:>3}\n" + \
                      f"Food:   {agent_food:>3}\n" + \
                      f"Water:  {agent_water:>3}"
        action_text = self._get_action_text(env_actions.get(agent_id))
        
        ax.text(0.05, 0.05, f"Action:\n{action_text}", transform=ax.transAxes, fontsize=12,
                verticalalignment='bottom', horizontalalignment='left', color='black',
                bbox=dict(facecolor='white', alpha=0.7, edgecolor='black'), family='monospace')
        ax.text(0.95, 0.95, f"Step:  {step:>4}", transform=ax.transAxes, fontsize=12,
                verticalalignment='top', horizontalalignment='right', color='black',
                bbox=dict(facecolor='white', alpha=0.7, edgecolor='black'), family='monospace')
        ax.text(0.95, 0.90, agent_stats, transform=ax.transAxes, fontsize=12,
                verticalalignment='top', horizontalalignment='right', color='black',
                bbox=dict(facecolor='white', alpha=0.7, edgecolor='black'), family='monospace')

        # Set up axes
        ax.set_xticks(np.arange(grid_cols))
        ax.set_yticks(np.arange(grid_rows))
        ax.set_xticklabels(np.arange(min_col, max_col + 1))
        ax.set_yticklabels(np.arange(min_row, max_row + 1))
        legend = ax.legend(loc='upper left')
        for text in legend.get_texts():
            text.set_family('monospace')
        
        plt.tight_layout()
        output_file = f"{self._image_dir}/step_{step}.png"
        plt.savefig(output_file)
        plt.close()
        
        if not self.quiet:
            print(f"Saved agent view to {output_file}")

    def step(
        self,
        observations_per_agent: dict[int, Observations], 
        actions_per_agent: dict[int, ActionData], 
        episode: int, 
        step: int) -> None:
        self.plot_agent_view(
            observations_per_agent, 
            {agent_id: actions.action_dict for agent_id, actions in actions_per_agent.items()}, 
            self.agent_id, 
            step)
        self._current_episode_steps += 1

    def episode_start(self, episode: int) -> None:
        self._current_episode_steps = 0

    def episode_end(
        self, 
        episode: int, 
        rewards_per_agent: dict[int, float], 
        losses: tuple[list[float], list[float], list[float]]
    )-> None:
        self.create_animation(f"{self._plots_dir}/animations/{self.output_name}_{episode}_{int(time.time())}.gif", fps=2)
        if os.path.exists(self._image_dir):
            shutil.rmtree(self._image_dir)

In [None]:
def plot_losses(
    actor_losses: list[float], 
    critic_losses: list[float],
    window: int = 500
) -> None:
    _, (ax1, ax2) = plt.subplots(1, 2, figsize=(20, 6))
    
    ax1.plot(actor_losses, label="Actor Loss", color='blue', alpha=0.4)
    actor_losses_smooth = np.convolve(actor_losses, np.ones(window)/window, mode='valid')

    actor_losses_std = np.array([np.std(actor_losses[max(0, i-window):i+1]) 
                                for i in range(window-1, len(actor_losses))])
    
    ax1.plot(range(window-1, len(actor_losses)), actor_losses_smooth, 
             label=f"Running Mean (window={window})", color='red')
    ax1.fill_between(range(window-1, len(actor_losses)), 
                     actor_losses_smooth - actor_losses_std,
                     actor_losses_smooth + actor_losses_std,
                     alpha=0.2, color='red', label='Standard Deviation')
    
    ax1.set_xlabel("Epoch")
    ax1.set_ylabel("Loss")
    ax1.set_title("Actor Loss Over Time")
    ax1.legend()
    
    ax2.semilogy(critic_losses, label="Critic Loss", color='blue', alpha=0.4)
    critic_losses_smooth = np.convolve(critic_losses, np.ones(window)/window, mode='valid')
    ax2.semilogy(range(window-1, len(critic_losses)), critic_losses_smooth, 
                 label=f"Running Mean (window={window})", color='red')
    ax2.set_xlabel("Epoch")
    ax2.set_ylabel("Loss (log scale)")
    ax2.set_title("Critic Loss Over Time")
    ax2.legend()
    
    plt.tight_layout()
    plt.show()
    
def plot_losses_from_save(agent_name: str, window: int = 500) -> None:
    history = Jar("saves").get(agent_name)

    losses = [episode[2] for episode in history]
    actor_losses = [l for loss in losses for l in loss[0]]
    critic_losses = [l for loss in losses for l in loss[1]]
    
    plot_losses(actor_losses, critic_losses, window)

def plot_rewards(
    avg_rewards: list[float], 
    max_rewards: list[float], 
    min_rewards: list[float], 
    ninetieth_percentile_rewards: list[float] | None = None, 
    random_agent_reward: float | None = None,
    window: int = 50
) -> None:
    _, ax = plt.subplots(figsize=(10, 6))
    
    if ninetieth_percentile_rewards is not None:
        ax.plot(ninetieth_percentile_rewards, label="90th Percentile Reward", color='purple', alpha=0.4)
        
    ax.plot(avg_rewards, label="Average Reward", color='red', alpha=0.4)
    ax.plot(max_rewards, label="Max Reward", color='pink', alpha=0.8)
    ax.plot(min_rewards, label="Min Reward", color='green', alpha=0.4)
    
    if random_agent_reward is not None:
        ax.axhline(y=random_agent_reward, label="Random Agent Reward", color='black', linestyle='--')
    
    avg_rewards_smooth = np.convolve(avg_rewards, np.ones(window)/window, mode='valid')
    ax.plot(range(window-1, len(avg_rewards)), avg_rewards_smooth, label=f"Running Mean (window={window})", color='blue')

    ax.set_xlabel("Episode")
    ax.set_ylabel("Reward")
    ax.set_title("Rewards Over Time")
    ax.legend()
    plt.show()

def plot_rewards_from_save(agent_name: str, window: int = 50, random_agent_reward: float | None = None) -> None:
    history = Jar("saves").get(agent_name)

    rewards = [episode[1] for episode in history]
    num_agents = len(rewards[0])
    
    avg_rewards = [np.mean([r for r in reward.values()]) for reward in rewards]
    max_rewards = [np.max([r for r in reward.values()]) for reward in rewards]
    min_rewards = [np.min([r for r in reward.values()]) for reward in rewards]
    
    if num_agents > 30:
        ninetieth_percentile_rewards = [np.percentile([r for r in reward.values()], 90) for reward in rewards]
    else:
        ninetieth_percentile_rewards = None
    
    plot_rewards(avg_rewards, max_rewards, min_rewards, ninetieth_percentile_rewards, random_agent_reward, window)

def plot_lifetimes(
    avg_lifetimes: list[float], 
    max_lifetimes: list[float], 
    min_lifetimes: list[float], 
    ninetieth_percentile: list[float] | None = None,
    random_agent_lifetime: float | None = None,
    window: int = 50
) -> None:
    _, ax = plt.subplots(figsize=(10, 6))
    
    if ninetieth_percentile is not None:
        ax.plot(ninetieth_percentile, label="90th Percentile", color='purple', alpha=0.4)
        
    ax.plot(avg_lifetimes, label="Average Lifetime", color='red', alpha=0.4)
    ax.plot(max_lifetimes, label="Max Lifetime", color='pink', alpha=0.8)
    ax.plot(min_lifetimes, label="Min Lifetime", color='green', alpha=0.4)
    
    if random_agent_lifetime is not None:
        ax.axhline(y=random_agent_lifetime, label="Random Agent Lifetime", color='black', linestyle='--')

    avg_rewards_smooth = np.convolve(avg_lifetimes, np.ones(window)/window, mode='valid')
    ax.plot(range(window-1, len(avg_lifetimes)), avg_rewards_smooth, label=f"Running Mean (window={window})", color='blue')

    ax.set_xlabel("Episode")
    ax.set_ylabel("Lifetime")
    ax.set_title("Agent Lifetime Over Time")
    ax.legend()
    plt.show()   

def plot_lifetimes_from_save(agent_name: str, random_agent_lifetime: float | None = None, window: int = 50) -> None:
    history = Jar("saves").get(agent_name)

    lifetimes = [episode[3] for episode in history]
    num_agents = len(lifetimes[0])
    
    avg_lifetimes = [np.mean([r for r in reward.values()]) for reward in lifetimes]
    max_lifetimes = [np.max([r for r in reward.values()]) for reward in lifetimes]
    min_lifetimes = [np.min([r for r in reward.values()]) for reward in lifetimes]
    
    if num_agents > 30:
        ninetieth_percentile = [np.percentile([r for r in lifetime.values()], 90) for lifetime in lifetimes]
    else:
        ninetieth_percentile = None
    
    plot_lifetimes(avg_lifetimes, max_lifetimes, min_lifetimes, ninetieth_percentile, random_agent_lifetime, window)

In [None]:
def get_all_observations_from_save(save_name: str, agent_ids: list[int]) -> list[Observations]:       
    history = Jar("saves").get(save_name)
    observations = [ep_obs[agent_id] 
                   for ep in history 
                   for ep_obs, _ in ep[0] 
                   for agent_id in agent_ids
                   if agent_id in ep_obs]
    return observations

In [None]:
conf = config.Default()
conf.set("PLAYER_N", 32)
conf.set("NPC_N", 0)

reward = ResourcesAndGatheringReward(1024)

random_reward, random_rewards = get_avg_reward_for_random_agent(conf, reward=reward, retries=5)
random_reward_std = np.std(random_rewards)

random_lifetime, random_lifetimes = get_avg_lifetime_for_random_agent(conf, retries=5)
random_lifetime_std = np.std(random_lifetimes)

print(f"Random agent reward: {random_reward:.6f} ± {random_reward_std:.6f}")
print(f"Random agent lifetime: {random_lifetime:4.2f} ± {random_lifetime_std:.2f}")

In [None]:
agent_name = "copy_arrays_32_agents_npcs_quick_test"
save_name = agent_name

train_ppo(nmmo.Env(conf),
          SimplierInputAgent(
              learning_rate=5e-5,
              epsilon=0.1,
              epochs=50,
              batch_size=256),
          episodes=20,
          save_every=200,
          print_every=1,
          custom_reward=reward,
          agent_name=agent_name,
          callbacks=[SavingCallback(save_name, saved_agent_ids=list(range(32)))])

In [None]:
plot_rewards_from_save(save_name, random_agent_reward=random_reward, window=50)

In [None]:
plot_lifetimes_from_save(save_name, random_agent_lifetime=random_lifetime, window=50)

In [None]:
plot_losses_from_save(save_name, window=2000)

In [9]:
observations = get_all_observations_from_save(save_name, agent_ids=list(range(32)))
net_inputs = [observations_to_inputs_simplier(obs, device="cpu") for obs in observations]

tiles = [inp[0][0] for inp in net_inputs]
tile_features = [feature for tile in tiles for feature in tile.reshape(-1, 28)[:, -9:]if not torch.all(feature == 0)]

self_datas = [inp[1][0] for inp in net_inputs]
move_masks = [inp[2][0] for inp in net_inputs]
attack_masks = [inp[3][0] for inp in net_inputs]

In [None]:
def assert_for_all(values, assertion_fn, description):
    correct_count = sum([assertion_fn(tensor) for tensor in values])
    total_count = len(values)
    print(f"{(description+':'):<35}{correct_count}/{total_count} {('✅' if correct_count == total_count else '❌')}")

assert_for_all(tiles, lambda x: x.shape == torch.Size([15, 15, 28]), "Tiles shape")
assert_for_all(tiles, lambda x: torch.all(torch.sum(x[:, :, :16], dim=-1) == 1), "16 features one-hot encoded")
assert_for_all(tiles, lambda x: torch.all(torch.sum(x[:, :, 16:18], dim=-1) == 1), "Each tile either passable or not")
assert_for_all(tiles, lambda x: torch.all(torch.logical_or(x[:, :, 18] == 0, x[:, :, 18] == 1)), "Each tile harvestable or not")
# TODO: Check seen entity data
print()

assert_for_all(self_datas, lambda x: x.shape == torch.Size([5]), "Self data shape")
assert_for_all(self_datas, lambda x: torch.all((x >= 0) & (x <= 1)), "All values between 0 and 1")
print()

assert_for_all(attack_masks, lambda x: x.shape == torch.Size([3]), "Attack mask shape")
assert_for_all(attack_masks, lambda x: torch.all(x == 1), "Every attack style valid")
print()

assert_for_all(move_masks, lambda x: x.shape == torch.Size([5]), "Move mask shape")
assert_for_all(move_masks, lambda x: x[-1] == 1, "Can not move")
assert_for_all(move_masks, lambda x: torch.any(x[:-1] == 1), "Can move somewhere")