# TomatoMAP-Cls Trainer

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import os
from tqdm import tqdm
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
from pathlib import Path
import argparse
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
from PIL import ImageDraw, ImageFont, Image
import torchvision.transforms as transforms
from torchvision import models
from torchvision.models import (
    MobileNet_V3_Large_Weights,
    MobileNet_V3_Small_Weights,
    MobileNet_V2_Weights,
    ResNet18_Weights,
)

# env checker
print("Environment checker:")
print(f"  PyTorch version: {torch.__version__}")
print(f"  CUDA version: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"  GPU device: {torch.cuda.get_device_name(0)}")
    print(f"  GPU ram: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB")

Environment checker:
  PyTorch version: 2.7.1+cu126
  CUDA version: True
  GPU device: Tesla V100-PCIE-16GB
  GPU ram: 15.8 GB


In [2]:
def save_model(model, path):
    torch.save(model.state_dict(), path)
    print(f"model saved at: {path}")

def load_model(model, path, device='cpu'):
    model.load_state_dict(torch.load(path, map_location=device))
    model.eval()
    print(f"model loaded from: {path}")

def save_checkpoint(model, optimizer, epoch, path):
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict()
    }, path)

def load_checkpoint(path, model, optimizer, device):
    checkpoint = torch.load(path, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    start_epoch = checkpoint['epoch'] + 1
    print(f"from epoch {start_epoch} re-training")
    return start_epoch

def get_font(size=30, bold=False):
    font_paths = [
        "C:/Windows/Fonts/arialbd.ttf" if bold else "C:/Windows/Fonts/arial.ttf",
        "/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf" if bold else "/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf",
        "/System/Library/Fonts/Supplemental/Arial-Bold.ttf" if bold else "/System/Library/Fonts/Supplemental/Arial.ttf",
    ]
    for path in font_paths:
        try:
            return ImageFont.truetype(path, size=size)
        except:
            continue
    return ImageFont.load_default()

def denormalize(img_tensor, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)):
    mean = torch.tensor(mean).view(3, 1, 1).to(img_tensor.device)
    std = torch.tensor(std).view(3, 1, 1).to(img_tensor.device)
    return torch.clamp(img_tensor * std + mean, 0, 1)

def get_model(name, num_classes, pretrained=True):
    print(f"build model: {name}, class number: {num_classes}")
    
    if name == 'mobilenet_v3_large':
        weights = MobileNet_V3_Large_Weights.DEFAULT if pretrained else None
        model = models.mobilenet_v3_large(weights=weights)
        model.classifier[3] = nn.Linear(model.classifier[3].in_features, num_classes)
    elif name == 'mobilenet_v3_small':
        weights = MobileNet_V3_Small_Weights.DEFAULT if pretrained else None
        model = models.mobilenet_v3_small(weights=weights)
        model.classifier[3] = nn.Linear(model.classifier[3].in_features, num_classes)
    elif name == 'mobilenet_v2':
        weights = MobileNet_V2_Weights.DEFAULT if pretrained else None
        model = models.mobilenet_v2(weights=weights)
        model.classifier[1] = nn.Linear(model.classifier[1].in_features, num_classes)
    elif name == 'resnet18':
        weights = ResNet18_Weights.DEFAULT if pretrained else None
        model = models.resnet18(weights=weights)
        model.fc = nn.Linear(model.fc.in_features, num_classes)
    else:
        raise ValueError(f"Model {name} not supported.")

    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"parameter info: Total: {total_params:,}, Trainable: {trainable_params:,}")
    
    return model

class BBCHDataset(Dataset):
    
    def __init__(self, data_dir, split='train', transform=None):
        self.data_dir = os.path.join(data_dir, split)
        self.transform = transform
        
        if not os.path.exists(self.data_dir):
            raise FileNotFoundError(f"Directory not found: {self.data_dir}")
        
        # get all classes
        self.classes = sorted([d for d in os.listdir(self.data_dir)
                              if os.path.isdir(os.path.join(self.data_dir, d))])
        self.class_to_idx = {cls: idx for idx, cls in enumerate(self.classes)}

        self.samples = []
        for class_name in self.classes:
            class_dir = os.path.join(self.data_dir, class_name)
            class_idx = self.class_to_idx[class_name]
            
            for img_name in os.listdir(class_dir):
                if img_name.lower().endswith(('.png', '.jpg', '.jpeg')):
                    img_path = os.path.join(class_dir, img_name)
                    self.samples.append((img_path, class_idx))
        
        print(f"loading {split} dataset: {len(self.samples)} images, {len(self.classes)} classes")
    
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        img_path, label = self.samples[idx]
        
        try:
            image = Image.open(img_path).convert('RGB')
        except Exception as e:
            print(f"failed to load image: {img_path}, error: {e}")
            image = Image.new('RGB', (224, 224), (0, 0, 0))
        
        if self.transform:
            image = self.transform(image)
        
        return image, label

