In [None]:
import os
import pickle
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MinMaxScaler
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
from einops import rearrange

# 中文显示配置
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False
plt.rcParams['figure.figsize'] = (15, 10)
plt.rcParams['font.size'] = 10

# 设备配置
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"使用设备：{device}")
if device.type == 'cuda':
    print(f"GPU型号：{torch.cuda.get_device_name(0)}")

# ========== 核心参数配置 ==========
PKL_3D_DIR = '/root/second/3D'  # 你的3D骨架数据路径
TARGET_FRAMES = 32  # 统一帧长
BATCH_SIZE = 8
EPOCHS = 30
ANOMALY_THRESHOLD = 0.8
NUM_JOINTS = 71
LEARNING_RATE = 1e-4

# 模型超参数
D_MODEL = 64       # 特征维度
NHEAD = 8          # 注意力头数
NUM_LAYERS = 2     # Transformer层数
DROPOUT = 0.2      # Dropout率

# 动作组定义（按类型聚类）
ACTION_GROUPS = {
    "group_swing_squat": ["arm_swing_as", "body_swing_bs", "chest_expansion_ce", "squat_sq"],
    "group_drumming_shaking": ["drumming_dr", "maracas_forward_shaking_mfs", "maracas_shaking_ms", "sing_and_clap_sac"],
    "group_pose": ["frog_pose_fg", "tree_pose_tr", "twist_pose_tw"]
}
THEME_TO_GROUP = {theme: group_name 
                 for group_name, themes in ACTION_GROUPS.items() 
                 for theme in themes}

# 输出目录
OUTPUT_DIR = 'st_transformer_action_groups'
os.makedirs(OUTPUT_DIR, exist_ok=True)

使用设备：cuda
GPU型号：NVIDIA L20


In [None]:
class SkeletonDataset(Dataset):
    """骨架数据集（支持组内归一化和帧长统一）"""
    def __init__(self, sequences, file_names, mean=None, std=None, is_train=True):
        self.sequences = sequences  # (N, T, J, C)
        self.file_names = file_names
        self.is_train = is_train
        self.mean = mean
        self.std = std
        
        # 训练集计算全局归一化统计量
        if self.is_train:
            self.mean = np.mean(self.sequences, axis=(0, 1), keepdims=True)  # (1,1,J,C)
            self.std = np.std(self.sequences, axis=(0, 1), keepdims=True) + 1e-6  # 避免除零
        
        # 数据归一化
        self.sequences_norm = (self.sequences - self.mean) / self.std

    def __len__(self):
        return len(self.sequences)
    
    def __getitem__(self, idx):
        return torch.tensor(self.sequences_norm[idx], dtype=torch.float32)
    
    def get_original_data(self, idx):
        """获取原始数据（非归一化）"""
        return self.sequences[idx], self.file_names[idx]

In [None]:
def load_data_by_action_group(root_dir, target_frames=32):
    """按动作组加载数据并统一帧长"""
    group_data = {group_name: {"sequences": [], "file_names": []} 
                 for group_name in ACTION_GROUPS.keys()}
    
    # 遍历所有PKL文件
    pkl_files = []
    for root, _, files in os.walk(root_dir):
        for file in files:
            if file.endswith('.pkl'):
                pkl_files.append(os.path.join(root, file))
    
    print(f"找到 {len(pkl_files)} 个pkl文件，按动作组加载...")
    
    for pkl_path in tqdm(pkl_files, desc="加载数据"):
        with open(pkl_path, 'rb') as f:
            data = pickle.load(f)
        
        # 过滤3D模态
        metadata = data['metadata']
        if metadata['data_type'].split('_')[-1] != '3d':
            continue
        
        # 匹配动作组
        theme_name = metadata['theme_name']
        if theme_name not in THEME_TO_GROUP:
            continue
        group_name = THEME_TO_GROUP[theme_name]
        
        # 处理每个样本的帧长统一
        samples = data['samples']
        for sample in samples:
            skeleton = sample['skeleton_3d']  # (T, J, C)
            T = skeleton.shape[0]
            
            # 帧长统一为target_frames
            if T > target_frames:
                # 下采样：均匀取帧
                indices = np.linspace(0, T-1, target_frames, dtype=int)
                skeleton_unified = skeleton[indices]
            else:
                # 上采样：线性插值补帧
                indices = np.linspace(0, T-1, target_frames)
                skeleton_unified = np.zeros((target_frames, NUM_JOINTS, 3))
                for j in range(NUM_JOINTS):
                    for c in range(3):
                        skeleton_unified[:, j, c] = np.interp(indices, np.arange(T), skeleton[:, j, c])
            
            # 添加到对应动作组
            group_data[group_name]["sequences"].append(skeleton_unified)
            group_data[group_name]["file_names"].append(f"{os.path.basename(pkl_path)}_{theme_name}")
    
    # 创建数据集和DataLoader
    group_datasets = {}
    for group_name, data in group_data.items():
        sequences = np.array(data["sequences"], dtype=np.float32)
        file_names = data["file_names"]
        print(f"\n{group_name}：样本数={len(sequences)}")
        
        # 组内划分训练/测试集（8:2）
        train_seq, test_seq, train_files, test_files = train_test_split(
            sequences, file_names, test_size=0.2, random_state=42, shuffle=True
        )
        
        # 创建Dataset
        train_dataset = SkeletonDataset(train_seq, train_files, is_train=True)
        test_dataset = SkeletonDataset(test_seq, test_files, 
                                      mean=train_dataset.mean, std=train_dataset.std, is_train=False)
        
        # 创建DataLoader
        train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)
        test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)
        
        group_datasets[group_name] = {
            "train_dataset": train_dataset,
            "test_dataset": test_dataset,
            "train_loader": train_loader,
            "test_loader": test_loader
        }
    
    return group_datasets

# ========== 执行数据加载 ==========
group_datasets = load_data_by_action_group(PKL_3D_DIR, target_frames=TARGET_FRAMES)

