# sEMG-HHT 双分类器系统 | sEMG-HHT Dual Classifier System

## 系统架构 | System Architecture

本笔记本实现了双分类器系统：
This notebook implements a dual classifier system:

1. **深度学习 CNN** - 动作质量分类（全程、半程、无效）
   **Deep Learning CNN** - Action Quality Classification (Full, Half, Invalid)
   
2. **SVM 分类器** - 性别分类（男性、女性）
   **SVM Classifier** - Gender Classification (Male, Female)

## 关键改进 | Key Improvements

### 解决训练问题 | Training Problem Solutions:
- ✅ **扩展的CNN架构** - 7层深度网络，更强的特征提取能力
  **Expanded CNN Architecture** - 7-layer deep network with stronger feature extraction
  
- ✅ **批归一化** - 加速训练，防止梯度消失
  **Batch Normalization** - Accelerate training, prevent vanishing gradients
  
- ✅ **Kaiming初始化** - 正确的权重初始化，确保梯度流动
  **Kaiming Initialization** - Proper weight init for gradient flow
  
- ✅ **学习率预热** - 防止训练初期不稳定
  **Learning Rate Warmup** - Prevent early training instability
  
- ✅ **梯度裁剪** - 防止梯度爆炸
  **Gradient Clipping** - Prevent gradient explosion
  
- ✅ **数据增强** - 提高模型泛化能力
  **Data Augmentation** - Improve model generalization

### 网络规模 | Network Scale:
- **输入** Input: 1×256×256
- **通道数** Channels: 64 → 128 → 256 → 512 → 1024 → 2048 → 2048
- **特征维度** Feature dim: 2048
- **分类头** Classifier: 2048 → 1024 → 512 → 3 classes

## 数据要求 | Data Requirements

- **格式** Format: `.npz` 文件，包含 256×256 HHT 矩阵
- **通道** Channels: 单通道（灰度图）
- **命名规则** Naming: `肌肉名_动作_性别_编号.npz`
  - 例如 Example: `BICEPS_fatiguetest_M_006.npz` (男性，全程动作)
  - 例如 Example: `TRICEPS_half_F_012.npz` (女性，半程动作)
  - 测试文件 Test files: 以 `Test` 开头

## 1. 环境配置 | Environment Setup

In [None]:
import os
import sys

# 检测Kaggle环境 | Detect Kaggle environment
IS_KAGGLE = os.path.exists('/kaggle/input')

if IS_KAGGLE:
    DATA_DIR = '/kaggle/input/hilbertmatrix-npz/hht_matrices'
    CHECKPOINT_DIR = '/kaggle/working/checkpoints'
    print('🏃 在Kaggle上运行 | Running on Kaggle')
    print(f'📁 数据目录 | Data directory: {DATA_DIR}')
else:
    DATA_DIR = './data'
    CHECKPOINT_DIR = './checkpoints'
    print('💻 本地运行 | Running locally')
    print(f'📁 数据目录 | Data directory: {DATA_DIR}')

os.makedirs(CHECKPOINT_DIR, exist_ok=True)
print(f'💾 检查点目录 | Checkpoint directory: {CHECKPOINT_DIR}')

## 2. 导入依赖 | Import Dependencies

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import glob
import re
from sklearn.svm import SVC
from sklearn.preprocessing import StandardScaler, LabelEncoder
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score
import matplotlib.pyplot as plt
import warnings
from typing import Tuple, List, Dict, Optional
import pickle
from tqdm.auto import tqdm

warnings.filterwarnings('ignore')

# 设置随机种子 | Set random seeds
SEED = 42
torch.manual_seed(SEED)
np.random.seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

# 检测设备 | Detect device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'🖥️  使用设备 | Using device: {device}')
if torch.cuda.is_available():
    print(f'   GPU: {torch.cuda.get_device_name(0)}')

## 3. 超参数配置 | Hyperparameter Configuration

在此配置所有训练参数。根据需要调整这些值。
Configure all training parameters here. Adjust these values as needed.

In [None]:
# =============================================================================
# 超参数配置 | HYPERPARAMETER CONFIGURATION
# =============================================================================

# -----------------------------------------------------------------------------
# 模型架构 | Model Architecture
# -----------------------------------------------------------------------------
MODEL_IN_CHANNELS = 1              # 输入通道数（灰度图）| Input channels (grayscale)
MODEL_BASE_CHANNELS = 64           # 基础通道数 | Base channels
MODEL_NUM_LAYERS = 7               # 卷积层数（扩展至7层）| Number of conv layers (expanded to 7)
MODEL_DROPOUT_RATE = 0.5           # Dropout率 | Dropout rate

# -----------------------------------------------------------------------------
# 动作质量CNN训练配置 | Action Quality CNN Training Config
# -----------------------------------------------------------------------------
ACTION_EPOCHS = 100                # 训练轮数 | Training epochs
ACTION_BATCH_SIZE = 16             # 批次大小 | Batch size
ACTION_LEARNING_RATE = 0.0001      # 学习率（降低以提高稳定性）| Learning rate (lowered for stability)
ACTION_WEIGHT_DECAY = 1e-4         # L2正则化 | L2 regularization
ACTION_WARMUP_EPOCHS = 5           # 学习率预热轮数 | LR warmup epochs
ACTION_GRAD_CLIP = 1.0             # 梯度裁剪值 | Gradient clipping value

# -----------------------------------------------------------------------------
# 学习率调度器 | Learning Rate Scheduler
# -----------------------------------------------------------------------------
LR_SCHEDULER_FACTOR = 0.5          # 学习率衰减因子 | LR decay factor
LR_SCHEDULER_PATIENCE = 7          # 等待轮数 | Patience epochs
LR_SCHEDULER_MIN_LR = 1e-6         # 最小学习率 | Minimum LR

# -----------------------------------------------------------------------------
# 训练轮数配置 | Training Rounds Configuration
# -----------------------------------------------------------------------------
NUM_TRAINING_ROUNDS = 3            # 总训练轮数 | Total training rounds
EPOCHS_PER_ROUND = 100             # 每轮训练的epoch数 | Epochs per training round

# -----------------------------------------------------------------------------
# SVM配置 | SVM Configuration
# -----------------------------------------------------------------------------
SVM_KERNEL = 'rbf'                 # SVM核函数 | SVM kernel
SVM_C = 10.0                       # 正则化参数 | Regularization parameter
SVM_GAMMA = 'scale'                # Gamma参数 | Gamma parameter

