# Fine-Tune YOLO-NAS on a Roboflow Dataset

This notebook demonstrates how to fine-tune YOLO-NAS S on a Roboflow dataset using PyTorch Lightning, then evaluate and visualize the results.

## Prerequisites

Make sure you have run `01_explore_dataset.ipynb` first to download the dataset from Roboflow.

In [None]:
# Load config + setup
from modern_yolonas.data import load_dataset_config

cfg = load_dataset_config("./hardhat-dataset/data.yaml")
print(f"Training on {cfg.num_classes} classes: {cfg.class_names}")

In [None]:
# Fine-tune with Lightning
import torch
import lightning as L
from lightning.pytorch.callbacks import ModelCheckpoint

from modern_yolonas import yolo_nas_s
from modern_yolonas.data import YOLODetectionDataset
from modern_yolonas.data.transforms import (
    Compose, HSVAugment, HorizontalFlip, LetterboxResize, Normalize,
)
from modern_yolonas.training import (
    YoloNASLightningModule, EMACallback, DetectionDataModule,
)

model = yolo_nas_s(pretrained=True, num_classes=cfg.num_classes)
print(f"Model created: YOLO-NAS S with {cfg.num_classes} classes")

train_transforms = Compose([
    HSVAugment(), HorizontalFlip(p=0.5), LetterboxResize(target_size=640), Normalize(),
])
val_transforms = Compose([LetterboxResize(target_size=640), Normalize()])

train_ds = YOLODetectionDataset(root=cfg.root, split=cfg.train_split, transforms=train_transforms)
val_ds = YOLODetectionDataset(root=cfg.root, split=cfg.val_split, transforms=val_transforms)

lit_model = YoloNASLightningModule(model=model, num_classes=cfg.num_classes, lr=2e-4, warmup_steps=100)
data_module = DetectionDataModule(train_dataset=train_ds, val_dataset=val_ds, batch_size=8, num_workers=2)

trainer = L.Trainer(
    max_epochs=15, accelerator="auto", precision="16-mixed",
    callbacks=[EMACallback(), ModelCheckpoint(dirpath="runs/hardhat", save_last=True)],
    default_root_dir="runs/hardhat",
)
trainer.fit(lit_model, datamodule=data_module)
print("Training complete!")

In [None]:
# Evaluate & Visualize
import cv2
import matplotlib.pyplot as plt
from collections import Counter
from modern_yolonas import Detector
from modern_yolonas.training import extract_model_state_dict

best_ckpt = "runs/hardhat/last.ckpt"
sd = extract_model_state_dict(best_ckpt)
model.load_state_dict(sd)
model.eval()

device = "cuda" if torch.cuda.is_available() else "cpu"
detector = Detector(model=model, device=device, conf_threshold=0.25, iou_threshold=0.45, retain_image=True)

val_images = sorted((cfg.root / "images" / cfg.val_split).glob("*.*"))[:8]

fig, axes = plt.subplots(2, 4, figsize=(20, 10))
class_detection_counts = Counter()

for ax, img_path in zip(axes.flat, val_images):
    detection = detector.detect_image(str(img_path))
    vis = detection.visualize(class_names=cfg.class_names)
    ax.imshow(cv2.cvtColor(vis, cv2.COLOR_BGR2RGB))
    ax.set_title(f"{len(detection.boxes)} detections")
    ax.axis("off")
    for cls_id in detection.class_ids:
        name = cfg.class_names[int(cls_id)] if int(cls_id) < len(cfg.class_names) else f"class_{int(cls_id)}"
        class_detection_counts[name] += 1

plt.tight_layout()
plt.show()

print("\nDetection counts per class:")
for name, count in class_detection_counts.most_common():
    print(f"  {name}: {count}")