In [None]:
from matplotlib import pyplot as plt
from mars.utils import merge_bounding_box
import os
import numpy as np
from mars.env import SearchEnv
from mars.detector import YoloV11Detector
from mars.AgentQTableTrainer import DotaRawDataset
from mars.agent import RLQtableAgent


In [None]:

episodes = 8
log_interval = 1
agent = RLQtableAgent(training=True, load=True)
train_dataset = DotaRawDataset(
    # image_dir=os.path.join('../data/train', 'images'),
    # label_dir=os.path.join('../data/train', 'labelTxt')
    image_dir=os.path.join('\\Data\\train', 'images'),
    label_dir=os.path.join('\\Data\\train', 'labelTxt')
)

env = SearchEnv()
env.set_detector(YoloV11Detector())

# 用于绘图的列表
episode_rewards = []
episode_vehicles_found = []
episode_steps = []
moving_avg_rewards = []
moving_avg_found = []

for episode in range(episodes):
    vehicles_found_this_episode = 0
    total_reward = 0
    image_steps = []
    for image, target in train_dataset:
        env.set_image(image)
        env.set_all_target(target)
        status = env.reset()
        image_reward = 0
        done = False
        all_obbs = []
        step = 0
        steps_count = 0
        obbs = []
        for step in range(10):
            # 选择动作
            action = agent.select_action(*status)
            if not action:
                break
            # 执行动作
            next_status, reward, obbs, window = env.step(action)
            merge_bounding_box(all_obbs, obbs)
            # 学习
            agent.update(status, action, reward, next_status)
            status = next_status
            image_reward += reward
        # 记录本轮指标
        steps_count += step
        total_reward += image_reward
        episode_steps.append(steps_count)
        vehicles_found_this_episode += len(all_obbs)
    episode_rewards.append(total_reward)
    episode_vehicles_found.append(vehicles_found_this_episode)

    # 计算滑动平均（窗口大小为100）以便更容易看到趋势
    moving_avg_r = np.mean(episode_rewards[-100:]) if len(episode_rewards) >= 100 else np.mean(episode_rewards)
    moving_avg_rewards.append(moving_avg_r)
    moving_avg_f = np.mean(episode_vehicles_found[-100:]) if len(episode_vehicles_found) >= 100 else np.mean(
        episode_vehicles_found)
    moving_avg_found.append(moving_avg_f)

    # 定期打印日志
    print(f"Episode {episode:4d}/{episodes} | "
          f"Reward: {total_reward:6.1f} | "
          f"Vehicles Found: {vehicles_found_this_episode:2d} | "
          f"Steps: {episode_steps} | "
          f"Avg Reward (MA100): {moving_avg_r:6.1f} | "
          f"Avg Found (MA100): {moving_avg_f:4.1f}")
    agent.save('qtable-{}.pkl'.format(episode))
# 训练结束后绘制图表
plt.figure(figsize=(12, 10))

plt.subplot(2, 2, 1)
plt.plot(episode_vehicles_found, label='Per Episode', alpha=0.3)
plt.plot(moving_avg_found, label='Moving Avg (100)', linewidth=2)
plt.axhline(y=100, color='r', linestyle='--', label='True Count')
plt.xlabel('Episode')
plt.ylabel('Vehicles Found')
plt.title('Performance: Vehicles Found per Episode')
plt.legend()
plt.grid(True)

plt.subplot(2, 2, 2)
plt.plot(episode_rewards, label='Per Episode', alpha=0.3)
plt.plot(moving_avg_rewards, label='Moving Avg (100)', linewidth=2)
plt.xlabel('Episode')
plt.ylabel('Total Reward')
plt.title('Performance: Total Reward per Episode')
plt.legend()
plt.grid(True)

plt.subplot(2, 2, 3)
plt.plot(episode_steps)
plt.xlabel('Episode')
plt.ylabel('Steps Taken')
plt.title('Efficiency: Steps per Episode')
plt.grid(True)

# ... 你还可以添加Loss等图 ...

plt.tight_layout()
plt.savefig('training_metrics.png')
plt.show()