# -----------------------------------------------------------------------------
# 数据配置 | Data Configuration
# -----------------------------------------------------------------------------
DATA_NORMALIZE = True              # 数据归一化 | Data normalization
DATA_TEST_SIZE = 0.2               # 验证集比例 | Validation split ratio
DATA_AUGMENTATION = True           # 数据增强 | Data augmentation

# -----------------------------------------------------------------------------
# 检查点配置 | Checkpoint Configuration
# -----------------------------------------------------------------------------
CHECKPOINT_INTERVAL = 10           # 检查点保存间隔 | Checkpoint save interval

print('='*80)
print('超参数配置 | HYPERPARAMETER CONFIGURATION')
print('='*80)
print(f'\n📐 模型架构 | Model Architecture:')
print(f'   输入通道 Input channels: {MODEL_IN_CHANNELS}')
print(f'   基础通道 Base channels: {MODEL_BASE_CHANNELS}')
print(f'   网络层数 Network layers: {MODEL_NUM_LAYERS}')
print(f'   Dropout率 Dropout rate: {MODEL_DROPOUT_RATE}')
print(f'\n🎯 动作质量训练 | Action Quality Training:')
print(f'   训练轮数 Epochs: {ACTION_EPOCHS}')
print(f'   批次大小 Batch size: {ACTION_BATCH_SIZE}')
print(f'   学习率 Learning rate: {ACTION_LEARNING_RATE}')
print(f'   权重衰减 Weight decay: {ACTION_WEIGHT_DECAY}')
print(f'   预热轮数 Warmup epochs: {ACTION_WARMUP_EPOCHS}')
print(f'   梯度裁剪 Gradient clipping: {ACTION_GRAD_CLIP}')
print(f'\n🔄 训练轮数配置 | Training Rounds Configuration:')
print(f'   总训练轮数 Total rounds: {NUM_TRAINING_ROUNDS}')
print(f'   每轮epoch数 Epochs per round: {EPOCHS_PER_ROUND}')
print(f'\n📉 学习率调度器 | Learning Rate Scheduler:')
print(f'   衰减因子 Decay factor: {LR_SCHEDULER_FACTOR}')
print(f'   等待轮数 Patience: {LR_SCHEDULER_PATIENCE}')
print(f'   最小学习率 Min LR: {LR_SCHEDULER_MIN_LR}')
print(f'\n🔧 SVM配置 | SVM Configuration:')
print(f'   核函数 Kernel: {SVM_KERNEL}')
print(f'   C参数 C: {SVM_C}')
print(f'   Gamma: {SVM_GAMMA}')
print('='*80)

## 4. 模型架构 | Model Architecture

### 扩展的CNN编码器（7层）| Expanded CNN Encoder (7 layers)

解决训练问题的关键改进：
Key improvements to solve training issues:

1. **更深的网络**：7层卷积，提取更复杂的特征
   **Deeper network**: 7 conv layers for more complex features
   
2. **批归一化**：每层后添加BN，加速收敛
   **Batch Normalization**: BN after each layer, faster convergence
   
3. **Kaiming初始化**：防止梯度消失/爆炸
   **Kaiming Init**: Prevent vanishing/exploding gradients
   
4. **残差连接**：改善梯度流动
   **Residual connections**: Better gradient flow

In [None]:
class ImprovedConvBlock(nn.Module):
    """
    改进的卷积块，包含BatchNorm和残差连接
    Improved conv block with BatchNorm and residual connection
    """
    def __init__(self, in_channels: int, out_channels: int, 
                 kernel_size: int = 3, stride: int = 2, padding: int = 1,
                 use_residual: bool = False):
        super(ImprovedConvBlock, self).__init__()
        
        self.conv = nn.Conv2d(in_channels, out_channels, 
                             kernel_size=kernel_size, 
                             stride=stride, 
                             padding=padding,
                             bias=False)  # BN后不需要bias
        self.bn = nn.BatchNorm2d(out_channels)
        self.activation = nn.LeakyReLU(0.2, inplace=True)
        
        self.use_residual = use_residual and (in_channels == out_channels) and (stride == 1)
        
        # Kaiming初始化 | Kaiming initialization
        nn.init.kaiming_normal_(self.conv.weight, mode='fan_out', nonlinearity='leaky_relu')
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        identity = x
        
        out = self.conv(x)
        out = self.bn(out)
        out = self.activation(out)
        
        if self.use_residual:
            out = out + identity
            
        return out


class ExpandedCNNEncoder(nn.Module):
    """
    可配置深度的CNN编码器，用于特征提取
    Configurable-depth CNN encoder for feature extraction
    
    架构 | Architecture:
    - 可配置的卷积层数（1-8层）
    - 通道数递增：64 → 128 → 256 → 512 → ...
    - 每块包含：Conv2d + BatchNorm + LeakyReLU
    - 全局平均池化输出特征向量
    
    Input: (B, 1, 256, 256)
    Output: (B, feature_dim)
    """
    def __init__(self, in_channels: int = 1, base_channels: int = 64, num_layers: int = 7):
        super(ExpandedCNNEncoder, self).__init__()
        
        # 验证参数 | Validate parameters
        if num_layers < 1:
            raise ValueError(f"num_layers must be at least 1, got {num_layers}")
        if num_layers > 8:
            raise ValueError(f"num_layers must be at most 8 for 256x256 input, got {num_layers}")
        
        self.num_layers = num_layers
        
        # 定义通道数序列 | Define channel progression
        # 通道数以2的幂次增长，但在深度网络中后期可能保持不变
        channels = []
        for i in range(num_layers):
            if i < 6:
                channels.append(base_channels * (2**i))
            else:
                # 对于第7层及以后，保持与第6层相同的通道数
                channels.append(base_channels * (2**5))  # 2048 for base_channels=64
        
        # 构建编码器层 | Build encoder layers
        layers = []
        current_channels = in_channels
        
        for i, out_channels in enumerate(channels):
            # 计算下采样次数，确保输出尺寸合理
            # 对于256x256输入，最多下采样8次（256 -> 1）
            # 前min(num_layers, 8)层使用stride=2，其余使用stride=1
            stride = 2 if i < min(num_layers, 8) else 1
            
            # 深层网络（第6层及以后）使用残差连接
            use_residual = (i >= 5 and current_channels == out_channels)
            
            layers.append(ImprovedConvBlock(
                in_channels=current_channels,
                out_channels=out_channels,
                kernel_size=3,
                stride=stride,
                padding=1,
                use_residual=use_residual
            ))
            current_channels = out_channels
        
        self.encoder = nn.Sequential(*layers)
        self.global_avg_pool = nn.AdaptiveAvgPool2d((1, 1))
        self.feature_dim = channels[-1]
        
        print(f'✅ 编码器已创建 | Encoder created:')
        print(f'   层数 Layers: {len(channels)}')
        print(f'   通道序列 Channels: {in_channels} → {" → ".join(map(str, channels))}')
        print(f'   输出特征维度 Output feature dim: {self.feature_dim}')
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """提取特征 | Extract features"""
        x = self.encoder(x)
        x = self.global_avg_pool(x)
        x = x.view(x.size(0), -1)
        return x
    
    def get_feature_dim(self) -> int:
        """返回特征维度 | Return feature dimension"""
        return self.feature_dim


