In [None]:
%pip install -q -e .

In [2]:
import torch
import numpy as np
import matplotlib.pyplot as plt
import nmmo
from nmmo import config

from implementations.train_ppo import train_ppo, evaluate_agent
from implementations.SimplierInputAgentV2 import SimplierInputAgentV2
from implementations.RandomAgent import get_avg_lifetime_for_random_agent, get_avg_reward_for_random_agent
from implementations.Observations import Observations
from implementations.CustomRewardBase import LifetimeReward, ResourcesAndGatheringReward, \
    ExplorationReward, WeightedReward, ShiftingReward, CurriculumReward, ResourcesReward, \
    CustomRewardBase
from implementations.StayNearResourcesReward import StayNearResourcesReward
from implementations.SavingCallback import SavingCallback
from implementations.AnimationCallback import AnimationCallback
from implementations.PathTrackingCallback import PathTrackingCallback
from implementations.SpawnTrackingCallback import SpawnTrackingCallback
from implementations.observations_to_inputs import observations_to_inputs_simplier
from implementations.plots import *
from implementations.jar import Jar

In [None]:
conf = config.Default()
conf.set("PLAYER_N", 32)
#conf.set("NPC_N", 0)

# lifetime_reward = LifetimeReward(max_lifetime=1000)
# resources_reward = ResourcesReward(max_lifetime=1000)
# exploration_reward = ExplorationReward(max_lifetime=1000, map_size=128, view_radius=7)

# random_lifetime_reward, _ = get_avg_reward_for_random_agent(conf, reward=lifetime_reward, retries=20)
# print(f"Random agent lifetime reward: {random_lifetime_reward:.6f}")
# random_resources_reward, _ = get_avg_reward_for_random_agent(conf, reward=resources_reward, retries=20)
# print(f"Random agent resources reward: {random_resources_reward:.6f}")
# random_exploration_reward, _ = get_avg_reward_for_random_agent(conf, reward=exploration_reward, retries=20)
# print(f"Random agent exploration reward: {random_exploration_reward:.6f}")

# reward = CurriculumReward(reward_stages=[
#     (3.5 * random_resources_reward, resources_reward),
#     (3 * random_exploration_reward, exploration_reward),
#     (0, lifetime_reward)
# ])

reward = WeightedReward({
    StayNearResourcesReward(1024, target_distance=2): 1,
    ResourcesAndGatheringReward(1024, gathering_bonus=4, scale_with_resource_change=True): 1,
    ExplorationReward(1024): 0.2
})

random_reward, random_rewards = get_avg_reward_for_random_agent(conf, reward=reward, retries=20)
random_reward_std = np.std(random_rewards)
print(f"Random agent reward: {random_reward:.6f} ± {random_reward_std:.6f}")

random_lifetime, random_lifetimes = get_avg_lifetime_for_random_agent(conf, retries=20)
random_lifetime_std = np.std(random_lifetimes)
print(f"Random agent lifetime: {random_lifetime:4.2f} ± {random_lifetime_std:.2f}")

In [None]:
def make_agent() -> SimplierInputAgentV2:
    return SimplierInputAgentV2(
                learning_rate=5e-5,
                lr_decay=0.999,
                min_lr=5e-7,
                critic_learning_rate=5e-6,
                critic_lr_decay=0.999,
                critic_min_lr=5e-7,
                epsilon=0.1,
                epochs=50,
                batch_size=256,
                entropy_loss_coef=3e-5,
                max_grad_norm=0.5,
                sample_weights_softmin_temp=-0.5,
                action_loss_weights={
                    "Move": 2.5,
                    "AttackStyle": 0.5,
                    "AttackTargetPos": 0.5,
                })
    
