# Finetuning with Ultralytics YOLO

This notebook demonstrates fine-tuning the latest Ultralytics YOLO models (YOLO11/YOLOv8) on custom datasets using single GPU training.

**Key Features:**
- Simple API with minimal boilerplate
- Fast training and inference
- Native COCO format support
- Production-ready MLflow integration
- Deployment-ready model wrapper with base64 input support

**References:**
- https://docs.ultralytics.com/modes/train/
- https://docs.ultralytics.com/datasets/detect/

**Prerequisites**
- MLR 17.3 LTS (for numpy 2.x compatibility)
- Single GPU cluster
- Cluster started with `scripts/init_script_ultralytics.sh` init script
- COCO format dataset created from video processing notebooks:
  - `1_Processing_w_ffmpeg.ipynb` - standardize input videos
  - `2_Batch_Inference_w_opencv.ipynb` - create COCO annotations 

In [None]:
%sh
# Make sure to install these via init scripts for sustainability
# /databricks/python/bin/pip install -U ultralytics supervision

In [None]:
%pip install -U mlflow psutil nvidia-ml-py
%restart_python


# Setup and Configure

Initialize all variables needed for training using Unity Catalog Volumes for data storage.


In [None]:
import mlflow
import os
from pathlib import Path

ds_catalog = 'brian_ml_dev'
ds_schema = 'image_processing'
coco_volume = 'coco_dataset'
training_volume = 'training'

mlflow_experiment = '/Users/brian.law@databricks.com/brian_yolo_training'

volume_path = f"/Volumes/{ds_catalog}/{ds_schema}/{coco_volume}"
training_volume_path = f"/local_disk0/ultralytics_logging_folder"
image_path = f'{volume_path}/images'
annotation_json = f'{volume_path}/annotations.json'

# YOLO model to start from (yolo11n.pt, yolo11s.pt, yolo11m.pt, yolo11l.pt, yolo11x.pt)
# or use yolov8n.pt, yolov8s.pt, etc.
YOLO_MODEL = 'yolo11n.pt'  # Start with nano for faster training

print(f"Dataset location: {volume_path}")
print(f"Training outputs: {training_volume_path}")


In [None]:
# Training Hyperparams parameters
EPOCHS = 2
BATCH_SIZE = 128
IMG_SIZE = 640
initial_lr = 0.005
final_lr = 0.1
run_name = 'single_gpu_run'

## Prepare YOLO Dataset Configuration

YOLO expects a `.yaml` file that defines:
- Path to train/val images
- Number of classes
- Class names

We'll convert our COCO format to YOLO's expected structure.


In [None]:
import json
import yaml

# Read COCO annotations to get class information
with open(annotation_json, 'r') as f:
    coco_data = json.load(f)

# Extract categories and create mapping from COCO category_id to YOLO class_id (0-indexed)
categories = coco_data['categories']
sorted_categories = sorted(categories, key=lambda x: x['id'])
class_names = [cat['name'] for cat in sorted_categories]
num_classes = len(class_names)

# CRITICAL: Create mapping from COCO category_id to YOLO class index (0-based)
# COCO IDs might be non-contiguous (e.g., 1, 2, 3, ..., 90) but YOLO needs 0, 1, 2, ..., n-1
coco_id_to_yolo_id = {cat['id']: idx for idx, cat in enumerate(sorted_categories)}

print(f"Number of classes: {num_classes}")
print(f"Classes: {class_names[:10]}...")  # Show first 10
print(f"COCO ID to YOLO ID mapping sample: {dict(list(coco_id_to_yolo_id.items())[:5])}")

# Create YOLO dataset config
yolo_config = {
    'path': volume_path,  # Root directory
    'train': 'images',  # Train images relative to 'path'
    'val': 'images',    # Using same for now - split in production
    'nc': num_classes,  # Number of classes
    'names': class_names  # Class names
}

# Save config file
config_path = f"{volume_path}/data.yaml"
with open(config_path, 'w') as f:
    yaml.dump(yolo_config, f)