class ActionQualityCNN(nn.Module):
    """
    动作质量分类器（深度学习）
    Action Quality Classifier (Deep Learning)
    
    3个类别 | 3 classes:
    - 0: Full (全程)
    - 1: Half (半程)
    - 2: Invalid (无效)
    """
    def __init__(self, encoder: ExpandedCNNEncoder, n_classes: int = 3, dropout_rate: float = 0.5):
        super(ActionQualityCNN, self).__init__()
        
        self.encoder = encoder
        feature_dim = encoder.get_feature_dim()
        
        # 自适应分类头：根据特征维度调整中间层大小
        # Adaptive classifier head: scale intermediate layers based on feature_dim
        hidden_dim_1 = min(max(256, feature_dim // 2), 1024)
        hidden_dim_2 = min(max(128, feature_dim // 4), 512)
        
        # 3层分类头，逐步降维 | 3-layer classification head with gradual dimension reduction
        self.classifier = nn.Sequential(
            nn.Linear(feature_dim, hidden_dim_1),
            nn.BatchNorm1d(hidden_dim_1),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout_rate),
            
            nn.Linear(hidden_dim_1, hidden_dim_2),
            nn.BatchNorm1d(hidden_dim_2),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout_rate),
            
            nn.Linear(hidden_dim_2, n_classes)
        )
        
        # 初始化分类头 | Initialize classifier head
        for m in self.classifier.modules():
            if isinstance(m, nn.Linear):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
        
        print(f'✅ 动作质量分类器已创建 | Action Quality Classifier created:')
        print(f'   输入特征维度 Input feature dim: {feature_dim}')
        print(f'   分类头结构 Classifier: {feature_dim} → {hidden_dim_1} → {hidden_dim_2} → {n_classes}')
        print(f'   Dropout率 Dropout rate: {dropout_rate}')
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """前向传播 | Forward pass"""
        features = self.encoder(x)
        logits = self.classifier(features)
        return logits
    
    def extract_features(self, x: torch.Tensor) -> torch.Tensor:
        """仅提取特征，用于SVM | Extract features only, for SVM"""
        return self.encoder(x)


print('\n✅ 模型类定义完成 | Model classes defined')

## 5. 数据加载 | Data Loading

从Kaggle数据集或本地目录加载真实的HHT矩阵。
Load real HHT matrices from Kaggle dataset or local directory.

In [None]:
def parse_filename(filename: str) -> Optional[Dict[str, str]]:
    """
    解析文件名提取标签
    Parse filename to extract labels
    
    文件命名格式 | File naming format:
    - MUSCLENAME_movement_GENDER_###.npz
    - 例如 Example: BICEPS_fatiguetest_M_006.npz
    
    Returns:
        dict with 'gender' and 'movement' keys, or None if test file
    """
    basename = os.path.basename(filename)
    
    # 跳过测试文件 | Skip test files
    if basename.lower().startswith('test'):
        return None
    
    # 提取性别 | Extract gender
    gender_match = re.search(r'[_-]([MF])[_-]', basename)
    if not gender_match:
        return None
    gender = gender_match.group(1)
    
    # 提取动作质量 | Extract movement quality
    basename_lower = basename.lower()
    if 'fatiguetest' in basename_lower or 'full' in basename_lower:
        movement = 'full'
    elif 'half' in basename_lower:
        movement = 'half'
    elif 'invalid' in basename_lower or 'wrong' in basename_lower:
        movement = 'invalid'
    else:
        return None
    
    return {'gender': gender, 'movement': movement}


def load_real_data(data_dir: str) -> Tuple[np.ndarray, np.ndarray, np.ndarray, List[str]]:
    """
    从目录加载真实数据
    Load real data from directory
    
    Returns:
        X: HHT matrices (N, 256, 256)
        y_movement: Movement labels (N,) - 0=full, 1=half, 2=invalid
        y_gender: Gender labels (N,) - 0=M, 1=F
        filenames: List of filenames
    """
    print(f'\n📂 从目录加载数据 | Loading data from: {data_dir}')
    
    npz_files = glob.glob(os.path.join(data_dir, '*.npz'))
    print(f'   找到 Found {len(npz_files)} .npz files')
    
    X_list = []
    y_movement_list = []
    y_gender_list = []
    filenames = []
    
    # 标签编码器 | Label encoders
    movement_encoder = LabelEncoder()
    movement_encoder.fit(['full', 'half', 'invalid'])
    
    gender_encoder = LabelEncoder()
    gender_encoder.fit(['M', 'F'])
    
    # 加载数据 | Load data
    for npz_file in tqdm(npz_files, desc='Loading files'):
        labels = parse_filename(npz_file)
        
        if labels is None:  # 测试文件 | Test file
            continue
        
        try:
            data = np.load(npz_file)
            if 'hht' in data:
                hht_matrix = data['hht']
            else:
                hht_matrix = data[list(data.keys())[0]]
            
            # 验证形状 | Verify shape
            if hht_matrix.shape != (256, 256):
                print(f'   ⚠️  跳过 Skipping {os.path.basename(npz_file)}: 形状不匹配 shape mismatch {hht_matrix.shape}')
                continue
            
            # 归一化到[0,1] | Normalize to [0,1]
            if DATA_NORMALIZE:
                hht_min = hht_matrix.min()
                hht_max = hht_matrix.max()
                if hht_max > hht_min:
                    hht_matrix = (hht_matrix - hht_min) / (hht_max - hht_min)
            
            # 编码标签 | Encode labels
            movement_label = movement_encoder.transform([labels['movement']])[0]
            gender_label = gender_encoder.transform([labels['gender']])[0]
            
            X_list.append(hht_matrix)
            y_movement_list.append(movement_label)
            y_gender_list.append(gender_label)
            filenames.append(npz_file)
            
        except Exception as e:
            print(f'   ❌ 错误 Error loading {os.path.basename(npz_file)}: {e}')
            continue
    
    X = np.array(X_list, dtype=np.float32)
    y_movement = np.array(y_movement_list, dtype=np.int64)
    y_gender = np.array(y_gender_list, dtype=np.int64)
    
    print(f'\n✅ 数据加载完成 | Data loading complete:')
    print(f'   样本数 Samples: {len(X)}')
    print(f'   形状 Shape: {X.shape}')
    print(f'   数据范围 Data range: [{X.min():.4f}, {X.max():.4f}]')
    
    print(f'\n📊 动作质量分布 | Movement quality distribution:')
    for i, movement in enumerate(['full', 'half', 'invalid']):
        count = np.sum(y_movement == i)
        print(f'   {movement}: {count} samples ({count/len(y_movement)*100:.1f}%)')
    
    print(f'\n👥 性别分布 | Gender distribution:')
    for i, gender in enumerate(['M', 'F']):
        count = np.sum(y_gender == i)
        print(f'   {gender}: {count} samples ({count/len(y_gender)*100:.1f}%)')
    
    return X, y_movement, y_gender, filenames


# 加载数据 | Load data
if os.path.exists(DATA_DIR):
    X, y_movement, y_gender, filenames = load_real_data(DATA_DIR)
    
    # 分割数据 | Split data
    X_train, X_val, y_movement_train, y_movement_val, y_gender_train, y_gender_val = train_test_split(
        X, y_movement, y_gender, 
        test_size=DATA_TEST_SIZE, 
        random_state=SEED, 
        stratify=y_movement  # 按动作质量分层 | Stratify by movement quality
    )
    
    print(f'\n✂️  数据分割 | Data split:')
    print(f'   训练集 Training: {len(X_train)} samples')
    print(f'   验证集 Validation: {len(X_val)} samples')
    
else:
    print(f'\n❌ 数据目录未找到 | Data directory not found: {DATA_DIR}')
    print('   请确保数据集已添加到此笔记本 | Please ensure dataset is added to this notebook')

## 6. 训练动作质量分类器（深度学习CNN）| Train Action Quality Classifier (Deep Learning CNN)

使用扩展的7层CNN架构训练动作质量分类器。
Train action quality classifier using expanded 7-layer CNN architecture.

### 训练改进 | Training Improvements:

1. **学习率预热** - 前5轮逐步增加学习率
   **LR Warmup** - Gradually increase LR for first 5 epochs
   
2. **梯度裁剪** - 防止梯度爆炸
   **Gradient Clipping** - Prevent gradient explosion
   
3. **标签平滑** - 提高泛化能力
   **Label Smoothing** - Improve generalization
   
4. **余弦退火调度** - 更好的学习率衰减
   **Cosine Annealing** - Better LR decay

In [None]:
class LabelSmoothingCrossEntropy(nn.Module):
    """
    标签平滑交叉熵损失
    Label smoothing cross entropy loss
    """
    def __init__(self, smoothing=0.1):
        super().__init__()
        self.smoothing = smoothing
    
    def forward(self, pred, target):
        n_classes = pred.size(-1)
        log_preds = F.log_softmax(pred, dim=-1)
        
        # 平滑目标 | Smooth targets
        with torch.no_grad():
            true_dist = torch.zeros_like(log_preds)
            true_dist.fill_(self.smoothing / (n_classes - 1))
            true_dist.scatter_(1, target.unsqueeze(1), 1.0 - self.smoothing)
        
        return torch.mean(torch.sum(-true_dist * log_preds, dim=-1))


def get_lr_schedule(optimizer, warmup_epochs, total_epochs, base_lr):
    """
    创建学习率调度器（预热 + 余弦退火）
    Create LR scheduler (warmup + cosine annealing)
    """
    def lr_lambda(epoch):
        if epoch < warmup_epochs:
            # 线性预热 | Linear warmup
            return (epoch + 1) / warmup_epochs
        else:
            # 余弦退火 | Cosine annealing
            progress = (epoch - warmup_epochs) / (total_epochs - warmup_epochs)
            return 0.5 * (1.0 + np.cos(np.pi * progress))
    
    return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)


def train_action_quality_model(
    model, 
    X_train, y_train, 
    X_val, y_val,
    epochs=100,
    batch_size=16,
    learning_rate=0.0001,
    warmup_epochs=5,
    grad_clip=1.0,
    device='cuda',
    resume_from=None,
    num_rounds=1,
    epochs_per_round=100
):
    """
    训练动作质量分类模型
    Train action quality classification model
    """
    print('\n' + '='*80)
    print('开始训练动作质量分类器 | Starting Action Quality Classifier Training')
    print('='*80)
    
    model = model.to(device)
    
    # 准备数据 | Prepare data
    if X_train.ndim == 3:
        X_train = X_train[:, np.newaxis, :, :]  # Add channel dim
        X_val = X_val[:, np.newaxis, :, :]
    
    train_dataset = torch.utils.data.TensorDataset(
        torch.tensor(X_train, dtype=torch.float32),
        torch.tensor(y_train, dtype=torch.long)
    )
    val_dataset = torch.utils.data.TensorDataset(
        torch.tensor(X_val, dtype=torch.float32),
        torch.tensor(y_val, dtype=torch.long)
    )
    
    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=batch_size, shuffle=True, 
        num_workers=0, pin_memory=True
    )
    val_loader = torch.utils.data.DataLoader(
        val_dataset, batch_size=batch_size, shuffle=False,
        num_workers=0, pin_memory=True
    )
    
    # 损失函数和优化器 | Loss and optimizer
    criterion = LabelSmoothingCrossEntropy(smoothing=0.1)
    optimizer = torch.optim.AdamW(
        model.parameters(), 
        lr=learning_rate, 
        weight_decay=ACTION_WEIGHT_DECAY
    )
    
    # 学习率调度器 | LR scheduler
    scheduler = get_lr_schedule(optimizer, warmup_epochs, epochs, learning_rate)
    
    # ReduceLROnPlateau作为备用 | ReduceLROnPlateau as backup
    plateau_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='max', factor=LR_SCHEDULER_FACTOR, 
        patience=LR_SCHEDULER_PATIENCE, min_lr=LR_SCHEDULER_MIN_LR
    )
    
    # 训练历史 | Training history
    history = {
        'train_loss': [],
        'train_acc': [],
        'val_loss': [],
        'val_acc': [],
        'lr': []
    }
    
    best_val_acc = 0.0
    best_model_path = os.path.join(CHECKPOINT_DIR, 'best_action_quality_model.pt')
    start_epoch = 0
    start_round = 0
    
    # 尝试从检查点恢复 | Try to resume from checkpoint
    if resume_from and os.path.exists(resume_from):
        print(f'\n📂 从检查点恢复训练 | Resuming training from checkpoint: {resume_from}')
        checkpoint = torch.load(resume_from, map_location=device, weights_only=False)
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        if 'history' in checkpoint:
            history = checkpoint['history']
        if 'best_val_acc' in checkpoint:
            best_val_acc = checkpoint['best_val_acc']
        if 'epoch' in checkpoint:
            start_epoch = checkpoint['epoch'] + 1
        if 'round' in checkpoint:
            start_round = checkpoint['round']
        print(f'   ✅ 已恢复到epoch {start_epoch}, 最佳验证准确率: {best_val_acc:.4f}')
        print(f'   ✅ Resumed to epoch {start_epoch}, best val acc: {best_val_acc:.4f}')
    
    # 计算总epoch数 | Calculate total epochs
    total_epochs = num_rounds * epochs_per_round
    
    print(f'\n🚀 训练配置 | Training configuration:')
    print(f'   设备 Device: {device}')
    print(f'   训练样本 Training samples: {len(X_train)}')
    print(f'   验证样本 Validation samples: {len(X_val)}')
    print(f'   训练轮数 Training rounds: {num_rounds}')
    print(f'   每轮epoch数 Epochs per round: {epochs_per_round}')
    print(f'   总epoch数 Total epochs: {total_epochs}')
    print(f'   起始epoch Starting epoch: {start_epoch}')
    print(f'   批次大小 Batch size: {batch_size}')
    print(f'   学习率 Learning rate: {learning_rate}')
    print(f'   预热轮数 Warmup epochs: {warmup_epochs}')
    print(f'   梯度裁剪 Gradient clipping: {grad_clip}')
    print(f'   权重衰减 Weight decay: {ACTION_WEIGHT_DECAY}')
    print()
    
    # 训练循环 | Training loop
    for epoch in range(start_epoch, total_epochs):
        current_round = epoch // epochs_per_round + 1
        # 训练阶段 | Training phase
        model.train()
        train_loss = 0.0
        train_correct = 0
        train_total = 0
        
        pbar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{epochs}')
        for batch_X, batch_y in pbar:
            batch_X = batch_X.to(device, non_blocking=True)
            batch_y = batch_y.to(device, non_blocking=True)
            
            optimizer.zero_grad()
            outputs = model(batch_X)
            loss = criterion(outputs, batch_y)
            loss.backward()
            
            # 梯度裁剪 | Gradient clipping
            if grad_clip > 0:
                torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
            
            optimizer.step()
            
            train_loss += loss.item() * batch_X.size(0)
            _, predicted = outputs.max(1)
            train_total += batch_y.size(0)
            train_correct += predicted.eq(batch_y).sum().item()
            
            # 更新进度条 | Update progress bar
            pbar.set_postfix({
                'loss': f'{loss.item():.4f}',
                'acc': f'{100.*train_correct/train_total:.2f}%'
            })
        
        train_loss /= train_total
        train_acc = train_correct / train_total
        
        # 验证阶段 | Validation phase
        model.eval()
        val_loss = 0.0
        val_correct = 0
        val_total = 0
        
        with torch.no_grad():
            for batch_X, batch_y in val_loader:
                batch_X = batch_X.to(device, non_blocking=True)
                batch_y = batch_y.to(device, non_blocking=True)
                
                outputs = model(batch_X)
                loss = criterion(outputs, batch_y)
                
                val_loss += loss.item() * batch_X.size(0)
                _, predicted = outputs.max(1)
                val_total += batch_y.size(0)
                val_correct += predicted.eq(batch_y).sum().item()
        
        val_loss /= val_total
        val_acc = val_correct / val_total
        
        # 更新学习率 | Update learning rate
        scheduler.step()
        plateau_scheduler.step(val_acc)
        current_lr = optimizer.param_groups[0]['lr']
        
        # 保存历史 | Save history
        history['train_loss'].append(train_loss)
        history['train_acc'].append(train_acc)
        history['val_loss'].append(val_loss)
        history['val_acc'].append(val_acc)
        history['lr'].append(current_lr)
        
        # 打印进度 | Print progress
        print(f'Round [{current_round}/{num_rounds}] Epoch [{epoch+1:3d}/{total_epochs}] | '
              f'Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.4f} | '
              f'Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.4f} | '
              f'LR: {current_lr:.6f}')
        
        # 保存最佳模型 | Save best model
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            temp_best_path = best_model_path + '.tmp'
            try:
                torch.save({
                    'epoch': epoch,
                    'round': current_round,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'best_val_acc': best_val_acc,
                    'history': history
                }, temp_best_path)
                os.replace(temp_best_path, best_model_path)
                print(f'  ⭐ 新最佳模型！| New best model! Val Acc: {val_acc:.4f}')
            except Exception as e:
                print(f'  ❌ 保存最佳模型失败 | Failed to save best model: {e}')
                if os.path.exists(temp_best_path):
                    os.remove(temp_best_path)
        
        # 定期保存检查点 | Save checkpoint periodically
        if (epoch + 1) % CHECKPOINT_INTERVAL == 0:
            checkpoint_path = os.path.join(CHECKPOINT_DIR, f'action_quality_round_{current_round}_epoch_{epoch+1}.pt')
            temp_path = checkpoint_path + '.tmp'
            
            # 检查磁盘空间 | Check disk space
            try:
                import shutil
                stat = shutil.disk_usage(CHECKPOINT_DIR)
                available_gb = stat.free / (1024 ** 3)
                if available_gb < 0.5:  # 少于500MB | Less than 500MB
                    print(f'  ⚠️  磁盘空间不足 ({available_gb:.2f} GB)，跳过检查点保存')
                    print(f'  ⚠️  Low disk space ({available_gb:.2f} GB), skipping checkpoint')
                else:
                    # 使用临时文件原子写入 | Atomic write with temp file
                    torch.save({
                        'epoch': epoch,
                        'round': current_round,
                        'model_state_dict': model.state_dict(),
                        'optimizer_state_dict': optimizer.state_dict(),
                        'best_val_acc': best_val_acc,
                        'history': history
                    }, temp_path)
                    os.replace(temp_path, checkpoint_path)
                    print(f'  💾 检查点已保存 | Checkpoint saved: {checkpoint_path}')
            except Exception as e:
                print(f'  ❌ 保存检查点失败 | Failed to save checkpoint: {e}')
                if os.path.exists(temp_path):
                    os.remove(temp_path)
    
    print('\n' + '='*80)
    print(f'✅ 训练完成！| Training complete!')
    print(f'   最佳验证准确率 Best validation accuracy: {best_val_acc:.4f}')
    print(f'   最佳模型已保存 Best model saved to: {best_model_path}')
    print('='*80)
    
    return history, best_model_path


