In [None]:
import numpy as np

from target_assign_rl import IQLAgent, ReplayBuffer, raw_env

In [None]:
def train_iql(env, num_episodes, batch_size=256, update_target_every=500, save_every=2000):
    env.reset()
    state_dim = env.state().shape[0]
    action_dim = env.action_space(env.agents[0]).n

    rl_agent = IQLAgent(state_dim, action_dim)
    replay_buffer = ReplayBuffer(100000)

    last_avg_reward = -np.inf
    episode_rewards = []
    episode_losses = []
    episode_epsilons = []
    episode_kd = []
    episode_kill = []
    episode_remain = []
    episode_lost = []

    for episode in range(num_episodes):
        env.reset()
        episode_reward = 0
        episode_loss = 0

        states = []
        actions = []
        next_states = []
        dones = []
        for agent in env.agents:
            agent_state = env.state()
            states.append(agent_state)
            action_mask = env.action_mask(agent)
            action = rl_agent.select_action(agent_state, action_mask)
            actions.append(action)
            env.step(action)
            next_states.append(env.state())
            dones.append(False)

        _, reward, done, _, info = env.last()
        dones[-1] = done
        for i in range(len(env.agents)):
            replay_buffer.push(states[i], actions[i], reward, next_states[i], dones[i])
        episode_reward += reward

        if len(replay_buffer) > batch_size * 4:
            batch = replay_buffer.sample(batch_size)
            loss = rl_agent.update(batch)
            episode_loss += loss
            rl_agent.update_epsilon()
            # replay_buffer.clear()

        if episode % update_target_every == 0:
            rl_agent.update_target_network()

        episode_rewards.append(episode_reward)
        episode_losses.append(episode_loss)
        episode_epsilons.append(rl_agent.epsilon)
        episode_kd.append(info["kd_ratio"])
        episode_kill.append(info["threat_destroyed"])
        episode_remain.append(info["num_remaining_threat"])
        episode_lost.append(info["drone_lost"])

        if (episode + 1) % save_every == 0:
            avg_reward = np.mean(episode_rewards[-save_every:])
            avg_loss = np.mean(episode_losses[-save_every:])
            print(
                f"Episode {episode + 1}, Avg Reward: {avg_reward:.2f}, Avg Loss: {avg_loss:.4f}, Epsilon: {rl_agent.epsilon:.2f}"
            )
            if avg_reward > last_avg_reward:
                last_avg_reward = avg_reward
                rl_agent.save_checkpoint(episode + 1)

    training_data = {
        "rewards": episode_rewards,
        "losses": episode_losses,
        "kd_ratio": episode_kd,
        "threat_destroyed": episode_kill,
        "threat_remain": episode_remain,
        "drone_lost": episode_lost,
    }

    np.save("training_data.npy", training_data)
    rl_agent.save_checkpoint(num_episodes)

    return rl_agent, training_data

In [None]:
env = raw_env(
    dict(
        min_drones=20,
        possible_level=[0, 0.05, 0.1, 0.5, 0.8],
        threat_dist=[0.1, 0.3, 0.1, 0.35, 0.15],
        attack_prob=0.6,
    )
)
trained_agent, training_data = train_iql(env, num_episodes=int(1e6))

### Data Analysis

In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

In [None]:
def plot_moving_average(df, column, window=1000, figsize=(12, 6)):
    """绘制移动平均线图"""
    plt.figure(figsize=figsize)
    plt.plot(df["episode"], df[column], alpha=0.3, label="Raw")
    plt.plot(
        df["episode"],
        df[column].rolling(window=window).mean(),
        label=f"{window}-episode Moving Average",
    )
    plt.title(f"{column.capitalize()} over Episodes")
    plt.xlabel("Episode")
    plt.ylabel(column.capitalize())
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.savefig(f"{column}_moving_average.png")
    plt.close()


def plot_multi_metric_comparison(df, metrics, figsize=(12, 6), normalize=True):
    """绘制多指标对比图"""
    plt.figure(figsize=figsize)
    for metric in metrics:
        plt.plot(
            df["episode"],
            (df[metric] - df[metric].min()) / (df[metric].max() - df[metric].min()) if normalize else df[metric],
            label=metric,
        )
    plt.title(f"{'Normalized' if normalize else 'Raw'} Metrics Comparison")
    plt.xlabel("Episode")
    plt.ylabel(f"{'Normalized' if normalize else 'Raw'} Value")
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.tight_layout()

    metrics_str = "_".join(metrics)
    plt.savefig(f"{metrics_str}_comparison.png")
    plt.close()


def plot_correlation_heatmap(df, figsize=(10, 8)):
    """绘制相关性热力图"""
    corr = df.corr()
    plt.figure(figsize=figsize)
    sns.heatmap(corr, annot=True, cmap="coolwarm", vmin=-1, vmax=1, center=0)
    plt.title("Correlation Heatmap of Training Metrics")
    plt.tight_layout()
    plt.savefig("correlation_heatmap.png")
    plt.close()


def analyze_training_data(data, normalize=True):
    """综合分析训练数据"""
    df = pd.DataFrame(data)
    df['episode'] = range(1, len(df) + 1)

    # 绘制移动平均线图
    plot_moving_average(df, "rewards")
    plot_moving_average(df, "losses")
    plot_moving_average(df, "kd_ratio")

    # 绘制多指标对比图
    plot_multi_metric_comparison(df, ["rewards", "losses", "kd_ratio"], normalize=normalize)

    # 绘制相关性热力图
    plot_correlation_heatmap(df)

    print("分析完成，所有图表已保存。")

    return df

In [None]:
# training_data = np.load("training_data.npy", allow_pickle=True).item()

df = analyze_training_data(training_data)
df_mean = df.groupby(np.arange(len(df))//1000).mean()
df_mean.describe()

In [None]:
plot_multi_metric_comparison(df_mean, ["threat_destroyed", "threat_remain", "drone_lost"], normalize=False)