# TomatoMAP-Cls Trainer

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import os
from tqdm import tqdm
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
from pathlib import Path
import argparse
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
from PIL import ImageDraw, ImageFont
import torchvision.transforms as transforms

# env checker
print("check env:")
print(f"  PyTorch version: {torch.__version__}")
print(f"  CUDA version: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"  GPU device: {torch.cuda.get_device_name(0)}")
    print(f"  GPU ram: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB")

print("env check done")

check env:
  PyTorch version: 2.3.0
  CUDA version: True
  GPU device: NVIDIA GeForce RTX 4060 Laptop GPU
  GPU ram: 8.0 GB
env check done


In [2]:
def save_model(model, path):
    torch.save(model.state_dict(), path)
    print(f"model saved at: {path}")

def load_model(model, path, device='cpu'):
    model.load_state_dict(torch.load(path, map_location=device))
    model.eval()
    print(f"model loaded from: {path}")

def save_checkpoint(model, optimizer, epoch, path):
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict()
    }, path)

def load_checkpoint(path, model, optimizer, device):
    checkpoint = torch.load(path, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    start_epoch = checkpoint['epoch'] + 1
    print(f"from epoch {start_epoch} re-training")
    return start_epoch

def get_font(size=30, bold=False):
    font_paths = [
        "C:/Windows/Fonts/arialbd.ttf" if bold else "C:/Windows/Fonts/arial.ttf",
        "/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf" if bold else "/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf",
        "/System/Library/Fonts/Supplemental/Arial Bold.ttf" if bold else "/System/Library/Fonts/Supplemental/Arial.ttf",
    ]
    for path in font_paths:
        try:
            return ImageFont.truetype(path, size=size)
        except:
            continue
    return ImageFont.load_default()

def denormalize(img_tensor, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)):
    mean = torch.tensor(mean).view(3, 1, 1).to(img_tensor.device)
    std = torch.tensor(std).view(3, 1, 1).to(img_tensor.device)
    return torch.clamp(img_tensor * std + mean, 0, 1)

print("functions defination done!")

functions defination done!


In [3]:
from torchvision import models
from torchvision.models import (
    MobileNet_V3_Large_Weights,
    MobileNet_V3_Small_Weights,
    MobileNet_V2_Weights,
    ResNet18_Weights,
)

def get_model(name, num_classes, pretrained=True):
    print(f"build model: {name}, class number: {num_classes}")
    
    if name == 'mobilenet_v3_large':
        weights = MobileNet_V3_Large_Weights.DEFAULT if pretrained else None
        model = models.mobilenet_v3_large(weights=weights)
        model.classifier[3] = nn.Linear(model.classifier[3].in_features, num_classes)
    elif name == 'mobilenet_v3_small':
        weights = MobileNet_V3_Small_Weights.DEFAULT if pretrained else None
        model = models.mobilenet_v3_small(weights=weights)
        model.classifier[3] = nn.Linear(model.classifier[3].in_features, num_classes)
    elif name == 'mobilenet_v2':
        weights = MobileNet_V2_Weights.DEFAULT if pretrained else None
        model = models.mobilenet_v2(weights=weights)
        model.classifier[1] = nn.Linear(model.classifier[1].in_features, num_classes)
    elif name == 'resnet18':
        weights = ResNet18_Weights.DEFAULT if pretrained else None
        model = models.resnet18(weights=weights)
        model.fc = nn.Linear(model.fc.in_features, num_classes)
    else:
        raise ValueError(f"Model {name} not supported.")

    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"parameter info: Total{total_params:,}, Trainable{trainable_params:,}")
    
    return model

print("model defination done!")

model defination done!


In [6]:
from torch.utils.data import DataLoader, Dataset
from PIL import Image

class BBCHDataset(Dataset):
    
    def __init__(self, data_dir, split='train', transform=None):
        self.data_dir = os.path.join(data_dir, split)
        self.transform = transform
        
        if not os.path.exists(self.data_dir):
            raise FileNotFoundError(f"Directory not found: {self.data_dir}")
        
        # get all classes
        self.classes = sorted([d for d in os.listdir(self.data_dir)
                              if os.path.isdir(os.path.join(self.data_dir, d))])
        self.class_to_idx = {cls: idx for idx, cls in enumerate(self.classes)}

        self.samples = []
        for class_name in self.classes:
            class_dir = os.path.join(self.data_dir, class_name)
            class_idx = self.class_to_idx[class_name]
            
            for img_name in os.listdir(class_dir):
                if img_name.lower().endswith(('.png', '.jpg', '.jpeg')):
                    img_path = os.path.join(class_dir, img_name)
                    self.samples.append((img_path, class_idx))
        
        print(f"loading {split} dataset: {len(self.samples)} images, {len(self.classes)} classes")
    
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        img_path, label = self.samples[idx]
        
        try:
            image = Image.open(img_path).convert('RGB')
        except Exception as e:
            print(f"failed to load image: {img_path}, error: {e}")
            image = Image.new('RGB', (224, 224), (0, 0, 0))
        
        if self.transform:
            image = self.transform(image)
        
        return image, label

