# DINO-DETR Knowledge Distillation on KITTI Dataset

This notebook demonstrates the complete pipeline for training a distilled DINO-DETR model on the KITTI dataset.

## Overview

1. **Setup**: Install dependencies and import libraries
2. **Data Preparation**: Download and convert KITTI to COCO format
3. **Dataset Loading**: Create PyTorch datasets
4. **Model Setup**: Load teacher and student models
5. **Training**: Train with knowledge distillation
6. **Evaluation**: Evaluate with COCO metrics
7. **Visualization**: Visualize predictions

---


## 1. Setup and Installation


In [None]:
# Check if running in Colab
try:
    import google.colab
    IN_COLAB = True
    print("Running in Google Colab")
except:
    IN_COLAB = False
    print("Running locally")

# Mount Google Drive if in Colab
if IN_COLAB:
    from google.colab import drive
    drive.mount('/content/gdrive')


In [None]:
# Install dependencies
!pip install -q torch torchvision transformers pycocotools pillow tqdm pyyaml matplotlib opencv-python


In [None]:
# Setup working directory
import os
if IN_COLAB:
    # Clone repository if needed
    if not os.path.exists('object-detection'):
        !git clone https://github.com/your-repo/object-detection.git
        %cd object-detection
else:
    # Navigate to repository root if in notebooks folder
    if os.path.basename(os.getcwd()) == 'notebooks':
        os.chdir('..')
print(f"Working directory: {os.getcwd()}")


## 2. Import Libraries


In [None]:
import sys
import json
import random
from pathlib import Path

import torch
import numpy as np
import matplotlib.pyplot as plt
import cv2
from PIL import Image
from tqdm.auto import tqdm
from pycocotools.coco import COCO
from pycocotools.cocoeval import COCOeval
from torch.utils.data import DataLoader

# Import project modules
from src.datasets import build_kitti_coco_dataset, collate_fn
from src.models import build_teacher_student_models
from src.distillation import DistillationLoss, DistillationTrainer
from src.utils import get_device, seed_all

print("✓ All imports successful")
print(f"PyTorch version: {torch.__version__}")
print(f"Device: {get_device()}")


## 3. Configuration


In [None]:
# Configuration
CONFIG = {
    # Paths
    'kitti_root': './kitti_data/training',
    'data_root': './kitti_coco',
    'output_dir': './output/distillation_notebook',
    
    # Data
    'num_labels': 4,  # car, person, bicycle + background
    'train_split': 0.8,
    'max_samples': 200,  # Use subset for faster training
    
    # Models
    'teacher_model': 'IDEA-Research/dino-detr-resnet-50',
    'student_model': 'IDEA-Research/dino-detr-resnet-50',
    
    # Training
    'batch_size': 2,
    'num_workers': 2,
    'epochs': 3,
    'learning_rate': 1e-4,
    'weight_decay': 1e-4,
    
    # Distillation
    'temperature': 2.0,
    'alpha': 0.5,
    
    # Other
    'seed': 42,
    'device': None,  # auto-detect
}

# Set random seed
seed_all(CONFIG['seed'])

# Get device
device = get_device(CONFIG['device'])
print(f"Using device: {device}")

# Create output directory
Path(CONFIG['output_dir']).mkdir(parents=True, exist_ok=True)

print("\n📋 Configuration:")
for key, value in CONFIG.items():
    print(f"  {key}: {value}")


In [None]:
# Download KITTI if not exists
if not Path(CONFIG['kitti_root']).exists():
    print("Downloading KITTI dataset...")
    !python scripts/download_kitti.py --output-dir ./kitti_data
else:
    print("✓ KITTI dataset already exists")

# Convert to COCO format if not exists
if not Path(CONFIG['data_root']).exists():
    print("\nConverting to COCO format...")
    !python scripts/prepare_kitti_coco.py \
        --kitti-root {CONFIG['kitti_root']} \
        --output-dir {CONFIG['data_root']} \
        --train-split {CONFIG['train_split']} \
        --max-samples {CONFIG['max_samples']}
else:
    print("✓ COCO format dataset already exists")

print("\n✓ Dataset ready!")


## 5. Load Datasets and Create Data Loaders


In [None]:
print("Loading datasets...")

train_dataset = build_kitti_coco_dataset(
    split='train',
    data_root=CONFIG['data_root'],
    transforms=None,
)

val_dataset = build_kitti_coco_dataset(
    split='val',
    data_root=CONFIG['data_root'],
    transforms=None,
)

train_loader = DataLoader(
    train_dataset,
    batch_size=CONFIG['batch_size'],
    shuffle=True,
    num_workers=CONFIG['num_workers'],
    collate_fn=collate_fn,
)

val_loader = DataLoader(
    val_dataset,
    batch_size=CONFIG['batch_size'],
    shuffle=False,
    num_workers=CONFIG['num_workers'],
    collate_fn=collate_fn,
)

print(f"✓ Train dataset: {len(train_dataset)} samples ({len(train_loader)} batches)")
print(f"✓ Val dataset: {len(val_dataset)} samples ({len(val_loader)} batches)")


## 6. Load Teacher and Student Models


In [None]:
print("Loading models...")
print(f"Teacher: {CONFIG['teacher_model']}")
print(f"Student: {CONFIG['student_model']}")

teacher_model, student_model, image_processor = build_teacher_student_models(
    teacher_model_name=CONFIG['teacher_model'],
    student_model_name=CONFIG['student_model'],
    num_labels=CONFIG['num_labels'],
    device=device,
)

# Count parameters
teacher_params = sum(p.numel() for p in teacher_model.parameters())
student_params = sum(p.numel() for p in student_model.parameters())
student_trainable = sum(p.numel() for p in student_model.parameters() if p.requires_grad)