# 创建并训练模型 | Create and train model
if 'X_train' in locals():
    # 创建编码器 | Create encoder
    encoder = ExpandedCNNEncoder(
        in_channels=MODEL_IN_CHANNELS,
        base_channels=MODEL_BASE_CHANNELS,
        num_layers=MODEL_NUM_LAYERS
    )
    
    # 创建动作质量分类器 | Create action quality classifier
    action_model = ActionQualityCNN(
        encoder=encoder,
        n_classes=3,  # full, half, invalid
        dropout_rate=MODEL_DROPOUT_RATE
    )
    
    # 检查是否有之前的检查点 | Check for previous checkpoint
    resume_checkpoint = os.path.join(CHECKPOINT_DIR, 'best_action_quality_model.pt')
    if not os.path.exists(resume_checkpoint):
        resume_checkpoint = None
        print('\n🆕 开始新的训练 | Starting new training')
    else:
        print(f'\n♻️  发现检查点，将继续训练 | Found checkpoint, will resume training')
    
    # 训练 | Train
    action_history, action_best_path = train_action_quality_model(
        model=action_model,
        X_train=X_train,
        y_train=y_movement_train,
        X_val=X_val,
        y_val=y_movement_val,
        epochs=ACTION_EPOCHS,
        batch_size=ACTION_BATCH_SIZE,
        learning_rate=ACTION_LEARNING_RATE,
        warmup_epochs=ACTION_WARMUP_EPOCHS,
        grad_clip=ACTION_GRAD_CLIP,
        device=device,
        resume_from=resume_checkpoint,
        num_rounds=NUM_TRAINING_ROUNDS,
        epochs_per_round=EPOCHS_PER_ROUND
    )