def get_transforms(target_size=(640, 640)):
    train_transform = transforms.Compose([
        transforms.Resize(target_size),
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.RandomRotation(10),
        transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    
    val_transform = transforms.Compose([
        transforms.Resize(target_size),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    
    return train_transform, val_transform

def get_dataloaders(data_dir, batch_size=32, target_size=(640, 640), num_workers=4, include_test=False):
    print(f"building dataloader: {data_dir}")
    
    train_transform, val_transform = get_transforms(target_size)
    
    train_dataset = BBCHDataset(data_dir, 'train', train_transform)
    val_dataset = BBCHDataset(data_dir, 'val', val_transform)

    import platform
    if platform.system() == 'Windows':
        num_workers = 0
    
    train_loader = DataLoader(
        train_dataset, batch_size=batch_size, shuffle=True,
        num_workers=num_workers, pin_memory=torch.cuda.is_available()
    )
    
    val_loader = DataLoader(
        val_dataset, batch_size=batch_size, shuffle=False,
        num_workers=num_workers, pin_memory=torch.cuda.is_available()
    )
    
    test_loader = None
    if include_test:
        test_dir = os.path.join(data_dir, 'test')
        if os.path.exists(test_dir):
            test_dataset = BBCHDataset(data_dir, 'test', val_transform)
            test_loader = DataLoader(
                test_dataset, batch_size=batch_size, shuffle=False,
                num_workers=num_workers, pin_memory=torch.cuda.is_available()
            )
        else:
            print("test set not found, using val as test")
            test_loader = val_loader

    return train_loader, val_loader, test_loader
print(f"dataloader setup!")


dataloader setup!


In [7]:
CLASSIFICATION_CONFIG = {
    'data_dir': 'TomatoMAP-Cls',
    'model_name': 'mobilenet_v3_large',  # 'mobilenet_v3_large', 'mobilenet_v3_small', 'mobilenet_v2', 'resnet18'
    'num_classes': 50,
    'batch_size': 32,
    'num_epochs': 150,
    'learning_rate': 1e-4,
    'target_size': (640, 640),
    'patience': 15,
    'save_interval': 20
}

print("config:")
for key, value in CLASSIFICATION_CONFIG.items():
    print(f"  {key}: {value}")

config:
  data_dir: TomatoMAP-Cls
  model_name: mobilenet_v3_large
  num_classes: 50
  batch_size: 32
  num_epochs: 150
  learning_rate: 0.0001
  target_size: (640, 640)
  patience: 15
  save_interval: 20


In [6]:
def plot_training_curve(log_path, model_name):
    """绘制训练曲线"""
    if not os.path.exists(log_path):
        print("📭 没有找到训练日志")
        return
        
    df = pd.read_csv(log_path)
    
    fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(15, 10))
    
    # 损失曲线
    ax1.plot(df['epoch'], df['train_loss'], 'b-', label='训练损失', linewidth=2)
    ax1.plot(df['epoch'], df['val_loss'], 'r-', label='验证损失', linewidth=2)
    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('Loss')
    ax1.set_title('训练/验证损失')
    ax1.legend()
    ax1.grid(True, alpha=0.3)
    
    # 准确率曲线
    ax2.plot(df['epoch'], df['train_accuracy'], 'g-', label='训练准确率', linewidth=2)
    ax2.plot(df['epoch'], df['val_accuracy'], 'orange', label='验证准确率', linewidth=2)
    ax2.set_xlabel('Epoch')
    ax2.set_ylabel('Accuracy')
    ax2.set_title('训练/验证准确率')
    ax2.legend()
    ax2.grid(True, alpha=0.3)
    
    # 学习率曲线（如果有的话）
    ax3.plot(df['epoch'], df['train_loss'], 'b-', alpha=0.7)
    ax3.set_xlabel('Epoch')
    ax3.set_ylabel('Train Loss')
    ax3.set_title('训练损失详细')
    ax3.grid(True, alpha=0.3)
    
    # 过拟合检测
    overfitting = df['train_accuracy'] - df['val_accuracy']
    ax4.plot(df['epoch'], overfitting, 'purple', label='过拟合程度', linewidth=2)
    ax4.axhline(y=0, color='black', linestyle='--', alpha=0.5)
    ax4.set_xlabel('Epoch')
    ax4.set_ylabel('Train Acc - Val Acc')
    ax4.set_title('过拟合检测')
    ax4.legend()
    ax4.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.savefig(f"checkpoints/{model_name}_detailed_curves.png", dpi=300, bbox_inches='tight')
    plt.show()
    print(f"📈 训练曲线已保存: checkpoints/{model_name}_detailed_curves.png")

def evaluate_classification(model, dataloader, criterion, device):
    """评估分类模型"""
    model.eval()
    total_loss = 0.0
    correct = 0
    total = 0
    
    with torch.no_grad():
        for images, labels in dataloader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)
            
            total_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
    
    avg_loss = total_loss / len(dataloader)
    accuracy = correct / total
    return avg_loss, accuracy

def train_classification(config):
    """PyTorch分类训练主函数"""
    print("🚀 开始PyTorch分类训练")
    print("="*50)
    
    # 设备检查
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"🔥 使用设备: {device}")
    
    try:
        # 数据加载器
        train_loader, val_loader, _ = get_dataloaders(
            config['data_dir'], 
            config['batch_size'], 
            config['target_size']
        )
        
        # 获取类别信息
        num_classes = len(train_loader.dataset.classes)
        class_names = train_loader.dataset.classes
        config['num_classes'] = num_classes  # 更新配置
        
        print(f"📊 数据集信息:")
        print(f"  训练批次: {len(train_loader)}")
        print(f"  验证批次: {len(val_loader)}")
        print(f"  实际类别数: {num_classes}")
        print(f"  类别名称: {class_names[:5]}...")
        
        # 创建模型
        model = get_model(config['model_name'], num_classes).to(device)
        
        # 优化器和损失函数
        criterion = nn.CrossEntropyLoss()
        optimizer = optim.Adam(model.parameters(), lr=config['learning_rate'])
        scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=config['num_epochs'])
        
        # 创建保存目录
        os.makedirs("checkpoints", exist_ok=True)
        log_path = f"checkpoints/{config['model_name']}_train_log.csv"
        
        if not os.path.exists(log_path):
            with open(log_path, 'w') as f:
                f.write("epoch,train_loss,train_accuracy,val_loss,val_accuracy\\n")
        
        print(f"💾 保存设置:")
        print(f"  检查点目录: checkpoints/")
        print(f"  训练日志: {log_path}")
        
        # 训练循环
        best_val_loss = float('inf')
        patience_counter = 0
        
        for epoch in range(config['num_epochs']):
            # 训练阶段
            model.train()
            running_loss = 0.0
            correct = 0
            total = 0
            
            print(f"\\n📅 Epoch {epoch+1}/{config['num_epochs']}")
            
            progress_bar = tqdm(train_loader, desc=f"训练Epoch {epoch+1}")
            for batch_idx, (images, labels) in enumerate(progress_bar):
                images, labels = images.to(device), labels.to(device)
                
                optimizer.zero_grad()
                outputs = model(images)
                loss = criterion(outputs, labels)
                loss.backward()
                optimizer.step()
                
                running_loss += loss.item()
                _, predicted = outputs.max(1)
                total += labels.size(0)
                correct += predicted.eq(labels).sum().item()
                
                # 更新进度条
                current_acc = correct / total
                progress_bar.set_postfix({
                    'Loss': f'{loss.item():.4f}',
                    'Acc': f'{current_acc:.4f}'
                })
                
                # 定期显示详细信息
                if (batch_idx + 1) % 50 == 0:
                    print(f"    Batch {batch_idx+1}/{len(train_loader)} - Loss: {loss.item():.4f}, Acc: {current_acc:.4f}")
            
            # 计算训练指标
            train_loss = running_loss / len(train_loader)
            train_acc = correct / total
            
            # 验证阶段
            val_loss, val_acc = evaluate_classification(model, val_loader, criterion, device)
            
            # 学习率调度
            scheduler.step()
            current_lr = scheduler.get_last_lr()[0]
            
            print(f"📊 Epoch {epoch+1} 结果:")
            print(f"  训练: Loss={train_loss:.4f}, Acc={train_acc:.4f}")
            print(f"  验证: Loss={val_loss:.4f}, Acc={val_acc:.4f}")
            print(f"  学习率: {current_lr:.6f}")
            
            # 记录训练日志
            with open(log_path, 'a') as f:
                f.write(f"{epoch+1},{train_loss:.4f},{train_acc:.4f},{val_loss:.4f},{val_acc:.4f}\\n")
            
            # 保存检查点
            save_checkpoint(model, optimizer, epoch, "checkpoints/last_checkpoint.pth")
            
            # 保存最佳模型
            if val_loss < best_val_loss:
                best_val_loss = val_loss
                patience_counter = 0
                save_model(model, f"checkpoints/{config['model_name']}_best.pth")
                print("🌟 保存最佳模型!")
            else:
                patience_counter += 1
                print(f"⏰ 早停计数: {patience_counter}/{config['patience']}")
            
            # 定期保存
            if (epoch + 1) % config['save_interval'] == 0:
                save_checkpoint(model, optimizer, epoch, 
                              f"checkpoints/{config['model_name']}_epoch{epoch+1}.pth")
                print(f"💾 保存检查点: epoch{epoch+1}")
            
            # 检查早停
            if patience_counter >= config['patience']:
                print("🛑 触发早停!")
                break
        
        # 保存最终模型
        save_model(model, f"{config['model_name']}_final.pth")
        
        # 绘制训练曲线
        plot_training_curve(log_path, config['model_name'])
        
        print(f"\\n🎉 训练完成!")
        print(f"📁 最佳模型: checkpoints/{config['model_name']}_best.pth")
        print(f"📁 最终模型: {config['model_name']}_final.pth")
        
        return model, train_loader, val_loader
        
    except Exception as e:
        print(f"❌ 训练失败: {e}")
        import traceback
        traceback.print_exc()
        return None, None, None

