In [None]:
import random
import os
from collections import deque

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from target_assign_env import raw_env

In [None]:
class QNetwork(nn.Module):
    def __init__(self, state_dim, action_dim):
        super(QNetwork, self).__init__()
        self.fc1 = nn.Linear(state_dim, 256)
        self.fc2 = nn.Linear(256, 256)
        self.fc3 = nn.Linear(256, action_dim)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        return self.fc3(x)


class ReplayBuffer:
    def __init__(self, capacity):
        self.buffer = deque(maxlen=capacity)

    def push(self, initial_state, joint_action, reward, final_state, done):
        self.buffer.append((initial_state, joint_action, reward, final_state, done))

    def sample(self, batch_size):
        return random.sample(self.buffer, batch_size)

    def clear(self):
        self.buffer.clear()

    def __len__(self):
        return len(self.buffer)


class IQLAgent:
    def __init__(
        self,
        state_dim,
        action_dim,
        lr=1e-5,
        gamma=0.99,
        epsilon_start=1.0,
        epsilon_end=0.01,
        epsilon_decay=0.995,
    ):
        self.state_dim = state_dim
        self.action_dim = action_dim
        self.q_network = QNetwork(state_dim, action_dim)
        self.target_network = QNetwork(state_dim, action_dim)
        self.target_network.load_state_dict(self.q_network.state_dict())
        self.optimizer = optim.Adam(self.q_network.parameters(), lr=lr)
        self.gamma = gamma
        self.epsilon = epsilon_start
        self.epsilon_end = epsilon_end
        self.epsilon_decay = epsilon_decay

    def save_checkpoint(self, episode, path="checkpoints"):
        if not os.path.exists(path):
            os.makedirs(path)

        checkpoint = {
            "episode": episode,
            "q_network_state_dict": self.q_network.state_dict(),
            "target_network_state_dict": self.target_network.state_dict(),
            "optimizer_state_dict": self.optimizer.state_dict(),
            "epsilon": self.epsilon,
        }

        checkpoint_path = os.path.join(path, f"checkpoint_episode_{episode}.pth")
        torch.save(checkpoint, checkpoint_path)
        print(f"Checkpoint saved at episode {episode}")

    def load_checkpoint(self, checkpoint_path):
        if not os.path.exists(checkpoint_path):
            print(f"Checkpoint file not found: {checkpoint_path}")
            return None

        checkpoint = torch.load(checkpoint_path)

        self.q_network.load_state_dict(checkpoint["q_network_state_dict"])
        self.target_network.load_state_dict(checkpoint["target_network_state_dict"])
        self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
        self.epsilon = checkpoint["epsilon"]

        print(f"Checkpoint loaded from episode {checkpoint['episode']}")
        return checkpoint["episode"]

    def select_action(self, state, action_mask):
        if random.random() > self.epsilon:
            with torch.no_grad():
                q_values = self.q_network(torch.FloatTensor(state)).numpy()
                q_values[~action_mask] = -np.inf
                return np.argmax(q_values)
        else:
            valid_actions = np.where(action_mask)[0]
            return np.random.choice(valid_actions)

    def update(self, batch):
        states, actions, rewards, next_states, dones = zip(*batch)

        states = torch.FloatTensor(np.array(states))
        actions = torch.LongTensor(np.array(actions))
        rewards = torch.FloatTensor(rewards)
        next_states = torch.FloatTensor(np.array(next_states))
        dones = torch.FloatTensor(dones)

        current_q_values = self.q_network(states).gather(1, actions.unsqueeze(1))
        with torch.no_grad():
            next_q_values = self.target_network(next_states).max(1)[0]
            target_q_values = rewards + (1 - dones) * self.gamma * next_q_values

        loss = F.mse_loss(current_q_values.squeeze(), target_q_values)
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        return loss.item()

    def update_target_network(self):
        self.target_network.load_state_dict(self.q_network.state_dict())

    def update_epsilon(self):
        self.epsilon = max(self.epsilon_end, self.epsilon * self.epsilon_decay)