else:
    print('\n⚠️  请先加载数据 | Please load data first')

## 7. 可视化训练过程 | Visualize Training Process

In [None]:
def plot_training_history(history):
    """绘制训练历史 | Plot training history"""
    fig, axes = plt.subplots(1, 3, figsize=(18, 5))
    
    epochs = range(1, len(history['train_loss']) + 1)
    
    # Loss
    axes[0].plot(epochs, history['train_loss'], 'b-', label='训练损失 Train Loss', linewidth=2)
    axes[0].plot(epochs, history['val_loss'], 'r-', label='验证损失 Val Loss', linewidth=2)
    axes[0].set_xlabel('轮次 Epoch', fontsize=12)
    axes[0].set_ylabel('损失 Loss', fontsize=12)
    axes[0].set_title('训练和验证损失 | Training and Validation Loss', fontsize=14, fontweight='bold')
    axes[0].legend(fontsize=10)
    axes[0].grid(True, alpha=0.3)
    
    # Accuracy
    axes[1].plot(epochs, history['train_acc'], 'b-', label='训练准确率 Train Acc', linewidth=2)
    axes[1].plot(epochs, history['val_acc'], 'r-', label='验证准确率 Val Acc', linewidth=2)
    axes[1].set_xlabel('轮次 Epoch', fontsize=12)
    axes[1].set_ylabel('准确率 Accuracy', fontsize=12)
    axes[1].set_title('训练和验证准确率 | Training and Validation Accuracy', fontsize=14, fontweight='bold')
    axes[1].legend(fontsize=10)
    axes[1].grid(True, alpha=0.3)
    
    # Learning Rate
    axes[2].plot(epochs, history['lr'], 'g-', linewidth=2)
    axes[2].set_xlabel('轮次 Epoch', fontsize=12)
    axes[2].set_ylabel('学习率 Learning Rate', fontsize=12)
    axes[2].set_title('学习率调度 | Learning Rate Schedule', fontsize=14, fontweight='bold')
    axes[2].set_yscale('log')
    axes[2].grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()
    
    # 打印统计 | Print statistics
    print(f'\n📊 训练统计 | Training Statistics:')
    print(f'   最终训练准确率 Final train acc: {history["train_acc"][-1]:.4f}')
    print(f'   最终验证准确率 Final val acc: {history["val_acc"][-1]:.4f}')
    print(f'   最佳验证准确率 Best val acc: {max(history["val_acc"]):.4f}')
    print(f'   最终训练损失 Final train loss: {history["train_loss"][-1]:.4f}')
    print(f'   最终验证损失 Final val loss: {history["val_loss"][-1]:.4f}')