print("✅ PyTorch分类训练函数定义完成")

✅ PyTorch分类训练函数定义完成


In [9]:
# 完整的训练循环
def train_model(config):
    """完整的分类模型训练函数"""
    
    # 设置设备
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"🔧 使用设备: {device}")
    
    # 创建输出目录
    output_dir = Path(f"runs/{config['model_name']}_cls")
    output_dir.mkdir(parents=True, exist_ok=True)
    
    # 获取数据加载器（包含test集）
    train_loader, val_loader, test_loader = get_dataloaders(
        config['data_dir'], 
        batch_size=config['batch_size'],
        target_size=config['target_size'],
        num_workers=4,
        include_test=True
    )
    
    # 创建模型
    model = get_model(config['model_name'], config['num_classes'], pretrained=True)
    model = model.to(device)
    
    # 损失函数和优化器
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.AdamW(model.parameters(), lr=config['learning_rate'], weight_decay=0.01)
    
    # 学习率调度器
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)
    
    # 训练历史记录
    train_losses = []
    val_losses = []
    train_accuracies = []
    val_accuracies = []
    
    # 早停参数
    best_val_acc = 0.0
    patience_counter = 0
    
    print(f"🚀 开始训练 {config['num_epochs']} 个epoch...")
    
    for epoch in range(config['num_epochs']):
        # 训练阶段
        model.train()
        train_loss = 0.0
        train_correct = 0
        train_total = 0
        
        train_pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{config['num_epochs']} [训练]")
        
        for batch_idx, (images, labels) in enumerate(train_pbar):
            images, labels = images.to(device), labels.to(device)
            
            # 前向传播
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            
            # 反向传播
            loss.backward()
            optimizer.step()
            
            # 统计
            train_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            train_total += labels.size(0)
            train_correct += (predicted == labels).sum().item()
            
            # 更新进度条
            current_acc = 100 * train_correct / train_total
            train_pbar.set_postfix({
                'Loss': f'{loss.item():.4f}',
                'Acc': f'{current_acc:.2f}%'
            })
        
        # 验证阶段
        model.eval()
        val_loss = 0.0
        val_correct = 0
        val_total = 0
        
        with torch.no_grad():
            val_pbar = tqdm(val_loader, desc=f"Epoch {epoch+1}/{config['num_epochs']} [验证]")
            
            for images, labels in val_pbar:
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                loss = criterion(outputs, labels)
                
                val_loss += loss.item()
                _, predicted = torch.max(outputs.data, 1)
                val_total += labels.size(0)
                val_correct += (predicted == labels).sum().item()
                
                current_acc = 100 * val_correct / val_total
                val_pbar.set_postfix({
                    'Loss': f'{loss.item():.4f}',
                    'Acc': f'{current_acc:.2f}%'
                })
        
        # 计算平均损失和准确率
        avg_train_loss = train_loss / len(train_loader)
        avg_val_loss = val_loss / len(val_loader)
        train_acc = 100 * train_correct / train_total
        val_acc = 100 * val_correct / val_total
        
        # 记录历史
        train_losses.append(avg_train_loss)
        val_losses.append(avg_val_loss)
        train_accuracies.append(train_acc)
        val_accuracies.append(val_acc)
        
        # 更新学习率
        scheduler.step()
        current_lr = optimizer.param_groups[0]['lr']
        
        # 打印epoch结果
        print(f"\n📊 Epoch {epoch+1}/{config['num_epochs']}:")
        print(f"  训练 - Loss: {avg_train_loss:.4f}, Acc: {train_acc:.2f}%")
        print(f"  验证 - Loss: {avg_val_loss:.4f}, Acc: {val_acc:.2f}%")
        print(f"  学习率: {current_lr:.2e}")
        
        # 保存最佳模型
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            best_model_path = output_dir / f"best_{config['model_name']}.pth"
            save_model(model, best_model_path)
            patience_counter = 0
            print(f"  ⭐ 新的最佳验证准确率: {best_val_acc:.2f}%")
        else:
            patience_counter += 1
            print(f"  ⏳ 验证准确率未提升 ({patience_counter}/{config['patience']})")
        
        # 定期保存检查点
        if (epoch + 1) % config['save_interval'] == 0:
            checkpoint_path = output_dir / f"checkpoint_epoch_{epoch+1}.pth"
            save_checkpoint(model, optimizer, epoch, checkpoint_path)
        
        # 早停检查
        if patience_counter >= config['patience']:
            print(f"\n🛑 早停触发! 验证准确率连续 {config['patience']} 个epoch未提升")
            break
        
        print("-" * 60)
    
    # 训练完成，保存最终模型
    final_model_path = output_dir / f"final_{config['model_name']}.pth"
    save_model(model, final_model_path)
    
    print(f"\n🎉 训练完成!")
    print(f"  最佳验证准确率: {best_val_acc:.2f}%")
    print(f"  模型保存路径: {output_dir}")
    
    # 绘制训练曲线
    plt.figure(figsize=(15, 5))
    
    # 损失曲线
    plt.subplot(1, 3, 1)
    plt.plot(train_losses, label='训练损失', color='blue')
    plt.plot(val_losses, label='验证损失', color='red')
    plt.title('训练损失曲线')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.grid(True, alpha=0.3)
    
    # 准确率曲线
    plt.subplot(1, 3, 2)
    plt.plot(train_accuracies, label='训练准确率', color='blue')
    plt.plot(val_accuracies, label='验证准确率', color='red')
    plt.title('训练准确率曲线')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy (%)')
    plt.legend()
    plt.grid(True, alpha=0.3)
    
    # 学习率曲线
    plt.subplot(1, 3, 3)
    lrs = []
    for i in range(len(train_losses)):
        if i < 30:
            lrs.append(config['learning_rate'])
        elif i < 60:
            lrs.append(config['learning_rate'] * 0.1)
        else:
            lrs.append(config['learning_rate'] * 0.01)
    plt.plot(lrs, label='学习率', color='green')
    plt.title('学习率变化')
    plt.xlabel('Epoch')
    plt.ylabel('Learning Rate')
    plt.yscale('log')
    plt.legend()
    plt.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.savefig(output_dir / 'training_curves.png', dpi=300, bbox_inches='tight')
    plt.show()
    
    # 保存训练历史
    history_df = pd.DataFrame({
        'epoch': range(1, len(train_losses) + 1),
        'train_loss': train_losses,
        'val_loss': val_losses,
        'train_acc': train_accuracies,
        'val_acc': val_accuracies
    })
    history_df.to_csv(output_dir / 'training_history.csv', index=False)
    print(f"📈 训练历史已保存: {output_dir / 'training_history.csv'}")
    
    return model, best_val_acc, output_dir, test_loader


