# Finetuning with Ultralytics YOLO Multi-GPU

This notebook demonstrates fine-tuning the latest Ultralytics YOLO models (YOLO11/YOLOv8) on custom datasets in a distributed manner.

**References:**
- https://docs.ultralytics.com/modes/train/
- https://docs.ultralytics.com/datasets/detect/

**Prerequisites**
- MLR 17.3 LTS - need numpy 2.z compatibility
- Cluster with Multi-GPU
- Cluster started with `scripts/init_script_ultralytics.sh` init script
- **YOLO dataset already prepared** with:
  - `data.yaml` config file in your dataset volume
  - Images in `images/` directory
  - Labels in `labels/` directory (YOLO format: class_id center_x center_y width height) 


In [None]:
%pip install -U mlflow psutil nvidia-ml-py
%restart_python


# Setup and Configure

Initialize all variables needed for training.


In [None]:
import mlflow
import os
from pathlib import Path

# Dataset configuration
ds_catalog = 'brian_ml_dev'
ds_schema = 'image_processing'
dataset_volume = 'coco_dataset'

# Paths
dataset_path = f"/Volumes/{ds_catalog}/{ds_schema}/{dataset_volume}"
config_path = f"{dataset_path}/data.yaml"  # YOLO config file (must exist)
training_volume_path = f"/local_disk0/ultralytics_logging_folder"

# MLflow
mlflow_experiment = '/Users/brian.law@databricks.com/brian_yolo_training'

# MLflow connectivity
browser_host = spark.conf.get("spark.databricks.workspaceUrl")
db_host = f"https://{browser_host}"
db_token = dbutils.notebook.entry_point.getDbutils().notebook().getContext().apiToken().get()

# YOLO model to start from (yolo11n.pt, yolo11s.pt, yolo11m.pt, yolo11l.pt, yolo11x.pt)
YOLO_MODEL = 'yolo11n.pt'

print(f"Dataset location: {dataset_path}")
print(f"Config file: {config_path}")
print(f"Training outputs: {training_volume_path}")

In [None]:
# Create Databricks widgets for hyperparameters
dbutils.widgets.text("epochs", "2", "Epochs")
dbutils.widgets.text("batch_size", "128", "Batch Size")
dbutils.widgets.text("img_size", "640", "Image Size")
dbutils.widgets.text("initial_lr", "0.005", "Initial Learning Rate")
dbutils.widgets.text("final_lr", "0.1", "Final LR Factor")
dbutils.widgets.text("device_config", "[0,1]", "Device Config (e.g., [0,1])")
dbutils.widgets.text("run_name", "multi_gpu_run", "Run Name")

# Get hyperparameters from widgets
EPOCHS = int(dbutils.widgets.get("epochs"))
BATCH_SIZE = int(dbutils.widgets.get("batch_size"))
IMG_SIZE = int(dbutils.widgets.get("img_size"))
initial_lr = float(dbutils.widgets.get("initial_lr"))
final_lr = float(dbutils.widgets.get("final_lr"))
device_config = eval(dbutils.widgets.get("device_config"))  # Parse list from string
run_name = dbutils.widgets.get("run_name")

print("Training Hyperparameters:")
print(f"  Epochs: {EPOCHS}")
print(f"  Batch Size: {BATCH_SIZE}")
print(f"  Image Size: {IMG_SIZE}")
print(f"  Initial LR: {initial_lr}")
print(f"  Final LR Factor: {final_lr}")
print(f"  Device Config: {device_config}")
print(f"  Run Name: {run_name}")

## Validate Dataset

Verify that the YOLO dataset is properly configured and ready for training.


In [None]:
import yaml
from pathlib import Path

# Verify config file exists
if not Path(config_path).exists():
    raise FileNotFoundError(f"Config file not found: {config_path}\n"
                          f"Please run dataset preparation first.")

# Load and display config
with open(config_path, 'r') as f:
    yolo_config = yaml.safe_load(f)

print("YOLO Dataset Configuration:")
print(f"  Path: {yolo_config['path']}")
print(f"  Train: {yolo_config['train']}")
print(f"  Val: {yolo_config['val']}")
print(f"  Classes: {yolo_config['nc']}")
print(f"  Class names (first 10): {yolo_config['names'][:10]}...")
print(f"\n✓ Dataset configuration validated!")


## Multi-GPU Training with TorchDistributor

For multi-GPU training on Databricks, we use `TorchDistributor` which properly manages distributed processes. This is required because Ultralytics' built-in `device=[0,1]` approach doesn't work in Databricks notebook environments.

**How it works:**
- TorchDistributor spawns one process per GPU
- Each process runs the training function with its own `local_rank`
- NCCL handles communication between GPUs
- Only rank 0 handles MLflow logging


In [None]:
import torch

# Check GPU availability
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    gpu_count = torch.cuda.device_count()
    print(f"Available GPUs: {gpu_count}")
    for i in range(gpu_count):
        print(f"  GPU {i}: {torch.cuda.get_device_name(i)}")
