# Train Yolo-NAS model with 3LC Metrics Collection

This notebook shows how to train a Yolo-NAS detection model with 3LC Metrics Collection on an Object Detection 3LC `Table`.

## Imports

In [None]:
from super_gradients.common.object_names import Models
from super_gradients.training import Trainer, models
from super_gradients.training.transforms.transforms import (
    DetectionHorizontalFlip,
    DetectionHSV,
    DetectionPaddedRescale,
    DetectionRandomAffine,
    DetectionStandardize,
    DetectionTargetsFormatTransform,
)
from super_gradients.training.utils.collate_fn.detection_collate_fn import DetectionCollateFN
from torch.utils.data import DataLoader

import tlc
from tlc.integration.super_gradients.datasets.detection_dataset import DetectionDataset
from tlc.integration.super_gradients.hooks import TLCLoggingCallback
from tlc.integration.super_gradients.metrics_collectors import DetectionMetricsCollector

## Project Setup

In [None]:
PROJECT_NAME = ...
DATASET_NAME = ...

In [None]:
train_table = ...
val_table = ...

In [None]:
simple_value_map = train_table.get_simple_value_map("bbs.bb_list.label")
assert isinstance(simple_value_map, dict)
class_names = list(simple_value_map.values())
num_classes = len(class_names)

train_dataset = DetectionDataset(
    train_table,
    input_dim=(640, 640),
    transforms=[
        DetectionRandomAffine(
            degrees=0.0,
            scales=(0.5, 1.5),
            shear=0.0,
            target_size=(640, 640),
            filter_box_candidates=False,
            border_value=128,
        ),
        DetectionHSV(prob=1.0, hgain=5, vgain=30, sgain=30),
        DetectionHorizontalFlip(prob=0.5),
        DetectionPaddedRescale(input_dim=(640, 640)),
        DetectionStandardize(max_value=255),
        DetectionTargetsFormatTransform(input_dim=(640, 640), output_format="LABEL_CXCYWH"),
    ],
)
val_dataset = DetectionDataset(
    val_table,
    input_dim=(640, 640),
    transforms=[
        DetectionPaddedRescale(input_dim=(640, 640), max_targets=300),
        DetectionStandardize(max_value=255),
        DetectionTargetsFormatTransform(input_dim=(640, 640), output_format="LABEL_CXCYWH"),
    ],
)

train_dataloader: DataLoader = DataLoader(
    train_dataset, batch_size=4, shuffle=True, collate_fn=DetectionCollateFN()
)
val_dataloader: DataLoader = DataLoader(val_dataset, batch_size=4, shuffle=False, collate_fn=DetectionCollateFN())

model = models.get(Models.YOLO_NAS_S, num_classes=num_classes, pretrained_weights="coco")

trainer = Trainer(
    experiment_name="fine-tune-yolo-nas-detect-3lc",
    ckpt_root_dir="checkpoints",
)

In [None]:
from super_gradients.training.losses import PPYoloELoss
from super_gradients.training.metrics import DetectionMetrics_050
from super_gradients.training.models.detection_models.pp_yolo_e import PPYoloEPostPredictionCallback

detection_metrics_collector = DetectionMetricsCollector(train_table)
tlc_callback = TLCLoggingCallback(
    project_name=PROJECT_NAME,
    metrics_collectors=[detection_metrics_collector],
)

train_params = {
    "warmup_initial_lr": 1e-5,
    "initial_lr": 5e-4,
    "lr_mode": "cosine",
    "cosine_final_lr_ratio": 0.5,
    "optimizer": "AdamW",
    "zero_weight_decay_on_bias_and_bn": True,
    "lr_warmup_epochs": 1,
    "warmup_mode": "LinearEpochLRWarmup",
    "optimizer_params": {"weight_decay": 0.0001},
    "ema": False,
    "average_best_models": False,
    "ema_params": {"beta": 25, "decay_type": "exp"},
    "max_epochs": 5,
    "mixed_precision": True,
    "loss": PPYoloELoss(use_static_assigner=False, num_classes=num_classes, reg_max=16),
    "valid_metrics_list": [
        DetectionMetrics_050(
            score_thres=0.1,
            top_k_predictions=300,
            num_cls=num_classes,
            normalize_targets=True,
            include_classwise_ap=True,
            class_names=class_names,
            post_prediction_callback=PPYoloEPostPredictionCallback(
                score_threshold=0.01, nms_top_k=1000, max_predictions=300, nms_threshold=0.7
            ),
        )
    ],
    "metric_to_watch": "mAP@0.50",
    "phase_callbacks": [tlc_callback],
}

run = tlc.init(project_name=PROJECT_NAME)

trainer.train(
    model=model,
    training_params=train_params,
    train_loader=train_dataloader,
    valid_loader=val_dataloader,
)