print(f"\nYOLO config saved to: {config_path}")


## Convert COCO to YOLO Format

YOLO expects annotations in a specific format:
- One `.txt` file per image
- Each line: `class_id center_x center_y width height` (normalized 0-1)

We'll convert the COCO annotations to YOLO format.


In [None]:
## Diagnose COCO Data Format (Run this to debug coordinate issues)

# Inspect a few annotations to understand the coordinate format
print("=== COCO Data Inspection ===\n")

# Check images
sample_images = coco_data['images'][:3]
print(f"Total images: {len(coco_data['images'])}")
print(f"\nSample images:")
for img in sample_images:
    print(f"  ID: {img['id']}, File: {img['file_name']}, Size: {img['width']}x{img['height']}")

# Check annotations
sample_annotations = coco_data['annotations'][:5]
print(f"\nTotal annotations: {len(coco_data['annotations'])}")
print(f"\nSample annotations:")
for ann in sample_annotations:
    img_info = next((img for img in coco_data['images'] if img['id'] == ann['image_id']), None)
    if img_info:
        bbox = ann['bbox']
        print(f"\n  Annotation ID: {ann['id']}")
        print(f"    Image: {img_info['file_name']} ({img_info['width']}x{img_info['height']})")
        print(f"    Category ID: {ann['category_id']}")
        print(f"    Bbox: {bbox}")
        print(f"    Bbox/Image ratio: x={bbox[0]/img_info['width']:.3f}, y={bbox[1]/img_info['height']:.3f}, w={bbox[2]/img_info['width']:.3f}, h={bbox[3]/img_info['height']:.3f}")
        
        # Check if coordinates seem normalized or in pixels
        if bbox[0] < 2 and bbox[1] < 2 and bbox[2] < 2 and bbox[3] < 2:
            print(f"    ⚠️  WARNING: Bbox values are < 2, might already be normalized!")
        if bbox[0] > img_info['width'] or bbox[1] > img_info['height']:
            print(f"    ⚠️  WARNING: Bbox x/y exceed image dimensions!")
        if (bbox[0] + bbox[2]) > img_info['width'] or (bbox[1] + bbox[3]) > img_info['height']:
            print(f"    ⚠️  WARNING: Bbox extends beyond image boundaries!")


In [None]:
from pathlib import Path

# Create labels directory
labels_dir = Path(volume_path) / 'labels'
labels_dir.mkdir(exist_ok=True)

print(f"Converting COCO annotations to YOLO format...")

# Group annotations by image_id
image_annotations = {}
for ann in coco_data['annotations']:
    image_id = ann['image_id']
    if image_id not in image_annotations:
        image_annotations[image_id] = []
    image_annotations[image_id].append(ann)

# Convert each image's annotations
images_dict = {img['id']: img for img in coco_data['images']}
converted_count = 0
skipped_annotations = 0
invalid_coords_count = 0

# Debug: Check first annotation to understand the coordinate format
if coco_data['annotations']:
    first_ann = coco_data['annotations'][0]
    first_img = images_dict[first_ann['image_id']]
    print(f"\nDebug - First annotation:")
    print(f"  Image size: {first_img['width']}x{first_img['height']}")
    print(f"  Bbox (COCO): {first_ann['bbox']}")
    print(f"  Bbox format should be: [x, y, width, height] in pixels")

