# DETR Knowledge Distillation on KITTI Dataset

This notebook demonstrates the complete pipeline for training a distilled DETR model on the KITTI dataset.

## Overview

1. **Setup**: Install dependencies and import libraries
2. **Configuration**: Load from YAML file
3. **Data Preparation**: Download and convert KITTI to COCO format
4. **Dataset Loading**: Create PyTorch datasets
5. **Model Setup**: Load teacher and student models
6. **Training**: Train with knowledge distillation
7. **Evaluation**: Evaluate with COCO metrics
8. **Visualization**: Visualize predictions

## 1. Setup and Installation

In [1]:
from google.colab import drive
drive.mount('/content/gdrive')

Mounted at /content/gdrive


In [2]:
!pip install -q torch torchvision transformers pycocotools pillow tqdm pyyaml matplotlib opencv-python

In [3]:
import os
!git clone https://github.com/HenryNVP/object-detection.git
%cd object-detection
print(f"Working directory: {os.getcwd()}")

Cloning into 'object-detection'...
remote: Enumerating objects: 61, done.[K
remote: Counting objects: 100% (61/61), done.[K
remote: Compressing objects: 100% (47/47), done.[K
remote: Total 61 (delta 14), reused 57 (delta 10), pack-reused 0 (from 0)[K
Receiving objects: 100% (61/61), 5.68 MiB | 11.32 MiB/s, done.
Resolving deltas: 100% (14/14), done.
/content/object-detection
Working directory: /content/object-detection


## 2. Import Libraries

In [5]:
import sys
import json
import random
import yaml
from pathlib import Path

import torch
import numpy as np
import matplotlib.pyplot as plt
import cv2
from PIL import Image
from tqdm.auto import tqdm
from pycocotools.coco import COCO
from pycocotools.cocoeval import COCOeval
from torch.utils.data import DataLoader

from src.datasets.kitti_coco import build_kitti_coco_dataset, collate_fn
from src.models import build_teacher_student_models
from src.distillation import DistillationLoss, DistillationTrainer
from src.utils import get_device, seed_all

print("✓ All imports successful")
print(f"PyTorch version: {torch.__version__}")
print(f"Device: {get_device()}")

✓ All imports successful
PyTorch version: 2.8.0+cu126
Device: cuda


## 3. Load Configuration from YAML

In [6]:
# Load configuration from YAML file
config_path = Path('configs/distillation.yaml')
with open(config_path) as f:
    config = yaml.safe_load(f)

CONFIG = {
    'kitti_root': './kitti_data/training',
    'data_root': config['data']['root'],
    'output_dir': './output/distillation_notebook',
    'num_labels': config['data']['num_labels'],
    'train_split': 0.8,
    'max_samples': 1000,
    'teacher_model': config['model']['teacher'],
    'student_model': config['model']['student'],
    'batch_size': config['data']['batch_size'],
    'num_workers': config['data']['num_workers'],
    'epochs': 3,
    'learning_rate': config['training']['learning_rate'],
    'weight_decay': config['training']['weight_decay'],
    'temperature': config['distillation']['temperature'],
    'alpha': config['distillation']['alpha'],
    'seed': 42,
    'device': None,
}

## 4. Data Preparation

Download KITTI dataset and convert to COCO format if needed.


In [8]:
# Download KITTI
!python scripts/download_kitti.py --output-dir ./kitti_data

# Convert to COCO format
print("\nConverting to COCO format...")
!python scripts/prepare_kitti_coco.py \
    --kitti-root {CONFIG['kitti_root']} \
    --output-dir {CONFIG['data_root']} \
    --train-split {CONFIG['train_split']} \
    --max-samples {CONFIG['max_samples']}

print("\n✓ Dataset ready!")


KITTI Object Detection Dataset Downloader
--------------------------------------------------
Output directory: kitti_data
--------------------------------------------------

Downloading images...
data_object_image_2.zip:   1% 152M/12.6G [00:07<10:15, 20.2MB/s]
Traceback (most recent call last):
  File "/usr/lib/python3.12/urllib/request.py", line 268, in urlretrieve
    while block := fp.read(bs):
                   ^^^^^^^^^^^
  File "/usr/lib/python3.12/http/client.py", line 479, in read
    s = self.fp.read(amt)
        ^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.12/socket.py", line 720, in readinto
    return self._sock.recv_into(b)
           ^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.12/ssl.py", line 1251, in recv_into
    return self.read(nbytes, buffer)
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.12/ssl.py", line 1103, in read
    return self._sslobj.read(len, buffer)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
KeyboardInterrupt

During handling of the abov

## 5. Load Datasets and Create Data Loaders


In [None]:
print("Loading datasets...")

# Define transforms to convert PIL images to tensors
import torchvision.transforms as T

def get_transform():
    """Basic transform to convert PIL images to tensors."""
    return T.Compose([
        T.ToTensor(),
    ])

# Note: We'll handle DETR-specific preprocessing in the trainer
train_dataset = build_kitti_coco_dataset(
    split='train',
    data_root=CONFIG['data_root'],
    transforms=None,  # We'll use image_processor in trainer
)

val_dataset = build_kitti_coco_dataset(
    split='val',
    data_root=CONFIG['data_root'],
    transforms=None,  # We'll use image_processor in trainer
)

train_loader = DataLoader(
    train_dataset,
    batch_size=CONFIG['batch_size'],
    shuffle=True,
    num_workers=CONFIG['num_workers'],
    collate_fn=collate_fn,
)

val_loader = DataLoader(
    val_dataset,
    batch_size=CONFIG['batch_size'],
    shuffle=False,
    num_workers=CONFIG['num_workers'],
    collate_fn=collate_fn,
)

print(f"✓ Train dataset: {len(train_dataset)} samples ({len(train_loader)} batches)")
print(f"✓ Val dataset: {len(val_dataset)} samples ({len(val_loader)} batches)")


## 6. Load Teacher and Student Models

**Note**: Using `facebook/detr-resnet-50` (official Facebook DETR model) as teacher.  
Creating a custom smaller student model for distillation with:
- **Smaller backbone**: ResNet-18 vs ResNet-50 (~11M vs ~25M params)
- **Fewer transformer layers**: 4 vs 6 (both encoder and decoder)
- **Fewer attention heads**: 4 vs 8
- **Smaller FFN dimension**: 1024 vs 2048
- **Result**: ~60-70% parameter reduction


In [None]:
from transformers import DetrForObjectDetection, DetrImageProcessor, DetrConfig
from torchvision.models import resnet18, resnet50
import torch.nn as nn

print("Loading models...")

# Use facebook/detr-resnet-50 (available on HuggingFace)
teacher_model_name = "facebook/detr-resnet-50"
print(f"Teacher: {teacher_model_name} (ResNet-50 backbone)")

# Load teacher model and image processor
image_processor = DetrImageProcessor.from_pretrained(teacher_model_name)
teacher_model = DetrForObjectDetection.from_pretrained(
    teacher_model_name,
    num_labels=CONFIG['num_labels'],
    ignore_mismatched_sizes=True
)
teacher_model = teacher_model.to(device)
teacher_model.eval()

# Freeze teacher
for param in teacher_model.parameters():
    param.requires_grad = False

print("✓ Teacher model loaded and frozen")

# Create smaller student model with ResNet-18 backbone
print("\nCreating smaller student model with ResNet-18 backbone...")

# Load ResNet-18 and remove classification head
resnet18_backbone = resnet18(pretrained=True)
# Remove avgpool and fc layers, keep only conv layers
student_backbone = nn.Sequential(*list(resnet18_backbone.children())[:-2])

# Create DETR config for student with smaller dimensions
config_detr = DetrConfig.from_pretrained(teacher_model_name)
config_detr.num_labels = CONFIG['num_labels']

# Adjust for ResNet-18's smaller feature dimension
# ResNet-18 outputs 512 channels (vs 2048 for ResNet-50)
config_detr.d_model = 256  # Keep hidden dimension same
config_detr.encoder_attention_heads = 4  # Fewer attention heads (default 8)
config_detr.decoder_attention_heads = 4
config_detr.encoder_layers = 4  # Fewer encoder layers (default 6)
config_detr.decoder_layers = 4  # Fewer decoder layers (default 6)
config_detr.encoder_ffn_dim = 1024  # Smaller FFN (default 2048)
config_detr.decoder_ffn_dim = 1024

# Create student model (will use default ResNet-50 backbone initially)
student_model = DetrForObjectDetection(config_detr)

# Replace the backbone with ResNet-18
print("  → Replacing backbone with ResNet-18...")
student_model.model.backbone.conv_encoder.model = student_backbone

student_model = student_model.to(device)

print("✓ Student model created with ResNet-18 backbone")

# Count parameters
teacher_params = sum(p.numel() for p in teacher_model.parameters())
student_params = sum(p.numel() for p in student_model.parameters())
student_trainable = sum(p.numel() for p in student_model.parameters() if p.requires_grad)

print(f"\n📊 Model Statistics:")
print(f"  Teacher parameters: {teacher_params:,} (ResNet-50 backbone)")
print(f"  Student parameters: {student_params:,} (ResNet-18 backbone, {student_trainable:,} trainable)")
print(f"  Compression ratio: {student_params / teacher_params:.2%}")
print(f"  Size reduction: {(1 - student_params / teacher_params):.1%}")
print(f"\nBackbone comparison:")
print(f"  Teacher backbone: ResNet-50 (~25M params)")
print(f"  Student backbone: ResNet-18 (~11M params)")


## 7. Setup Training with Distillation


In [None]:
# Setup optimizer
optimizer = torch.optim.AdamW(
    student_model.parameters(),
    lr=CONFIG['learning_rate'],
    weight_decay=CONFIG['weight_decay'],
)

# Setup distillation loss
distillation_loss = DistillationLoss(
    temperature=CONFIG['temperature'],
    alpha=CONFIG['alpha'],
)

# Create trainer with image_processor for PIL image handling
trainer = DistillationTrainer(
    teacher_model=teacher_model,
    student_model=student_model,
    train_loader=train_loader,
    val_loader=val_loader,
    optimizer=optimizer,
    distillation_loss=distillation_loss,
    device=device,
    output_dir=CONFIG['output_dir'],
    image_processor=image_processor,  # Pass processor to handle PIL images
)

print("✓ Training setup complete")
print(f"  Learning rate: {CONFIG['learning_rate']}")
print(f"  Temperature: {CONFIG['temperature']}")
print(f"  Alpha: {CONFIG['alpha']}")


## 8. Train Model

Train the student model with knowledge distillation from the teacher.


In [None]:
print(f"Starting training for {CONFIG['epochs']} epochs...")
print("=" * 70)

# Train
trainer.train(num_epochs=CONFIG['epochs'], save_every=1)

print("\n✓ Training complete!")
print(f"Checkpoints saved to: {CONFIG['output_dir']}")


## 9. Evaluate Model

Evaluate the trained student model on the validation set.


In [None]:
@torch.no_grad()
def evaluate_model(model, data_loader, device):
    """Evaluate model on validation set."""
    model.eval()
    predictions = []

    for images, targets in tqdm(data_loader, desc="Evaluating"):
        images = [img.to(device) for img in images]

        # Process images with image_processor
        pixel_values = torch.stack(images)
        outputs = model(pixel_values=pixel_values)

        for i, target in enumerate(targets):
            image_id = target['image_id'].item()
            logits = outputs.logits[i]
            boxes = outputs.pred_boxes[i]

            # Get predicted class and score
            scores = logits.softmax(-1)[:, :-1].max(-1)
            labels = scores.indices
            scores = scores.values

            # Filter low confidence predictions
            keep = scores > 0.3
            for box, score, label in zip(boxes[keep], scores[keep], labels[keep]):
                # Convert from normalized [cx, cy, w, h] to COCO [x, y, w, h]
                cx, cy, w, h = box.cpu().tolist()
                img_h, img_w = target['orig_size'].tolist()
                x = (cx - w/2) * img_w
                y = (cy - h/2) * img_h
                w = w * img_w
                h = h * img_h

                predictions.append({
                    'image_id': image_id,
                    'category_id': int(label.item()) + 1,
                    'bbox': [x, y, w, h],
                    'score': float(score.item()),
                })

    return predictions

print("Evaluating on validation set...")
predictions = evaluate_model(student_model, val_loader, device)
print(f"✓ Generated {len(predictions)} predictions")


In [None]:
# Run COCO evaluation
if len(predictions) > 0:
    print("\nRunning COCO evaluation...")
    ann_file = Path(CONFIG['data_root']) / 'annotations' / 'instances_val.json'

    coco_gt = COCO(str(ann_file))
    coco_dt = coco_gt.loadRes(predictions)
    coco_eval = COCOeval(coco_gt, coco_dt, 'bbox')
    coco_eval.evaluate()
    coco_eval.accumulate()
    coco_eval.summarize()
else:
    print("\n⚠️ No predictions to evaluate")


## 10. Visualize Predictions

Visualize sample predictions from the trained model.


In [None]:
def visualize_predictions(model, dataset, device, num_samples=3):
    """Visualize predictions on random samples."""
    model.eval()
    fig, axes = plt.subplots(1, num_samples, figsize=(15, 5))
    if num_samples == 1:
        axes = [axes]

    indices = random.sample(range(len(dataset)), num_samples)

    with torch.no_grad():
        for idx, ax in zip(indices, axes):
            image, target = dataset[idx]
            image_tensor = image.unsqueeze(0).to(device)
            outputs = model(pixel_values=image_tensor)

            logits = outputs.logits[0]
            boxes = outputs.pred_boxes[0]
            scores = logits.softmax(-1)[:, :-1].max(-1)
            labels = scores.indices
            scores = scores.values
            keep = scores > 0.5

            # Convert image to numpy for visualization
            img_np = image.permute(1, 2, 0).cpu().numpy()
            img_np = (img_np * 255).astype(np.uint8)
            h, w = img_np.shape[:2]

            # Draw bounding boxes
            for box, score in zip(boxes[keep], scores[keep]):
                cx, cy, bw, bh = box.cpu().numpy()
                x1 = int((cx - bw/2) * w)
                y1 = int((cy - bh/2) * h)
                x2 = int((cx + bw/2) * w)
                y2 = int((cy + bh/2) * h)
                cv2.rectangle(img_np, (x1, y1), (x2, y2), (0, 255, 0), 2)
                cv2.putText(img_np, f"{score.item():.2f}", (x1, y1-5),
                           cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2)

            ax.imshow(img_np)
            ax.set_title(f"Predictions (n={keep.sum()})")
            ax.axis('off')

    plt.tight_layout()
    plt.show()

visualize_predictions(student_model, val_dataset, device, num_samples=3)


## 11. Summary

Knowledge distillation training completed successfully!


In [None]:
print("="*70)
print("🎉 KNOWLEDGE DISTILLATION PIPELINE COMPLETE!")
print("="*70)
print("\n✓ Configuration loaded from YAML")
print("✓ Dataset prepared and loaded")
print("✓ Teacher and student models configured")
print("✓ Training completed with distillation")
print("✓ Model evaluated with COCO metrics")
print("✓ Predictions visualized")
print(f"\n📁 Output directory: {CONFIG['output_dir']}")
print("   - best.pth: Best model checkpoint")
print("   - epoch_*.pth: Epoch checkpoints")
print("\n🚀 Next Steps:")
print("   1. Train for more epochs (edit YAML config)")
print("   2. Tune hyperparameters in YAML")
print("   3. Try different model pairs")
print("   4. Deploy the model")


In [None]:
# Load configuration from YAML file
config_path = Path('configs/distillation.yaml')
with open(config_path) as f:
    config = yaml.safe_load(f)

# Create flattened CONFIG for easier access
CONFIG = {
    'kitti_root': './kitti_data/training',
    'data_root': config['data']['root'],
    'output_dir': './output/distillation_notebook',
    'num_labels': config['data']['num_labels'],
    'train_split': 0.8,
    'max_samples': 200,
    'teacher_model': config['model']['teacher'],
    'student_model': config['model']['student'],
    'batch_size': config['data']['batch_size'],
    'num_workers': config['data']['num_workers'],
    'epochs': 3,
    'learning_rate': config['training']['learning_rate'],
    'weight_decay': config['training']['weight_decay'],
    'temperature': config['distillation']['temperature'],
    'alpha': config['distillation']['alpha'],
    'seed': 42,
    'device': None,
}

print("📋 Configuration loaded from:", config_path)
print("\n🔧 Notebook overrides (for faster demo):")
print(f"  • epochs: {config['training']['epochs']} → {CONFIG['epochs']}")
print(f"  • max_samples: full dataset → {CONFIG['max_samples']}")
print(f"  • output_dir: {config['output_dir']} → {CONFIG['output_dir']}")

seed_all(CONFIG['seed'])
device = get_device(CONFIG['device'])
print(f"\n🖥️  Using device: {device}")

Path(CONFIG['output_dir']).mkdir(parents=True, exist_ok=True)

print("\n📊 Active Configuration:")
print(f"  Teacher: {CONFIG['teacher_model']}")
print(f"  Student: {CONFIG['student_model']}")
print(f"  Epochs: {CONFIG['epochs']}")
print(f"  Batch size: {CONFIG['batch_size']}")
print(f"  Learning rate: {CONFIG['learning_rate']}")
print(f"  Temperature: {CONFIG['temperature']}")
print(f"  Alpha: {CONFIG['alpha']}")