def make_test_configs() -> dict[str, SimplierInputAgentV2]:
    configs = {}
    
    # no lr decay
    no_lr_decay = make_agent()
    no_lr_decay.lr_decay = 1
    no_lr_decay.critic_lr_decay = 1
    configs["no_lr_decay"] = no_lr_decay
    
    # no entropy loss
    no_entropy_loss = make_agent()
    no_entropy_loss.entropy_loss_coef = 0
    configs["no_entropy_loss"] = no_entropy_loss
    
    # no action loss weights
    no_action_loss_weights = make_agent()
    no_action_loss_weights.action_loss_weights = {}
    configs["no_action_loss_weights"] = no_action_loss_weights
    
    # no sample weights
    no_sample_weights = make_agent()
    no_sample_weights.softmin_temp = 0
    configs["no_sample_weights"] = no_sample_weights
    
    # positive sample weights
    positive_sample_weights = make_agent()
    positive_sample_weights.softmin_temp = 1
    configs["positive_sample_weights"] = positive_sample_weights
    
    # more negative sample weights
    more_negative_sample_weights = make_agent()
    more_negative_sample_weights.softmin_temp = -1
    configs["more_negative_sample_weights"] = more_negative_sample_weights
    
    # no gradient clipping
    no_gradient_clipping = make_agent()
    no_gradient_clipping.max_grad_norm = 1000
    configs["no_gradient_clipping"] = no_gradient_clipping
    
    # no advantage normalization
    no_advantage_normalization = make_agent()
    no_advantage_normalization.normalize_advantages = False
    configs["no_advantage_normalization"] = no_advantage_normalization
    
    # no separate critic lr
    no_separate_critic_lr = make_agent()
    no_separate_critic_lr.critic_learning_rate = no_separate_critic_lr.learning_rate
    no_separate_critic_lr.critic_lr_decay = no_separate_critic_lr.lr_decay
    no_separate_critic_lr.critic_min_lr = no_separate_critic_lr.min_lr
    configs["no_separate_critic_lr"] = no_separate_critic_lr
    
def run_test_configs(
    reward: CustomRewardBase, 
    run_only: list[str] | None = None,
    episodes: int = 400,
    dir_name: str = "final_tests"
) -> None:
    configs = make_test_configs()
    if run_only is not None:
        configs = {k: v for k, v in configs.items() if k in run_only}
        
    random_reward, random_rewards = get_avg_reward_for_random_agent(env_conf, reward=reward, retries=20)
    random_reward_std = np.std(random_rewards)
    print(f"Random agent reward: {random_reward:.6f} ± {random_reward_std:.6f}")

    random_lifetime, random_lifetimes = get_avg_lifetime_for_random_agent(env_conf, retries=20)
    random_lifetime_std = np.std(random_lifetimes)
    print(f"Random agent lifetime: {random_lifetime:4.2f} ± {random_lifetime_std:.2f}")
        
    for name, agent in configs.items():
        env_conf = config.Default()
        env_conf.set("PLAYER_N", 32)
        
        agent_name = f"{dir_name}/{name}"
        train_ppo(nmmo.Env(env_conf),
            agent,
            episodes=episodes,
            save_every=25,
            print_every=5,
            eval_every=50,
            eval_episodes=5,
            custom_reward=reward,
            agent_name=agent_name,
            callbacks=[
                SavingCallback(agent_name, reward_config=reward.get_config())])
        
        
        plot_rewards_from_save(agent_name, random_agent_reward=random_reward, window=50, save_as=f"{dir_name}/rewards/{name}")
        plot_lifetimes_from_save(agent_name, random_agent_lifetime=random_lifetime, window=50, save_as=f"{dir_name}/lifetimes/{name}")
        plot_losses_from_save(agent_name, window=1000, save_as=f"{dir_name}/losses/{name}")
        plot_entropies_from_save(agent_name, window=1000, save_as=f"{dir_name}/entropies/{name}")
        
def make_plots_from_save(save_name: str) -> None:
    dir_name = "variants"
    plot_rewards_from_save(save_name, random_agent_reward=random_reward, window=50, save_as=f"{dir_name}/rewards/{save_name}")
    plot_lifetimes_from_save(save_name, random_agent_lifetime=random_lifetime, window=50, save_as=f"{dir_name}/lifetimes/{save_name}")
    plot_losses_from_save(save_name, window=1000, save_as=f"{dir_name}/losses/{save_name}")
    plot_entropies_from_save(save_name, window=50, save_as=f"{dir_name}/entropies/{save_name}")   

In [None]:
agent_name = "variation_scaled_rathering_bonus"
save_name = agent_name

train_ppo(nmmo.Env(conf),
          SimplierInputAgentV2(
            learning_rate=5e-5,
            lr_decay=0.999,
            min_lr=5e-7,
            critic_learning_rate=5e-6,
            critic_lr_decay=0.999,
            critic_min_lr=5e-7,
            epsilon=0.1,
            epochs=50,
            batch_size=256,
            entropy_loss_coef=3e-5,
            max_grad_norm=0.5,
            sample_weights_softmin_temp=-0.5,
            action_loss_weights={
                "Move": 2.5,
                "AttackStyle": 0.5,
                "AttackTargetPos": 0.5,
              }),
          episodes=2000,
          save_every=25,
          print_every=5,
          eval_every=50,
          eval_episodes=5,
          custom_reward=reward,
          agent_name=agent_name,
          callbacks=[
            SavingCallback(save_name, reward_config=reward.get_config())])

