YOLO Model Training Pipeline with MLflow Integration

This notebook implements the core training phase for YOLOv8n object detection model. It executes a complete two-phase training workflow: initial training with transfer learning from COCO pre-trained weights, followed by validation metrics computation. All training artifacts, hyperparameters, and metrics are tracked using MLflow for reproducibility and model management.

Training Workflow:
1. Load YOLOv8n architecture with COCO pre-trained weights
2. Configure dataset and training parameters (50 epochs, batch size 16)
3. Initialize MLflow experiment tracking for reproducibility
4. Execute model training with early stopping and learning rate scheduling
5. Validate trained model on validation dataset
6. Compute detection metrics (mAP50, mAP50-95, precision, recall)
7. Register best model to MLflow Model Registry with Production stage

In [None]:
import os
import yaml
import torch
import mlflow
import mlflow.pytorch
from pathlib import Path
from ultralytics import YOLO

np.random.seed(42)
torch.manual_seed(42)
torch.cuda.manual_seed(42)

mlflow.set_tracking_uri('file:///mlruns')
mlflow.set_experiment('yolo_3class_detection')

DATA_DIR = Path('../data')
MODELS_DIR = Path('../models')
MODELS_DIR.mkdir(parents=True, exist_ok=True)

print("YOLO TRAINING PIPELINE")
print("=" * 50)
print(f"Device: {torch.device('cuda' if torch.cuda.is_available() else 'cpu')}")
print(f"Classes: 3 (person, car, dog)")
print("=" * 50)

In [None]:
model = YOLO('yolov8n.pt')

mlflow.start_run(run_name='yolo_training')

mlflow.log_params({
    'model': 'yolov8n',
    'epochs': 50,
    'batch_size': 16,
    'device': str(torch.device('cuda' if torch.cuda.is_available() else 'cpu')),
    'seed': 42,
    'num_classes': 3,
    'classes': 'person, car, dog'
})

results = model.train(
    data=str(DATA_DIR / 'data.yaml'),
    epochs=50,
    imgsz=416,
    batch=16,
    device=0 if torch.cuda.is_available() else 'cpu',
    patience=10,
    save=True,
    exist_ok=True,
    project=str(MODELS_DIR),
    name='yolo_run',
    seed=42,
    verbose=True
)

best_model_path = MODELS_DIR / 'yolo_run' / 'weights' / 'best.pt'

if best_model_path.exists():
    mlflow.log_artifact(str(best_model_path), artifact_path='models')
    mlflow.log_metric('training_status', 1)
    print(f"Best model saved: {best_model_path}")
else:
    print("Warning: Best model path not found")
    mlflow.log_metric('training_status', 0)

mlflow.end_run()

print("Training completed with MLflow tracking")
print(f"Model location: {best_model_path}")

In [None]:
best_model = YOLO(str(best_model_path))

results = best_model.val(
    data=str(DATA_DIR / 'data.yaml'),
    imgsz=416,
    batch=16,
    device=0 if torch.cuda.is_available() else 'cpu',
    verbose=False
)

mlflow.start_run(run_name='yolo_validation')

mlflow.log_params({
    'model': 'yolov8n_best',
    'validation_split': 'val'
})

if hasattr(results, 'box') and results.box:
    metrics = {
        'mAP50': results.box.map50,
        'mAP50_95': results.box.map,
        'precision': results.box.p.mean(),
        'recall': results.box.r.mean(),
    }
    for metric_name, metric_value in metrics.items():
        mlflow.log_metric(metric_name, float(metric_value))
    
    print(f"Validation Metrics:")
    for metric_name, metric_value in metrics.items():
        print(f"  {metric_name}: {metric_value:.4f}")

mlflow.end_run()

print("Validation completed")

In [None]:
mlflow.start_run(run_name='yolo_model_registration')

mlflow.pytorch.log_model(
    pytorch_model=best_model.model,
    artifact_path='yolo_model',
    registered_model_name='yolo_3class_detector'
)

model_uri = mlflow.get_artifact_uri('yolo_model')
print(f"Model registered at: {model_uri}")

client = mlflow.tracking.MlflowClient()
model_version = client.get_latest_versions('yolo_3class_detector')[0]

client.transition_model_version_stage(
    name='yolo_3class_detector',
    version=model_version.version,
    stage='Production'
)

print(f"Model transitioned to Production - Version: {model_version.version}")

mlflow.end_run()

print("Model Registration completed")