In [None]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns

from target_assign_agent import RuleAgent, IQLAgent
from target_assign_env import raw_env, TaskAllocationEnv

In [None]:
env = raw_env(
    dict(
        min_drones=20,
        possible_level=[0, 0.1, 0.4, 0.8],
        threat_dist=[0.15, 0.25, 0.35, 0.25],
    )
)

env.reset()
state_dim = env.state().shape[0]
action_dim = env.action_space(env.agents[0]).n

a20_ckpt = "checkpoint_A20.pth"
rule_agent = RuleAgent(num_threats=20)
a20_agent = IQLAgent(state_dim, action_dim)
a20_agent.load_checkpoint(a20_ckpt)

In [None]:
def simulate_drone_lost(
    trained_agent, compare_agent=None, num_episodes=100, max_drone_lost=4, env=None
):
    if env is None:
        env = TaskAllocationEnv(dict(min_drones=20))

    compare_agent = trained_agent if compare_agent is None else compare_agent
    comparative_data = []

    for episode in range(num_episodes):
        env.reset()
        drone_lost = np.random.randint(0, max_drone_lost + 1)

        for i, agent in enumerate(env.agents):
            state = env.state()
            action_mask = env.action_mask(agent)
            action = trained_agent.predict(state, action_mask)
            if drone_lost > 0 and i == len(env.agents) - 1:
                while True:
                    lost_drones = np.random.choice(env.agents, drone_lost, replace=False)
                    if agent not in lost_drones:
                        break
                for drone in lost_drones:
                    env.truncations[drone] = True
            env.step(action)

        _, original_reward, _, __, original_info = env.last()

        # num_lost env
        num_drones = 20 - drone_lost
        lost_env = TaskAllocationEnv(
            dict(
                min_drones=num_drones,
                max_drones=num_drones,
                attack_prob=env.attack_prob,
                possible_level=env.possible_level,
                threat_dist=env.threat_dist,
            )
        )
        lost_env.reset()
        lost_env.threat_levels = env.threat_levels.copy()
        lost_env.actual_threats = env.actual_threats.copy()
        lost_env.num_actual_threat = env.num_actual_threat
        lost_env.pre_allocation = lost_env.calculate_pre_allocation()

        new_assignments = []
        for agent in lost_env.agents:
            state = lost_env.state()
            action_mask = lost_env.action_mask(agent)
            action = compare_agent.predict(state, action_mask)
            lost_env.step(action)

        _, new_reward, _, __, new_info = lost_env.last()

        episode_data = {
            "episode": episode,
            "threat_levels": env.threat_levels.copy(),
            "num_actual_threat": env.num_actual_threat,
            "num_drones_lost": drone_lost,
            "original_assignments": env.actual_allocation.copy(),
            "original_reward": original_reward,
            "original_info": original_info,
            "original_coverage": original_info["coverage"],
            "original_threat_destroyed": original_info["threat_destroyed"],
            "original_drone_lost": original_info["drone_lost"],
            "original_kd_ratio": original_info["kd_ratio"],
            "original_remaining_threat": original_info["num_remaining_threat"],
            "new_assignments": lost_env.actual_allocation.copy(),
            "new_reward": new_reward,
            "new_info": new_info,
            "new_coverage": new_info["coverage"],
            "new_threat_destroyed": new_info["threat_destroyed"],
            "new_drone_lost": new_info["drone_lost"],
            "new_kd_ratio": new_info["kd_ratio"],
            "new_remaining_threat": new_info["num_remaining_threat"],
        }
        comparative_data.append(episode_data)

    return comparative_data