for image_id, image_info in images_dict.items():
    img_width = image_info['width']
    img_height = image_info['height']
    
    # Get corresponding annotations
    annotations = image_annotations.get(image_id, [])
    
    # Create label file
    image_filename = Path(image_info['file_name']).stem
    label_file = labels_dir / f"{image_filename}.txt"
    
    with open(label_file, 'w') as f:
        for ann in annotations:
            # Get COCO category_id and map to YOLO class index
            coco_category_id = ann['category_id']
            
            # Skip if category_id is not in the mapping (should not happen with valid COCO data)
            if coco_category_id not in coco_id_to_yolo_id:
                print(f"Warning: Unknown category_id {coco_category_id} in image {image_filename}")
                skipped_annotations += 1
                continue
            
            # Map to YOLO class index (0-based)
            yolo_class_id = coco_id_to_yolo_id[coco_category_id]
            
            # COCO format: [x, y, width, height] (top-left corner)
            x, y, w, h = ann['bbox']
            
            # Validate bbox values are positive and reasonable
            if x < 0 or y < 0 or w <= 0 or h <= 0:
                if invalid_coords_count == 0:
                    print(f"Warning: Invalid bbox in {image_filename}: x={x}, y={y}, w={w}, h={h}")
                invalid_coords_count += 1
                continue
            
            # Convert to YOLO format: [center_x, center_y, width, height] (normalized)
            center_x = (x + w / 2) / img_width
            center_y = (y + h / 2) / img_height
            norm_w = w / img_width
            norm_h = h / img_height
            
            # Validate normalized coordinates are in valid range [0, 1]
            # Allow slight overflow due to floating point, but clip to [0, 1]
            if center_x > 1.05 or center_y > 1.05 or norm_w > 1.05 or norm_h > 1.05:
                if invalid_coords_count < 5:  # Print first few examples
                    print(f"\nWarning: Out of bounds coordinates in {image_filename}:")
                    print(f"  Image size: {img_width}x{img_height}")
                    print(f"  COCO bbox: x={x}, y={y}, w={w}, h={h}")
                    print(f"  YOLO (before clip): cx={center_x:.4f}, cy={center_y:.4f}, w={norm_w:.4f}, h={norm_h:.4f}")
                invalid_coords_count += 1
                # Skip this annotation if severely out of bounds
                if center_x > 1.5 or center_y > 1.5 or norm_w > 1.5 or norm_h > 1.5:
                    continue
            
            # Clip coordinates to valid range [0, 1]
            center_x = max(0.0, min(1.0, center_x))
            center_y = max(0.0, min(1.0, center_y))
            norm_w = max(0.0, min(1.0, norm_w))
            norm_h = max(0.0, min(1.0, norm_h))
            
            # Write in YOLO format (using mapped class_id)
            f.write(f"{yolo_class_id} {center_x:.6f} {center_y:.6f} {norm_w:.6f} {norm_h:.6f}\n")
    
    converted_count += 1
    if converted_count % 100 == 0:
        print(f"Converted {converted_count}/{len(images_dict)} images...")

print(f"\nConversion complete! {converted_count} label files created in {labels_dir}")
if skipped_annotations > 0:
    print(f"Warning: Skipped {skipped_annotations} annotations with unknown category IDs")
if invalid_coords_count > 0:
    print(f"Warning: Found {invalid_coords_count} annotations with invalid/out-of-bounds coordinates (clipped or skipped)")


## Training on Single GPU

Start with single GPU training to validate the setup.


In [None]:
from ultralytics import YOLO
from ultralytics import settings
import torch
import torch.distributed as dist

settings.update({"mlflow": True})

# Setting MLflow configs for ultralytics
os.environ['MLFLOW_EXPERIMENT_NAME'] = mlflow_experiment
os.environ['MLFLOW_ENABLE_SYSTEM_METRICS_LOGGING'] = "true"

# keep run active to log the best model into mlflow
os.environ['MLFLOW_KEEP_RUN_ACTIVE'] = "true"

# setup torch routines that Ultralytics requires
if not dist.is_initialized():
    dist.init_process_group(backend="nccl")

In [None]:
# Check GPU availability
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

# Load pretrained model
model = YOLO(YOLO_MODEL)

print(f"\nLoaded {YOLO_MODEL}")
print(f"Model summary:")
model.info()

In [None]:
# Train the model
results = model.train(
        data=config_path,
        epochs=EPOCHS,
        batch=BATCH_SIZE,
        lr0=initial_lr,       # initial learning rate
        lrf=final_lr,         # final LR factor (relative to lr0)
        imgsz=IMG_SIZE,
        name=run_name,
        project=training_volume_path,
        device=0,  # Use GPU 0
        workers=8,
        patience=50,
        save=True,
        save_period=5,  # Save checkpoint every 5 epochs
        verbose=True
    )