class RuleAgent:
    def __init__(self, num_threats=20):
        self.max_threats = num_threats
        self.current_allocation = np.zeros(num_threats)
        self.pre_allocation = None
        self.index = 0

    def select_action(self, state, action_mask):
        threat_levels, pre_allocation, current_allocation = state.reshape([3, -1])
        if (
            not np.array_equal(self.pre_allocation, pre_allocation)
            or np.array_equal(self.current_allocation, pre_allocation)
        ):
            self.reset(pre_allocation)

        while self.index < self.max_threats:
            if self.current_allocation[self.index] < self.pre_allocation[self.index]:
                self.current_allocation[self.index] += 1
                return self.index
            self.index += 1

    def reset(self, allocation):
        self.pre_allocation = allocation
        self.current_allocation = np.zeros(self.max_threats)
        self.index = 0

In [None]:
def train_iql(env, num_episodes, batch_size=256, update_target_every=500, save_every=2000):
    env.reset()
    state_dim = env.state().shape[0]
    action_dim = env.action_space(env.agents[0]).n

    rl_agent = IQLAgent(state_dim, action_dim)
    replay_buffer = ReplayBuffer(100000)

    last_avg_reward = -np.inf
    episode_rewards = []
    episode_losses = []
    episode_epsilons = []
    episode_kd = []
    episode_kill = []
    episode_remain = []
    episode_lost = []

    for episode in range(num_episodes):
        env.reset()
        episode_reward = 0
        episode_loss = 0

        states = []
        actions = []
        next_states = []
        dones = []
        for agent in env.agents:
            agent_state = env.state()
            states.append(agent_state)
            action_mask = env.action_mask(agent)
            action = rl_agent.select_action(agent_state, action_mask)
            actions.append(action)
            env.step(action)
            next_states.append(env.state())
            dones.append(False)

        _, reward, done, _, info = env.last()
        dones[-1] = done
        for i in range(len(env.agents)):
            replay_buffer.push(states[i], actions[i], reward, next_states[i], dones[i])
        episode_reward += reward

        if len(replay_buffer) > batch_size * 4:
            batch = replay_buffer.sample(batch_size)
            loss = rl_agent.update(batch)
            episode_loss += loss
            rl_agent.update_epsilon()
            # replay_buffer.clear()

        if episode % update_target_every == 0:
            rl_agent.update_target_network()

        episode_rewards.append(episode_reward)
        episode_losses.append(episode_loss)
        episode_epsilons.append(rl_agent.epsilon)
        episode_kd.append(info["kd_ratio"])
        episode_kill.append(info["threat_destroyed"])
        episode_remain.append(info["remaining_threat"])
        episode_lost.append(info["drone_lost"])

        if (episode + 1) % save_every == 0:
            avg_reward = np.mean(episode_rewards[-save_every:])
            avg_loss = np.mean(episode_losses[-save_every:])
            print(
                f"Episode {episode + 1}, Avg Reward: {avg_reward:.2f}, Avg Loss: {avg_loss:.4f}, Epsilon: {rl_agent.epsilon:.2f}"
            )
            if avg_reward > last_avg_reward:
                last_avg_reward = avg_reward
                rl_agent.save_checkpoint(episode + 1)

    training_data = {
        "rewards": episode_rewards,
        "losses": episode_losses,
        "epsilons": episode_epsilons,
        "kd_ratio": episode_kd,
        "threat_destroyed": episode_kill,
        "threat_remain": episode_remain,
        "drone_lost": episode_lost,
    }

    np.save("training_data.npy", training_data)
    rl_agent.save_checkpoint(num_episodes)

    return rl_agent, training_data

In [None]:
env = raw_env(dict(min_drones=20))
trained_agents, training_data = train_iql(env, num_episodes=int(1e6))

### Data Analysis

In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