找到 11 个pkl文件，按动作组加载...


加载数据: 100%|██████████| 11/11 [00:00<00:00, 17.59it/s]


group_swing_squat：样本数=440

group_drumming_shaking：样本数=891

group_pose：样本数=362





In [None]:
class SpatialAttention(nn.Module):
    """空间注意力层：捕捉单帧内关节间依赖"""
    def __init__(self, d_model, nhead=8, dropout=0.1):
        super().__init__()
        self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=True)
        self.norm1 = nn.LayerNorm(d_model)
        self.ffn = nn.Sequential(
            nn.Linear(d_model, d_model * 4),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(d_model * 4, d_model)
        )
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)

    def forward(self, x):
        B, T, J, d = x.shape
        x_reshaped = rearrange(x, 'b t j d -> (b t) j d')  # (B*T, J, d)
        
        # 自注意力+残差连接
        attn_out, _ = self.self_attn(x_reshaped, x_reshaped, x_reshaped)
        x_reshaped = x_reshaped + self.dropout1(attn_out)
        x_reshaped = self.norm1(x_reshaped)
        
        # 前馈网络+残差连接
        ffn_out = self.ffn(x_reshaped)
        x_reshaped = x_reshaped + self.dropout2(ffn_out)
        x_reshaped = self.norm2(x_reshaped)
        
        return rearrange(x_reshaped, '(b t) j d -> b t j d', b=B, t=T)

class TemporalAttention(nn.Module):
    """时间注意力层：捕捉帧间时序依赖"""
    def __init__(self, d_model, nhead=8, dropout=0.1):
        super().__init__()
        self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=True)
        self.norm1 = nn.LayerNorm(d_model)
        self.ffn = nn.Sequential(
            nn.Linear(d_model, d_model * 4),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(d_model * 4, d_model)
        )
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)

    def forward(self, x):
        B, T, J, d = x.shape
        x_reshaped = rearrange(x, 'b t j d -> (b j) t d')  # (B*J, T, d)
        
        # 自注意力+残差连接
        attn_out, _ = self.self_attn(x_reshaped, x_reshaped, x_reshaped)
        x_reshaped = x_reshaped + self.dropout1(attn_out)
        x_reshaped = self.norm1(x_reshaped)
        
        # 前馈网络+残差连接
        ffn_out = self.ffn(x_reshaped)
        x_reshaped = x_reshaped + self.dropout2(ffn_out)
        x_reshaped = self.norm2(x_reshaped)
        
        return rearrange(x_reshaped, '(b j) t d -> b t j d', b=B, j=J)

class STTransformerBlock(nn.Module):
    """时空Transformer块（空间+时间注意力）"""
    def __init__(self, d_model, nhead=8, dropout=0.1):
        super().__init__()
        self.spatial_attn = SpatialAttention(d_model, nhead, dropout)
        self.temporal_attn = TemporalAttention(d_model, nhead, dropout)
        self.residual_proj = nn.Linear(d_model, d_model)  # 残差投影

    def forward(self, x):
        # 空间注意力+残差
        spatial_out = self.spatial_attn(x)
        x = x + self.residual_proj(spatial_out)
        
        # 时间注意力+残差
        temporal_out = self.temporal_attn(x)
        x = x + self.residual_proj(temporal_out)
        
        return x

