In [7]:
"""
番茄疾病分类训练脚本 - 精简版
核心关注点：
1. 统一的验证函数，确保计算一致性
2. 只关注核心指标：Loss 和 Accuracy
3. 结构化、模块化的代码
"""

import os
import sys
import jittor as jt
from jittor import nn
import numpy as np
import time
from datetime import timedelta

# =========================================================================
# 设置项目路径
# =========================================================================
current_dir = os.path.abspath('.')
project_root = os.path.dirname(current_dir)

if project_root in sys.path:
    sys.path.remove(project_root)

sys.path.insert(0, project_root)

from data_loader import get_dataloader, CLASS_NAMES
from models.vit_model import Visual_Transformer
from config import Config

# =========================================================================
# 工具函数
# =========================================================================
def set_seed(seed=42):
    """设置随机种子以确保可复现性"""
    np.random.seed(seed)
    jt.set_global_seed(seed)

def format_time(seconds):
    """格式化时间显示"""
    hours = int(seconds // 3600)
    minutes = int((seconds % 3600) // 60)
    secs = int(seconds % 60)
    
    if hours > 0:
        return f"{hours}小时{minutes}分钟{secs}秒"
    elif minutes > 0:
        return f"{minutes}分钟{secs}秒"
    else:
        return f"{secs}秒"

# =========================================================================
# 核心验证函数（精简版）
# =========================================================================
def validate_batch_unified(images, labels, model, config):
    """
    统一的batch验证函数 - 精简版
    功能：计算 Loss 和 Accuracy
    """
    # 1. 过滤无效标签
    labels_np = labels.numpy()
    valid_mask = (labels_np >= 0) & (labels_np < config.NUM_CLASSES)
    valid_mask_jt = jt.array(valid_mask)
    valid_count = int(jt.sum(valid_mask_jt))
    
    if valid_count == 0:
        return None, None, True
    
    if valid_count < len(images):
        images = images[valid_mask_jt]
        labels = labels[valid_mask_jt]
    
    # 2. 前向传播
    outputs = model(images)
    
    if jt.isnan(outputs).any():
        return None, None, True
    
    # 3. 计算loss
    loss = nn.cross_entropy_loss(outputs, labels)
    
    if jt.isnan(loss).any():
        return None, None, True
    
    # 4. 计算准确率
    preds = jt.argmax(outputs, dim=1)[0]
    correct_count = jt.sum(preds == labels)
    acc = float(correct_count) / labels.shape[0]
    
    return loss, acc, False

# =========================================================================
# 训练和验证函数 (精简版)
# =========================================================================
def train_one_epoch(model, train_loader, optimizer, config, epoch, num_epochs):
    """训练一个epoch"""
    model.train()
    
    train_loss = 0.0
    train_acc = 0.0
    train_batches = 0
    train_skipped = 0
    
    num_batches = (len(train_loader) + config.BATCH_SIZE - 1) // config.BATCH_SIZE
    start_time = time.time()
    
    for batch_idx, (images, labels) in enumerate(train_loader):
        loss, acc, skipped = validate_batch_unified(images, labels, model, config)
        
        if skipped:
            train_skipped += 1
            continue
        
        optimizer.step(loss)
        
        train_loss += float(loss)
        train_acc += acc
        train_batches += 1
        
        if (batch_idx + 1) % 100 == 0 or (batch_idx + 1) == num_batches:
            avg_loss = train_loss / train_batches
            avg_acc = train_acc / train_batches
            print(f"  Batch [{batch_idx+1:4d}/{num_batches}] | "
                  f"Loss: {avg_loss:.4f} | Acc: {avg_acc*100:.2f}%")
    
    elapsed = time.time() - start_time
    
    avg_loss = train_loss / train_batches if train_batches > 0 else 0.0
    avg_acc = train_acc / train_batches if train_batches > 0 else 0.0
    
    return avg_loss, avg_acc, train_batches, train_skipped, elapsed


def validate(model, valid_loader, config):
    """验证模型 - 精简版"""
    model.eval()
    
    val_loss = 0.0
    val_acc = 0.0
    val_batches = 0
    val_skipped = 0
    
    start_time = time.time()
    
    with jt.no_grad():
        for images, labels in valid_loader:
            loss, acc, skipped = validate_batch_unified(images, labels, model, config)
            
            if skipped:
                val_skipped += 1
                continue
            
            val_loss += float(loss)
            val_acc += acc
            val_batches += 1
    
    elapsed = time.time() - start_time
    
    avg_loss = val_loss / val_batches if val_batches > 0 else 0.0
    avg_acc = val_acc / val_batches if val_batches > 0 else 0.0
    
    return avg_loss, avg_acc, val_batches, val_skipped, elapsed


def print_training_summary(epoch, num_epochs, train_stats, val_stats):
    """打印训练总结 - 精简版"""
    train_loss, train_acc, train_batches, train_skipped, train_time = train_stats
    val_loss, val_acc, val_batches, val_skipped, val_time = val_stats
    
    print(f"\n{'='*70}")
    print(f"Epoch [{epoch+1}/{num_epochs}] 总结")
    print(f"{'='*70}")
    
    # 训练统计
    print(f"\n【训练阶段】 耗时: {format_time(train_time)}")
    print(f"  有效batch: {train_batches}")
    if train_skipped > 0:
        print(f"  跳过batch: {train_skipped}")
    print(f"  平均Loss:  {train_loss:.4f}")
    print(f"  平均Acc:   {train_acc*100:.2f}%")
    
    # 验证统计
    print(f"\n【验证阶段】 耗时: {format_time(val_time)}")
    print(f"  有效batch: {val_batches}")
    if val_skipped > 0:
        print(f"  跳过batch: {val_skipped}")
    print(f"  平均Loss:  {val_loss:.4f}")
    print(f"  平均Acc:   {val_acc*100:.2f}%") # 唯一的、核心的验证准确率
    
    return val_acc # 返回平均验证准确率用于比较


# =========================================================================
# 主训练函数
# =========================================================================
def train():
    """主训练流程"""
    # 初始化
    config = Config()
    jt.flags.use_cuda = 1 if jt.has_cuda else 0
    set_seed(42)
    
    # 训练配置
    EPOCHS = config.EPOCHS
    BATCH_SIZE = config.BATCH_SIZE
    LEARNING_RATE = config.LEARNING_RATE
    TRAIN_SAMPLE_RATIO = 1.0
    VALID_SAMPLE_RATIO = 1.0
    IMG_SIZE = config.IMG_SIZE
    
    print("\n" + "="*70)
    print("番茄疾病分类训练 (精简版)")
    print("="*70)
    print(f"\n【训练配置】")
    print(f"  Epochs:        {EPOCHS}")
    print(f"  Batch Size:    {BATCH_SIZE}")
    print(f"  Learning Rate: {LEARNING_RATE}")
    # ... 其他打印 ...
    
    # 数据加载
    print(f"\n【数据加载】")
    data_root = os.path.join(project_root, 'tomato_yolo_dataset')
    save_dir = os.path.join(project_root, 'checkpoints')
    os.makedirs(save_dir, exist_ok=True)
    
    train_loader = get_dataloader(data_root, 'train', BATCH_SIZE, IMG_SIZE, True, 0, TRAIN_SAMPLE_RATIO)
    valid_loader = get_dataloader(data_root, 'val', BATCH_SIZE, IMG_SIZE, False, 0, VALID_SAMPLE_RATIO)
    
    num_train_batches = (len(train_loader) + BATCH_SIZE - 1) // BATCH_SIZE
    num_valid_batches = (len(valid_loader) + BATCH_SIZE - 1) // BATCH_SIZE
    
    print(f"  训练集: {len(train_loader):5d} 样本, {num_train_batches:4d} batches")
    print(f"  验证集: {len(valid_loader):5d} 样本, {num_valid_batches:4d} batches")
    
    # 创建模型
    print(f"\n【模型构建】")
    model = Visual_Transformer(
        img_size=config.IMG_SIZE,
        patch_size=config.PATCH_SIZE,
        # ... 其他模型参数 ...
        hidden_dim=config.MLP_Hidden_Dim
    )
    # ... 模型参数打印 ...
    
    # 优化器
    optimizer = nn.Adam(model.parameters(), lr=LEARNING_RATE, weight_decay=1e-4)
    
    # 训练记录
    best_val_acc = 0.0
    best_epoch = 0
    training_history = {'train_loss': [], 'train_acc': [], 'val_loss': [], 'val_acc': []}
    
    total_start_time = time.time()
    
    print(f"\n{'='*70}\n开始训练\n{'='*70}")
    
    # 训练循环
    for epoch in range(EPOCHS):
        epoch_start_time = time.time()
        
        print(f"\n--- Epoch [{epoch+1}/{EPOCHS}] ---")
        
        # 训练
        print(f"[训练中...]")
        train_stats = train_one_epoch(model, train_loader, optimizer, config, epoch, EPOCHS)
        
        # 验证
        print(f"[验证中...]")
        val_stats = validate(model, valid_loader, config)
        
        # 打印总结
        current_val_acc = print_training_summary(epoch, EPOCHS, train_stats, val_stats)
        
        # 记录历史
        training_history['train_loss'].append(train_stats[0])
        training_history['train_acc'].append(train_stats[1])
        training_history['val_loss'].append(val_stats[0])
        training_history['val_acc'].append(val_stats[1])
        
        # 保存最佳模型
        if current_val_acc > best_val_acc:
            best_val_acc = current_val_acc
            best_epoch = epoch + 1
            model_path = os.path.join(save_dir, 'best_model.pkl')
            jt.save(model.state_dict(), model_path)
            print(f"\n  ✅ 保存最佳模型! (Epoch {best_epoch}, 平均验证准确率: {best_val_acc*100:.2f}%)")
        
        # 定期保存checkpoint
        if (epoch + 1) % 5 == 0:
            checkpoint_path = os.path.join(save_dir, f'checkpoint_epoch_{epoch+1}.pkl')
            jt.save(model.state_dict(), checkpoint_path)
            print(f"  ✅ 保存checkpoint: epoch_{epoch+1}")
        
        epoch_time = time.time() - epoch_start_time
        print(f"\n  本轮耗时: {format_time(epoch_time)}")
    
    # 训练完成
    total_time = time.time() - total_start_time
    
    print(f"\n{'='*70}\n训练完成!\n{'='*70}")
    print(f"\n【训练总结】")
    print(f"  总耗时:         {format_time(total_time)}")
    print(f"  最佳Epoch:      {best_epoch}/{EPOCHS}")
    print(f"  最佳平均验证准确率: {best_val_acc*100:.2f}%")
    print(f"\n{'='*70}\n")


if __name__ == '__main__':
    try:
        train()
    except KeyboardInterrupt:
        print("\n\n训练被用户中断")
    except Exception as e:
        print(f"\n\n训练过程中发生错误: {e}")
        import traceback
        traceback.print_exc()


番茄疾病分类训练 (精简版)

【训练配置】
  Epochs:        15
  Batch Size:    16
  Learning Rate: 0.001

【数据加载】
  训练集: 14526 样本,  908 batches
  验证集:  3632 样本,  227 batches

【模型构建】

开始训练

--- Epoch [1/15] ---
[训练中...]
  Batch [ 100/908] | Loss: 2.0155 | Acc: 30.44%
  Batch [ 200/908] | Loss: 1.9284 | Acc: 33.16%
  Batch [ 300/908] | Loss: 1.8987 | Acc: 34.10%
  Batch [ 400/908] | Loss: 1.8432 | Acc: 35.92%
  Batch [ 500/908] | Loss: 1.7868 | Acc: 37.74%
  Batch [ 600/908] | Loss: 1.7570 | Acc: 39.02%
  Batch [ 700/908] | Loss: 1.7233 | Acc: 40.13%
  Batch [ 800/908] | Loss: 1.6869 | Acc: 41.37%
  Batch [ 900/908] | Loss: 1.6518 | Acc: 42.67%
  Batch [ 908/908] | Loss: 1.6492 | Acc: 42.77%
[验证中...]

Epoch [1/15] 总结

【训练阶段】 耗时: 1分钟27秒
  有效batch: 908
  平均Loss:  1.6492
  平均Acc:   42.77%

【验证阶段】 耗时: 11秒
  有效batch: 227
  平均Loss:  1.2681
  平均Acc:   56.50%

  ✅ 保存最佳模型! (Epoch 1, 平均验证准确率: 56.50%)

  本轮耗时: 1分钟38秒

--- Epoch [2/15] ---
[训练中...]
  Batch [ 100/908] | Loss: 1.3773 | Acc: 52.62%
  Batch [ 200/908] | L