print(f"\n✓ Models loaded")
print(f"  Teacher parameters: {teacher_params:,}")
print(f"  Student parameters: {student_params:,} ({student_trainable:,} trainable)")


## 7. Setup Training with Distillation


In [None]:
# Setup optimizer
optimizer = torch.optim.AdamW(
    student_model.parameters(),
    lr=CONFIG['learning_rate'],
    weight_decay=CONFIG['weight_decay'],
)

# Setup distillation loss
distillation_loss = DistillationLoss(
    temperature=CONFIG['temperature'],
    alpha=CONFIG['alpha'],
)

# Create trainer
trainer = DistillationTrainer(
    teacher_model=teacher_model,
    student_model=student_model,
    train_loader=train_loader,
    val_loader=val_loader,
    optimizer=optimizer,
    distillation_loss=distillation_loss,
    device=device,
    output_dir=CONFIG['output_dir'],
)

print("✓ Training setup complete")
print(f"  Learning rate: {CONFIG['learning_rate']}")
print(f"  Temperature: {CONFIG['temperature']}")
print(f"  Alpha: {CONFIG['alpha']}")


## 8. Train Model

Train the student model with knowledge distillation from the teacher.


In [None]:
print(f"Starting training for {CONFIG['epochs']} epochs...")
print("=" * 70)

# Train
trainer.train(num_epochs=CONFIG['epochs'], save_every=1)

print("\n✓ Training complete!")
print(f"Checkpoints saved to: {CONFIG['output_dir']}")


## 9. Evaluate Model

Evaluate the trained student model on the validation set.


In [None]:
@torch.no_grad()
def evaluate_model(model, data_loader, device):
    \"\"\"Evaluate model on validation set.\"\"\"
    model.eval()
    predictions = []
    
    for images, targets in tqdm(data_loader, desc="Evaluating"):
        images = [img.to(device) for img in images]
        outputs = model(images)
        
        for output, target in zip(outputs, targets):
            image_id = target['image_id'].item()
            logits = output['logits']
            boxes = output['pred_boxes']
            
            scores = logits.softmax(-1)[:, :-1].max(-1)
            labels = scores.indices
            scores = scores.values
            
            keep = scores > 0.3
            for box, score, label in zip(boxes[keep], scores[keep], labels[keep]):
                predictions.append({
                    'image_id': image_id,
                    'category_id': int(label.item()) + 1,
                    'bbox': box.cpu().tolist(),
                    'score': float(score.item()),
                })
    
    return predictions

print("Evaluating on validation set...")
predictions = evaluate_model(student_model, val_loader, device)
print(f"✓ Generated {len(predictions)} predictions")


In [None]:
# Run COCO evaluation
if len(predictions) > 0:
    print("\nRunning COCO evaluation...")
    ann_file = Path(CONFIG['data_root']) / 'annotations' / 'instances_val.json'
    
    coco_gt = COCO(str(ann_file))
    coco_dt = coco_gt.loadRes(predictions)
    coco_eval = COCOeval(coco_gt, coco_dt, 'bbox')
    coco_eval.evaluate()
    coco_eval.accumulate()
    coco_eval.summarize()
else:
    print("\n⚠️ No predictions to evaluate")


## 10. Visualize Predictions

Visualize sample predictions from the trained model.


In [None]:
def visualize_predictions(model, dataset, device, num_samples=3):
    \"\"\"Visualize predictions on random samples.\"\"\"
    model.eval()
    fig, axes = plt.subplots(1, num_samples, figsize=(15, 5))
    if num_samples == 1:
        axes = [axes]
    
    indices = random.sample(range(len(dataset)), num_samples)
    
    with torch.no_grad():
        for idx, ax in zip(indices, axes):
            image, target = dataset[idx]
            image_tensor = image.unsqueeze(0).to(device)
            outputs = model(image_tensor)[0]
            
            logits = outputs['logits']
            boxes = outputs['pred_boxes']
            scores = logits.softmax(-1)[:, :-1].max(-1)
            labels = scores.indices
            scores = scores.values
            keep = scores > 0.5
            
            img_np = image.permute(1, 2, 0).cpu().numpy()
            img_np = (img_np * 255).astype(np.uint8)
            h, w = img_np.shape[:2]
            
            for box, score in zip(boxes[keep], scores[keep]):
                cx, cy, bw, bh = box.cpu().numpy()
                x1 = int((cx - bw/2) * w)
                y1 = int((cy - bh/2) * h)
                x2 = int((cx + bw/2) * w)
                y2 = int((cy + bh/2) * h)
                cv2.rectangle(img_np, (x1, y1), (x2, y2), (0, 255, 0), 2)
                cv2.putText(img_np, f"{score.item():.2f}", (x1, y1-5),
                           cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2)
            
            ax.imshow(img_np)
            ax.set_title(f"Predictions (n={keep.sum()})")
            ax.axis('off')
    
    plt.tight_layout()
    plt.show()

visualize_predictions(student_model, val_dataset, device, num_samples=3)


## 11. Summary

Knowledge distillation training completed successfully!


In [None]:
print("="*70)
print("🎉 KNOWLEDGE DISTILLATION PIPELINE COMPLETE!")
print("="*70)
print("\n✓ Dataset prepared and loaded")
print("✓ Teacher and student models configured")
print("✓ Training completed with distillation")
print("✓ Model evaluated with COCO metrics")
print("✓ Predictions visualized")
print(f"\n📁 Output directory: {CONFIG['output_dir']}")
print("   - best.pth: Best model checkpoint")
print("   - epoch_*.pth: Epoch checkpoints")
print("\n🚀 Next Steps:")
print("   1. Train for more epochs")
print("   2. Tune hyperparameters")
print("   3. Try different model pairs")
print("   4. Deploy the model")