if 'action_history' in locals():
    plot_training_history(action_history)
else:
    print('\n⚠️  请先训练模型 | Please train the model first')

## 8. 训练性别分类器（SVM）| Train Gender Classifier (SVM)

使用训练好的CNN提取特征，然后训练SVM进行性别分类。
Use trained CNN to extract features, then train SVM for gender classification.

In [None]:
class GenderSVMClassifier:
    """
    性别SVM分类器
    Gender SVM Classifier
    
    使用CNN特征 + SVM分类器
    Uses CNN features + SVM classifier
    """
    def __init__(self, feature_extractor, svm_kernel='rbf', svm_C=10.0, svm_gamma='scale', device='cuda'):
        self.feature_extractor = feature_extractor
        self.device = device
        self.scaler = StandardScaler()
        self.svm = SVC(
            kernel=svm_kernel,
            C=svm_C,
            gamma=svm_gamma,
            probability=True,
            random_state=SEED
        )
        self.is_fitted = False
        
        print(f'✅ 性别分类器已创建 | Gender classifier created:')
        print(f'   特征提取器 Feature extractor: CNN')
        print(f'   SVM核函数 SVM kernel: {svm_kernel}')
        print(f'   C参数 C: {svm_C}')
        print(f'   Gamma: {svm_gamma}')
    
    def extract_features(self, X, batch_size=32):
        """提取特征 | Extract features"""
        self.feature_extractor.eval()
        
        if X.ndim == 3:
            X = X[:, np.newaxis, :, :]
        
        features_list = []
        with torch.no_grad():
            for i in tqdm(range(0, len(X), batch_size), desc='Extracting features'):
                batch = torch.tensor(X[i:i+batch_size], dtype=torch.float32).to(self.device)
                batch_features = self.feature_extractor.extract_features(batch)
                features_list.append(batch_features.cpu().numpy())
        
        return np.vstack(features_list)
    
    def fit(self, X, y, batch_size=32):
        """训练SVM | Train SVM"""
        print(f'\n🔧 开始训练性别SVM分类器 | Starting Gender SVM Classifier Training')
        
        # 提取特征 | Extract features
        features = self.extract_features(X, batch_size)
        
        # 归一化 | Normalize
        print('   归一化特征 | Normalizing features...')
        features_scaled = self.scaler.fit_transform(features)
        
        # 训练SVM | Train SVM
        print('   训练SVM | Training SVM...')
        self.svm.fit(features_scaled, y)
        
        self.is_fitted = True
        print('   ✅ 训练完成 | Training complete!')
    
    def predict(self, X, batch_size=32):
        """预测 | Predict"""
        if not self.is_fitted:
            raise RuntimeError('必须先训练模型 | Must fit model first')
        
        features = self.extract_features(X, batch_size)
        features_scaled = self.scaler.transform(features)
        return self.svm.predict(features_scaled)
    
    def predict_proba(self, X, batch_size=32):
        """预测概率 | Predict probabilities"""
        if not self.is_fitted:
            raise RuntimeError('必须先训练模型 | Must fit model first')
        
        features = self.extract_features(X, batch_size)
        features_scaled = self.scaler.transform(features)
        return self.svm.predict_proba(features_scaled)
    
    def evaluate(self, X, y, batch_size=32):
        """评估模型 | Evaluate model"""
        y_pred = self.predict(X, batch_size)
        accuracy = accuracy_score(y, y_pred)
        
        return {
            'accuracy': accuracy,
            'predictions': y_pred,
            'classification_report': classification_report(y, y_pred, target_names=['M', 'F']),
            'confusion_matrix': confusion_matrix(y, y_pred)
        }
    
    def save(self, path):
        """保存模型 | Save model"""
        with open(f'{path}_scaler.pkl', 'wb') as f:
            pickle.dump(self.scaler, f)
        with open(f'{path}_svm.pkl', 'wb') as f:
            pickle.dump(self.svm, f)
        print(f'💾 模型已保存 | Model saved to {path}_*.pkl')
    
    @classmethod
    def load(cls, path, feature_extractor, device='cuda'):
        """加载模型 | Load model"""
        classifier = cls(feature_extractor, device=device)
        with open(f'{path}_scaler.pkl', 'rb') as f:
            classifier.scaler = pickle.load(f)
        with open(f'{path}_svm.pkl', 'rb') as f:
            classifier.svm = pickle.load(f)
        classifier.is_fitted = True
        print(f'📂 模型已加载 | Model loaded from {path}_*.pkl')
        return classifier