# MLflow Integration

Ultralytics YOLO includes built-in MLflow callback integration for logging training metrics and checkpoints. We extend this with:

- **Dataset Tracking**: Log COCO dataset metadata for full lineage
- **Deployment-Ready Models**: PyFunc wrapper with base64 input support for Model Serving
- **Metrics Linkage**: Connect validation metrics to both model and dataset (MLflow 3.x)

## Log Dataset to Active MLflow Run

Use `mlflow.data` module to log the dataset as a trackable input.


In [None]:
import mlflow
import pandas as pd
from pathlib import Path

# Get the active run (should be the training run)
active_run = mlflow.active_run()

if active_run:
    print(f"Active Run ID: {active_run.info.run_id}")
    print(f"Logging dataset using mlflow.data module...\n")
    
    # Prepare dataset metadata DataFrame
    # Include image info and link to annotations
    image_data = []
    
    # Create mapping of image_id to annotations
    image_to_anns = {}
    for ann in coco_data['annotations']:
        img_id = ann['image_id']
        if img_id not in image_to_anns:
            image_to_anns[img_id] = []
        image_to_anns[img_id].append(ann)
    
    # Build comprehensive dataset representation
    for img in coco_data['images']:
        img_id = img['id']
        anns = image_to_anns.get(img_id, [])
        
        # Get class distribution for this image
        classes_in_image = [ann['category_id'] for ann in anns]
        
        image_data.append({
            'image_id': img_id,
            'file_name': img['file_name'],
            'width': img['width'],
            'height': img['height'],
            'num_annotations': len(anns),
            'classes': ','.join(map(str, classes_in_image)) if classes_in_image else ''
        })
    
    dataset_df = pd.DataFrame(image_data)
    
    # Create MLflow Dataset with full metadata
    dataset = mlflow.data.from_pandas(
        dataset_df,
        source=volume_path,
        name=f"{ds_catalog}.{ds_schema}.{coco_volume}",
        targets="num_annotations",  # What we're predicting
    )
    
    # Log the dataset as an input to the training run
    mlflow.log_input(dataset, context="training")
    
    print(f"✓ Logged dataset with {len(dataset_df)} images")
    print(f"  - Source: {volume_path}")
    print(f"  - Name: {ds_catalog}.{ds_schema}.{coco_volume}")
    print(f"  - Total annotations: {dataset_df['num_annotations'].sum()}")
    print(f"  - Avg annotations/image: {dataset_df['num_annotations'].mean():.2f}")
    print(f"\n{'='*60}")
    print(f"Dataset logged to active run!")
    print(f"View in MLflow UI at: {mlflow_experiment}")
    print(f"{'='*60}")
    
else:
    print("⚠ No active MLflow run found!")
    print("The run may have already been ended by Ultralytics.")
    print("\nTo manually log the dataset, use the cell below with a specific run_id.")


## Log YOLO Model for Deployment

Create an MLflow PyFunc wrapper that makes the model deployment-ready:

- **Handles base64 input**: REST API-friendly image encoding
- **Complete preprocessing**: Image decoding, resizing, and normalization
- **Post-processing included**: NMS and confidence filtering applied internally
- **Structured output**: Returns detection boxes, scores, and class IDs
- **Artifact-based loading**: Model weights loaded from MLflow artifacts

In [None]:
# PyFunc Model Wrapper for Deployment
# This wrapper handles base64 image input and includes the model weights as artifacts
import torch
from torchvision.ops import nms
import mlflow.pyfunc
import base64
from io import BytesIO
from PIL import Image
import numpy as np

