In [None]:
%pip install -q -e .

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
import nmmo
from nmmo import config

from implementations.train_ppo import train_ppo, evaluate_agent
from implementations.SimplierInputAgentV2 import SimplierInputAgentV2
from implementations.RandomAgent import get_avg_lifetime_for_random_agent, get_avg_reward_for_random_agent
from implementations.Observations import Observations
from implementations.CustomRewardBase import LifetimeReward, \
    ResourcesAndGatheringReward, ExplorationReward, WeightedReward, ShiftingReward
from implementations.StayNearResourcesReward import StayNearResourcesReward
from implementations.SavingCallback import SavingCallback
from implementations.AnimationCallback import AnimationCallback
from implementations.PathTrackingCallback import PathTrackingCallback
from implementations.SpawnTrackingCallback import SpawnTrackingCallback
from implementations.observations_to_inputs import observations_to_inputs_simplier
from implementations.plots import *
from implementations.jar import Jar

In [None]:
conf = config.Default()
conf.set("PLAYER_N", 32)
#conf.set("NPC_N", 0)

reward = WeightedReward({
    StayNearResourcesReward(1024, target_distance=2): 1,
    ResourcesAndGatheringReward(1024, gathering_bonus=4, scale_with_resource_change=False): 1,
    ExplorationReward(1024): 0.2
})

random_reward, random_rewards = get_avg_reward_for_random_agent(conf, reward=reward, retries=20)
random_reward_std = np.std(random_rewards)
print(f"Random agent reward: {random_reward:.6f} ± {random_reward_std:.6f}")

random_lifetime, random_lifetimes = get_avg_lifetime_for_random_agent(conf, retries=20)
random_lifetime_std = np.std(random_lifetimes)
print(f"Random agent lifetime: {random_lifetime:4.2f} ± {random_lifetime_std:.2f}")

In [None]:
agent_name = "compound_reward_retry_negative_sample_weights"
save_name = agent_name

train_ppo(nmmo.Env(conf),
          SimplierInputAgentV2(
            learning_rate=5e-5,
            lr_decay=0.999,
            min_lr=5e-7,
            critic_learning_rate=5e-6,
            critic_lr_decay=0.999,
            critic_min_lr=5e-7,
            epsilon=0.1,
            epochs=50,
            batch_size=256,
            entropy_loss_coef=3e-5,
            max_grad_norm=0.5,
            sample_weights_softmin_temp=-1,
            action_loss_weights={
                "Move": 2.5,
                "AttackStyle": 0.5,
                "AttackTargetPos": 0.5,
              }),
          episodes=3000,
          save_every=25,
          print_every=5,
          eval_every=25,
          eval_episodes=5,
          custom_reward=reward,
          agent_name=agent_name,
          callbacks=[
            SavingCallback(save_name, reward_config=reward.get_config())])

In [None]:
plot_rewards_from_save(save_name, random_agent_reward=random_reward, window=50)

In [None]:
plot_lifetimes_from_save(save_name, random_agent_lifetime=random_lifetime, window=50)

In [None]:
plot_losses_from_save(save_name, window=1000)

In [None]:
plot_entropies_from_save(save_name, window=50)

In [None]:
spawn_callback = SpawnTrackingCallback()

_ = evaluate_agent(
    nmmo.Env(conf),
    SimplierInputAgentV2.load(f"{agent_name}_best"),
    episodes=10,
    custom_reward=LifetimeReward(1024),
    callbacks=[
        spawn_callback
    ]
)

In [None]:
spawn_callback.plot_density_reward_correlation(max_distance=50, smoothing=1)

In [9]:
def get_all_observations_from_save(save_name: str, agent_ids: list[int]) -> list[Observations]:       
    history = Jar("saves").get(save_name)
    observations = [ep_obs[agent_id] 
                   for ep in history 
                   for ep_obs, _ in ep[0] 
                   for agent_id in agent_ids
                   if agent_id in ep_obs]
    return observations

def verify_observations(save_name: str) -> None:
    observations = get_all_observations_from_save(save_name, agent_ids=list(range(1, 33)))
    net_inputs = [observations_to_inputs_simplier(obs, device="cpu") for obs in observations]

    tiles = [inp[0][0] for inp in net_inputs]
    tile_features = [feature for tile in tiles for feature in tile.reshape(-1, 28)[:, -9:]if not torch.all(feature == 0)]

    self_datas = [inp[1][0] for inp in net_inputs]
    move_masks = [inp[2][0] for inp in net_inputs]
    attack_masks = [inp[3][0] for inp in net_inputs]

    def assert_for_all(values, assertion_fn, description):
        correct_count = sum([assertion_fn(tensor) for tensor in values])
        total_count = len(values)
        print(f"{(description+':'):<35}{correct_count}/{total_count} {('✅' if correct_count == total_count else '❌')}")

    assert_for_all(tiles, lambda x: x.shape == torch.Size([15, 15, 28]), "Tiles shape")
    assert_for_all(tiles, lambda x: torch.all(torch.sum(x[:, :, :16], dim=-1) == 1), "16 features one-hot encoded")
    assert_for_all(tiles, lambda x: torch.all(torch.sum(x[:, :, 16:18], dim=-1) == 1), "Each tile either passable or not")
    assert_for_all(tiles, lambda x: torch.all(torch.logical_or(x[:, :, 18] == 0, x[:, :, 18] == 1)), "Each tile harvestable or not")
    # TODO: Check seen entity data
    print()

    assert_for_all(self_datas, lambda x: x.shape == torch.Size([5]), "Self data shape")
    assert_for_all(self_datas, lambda x: torch.all((x >= 0) & (x <= 1)), "All values between 0 and 1")
    print()

    assert_for_all(attack_masks, lambda x: x.shape == torch.Size([3]), "Attack mask shape")
    assert_for_all(attack_masks, lambda x: torch.all(x == 1), "Every attack style valid")
    print()

    assert_for_all(move_masks, lambda x: x.shape == torch.Size([5]), "Move mask shape")
    assert_for_all(move_masks, lambda x: x[-1] == 1, "Can not move")
    assert_for_all(move_masks, lambda x: torch.any(x[:-1] == 1), "Can move somewhere")

    seen_ids = {}
    for obs in observations:
        if obs.agent_id not in seen_ids:
            seen_ids[obs.agent_id] = {}
            
        for seen_id in obs.entities.id:
            if seen_id == 0 or seen_id == obs.agent_id:
                continue
            
            if seen_id not in seen_ids[obs.agent_id]:
                seen_ids[obs.agent_id][seen_id] = 0
                
            seen_ids[obs.agent_id][seen_id] += 1
            
    # Check that if one agent sees another, the other agent sees the first agent
    for agent_id, seen in seen_ids.items():
        for seen_id, count in seen.items():
            if seen_ids.get(seen_id, {}).get(agent_id, 0) != count:
                print(f"Agent {agent_id} saw {seen_id} {count} times, but {seen_id} saw {agent_id} {seen_ids.get(seen_id, {}).get(agent_id, 0)} times")