# 🚀 开始执行训练
print("=" * 60)
print("🎯 TomatoMAP 分类模型训练")
print("=" * 60)

# 首先检查数据目录
if not os.path.exists(CLASSIFICATION_CONFIG['data_dir']):
    print(f"❌ 错误: 数据目录不存在")
    print(f"   路径: {CLASSIFICATION_CONFIG['data_dir']}")
    print(f"   请确保数据目录存在并包含 train/ 和 val/ 子目录")
else:
    print(f"✅ 数据目录存在: {CLASSIFICATION_CONFIG['data_dir']}")
    
    # 检查train、val和test目录
    train_dir = os.path.join(CLASSIFICATION_CONFIG['data_dir'], 'train')
    val_dir = os.path.join(CLASSIFICATION_CONFIG['data_dir'], 'val')
    test_dir = os.path.join(CLASSIFICATION_CONFIG['data_dir'], 'test')
    
    if not os.path.exists(train_dir):
        print(f"❌ 错误: 训练目录不存在: {train_dir}")
    elif not os.path.exists(val_dir):
        print(f"❌ 错误: 验证目录不存在: {val_dir}")
    elif not os.path.exists(test_dir):
        print(f"⚠️ 警告: 测试目录不存在: {test_dir}")
        print(f"   将使用验证集作为测试集")
    else:
        print(f"✅ 训练、验证和测试目录都存在")
        
        # 显示训练配置
        print("\n⚙️ 训练配置:")
        for key, value in CLASSIFICATION_CONFIG.items():
            print(f"   {key}: {value}")
        
        print("\n🚀 开始训练...")
        
        try:
            # 执行训练
            model, best_acc, output_dir, test_loader = train_model(CLASSIFICATION_CONFIG)
            
            print("\n" + "=" * 60)
            print("🎉 训练成功完成!")
            print(f"   最佳验证准确率: {best_acc:.2f}%")
            print(f"   模型保存目录: {output_dir}")
            
            # 在测试集上评估
            print("\n🧪 在测试集上评估模型...")
            model.eval()
            test_correct = 0
            test_total = 0
            test_predictions = []
            test_labels = []
            
            device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
            
            with torch.no_grad():
                test_pbar = tqdm(test_loader, desc="测试集评估")
                for images, labels in test_pbar:
                    images, labels = images.to(device), labels.to(device)
                    outputs = model(images)
                    _, predicted = torch.max(outputs.data, 1)
                    
                    test_total += labels.size(0)
                    test_correct += (predicted == labels).sum().item()
                    
                    # 收集预测结果用于混淆矩阵
                    test_predictions.extend(predicted.cpu().numpy())
                    test_labels.extend(labels.cpu().numpy())
                    
                    current_acc = 100 * test_correct / test_total
                    test_pbar.set_postfix({'Acc': f'{current_acc:.2f}%'})
            
            test_accuracy = 100 * test_correct / test_total
            print(f"🎯 测试集准确率: {test_accuracy:.2f}%")
            
            # 绘制混淆矩阵
            print("\n📊 生成混淆矩阵...")
            
            # 获取类别名称
            train_dataset = test_loader.dataset
            class_names = train_dataset.classes
            
            # 计算混淆矩阵
            cm = confusion_matrix(test_labels, test_predictions)
            
            # 绘制混淆矩阵
            plt.figure(figsize=(12, 10))
            disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=class_names)
            disp.plot(cmap='Blues', values_format='d')
            plt.title(f'混淆矩阵 (测试集准确率: {test_accuracy:.2f}%)', fontsize=16)
            plt.xticks(rotation=45, ha='right')
            plt.yticks(rotation=0)
            plt.tight_layout()
            plt.savefig(output_dir / 'confusion_matrix.png', dpi=300, bbox_inches='tight')
            plt.show()
            
            # 保存测试结果
            test_results = {
                'test_accuracy': test_accuracy,
                'total_samples': test_total,
                'correct_predictions': test_correct,
                'num_classes': len(class_names),
                'class_names': class_names
            }
            
            import json
            with open(output_dir / 'test_results.json', 'w', encoding='utf-8') as f:
                json.dump(test_results, f, indent=2, ensure_ascii=False)
            
            print("\n" + "=" * 60)
            print("📈 完整评估结果:")
            print(f"   验证集最佳准确率: {best_acc:.2f}%")
            print(f"   测试集准确率: {test_accuracy:.2f}%")
            print(f"   类别数量: {len(class_names)}")
            print(f"   测试样本数: {test_total}")
            print(f"   结果保存路径: {output_dir}")
            print("=" * 60)
            
        except KeyboardInterrupt:
            print("\n⏹️ 训练被用户中断")
            
        except Exception as e:
            print(f"\n❌ 训练过程中出现错误:")
            print(f"   错误信息: {str(e)}")
            print("\n详细错误信息:")
            import traceback
            traceback.print_exc()

