In [1]:
from torchrl.envs.utils import check_env_specs

from environments.flipit_geometric import FlipItEnv, FlipItMap

NUM_STEPS = 15
flipit_map = FlipItMap.load("test_map3.pth", "cuda:0")
env = FlipItEnv(flipit_map, NUM_STEPS, "cuda:0")
check_env_specs(env)

2025-06-10 09:07:49,000 [torchrl][INFO] check_env_specs succeeded!


In [3]:
import yaml

from config import TrainingConfig, LossConfig, AgentNNConfig, BackboneConfig, HeadConfig
from algorithms.simple_nn import TrainableNNAgentPolicy
from algorithms.generic_policy import MultiAgentPolicy
from algorithms.generator import AgentGenerator


with open("configs/run/test_single_training_transformer.yaml", "r") as file:
    config = yaml.safe_load(file)
training_config_defender = TrainingConfig.from_dict(config, suffix="_defender")
loss_config_defender = LossConfig.from_dict(config, suffix="_defender")
training_config_attacker = TrainingConfig.from_dict(config, suffix="_attacker")
loss_config_attacker = LossConfig.from_dict(config, suffix="_attacker")
num_nodes = flipit_map.num_nodes
agent_config = AgentNNConfig.from_dict(config)
backbone_config = BackboneConfig.from_dict(config, suffix="_backbone")
head_config = HeadConfig.from_dict(config, suffix="_head")

In [5]:
defender_agent_transformer = TrainableNNAgentPolicy(
    num_nodes=num_nodes,
    total_steps=env.num_steps,
    player_type=0,
    agent_config=agent_config,
    backbone_config=backbone_config,
    head_config=head_config,
    device="cuda:0",
    loss_config=loss_config_defender,
    training_config=training_config_defender,
    run_name="test",
)
defender_agent_transformer.eval()
defender_agent_transformer.load("saved_models/2025-06-08_16:30:57-full-transformer/defender/agent_0.pth")

attacker_agent_transformer = MultiAgentPolicy(
    num_nodes=num_nodes,
    player_type=1,
    device="cuda:0",
    embedding_size=32,
    run_name="test",
    policy_generator=AgentGenerator(
        TrainableNNAgentPolicy,
        {
            "num_nodes": num_nodes,
            "total_steps": env.num_steps,
            "player_type": 1,
            "device": "cuda:0",
            "loss_config": loss_config_attacker,
            "training_config": training_config_attacker,
            "run_name": "test",
            "agent_config": agent_config,
            "backbone_config": backbone_config,
            "head_config": head_config,
        }
    ),
)
attacker_agent_transformer.eval()
attacker_agent_transformer.load("saved_models/2025-06-08_16:30:57-full-transformer/attacker")

In [7]:
with open("configs/run/test_single_training.yaml", "r") as file:
    config = yaml.safe_load(file)
training_config_defender = TrainingConfig.from_dict(config, suffix="_defender")
loss_config_defender = LossConfig.from_dict(config, suffix="_defender")
training_config_attacker = TrainingConfig.from_dict(config, suffix="_attacker")
loss_config_attacker = LossConfig.from_dict(config, suffix="_attacker")
num_nodes = flipit_map.num_nodes
agent_config = AgentNNConfig.from_dict(config)
backbone_config = BackboneConfig.from_dict(config, suffix="_backbone")
head_config = HeadConfig.from_dict(config, suffix="_head")

defender_agent_ffn = TrainableNNAgentPolicy(
    num_nodes=num_nodes,
    total_steps=env.num_steps,
    player_type=0,
    device="cuda:0",
    loss_config=loss_config_defender,
    training_config=training_config_defender,
    run_name="test",
    agent_config=agent_config,
    backbone_config=backbone_config,
    head_config=head_config,
)
defender_agent_ffn.eval()
defender_agent_ffn.load("saved_models/2025-06-08_16:30:57-full-fnn/defender/agent_0.pth")

attacker_agent_ffn = MultiAgentPolicy(
    num_nodes=num_nodes,
    player_type=1,
    device="cuda:0",
    embedding_size=32,
    run_name="test",
    policy_generator=AgentGenerator(
        TrainableNNAgentPolicy,
        {
            "num_nodes": num_nodes,
            "total_steps": env.num_steps,
            "player_type": 1,
            "device": "cuda:0",
            "loss_config": loss_config_attacker,
            "training_config": training_config_attacker,
            "run_name": "test",
            "agent_config": agent_config,
            "backbone_config": backbone_config,
            "head_config": head_config,
        }
    ),
)
attacker_agent_ffn.eval()
attacker_agent_ffn.load("saved_models/2025-06-08_16:30:57-full-fnn/attacker")

In [8]:
from algorithms.coevosg import CoevoSGDefenderAgent, CoevoSGAttackerAgent, CoevoSGConfig