In [None]:
plot_rewards_from_save(save_name, random_agent_reward=random_reward, window=50)

In [None]:
plot_lifetimes_from_save(save_name, random_agent_lifetime=random_lifetime, window=50)

In [None]:
plot_losses_from_save(save_name, window=1000)

In [None]:
plot_entropies_from_save(save_name, window=50)

In [None]:
spawn_callback = SpawnTrackingCallback()
path_callback = PathTrackingCallback()

eval_rewards, _ = evaluate_agent(
    nmmo.Env(conf),
    SimplierInputAgentV2.load(f"{agent_name}_best"),
    episodes=20,
    custom_reward=LifetimeReward(1024),
    callbacks=[
        spawn_callback,
        path_callback,
        AnimationCallback(1, f"{agent_name}_best_animation")
    ],
    sample_actions=True
)

print(f"Average reward: {np.mean(eval_rewards):.6f} ± {np.std(eval_rewards):.6f}")

In [None]:
spawn_callback.plot_density_reward_correlation(max_distance=50, smoothing=1)

In [None]:
path_callback.plot_paths(10)

In [12]:
def get_all_observations_from_save(save_name: str, agent_ids: list[int]) -> list[Observations]:       
    history = Jar("saves").get(save_name)
    observations = [ep_obs[agent_id] 
                   for ep in history 
                   for ep_obs, _ in ep[0] 
                   for agent_id in agent_ids
                   if agent_id in ep_obs]
    return observations

def verify_observations(save_name: str) -> None:
    observations = get_all_observations_from_save(save_name, agent_ids=list(range(1, 33)))
    net_inputs = [observations_to_inputs_simplier(obs, device="cpu") for obs in observations]

    tiles = [inp[0][0] for inp in net_inputs]
    tile_features = [feature for tile in tiles for feature in tile.reshape(-1, 28)[:, -9:]if not torch.all(feature == 0)]

    self_datas = [inp[1][0] for inp in net_inputs]
    move_masks = [inp[2][0] for inp in net_inputs]
    attack_masks = [inp[3][0] for inp in net_inputs]

    def assert_for_all(values, assertion_fn, description):
        correct_count = sum([assertion_fn(tensor) for tensor in values])
        total_count = len(values)
        print(f"{(description+':'):<35}{correct_count}/{total_count} {('✅' if correct_count == total_count else '❌')}")

    assert_for_all(tiles, lambda x: x.shape == torch.Size([15, 15, 28]), "Tiles shape")
    assert_for_all(tiles, lambda x: torch.all(torch.sum(x[:, :, :16], dim=-1) == 1), "16 features one-hot encoded")
    assert_for_all(tiles, lambda x: torch.all(torch.sum(x[:, :, 16:18], dim=-1) == 1), "Each tile either passable or not")
    assert_for_all(tiles, lambda x: torch.all(torch.logical_or(x[:, :, 18] == 0, x[:, :, 18] == 1)), "Each tile harvestable or not")
    # TODO: Check seen entity data
    print()

    assert_for_all(self_datas, lambda x: x.shape == torch.Size([5]), "Self data shape")
    assert_for_all(self_datas, lambda x: torch.all((x >= 0) & (x <= 1)), "All values between 0 and 1")
    print()

    assert_for_all(attack_masks, lambda x: x.shape == torch.Size([3]), "Attack mask shape")
    assert_for_all(attack_masks, lambda x: torch.all(x == 1), "Every attack style valid")
    print()

    assert_for_all(move_masks, lambda x: x.shape == torch.Size([5]), "Move mask shape")
    assert_for_all(move_masks, lambda x: x[-1] == 1, "Can not move")
    assert_for_all(move_masks, lambda x: torch.any(x[:-1] == 1), "Can move somewhere")

    seen_ids = {}
    for obs in observations:
        if obs.agent_id not in seen_ids:
            seen_ids[obs.agent_id] = {}
            
        for seen_id in obs.entities.id:
            if seen_id == 0 or seen_id == obs.agent_id:
                continue
            
            if seen_id not in seen_ids[obs.agent_id]:
                seen_ids[obs.agent_id][seen_id] = 0
                
            seen_ids[obs.agent_id][seen_id] += 1
            
    # Check that if one agent sees another, the other agent sees the first agent
    for agent_id, seen in seen_ids.items():
        for seen_id, count in seen.items():
            if seen_ids.get(seen_id, {}).get(agent_id, 0) != count:
                print(f"Agent {agent_id} saw {seen_id} {count} times, but {seen_id} saw {agent_id} {seen_ids.get(seen_id, {}).get(agent_id, 0)} times")