# Prep

Setting up some prior functionality

In [1]:
import torch, torchvision
print(torch.__version__, torch.cuda.is_available())

2.6.0+cu126 True


# Load a model

First we have to decide if our model should be pretrained. 

This greatly depends on the size of a dataset. Smaller datasets rely more on finetuning. 

In [2]:
pretrained = True

if pretrained:
    # Get pretrained weights
    checkpoint = torch.hub.load_state_dict_from_url(
                url='https://dl.fbaipublicfiles.com/detr/detr-r50-e632da11.pth',
                map_location='cpu',
                check_hash=True)

    # Remove class weights
    del checkpoint["model"]["class_embed.weight"]
    del checkpoint["model"]["class_embed.bias"]

    # SaveOGH
    torch.save(checkpoint,
               'detr-r50_no-class-head.pth')

# Dataset

Our dataset should be loadable as a COCO format

This allows us to use the pycocotools to load the data dict for the main python script

In [8]:
dataset_file = "coco" # alternatively, implement your own coco-type dataset loader in datasets and add this "key" to datasets/__init__.py

dataDir='datasets/ship_playground_coco' # should lead to a directory with a train2017 and val2017 folder as well as an annotations folder
num_classes = 2 # this int should be the actual number of classes + 1 (for no class)

outDir = 'outputs'
resume = "detr-r50_no-class-head.pth" if pretrained else ""

# Training

We use the main.py script to run our training

In [None]:
!python main.py \
  --dataset_file $dataset_file \
  --coco_path $dataDir \
  --output_dir $outDir \
  --resume $resume \
  --num_classes $num_classes \
  --lr 1e-5 \
  --lr_backbone 1e-6 \
  --epochs 1

# Results

Quick and easy overview of the training results

In [5]:
from util.plot_utils import plot_logs

from pathlib import Path

log_directory = [Path(outDir)]

ModuleNotFoundError: No module named 'pandas'

In [None]:
fields_of_interest = (
    'loss',
    'mAP',
    )

plot_logs(log_directory,
          fields_of_interest)

In [None]:
fields_of_interest = (
    'loss_ce',
    'loss_bbox',
    'loss_giou',
    )

plot_logs(log_directory,
          fields_of_interest)

In [None]:
fields_of_interest = (
    'class_error',
    'cardinality_error_unscaled',
    )

plot_logs(log_directory,
          fields_of_interest)   