defender_agent_coevosg = CoevoSGDefenderAgent(
    num_nodes=num_nodes,
    player_type=0,
    device="cpu",
    run_name="test",
    config=CoevoSGConfig(),
    env=env,
)

attacker_agent_coevosg = CoevoSGAttackerAgent(
    num_nodes=num_nodes,
    player_type=1,
    device="cpu",
    run_name="test",
    config=CoevoSGConfig(),
    env=env,
)

defender_agent_coevosg.eval()
attacker_agent_coevosg.eval()
defender_agent_coevosg.load("saved_models/2025-06-09_08:45:52-full-coevosg-/defender/agent_0.pth")
attacker_agent_coevosg.load("saved_models/2025-06-09_08:45:52-full-coevosg-/attacker/agent_0.pth")

In [9]:
from algorithms.generic_policy import RandomAgent, GreedyOracleAgent

attacker_agent_random = RandomAgent(num_nodes=num_nodes, embedding_size=32, player_type=1, device="cuda:0", run_name="test")
attacker_greedy_oracle = GreedyOracleAgent(
    num_nodes=num_nodes, total_steps=env.num_steps, embedding_size=32, player_type=1, device="cuda:0", run_name="test", env_map=flipit_map
)

In [10]:
from utils import compare_agent_pairs

results = compare_agent_pairs(
    [
        (defender_agent_transformer, attacker_agent_transformer, "transformer"),
        (defender_agent_ffn, attacker_agent_ffn, "ffn"),
        (defender_agent_coevosg, attacker_agent_coevosg, "coevosg"),
    ],
    [
        (attacker_agent_random, "random"),
        (attacker_greedy_oracle, "greedy"),
    ],
    env,
    print_results=True,
)

Defender: transformer vs Attacker: transformer => Defender avg reward: 126.8858 (7.2249)
Defender: transformer vs Attacker: ffn => Defender avg reward: 130.8115 (4.6874)
Defender: transformer vs Attacker: coevosg => Defender avg reward: 140.3997
Defender: transformer vs Attacker: random => Defender avg reward: 141.7736
Defender: transformer vs Attacker: greedy => Defender avg reward: 133.1033
Defender: transformer => Avg reward: 134.5948 (6.3482)
Defender: ffn vs Attacker: transformer => Defender avg reward: 120.1021 (7.6354)
Defender: ffn vs Attacker: ffn => Defender avg reward: 130.6920 (3.9075)
Defender: ffn vs Attacker: coevosg => Defender avg reward: 128.4320
Defender: ffn vs Attacker: random => Defender avg reward: 143.0464
Defender: ffn vs Attacker: greedy => Defender avg reward: 121.3210
Defender: ffn => Avg reward: 128.7187 (9.1955)
Defender: coevosg vs Attacker: transformer => Defender avg reward: 119.3308 (3.2711)
Defender: coevosg vs Attacker: ffn => Defender avg reward: 12

In [8]:
results

{'transformer/transformer/avg': 36.41667556762695,
 'transformer/transformer/std': 2.404609203338623,
 'transformer/ffn/avg': 36.88743591308594,
 'transformer/ffn/std': 2.0703463554382324,
 'transformer/coevosg/avg': 35.601593017578125,
 'transformer/coevosg/std': None,
 'transformer/random/avg': 41.157005310058594,
 'transformer/random/std': None,
 'transformer/greedy/avg': 36.866844177246094,
 'transformer/greedy/std': None,
 'transformer/avg': 37.38591384887695,
 'transformer/std': 2.1715357303619385,
 'ffn/transformer/avg': 36.282413482666016,
 'ffn/transformer/std': 2.292834758758545,
 'ffn/ffn/avg': 38.51555633544922,
 'ffn/ffn/std': 1.8750312328338623,
 'ffn/coevosg/avg': 33.64737319946289,
 'ffn/coevosg/std': None,
 'ffn/random/avg': 39.295902252197266,
 'ffn/random/std': None,
 'ffn/greedy/avg': 39.965843200683594,
 'ffn/greedy/std': None,
 'ffn/avg': 37.541419982910156,
 'ffn/std': 2.58135986328125,
 'coevosg/transformer/avg': 35.037418365478516,
 'coevosg/transformer/std': 2

In [7]:
from algorithms.generic_policy import CombinedPolicy

combined = CombinedPolicy(defender_agent_transformer, attacker_greedy_oracle)
output = combined.evaluate(env, 1000, current_player=0, add_logs=False)

In [8]:
output

tensor([[ 9.6579, -0.0557],
        [ 9.6579, -0.0557],
        [ 9.6579, -0.0557],
        ...,
        [ 9.6579, -0.0557],
        [ 9.6579, -0.0557],
        [ 9.6579, -0.0557]], device='cuda:0')