class STTransformerAutoencoder(nn.Module):
    """完整的ST-Transformer自编码器"""
    def __init__(self):
        super().__init__()
        # 输入投影：(B, T, J, 3) → (B, T, J, D_MODEL)
        self.input_proj = nn.Sequential(
            nn.Linear(3, D_MODEL),
            nn.LayerNorm(D_MODEL),
            nn.ReLU(),
            nn.Dropout(DROPOUT)
        )
        
        # 编码器（堆叠时空Transformer块）
        self.encoder = nn.Sequential(*[
            STTransformerBlock(D_MODEL, NHEAD, DROPOUT) for _ in range(NUM_LAYERS)
        ])
        
        # 解码器（堆叠时空Transformer块）
        self.decoder = nn.Sequential(*[
            STTransformerBlock(D_MODEL, NHEAD, DROPOUT) for _ in range(NUM_LAYERS)
        ])
        
        # 输出投影：还原3维坐标
        self.output_proj = nn.Sequential(
            nn.Linear(D_MODEL, D_MODEL // 2),
            nn.ReLU(),
            nn.Dropout(DROPOUT),
            nn.Linear(D_MODEL // 2, 3)
        )

    def forward(self, x):
        # 输入投影
        x = self.input_proj(x)
        
        # 编码
        z = self.encoder(x)
        
        # 解码
        x_recon = self.decoder(z)
        
        # 输出投影
        x_recon = self.output_proj(x_recon)
        
        return x_recon

    def get_reconstruction_error(self, x):
        """计算帧级/序列级异常概率（0-1）"""
        with torch.no_grad():
            x_recon = self.forward(x)
            # 帧级MSE误差：(B, T)
            frame_errors = torch.mean((x - x_recon) ** 2, dim=[2, 3])
        
        # 归一化到0-1区间
        frame_errors_np = frame_errors.cpu().numpy()
        scaler = MinMaxScaler(feature_range=(0, 1))
        frame_probs = np.zeros_like(frame_errors_np)
        
        for i in range(frame_errors_np.shape[0]):
            frame_probs[i] = scaler.fit_transform(frame_errors_np[i].reshape(-1, 1)).flatten()
        
        # 序列级异常概率（帧概率均值）
        seq_probs = np.mean(frame_probs, axis=1)
        
        return seq_probs, frame_probs

In [None]:
def train_group_model(group_name, train_loader, test_loader, train_dataset):
    """训练单个动作组的ST-Transformer模型（添加准确率输出）"""
    # 初始化模型
    model = STTransformerAutoencoder().to(device)
    criterion = nn.MSELoss()  # 重建损失
    optimizer = optim.AdamW(
        model.parameters(),
        lr=LEARNING_RATE,
        weight_decay=1e-5  # L2正则化
    )
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS)  # 学习率退火
    
    # 训练记录
    train_losses = []
    val_losses = []
    train_accs = []  # 训练集准确率
    val_accs = []    # 验证集准确率
    best_val_loss = float('inf')
    best_model_path = os.path.join(OUTPUT_DIR, f"st_transformer_{group_name}_best.pth")
    
    print(f"\n开始训练动作组：{group_name}")
    print("="*80)
    print(f"{'Epoch':<5} {'Train Loss':<12} {'Val Loss':<12} {'Train Acc':<12} {'Val Acc':<12}")
    print("="*80)
    
    for epoch in range(EPOCHS):
        # 训练阶段
        model.train()
        train_loss = 0.0
        train_correct = 0
        train_total = 0
        
        for batch in train_loader:
            batch = batch.to(device)
            optimizer.zero_grad()
            
            # 前向传播
            recon_batch = model(batch)
            loss = criterion(recon_batch, batch)
            
            # 反向传播
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)  # 梯度裁剪
            optimizer.step()
            
            train_loss += loss.item() * batch.size(0)
            
            # 计算训练集准确率（假设训练集均为正常样本）
            with torch.no_grad():
                seq_probs, frame_probs = model.get_reconstruction_error(batch)
                # 正常帧：重建误差 < 异常阈值
                train_correct += (frame_probs < ANOMALY_THRESHOLD).sum()
                train_total += frame_probs.size  # 总帧数
        
        avg_train_loss = train_loss / len(train_loader.dataset)
        train_acc = train_correct / train_total if train_total > 0 else 0.0
        train_losses.append(avg_train_loss)
        train_accs.append(train_acc)
        
        # 验证阶段
        model.eval()
        val_loss = 0.0
        val_correct = 0
        val_total = 0
        
        with torch.no_grad():
            for batch in test_loader:
                batch = batch.to(device)
                recon_batch = model(batch)
                val_loss += criterion(recon_batch, batch).item() * batch.size(0)
                
                # 计算验证集准确率
                seq_probs, frame_probs = model.get_reconstruction_error(batch)
                val_correct += (frame_probs < ANOMALY_THRESHOLD).sum()
                val_total += frame_probs.size
        
        avg_val_loss = val_loss / len(test_loader.dataset)
        val_acc = val_correct / val_total if val_total > 0 else 0.0
        val_losses.append(avg_val_loss)
        val_accs.append(val_acc)
        
        # 保存最优模型（按验证损失）
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            torch.save({
                'group_name': group_name,
                'epoch': epoch + 1,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'mean': train_dataset.mean,
                'std': train_dataset.std,
                'best_val_loss': best_val_loss,
                'best_val_acc': val_acc,
                'train_losses': train_losses,
                'val_losses': val_losses,
                'train_accs': train_accs,
                'val_accs': val_accs
            }, best_model_path)
        
        # 更新学习率
        scheduler.step()
        
        # 打印日志（每个epoch）
        print(f"{epoch+1:<5} {avg_train_loss:<12.6f} {avg_val_loss:<12.6f} {train_acc:<12.4f} {val_acc:<12.4f}")
    
    # 绘制训练曲线（损失+准确率）
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
    
    # 损失曲线
    ax1.plot(train_losses, label='训练损失', color='#1f77b4', linewidth=2, marker='.')
    ax1.plot(val_losses, label='验证损失', color='#ff7f0e', linewidth=2, marker='.')
    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('MSE损失')
    ax1.set_title(f'{group_name} 损失曲线')
    ax1.legend()
    ax1.grid(alpha=0.3)
    
    # 准确率曲线
    ax2.plot(train_accs, label='训练准确率', color='#2ca02c', linewidth=2, marker='.')
    ax2.plot(val_accs, label='验证准确率', color='#d62728', linewidth=2, marker='.')
    ax2.set_xlabel('Epoch')
    ax2.set_ylabel('准确率')
    ax2.set_title(f'{group_name} 准确率曲线')
    ax2.legend()
    ax2.grid(alpha=0.3)
    
    plt.tight_layout()
    plt.savefig(os.path.join(OUTPUT_DIR, f'{group_name}_train_metrics.png'), dpi=300)
    plt.close()
    
    # 输出最终结果
    best_val_acc = max(val_accs) if val_accs else 0.0
    print(f"\n{group_name} 训练完成！")
    print(f"最优验证损失：{best_val_loss:.6f} | 最优验证准确率：{best_val_acc:.4f}")
    print(f"模型保存路径：{best_model_path}")
    
    return model, best_model_path

# ========== 批量训练所有动作组 ==========
trained_models = {}
for group_name, data in group_datasets.items():
    model, model_path = train_group_model(
        group_name=group_name,
        train_loader=data['train_loader'],
        test_loader=data['test_loader'],
        train_dataset=data['train_dataset']
    )
    trained_models[group_name] = {
        'model': model,
        'model_path': model_path,
        'test_dataset': data['test_dataset'],
        'test_loader': data['test_loader']
    }


开始训练动作组：group_swing_squat
Epoch Train Loss   Val Loss     Train Acc    Val Acc     
1     0.506479     0.121261     0.8746       0.9034      
2     0.240002     0.064288     0.8942       0.9087      
3     0.179673     0.044448     0.8946       0.9148      
4     0.153961     0.037044     0.8968       0.9119      
5     0.139187     0.033087     0.9038       0.9173      
6     0.128868     0.028326     0.9044       0.9158      
7     0.121232     0.024487     0.9031       0.9212      
8     0.115477     0.022271     0.9052       0.9215      
9     0.110392     0.020673     0.9027       0.9237      
10    0.106958     0.022206     0.9014       0.9169      
11    0.103777     0.020125     0.9031       0.9176      
12    0.101055     0.019714     0.9038       0.9194      
13    0.098837     0.017542     0.9023       0.9180      
14    0.097021     0.017542     0.9085       0.9251      
15    0.095555     0.017033     0.9047       0.9229      
16    0.094247     0.019325     0.9078       

findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not foun

30    0.087891     0.016465     0.9087       0.9343      


findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not foun


group_swing_squat 训练完成！
最优验证损失：0.015938 | 最优验证准确率：0.9357
模型保存路径：st_transformer_action_groups/st_transformer_group_swing_squat_best.pth

开始训练动作组：group_drumming_shaking
Epoch Train Loss   Val Loss     Train Acc    Val Acc     
1     0.310515     0.044372     0.8952       0.9288      
2     0.134963     0.028358     0.9078       0.9384      
3     0.112271     0.021996     0.9070       0.9370      
4     0.101129     0.016228     0.9094       0.9351      
5     0.093799     0.015844     0.9075       0.9365      
6     0.088524     0.017088     0.9112       0.9361      
7     0.084335     0.014192     0.9094       0.9391      
8     0.080369     0.014059     0.9106       0.9380      
9     0.077667     0.013229     0.9118       0.9370      
10    0.075720     0.013036     0.9147       0.9370      
11    0.073963     0.011863     0.9141       0.9385      
12    0.072469     0.011897     0.9159       0.9361      
13    0.071327     0.011432     0.9111       0.9380      
14    0.070358     0

findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not foun

30    0.065190     0.011547     0.9123       0.9340      


findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not foun


group_drumming_shaking 训练完成！
最优验证损失：0.011301 | 最优验证准确率：0.9391
模型保存路径：st_transformer_action_groups/st_transformer_group_drumming_shaking_best.pth

开始训练动作组：group_pose
Epoch Train Loss   Val Loss     Train Acc    Val Acc     
1     0.567990     0.163170     0.8718       0.8857      
2     0.234833     0.095015     0.8953       0.9007      
3     0.176663     0.060405     0.9042       0.9174      
4     0.147023     0.042955     0.9078       0.9157      
5     0.130584     0.034400     0.9119       0.9212      
6     0.120194     0.029460     0.9092       0.9217      
7     0.112755     0.027188     0.9097       0.9204      
8     0.106931     0.024737     0.9096       0.9229      
9     0.103034     0.023024     0.9118       0.9259      
10    0.100134     0.025494     0.9140       0.9307      
11    0.097434     0.020580     0.9085       0.9238      
12    0.095390     0.023098     0.9148       0.9315      
13    0.093517     0.019792     0.9104       0.9294      
14    0.091979     0.0

findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not foun

30    0.083693     0.016947     0.9171       0.9336      


findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not foun


group_pose 训练完成！
最优验证损失：0.016947 | 最优验证准确率：0.9341
模型保存路径：st_transformer_action_groups/st_transformer_group_pose_best.pth


In [None]:
def train_group_model(group_name, train_loader, test_loader, train_dataset):
    """训练单个动作组的ST-Transformer模型（添加准确率输出）"""
    # 初始化模型
    model = STTransformerAutoencoder().to(device)
    criterion = nn.MSELoss()  # 重建损失
    optimizer = optim.AdamW(
        model.parameters(),
        lr=LEARNING_RATE,
        weight_decay=1e-5  # L2正则化
    )
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS)  # 学习率退火
    
    # 训练记录
    train_losses = []
    val_losses = []
    train_accs = []  # 训练集准确率
    val_accs = []    # 验证集准确率
    best_val_loss = float('inf')
    best_model_path = os.path.join(OUTPUT_DIR, f"st_transformer_{group_name}_best.pth")
    
    print(f"\n开始训练动作组：{group_name}")
    print("="*80)
    print(f"{'Epoch':<5} {'Train Loss':<12} {'Val Loss':<12} {'Train Acc':<12} {'Val Acc':<12}")
    print("="*80)
    
    for epoch in range(EPOCHS):
        # 训练阶段
        model.train()
        train_loss = 0.0
        train_correct = 0
        train_total = 0
        
        for batch in train_loader:
            batch = batch.to(device)
            optimizer.zero_grad()
            
            # 前向传播
            recon_batch = model(batch)
            loss = criterion(recon_batch, batch)
            
            # 反向传播
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)  # 梯度裁剪
            optimizer.step()
            
            train_loss += loss.item() * batch.size(0)
            
            # 计算训练集准确率（假设训练集均为正常样本）
            with torch.no_grad():
                seq_probs, frame_probs = model.get_reconstruction_error(batch)
                # 正常帧：重建误差 < 异常阈值
                train_correct += (frame_probs < ANOMALY_THRESHOLD).sum()
                train_total += frame_probs.size  # 总帧数
        
        avg_train_loss = train_loss / len(train_loader.dataset)
        train_acc = train_correct / train_total if train_total > 0 else 0.0
        train_losses.append(avg_train_loss)
        train_accs.append(train_acc)
        
        # 验证阶段
        model.eval()
        val_loss = 0.0
        val_correct = 0
        val_total = 0
        
        with torch.no_grad():
            for batch in test_loader:
                batch = batch.to(device)
                recon_batch = model(batch)
                val_loss += criterion(recon_batch, batch).item() * batch.size(0)
                
                # 计算验证集准确率
                seq_probs, frame_probs = model.get_reconstruction_error(batch)
                val_correct += (frame_probs < ANOMALY_THRESHOLD).sum()
                val_total += frame_probs.size
        
        avg_val_loss = val_loss / len(test_loader.dataset)
        val_acc = val_correct / val_total if val_total > 0 else 0.0
        val_losses.append(avg_val_loss)
        val_accs.append(val_acc)
        
        # 保存最优模型（按验证损失）
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            torch.save({
                'group_name': group_name,
                'epoch': epoch + 1,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'mean': train_dataset.mean,
                'std': train_dataset.std,
                'best_val_loss': best_val_loss,
                'best_val_acc': val_acc,
                'train_losses': train_losses,
                'val_losses': val_losses,
                'train_accs': train_accs,
                'val_accs': val_accs
            }, best_model_path)
        
        # 更新学习率
        scheduler.step()
        
        # 打印日志（每个epoch）
        print(f"{epoch+1:<5} {avg_train_loss:<12.6f} {avg_val_loss:<12.6f} {train_acc:<12.4f} {val_acc:<12.4f}")
    
    # 绘制训练曲线（损失+准确率）
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
    
    # 损失曲线
    ax1.plot(train_losses, label='训练损失', color='#1f77b4', linewidth=2, marker='.')
    ax1.plot(val_losses, label='验证损失', color='#ff7f0e', linewidth=2, marker='.')
    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('MSE损失')
    ax1.set_title(f'{group_name} 损失曲线')
    ax1.legend()
    ax1.grid(alpha=0.3)
    
    # 准确率曲线
    ax2.plot(train_accs, label='训练准确率', color='#2ca02c', linewidth=2, marker='.')
    ax2.plot(val_accs, label='验证准确率', color='#d62728', linewidth=2, marker='.')
    ax2.set_xlabel('Epoch')
    ax2.set_ylabel('准确率')
    ax2.set_title(f'{group_name} 准确率曲线')
    ax2.legend()
    ax2.grid(alpha=0.3)
    
    plt.tight_layout()
    plt.savefig(os.path.join(OUTPUT_DIR, f'{group_name}_train_metrics.png'), dpi=300)
    plt.close()
    
    # 输出最终结果
    best_val_acc = max(val_accs) if val_accs else 0.0
    print(f"\n{group_name} 训练完成！")
    print(f"最优验证损失：{best_val_loss:.6f} | 最优验证准确率：{best_val_acc:.4f}")
    print(f"模型保存路径：{best_model_path}")
    
    return model, best_model_path