In [None]:
def analyze_compare_data(comparative_data):
    df = pd.DataFrame(comparative_data)

    fig, axs = plt.subplots(4, 2, figsize=(20, 30))
    fig.suptitle("Comparison of Original and New Allocation Strategies", fontsize=16)

    # 1. 奖励对比
    axs[0, 0].scatter(df["original_reward"], df["new_reward"])
    axs[0, 0].plot(
        [df["original_reward"].min(), df["original_reward"].max()],
        [df["original_reward"].min(), df["original_reward"].max()],
        "r--",
    )
    axs[0, 0].set_xlabel("Original Reward")
    axs[0, 0].set_ylabel("New Reward")
    axs[0, 0].set_title("Reward Comparison")

    # 2. 覆盖率对比
    axs[0, 1].scatter(df["original_coverage"], df["new_coverage"])
    axs[0, 1].plot([0, 1], [0, 1], "r--")
    axs[0, 1].set_xlabel("Original Coverage")
    axs[0, 1].set_ylabel("New Coverage")
    axs[0, 1].set_title("Coverage Comparison")

    # 3. 威胁消除对比
    axs[1, 0].scatter(df["original_threat_destroyed"], df["new_threat_destroyed"])
    axs[1, 0].plot(
        [df["original_threat_destroyed"].min(), df["original_threat_destroyed"].max()],
        [df["original_threat_destroyed"].min(), df["original_threat_destroyed"].max()],
        "r--",
    )
    axs[1, 0].set_xlabel("Original Threats Destroyed")
    axs[1, 0].set_ylabel("New Threats Destroyed")
    axs[1, 0].set_title("Threat Destruction Comparison")

    # 4. 无人机损失对比
    axs[1, 1].scatter(df["original_drone_lost"], df["new_drone_lost"])
    axs[1, 1].plot(
        [df["original_drone_lost"].min(), df["original_drone_lost"].max()],
        [df["original_drone_lost"].min(), df["original_drone_lost"].max()],
        "r--",
    )
    axs[1, 1].set_xlabel("Original Drones Lost")
    axs[1, 1].set_ylabel("New Drones Lost")
    axs[1, 1].set_title("Drone Loss Comparison")

    # 5. K/D比率对比
    axs[2, 0].scatter(df["original_kd_ratio"], df["new_kd_ratio"])
    axs[2, 0].plot(
        [df["original_kd_ratio"].min(), df["original_kd_ratio"].max()],
        [df["original_kd_ratio"].min(), df["original_kd_ratio"].max()],
        "r--",
    )
    axs[2, 0].set_xlabel("Original K/D Ratio")
    axs[2, 0].set_ylabel("New K/D Ratio")
    axs[2, 0].set_title("K/D Ratio Comparison")

    # 6. 剩余威胁对比
    axs[2, 1].scatter(df["original_remaining_threat"], df["new_remaining_threat"])
    axs[2, 1].plot(
        [df["original_remaining_threat"].min(), df["original_remaining_threat"].max()],
        [df["original_remaining_threat"].min(), df["original_remaining_threat"].max()],
        "r--",
    )
    axs[2, 1].set_xlabel("Original Remaining Threat")
    axs[2, 1].set_ylabel("New Remaining Threat")
    axs[2, 1].set_title("Remaining Threat Comparison")

    # 7. 损失无人机数量与性能改进的关系
    improvement = df["new_reward"] - df["original_reward"]
    axs[3, 0].scatter(df["num_drones_lost"], improvement)
    axs[3, 0].set_xlabel("Number of Drones Lost")
    axs[3, 0].set_ylabel("Reward Improvement")
    axs[3, 0].set_title("Drone Loss vs Performance Improvement")

    # 8. 平均威胁等级和分配情况热力图
    avg_original_assignments = np.mean(
        [
            np.bincount(episode["original_assignments"], minlength=20)
            for episode in comparative_data
        ],
        axis=0,
    )
    avg_new_assignments = np.mean(
        [
            np.bincount(episode["new_assignments"], minlength=20)
            for episode in comparative_data
        ],
        axis=0,
    )
    avg_threat_levels = np.mean(
        [episode["threat_levels"] for episode in comparative_data], axis=0
    )

    assignment_data = np.vstack(
        (avg_threat_levels, avg_original_assignments, avg_new_assignments)
    )
    sns.heatmap(assignment_data, ax=axs[3, 1], cmap="YlOrRd", annot=True, fmt=".2f")
    axs[3, 1].set_title("Average Threat Levels and Assignments")
    axs[3, 1].set_ylabel("Threat Level | Original | New")
    axs[3, 1].set_xlabel("Threat Position")

    plt.tight_layout()
    plt.show()

    print("\nAverage improvements:")
    print(f"Reward: {(df['new_reward'] - df['original_reward']).mean():.4f}")
    print(f"Coverage: {(df['new_coverage'] - df['original_coverage']).mean():.4f}")
    print(f"Threats destroyed: {(df['new_threat_destroyed'] - df['original_threat_destroyed']).mean():.4f}")
    print(f"Drones lost: {(df['new_drone_lost'] - df['original_drone_lost']).mean():.4f}")
    print(f"K/D ratio: {(df['new_kd_ratio'] - df['original_kd_ratio']).mean():.4f}")
    print(f"Remaining threat: {(df['new_remaining_threat'] - df['original_remaining_threat']).mean():.4f}")

In [None]:
compare_data = simulate_drone_lost(a20_agent, rule_agent, num_episodes=1000, max_drone_lost=0, env=env)
df = pd.DataFrame(compare_data)
df.describe()

In [None]:
analyze_compare_data(compare_data)