# data enhance
def get_transforms(target_size=(640, 640)):
    train_transform = transforms.Compose([
        transforms.Resize(target_size),
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.RandomRotation(10),
        transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    
    val_transform = transforms.Compose([
        transforms.Resize(target_size),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    
    return train_transform, val_transform

def get_dataloaders(data_dir, batch_size=32, target_size=(640, 640), num_workers=8, include_test=False):
    print(f"building dataloader: {data_dir}")
    
    train_transform, val_transform = get_transforms(target_size)
    
    train_dataset = BBCHDataset(data_dir, 'train', train_transform)
    val_dataset = BBCHDataset(data_dir, 'val', val_transform)

    # for windows users
    import platform
    if platform.system() == 'Windows':
        num_workers = 0
    
    train_loader = DataLoader(
        train_dataset, batch_size=batch_size, shuffle=True,
        num_workers=num_workers, pin_memory=torch.cuda.is_available()
    )
    
    val_loader = DataLoader(
        val_dataset, batch_size=batch_size, shuffle=False,
        num_workers=num_workers, pin_memory=torch.cuda.is_available()
    )
    
    test_loader = None
    if include_test:
        test_dir = os.path.join(data_dir, 'test')
        if os.path.exists(test_dir):
            test_dataset = BBCHDataset(data_dir, 'test', val_transform)
            test_loader = DataLoader(
                test_dataset, batch_size=batch_size, shuffle=False,
                num_workers=num_workers, pin_memory=torch.cuda.is_available()
            )
        else:
            print("test set not found, using val as test")
            test_loader = val_loader

    return train_loader, val_loader, test_loader

In [3]:
CLASSIFICATION_CONFIG = {
    'data_dir': 'TomatoMAP/TomatoMAP-Cls',
    'model_name': 'mobilenet_v3_large',  # 'mobilenet_v3_small', 'mobilenet_v2', 'resnet18'
    'num_classes': 50,
    'batch_size': 32,
    'num_epochs': 30,
    'learning_rate': 1e-4,
    'target_size': (640, 640),
    'patience': 3,
    'save_interval': 20
}

print("config:")
for key, value in CLASSIFICATION_CONFIG.items():
    print(f"  {key}: {value}")

config:
  data_dir: TomatoMAP/TomatoMAP-Cls
  model_name: mobilenet_v3_large
  num_classes: 50
  batch_size: 32
  num_epochs: 30
  learning_rate: 0.0001
  target_size: (640, 640)
  patience: 3
  save_interval: 20


In [None]:
def train_model(config):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")
    
    output_dir = Path(f"cls/runs/{config['model_name']}_cls")
    output_dir.mkdir(parents=True, exist_ok=True)
    
    train_loader, val_loader, test_loader = get_dataloaders(
        config['data_dir'], 
        batch_size=config['batch_size'],
        target_size=config['target_size'],
        num_workers=8,
        include_test=True
    )
    
    model = get_model(config['model_name'], config['num_classes'], pretrained=True)
    model = model.to(device)
    
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.AdamW(model.parameters(), lr=config['learning_rate'], weight_decay=0.01)
    
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)
    
    train_losses = []
    val_losses = []
    train_accuracies = []
    val_accuracies = []
    
    best_val_acc = 0.0
    patience_counter = 0
    
    print(f"Training start with {config['num_epochs']} epoch(s),")
    
    for epoch in range(config['num_epochs']):
        model.train()
        train_loss = 0.0
        train_correct = 0
        train_total = 0
        
        train_pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{config['num_epochs']} [Train]")
        
        for batch_idx, (images, labels) in enumerate(train_pbar):
            images, labels = images.to(device), labels.to(device)
            
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            
            loss.backward()
            optimizer.step()
            
            train_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            train_total += labels.size(0)
            train_correct += (predicted == labels).sum().item()
            
            current_acc = 100 * train_correct / train_total
            train_pbar.set_postfix({
                'Loss': f'{loss.item():.4f}',
                'Acc': f'{current_acc:.2f}%'
            })
        
        model.eval()
        val_loss = 0.0
        val_correct = 0
        val_total = 0
        
        with torch.no_grad():
            val_pbar = tqdm(val_loader, desc=f"Epoch {epoch+1}/{config['num_epochs']} [Val]")
            
            for images, labels in val_pbar:
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                loss = criterion(outputs, labels)
                
                val_loss += loss.item()
                _, predicted = torch.max(outputs.data, 1)
                val_total += labels.size(0)
                val_correct += (predicted == labels).sum().item()
                
                current_acc = 100 * val_correct / val_total
                val_pbar.set_postfix({
                    'Loss': f'{loss.item():.4f}',
                    'Acc': f'{current_acc:.2f}%'
                })
        
        avg_train_loss = train_loss / len(train_loader)
        avg_val_loss = val_loss / len(val_loader)
        train_acc = 100 * train_correct / train_total
        val_acc = 100 * val_correct / val_total
        
        train_losses.append(avg_train_loss)
        val_losses.append(avg_val_loss)
        train_accuracies.append(train_acc)
        val_accuracies.append(val_acc)
        
        scheduler.step()
        current_lr = optimizer.param_groups[0]['lr']
        
        print(f"\n Epoch {epoch+1}/{config['num_epochs']}:")
        print(f"  Train - Loss: {avg_train_loss:.4f}, Acc: {train_acc:.2f}%")
        print(f"  Val - Loss: {avg_val_loss:.4f}, Acc: {val_acc:.2f}%")
        print(f"  lr: {current_lr:.2e}")
        
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            best_model_path = output_dir / f"best_{config['model_name']}.pth"
            save_model(model, best_model_path)
            patience_counter = 0
            print(f"  best val acc: {best_val_acc:.2f}%")
        else:
            patience_counter += 1
            print(f"  val acc not raised ({patience_counter}/{config['patience']})")
        
        if (epoch + 1) % config['save_interval'] == 0:
            checkpoint_path = output_dir / f"checkpoint_epoch_{epoch+1}.pth"
            save_checkpoint(model, optimizer, epoch, checkpoint_path)
        
        if patience_counter >= config['patience']:
            print(f"\n Trigger early stop. Val acc has {config['patience']} epochs no improve")
            break
        
        print("-" * 60)
    
    final_model_path = output_dir / f"final_{config['model_name']}.pth"
    save_model(model, final_model_path)
    
    print(f"\n TomatoMAP-Cls is trained!")
    print(f"  best val acc: {best_val_acc:.2f}%")
    print(f"  model saved at: {output_dir}")
    
    plt.figure(figsize=(15, 5))
    
    plt.subplot(1, 3, 1)
    plt.plot(train_losses, label='train loss', color='blue')
    plt.plot(val_losses, label='val loss', color='red')
    plt.title('training loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.grid(True, alpha=0.3)
    
    plt.subplot(1, 3, 2)
    plt.plot(train_accuracies, label='train acc', color='blue')
    plt.plot(val_accuracies, label='val acc', color='red')
    plt.title('training acc')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy (%)')
    plt.legend()
    plt.grid(True, alpha=0.3)
    
    plt.subplot(1, 3, 3)
    lrs = []
    for i in range(len(train_losses)):
        if i < 30:
            lrs.append(config['learning_rate'])
        elif i < 60:
            lrs.append(config['learning_rate'] * 0.1)
        else:
            lrs.append(config['learning_rate'] * 0.01)
    plt.plot(lrs, label='lr', color='green')
    plt.title('lr changes')
    plt.xlabel('Epoch')
    plt.ylabel('Learning Rate')
    plt.yscale('log')
    plt.legend()
    plt.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.savefig(output_dir / 'training_curves.png', dpi=300, bbox_inches='tight')
    plt.show()
    
    history_df = pd.DataFrame({
        'epoch': range(1, len(train_losses) + 1),
        'train_loss': train_losses,
        'val_loss': val_losses,
        'train_acc': train_accuracies,
        'val_acc': val_accuracies
    })
    history_df.to_csv(output_dir / 'training_history.csv', index=False)
    print(f" training log saved at: {output_dir / 'training_history.csv'}")
    
    return model, best_val_acc, output_dir, test_loader


print("=" * 60)
print("TomatoMAP-Cls Trainer")
print("=" * 60)

if not os.path.exists(CLASSIFICATION_CONFIG['data_dir']):
    print(f"dataset not exist")
    print(f"   path: {CLASSIFICATION_CONFIG['data_dir']}")
    print(f"   please check data structure")
else:
    print(f"data founded at: {CLASSIFICATION_CONFIG['data_dir']}")
    
    train_dir = os.path.join(CLASSIFICATION_CONFIG['data_dir'], 'train')
    val_dir = os.path.join(CLASSIFICATION_CONFIG['data_dir'], 'val')
    test_dir = os.path.join(CLASSIFICATION_CONFIG['data_dir'], 'test')
    
    if not os.path.exists(train_dir):
        print(f"training subset not exist: {train_dir}")
    elif not os.path.exists(val_dir):
        print(f"val subset not exist: {val_dir}")
    elif not os.path.exists(test_dir):
        print(f"test subset not exist: {test_dir}")
        print(f"   using val subset for test")
    else:
        print(f"TomatoMAP-Cls is well structured.")
        
        print("\n training config:")
        for key, value in CLASSIFICATION_CONFIG.items():
            print(f"   {key}: {value}")
        
        print("\n training start.")
        
        try:
            model, best_acc, output_dir, test_loader = train_model(CLASSIFICATION_CONFIG)
            
            print("\n" + "=" * 60)
            print("\n training finished!")
            print(f"   best val acc is: {best_acc:.2f}%")
            print(f"   model saved at: {output_dir}")
            
            print("\n evaluating on test subset...")
            model.eval()
            test_correct = 0
            test_total = 0
            test_predictions = []
            test_labels = []
            
            device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
            
            with torch.no_grad():
                test_pbar = tqdm(test_loader, desc="evaluating")
                for images, labels in test_pbar:
                    images, labels = images.to(device), labels.to(device)
                    outputs = model(images)
                    _, predicted = torch.max(outputs.data, 1)
                    
                    test_total += labels.size(0)
                    test_correct += (predicted == labels).sum().item()

                    test_predictions.extend(predicted.cpu().numpy())
                    test_labels.extend(labels.cpu().numpy())
                    
                    current_acc = 100 * test_correct / test_total
                    test_pbar.set_postfix({'Acc': f'{current_acc:.2f}%'})
            
            test_accuracy = 100 * test_correct / test_total
            print(f" test subset acc: {test_accuracy:.2f}%")

            print("\n building confusion matrix")
            
            train_dataset = test_loader.dataset
            class_names = train_dataset.classes
            
            cm = confusion_matrix(test_labels, test_predictions)
            
            cm_df = pd.DataFrame(cm, index=class_names, columns=class_names)
            cm_df.to_csv(output_dir / 'confusion_matrix.csv')

            normalized_cm = cm_df.div(cm_df.sum(axis=1), axis=0).fillna(0)
            
            matrix = normalized_cm.T.to_numpy()
            
            from matplotlib import rcParams
            # rcParams['font.family'] = 'Calibri' # Ubuntu doesn't own this when training on ubuntu VM
            rcParams['font.size'] = 8
            
            masked_matrix = np.ma.masked_where(matrix == 0, matrix)
            
            from matplotlib.colors import Normalize
            cmap = plt.cm.jet
            cmap.set_bad(color='white')
            norm = Normalize(vmin=0.1, vmax=1)
            
            fig_width_in = 3.1
            fig_height_in = fig_width_in
            fig, ax = plt.subplots(figsize=(fig_width_in, fig_height_in))
            
            im = ax.imshow(masked_matrix, cmap=cmap, norm=norm)

            # For further process for publishing purpose, labels are removed :)
            ax.set_xlabel("")
            ax.set_ylabel("")
            
            ax.set_xticks([])
            ax.set_yticks([])
            
            plt.tight_layout()
            plt.savefig(output_dir / 'normalized_confusion_matrix.png', format='png', dpi=300)
            plt.show()
            
            # plt.figure(figsize=(12, 10))
            # disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=class_names)
            # disp.plot(cmap='Blues', values_format='d')
            # plt.title(f'Detailed Confusion Matrix (test acc: {test_accuracy:.2f}%)', fontsize=8)
            # plt.xticks(rotation=45, ha='right')
            # plt.yticks(rotation=0)
            # plt.tight_layout()
            # plt.savefig(output_dir / 'detailed_confusion_matrix.png', dpi=300, bbox_inches='tight')
            # plt.show()
            
            test_results = {
                'test_accuracy': test_accuracy,
                'total_samples': test_total,
                'correct_predictions': test_correct,
                'num_classes': len(class_names),
                'class_names': class_names
            }
            
            import json
            with open(output_dir / 'test_results.json', 'w', encoding='utf-8') as f:
                json.dump(test_results, f, indent=2, ensure_ascii=False)
            
            print("\n" + "=" * 60)
            print(" Evaluation results:")
            print(f"   best val acc: {best_acc:.2f}%")
            print(f"   test acc: {test_accuracy:.2f}%")
            print(f"   class number: {len(class_names)}")
            print(f"   test data size: {test_total}")
            print(f"   results saved at: {output_dir}")
            print(f"   GGWP!")
            print("=" * 60)
            
        except KeyboardInterrupt:
            print("\n training interruptted")
            
        except Exception as e:
            print(f"\n error during training:")
            print(f"   error info: {str(e)}")
            print("\nDetails:")
            import traceback
            traceback.print_exc()

TomatoMAP-Cls Trainer
data founded at: TomatoMAP/TomatoMAP-Cls
TomatoMAP-Cls is well structured.

 training config:
   data_dir: TomatoMAP/TomatoMAP-Cls
   model_name: mobilenet_v3_large
   num_classes: 50
   batch_size: 32
   num_epochs: 30
   learning_rate: 0.0001
   target_size: (640, 640)
   patience: 3
   save_interval: 20

 training start.
Using device: cuda
building dataloader: TomatoMAP/TomatoMAP-Cls
loading train dataset: 45099 images, 50 classes
loading val dataset: 12870 images, 50 classes
loading test dataset: 6495 images, 50 classes
build model: mobilenet_v3_large, class number: 50
parameter info: Total: 4,266,082, Trainable: 4,266,082
Training start with 30 epoch(s),


Epoch 1/30 [Train]: 100%|████████████████| 1410/1410 [06:55<00:00,  3.39it/s, Loss=1.8002, Acc=32.31%]
Epoch 1/30 [Val]: 100%|████████████████████| 403/403 [01:00<00:00,  6.67it/s, Loss=1.8581, Acc=40.05%]



 Epoch 1/30:
  Train - Loss: 2.1123, Acc: 32.31%
  Val - Loss: 1.6768, Acc: 40.05%
  lr: 1.00e-04
model saved at: cls/runs/mobilenet_v3_large_cls/best_mobilenet_v3_large.pth
  best val acc: 40.05%
------------------------------------------------------------


Epoch 2/30 [Train]: 100%|████████████████| 1410/1410 [06:48<00:00,  3.45it/s, Loss=1.9883, Acc=42.30%]
Epoch 2/30 [Val]: 100%|████████████████████| 403/403 [01:07<00:00,  5.99it/s, Loss=0.9373, Acc=45.84%]



 Epoch 2/30:
  Train - Loss: 1.6399, Acc: 42.30%
  Val - Loss: 1.4930, Acc: 45.84%
  lr: 1.00e-04
model saved at: cls/runs/mobilenet_v3_large_cls/best_mobilenet_v3_large.pth
  best val acc: 45.84%
------------------------------------------------------------


Epoch 3/30 [Train]: 100%|████████████████| 1410/1410 [06:52<00:00,  3.42it/s, Loss=1.7607, Acc=47.10%]
Epoch 3/30 [Val]: 100%|████████████████████| 403/403 [01:00<00:00,  6.64it/s, Loss=0.9665, Acc=47.71%]



 Epoch 3/30:
  Train - Loss: 1.4578, Acc: 47.10%
  Val - Loss: 1.4351, Acc: 47.71%
  lr: 1.00e-04
model saved at: cls/runs/mobilenet_v3_large_cls/best_mobilenet_v3_large.pth
  best val acc: 47.71%
------------------------------------------------------------


Epoch 4/30 [Train]: 100%|████████████████| 1410/1410 [06:46<00:00,  3.47it/s, Loss=1.2591, Acc=51.96%]
Epoch 4/30 [Val]: 100%|████████████████████| 403/403 [01:06<00:00,  6.06it/s, Loss=1.2630, Acc=53.02%]



 Epoch 4/30:
  Train - Loss: 1.3072, Acc: 51.96%
  Val - Loss: 1.2602, Acc: 53.02%
  lr: 1.00e-04
model saved at: cls/runs/mobilenet_v3_large_cls/best_mobilenet_v3_large.pth
  best val acc: 53.02%
------------------------------------------------------------


Epoch 5/30 [Train]:  28%|████▊            | 395/1410 [01:58<04:26,  3.80it/s, Loss=1.2057, Acc=55.83%]

# TomatoMAP-Det Trainer

In [None]:
from ultralytics import YOLO
from ultralytics import RTDETR

# using proper libiary?
import ultralytics
ultralytics.checks()
print(ultralytics.__file__)

In [None]:
torch.use_deterministic_algorithms(False)

print("\n" + "=" * 60)
print("TomatoMAP-Det Trainer")
print("\n" + "=" * 60)

print("downloading pretrained model: ")

model = YOLO("yolo11l.pt")

print("model info: ")

train_result = model.train(
    data="det/TomatoMAP-Det.yaml",
    epochs=500,
    imgsz=640,
    device=[0],
    batch=32,
    patience=10,
    project="det/output",
    cfg="det/best_hyperparameters.yaml", # fine-tuned hyperparameters, ready to use, details please contact us per email
    #profile=True,
    plots=True
)

# TomatoMAP-Seg Trainer

In [None]:
import os
import cv2
import json
import yaml
import random
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
from tqdm import tqdm
from pathlib import Path
from collections import OrderedDict

# Detectron2
import detectron2
from detectron2.engine import DefaultTrainer, DefaultPredictor, HookBase
from detectron2.config import get_cfg
from detectron2 import model_zoo
from detectron2.data import MetadataCatalog, build_detection_test_loader, build_detection_train_loader
from detectron2.data.datasets import register_coco_instances
from detectron2.evaluation import COCOEvaluator, inference_on_dataset
from detectron2.utils.visualizer import Visualizer
from detectron2.utils.events import get_event_storage
from detectron2.utils.logger import setup_logger

setup_logger()

print(f"  Detectron2 version: {detectron2.__version__}")

In [None]:
def flatten_segmentation(points):
    #format points [[x,y],[x,y]] to [x1,y1,x2,y2,...]
    return [coord for pair in points for coord in pair]

def load_categories_from_yaml(yaml_path):
    with open(yaml_path, 'r', encoding='utf-8') as f:
        data = yaml.safe_load(f)
    categories = []
    cat_map = {}
    cat_id = 1
    for item in data['label']:
        name = item['name']
        if name == '__background__':
            continue
        categories.append({
            "id": cat_id,
            "name": name,
            "supercategory": "none"
        })
        cat_map[name] = cat_id
        cat_id += 1
    return categories, cat_map

def convert_isat_folder_to_coco(task_dir, label_dir, yaml_path, output_dir, train_ratio=0.7, val_ratio=0.2):
    print("ISAT2COCO...")
    
    os.makedirs(output_dir, exist_ok=True)

    categories, category_map = load_categories_from_yaml(yaml_path)
    print(f"loaded {len(categories)} classes")

    if not os.path.exists(task_dir):
        print(f"image folder not exist: {task_dir}")
        return False
    
    if not os.path.exists(label_dir):
        print(f"label folder not exist: {label_dir}")
        return False

    images = [f for f in os.listdir(task_dir) if f.lower().endswith(('.jpg', '.jpeg', '.png', '.bmp'))]
    json_map = {os.path.splitext(f)[0]: f for f in os.listdir(label_dir) if f.endswith(".json")}

    print(f"found {len(images)} images")
    print(f"found {len(json_map)} labels")

    # matching image and labels
    dataset = []
    unmatched_images = []
    
    for img_name in tqdm(images, desc="matching image and labels"):
        base = os.path.splitext(img_name)[0]
        if base in json_map:
            dataset.append({
                "img_file": img_name,
                "json_file": json_map[base]
            })
        else:
            unmatched_images.append(img_name)

    print(f"successfully matches {len(dataset)} pairs")
    if unmatched_images:
        print(f" {len(unmatched_images)} unmatches images")

    if len(dataset) == 0:
        print("no matched pairs")
        return False

    random.shuffle(dataset)
    total = len(dataset)
    train_end = int(total * train_ratio)
    val_end = int(total * (train_ratio + val_ratio))

    splits = {
        "train": dataset[:train_end],
        "val": dataset[train_end:val_end],
        "test": dataset[val_end:]
    }

    print(f"\n spilting dataset:")
    for split_name, split_data in splits.items():
        print(f"  {split_name}: {len(split_data)} iamges ({len(split_data)/total*100:.1f}%)")

    conversion_stats = {}
    
    for split_name, split_data in splits.items():
        if len(split_data) == 0:
            print(f" {split_name} dataset empty, skipping")
            continue
            
        print(f"\n transforming {split_name} dataset.")
        
        coco = {
            "images": [],
            "annotations": [],
            "categories": categories
        }
        ann_id = 1
        img_id = 1
        
        processed_annotations = 0
        skipped_annotations = 0

        for item in tqdm(split_data, desc=f"处理{split_name}"):
            img_path = os.path.join(task_dir, item["img_file"])
            json_path = os.path.join(label_dir, item["json_file"])

            if not os.path.exists(json_path):
                print(f" label files not exist: {json_path}")
                continue

            try:
                with open(json_path, 'r', encoding='utf-8') as f:
                    isat = json.load(f)
            except Exception as e:
                print(f"failed to load label file {json_path}: {e}")
                continue

            info = isat['info']
            coco["images"].append({
                "file_name": item["img_file"],
                "id": img_id,
                "width": info["width"],
                "height": info["height"]
            })

            for obj in isat.get('objects', []):
                cat = obj['category']
                if cat not in category_map:
                    skipped_annotations += 1
                    continue

                seg_flat = flatten_segmentation(obj["segmentation"])
                if len(seg_flat) < 6:
                    skipped_annotations += 1
                    continue

                coco["annotations"].append({
                    "id": ann_id,
                    "image_id": img_id,
                    "category_id": category_map[cat],
                    "segmentation": [seg_flat],
                    "bbox": obj["bbox"],
                    "area": obj["area"],
                    "iscrowd": obj.get("iscrowd", 0),
                    "group_id": obj.get("group", None)
                })
                ann_id += 1
                processed_annotations += 1
                
            img_id += 1

        output_file = os.path.join(output_dir, f"{split_name}.json")
        with open(output_file, "w", encoding='utf-8') as f:
            json.dump(coco, f, indent=2, ensure_ascii=False)

        conversion_stats[split_name] = {
            'images': len(split_data),
            'annotations': processed_annotations,
            'skipped': skipped_annotations
        }

        print(f"  generated {split_name}.json")
        print(f"     images: {len(split_data)}")
        print(f"     labels: {processed_annotations}")
        print(f"     skipped: {skipped_annotations}")

    print(f"\n ISAT2COCO finished")
    print(f" output at: {output_dir}")
    print(f" transform info:")
    
    total_images = sum(stats['images'] for stats in conversion_stats.values())
    total_annotations = sum(stats['annotations'] for stats in conversion_stats.values())
    total_skipped = sum(stats['skipped'] for stats in conversion_stats.values())
    
    print(f"  total images: {total_images}")
    print(f"  total labels: {total_annotations}")
    print(f"  skipped labels: {total_skipped}")
    
    return True

In [None]:
ISAT_CONFIG = {
    'task_dir': "TomatoMAP/TomatoMAP-Seg/images",
    'label_dir': "TomatoMAP/TomatoMAP-Seg/labels",
    'yaml_path': "TomatoMAP/TomatoMAP-Seg/labels/isat.yaml",
    'output_dir': "TomatoMAP/TomatoMAP-Seg/cocoOut",
    'train_ratio': 0.7,
    'val_ratio': 0.2,    # rest 0.1 is test
    'auto_convert': True
}

DATASET_CONFIG = {
    'dataset_root': "TomatoMAP/TomatoMAP-Seg/",
    'img_dir': "TomatoMAP/TomatoMAP-Seg/images",
    'coco_ann_dir': "TomatoMAP/TomatoMAP-Seg/cocoOut",
    'isat_yaml_path': "TomatoMAP/TomatoMAP-Seg/labels/isat.yaml",
    'output_dir': "TomatoMAP/TomatoMAP-Seg/output",
    'num_classes': 10,    # without background
}

TRAINING_CONFIG = {
    'model_name': "mask_rcnn_R_50_FPN_3x",
    'batch_size': 4,
    'base_lr': 0.00024,
    'max_epochs': 100,
    'patience': 15,
    'num_workers': 8,  # Windows user please set to 0
    'score_thresh_test': 0.3,
    'input_min_size_train': (640, 672, 704, 736, 768, 800),  # 多尺度训练
    'input_max_size_train': 1333,
    'checkpoint_period': 10,
    'eval_period': 10,
}

print("Configurations:")
print("ISAT converter config:")
for key, value in ISAT_CONFIG.items():
    print(f"  {key}: {value}")
print("\n dataset config:")
for key, value in DATASET_CONFIG.items():
    print(f"  {key}: {value}")
print("\n training config:")
for key, value in TRAINING_CONFIG.items():
    print(f"  {key}: {value}")

In [None]:
print("\n" + "=" * 60)
print("ISAT coverting to COCO format")
print("=" * 60)

need_conversion = ISAT_CONFIG['auto_convert']

coco_files_exist = all(
    os.path.exists(os.path.join(ISAT_CONFIG['output_dir'], f"{split}.json"))
    for split in ['train', 'val', 'test']
)

if coco_files_exist and not need_conversion:
    print("coco format exist, skipping!")
    print("   if wanna reconvert, set ISAT_CONFIG['auto_convert'] = True")
else:
    required_isat_paths = [
        ISAT_CONFIG['task_dir'],
        ISAT_CONFIG['label_dir'], 
        ISAT_CONFIG['yaml_path']
    ]
    
    missing_paths = [path for path in required_isat_paths if not os.path.exists(path)]
    
    if missing_paths:
        print("following ISAT path not exist:")
        for path in missing_paths:
            print(f"   {path}")
        print("\nplease check ISAT_CONFIG path setting")
        conversion_success = False
    else:
        print("ISAT checked, start converting...")
        
        conversion_success = convert_isat_folder_to_coco(
            task_dir=ISAT_CONFIG['task_dir'],
            label_dir=ISAT_CONFIG['label_dir'],
            yaml_path=ISAT_CONFIG['yaml_path'],
            output_dir=ISAT_CONFIG['output_dir'],
            train_ratio=ISAT_CONFIG['train_ratio'],
            val_ratio=ISAT_CONFIG['val_ratio']
        )

if 'conversion_success' not in locals():
    conversion_success = True

if conversion_success:
    DATASET_CONFIG['coco_ann_dir'] = ISAT_CONFIG['output_dir']
    DATASET_CONFIG['img_dir'] = ISAT_CONFIG['task_dir'] 
    DATASET_CONFIG['isat_yaml_path'] = ISAT_CONFIG['yaml_path']
    print(f"\n Configuration of dataset is updated:")
    print(f"   image path: {DATASET_CONFIG['img_dir']}")
    print(f"   label path: {DATASET_CONFIG['coco_ann_dir']}")
else:
    print("\n Transfer failed")

In [None]:
def analyze_dataset_areas():
    print(f"\n analyze dataset object areas...")
    print(f"{'='*60}")
    
    for split in ['train', 'val', 'test']:
        ann_file = os.path.join(DATASET_CONFIG['coco_ann_dir'], f"{split}.json")
        if not os.path.exists(ann_file):
            print(f"label file {ann_file} not exist")
            continue
            
        with open(ann_file, 'r') as f:
            data = json.load(f)
        
        image_info = {img['id']: img for img in data['images']}
        
        areas_original = []
        areas_scaled = []
        
        min_size = min(TRAINING_CONFIG['input_min_size_train'])
        max_size = TRAINING_CONFIG['input_max_size_train']
        
        for ann in data['annotations']:
            if 'area' in ann:
                area = ann['area']
            else:
                bbox = ann.get('bbox', [0, 0, 0, 0])
                area = bbox[2] * bbox[3]
            areas_original.append(area)
            
            img_id = ann['image_id']
            if img_id in image_info:
                img = image_info[img_id]
                orig_w, orig_h = img['width'], img['height']
                
                size = max(orig_w, orig_h)
                if size > max_size:
                    scale = max_size / size
                else:
                    scale = min_size / min(orig_w, orig_h)
                    if scale * size > max_size:
                        scale = max_size / size
                
                scaled_area = area * (scale ** 2)
                areas_scaled.append(scaled_area)
        
        areas_original = np.array(areas_original)
        areas_scaled = np.array(areas_scaled) if areas_scaled else areas_original
        
        print(f"\n{split.upper()} dataset analysis:")
        print(f"-" * 40)
        
        if len(data['images']) > 0:
            avg_width = np.mean([img['width'] for img in data['images']])
            avg_height = np.mean([img['height'] for img in data['images']])
            print(f"average size: {avg_width:.0f} x {avg_height:.0f}")
        
        print(f"total object amount: {len(areas_original)}")
        
        print(f"\n original object image size distribution:")
        small_orig = np.sum(areas_original < 32**2)
        medium_orig = np.sum((areas_original >= 32**2) & (areas_original < 96**2))
        large_orig = np.sum(areas_original >= 96**2)
        
        print(f"  small object (<32²): {small_orig} ({small_orig/len(areas_original)*100:.1f}%)")
        print(f"  mid object (32²-96²): {medium_orig} ({medium_orig/len(areas_original)*100:.1f}%)")
        print(f"  big object (>96²): {large_orig} ({large_orig/len(areas_original)*100:.1f}%)")
        print(f"  min area: {np.min(areas_original):.0f} pixel²")
        print(f"  max area: {np.max(areas_original):.0f} pixel²")
        print(f"  mean area: {np.mean(areas_original):.0f} pixel²")
        
        print(f"\nscaled to {min_size}-{max_size} :")
        small_scaled = np.sum(areas_scaled < 32**2)
        medium_scaled = np.sum((areas_scaled >= 32**2) & (areas_scaled < 96**2))
        large_scaled = np.sum(areas_scaled >= 96**2)
        
        print(f"  small object (<32²): {small_scaled} ({small_scaled/len(areas_scaled)*100:.1f}%)")
        print(f"  mid object (32²-96²): {medium_scaled} ({medium_scaled/len(areas_scaled)*100:.1f}%)")
        print(f"  big object (>96²): {large_scaled} ({large_scaled/len(areas_scaled)*100:.1f}%)")
        
        if small_scaled == 0:
            print(f"\n after scale, no small object - APs set to -1")
        if medium_scaled == 0:
            print(f" after scale, no mid object - APm set to -1")
        if large_scaled == 0:
            print(f" after scale, no big object - APl set to -1")

def get_dataset_info():
    print(f"\n TomatoMAP-Seg info:")
    print(f"{'='*40}")
    
    for split in ['train', 'val', 'test']:
        ann_file = os.path.join(DATASET_CONFIG['coco_ann_dir'], f"{split}.json")
        if os.path.exists(ann_file):
            with open(ann_file, 'r') as f:
                data = json.load(f)
            print(f"{split}: {len(data['images'])} images, {len(data['annotations'])} labels")
    
    return True

if conversion_success:
    get_dataset_info()
    analyze_dataset_areas()

In [None]:
class BestModelHook(HookBase):
    # hook to save the best model based on validation segmentation mAP
    
    def __init__(self, cfg, eval_period, patience=10):
        self.cfg = cfg.clone()
        self.eval_period = eval_period
        self.patience = patience
        self.best_score = 0  # 使用0而不是-1
        self.best_metric_name = None
        self.best_epoch = -1
        self.epochs_without_improvement = 0
        self.should_stop = False
        self.history = []  # 记录历史
        
    def get_valid_score(self, segm_results):
        priority_metrics = ["AP", "AP50", "AP75", "APm", "APl"]
        
        for metric in priority_metrics:
            value = segm_results.get(metric, -1)
            if value != -1:
                return metric, value
        
        return None, None
    
    def after_step(self):
        next_iter = self.trainer.iter + 1
        is_final_iter = next_iter == self.trainer.max_iter
        
        if (next_iter % self.eval_period == 0 and not is_final_iter):
            current_epoch = (next_iter // self.eval_period)
            
            results = self._do_eval()
            if results is None:
                print(f"Epoch {current_epoch}: evaluate failed")
                return
            
            segm_results = results.get("segm", {})
            bbox_results = results.get("bbox", {})
            
            metric_name, current_score = self.get_valid_score(segm_results)
            
            print(f"\n{'='*60}")
            print(f"Epoch {current_epoch} evaluate result:")
            print(f"{'='*60}")
            
            print("\n bbox metrics:")
            for key in ["AP", "AP50", "AP75", "APs", "APm", "APl"]:
                value = bbox_results.get(key, -1)
                if value != -1:
                    print(f"  {key}: {value:.4f} ✓")
                else:
                    print(f"  {key}: N/A")
            
            print("\n seg metrics:")
            for key in ["AP", "AP50", "AP75", "APs", "APm", "APl"]:
                value = segm_results.get(key, -1)
                if value != -1:
                    print(f"  {key}: {value:.4f} ✓")
                else:
                    print(f"  {key}: N/A (no object for this class)")
            
            if metric_name is None:
                print("\n waring! no useful metrics")
                print("please check TomatoMAP-Seg structure")

                metric_name, current_score = self.get_valid_score(bbox_results)
                if metric_name is not None:
                    print(f"using bbox metric: {metric_name} = {current_score:.4f}")
                else:
                    return
            
            print(f"\n main metrics: {metric_name} = {current_score:.4f}")
            
            self.history.append({
                'epoch': current_epoch,
                'metric': metric_name,
                'score': current_score,
                'all_metrics': {**segm_results, **{'bbox_' + k: v for k, v in bbox_results.items()}}
            })
            
            if current_score > self.best_score:
                improvement = current_score - self.best_score
                self.best_score = current_score
                self.best_metric_name = metric_name
                self.best_epoch = current_epoch
                self.epochs_without_improvement = 0
                
                self.trainer.checkpointer.save("model_best")
                print(f"\n best model saved")
                print(f"   score: {current_score:.4f} (↑{improvement:.4f})")
                
                best_results_file = os.path.join(self.cfg.OUTPUT_DIR, "best_results.json")
                with open(best_results_file, 'w') as f:
                    json.dump({
                        'epoch': current_epoch,
                        'metric': metric_name,
                        'score': current_score,
                        'segm_results': segm_results,
                        'bbox_results': bbox_results
                    }, f, indent=2)
            else:
                self.epochs_without_improvement += 1
                gap = self.best_score - current_score
                print(f"\ncurrent: {current_score:.4f} | best: {self.best_score:.4f} (gap: {gap:.4f})")
                print(f"continuted {self.epochs_without_improvement}/{self.patience} epoch no improve")
            
            if self.epochs_without_improvement >= self.patience:
                print(f"\n{'='*60}")
                print(f"early stop triggered")
                print(f"   best {self.best_metric_name}: {self.best_score:.4f} (epoch {self.best_epoch})")
                print(f"   total epochs: {current_epoch}")
                print(f"{'='*60}")
                self.should_stop = True

                self.trainer.storage._iter = self.trainer.max_iter
    
    def _do_eval(self):

        try:
            evaluator = COCOEvaluator("tomato_val", self.cfg, False, 
                                    output_dir=os.path.join(self.cfg.OUTPUT_DIR, "inference"))
            val_loader = build_detection_test_loader(self.cfg, "tomato_val")
            results = inference_on_dataset(self.trainer.model, val_loader, evaluator)
            return results
        except Exception as e:
            print(f"evaluate failed: {e}")
            import traceback
            traceback.print_exc()
            return None

class MyTrainer(DefaultTrainer):
    
    def __init__(self, cfg, patience=None):
        self.patience = patience if patience is not None else TRAINING_CONFIG['patience']
        super().__init__(cfg)
    
    @classmethod
    def build_evaluator(cls, cfg, dataset_name, output_folder=None):
        if output_folder is None:
            output_folder = os.path.join(cfg.OUTPUT_DIR, "inference")
        return COCOEvaluator(
            dataset_name=dataset_name,
            distributed=False,
            output_dir=output_folder,
            use_fast_impl=True,
            tasks=("bbox", "segm"),
        )
    
    def build_hooks(self):
        hooks = super().build_hooks()
        
        try:
            train_loader = build_detection_train_loader(self.cfg)
            iters_per_epoch = len(train_loader) // self.cfg.SOLVER.IMS_PER_BATCH
            print(f"iters per epoch: {iters_per_epoch}")
        except:
            iters_per_epoch = 106
            print(f"iters per epoch: {iters_per_epoch}")
        
        eval_period = iters_per_epoch * TRAINING_CONFIG['eval_period']
        
        best_model_hook = BestModelHook(self.cfg, eval_period, self.patience)
        hooks.append(best_model_hook)
        
        self.best_model_hook = best_model_hook
        
        return hooks
    
    def run_step(self):
        super().run_step()
        
        if hasattr(self, 'best_model_hook') and self.best_model_hook.should_stop:
            print("early stop triggered, training stop")
            self.storage._iter = self.max_iter
    
    def train(self):
        super().train()
        
        if hasattr(self, 'best_model_hook'):
            print(f"\n{'='*60}")
            print(f"training info:")
            print(f"{'='*60}")
            if self.best_model_hook.best_score > 0:
                print(f"best {self.best_model_hook.best_metric_name}: {self.best_model_hook.best_score:.4f}")
                print(f"best epoch: {self.best_model_hook.best_epoch}")
                print(f"best model saved as: model_best.pth")
            else:
                print("no metrics found for training")
            
            history_file = os.path.join(self.cfg.OUTPUT_DIR, "training_history.json")
            with open(history_file, 'w') as f:
                json.dump(self.best_model_hook.history, f, indent=2)
            print(f"training log saved at: {history_file}")

In [None]:
def register_all_datasets():
    print("register dataset")
    
    try:
        with open(DATASET_CONFIG['isat_yaml_path'], 'r', encoding='utf-8') as f:
            data = yaml.safe_load(f)
        labels = [item['name'] for item in data['label'] if item['name'] != '__background__']
        print(f"loading label classes: {len(labels)} classes")
        for i, label in enumerate(labels):
            print(f"  {i}: {label}")
    except Exception as e:
        print(f"class label loading failed: {e}")
        return None
    
    datasets = ['train', 'val', 'test']
    
    for dataset_name in datasets:
        dataset_key = f"tomato_{dataset_name}"
        
        try:
            from detectron2.data.datasets.coco import _PREDEFINED_SPLITS_COCO
            if dataset_key in _PREDEFINED_SPLITS_COCO:
                del _PREDEFINED_SPLITS_COCO[dataset_key]
        except ImportError:
            try:
                from detectron2.data.datasets.builtin import _PREDEFINED_SPLITS_COCO
                if dataset_key in _PREDEFINED_SPLITS_COCO:
                    del _PREDEFINED_SPLITS_COCO[dataset_key]
            except ImportError:
                try:
                    from detectron2.data.datasets.register_coco import _PREDEFINED_SPLITS_COCO
                    if dataset_key in _PREDEFINED_SPLITS_COCO:
                        del _PREDEFINED_SPLITS_COCO[dataset_key]
                except ImportError:
                    print(f"  can't clean {dataset_key} version cap)")
        
        try:
            if MetadataCatalog.has(dataset_key):
                MetadataCatalog.remove(dataset_key)
        except:
            pass
    
    for dataset_name in datasets:
        coco_json = os.path.join(DATASET_CONFIG['coco_ann_dir'], f"{dataset_name}.json")
        dataset_key = f"tomato_{dataset_name}"
        
        if os.path.exists(coco_json):
            abs_coco_json = os.path.abspath(coco_json)
            abs_img_dir = os.path.abspath(DATASET_CONFIG['img_dir'])
            
            try:
                register_coco_instances(
                    dataset_key, 
                    {}, 
                    abs_coco_json, 
                    abs_img_dir
                )
                MetadataCatalog.get(dataset_key).thing_classes = labels
                print(f"  registered {dataset_key}")
            except Exception as e:
                print(f"  register failed: {e}")
                try:
                    MetadataCatalog.get(dataset_key).thing_classes = labels
                    print(f"  re-setting {dataset_key} meta data")
                except Exception as e2:
                    print(f"  meta setting failed: {e2}")
        else:
            print(f"  can't find {coco_json}")
    
    return labels

def build_cfg():
    cfg = get_cfg()
    
    model_config_file = f"COCO-InstanceSegmentation/{TRAINING_CONFIG['model_name']}.yaml"
    cfg.merge_from_file(model_zoo.get_config_file(model_config_file))
    
    cfg.DATASETS.TRAIN = ("tomato_train",)
    cfg.DATASETS.TEST = ("tomato_val",)
    
    cfg.DATALOADER.NUM_WORKERS = TRAINING_CONFIG['num_workers']
    
    cfg.MODEL.ROI_HEADS.NUM_CLASSES = DATASET_CONFIG['num_classes']
    cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url(model_config_file)
    
    cfg.SOLVER.IMS_PER_BATCH = TRAINING_CONFIG['batch_size']
    cfg.SOLVER.BASE_LR = TRAINING_CONFIG['base_lr']
    
    estimated_iters_per_epoch = 106
    cfg.SOLVER.MAX_ITER = estimated_iters_per_epoch * TRAINING_CONFIG['max_epochs']
    
    cfg.SOLVER.STEPS = (int(cfg.SOLVER.MAX_ITER * 0.7), int(cfg.SOLVER.MAX_ITER * 0.9))
    cfg.SOLVER.GAMMA = 0.1
    
    cfg.INPUT.MIN_SIZE_TRAIN = TRAINING_CONFIG['input_min_size_train']
    cfg.INPUT.MAX_SIZE_TRAIN = TRAINING_CONFIG['input_max_size_train']
    
    cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = TRAINING_CONFIG['score_thresh_test']
    
    cfg.OUTPUT_DIR = DATASET_CONFIG['output_dir']
    os.makedirs(cfg.OUTPUT_DIR, exist_ok=True)
    
    cfg.MODEL.DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
    
    cfg.SOLVER.CHECKPOINT_PERIOD = estimated_iters_per_epoch * TRAINING_CONFIG['checkpoint_period']
    
    return cfg

if conversion_success:
    print("\n" + "=" * 60)
    print("TomatoMAP-Seg Registeration")
    print("=" * 60)
    
    class_labels = register_all_datasets()
    
    if class_labels is not None:
        print("TomatoMAP-Seg is registered")
        
        cfg = build_cfg()
        print("building configed")
        
        print(f"Configuration:")
        print(f"  model: {TRAINING_CONFIG['model_name']}")
        print(f"  class num: {DATASET_CONFIG['num_classes']}")
        print(f"  batch size: {TRAINING_CONFIG['batch_size']}")
        print(f"  lr: {TRAINING_CONFIG['base_lr']}")
        print(f"  max epoch: {TRAINING_CONFIG['max_epochs']}")
        print(f"  patience: {TRAINING_CONFIG['patience']}")
        print(f"  imput size: {TRAINING_CONFIG['input_min_size_train'][0]}-{TRAINING_CONFIG['input_max_size_train']}")
        print(f"  output path: {cfg.OUTPUT_DIR}")
        print(f"  device: {cfg.MODEL.DEVICE}")
    else:
        print("data registeration failed")
        conversion_success = False

In [None]:
def train_model():
    print("Training TomatoMAP-Seg")
    
    trainer = MyTrainer(cfg, patience=TRAINING_CONFIG['patience'])
    trainer.resume_or_load(resume=False)
    
    print(f"\n training configuration:")
    print(f"  model: {TRAINING_CONFIG['model_name']}")
    print(f"  max epoch: {TRAINING_CONFIG['max_epochs']}")
    print(f"  patience: {TRAINING_CONFIG['patience']} epochs")
    print(f"  eval period: per {TRAINING_CONFIG['eval_period']} epochs")
    print(f"  save check point: per {TRAINING_CONFIG['checkpoint_period']} epochs")
    print(f"  multi scale training: {TRAINING_CONFIG['input_min_size_train'][0]}-{TRAINING_CONFIG['input_max_size_train']}")
    
    print(f"\n{'='*60}")
    print(f"training start")
    print(f"{'='*60}")
    
    try:
        trainer.train()
        
        print("\n training finished")
        
        config_path = os.path.join(cfg.OUTPUT_DIR, "config.yaml")
        with open(config_path, "w") as f:
            f.write(cfg.dump())
        print(f"config saved: {config_path}")
        
        return trainer, cfg
        
    except KeyboardInterrupt:
        print("\n training interrupted")
        return None, cfg
        
    except Exception as e:
        print(f"\n error occurs:")
        print(f"   error info: {str(e)}")
        import traceback
        traceback.print_exc()
        return None, cfg

if 'class_labels' in locals() and class_labels is not None and conversion_success:
    print("\n" + "=" * 60)
    print("ready to start training")
    print("=" * 60)
    
    trainer, cfg = train_model()
    
    if trainer is not None:
        print("\n training finished")
        
        print(f"\n output path:")
        output_dir = Path(cfg.OUTPUT_DIR)
        if output_dir.exists():
            for file in output_dir.iterdir():
                if file.is_file():
                    print(f"  📄 {file.name}")
    else:
        print("\n training failed")
else:
    print("can't start training, please check data structure")

In [None]:
def evaluate_model(model_path="model_best.pth", dataset_name="tomato_test"):
    print(f"evaluating {dataset_name} ...")
    
    eval_cfg = build_cfg()
    
    full_model_path = os.path.join(eval_cfg.OUTPUT_DIR, model_path)
    if not os.path.exists(full_model_path):
        print(f"model not exist: {full_model_path}")

        final_model_path = os.path.join(eval_cfg.OUTPUT_DIR, "model_final.pth")
        if os.path.exists(final_model_path):
            full_model_path = final_model_path
            print(f"using final model: {final_model_path}")
        else:
            print("can't find any models")
            return None
    
    eval_cfg.MODEL.WEIGHTS = full_model_path
    print(f"load model: {full_model_path}")
    
    try:
        evaluator = COCOEvaluator(dataset_name, eval_cfg, False, output_dir=eval_cfg.OUTPUT_DIR)
        test_loader = build_detection_test_loader(eval_cfg, dataset_name)
        
        model = MyTrainer.build_model(eval_cfg)
        
        print("start evaluating")
        results = inference_on_dataset(model, test_loader, evaluator)
        
        print("\n evaluation result:")
        
        if "bbox" in results:
            print("\n bbox result:")
            bbox_results = results["bbox"]
            for key in ["AP", "AP50", "AP75", "APs", "APm", "APl"]:
                value = bbox_results.get(key, -1)
                if value != -1:
                    print(f"  {key}: {value:.4f}")
                else:
                    print(f"  {key}: N/A")
        
        if "segm" in results:
            print("\n segm result:")
            segm_results = results["segm"]
            for key in ["AP", "AP50", "AP75", "APs", "APm", "APl"]:
                value = segm_results.get(key, -1)
                if value != -1:
                    print(f"  {key}: {value:.4f}")
                else:
                    print(f"  {key}: N/A (no objects in this size)")
        
        results_file = os.path.join(eval_cfg.OUTPUT_DIR, f"eval_results_{dataset_name}.json")
        with open(results_file, 'w') as f:
            json.dump(results, f, indent=2)
        print(f"\n evaluation results saved at: {results_file}")
        
        return results
        
    except Exception as e:
        print(f"failed to evaluate: {e}")
        import traceback
        traceback.print_exc()
        return None

if 'trainer' in locals() and trainer is not None:
    print("\n" + "=" * 60)
    print("Model evaluation started")
    print("=" * 60)
    
    test_results = evaluate_model("model_best.pth", "tomato_test")
    
    final_results = evaluate_model("model_final.pth", "tomato_test")

In [None]:
def visualize_predictions(dataset_name="tomato_test", num_samples=5, model_path="model_best.pth"):

    print(f"plotting {dataset_name} inference result")
    
    vis_cfg = build_cfg()
    full_model_path = os.path.join(vis_cfg.OUTPUT_DIR, model_path)
    
    if not os.path.exists(full_model_path):
        print(f"model not exist: {full_model_path}")

        final_model_path = os.path.join(vis_cfg.OUTPUT_DIR, "model_final.pth")
        if os.path.exists(final_model_path):
            full_model_path = final_model_path
            print(f"using final model: {final_model_path}")
        else:
            print("no model file exist")
            return
    
    vis_cfg.MODEL.WEIGHTS = full_model_path
    
    predictor = DefaultPredictor(vis_cfg)
    
    try:
        metadata = MetadataCatalog.get(dataset_name)
    except:
        print(f" can't get {dataset_name} metadata")
        metadata = None
    
    img_dir = DATASET_CONFIG['img_dir']
    if not os.path.exists(img_dir):
        print(f"image folder not eixst: {img_dir}")
        return
    
    img_list = [f for f in os.listdir(img_dir) 
                if f.lower().endswith(('.jpg', '.jpeg', '.png', '.bmp', '.tiff'))]
    
    if not img_list:
        print(f"can't find images in {img_dir}")
        return
    
    random.shuffle(img_list)
    shown = 0
    
    print(f"using {model_path} generating {num_samples} samples...")
    
    for file in img_list:
        try:
            img_path = os.path.join(img_dir, file)
            im = cv2.imread(img_path)
            
            if im is None:
                print(f"failed to load image: {img_path}")
                continue
            
            outputs = predictor(im)
            
            v = Visualizer(im[:, :, ::-1], metadata=metadata, scale=1.2)
            v._default_font_size = 20
            out = v.draw_instance_predictions(outputs["instances"].to("cpu"))
            
            save_path = os.path.join(vis_cfg.OUTPUT_DIR, f"prediction_{shown+1}_{file}")
            cv2.imwrite(save_path, out.get_image()[:, :, ::-1])
            print(f"  saved at: {save_path}")
            
            shown += 1
            if shown >= num_samples:
                break
                
        except Exception as e:
            print(f"error when processing {file} : {e}")
            continue
    
    print("Plotting finished!")

def plot_training_history():
    history_file = os.path.join(cfg.OUTPUT_DIR, "training_history.json")
    
    if not os.path.exists(history_file):
        print(f"can't find training log: {history_file}")
        return
    
    try:
        with open(history_file, 'r') as f:
            history = json.load(f)
        
        if not history:
            print("training log is empty")
            return
        
        epochs = [h['epoch'] for h in history]
        scores = [h['score'] for h in history]
        
        plt.figure(figsize=(10, 6))
        plt.plot(epochs, scores, 'b-o', linewidth=2, markersize=6)
        plt.title(f'training log - {history[0]["metric"]}', fontsize=14)
        plt.xlabel('Epoch', fontsize=8)
        plt.ylabel(f'{history[0]["metric"]}', fontsize=12)
        plt.grid(True, alpha=0.3)
        
        best_idx = scores.index(max(scores))
        plt.annotate(f'best: {max(scores):.4f}', 
                    xy=(epochs[best_idx], scores[best_idx]),
                    xytext=(10, 10), textcoords='offset points',
                    bbox=dict(boxstyle='round,pad=0.3', facecolor='yellow', alpha=0.7),
                    arrowprops=dict(arrowstyle='->', connectionstyle='arc3,rad=0'))
        
        plt.tight_layout()
        
        curve_path = os.path.join(cfg.OUTPUT_DIR, "training_curve.png")
        plt.savefig(curve_path, dpi=300, bbox_inches='tight')
        plt.show()
        
        print(f"training curve saved at: {curve_path}")
        
    except Exception as e:
        print(f"failed to plot training log: {e}")

if 'trainer' in locals() and trainer is not None:
    print("\n" + "=" * 60)
    print("plotting start")
    print("=" * 60)
    
    visualize_predictions("tomato_test", num_samples=3, model_path="model_best.pth")
    
    plot_training_history()
    
    print("plotting finished")

print("Output saved.")