# ========== 批量训练所有动作组 ==========
trained_models = {}
for group_name, data in group_datasets.items():
    model, model_path = train_group_model(
        group_name=group_name,
        train_loader=data['train_loader'],
        test_loader=data['test_loader'],
        train_dataset=data['train_dataset']
    )
    trained_models[group_name] = {
        'model': model,
        'model_path': model_path,
        'test_dataset': data['test_dataset'],
        'test_loader': data['test_loader']
    }


开始训练动作组：group_swing_squat
Epoch Train Loss   Val Loss     Train Acc    Val Acc     
1     0.506479     0.121261     0.8746       0.9034      
2     0.240002     0.064288     0.8942       0.9087      
3     0.179673     0.044448     0.8946       0.9148      
4     0.153961     0.037044     0.8968       0.9119      
5     0.139187     0.033087     0.9038       0.9173      
6     0.128868     0.028326     0.9044       0.9158      
7     0.121232     0.024487     0.9031       0.9212      
8     0.115477     0.022271     0.9052       0.9215      
9     0.110392     0.020673     0.9027       0.9237      
10    0.106958     0.022206     0.9014       0.9169      
11    0.103777     0.020125     0.9031       0.9176      
12    0.101055     0.019714     0.9038       0.9194      
13    0.098837     0.017542     0.9023       0.9180      
14    0.097021     0.017542     0.9085       0.9251      
15    0.095555     0.017033     0.9047       0.9229      
16    0.094247     0.019325     0.9078       

findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not foun

30    0.087891     0.016465     0.9087       0.9343      


findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not foun


group_swing_squat 训练完成！
最优验证损失：0.015938 | 最优验证准确率：0.9357
模型保存路径：st_transformer_action_groups/st_transformer_group_swing_squat_best.pth

开始训练动作组：group_drumming_shaking
Epoch Train Loss   Val Loss     Train Acc    Val Acc     
1     0.310515     0.044372     0.8952       0.9288      
2     0.134963     0.028358     0.9078       0.9384      
3     0.112271     0.021996     0.9070       0.9370      
4     0.101129     0.016228     0.9094       0.9351      
5     0.093799     0.015844     0.9075       0.9365      
6     0.088524     0.017088     0.9112       0.9361      
7     0.084335     0.014192     0.9094       0.9391      
8     0.080369     0.014059     0.9106       0.9380      
9     0.077667     0.013229     0.9118       0.9370      
10    0.075720     0.013036     0.9147       0.9370      
11    0.073963     0.011863     0.9141       0.9385      
12    0.072469     0.011897     0.9159       0.9361      
13    0.071327     0.011432     0.9111       0.9380      
14    0.070358     0

findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not foun

30    0.065190     0.011547     0.9123       0.9340      


findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not foun


group_drumming_shaking 训练完成！
最优验证损失：0.011301 | 最优验证准确率：0.9391
模型保存路径：st_transformer_action_groups/st_transformer_group_drumming_shaking_best.pth

开始训练动作组：group_pose
Epoch Train Loss   Val Loss     Train Acc    Val Acc     
1     0.567990     0.163170     0.8718       0.8857      
2     0.234833     0.095015     0.8953       0.9007      
3     0.176663     0.060405     0.9042       0.9174      
4     0.147023     0.042955     0.9078       0.9157      
5     0.130584     0.034400     0.9119       0.9212      
6     0.120194     0.029460     0.9092       0.9217      
7     0.112755     0.027188     0.9097       0.9204      
8     0.106931     0.024737     0.9096       0.9229      
9     0.103034     0.023024     0.9118       0.9259      
10    0.100134     0.025494     0.9140       0.9307      
11    0.097434     0.020580     0.9085       0.9238      
12    0.095390     0.023098     0.9148       0.9315      
13    0.093517     0.019792     0.9104       0.9294      
14    0.091979     0.0

findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not foun

30    0.083693     0.016947     0.9171       0.9336      


findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not foun


group_pose 训练完成！
最优验证损失：0.016947 | 最优验证准确率：0.9341
模型保存路径：st_transformer_action_groups/st_transformer_group_pose_best.pth