# 训练性别分类器 | Train gender classifier
if 'action_model' in locals() and 'X_train' in locals():
    # 加载最佳动作质量模型用于特征提取 | Load best action quality model for feature extraction
    checkpoint = torch.load(action_best_path, map_location=device, weights_only=False)
    action_model.load_state_dict(checkpoint['model_state_dict'])
    action_model.eval()
    
    # 创建性别分类器 | Create gender classifier
    gender_classifier = GenderSVMClassifier(
        feature_extractor=action_model,
        svm_kernel=SVM_KERNEL,
        svm_C=SVM_C,
        svm_gamma=SVM_GAMMA,
        device=device
    )
    
    # 训练 | Train
    gender_classifier.fit(X_train, y_gender_train, batch_size=32)
    
    # 评估训练集 | Evaluate on training set
    print(f'\n📊 训练集评估 | Training Set Evaluation:')
    train_results = gender_classifier.evaluate(X_train, y_gender_train, batch_size=32)
    print(f'   训练准确率 Training Accuracy: {train_results["accuracy"]:.4f}')
    
    # 评估验证集 | Evaluate on validation set
    print(f'\n📊 验证集评估 | Validation Set Evaluation:')
    val_results = gender_classifier.evaluate(X_val, y_gender_val, batch_size=32)
    print(f'   验证准确率 Validation Accuracy: {val_results["accuracy"]:.4f}')
    print(f'\n分类报告 | Classification Report:')
    print(val_results['classification_report'])
    
    # 保存模型 | Save model
    gender_model_path = os.path.join(CHECKPOINT_DIR, 'gender_svm_model')
    gender_classifier.save(gender_model_path)
    
else:
    print('\n⚠️  请先训练动作质量模型 | Please train action quality model first')

## 9. 综合评估 | Comprehensive Evaluation

评估两个分类器的性能。
Evaluate performance of both classifiers.