🎯 TomatoMAP 分类模型训练
✅ 数据目录存在: TomatoMAP-Cls
✅ 训练、验证和测试目录都存在

⚙️ 训练配置:
   data_dir: TomatoMAP-Cls
   model_name: mobilenet_v3_large
   num_classes: 50
   batch_size: 32
   num_epochs: 150
   learning_rate: 0.0001
   target_size: (640, 640)
   patience: 15
   save_interval: 20

🚀 开始训练...
🔧 使用设备: cpu
🔧 创建数据加载器: TomatoMAP-Cls
📊 加载 train 数据集: 45099 张图片, 50 个类别
📊 加载 val 数据集: 12870 张图片, 50 个类别
📊 加载 test 数据集: 6495 张图片, 50 个类别
✅ 数据加载器创建完成
🤖 创建模型: mobilenet_v3_large, 类别数: 50
📊 参数统计: 总数4,266,082, 可训练4,266,082
🚀 开始训练 150 个epoch...


Epoch 1/150 [训练]:   0%|▏                                                                                                                            | 2/1410 [00:21<4:10:10, 10.66s/it, Loss=3.9196, Acc=4.69%]



⏹️ 训练被用户中断


# TomatoMAP-Seg Trainer

In [1]:
import os
import cv2
import json
import yaml
import random
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
from tqdm import tqdm
from pathlib import Path
from sklearn.metrics import confusion_matrix
from pycocotools import mask as mask_utils

