# Fine-Tune YOLO-NAS on a FiftyOne Dataset

This notebook demonstrates how to fine-tune YOLO-NAS on a dataset loaded from FiftyOne.

We'll use the `FiftyOneDetectionDataset` to train YOLO-NAS S on the FiftyOne quickstart data with PyTorch Lightning.

In [None]:
import fiftyone as fo
import fiftyone.zoo as foz
import fiftyone.utils.random as four

# Load quickstart dataset
dataset = foz.load_zoo_dataset("quickstart")

# Create 80/20 train/val split using tags
four.random_split(dataset, {"train": 0.8, "val": 0.2})

train_view = dataset.match_tags("train")
val_view = dataset.match_tags("val")
print(f"Train samples: {len(train_view)}")
print(f"Val samples:   {len(val_view)}")

In [None]:
from modern_yolonas import yolo_nas_s
from modern_yolonas.data import FiftyOneDetectionDataset
from modern_yolonas.data.transforms import Compose, HSVAugment, HorizontalFlip, LetterboxResize, Normalize
from modern_yolonas.inference.visualize import COCO_NAMES

# Use COCO class names since quickstart uses COCO labels
class_names = list(COCO_NAMES)

train_transforms = Compose([
    HSVAugment(), HorizontalFlip(p=0.5), LetterboxResize(target_size=640), Normalize(),
])
val_transforms = Compose([LetterboxResize(target_size=640), Normalize()])

train_ds = FiftyOneDetectionDataset(
    fo_dataset=train_view, label_field="ground_truth",
    class_names=class_names, transforms=train_transforms,
)
val_ds = FiftyOneDetectionDataset(
    fo_dataset=val_view, label_field="ground_truth",
    class_names=class_names, transforms=val_transforms,
)

print(f"Num classes: {train_ds.num_classes}")
print(f"Train dataset: {len(train_ds)} samples")
print(f"Val dataset:   {len(val_ds)} samples")

model = yolo_nas_s(pretrained=True, num_classes=len(class_names))
print(f"\nModel: YOLO-NAS S with {len(class_names)} classes")

In [None]:
import lightning as L
from lightning.pytorch.callbacks import ModelCheckpoint
from modern_yolonas.training import YoloNASLightningModule, EMACallback, DetectionDataModule

lit_model = YoloNASLightningModule(
    model=model, num_classes=len(class_names), lr=2e-4, warmup_steps=50,
)
data_module = DetectionDataModule(
    train_dataset=train_ds, val_dataset=val_ds, batch_size=8, num_workers=2,
)

trainer = L.Trainer(
    max_epochs=5, accelerator="auto", precision="16-mixed",
    callbacks=[EMACallback(), ModelCheckpoint(dirpath="runs/fiftyone", save_last=True)],
    default_root_dir="runs/fiftyone",
)
trainer.fit(lit_model, datamodule=data_module)
print("Training complete!")

In [None]:
import torch
import cv2
import matplotlib.pyplot as plt
from collections import Counter
from modern_yolonas import Detector
from modern_yolonas.training import extract_model_state_dict

sd = extract_model_state_dict("runs/fiftyone/last.ckpt")
model.load_state_dict(sd)
model.eval()

device = "cuda" if torch.cuda.is_available() else "cpu"
detector = Detector(model=model, device=device, conf_threshold=0.25, iou_threshold=0.45, retain_image=True)

val_samples = list(val_view.take(8))
fig, axes = plt.subplots(2, 4, figsize=(20, 10))
detection_counts = Counter()

for ax, sample in zip(axes.flat, val_samples):
    detection = detector.detect_image(sample.filepath)
    vis = detection.visualize(class_names=class_names)
    ax.imshow(cv2.cvtColor(vis, cv2.COLOR_BGR2RGB))
    ax.set_title(f"{len(detection.boxes)} detections")
    ax.axis("off")
    for cls_id in detection.class_ids:
        name = class_names[int(cls_id)] if int(cls_id) < len(class_names) else f"class_{int(cls_id)}"
        detection_counts[name] += 1

plt.tight_layout()
plt.show()

print("\nDetection counts per class:")
for name, count in detection_counts.most_common(10):
    print(f"  {name}: {count}")