else:
    raise RuntimeError("CUDA not available. This notebook requires GPU.")

In [None]:
def train_single_gpu(run_id: str):
    """
    Training function that runs on each GPU process.
    TorchDistributor will spawn this function on each GPU.
    Includes MLflow dataset and model logging (rank 0 only).
    """
    import os
    import torch
    import torch.nn as nn
    import torch.distributed as dist
    from torchvision.ops import nms
    from ultralytics import YOLO
    from ultralytics import settings
    import mlflow
    import pandas as pd
    import numpy as np
    from mlflow.models.signature import ModelSignature
    from mlflow.types import Schema, TensorSpec

    # Set Databricks credentials
    os.environ['DATABRICKS_HOST'] = db_host
    os.environ['DATABRICKS_TOKEN'] = db_token
    
    # Get distributed training context
    rank = int(os.environ.get('RANK', 0))
    local_rank = int(os.environ.get('LOCAL_RANK', 0))
    world_size = int(os.environ.get('WORLD_SIZE', 1))
    
    print(f"[Rank {rank}] Process started - local_rank: {local_rank}, world_size: {world_size}")
    print(f"[Rank {rank}] CUDA available: {torch.cuda.is_available()}")
    
    # NCCL configuration for cloud environments
    os.environ['NCCL_IB_DISABLE'] = '1'
    os.environ['NCCL_P2P_DISABLE'] = '1'
    
    # Initialize distributed process group if not already initialized
    if world_size > 1 and not dist.is_initialized():
        print(f"[Rank {rank}] Initializing process group...")
        dist.init_process_group(
            backend='nccl',
            init_method='env://',
            world_size=world_size,
            rank=rank
        )
        print(f"[Rank {rank}] Process group initialized")
    
    # Set CUDA device for this process
    torch.cuda.set_device(local_rank)
    
    # MLflow setup - ONLY on rank 0
    if rank == 0:
        settings.update({"mlflow": True})
        os.environ['MLFLOW_EXPERIMENT_NAME'] = mlflow_experiment
        os.environ['MLFLOW_ENABLE_SYSTEM_METRICS_LOGGING'] = "true"
        os.environ['MLFLOW_RUN_ID'] = run_id
        print(f"[Rank 0] MLflow configured with run_id: {run_id}")
    else:
        # Disable MLflow on non-rank-0 processes
        settings.update({"mlflow": False})
        os.environ.pop('MLFLOW_EXPERIMENT_NAME', None)
        os.environ.pop('MLFLOW_ENABLE_SYSTEM_METRICS_LOGGING', None)
        os.environ.pop('MLFLOW_RUN_ID', None)
        print(f"[Rank {rank}] MLflow disabled for this process")
    
    # Log dataset reference (rank 0 only)
    if rank == 0:
        try:
            print(f"[Rank 0] Logging dataset reference...")
            dataset_info = pd.DataFrame([{
                'dataset_path': dataset_path,
                'config_file': config_path,
                'num_classes': yolo_config['nc'],
                'class_names': ','.join(yolo_config['names'][:10]) + '...',
            }])
            
            dataset = mlflow.data.from_pandas(
                dataset_info,
                source=dataset_path,
                name=f"{ds_catalog}.{ds_schema}.{dataset_volume}",
            )
            
            mlflow.log_input(dataset, context="training")
            print(f"[Rank 0] ✓ Dataset logged - {yolo_config['nc']} classes")
        except Exception as e:
            print(f"[Rank 0] Warning: Could not log dataset - {e}")
    
    # Load model
    print(f"[Rank {rank}] Loading model: {YOLO_MODEL}")
    model = YOLO(YOLO_MODEL)
    
    # Train - each process trains on its local GPU
    print(f"[Rank {rank}] Starting training on GPU {local_rank}...")
    results = model.train(
        data=config_path,
        epochs=EPOCHS,
        batch=BATCH_SIZE,
        lr0=initial_lr,
        lrf=final_lr,
        imgsz=IMG_SIZE,
        name=run_name,
        project=training_volume_path,
        device=local_rank,
        workers=4,
        patience=50,
        save=True,
        save_period=5,
        verbose=(rank == 0),
        exist_ok=True
    )
    
    # Log wrapped model (rank 0 only)
    if rank == 0:
        try:
            print(f"[Rank 0] Loading best model for MLflow logging...")
            
            # Define YoloDetWrapper inline
            class YoloDetWrapper(nn.Module):
                def __init__(self, base: nn.Module, conf_thres=0.25, iou_thres=0.5, max_det=300):
                    super().__init__()
                    self.base = base.eval()
                    self.conf_thres, self.iou_thres, self.max_det = conf_thres, iou_thres, max_det

                @staticmethod
                def xywh_to_xyxy(b):
                    x,y,w,h = b.unbind(-1)
                    x1 = x - w/2; y1 = y - h/2; x2 = x + w/2; y2 = y + h/2
                    return torch.stack([x1,y1,x2,y2], dim=-1)

                def forward(self, x):
                    with torch.no_grad():
                        out = self.base(x)
                        if isinstance(out, (list, tuple)): out = out[0]
                        if out.dim() == 3 and out.shape[1] < out.shape[2]:
                            out = out.permute(0, 2, 1).contiguous()
                        
                        boxes_xywh = out[..., :4]
                        obj = out[..., 4:5].sigmoid()
                        cls = out[..., 5:].sigmoid()
                        conf, cls_id = (obj * cls).max(-1)
                        boxes = self.xywh_to_xyxy(boxes_xywh)

                        N, A = conf.shape
                        K = self.max_det
                        out_pad = x.new_full((N, K, 6), -1.0)

                        for i in range(N):
                            mask = conf[i] >= self.conf_thres
                            if mask.sum() == 0: continue
                            b = boxes[i][mask]
                            s = conf[i][mask]
                            c = cls_id[i][mask].float()
                            keep = nms(b, s, self.iou_thres)[:K]
                            k = keep.numel()
                            if k == 0: continue
                            out_pad[i, :k, :4] = b[keep]
                            out_pad[i, :k, 4]  = s[keep]
                            out_pad[i, :k, 5]  = c[keep]
                        return out_pad
            
            # Load and wrap best model
            best_model_path = f"{training_volume_path}/{run_name}/weights/best.pt"
            best_model_raw = YOLO(best_model_path).model
            wrapped_model = YoloDetWrapper(best_model_raw, conf_thres=0.25, iou_thres=0.5, max_det=300)
            
            # Create model signature
            input_schema = Schema([
                TensorSpec(type=np.dtype(np.float32), shape=(-1, 3, IMG_SIZE, IMG_SIZE), name="images")
            ])
            output_schema = Schema([
                TensorSpec(type=np.dtype(np.float32), shape=(-1, None, 6), name="detections")
            ])
            signature = ModelSignature(inputs=input_schema, outputs=output_schema)
            
            # Log model to MLflow
            print(f"[Rank 0] Logging wrapped model to MLflow...")
            mlflow.pytorch.log_model(
                pytorch_model=wrapped_model,
                artifact_path="best_model",
                signature=signature,
                #registered_model_name=f"{ds_catalog}.{ds_schema}.yolo_model"  # Uncomment to register
            )
            print(f"[Rank 0] ✓ Model logged successfully")
            
        except Exception as e:
            print(f"[Rank 0] Warning: Could not log model - {e}")
    
    print(f"[Rank {rank}] Training finished, cleaning up...")
    
    # Cleanup distributed process group
    if world_size > 1 and dist.is_initialized():
        dist.barrier()  # Ensure all processes finish
        dist.destroy_process_group()
        print(f"[Rank {rank}] Process group destroyed")
    
    # Force cleanup
    import gc
    gc.collect()
    
    if rank == 0:
        print("[Rank 0] Training complete!")
        return results
    else:
        return None

