# Fine-Tuning YOLO-NAS on a Roboflow Dataset

This notebook demonstrates the full pipeline:

1. Download a dataset from Roboflow
2. Explore the data
3. Fine-tune YOLO-NAS S with pretrained backbone
4. Evaluate and visualize results
5. Export to ONNX
6. Post-Training Quantization (PTQ)
7. Quantization-Aware Training (QAT)
8. ONNX Runtime inference

**Requirements:** GPU recommended. Install `roboflow` for dataset download.

In [None]:
# Cell 1: Setup & Download
# Install roboflow if not already installed
# !pip install roboflow

from roboflow import Roboflow

# Replace with your Roboflow API key
# Get one free at https://roboflow.com
rf = Roboflow(api_key="<YOUR_API_KEY>")

# Download Hard Hat Workers dataset (small, 3 classes)
# You can replace this with any Roboflow dataset
project = rf.workspace("lus-gabriel").project("hard-hat-sample-5g5ip")
version = project.version(2)
dataset = version.download("yolov5")

print(f"Dataset downloaded to: {dataset.location}")

In [None]:
# Cell 1b: Parse data.yaml
from pathlib import Path
from modern_yolonas.data import load_dataset_config, YOLODetectionDataset

# Parse the Roboflow data.yaml
cfg = load_dataset_config("./hardhat-dataset/data.yaml")

print(f"Dataset root:  {cfg.root}")
print(f"Num classes:   {cfg.num_classes}")
print(f"Class names:   {cfg.class_names}")
print(f"Train split:   {cfg.train_split}")
print(f"Val split:     {cfg.val_split}")

In [None]:
# Cell 2: Explore Dataset
import cv2
import numpy as np
import matplotlib.pyplot as plt
from collections import Counter

# Load raw dataset (no transforms) for exploration
train_raw = YOLODetectionDataset(root=cfg.root, split=cfg.train_split)
val_raw = YOLODetectionDataset(root=cfg.root, split=cfg.val_split)

print(f"Train images: {len(train_raw)}")
print(f"Val images:   {len(val_raw)}")

# Count class distribution across training set
class_counts = Counter()
for i in range(len(train_raw)):
    _, targets = train_raw.load_raw(i)
    for t in targets:
        class_counts[int(t[0])] += 1

print("\nTraining set class distribution:")
for cls_id, count in sorted(class_counts.items()):
    name = cfg.class_names[cls_id] if cls_id < len(cfg.class_names) else f"class_{cls_id}"
    print(f"  {name}: {count}")