class YOLOPyFuncWrapper(mlflow.pyfunc.PythonModel):
    """
    MLflow PyFunc wrapper for YOLO model with base64 string input support.
    Handles decoding, preprocessing, inference, and post-processing.
    Uses artifacts to load the model file.
    """
    
    def __init__(self, img_size=640, conf_thres=0.25, iou_thres=0.5, max_det=300):
        """Initialize with detection parameters"""
        self.img_size = img_size
        self.conf_thres = conf_thres
        self.iou_thres = iou_thres
        self.max_det = max_det
        self.model = None
    
    def load_context(self, context):
        """Load the YOLO model during MLflow model loading"""
        from ultralytics import YOLO
        import torch
        
        # Load the YOLO model from artifacts
        model_path = context.artifacts["model"]
        yolo = YOLO(model_path)
        self.model = yolo.model.eval()
        
        print(f"YOLO model loaded from artifacts: {model_path}")
    
    @staticmethod
    def xywh_to_xyxy(b):
        """Convert box format from xywh to xyxy"""
        x, y, w, h = b.unbind(-1)
        return torch.stack([x-w/2, y-h/2, x+w/2, y+h/2], dim=-1)
    
    def preprocess_base64(self, base64_str):
        """Decode base64 image and convert to tensor format"""
        # Handle different string types (str, np.str_, bytes)
        if isinstance(base64_str, bytes):
            base64_str = base64_str.decode('utf-8')
        else:
            base64_str = str(base64_str)
        
        # Decode base64 to bytes
        img_bytes = base64.b64decode(base64_str)
        
        # Load as PIL Image
        img = Image.open(BytesIO(img_bytes)).convert('RGB')
        
        # Resize to expected size
        img = img.resize((self.img_size, self.img_size), Image.BILINEAR)
        
        # Convert to numpy array and normalize
        img_array = np.array(img, dtype=np.float32) / 255.0
        
        # Convert HWC to CHW format
        img_array = np.transpose(img_array, (2, 0, 1))
        
        # Convert to torch tensor with batch dimension
        img_tensor = torch.from_numpy(img_array).unsqueeze(0)
        
        return img_tensor
    
    def predict(self, context, model_input):
        """
        Handle prediction requests from Model Serving endpoint.
        
        Args:
            context: MLflow context (unused in predict, used in load_context)
            model_input: Dict, DataFrame, or direct input with 'images' containing base64 strings
            
        Returns:
            numpy array of detections with shape (batch_size, max_det, 6)
            Each detection: [x1, y1, x2, y2, confidence, class_id]
        """
        # Extract images from input
        if isinstance(model_input, dict):
            images = model_input.get('images')
        elif hasattr(model_input, 'to_dict'):  # DataFrame
            images = model_input['images'].tolist() if 'images' in model_input.columns else None
            if images and len(images) == 1:
                images = images[0]
        elif hasattr(model_input, '__getitem__'):  # List or array-like
            images = model_input
        else:
            images = str(model_input)
        
        if images is None:
            raise ValueError("Input must contain 'images' key with base64 encoded image string")
        
        # Handle different input formats for base64 strings
        with torch.no_grad():
            # Single string (most common case for serving endpoint)
            if isinstance(images, (str, bytes, np.str_)):
                x = self.preprocess_base64(images)
            
            # 0-dimensional numpy array containing a string
            elif isinstance(images, np.ndarray) and images.ndim == 0:
                x = self.preprocess_base64(images.item())
            
            # List/tuple of strings (batch)
            elif isinstance(images, (list, tuple)):
                processed = [self.preprocess_base64(img) for img in images]
                x = torch.cat(processed, dim=0)
            
            # Numpy array of strings
            elif isinstance(images, np.ndarray) and images.dtype.kind in ('U', 'S', 'O'):
                processed = [self.preprocess_base64(str(img)) for img in images.flat]
                x = torch.cat(processed, dim=0)
            
            else:
                raise ValueError(
                    f"Unsupported input type: {type(images)}. "
                    f"Expected base64 encoded string or list of strings. "
                    f"Input dtype: {getattr(images, 'dtype', 'N/A')}"
                )
            
            # Move to model device
            device = next(self.model.parameters()).device
            x = x.to(device)
            
            # Run inference
            out = self.model(x)
            
            # Handle different output formats
            if isinstance(out, (list, tuple)):
                out = out[0]
            
            # Ensure correct dimension ordering
            if out.dim() == 3 and out.shape[1] < out.shape[2]:
                out = out.permute(0, 2, 1).contiguous()
            
            # Parse detections: [batch, anchors, 4+1+num_classes]
            boxes_xywh = out[..., :4]
            obj = out[..., 4:5].sigmoid()
            cls = out[..., 5:].sigmoid()
            
            # Compute confidence scores and class IDs
            conf, cls_id = (obj * cls).max(-1)
            
            # Convert boxes to xyxy format
            boxes = self.xywh_to_xyxy(boxes_xywh)
            
            # Apply NMS and filtering per image
            N = boxes.shape[0]
            out_pad = x.new_full((N, self.max_det, 6), -1.0)
            
            for i in range(N):
                # Filter by confidence threshold
                mask = conf[i] >= self.conf_thres
                if mask.sum() == 0:
                    continue
                
                b = boxes[i][mask]
                s = conf[i][mask]
                c = cls_id[i][mask].float()
                
                # Apply NMS
                keep = nms(b, s, self.iou_thres)[:self.max_det]
                
                # Store results
                k = keep.numel()
                if k > 0:
                    out_pad[i, :k, :4] = b[keep]
                    out_pad[i, :k, 4] = s[keep]
                    out_pad[i, :k, 5] = c[keep]
            
            # Convert to numpy
            result = out_pad.cpu().numpy()
            
            return result