In [None]:
from pyspark.ml.torch.distributor import TorchDistributor
import time

# Set MLflow experiment
mlflow.set_experiment(mlflow_experiment)

# Create TorchDistributor and run training
num_processes = len(device_config) if isinstance(device_config, list) else 1
print(f"\n{'='*60}")
print(f"Starting distributed training with {num_processes} GPUs")
print(f"{'='*60}\n")

# Start MLflow run with system metrics logging
with mlflow.start_run(run_name=run_name, log_system_metrics=True) as run:
    active_run_id = run.info.run_id
    print(f"MLflow Run ID: {active_run_id}\n")
    
    # Log hyperparameters upfront
    mlflow.log_params({
        'model': YOLO_MODEL,
        'epochs': EPOCHS,
        'batch_size': BATCH_SIZE,
        'img_size': IMG_SIZE,
        'initial_lr': initial_lr,
        'final_lr': final_lr,
        'num_gpus': num_processes,
        'device_config': str(device_config),
        'run_name': run_name,
        'dataset_path': dataset_path,
    })
    
    # Create distributor
    distributor = TorchDistributor(
        num_processes=num_processes,
        local_mode=True,  # Single node multi-GPU
        use_gpu=True
    )
    
    # Run distributed training
    output = distributor.run(train_single_gpu, active_run_id)

print(f"\n{'='*60}")
print("Training complete! All MLflow logging finished.")
print(f"View results at: {mlflow_experiment}")
print(f"{'='*60}")

# Small delay to ensure all background processes fully exit
time.sleep(2)

# Training Complete!

The training function above includes all MLflow integration:

**What's Logged:**
- ✅ **Hyperparameters** - Model config, training settings, dataset info
- ✅ **Dataset Lineage** - Dataset reference with class information  
- ✅ **Training Metrics** - Ultralytics built-in MLflow callback logs metrics automatically
- ✅ **Wrapped Model** - Deployment-ready PyTorch model with proper signature
- ✅ **System Metrics** - GPU/CPU/memory usage during training

**Next Steps:**
- View results in MLflow UI at the experiment path above
- Uncomment `registered_model_name` in the training function to register the model to Unity Catalog
- Use the logged model artifact for deployment or further evaluation