In [None]:
def detect_anomaly(group_name, model, test_dataset, test_loader):
    """对动作组进行帧级异常检测"""
    model.eval()
    all_seq_probs = []
    all_frame_probs = []
    all_file_names = []
    
    with torch.no_grad():
        for batch_idx, batch in enumerate(test_loader):
            batch = batch.to(device)
            seq_probs, frame_probs = model.get_reconstruction_error(batch)
            
            all_seq_probs.extend(seq_probs)
            all_frame_probs.extend(frame_probs)
            
            # 获取文件名
            start_idx = batch_idx * BATCH_SIZE
            end_idx = min(start_idx + BATCH_SIZE, len(test_dataset))
            all_file_names.extend(test_dataset.file_names[start_idx:end_idx])
    
    # 生成检测结果DataFrame
    frame_results = []
    for seq_idx, (file_name, seq_prob, frame_probs) in enumerate(zip(all_file_names, all_seq_probs, all_frame_probs)):
        for frame_idx in range(TARGET_FRAMES):
            frame_results.append({
                'group_name': group_name,
                'file_name': file_name,
                'sequence_idx': seq_idx,
                'frame_idx': frame_idx,
                'sequence_anomaly_prob': round(seq_prob, 4),
                'frame_anomaly_prob': round(frame_probs[frame_idx], 4),
                'is_anomaly': frame_probs[frame_idx] >= ANOMALY_THRESHOLD
            })
    
    results_df = pd.DataFrame(frame_results)
    
    # 保存结果
    results_path = os.path.join(OUTPUT_DIR, f'{group_name}_anomaly_results.csv')
    results_df.to_csv(results_path, index=False, encoding='utf-8-sig')
    
    # 统计摘要
    total_frames = len(results_df)
    anomaly_frames = results_df['is_anomaly'].sum()
    anomaly_seqs = results_df.groupby('sequence_idx')['is_anomaly'].any().sum()
    
    print(f"\n{group_name} 异常检测结果：")
    print(f"总帧数：{total_frames} | 异常帧数：{anomaly_frames} ({anomaly_frames/total_frames*100:.2f}%)")
    print(f"总序列数：{len(all_seq_probs)} | 异常序列数：{anomaly_seqs} ({anomaly_seqs/len(all_seq_probs)*100:.2f}%)")
    print(f"结果保存路径：{results_path}")
    
    return results_df

# ========== 执行异常检测 ==========
all_detection_results = {}
for group_name, info in trained_models.items():
    results_df = detect_anomaly(
        group_name=group_name,
        model=info['model'],
        test_dataset=info['test_dataset'],
        test_loader=info['test_loader']
    )
    all_detection_results[group_name] = results_df

print("\n" + "="*60)
print("所有动作组训练和检测完成！")
print(f"结果保存在：{OUTPUT_DIR}")
print("="*60)


group_swing_squat 异常检测结果：
总帧数：2816 | 异常帧数：185 (6.57%)
总序列数：88 | 异常序列数：88 (100.00%)
结果保存路径：st_transformer_action_groups/group_swing_squat_anomaly_results.csv



group_drumming_shaking 异常检测结果：
总帧数：5728 | 异常帧数：378 (6.60%)
总序列数：179 | 异常序列数：179 (100.00%)
结果保存路径：st_transformer_action_groups/group_drumming_shaking_anomaly_results.csv

group_pose 异常检测结果：
总帧数：2336 | 异常帧数：155 (6.64%)
总序列数：73 | 异常序列数：73 (100.00%)
结果保存路径：st_transformer_action_groups/group_pose_anomaly_results.csv

所有动作组训练和检测完成！
结果保存在：st_transformer_action_groups