# Visualize a few samples
fig, axes = plt.subplots(1, 4, figsize=(20, 5))
for ax, idx in zip(axes, range(4)):
    img, targets = train_raw.load_raw(idx)
    img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    h, w = img.shape[:2]

    for t in targets:
        cls_id, cx, cy, bw, bh = t
        x1 = int((cx - bw / 2) * w)
        y1 = int((cy - bh / 2) * h)
        x2 = int((cx + bw / 2) * w)
        y2 = int((cy + bh / 2) * h)
        color = [(0, 255, 0), (255, 0, 0), (0, 0, 255)][int(cls_id) % 3]
        cv2.rectangle(img_rgb, (x1, y1), (x2, y2), color, 2)
        label = cfg.class_names[int(cls_id)] if int(cls_id) < len(cfg.class_names) else str(int(cls_id))
        cv2.putText(img_rgb, label, (x1, y1 - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 1)

    ax.imshow(img_rgb)
    ax.set_title(f"Image {idx} ({len(targets)} objects)")
    ax.axis("off")

plt.tight_layout()
plt.show()

In [None]:
# Cell 3: Fine-Tune YOLO-NAS S
import torch
from torch.utils.data import DataLoader

from modern_yolonas import yolo_nas_s
from modern_yolonas.data import YOLODetectionDataset
from modern_yolonas.data.transforms import (
    Compose,
    HSVAugment,
    HorizontalFlip,
    LetterboxResize,
    Normalize,
)
from modern_yolonas.data.collate import detection_collate_fn
from modern_yolonas.training.trainer import Trainer

# Build model with pretrained COCO backbone, random head for our classes
model = yolo_nas_s(pretrained=True, num_classes=cfg.num_classes)
print(f"Model created: YOLO-NAS S with {cfg.num_classes} classes")
print("Backbone+neck weights: loaded from COCO pretrained")
print("Head cls_pred layers: randomly initialized")

# Transforms
train_transforms = Compose([
    HSVAugment(),
    HorizontalFlip(p=0.5),
    LetterboxResize(target_size=640),
    Normalize(),
])
val_transforms = Compose([
    LetterboxResize(target_size=640),
    Normalize(),
])

# Datasets & loaders
train_ds = YOLODetectionDataset(
    root=cfg.root, split=cfg.train_split, transforms=train_transforms
)
val_ds = YOLODetectionDataset(
    root=cfg.root, split=cfg.val_split, transforms=val_transforms
)

train_loader = DataLoader(
    train_ds, batch_size=8, shuffle=True, num_workers=2,
    collate_fn=detection_collate_fn, pin_memory=True,
)
val_loader = DataLoader(
    val_ds, batch_size=8, shuffle=False, num_workers=2,
    collate_fn=detection_collate_fn, pin_memory=True,
)

print(f"Train batches: {len(train_loader)}, Val batches: {len(val_loader)}")

# Train
trainer = Trainer(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    num_classes=cfg.num_classes,
    epochs=15,
    lr=2e-4,
    warmup_steps=100,
    output_dir="runs/hardhat",
    device="cuda" if torch.cuda.is_available() else "cpu",
)

trainer.train()
print("Training complete!")

In [None]:
# Cell 4: Evaluate & Visualize
import matplotlib.pyplot as plt
from modern_yolonas import Detector

# Load best checkpoint
checkpoint = torch.load("runs/hardhat/last.pt", map_location="cpu", weights_only=True)
model.load_state_dict(checkpoint["model_state_dict"])
model.eval()

device = "cuda" if torch.cuda.is_available() else "cpu"
detector = Detector(
    model=model,
    device=device,
    conf_threshold=0.25,
    iou_threshold=0.45,
    retain_image=True,
)

# Run inference on validation images
val_images = sorted((cfg.root / "images" / cfg.val_split).glob("*.*"))[:8]

fig, axes = plt.subplots(2, 4, figsize=(20, 10))
class_detection_counts = Counter()

for ax, img_path in zip(axes.flat, val_images):
    detection = detector.detect_image(str(img_path))
    vis = detection.visualize(class_names=cfg.class_names)
    ax.imshow(cv2.cvtColor(vis, cv2.COLOR_BGR2RGB))
    ax.set_title(f"{len(detection.boxes)} detections")
    ax.axis("off")

    for cls_id in detection.class_ids:
        name = cfg.class_names[int(cls_id)] if int(cls_id) < len(cfg.class_names) else f"class_{int(cls_id)}"
        class_detection_counts[name] += 1

plt.tight_layout()
plt.show()

print("\nDetection counts per class:")
for name, count in class_detection_counts.most_common():
    print(f"  {name}: {count}")

In [None]:
# Cell 5: Export to ONNX
import torch
from pathlib import Path

export_model = yolo_nas_s(pretrained=False, num_classes=cfg.num_classes)
export_model.load_state_dict(checkpoint["model_state_dict"])
export_model.eval()

# Fuse RepVGG blocks for deployment
for module in export_model.modules():
    if hasattr(module, "fuse_block_residual_branches"):
        module.fuse_block_residual_branches()

dummy = torch.randn(1, 3, 640, 640)
onnx_path = "runs/hardhat/model_float32.onnx"

torch.onnx.export(
    export_model,
    dummy,
    onnx_path,
    input_names=["images"],
    output_names=["pred_bboxes", "pred_scores"],
    dynamic_axes={
        "images": {0: "batch"},
        "pred_bboxes": {0: "batch"},
        "pred_scores": {0: "batch"},
    },
    opset_version=17,
)

onnx_size = Path(onnx_path).stat().st_size / (1024 * 1024)
print(f"ONNX exported to: {onnx_path}")
print(f"Float32 model size: {onnx_size:.1f} MB")

# Verify output shapes
import onnxruntime as ort

session = ort.InferenceSession(onnx_path)
outputs = session.run(None, {"images": dummy.numpy()})
print(f"Output shapes: bboxes={outputs[0].shape}, scores={outputs[1].shape}")

In [None]:
# Cell 6: Post-Training Quantization (PTQ)
from modern_yolonas.quantization import (
    prepare_model_ptq,
    run_calibration,
    convert_quantized,
    export_quantized_onnx,
)

# Prepare model for PTQ
ptq_model_fresh = yolo_nas_s(pretrained=False, num_classes=cfg.num_classes)
ptq_model_fresh.load_state_dict(checkpoint["model_state_dict"])
ptq_model_fresh.eval()

ptq_model = prepare_model_ptq(ptq_model_fresh)
print("PTQ model prepared with observers")

# Calibrate on validation set (small set is fine for PTQ)
run_calibration(ptq_model, val_loader, num_batches=20, device="cpu")
print("Calibration complete")

# Convert to quantized model
quantized_model = convert_quantized(ptq_model)
print("Model converted to int8")

# Export quantized ONNX
ptq_onnx_path = "runs/hardhat/model_ptq_int8.onnx"
export_quantized_onnx(quantized_model, ptq_onnx_path, input_size=640)

ptq_size = Path(ptq_onnx_path).stat().st_size / (1024 * 1024)
print(f"\nPTQ ONNX exported to: {ptq_onnx_path}")
print(f"PTQ model size: {ptq_size:.1f} MB")
print(f"Size reduction: {(1 - ptq_size / onnx_size) * 100:.0f}%")

In [None]:
# Cell 7: Quantization-Aware Training (QAT)
from modern_yolonas.quantization import prepare_model_qat
from modern_yolonas.quantization.qat_trainer import QATTrainer

# Prepare model for QAT from the fine-tuned checkpoint
qat_model_fresh = yolo_nas_s(pretrained=False, num_classes=cfg.num_classes)
qat_model_fresh.load_state_dict(checkpoint["model_state_dict"])

qat_model = prepare_model_qat(qat_model_fresh)
print("QAT model prepared with fake-quantization nodes")

# QAT fine-tuning (short, low LR)
qat_trainer = QATTrainer(
    model=qat_model,
    train_loader=train_loader,
    val_loader=val_loader,
    num_classes=cfg.num_classes,
    epochs=5,
    lr=2e-5,
    warmup_steps=50,
    output_dir="runs/hardhat/qat",
    device="cuda" if torch.cuda.is_available() else "cpu",
)

qat_trainer.train()
print("QAT training complete!")

# Convert QAT model and export
qat_quantized = convert_quantized(qat_model)
qat_onnx_path = "runs/hardhat/model_qat_int8.onnx"
export_quantized_onnx(qat_quantized, qat_onnx_path, input_size=640)

qat_size = Path(qat_onnx_path).stat().st_size / (1024 * 1024)
print(f"\nQAT ONNX exported to: {qat_onnx_path}")
print(f"\nModel size comparison:")
print(f"  Float32: {onnx_size:.1f} MB")
print(f"  PTQ:     {ptq_size:.1f} MB ({(1 - ptq_size / onnx_size) * 100:.0f}% smaller)")
print(f"  QAT:     {qat_size:.1f} MB ({(1 - qat_size / onnx_size) * 100:.0f}% smaller)")

In [None]:
# Cell 8: ONNX Runtime Inference
import cv2
import numpy as np
import onnxruntime as ort
import matplotlib.pyplot as plt

from modern_yolonas.inference.preprocess import preprocess
from modern_yolonas.inference.postprocess import postprocess, rescale_boxes
from modern_yolonas.inference.visualize import draw_detections

# Load ONNX model
session = ort.InferenceSession(onnx_path)

# Load a test image
test_image_path = str(val_images[0])
original = cv2.imread(test_image_path)
orig_h, orig_w = original.shape[:2]

# Preprocess
input_tensor, pad_info = preprocess(original, input_size=640)

# Run inference
pred_bboxes, pred_scores = session.run(
    None, {"images": input_tensor.numpy()}
)

# Postprocess
import torch
boxes, scores, class_ids = postprocess(
    torch.from_numpy(pred_bboxes),
    torch.from_numpy(pred_scores),
    conf_threshold=0.25,
    iou_threshold=0.45,
)

# Rescale boxes to original image size
if len(boxes) > 0:
    boxes = rescale_boxes(boxes, pad_info, (orig_h, orig_w))

# Visualize
annotated = draw_detections(
    original, boxes.numpy(), scores.numpy(), class_ids.numpy(),
    class_names=cfg.class_names,
)

plt.figure(figsize=(10, 8))
plt.imshow(cv2.cvtColor(annotated, cv2.COLOR_BGR2RGB))
plt.title(f"ONNX Runtime Inference ({len(boxes)} detections)")
plt.axis("off")
plt.show()

print(f"\nDetections:")
for i in range(len(boxes)):
    cls_name = cfg.class_names[int(class_ids[i])] if int(class_ids[i]) < len(cfg.class_names) else f"class_{int(class_ids[i])}"
    print(f"  {cls_name}: {scores[i]:.2f}")