# Detectron2相关导入
import detectron2
from detectron2.engine import DefaultTrainer, DefaultPredictor
from detectron2.config import get_cfg
from detectron2 import model_zoo
from detectron2.data import MetadataCatalog, build_detection_test_loader
from detectron2.data.datasets import register_coco_instances
from detectron2.evaluation import COCOEvaluator, inference_on_dataset
from detectron2.utils.visualizer import Visualizer
from detectron2.data.datasets.coco import load_coco_json
from detectron2.utils.logger import setup_logger

# 设置日志
setup_logger()

print("🔧 环境检查:")
print(f"  PyTorch版本: {torch.__version__}")
print(f"  CUDA可用: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"  GPU设备: {torch.cuda.get_device_name(0)}")
    print(f"  GPU内存: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB")
print(f"  Detectron2版本: {detectron2.__version__}")
print("✅ 环境检查完成")

🔧 环境检查:
  PyTorch版本: 2.3.0+cu121
  CUDA可用: False
  Detectron2版本: 0.6
✅ 环境检查完成


  return torch._C._cuda_getDeviceCount() > 0


In [3]:
# 数据集配置
DATASET_CONFIG = {
    'dataset_root': "./",
    'img_dir': "images",  # 图像目录
    'coco_ann_dir': "labels",  # COCO格式标注目录
    'isat_yaml_path': "isat.yaml",  # ISAT配置文件
    'output_dir': "./output",  # 输出目录
    'num_classes': 10,  # 类别数量（不包括背景）
}

# 训练配置
TRAINING_CONFIG = {
    'model_name': "mask_rcnn_R_50_FPN_3x",
    'batch_size': 2,
    'base_lr': 0.00025,
    'max_iter': 1000,
    'num_workers': 0,  # Windows设为0，Linux可设为4
    'score_thresh_test': 0.5,
    'input_min_size_test': 800,
    'input_max_size_test': 1333,
}

print("⚙️ 语义实例分割配置:")
print("📁 数据集配置:")
for key, value in DATASET_CONFIG.items():
    print(f"  {key}: {value}")
print("\n🎯 训练配置:")
for key, value in TRAINING_CONFIG.items():
    print(f"  {key}: {value}")

⚙️ 语义实例分割配置:
📁 数据集配置:
  dataset_root: ./
  img_dir: images
  coco_ann_dir: labels
  isat_yaml_path: isat.yaml
  output_dir: ./output
  num_classes: 10

🎯 训练配置:
  model_name: mask_rcnn_R_50_FPN_3x
  batch_size: 2
  base_lr: 0.00025
  max_iter: 1000
  num_workers: 0
  score_thresh_test: 0.5
  input_min_size_test: 800
  input_max_size_test: 1333


In [4]:
# Cell 4: 数据集验证和注册
print("🔍 验证数据集完整性...")

# 检查必要的文件和目录
required_paths = [
    DATASET_CONFIG['img_dir'],
    DATASET_CONFIG['coco_ann_dir'],
    DATASET_CONFIG['isat_yaml_path']
]

for path in required_paths:
    if os.path.exists(path):
        print(f"  ✅ {path}")
    else:
        print(f"  ❌ {path} - 路径不存在!")

# 检查COCO格式的标注文件
coco_files = ['train.json', 'val.json', 'test.json']
available_splits = []

for coco_file in coco_files:
    coco_path = os.path.join(DATASET_CONFIG['coco_ann_dir'], coco_file)
    if os.path.exists(coco_path):
        print(f"  ✅ {coco_file}")
        available_splits.append(coco_file.replace('.json', ''))
        
        # 检查标注文件的基本信息
        with open(coco_path, 'r') as f:
            coco_data = json.load(f)
        print(f"    - 图像数量: {len(coco_data['images'])}")
        print(f"    - 标注数量: {len(coco_data['annotations'])}")
        print(f"    - 类别数量: {len(coco_data['categories'])}")
    else:
        print(f"  ⚠️ {coco_file} - 文件不存在")

if not available_splits:
    print("❌ 错误: 没有找到任何COCO格式的标注文件!")
else:
    print(f"\n📊 可用数据集: {available_splits}")
    
    # 注册数据集
    class_labels = register_all_datasets()
    print("✅ 数据集注册完成")

🔍 验证数据集完整性...
  ❌ images - 路径不存在!
  ✅ labels
  ❌ isat.yaml - 路径不存在!
  ⚠️ train.json - 文件不存在
  ⚠️ val.json - 文件不存在
  ⚠️ test.json - 文件不存在
❌ 错误: 没有找到任何COCO格式的标注文件!


In [5]:
def train_model():
    """执行模型训练"""
    print("🚀 开始训练语义实例分割模型...")
    
    # 构建配置
    cfg = build_cfg()
    
    # 保存配置
    save_config(cfg, cfg.OUTPUT_DIR)
    
    # 打印训练信息
    print(f"\n📋 训练信息:")
    print(f"  模型: {TRAINING_CONFIG['model_name']}")
    print(f"  类别数: {DATASET_CONFIG['num_classes']}")
    print(f"  批次大小: {TRAINING_CONFIG['batch_size']}")
    print(f"  学习率: {TRAINING_CONFIG['base_lr']}")
    print(f"  最大迭代: {TRAINING_CONFIG['max_iter']}")
    print(f"  输出目录: {cfg.OUTPUT_DIR}")
    print(f"  设备: {cfg.MODEL.DEVICE}")
    
    # 创建训练器
    trainer = CustomTrainer(cfg)
    trainer.resume_or_load(resume=False)
    
    print(f"\n🎯 开始训练 {TRAINING_CONFIG['max_iter']} 步...")
    
    try:
        # 开始训练
        trainer.train()
        
        print("\n🎉 训练完成!")
        
        # 保存最终模型路径
        final_model_path = os.path.join(cfg.OUTPUT_DIR, "model_final.pth")
        if os.path.exists(final_model_path):
            print(f"💾 最终模型已保存: {final_model_path}")
        
        return trainer, cfg
        
    except KeyboardInterrupt:
        print("\n⏹️ 训练被用户中断")
        return None, cfg
        
    except Exception as e:
        print(f"\n❌ 训练过程中出现错误:")
        print(f"   错误信息: {str(e)}")
        import traceback
        traceback.print_exc()
        return None, cfg

# 检查是否可以开始训练
if 'class_labels' in locals() and available_splits:
    print("=" * 60)
    print("🎯 准备开始训练...")
    print("=" * 60)
    
    # 执行训练
    trainer, cfg = train_model()
    
    if trainer is not None:
        print("\n✅ 训练流程完成!")
        
        # 显示训练结果位置
        print(f"\n📁 输出文件:")
        output_dir = Path(cfg.OUTPUT_DIR)
        if output_dir.exists():
            for file in output_dir.iterdir():
                if file.is_file():
                    print(f"  📄 {file.name}")
    else:
        print("\n❌ 训练未成功完成")
else:
    print("❌ 无法开始训练，请检查数据集配置")

❌ 无法开始训练，请检查数据集配置


In [6]:
def evaluate_model(dataset_name="tomato_test", config_override=None):
    """评估训练好的模型"""
    print(f"📊 在 {dataset_name} 上评估模型...")
    
    # 构建配置
    cfg = build_cfg()
    if config_override:
        cfg.update(config_override)
    
    # 加载训练好的模型
    model_path = os.path.join(cfg.OUTPUT_DIR, "model_final.pth")
    if not os.path.exists(model_path):
        print(f"❌ 模型文件不存在: {model_path}")
        return None
    
    cfg.MODEL.WEIGHTS = model_path
    
    # 创建评估器
    evaluator = COCOEvaluator(dataset_name, cfg, False, output_dir=cfg.OUTPUT_DIR)
    val_loader = build_detection_test_loader(cfg, dataset_name)
    
    # 构建模型
    model = DefaultTrainer.build_model(cfg)
    
    print("🔍 开始评估...")
    results = inference_on_dataset(model, val_loader, evaluator)
    
    # 显示结果
    print("\n📈 评估结果:")
    print(json.dumps(results, indent=2))
    
    # 保存结果
    results_path = os.path.join(cfg.OUTPUT_DIR, f"eval_results_{dataset_name}.json")
    with open(results_path, 'w') as f:
        json.dump(results, f, indent=2)
    print(f"💾 评估结果已保存: {results_path}")
    
    return results

# 如果训练完成，进行评估
if 'trainer' in locals() and trainer is not None:
    print("\n" + "=" * 60)
    print("📊 开始模型评估")
    print("=" * 60)
    
    # 在测试集上评估
    if 'test' in available_splits:
        test_results = evaluate_model("tomato_test")
    
    # 在验证集上评估
    if 'val' in available_splits:
        val_results = evaluate_model("tomato_val")
    
    print("✅ 评估完成!")

In [7]:
def visualize_predictions(dataset_name="tomato_test", num_samples=5):
    """可视化预测结果"""
    print(f"🎨 可视化 {dataset_name} 预测结果...")
    
    # 构建配置
    cfg = build_cfg()
    cfg.MODEL.WEIGHTS = os.path.join(cfg.OUTPUT_DIR, "model_final.pth")
    
    # 创建预测器
    predictor = DefaultPredictor(cfg)
    metadata = MetadataCatalog.get(dataset_name)
    
    # 加载数据集
    split_name = dataset_name.split('_')[-1]
    dataset_dicts = load_coco_json(
        os.path.join(DATASET_CONFIG['coco_ann_dir'], f"{split_name}.json"), 
        DATASET_CONFIG['img_dir']
    )
    
    # 随机选择样本进行可视化
    sample_data = random.sample(dataset_dicts, min(num_samples, len(dataset_dicts)))
    
    print(f"🖼️ 处理 {len(sample_data)} 张图像...")
    
    for i, d in enumerate(sample_data):
        img_path = d["file_name"]
        if not os.path.exists(img_path):
            print(f"⚠️ 图像不存在: {img_path}")
            continue
            
        # 读取图像
        im = cv2.imread(img_path)
        
        # 预测
        outputs = predictor(im)
        
        # 可视化
        v = Visualizer(im[:, :, ::-1], metadata=metadata, scale=1.2)
        v._default_font_size = 20
        out = v.draw_instance_predictions(outputs["instances"].to("cpu"))
        
        # 保存结果
        save_path = os.path.join(cfg.OUTPUT_DIR, f"prediction_{i+1}_{os.path.basename(img_path)}")
        cv2.imwrite(save_path, out.get_image()[:, :, ::-1])
        print(f"  💾 {save_path}")
    
    print("✅ 可视化完成!")

def generate_confusion_matrix(dataset_name="tomato_val"):
    """生成混淆矩阵"""
    print(f"📊 生成 {dataset_name} 混淆矩阵...")
    
    cfg = build_cfg()
    cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.3
    cfg.MODEL.WEIGHTS = os.path.join(cfg.OUTPUT_DIR, "model_final.pth")
    predictor = DefaultPredictor(cfg)

    split_name = dataset_name.split('_')[-1]
    dataset_dicts = load_coco_json(
        os.path.join(DATASET_CONFIG['coco_ann_dir'], f"{split_name}.json"),
        DATASET_CONFIG['img_dir'],
        dataset_name=dataset_name
    )

    metadata = MetadataCatalog.get(dataset_name)
    class_names = metadata.thing_classes
    num_classes = len(class_names)

    cmatrix_total = np.zeros((num_classes, num_classes), dtype=np.int64)

    print(f"🔍 处理 {len(dataset_dicts)} 张图像...")
    
    for data in tqdm(dataset_dicts, desc="生成混淆矩阵"):
        height, width = data["height"], data["width"]
        image = cv2.imread(data["file_name"])
        if image is None:
            continue

        # 生成真实标注mask
        gt_mask = np.zeros((height, width), dtype=np.uint8)
        for ann in data.get("annotations", []):
            category_id = ann["category_id"]
            segmentation = ann["segmentation"]
            if isinstance(segmentation, list):
                rle = mask_utils.frPyObjects(segmentation, height, width)
                rle = mask_utils.merge(rle)
            elif isinstance(segmentation, dict) and "counts" in segmentation:
                rle = segmentation
            else:
                continue
            m = mask_utils.decode(rle)
            gt_mask[m == 1] = category_id

        # 生成预测mask
        outputs = predictor(image)
        instances = outputs["instances"].to("cpu")

        pred_mask = np.zeros((height, width), dtype=np.uint8)
        for i in range(len(instances)):
            class_id = int(instances.pred_classes[i])
            mask = instances.pred_masks[i].numpy()
            pred_mask[mask == 1] = class_id

        # 计算混淆矩阵
        cm_local = confusion_matrix(
            gt_mask.flatten(),
            pred_mask.flatten(),
            labels=list(range(num_classes))
        )
        cmatrix_total += cm_local

    # 归一化混淆矩阵
    cmatrix_norm = np.nan_to_num(cmatrix_total.astype('float') / cmatrix_total.sum(axis=1, keepdims=True))
    
    # 保存为Excel
    df = pd.DataFrame(cmatrix_norm, index=class_names, columns=class_names)
    excel_path = os.path.join(cfg.OUTPUT_DIR, f"confusion_matrix_{split_name}.xlsx")
    df.to_excel(excel_path)
    print(f"💾 混淆矩阵已保存: {excel_path}")

    # 绘制混淆矩阵
    fig, ax = plt.subplots(figsize=(10, 8))
    masked = np.ma.masked_where(cmatrix_norm == 0, cmatrix_norm)
    im = ax.imshow(masked, cmap="Blues", vmin=0.0, vmax=1.0)
    
    # 添加数值标注
    for i in range(num_classes):
        for j in range(num_classes):
            val = cmatrix_norm[i, j]
            if val > 0:
                color = 'white' if val > 0.5 else 'black'
                ax.text(j, i, f"{val:.2f}", ha='center', va='center', color=color, fontsize=8)
    
    ax.set_xticks(np.arange(num_classes))
    ax.set_yticks(np.arange(num_classes))
    ax.set_xticklabels(class_names, rotation=45, ha='right')
    ax.set_yticklabels(class_names)
    ax.set_xlabel("预测类别")
    ax.set_ylabel("真实类别")
    ax.set_title(f"混淆矩阵 - {dataset_name}")
    
    plt.colorbar(im, ax=ax)
    plt.tight_layout()
    
    # 保存图像
    img_path = os.path.join(cfg.OUTPUT_DIR, f"confusion_matrix_{split_name}.png")
    plt.savefig(img_path, dpi=300, bbox_inches='tight')
    plt.show()
    
    print(f"🎨 混淆矩阵图已保存: {img_path}")
    print("✅ 混淆矩阵生成完成!")

# 如果模型训练完成，进行可视化
if 'trainer' in locals() and trainer is not None:
    print("\n" + "=" * 60)
    print("🎨 开始结果可视化")
    print("=" * 60)
    
    # 可视化预测结果
    if 'test' in available_splits:
        visualize_predictions("tomato_test", num_samples=3)
    
    # 生成混淆矩阵
    if 'val' in available_splits:
        generate_confusion_matrix("tomato_val")
    
    print("✅ 可视化完成!")

print("\n🎉 Detectron2语义实例分割训练流程完成!")


🎉 Detectron2语义实例分割训练流程完成!