In [None]:
# Prepare best model path and artifacts for pyfunc wrapper
import tempfile
import shutil

best_model_path = f"{training_volume_path}/{run_name}/weights/best.pt"

# Create temporary directory for artifacts
artifacts_dir = tempfile.mkdtemp()

# Copy the best.pt file to artifacts directory
artifact_model_path = os.path.join(artifacts_dir, "best.pt")
shutil.copy2(best_model_path, artifact_model_path)

# Create artifacts dict
artifacts = {
    "model": artifact_model_path
}

# Create pyfunc wrapper instance
pyfunc_wrapper = YOLOPyFuncWrapper(
    img_size=IMG_SIZE,
    conf_thres=0.25,
    iou_thres=0.5,
    max_det=300
)

print(f"✓ Prepared model artifacts from: {best_model_path}")

In [None]:
from mlflow.models.signature import ModelSignature
from mlflow.types import Schema, TensorSpec, ColSpec
import numpy as np

# Create signature with base64 string input (deployment-ready)
signature = ModelSignature(
    inputs=Schema([ColSpec("string", "images")]),
    outputs=Schema([TensorSpec(np.dtype(np.float32), (-1, -1, 6), "detections")])
)

# Create input example as base64 encoded image
dummy_img = np.random.randint(0, 255, (IMG_SIZE, IMG_SIZE, 3), dtype=np.uint8)
pil_img = Image.fromarray(dummy_img)
buffer = BytesIO()
pil_img.save(buffer, format='PNG')
img_base64 = base64.b64encode(buffer.getvalue()).decode('utf-8')
input_example = {"images": img_base64}

# Log model as pyfunc with artifacts
try:
    model_info = mlflow.pyfunc.log_model(
        python_model=pyfunc_wrapper,
        artifact_path="best_model",
        artifacts=artifacts,
        signature=signature,
        input_example=input_example,
        pip_requirements=[
            "torch",
            "torchvision", 
            "ultralytics",
            "pillow",
            "numpy"
        ]
        #registered_model_name="brian_ml_dev.image_processing.yolo_best_model"  # optional: register to UC
    )
    
    model_id = model_info.model_id
    print(f"✓ Model logged successfully!")
    print(f"  Model ID: {model_id}")
    print(f"  Artifact path: best_model")
    print(f"  Input format: base64 encoded image strings")
    print(f"  Output format: [batch, detections, 6] where each detection is [x1, y1, x2, y2, conf, class_id]")
    