In [None]:
# %% 单元格7：详细可视化分析（多帧数+多模态对比，修复类别数不匹配问题）
def detailed_visualization_with_frames(log_data, label_names):
    """
    详细可视化分析（支持多帧数分组）：
    1. 各分组损失/准确率曲线对比
    2. 多模态+多帧数指标雷达图
    3. 混淆矩阵（按分组展示）
    4. 类别级准确率对比
    5. 分类报告热力图
    """
    # 修复字体问题：使用通用英文字体，避免中文依赖
    plt.rcParams['font.sans-serif'] = ['DejaVu Sans', 'Arial', 'Helvetica', 'sans-serif']
    plt.rcParams['axes.unicode_minus'] = False
    plt.rcParams['figure.figsize'] = (18, 12)  # 扩大图表尺寸，适配英文标签
    plt.rcParams['font.size'] = 11

    results_2d = log_data['2d']
    results_3d = log_data['3d']
    supported_frames = log_data['params']['supported_frames']
    epochs = log_data['params']['epochs']
    epochs_range = range(1, epochs + 1)

    # 颜色方案（兼容大小写）
    colors_modality = {'2d': '#2E86AB', '3d': '#A23B72', '2D': '#2E86AB', '3D': '#A23B72'}
    colors_frame = {32: '#F18F01', 64: '#C73E1D'}

    # ========== 图1：损失曲线对比（所有分组） ==========
    plt.subplot(2, 3, 1)
    for modality, results in [('2d', results_2d), ('3d', results_3d)]:
        for frame_num, res in results.items():
            train_loss = res['train_log']['loss']
            test_loss = res['test_log']['loss']
            label = f'{modality.upper()}-{frame_num}f'
            plt.plot(epochs_range, train_loss, label=f'{label}-Train', 
                     color=colors_modality[modality], linewidth=2, 
                     linestyle='-' if frame_num == 32 else '--')
            plt.plot(epochs_range, test_loss, label=f'{label}-Test', 
                     color=colors_modality[modality], linewidth=2, 
                     linestyle=':' if frame_num == 32 else '-.')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Loss Curve Comparison (All Groups)')
    plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left', fontsize=9)
    plt.grid(alpha=0.3)
    plt.tight_layout()

    # ========== 图2：准确率曲线对比（所有分组） ==========
    plt.subplot(2, 3, 2)
    for modality, results in [('2d', results_2d), ('3d', results_3d)]:
        for frame_num, res in results.items():
            train_acc = res['train_log']['acc']
            test_acc = res['test_log']['acc']
            best_acc = res['best_metrics']['acc']
            best_epoch = res['best_metrics']['epoch']
            label = f'{modality.upper()}-{frame_num}f'
            # 绘制曲线
            plt.plot(epochs_range, train_acc, label=f'{label}-Train', 
                     color=colors_modality[modality], linewidth=2, 
                     linestyle='-' if frame_num == 32 else '--')
            plt.plot(epochs_range, test_acc, label=f'{label}-Test', 
                     color=colors_modality[modality], linewidth=2, 
                     linestyle=':' if frame_num == 32 else '-.')
            # 标注最佳准确率
            plt.scatter(best_epoch, best_acc, color=colors_modality[modality], s=50, zorder=5)
            plt.annotate(f'{best_acc:.3f}', 
                         xy=(best_epoch, best_acc), 
                         xytext=(best_epoch+1, best_acc-0.05),
                         color=colors_modality[modality], fontsize=8)
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.title('Accuracy Curve Comparison (All Groups)')
    plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left', fontsize=9)
    plt.grid(alpha=0.3)
    plt.ylim(0, 1.05)
    plt.tight_layout()

    # ========== 图3：多指标雷达图（最佳分组对比） ==========
    plt.subplot(2, 3, 3)
    # 收集所有分组的最佳指标
    radar_data = []
    labels_radar = []
    for modality, results in [('2d', results_2d), ('3d', results_3d)]:
        for frame_num, res in results.items():
            metrics = [
                res['best_metrics']['acc'],
                res['best_metrics']['precision'],
                res['best_metrics']['recall'],
                res['best_metrics']['f1']
            ]
            radar_data.append(metrics)
            labels_radar.append(f'{modality.upper()}-{frame_num}f')
    
    # 雷达图设置（英文指标名）
    metrics_names = ['Accuracy', 'Precision', 'Recall', 'F1-Score']
    angles = np.linspace(0, 2*np.pi, len(metrics_names), endpoint=False).tolist()
    angles += angles[:1]
    metrics_names += metrics_names[:1]

    # 绘制雷达图（修复标签格式）
    for i, (data, label) in enumerate(zip(radar_data, labels_radar)):
        data += data[:1]
        # 兼容标签格式（提取模态部分，转为小写）
        modality_part = label.split('-')[0].lower()
        color = colors_modality[modality_part]
        plt.polar(angles, data, label=label, color=color, linewidth=2, marker='o', markersize=3)
        plt.fill(angles, data, color=color, alpha=0.1)
    
    plt.xticks(angles[:-1], metrics_names[:-1], fontsize=10)
    plt.yticks([0.2, 0.4, 0.6, 0.8, 1.0], ['0.2', '0.4', '0.6', '0.8', '1.0'], alpha=0.7)
    plt.ylim(0, 1.0)
    plt.title('Multi-Metric Radar Chart (All Groups)', pad=20)
    plt.legend(loc='upper right', bbox_to_anchor=(1.3, 1.0), fontsize=9)
    plt.tight_layout()

    # ========== 图4：混淆矩阵（最佳性能分组） ==========
    plt.subplot(2, 3, 4)
    # 找到最佳F1的分组
    best_f1 = 0.0
    best_conf = None
    best_label = ""
    best_res = None
    for modality, results in [('2d', results_2d), ('3d', results_3d)]:
        for frame_num, res in results.items():
            if res['best_metrics']['f1'] > best_f1:
                best_f1 = res['best_metrics']['f1']
                best_preds = res['best_metrics']['preds']
                best_true = res['best_metrics']['labels']
                # 获取该分组实际存在的类别ID（去重并排序）
                actual_class_ids = sorted(list(set(best_true) | set(best_preds)))
                best_conf = confusion_matrix(best_true, best_preds, labels=actual_class_ids)
                best_label = f'{modality.upper()}-{frame_num}f'
                best_res = res  # 保存最佳结果
    
    # 关键修复：按实际类别ID过滤target_names，确保一一对应
    actual_class_names = [label_names[class_id] for class_id in actual_class_ids]
    
    # 绘制最佳分组混淆矩阵
    cm_norm = best_conf.astype('float') / best_conf.sum(axis=1)[:, np.newaxis] * 100
    sns.heatmap(cm_norm, annot=best_conf, fmt='d', cmap='Blues',
                xticklabels=actual_class_names, yticklabels=actual_class_names,
                cbar_kws={'label': 'Accuracy (%)'},
                annot_kws={'fontsize': 9})
    plt.xlabel('Predicted Label')
    plt.ylabel('True Label')
    plt.title(f'Confusion Matrix (Best Group: {best_label}, F1: {best_f1:.3f})')
    plt.xticks(rotation=45, ha='right', fontsize=9)
    plt.yticks(rotation=0, fontsize=9)
    plt.tight_layout()

    # ========== 图5：类别级准确率对比（所有分组） ==========
    plt.subplot(2, 3, 5)
    x = np.arange(len(actual_class_names))  # 使用实际类别数
    width = 0.15  # 柱宽
    offset = -len(radar_data)*width/2  # 偏移量

    for i, (label, _) in enumerate(zip(labels_radar, radar_data)):
        # 解析分组信息（修复格式）
        modality_part = label.split('-')[0].lower()
        frame_num = int(label.split('-')[1][:-1])
        res = results_2d[frame_num] if modality_part == '2d' else results_3d[frame_num]
        group_preds = res['best_metrics']['preds']
        group_true = res['best_metrics']['labels']
        
        # 计算该分组的类别级准确率（仅包含实际存在的类别）
        class_acc = []
        for class_id in actual_class_ids:
            mask = group_true == class_id
            if mask.sum() > 0:
                acc = (group_preds[mask] == class_id).sum() / mask.sum()
            else:
                acc = 0.0
            class_acc.append(acc)
        
        # 绘制柱状图
        plt.bar(x + offset + i*width, class_acc, width, label=label, 
                color=colors_modality[modality_part], alpha=0.8)
        # 标注数值（只标注前3个分组，避免拥挤）
        if i < 3:
            for k, v in enumerate(class_acc):
                plt.text(x[k] + offset + i*width, v + 0.02, f'{v:.2f}', 
                         ha='center', va='bottom', fontsize=7)
    
    plt.xlabel('Class')
    plt.ylabel('Class Accuracy')
    plt.title('Class-Level Accuracy Comparison (All Groups)')
    plt.xticks(x, actual_class_names, rotation=45, ha='right', fontsize=9)
    plt.ylim(0, 1.1)
    plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left', fontsize=9)
    plt.grid(alpha=0.3, axis='y')
    plt.tight_layout()

    # ========== 图6：各分组核心指标汇总 ==========
    plt.subplot(2, 3, 6)
    # 收集核心指标
    frame_nums = []
    accs_2d = []
    accs_3d = []
    for frame_num in supported_frames:
        frame_nums.append(f'{frame_num}f')  # 简化标签为"32f"/"64f"
        accs_2d.append(results_2d[frame_num]['best_metrics']['acc'] if frame_num in results_2d else 0)
        accs_3d.append(results_3d[frame_num]['best_metrics']['acc'] if frame_num in results_3d else 0)
    
    # 绘制分组对比柱状图
    x = np.arange(len(frame_nums))
    width = 0.35
    plt.bar(x - width/2, accs_2d, width, label='2D Modality', color=colors_modality['2d'], alpha=0.8)
    plt.bar(x + width/2, accs_3d, width, label='3D Modality', color=colors_modality['3d'], alpha=0.8)
    # 标注数值
    for i, v in enumerate(accs_2d):
        if v > 0:
            plt.text(i - width/2, v + 0.02, f'{v:.3f}', ha='center', va='bottom', fontsize=10)
    for i, v in enumerate(accs_3d):
        if v > 0:
            plt.text(i + width/2, v + 0.02, f'{v:.3f}', ha='center', va='bottom', fontsize=10)
    plt.xlabel('Frame Number')
    plt.ylabel('Best Accuracy')
    plt.title('Accuracy Comparison (Modality x Frames)')
    plt.xticks(x, frame_nums, fontsize=10)
    plt.ylim(0, 1.1)
    plt.legend(fontsize=10)
    plt.grid(alpha=0.3, axis='y')
    plt.tight_layout()

    # 保存综合图表
    plt.savefig('comprehensive_visualization_with_frames.png', dpi=300, bbox_inches='tight')
    plt.show()

    # ========== 额外：分类报告热力图（最佳分组） ==========
    plt.figure(figsize=(12, 6))
    # 最佳分组分类报告（使用实际类别ID和类别名）
    best_preds = best_res['best_metrics']['preds']
    best_true = best_res['best_metrics']['labels']
    
    # 生成分类报告（指定labels参数，确保与target_names匹配）
    report = classification_report(
        best_true, best_preds, 
        labels=actual_class_ids,  # 关键：指定实际存在的类别ID
        target_names=actual_class_names,  # 过滤后的类别名
        output_dict=True,
        zero_division=0  # 避免除以零警告
    )
    report_df = pd.DataFrame(report).T[:-3]  # 去除avg和total行
    report_df = report_df[['precision', 'recall', 'f1-score']]

    # 绘制热力图
    sns.heatmap(report_df, annot=True, cmap='Blues', vmin=0, vmax=1, fmt='.3f',
                annot_kws={'fontsize': 10},
                cbar_kws={'label': 'Score'})
    plt.title(f'Classification Report (Best Group: {best_label}, F1: {best_f1:.3f})', fontsize=12)
    plt.xlabel('Metrics', fontsize=11)
    plt.ylabel('Classes', fontsize=11)
    plt.xticks(rotation=0, fontsize=10)
    plt.yticks(rotation=0, fontsize=10)
    plt.tight_layout()
    plt.savefig('best_group_classification_report.png', dpi=300, bbox_inches='tight')
    plt.show()

    # ========== 输出详细文字报告 ==========
    print("\n" + "="*80)
    print("📋 Detailed Performance Comparison Report (Modality x Frames)")
    print("="*80)
    print(f"{'Group':<12} {'Accuracy':<10} {'Precision':<10} {'Recall':<10} {'F1-Score':<10} {'Sample Size':<10}")
    print("-"*80)
    for modality, results in [('2d', results_2d), ('3d', results_3d)]:
        for frame_num, res in results.items():
            group_name = f'{modality.upper()}-{frame_num}f'
            acc = res['best_metrics']['acc']
            prec = res['best_metrics']['precision']
            rec = res['best_metrics']['recall']
            f1 = res['best_metrics']['f1']
            sample_num = res['train_samples'] + res['test_samples']
            print(f"{group_name:<12} {acc:.4f}     {prec:.4f}     {rec:.4f}     {f1:.4f}     {sample_num:<10}")

    print(f"\n🏆 Best Group: {best_label}")
    print(f"  Best F1-Score: {best_f1:.4f}")
    print(f"  Best Accuracy: {best_res['best_metrics']['acc']:.4f}")
    print(f"  Actual Classes: {actual_class_names}")  # 显示实际训练的类别
    print(f"  Corresponding Model Path: {best_res['model_path']}")

    print(f"\n📁 Generated Visualization Files:")
    print(f"  1. comprehensive_visualization_with_frames.png (Comprehensive Comparison)")
    print(f"  2. best_group_classification_report.png (Best Group Classification Report)")
    print("="*80)

# 加载训练日志并执行可视化
with open('training_logs_with_frames.pkl', 'rb') as f:
    log_data = pickle.load(f)

detailed_visualization_with_frames(log_data, label_names=dataset_2d.label_names)

NameError: name 'dataset_2d' is not defined