In [1]:
import os
import sys
import jittor as jt
from jittor import nn
import numpy as np
import time
from datetime import timedelta

current_dir = os.path.abspath('.')
project_root = os.path.dirname(current_dir)

if project_root in sys.path:
    sys.path.remove(project_root)

sys.path.insert(0, project_root)

from data_loader import get_dataloader, CLASS_NAMES
from models.vit_model import Visual_Transformer
from config import Config


def set_seed(seed=42):
    np.random.seed(seed)
    jt.set_global_seed(seed)


def calculate_accuracy(outputs, labels):
    preds = jt.argmax(outputs, dim=1)[0]
    correct = jt.sum(preds == labels)
    return float(correct) / labels.shape[0]


def validate_labels(labels, num_classes):
    """快速标签验证，返回 jittor bool array"""
    labels_np = labels.numpy()
    valid_mask = (labels_np >= 0) & (labels_np < num_classes)
    return jt.array(valid_mask)


def format_time(seconds):
    return str(timedelta(seconds=int(seconds)))


def print_progress_bar(current, total, prefix='', suffix='', length=40):
    if total == 0:
        return
    percent = 100 * (current / float(total))
    filled = int(length * current // total)
    bar = '█' * filled + '░' * (length - filled)
    print(f'\r{prefix} |{bar}| {percent:.1f}% {suffix}', end='', flush=True)


def train():
    config = Config()
    jt.flags.use_cuda = 1 if jt.has_cuda else 0
    set_seed(42)
    
    # ========== 配置 ==========
    EPOCHS = 15
    BATCH_SIZE = 8
    NUM_WORKERS = 0
    TRAIN_SAMPLE_RATIO = 0.2
    VAL_SAMPLE_RATIO = 0.2
    PRINT_FREQ = 50
    
    data_root = os.path.join(project_root, 'tomato_yolo_dataset')
    save_dir = os.path.join(project_root, 'checkpoints')
    os.makedirs(save_dir, exist_ok=True)
    
    print(f"\n{'='*70}")
    print(f"🚀 训练配置")
    print(f"{'='*70}")
    print(f"设备: {'🎮 GPU' if jt.flags.use_cuda else '💻 CPU'}")
    print(f"Batch Size: {BATCH_SIZE}")
    print(f"训练轮数: {EPOCHS}")
    print(f"训练采样率: {TRAIN_SAMPLE_RATIO*100:.0f}%")
    print(f"验证采样率: {VAL_SAMPLE_RATIO*100:.0f}%")
    print(f"{'='*70}\n")
    
    # ========== 数据加载 ==========
    print("📦 加载数据...", flush=True)
    
    train_loader = get_dataloader(
        root_dir=data_root,
        mode='train',
        batch_size=BATCH_SIZE,
        img_size=config.IMG_SIZE,
        shuffle=True,
        num_workers=NUM_WORKERS,
        sample_ratio=TRAIN_SAMPLE_RATIO
    )
    
    val_loader = get_dataloader(
        root_dir=data_root,
        mode='val',
        batch_size=BATCH_SIZE,
        img_size=config.IMG_SIZE,
        shuffle=False,
        num_workers=NUM_WORKERS,
        sample_ratio=VAL_SAMPLE_RATIO
    )
    
    total_train_batches = None
    total_val_batches = None
    
    print(f"✓ 数据加载完成\n")
    
    # ========== 快速标签检查 ==========
    print("🔍 检查数据集标签（前10个batch）...", flush=True)
    check_count = 0
    invalid_count = 0
    check_batches = 0
    for idx, (_, labels) in enumerate(train_loader):
        if idx >= 10:
            break
        labels_np = labels.numpy()
        check_count += len(labels_np)
        invalid_count += ((labels_np < 0) | (labels_np >= config.NUM_CLASSES)).sum()
        check_batches += 1
    
    if invalid_count > 0:
        print(f"  ⚠️  发现 {invalid_count}/{check_count} 个异常标签 ({invalid_count/check_count*100:.1f}%)")
        print(f"  训练时会自动过滤\n")
    else:
        print(f"  ✓ 标签正常 (检查了{check_batches}个batch)\n")
    
    # ========== 创建模型 ==========
    print("🏗️  创建模型...", flush=True)
    
    model = Visual_Transformer(
        img_size=config.IMG_SIZE,
        patch_size=config.PATCH_SIZE,
        in_channels=config.IN_CHANNELS,
        embed_dim=config.EMBED_DIM,
        depth=config.NUM_LAYERS,
        num_heads=config.NUM_HEADS,
        dropout_rate=0.1,
        hidden_dim=config.MLP_Hidden_Dim
    )
    
    print("  测试前向传播...", flush=True)
    model.eval()
    with jt.no_grad():
        for imgs, lbls in train_loader:
            valid_mask = validate_labels(lbls, config.NUM_CLASSES)
            valid_count = int(jt.sum(valid_mask))
            if valid_count > 0:
                valid_imgs = imgs[valid_mask]
                test_out = model(valid_imgs[:min(2, valid_count)])
                print(f"  ✓ 测试通过: {test_out.shape}\n")
                break
    
    # ========== 优化器 ==========
    optimizer = nn.Adam(model.parameters(), lr=config.LEARNING_RATE, weight_decay=1e-4)
    
    # ========== 训练循环 ==========
    best_acc = 0.0
    best_balanced_acc = 0.0
    total_start_time = time.time()
    
    print(f"{'='*70}")
    print(f"🎯 开始训练 ({time.strftime('%H:%M:%S')})")
    print(f"{'='*70}\n", flush=True)
    
    for epoch in range(EPOCHS):
        epoch_start_time = time.time()
        
        print(f"\n{'='*70}")
        print(f"📅 Epoch [{epoch+1}/{EPOCHS}]")
        print(f"{'='*70}")
        
        # ========== 训练阶段 ==========
        print(f"\n🔥 训练中...", flush=True)
        model.train()
        
        train_loss_sum = 0.0
        train_acc_sum = 0.0
        train_batches = 0
        skipped = 0
        actual_batch_count = 0
        
        train_start = time.time()
        
        for batch_idx, (images, labels) in enumerate(train_loader):
            actual_batch_count = batch_idx + 1
            
            if batch_idx == 0:
                print(f"  ⏳ 第一个batch...", end='', flush=True)
                first_start = time.time()
            
            valid_mask = validate_labels(labels, config.NUM_CLASSES)
            valid_count = int(jt.sum(valid_mask))
            
            if valid_count == 0:
                skipped += 1
                continue
            
            if valid_count < len(labels):
                images = images[valid_mask]
                labels = labels[valid_mask]
            
            outputs = model(images)
            if jt.isnan(outputs).any():
                skipped += 1
                continue
            
            loss = nn.cross_entropy_loss(outputs, labels)
            if jt.isnan(loss).any():
                skipped += 1
                continue
            
            optimizer.step(loss)
            
            train_loss_sum += float(loss)
            train_acc_sum += calculate_accuracy(outputs, labels)
            train_batches += 1
            
            if batch_idx == 0:
                first_time = time.time() - first_start
                print(f" ✓ ({first_time:.1f}秒)", flush=True)
            
            if (batch_idx + 1) % 10 == 0:
                avg_loss = train_loss_sum / train_batches if train_batches > 0 else 0
                avg_acc = train_acc_sum / train_batches if train_batches > 0 else 0
                elapsed = time.time() - train_start
                speed = (batch_idx + 1) / elapsed
                
                if total_train_batches is not None:
                    eta = (total_train_batches - batch_idx - 1) / speed if speed > 0 else 0
                    suffix = f"Loss: {avg_loss:.4f} | Acc: {avg_acc*100:.1f}% | {speed:.1f}it/s | ETA: {int(eta)}s"
                    print_progress_bar(batch_idx + 1, total_train_batches, prefix='  进度', suffix=suffix)
                else:
                    suffix = f"Loss: {avg_loss:.4f} | Acc: {avg_acc*100:.1f}% | {speed:.1f}it/s"
                    print(f'\r  进度: {batch_idx+1} batches | {suffix}', end='', flush=True)
            
            if (batch_idx + 1) % PRINT_FREQ == 0:
                avg_loss = train_loss_sum / train_batches if train_batches > 0 else 0
                avg_acc = train_acc_sum / train_batches if train_batches > 0 else 0
                elapsed = time.time() - train_start
                speed = (batch_idx + 1) / elapsed
                
                if total_train_batches is not None:
                    print(f"\n  📊 [{batch_idx+1}/{total_train_batches}] "
                          f"Loss: {avg_loss:.4f} | Acc: {avg_acc*100:.2f}% | "
                          f"{speed:.1f} batch/s", flush=True)
                else:
                    print(f"\n  📊 [{batch_idx+1}] "
                          f"Loss: {avg_loss:.4f} | Acc: {avg_acc*100:.2f}% | "
                          f"{speed:.1f} batch/s", flush=True)
        
        if total_train_batches is None:
            total_train_batches = actual_batch_count
            print(f"\n\n  ℹ️  检测到训练集实际batch数: {total_train_batches}")
        
        print()
        
        train_time = time.time() - train_start
        avg_train_loss = train_loss_sum / train_batches if train_batches > 0 else 0
        avg_train_acc = train_acc_sum / train_batches if train_batches > 0 else 0
        
        print(f"\n  ✅ 训练完成 ({train_time:.1f}秒)")
        print(f"     遍历: {actual_batch_count} batches")
        print(f"     有效: {train_batches} batches")
        print(f"     Loss: {avg_train_loss:.4f} | Acc: {avg_train_acc*100:.2f}%")
        if skipped > 0:
            print(f"     跳过: {skipped} 个异常batch")
        print(flush=True)
        
        # ========== 验证阶段 ==========
        print(f"🔍 验证中...", flush=True)
        model.eval()
        
        val_loss_sum = 0.0
        val_acc_sum = 0.0
        val_batches = 0
        class_correct = np.zeros(config.NUM_CLASSES, dtype=np.int64)
        class_total = np.zeros(config.NUM_CLASSES, dtype=np.int64)
        val_skipped = 0
        actual_val_batch_count = 0
        
        val_start = time.time()
        
        with jt.no_grad():
            for batch_idx, (images, labels) in enumerate(val_loader):
                actual_val_batch_count = batch_idx + 1
                
                valid_mask = validate_labels(labels, config.NUM_CLASSES)
                valid_count = int(jt.sum(valid_mask))
                
                if valid_count == 0:
                    val_skipped += 1
                    continue
                
                if valid_count < len(labels):
                    images = images[valid_mask]
                    labels = labels[valid_mask]
                
                outputs = model(images)
                if jt.isnan(outputs).any():
                    val_skipped += 1
                    continue
                
                loss = nn.cross_entropy_loss(outputs, labels)
                val_loss_sum += float(loss)
                val_acc_sum += calculate_accuracy(outputs, labels)
                val_batches += 1
                
                # ✅ 统计各类别
                preds = jt.argmax(outputs, dim=1)[0].numpy()
                labels_np = labels.numpy()
                for pred, label in zip(preds, labels_np):
                    label = int(label)
                    pred = int(pred)
                    if 0 <= label < config.NUM_CLASSES:
                        class_total[label] += 1
                        if pred == label:
                            class_correct[label] += 1
                
                if (batch_idx + 1) % 20 == 0:
                    avg_loss = val_loss_sum / val_batches if val_batches > 0 else 0
                    avg_acc = val_acc_sum / val_batches if val_batches > 0 else 0
                    
                    if total_val_batches is not None:
                        suffix = f"Loss: {avg_loss:.4f} | Acc: {avg_acc*100:.1f}%"
                        print_progress_bar(batch_idx + 1, total_val_batches, prefix='  进度', suffix=suffix)
                    else:
                        suffix = f"Loss: {avg_loss:.4f} | Acc: {avg_acc*100:.1f}%"
                        print(f'\r  进度: {batch_idx+1} batches | {suffix}', end='', flush=True)
        
        if total_val_batches is None:
            total_val_batches = actual_val_batch_count
            print(f"\n\n  ℹ️  检测到验证集实际batch数: {total_val_batches}")
        
        print()
        
        val_time = time.time() - val_start
        
        if val_batches > 0:
            avg_val_loss = val_loss_sum / val_batches
            avg_val_acc = val_acc_sum / val_batches
            
            # ✅ 计算真实准确率（基于类别统计）
            total_correct = int(np.sum(class_correct))
            total_samples = int(np.sum(class_total))
            real_acc = total_correct / total_samples if total_samples > 0 else 0.0
            
            # 平衡准确率
            class_accs = []
            for i in range(config.NUM_CLASSES):
                if class_total[i] > 0:
                    acc = class_correct[i] / class_total[i]
                    class_accs.append(acc)
            balanced_acc = np.mean(class_accs) if class_accs else 0.0
            
            print(f"\n  ✅ 验证完成 ({val_time:.1f}秒)")
            print(f"     遍历: {actual_val_batch_count} batches")
            print(f"     有效: {val_batches} batches")
            print(f"     样本: {total_samples} 个 (正确: {total_correct})")
            print(f"     Loss: {avg_val_loss:.4f}")
            # ✅ 同时显示两种准确率
            print(f"     准确率: {real_acc*100:.2f}% (类别统计)")
            if abs(real_acc - avg_val_acc) > 0.01:  # 如果差异>1%，显示警告
                print(f"     ⚠️  Batch平均: {avg_val_acc*100:.2f}% (可能不准确)")
            print(f"     平衡准确率: {balanced_acc*100:.2f}%")
            if val_skipped > 0:
                print(f"     跳过: {val_skipped} 个异常batch")
            print(flush=True)
            
            # 每5轮显示类别详情
            if (epoch + 1) % 5 == 0 or epoch == EPOCHS - 1:
                print(f"\n  📊 各类别准确率:")
                for i in range(config.NUM_CLASSES):
                    if i < len(CLASS_NAMES):
                        if class_total[i] > 0:
                            acc = class_correct[i] / class_total[i] * 100
                            correct_count = int(class_correct[i])
                            total_count = int(class_total[i])
                            print(f"     {CLASS_NAMES[i]:<28}: {acc:>5.1f}% ({correct_count}/{total_count})")
                        else:
                            print(f"     {CLASS_NAMES[i]:<28}: {'N/A':>5} (0/0)")
                
                # ✅ 显示验证信息
                print(f"\n  ℹ️  准确率验证:")
                print(f"     各类正确数总和: {total_correct}")
                print(f"     各类样本数总和: {total_samples}")
                print(f"     手动计算准确率: {total_correct}/{total_samples} = {real_acc*100:.2f}%")
            
            # ✅ 使用真实准确率保存最佳模型
            if balanced_acc > best_balanced_acc:
                improvement = (balanced_acc - best_balanced_acc) * 100
                best_balanced_acc = balanced_acc
                best_acc = real_acc  # ✅ 使用真实准确率
                model_path = os.path.join(save_dir, 'best_model.pkl')
                jt.save(model.state_dict(), model_path)
                print(f"\n  🎉 新最佳! 平衡准确率: {balanced_acc*100:.2f}% (↑{improvement:.2f}%)", flush=True)
        
        # Epoch总结
        epoch_time = time.time() - epoch_start_time
        total_elapsed = time.time() - total_start_time
        avg_epoch_time = total_elapsed / (epoch + 1)
        remaining = avg_epoch_time * (EPOCHS - epoch - 1)
        
        print(f"\n  ⏱️  本轮: {epoch_time/60:.1f}min | 已用: {total_elapsed/60:.1f}min | 剩余: {remaining/60:.1f}min")
        print(f"  📅 预计完成: {time.strftime('%H:%M:%S', time.localtime(time.time() + remaining))}", flush=True)
    
    # 完成
    total_time = time.time() - total_start_time
    
    print(f"\n{'='*70}")
    print(f"🎊 训练完成! 总耗时: {total_time/60:.1f}分钟")
    print(f"{'='*70}")
    print(f"最佳准确率: {best_acc*100:.2f}%")
    print(f"最佳平衡准确率: {best_balanced_acc*100:.2f}%")
    print(f"模型保存: {os.path.join(save_dir, 'best_model.pkl')}")
    print(f"{'='*70}\n", flush=True)


if __name__ == '__main__':
    train()

[38;5;2m[i 1025 10:53:29.945191 28 log.cc:351] Load log_sync: 1[m
[38;5;2m[i 1025 10:53:29.983234 28 compiler.py:956] Jittor(1.3.10.0) src: /home/jittor/SCC_Model/ViT/.venv/lib/python3.10/site-packages/jittor[m
[38;5;2m[i 1025 10:53:29.987114 28 compiler.py:957] g++ at /usr/bin/g++(11.4.0)[m
[38;5;2m[i 1025 10:53:29.988391 28 compiler.py:958] cache_path: /home/jittor/.cache/jittor/jt1.3.10/g++11.4.0/py3.10.12/Linux-6.6.87.2x4a/AMDRyzen97940Hxd7/fa38/main[m
[38;5;2m[i 1025 10:53:30.071170 28 install_cuda.py:96] cuda_driver_version: [12, 9][m
[38;5;2m[i 1025 10:53:30.076601 28 __init__.py:412] Found /home/jittor/.cache/jittor/jtcuda/cuda12.2_cudnn8_linux/bin/nvcc(12.2.140) at /home/jittor/.cache/jittor/jtcuda/cuda12.2_cudnn8_linux/bin/nvcc.[m
[38;5;2m[i 1025 10:53:30.124254 28 __init__.py:412] Found addr2line(2.38) at /usr/bin/addr2line.[m
[38;5;2m[i 1025 10:53:30.198392 28 compiler.py:1013] cuda key:cu12.2.140[m
[38;5;2m[i 1025 10:53:30.633069 28 __init__.py:227] Total 


🚀 训练配置
设备: 🎮 GPU
Batch Size: 8
训练轮数: 15
训练采样率: 20%
验证采样率: 20%

📦 加载数据...
train 数据集: 从 14527 个样本中采样了 2905 个样本 (20.0%)
val 数据集: 从 3632 个样本中采样了 726 个样本 (20.0%)
✓ 数据加载完成

🔍 检查数据集标签（前10个batch）...
  ⚠️  发现 8/80 个异常标签 (10.0%)
  训练时会自动过滤

🏗️  创建模型...
  测试前向传播...
  ✓ 测试通过: [2,10,]

🎯 开始训练 (10:53:31)


📅 Epoch [1/15]

🔥 训练中...
  ⏳ 第一个batch... ✓ (0.4秒)
  进度: 50 batches | Loss: 2.1742 | Acc: 30.8% | 14.6it/s
  📊 [50] Loss: 2.1742 | Acc: 30.75% | 14.6 batch/s
  进度: 100 batches | Loss: 2.0259 | Acc: 34.4% | 16.1it/s
  📊 [100] Loss: 2.0259 | Acc: 34.38% | 16.1 batch/s
  进度: 150 batches | Loss: 1.9632 | Acc: 34.7% | 16.7it/s
  📊 [150] Loss: 1.9632 | Acc: 34.67% | 16.7 batch/s
  进度: 200 batches | Loss: 1.9217 | Acc: 34.9% | 17.0it/s
  📊 [200] Loss: 1.9217 | Acc: 34.94% | 17.0 batch/s
  进度: 250 batches | Loss: 1.9163 | Acc: 34.8% | 17.3it/s
  📊 [250] Loss: 1.9163 | Acc: 34.80% | 17.3 batch/s
  进度: 300 batches | Loss: 1.9106 | Acc: 34.5% | 17.5it/s
  📊 [300] Loss: 1.9106 | Acc: 34.46% | 17.5 batch/s
  进