finally:
    # Clean up temporary directory
    shutil.rmtree(artifacts_dir, ignore_errors=True)

## Log Final Metrics with Model and Dataset Linkage

Using MLflow 3.x features, we can link metrics to both the logged model and dataset for complete lineage tracking.


In [None]:
# Log final validation metrics linked to both LoggedModel and Dataset
try:
    # Get the dataset entity we logged earlier
    active_run = mlflow.active_run()
    
    if active_run and 'dataset' in locals():
        print(f"Logging final metrics linked to model and dataset...")
        
        # Extract final validation metrics from training results
        # Note: In single GPU, we need to access the results object
        # The results object should still be available from the training cell
        
        # Get validation metrics (these are typically available in the results object)
        # For demonstration, we'll compute from the training results
        val_mAP50 = 0.0  # Placeholder - extract from results if available
        val_mAP50_95 = 0.0  # Placeholder - extract from results if available
        
        # Try to extract from results if available
        if 'results' in locals() and hasattr(results, 'results_dict'):
            val_mAP50 = float(results.results_dict.get('metrics/mAP50(B)', 0.0))
            val_mAP50_95 = float(results.results_dict.get('metrics/mAP50-95(B)', 0.0))
        
        # Log metrics linked to both model_id and dataset
        mlflow.log_metric(
            key="final_val_mAP50",
            value=val_mAP50,
            step=EPOCHS,
            model_id=model_id,  # Links to LoggedModel
            dataset=dataset  # Links to dataset
        )
        
        mlflow.log_metric(
            key="final_val_mAP50-95",
            value=val_mAP50_95,
            step=EPOCHS,
            model_id=model_id,  # Links to LoggedModel
            dataset=dataset  # Links to dataset
        )
        
        print(f"✓ Final metrics logged and linked!")
        print(f"  - mAP50: {val_mAP50:.4f}")
        print(f"  - mAP50-95: {val_mAP50_95:.4f}")
        print(f"  - Linked to model_id: {model_id}")
        print(f"  - Linked to dataset: {ds_catalog}.{ds_schema}.{coco_volume}")
        print(f"\n{'='*60}")
        print("Complete lineage established:")
        print("  Dataset → Training → Model → Metrics")
        print(f"{'='*60}")
    else:
        print("⚠ Dataset not available for linking. Make sure Cell 17 was executed.")
        
except Exception as e:
    print(f"Warning: Could not log final metrics with linkage - {e}")
    print("This is expected if using MLflow < 3.0 or if dataset was not logged")


In [None]:
# Close out active mlflow run
mlflow.end_run()
print("✓ MLflow run ended successfully")

# Next Steps

## What You've Created

✅ **Complete Lineage Tracking:**
- Dataset logged with COCO metadata and image statistics
- Training hyperparameters captured automatically by Ultralytics
- Model logged with deployment-ready PyFunc wrapper
- Final metrics linked to both model and dataset (MLflow 3.x)

✅ **Deployment-Ready Model:**
- Accepts base64 encoded image strings (REST API friendly)
- Handles preprocessing and post-processing internally
- Returns structured detection output: `[batch, detections, 6]` where each detection is `[x1, y1, x2, y2, confidence, class_id]`
- Includes all dependencies in pip_requirements

## Deployment Options

### 1. Model Serving Endpoint
Deploy the logged model to a Databricks Model Serving endpoint:
```python
# Get the model URI from the MLflow run
model_uri = f"runs:/{run_id}/best_model"

# Create serving endpoint via Databricks UI or API
```

### 2. Register to Unity Catalog
For versioned model management, uncomment `registered_model_name` in Cell 21:
```python
registered_model_name="<catalog>.<schema>.<model_name>"
```

### 3. Batch Inference
Load the model for batch processing on Spark:
```python
model = mlflow.pyfunc.load_model(model_uri)
predictions = model.predict({"images": base64_images})
```

## Training at Scale

For larger datasets or faster training, use the multi-GPU notebook which leverages TorchDistributor with the `scripts/train_yolo.py` training script across multiple GPUs.