# How to Train RF-DETR Object Detection on a Custom Dataset

In [None]:
!nvidia-smi

In [None]:
import numpy as np
import supervision as sv

from PIL import Image

from rfdetr import RFDETRMedium
from rfdetr.util.coco_classes import COCO_CLASSES

# Custom class names from your trained dataset
CUSTOM_CLASSES = ['pen', 'Pen']

image = Image.open("dataset/valid/00cd18c5aff4548b_jpg.rf.bd1663cdc9106926c3d3507ce963788f.jpg")

model = RFDETRMedium(resolution=640)
model.optimize_for_inference()

detections = model.predict(image, threshold=0.5)

color = sv.ColorPalette.from_hex([
    "#ffff00", "#ff9b00", "#ff8080", "#ff66b2", "#ff66ff", "#b266ff",
    "#9999ff", "#3399ff", "#66ffff", "#33ff99", "#66ff66", "#99ff00"
])
text_scale = sv.calculate_optimal_text_scale(resolution_wh=image.size)
thickness = sv.calculate_optimal_line_thickness(resolution_wh=image.size)

bbox_annotator = sv.BoxAnnotator(color=color, thickness=thickness)
label_annotator = sv.LabelAnnotator(
    color=color,
    text_color=sv.Color.BLACK,
    text_scale=text_scale,
    smart_position=True
)

labels = [
    f"{COCO_CLASSES[class_id]} {confidence:.2f}"
    for class_id, confidence
    in zip(detections.class_id, detections.confidence)
]

annotated_image = image.copy()
annotated_image = bbox_annotator.annotate(annotated_image, detections)
annotated_image = label_annotator.annotate(annotated_image, detections, labels)
annotated_image.thumbnail((800, 800))
annotated_image

```
└── dataset
    ├── test
    │   ├── 0085364b2034b946_jpg.rf.70ae3ded7dadaec0f83fa75a1ca97a1b.jpg
    │   ├── 00cd18c5aff4548b_jpg.rf.bd1663cdc9106926c3d3507ce963788f.jpg
    │   └── _annotations.coco.json
    ├── train
    │   ├── 000fcbfa875b9eb2_jpg.rf.2fa47d1e61228c5f856a69c1c51e46ac.jpg
    │   ├── 000fcbfa875b9eb2_jpg.rf.3ab8fc8b5e049dad298b5f84e46d12cc.jpg
    │   └── _annotations.coco.json
    └── valid
        ├── 0085364b2034b946_jpg.rf.70ae3ded7dadaec0f83fa75a1ca97a1b.jpg
        ├── 00cd18c5aff4548b_jpg.rf.bd1663cdc9106926c3d3507ce963788f.jpg
        └── _annotations.coco.json

```

In [None]:
from rfdetr import RFDETRMedium

model = RFDETRMedium()

model.train(dataset_dir="dataset", epochs=10, batch_size=8, grad_accum_steps=2)

In [None]:
from PIL import Image

Image.open("output/metrics_plot.png")

# Test accuracy

In [None]:
import supervision as sv

ds = sv.DetectionDataset.from_coco(
    images_directory_path="dataset/test",
    annotations_path="dataset/test/_annotations.coco.json",
)

In [None]:
import supervision as sv
from tqdm import tqdm
from supervision.metrics import MeanAveragePrecision

targets = []
predictions = []

for path, image, annotations in tqdm(ds):
    image = Image.open(path)
    detections = model.predict(image, threshold=0)

    targets.append(annotations)
    predictions.append(detections)

In [None]:
map_metric = MeanAveragePrecision()
map_result = map_metric.update(predictions, targets).compute()
print(map_result)

# Inference

In [None]:
model = RFDETRMedium(pretrain_weights="output/checkpoint_best_total.pth")
model.optimize_for_inference()

image = Image.open("dataset/valid/00cd18c5aff4548b_jpg.rf.bd1663cdc9106926c3d3507ce963788f.jpg")

detections = model.predict(image, threshold=0.5)

print(detections)

color = sv.ColorPalette.from_hex([
    "#ffff00", "#ff9b00", "#ff8080", "#ff66b2", "#ff66ff", "#b266ff",
    "#9999ff", "#3399ff", "#66ffff", "#33ff99", "#66ff66", "#99ff00"
])
text_scale = sv.calculate_optimal_text_scale(resolution_wh=image.size)
thickness = sv.calculate_optimal_line_thickness(resolution_wh=image.size)

bbox_annotator = sv.BoxAnnotator(color=color, thickness=thickness)
label_annotator = sv.LabelAnnotator(
    color=color,
    text_color=sv.Color.BLACK,
    text_scale=text_scale,
    smart_position=True
)

labels = [
    f"{CUSTOM_CLASSES[class_id]} {confidence:.2f}"
    for class_id, confidence
    in zip(detections.class_id, detections.confidence)
]

annotated_image = image.copy()
annotated_image = bbox_annotator.annotate(annotated_image, detections)
annotated_image = label_annotator.annotate(annotated_image, detections, labels)
annotated_image.thumbnail((800, 800))
annotated_image