In [None]:
def plot_moving_average(df, column, window=100, figsize=(12, 6)):
    """绘制移动平均线图"""
    plt.figure(figsize=figsize)
    plt.plot(df["episode"], df[column], alpha=0.3, label="Raw")
    plt.plot(
        df["episode"],
        df[column].rolling(window=window).mean(),
        label=f"{window}-episode Moving Average",
    )
    plt.title(f"{column.capitalize()} over Episodes")
    plt.xlabel("Episode")
    plt.ylabel(column.capitalize())
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.savefig(f"{column}_moving_average.png")
    plt.close()


def plot_multi_metric_comparison(df, metrics, figsize=(12, 6), normalize=True):
    """绘制多指标对比图"""
    plt.figure(figsize=figsize)
    for metric in metrics:
        plt.plot(
            df["episode"],
            (df[metric] - df[metric].min()) / (df[metric].max() - df[metric].min()) if normalize else df[metric],
            label=metric,
        )
    plt.title("Normalized Metrics Comparison")
    plt.xlabel("Episode")
    plt.ylabel("Normalized Value")
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.tight_layout()

    metrics_str = "_".join(metrics)
    plt.savefig(f"{metrics_str}_comparison.png")
    plt.close()


def plot_correlation_heatmap(df, figsize=(10, 8)):
    """绘制相关性热力图"""
    corr = df.corr()
    plt.figure(figsize=figsize)
    sns.heatmap(corr, annot=True, cmap="coolwarm", vmin=-1, vmax=1, center=0)
    plt.title("Correlation Heatmap of Training Metrics")
    plt.tight_layout()
    plt.savefig("correlation_heatmap.png")
    plt.close()


def analyze_training_data(data, normalize=True):
    """综合分析训练数据"""
    df = pd.DataFrame(data)
    df['episode'] = range(1, len(df) + 1)

    # 绘制移动平均线图
    plot_moving_average(df, "rewards")
    plot_moving_average(df, "losses")
    plot_moving_average(df, "kd_ratio")

    # 绘制多指标对比图
    plot_multi_metric_comparison(df, ["rewards", "losses", "kd_ratio"], normalize=normalize)

    # 绘制相关性热力图
    plot_correlation_heatmap(df)

    print("分析完成，所有图表已保存。")

    return df