In [None]:
# 综合评估 | Comprehensive evaluation
if 'action_model' in locals() and 'gender_classifier' in locals():
    print('='*80)
    print('综合评估报告 | COMPREHENSIVE EVALUATION REPORT')
    print('='*80)
    
    # 动作质量评估 | Action quality evaluation
    print(f'\n🎯 动作质量分类器（深度学习CNN）| Action Quality Classifier (Deep Learning CNN)')
    print('-'*80)
    
    action_model.eval()
    X_val_tensor = torch.tensor(X_val[:, np.newaxis, :, :], dtype=torch.float32).to(device)
    y_val_tensor = torch.tensor(y_movement_val, dtype=torch.long).to(device)
    
    with torch.no_grad():
        outputs = action_model(X_val_tensor)
        _, predictions = outputs.max(1)
        action_acc = (predictions == y_val_tensor).float().mean().item()
    
    y_pred_action = predictions.cpu().numpy()
    
    print(f'验证集准确率 Validation Accuracy: {action_acc:.4f}')
    print(f'\n分类报告 Classification Report:')
    print(classification_report(y_movement_val, y_pred_action, target_names=['Full', 'Half', 'Invalid']))
    
    print(f'\n混淆矩阵 Confusion Matrix:')
    cm_action = confusion_matrix(y_movement_val, y_pred_action)
    print(cm_action)
    
    # 性别评估 | Gender evaluation
    print(f'\n👥 性别分类器（SVM）| Gender Classifier (SVM)')
    print('-'*80)
    
    val_results_gender = gender_classifier.evaluate(X_val, y_gender_val, batch_size=32)
    print(f'验证集准确率 Validation Accuracy: {val_results_gender["accuracy"]:.4f}')
    print(f'\n分类报告 Classification Report:')
    print(val_results_gender['classification_report'])
    
    print(f'\n混淆矩阵 Confusion Matrix:')
    print(val_results_gender['confusion_matrix'])
    
    # 可视化混淆矩阵 | Visualize confusion matrices
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))
    
    # 动作质量混淆矩阵 | Action quality confusion matrix
    im1 = axes[0].imshow(cm_action, cmap='Blues')
    axes[0].set_title('动作质量混淆矩阵\nAction Quality Confusion Matrix', fontsize=12, fontweight='bold')
    axes[0].set_xlabel('预测标签 Predicted', fontsize=10)
    axes[0].set_ylabel('真实标签 True', fontsize=10)
    axes[0].set_xticks([0, 1, 2])
    axes[0].set_yticks([0, 1, 2])
    axes[0].set_xticklabels(['Full', 'Half', 'Invalid'])
    axes[0].set_yticklabels(['Full', 'Half', 'Invalid'])
    
    # 添加数值标注 | Add value annotations
    for i in range(3):
        for j in range(3):
            axes[0].text(j, i, str(cm_action[i, j]), 
                        ha='center', va='center', color='white' if cm_action[i, j] > cm_action.max()/2 else 'black')
    plt.colorbar(im1, ax=axes[0])
    
    # 性别混淆矩阵 | Gender confusion matrix
    cm_gender = val_results_gender['confusion_matrix']
    im2 = axes[1].imshow(cm_gender, cmap='Greens')
    axes[1].set_title('性别混淆矩阵\nGender Confusion Matrix', fontsize=12, fontweight='bold')
    axes[1].set_xlabel('预测标签 Predicted', fontsize=10)
    axes[1].set_ylabel('真实标签 True', fontsize=10)
    axes[1].set_xticks([0, 1])
    axes[1].set_yticks([0, 1])
    axes[1].set_xticklabels(['M', 'F'])
    axes[1].set_yticklabels(['M', 'F'])
    
    # 添加数值标注 | Add value annotations
    for i in range(2):
        for j in range(2):
            axes[1].text(j, i, str(cm_gender[i, j]),
                        ha='center', va='center', color='white' if cm_gender[i, j] > cm_gender.max()/2 else 'black')
    plt.colorbar(im2, ax=axes[1])
    
    plt.tight_layout()
    plt.show()
    
    print('\n' + '='*80)
    print('✅ 评估完成 | Evaluation Complete')
    print('='*80)
    
else:
    print('\n⚠️  请先训练两个模型 | Please train both models first')

## 10. 总结和下一步 | Summary and Next Steps

### 模型保存位置 | Model Save Locations:

- **动作质量分类器** Action Quality Classifier: `{CHECKPOINT_DIR}/best_action_quality_model.pt`
- **性别SVM分类器** Gender SVM Classifier: `{CHECKPOINT_DIR}/gender_svm_model_*.pkl`

### 关键改进总结 | Key Improvements Summary:

1. ✅ **扩展的7层CNN** - 更强的特征提取能力
   **Expanded 7-layer CNN** - Stronger feature extraction

2. ✅ **批归一化 + Kaiming初始化** - 解决梯度问题
   **BatchNorm + Kaiming Init** - Solve gradient issues

3. ✅ **学习率预热和调度** - 更稳定的训练
   **LR warmup and scheduling** - More stable training

4. ✅ **梯度裁剪** - 防止梯度爆炸
   **Gradient clipping** - Prevent gradient explosion

5. ✅ **标签平滑** - 提高泛化能力
   **Label smoothing** - Better generalization

6. ✅ **双分类器系统** - 专门优化每个任务
   **Dual classifier system** - Specialized optimization

### 使用建议 | Usage Recommendations:

1. **调整超参数** - 根据数据集大小调整学习率、批次大小等
   **Tune hyperparameters** - Adjust LR, batch size based on dataset size

2. **数据增强** - 如果训练数据较少，可以添加数据增强
   **Data augmentation** - Add if training data is limited

3. **模型集成** - 可以训练多个模型并集成预测结果
   **Model ensemble** - Train multiple models and ensemble predictions

4. **持续监控** - 观察训练曲线，确保loss下降、accuracy提升
   **Monitor training** - Watch training curves, ensure loss decreases and accuracy improves

### 预期效果 | Expected Results:

- **损失下降** Loss decreases: 应该看到明显的loss下降曲线
  You should see clear loss decrease curve
  
- **准确率提升** Accuracy improves: 准确率应该稳步提升
  Accuracy should steadily improve
  
- **收敛稳定** Stable convergence: 训练应该在50-100轮内收敛
  Training should converge within 50-100 epochs

如果仍然遇到训练问题，请检查：
If you still encounter training issues, check:

1. 数据质量和分布 | Data quality and distribution
2. 学习率是否合适 | Learning rate appropriateness  
3. 批次大小是否合适 | Batch size appropriateness
4. 是否需要更多数据 | Need for more data