In [None]:
df = analyze_training_data(training_data)
df_mean = df.groupby(np.arange(len(df))//1000).mean()
df_mean.describe()

In [None]:
plot_multi_metric_comparison(df_mean, ["threat_destroyed", "threat_remain"], normalize=False)

### Action Analysis

In [None]:
def inference_and_collect_data(env, trained_agent: IQLAgent, num_episodes=100):
    collected_data = []

    for episode in range(num_episodes):
        env.reset()
        episode_data = {
            "initial_state": env.state(),
            "assignments": [],
            "threat_levels": [],
            "coverage": [],
            "final_reward": 0
        }

        for i, agent in enumerate(env.agents):
            state, _, te, tr, _ = env.last()
            action_mask = env.action_mask(agent)
            action = trained_agent.select_action(state, action_mask)
            env.step(action)
            
            # 收集数据
            episode_data["assignments"].append(action)
            episode_data["threat_levels"].append(env.threat_levels)
        
        _, reward, _, _, info = env.last()
        episode_data["final_reward"] = reward
        episode_data["coverage"] = np.sum(env.current_allocation > 0) / np.sum(env.threat_levels > 0)
        episode_data["threat_destroyed"] = info["threat_destroyed"]
        episode_data["drone_lost"] = info["drone_lost"]
        episode_data["remaining_threat"] = info["remaining_threat"]
        episode_data["actual_threat"] = env.actual_threats

        collected_data.append(episode_data)

    return collected_data

def analyze_assignment_strategy(collected_data):
    # 1. 平均奖励和覆盖率
    rewards = [episode['final_reward'] for episode in collected_data]
    coverages = [episode['coverage'] for episode in collected_data]
    
    plt.figure(figsize=(12, 6))
    plt.subplot(1, 2, 1)
    plt.hist(rewards, bins=20)
    plt.title('Distribution of Final Rewards')
    plt.xlabel('Reward')
    plt.ylabel('Frequency')
    
    plt.subplot(1, 2, 2)
    plt.hist(coverages, bins=20)
    plt.title('Distribution of Coverage')
    plt.xlabel('Coverage')
    plt.ylabel('Frequency')
    plt.tight_layout()
    plt.show()

    print(f"Average Reward: {np.mean(rewards):.2f}")
    print(f"Average Coverage: {np.mean(coverages):.2f}")

    # 2. 威胁等级与分配关系
    all_threat_levels = []
    all_assignments = []
    for episode in collected_data:
        all_threat_levels.extend(episode['threat_levels'][0])  # 假设每个episode的威胁等级相同
        all_assignments.extend(episode['assignments'])
    
    df = pd.DataFrame({'Threat Level': all_threat_levels, 'Assignment': all_assignments})
    
    plt.figure(figsize=(10, 6))
    sns.boxplot(x='Threat Level', y='Assignment', data=df)
    plt.title('Distribution of Assignments by Threat Level')
    plt.show()

    # 3. 威胁摧毁率分析
    threats_destroyed = [episode['threat_destroyed'] for episode in collected_data]
    drones_lost = [episode['drone_lost'] for episode in collected_data]
    
    plt.figure(figsize=(10, 6))
    plt.scatter(drones_lost, threats_destroyed)
    plt.title('Threats Destroyed vs Drones Lost')
    plt.xlabel('Drones Lost')
    plt.ylabel('Threats Destroyed')
    plt.show()

    # 4. 剩余威胁分析
    remaining_threats = [episode['remaining_threat'] for episode in collected_data]
    
    plt.figure(figsize=(10, 6))
    plt.hist(remaining_threats, bins=20)
    plt.title('Distribution of Remaining Threats')
    plt.xlabel('Remaining Threat')
    plt.ylabel('Frequency')
    plt.show()

    # 5. 分配策略热图
    avg_assignments = np.mean([np.bincount(episode['assignments'], minlength=20) for episode in collected_data], axis=0)
    avg_threat_levels = np.mean([episode['threat_levels'][0] for episode in collected_data], axis=0)
    
    plt.figure(figsize=(12, 6))
    sns.heatmap(np.vstack((avg_threat_levels, avg_assignments)), 
                cmap='YlOrRd', annot=True, fmt='.2f')
    plt.title('Average Threat Levels and Assignments')
    plt.ylabel('Threat Level | Assignments')
    plt.xlabel('Threat Position')
    plt.show()

In [None]:
inference_data = inference_and_collect_data(env, trained_agents, num_episodes=1000)
analyze_assignment_strategy(inference_data)

### Test

In [None]:
def simulate_drone_lost(trained_agent, num_episodes=100, max_drone_lost=4):
    env = raw_env(dict(min_drones=20))
    comparative_data = []

    for episode in range(num_episodes):
        env.reset()
        drone_lost = np.random.randint(0, max_drone_lost + 1)
        original_assignments = []

        for i, agent in enumerate(env.agents):
            state = env.state()
            action_mask = env.action_mask(agent)
            action = trained_agent.select_action(state, action_mask)
            if drone_lost > 0 and i == len(env.agents) - 1:
                while True:
                    lost_drones = np.random.choice(env.agents, drone_lost, replace=False)
                    if agent not in lost_drones:
                        break
                for drone in lost_drones:
                    env.truncations[drone] = True
            env.step(action)
            original_assignments.append(action)

        _, original_reward, _, __, original_info = env.last()

        # num_lost env
        num_drones = 20 - drone_lost
        lost_env = raw_env(dict(min_drones=num_drones, max_drones=num_drones))
        lost_env.reset()
        lost_env.threat_levels = env.threat_levels
        lost_env.actual_threats = env.actual_threats
        lost_env.num_actual_threat = env.num_actual_threat
        lost_env.pre_allocation = lost_env.calculate_pre_allocation()

        new_assignments = []
        for agent in lost_env.agents:
            state = lost_env.state()
            action_mask = lost_env.action_mask(agent)
            action = trained_agent.select_action(state, action_mask)
            lost_env.step(action)
            new_assignments.append(action)

        _, new_reward, _, __, new_info = lost_env.last()

        episode_data = {
            "episode": episode,
            "threat_levels": env.threat_levels,
            "num_drones_lost": drone_lost,
            "original_assignments": original_assignments,
            "original_reward": original_reward,
            "original_coverage": np.sum(env.current_allocation > 0) / np.sum(env.threat_levels > 0),
            "original_threat_destroyed": original_info["threat_destroyed"],
            "original_drone_lost": original_info["drone_lost"],
            "original_kd_ratio": original_info["kd_ratio"],
            "original_remaining_threat": original_info["remaining_threat"],
            "new_assignments": new_assignments,
            "new_reward": new_reward,
            "new_coverage": np.sum(lost_env.current_allocation > 0) / np.sum(lost_env.threat_levels > 0),
            "new_threat_destroyed": new_info["threat_destroyed"],
            "new_drone_lost": new_info["drone_lost"],
            "new_kd_ratio": new_info["kd_ratio"],
            "new_remaining_threat": new_info["remaining_threat"],
        }
        comparative_data.append(episode_data)

    return comparative_data


def analyze_compare_data(comparative_data):
    # 将数据转换为 DataFrame
    df = pd.DataFrame(comparative_data)

    # 计算改进率
    improvements = (df["new_reward"] > df["original_reward"]).mean()
    print(f"Improvement rate: {improvements:.2%}")

    # 创建一个 4x2 的子图网格
    fig, axs = plt.subplots(4, 2, figsize=(20, 30))
    fig.suptitle("Comparison of Original and New Allocation Strategies", fontsize=16)

    # 1. 奖励对比
    axs[0, 0].scatter(df["original_reward"], df["new_reward"])
    axs[0, 0].plot(
        [df["original_reward"].min(), df["original_reward"].max()],
        [df["original_reward"].min(), df["original_reward"].max()],
        "r--",
    )
    axs[0, 0].set_xlabel("Original Reward")
    axs[0, 0].set_ylabel("New Reward")
    axs[0, 0].set_title("Reward Comparison")

    # 2. 覆盖率对比
    axs[0, 1].scatter(df["original_coverage"], df["new_coverage"])
    axs[0, 1].plot([0, 1], [0, 1], "r--")
    axs[0, 1].set_xlabel("Original Coverage")
    axs[0, 1].set_ylabel("New Coverage")
    axs[0, 1].set_title("Coverage Comparison")

    # 3. 威胁消除对比
    axs[1, 0].scatter(df["original_threat_destroyed"], df["new_threat_destroyed"])
    axs[1, 0].plot(
        [df["original_threat_destroyed"].min(), df["original_threat_destroyed"].max()],
        [df["original_threat_destroyed"].min(), df["original_threat_destroyed"].max()],
        "r--",
    )
    axs[1, 0].set_xlabel("Original Threats Destroyed")
    axs[1, 0].set_ylabel("New Threats Destroyed")
    axs[1, 0].set_title("Threat Destruction Comparison")

    # 4. 无人机损失对比
    axs[1, 1].scatter(df["original_drone_lost"], df["new_drone_lost"])
    axs[1, 1].plot(
        [df["original_drone_lost"].min(), df["original_drone_lost"].max()],
        [df["original_drone_lost"].min(), df["original_drone_lost"].max()],
        "r--",
    )
    axs[1, 1].set_xlabel("Original Drones Lost")
    axs[1, 1].set_ylabel("New Drones Lost")
    axs[1, 1].set_title("Drone Loss Comparison")

    # 5. K/D比率对比
    axs[2, 0].scatter(df["original_kd_ratio"], df["new_kd_ratio"])
    axs[2, 0].plot(
        [df["original_kd_ratio"].min(), df["original_kd_ratio"].max()],
        [df["original_kd_ratio"].min(), df["original_kd_ratio"].max()],
        "r--",
    )
    axs[2, 0].set_xlabel("Original K/D Ratio")
    axs[2, 0].set_ylabel("New K/D Ratio")
    axs[2, 0].set_title("K/D Ratio Comparison")

    # 6. 剩余威胁对比
    axs[2, 1].scatter(df["original_remaining_threat"], df["new_remaining_threat"])
    axs[2, 1].plot(
        [df["original_remaining_threat"].min(), df["original_remaining_threat"].max()],
        [df["original_remaining_threat"].min(), df["original_remaining_threat"].max()],
        "r--",
    )
    axs[2, 1].set_xlabel("Original Remaining Threat")
    axs[2, 1].set_ylabel("New Remaining Threat")
    axs[2, 1].set_title("Remaining Threat Comparison")

    # 7. 损失无人机数量与性能改进的关系
    improvement = df["new_reward"] - df["original_reward"]
    axs[3, 0].scatter(df["num_drones_lost"], improvement)
    axs[3, 0].set_xlabel("Number of Drones Lost")
    axs[3, 0].set_ylabel("Reward Improvement")
    axs[3, 0].set_title("Drone Loss vs Performance Improvement")

    # 8. 平均威胁等级和分配情况热力图
    avg_original_assignments = np.mean(
        [
            np.bincount(episode["original_assignments"], minlength=20)
            for episode in comparative_data
        ],
        axis=0,
    )
    avg_new_assignments = np.mean(
        [
            np.bincount(episode["new_assignments"], minlength=20)
            for episode in comparative_data
        ],
        axis=0,
    )
    avg_threat_levels = np.mean(
        [episode["threat_levels"] for episode in comparative_data], axis=0
    )

    assignment_data = np.vstack(
        (avg_threat_levels, avg_original_assignments, avg_new_assignments)
    )
    sns.heatmap(assignment_data, ax=axs[3, 1], cmap="YlOrRd", annot=True, fmt=".2f")
    axs[3, 1].set_title("Average Threat Levels and Assignments")
    axs[3, 1].set_ylabel("Threat Level | Original | New")
    axs[3, 1].set_xlabel("Threat Position")

    plt.tight_layout()
    plt.show()

    # 额外的统计信息
    print("\nAverage improvements:")
    print(f"Reward: {(df['new_reward'] - df['original_reward']).mean():.4f}")
    print(f"Coverage: {(df['new_coverage'] - df['original_coverage']).mean():.4f}")
    print(
        f"Threats destroyed: {(df['new_threat_destroyed'] - df['original_threat_destroyed']).mean():.4f}"
    )
    print(
        f"Drones lost: {(df['new_drone_lost'] - df['original_drone_lost']).mean():.4f}"
    )
    print(f"K/D ratio: {(df['new_kd_ratio'] - df['original_kd_ratio']).mean():.4f}")
    print(
        f"Remaining threat: {(df['new_remaining_threat'] - df['original_remaining_threat']).mean():.4f}"
    )

In [None]:
# load agent
env = raw_env(dict(min_drones=20))
env.reset()
state_dim = env.state().shape[0]
action_dim = env.action_space(env.agents[0]).n
trained_agent = IQLAgent(state_dim, action_dim)

checkpoint_path = "checkpoint.pth"
trained_agent.load_checkpoint(checkpoint_path)

In [None]:
env_20 = raw_env(dict(min_drones=20))
test_data = inference_and_collect_data(env_20, trained_agent, num_episodes=1000)
df = pd.DataFrame(test_data)
df.describe()

In [None]:
rule_agent = RuleAgent(num_threats=20)
test_data = inference_and_collect_data(env_20, rule_agent, num_episodes=1000)
df = pd.DataFrame(test_data)
df.describe()

In [None]:
compare_data = simulate_drone_lost(trained_agent, num_episodes=1000, max_drone_lost=5)
df = pd.DataFrame(compare_data)
df.describe()

In [None]:
analyze_compare_data(compare_data)