## Overview

    Setup: Install dependencies and import libraries
    Data Preparation: Download and convert KITTI to COCO format
    Dataset Loading: Create PyTorch datasets
    Model Setup: Load teacher and student models
    Training: Train with knowledge distillation
    Evaluation: Evaluate with COCO metrics
    Visualization: Visualize predictions


## 1. Setup

In [None]:
from google.colab import drive
drive.mount('/content/gdrive')

In [None]:
!pip install -q torch torchvision transformers pycocotools pillow tqdm pyyaml matplotlib opencv-python ultralytics

In [None]:
import sys
import json
import random
import yaml
from pathlib import Path

import torch
import numpy as np
import matplotlib.pyplot as plt
import cv2
from PIL import Image
from tqdm.auto import tqdm
from pycocotools.coco import COCO
from pycocotools.cocoeval import COCOeval
from torch.utils.data import DataLoader
from ultralytics import YOLO

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


In [None]:
CONFIG = {
    'kitti_root': './kitti_data/training',
    'data_root': './kitti_yolo',
    'output_dir': './output/',
    'num_labels': 6,  # Car, Pedestrian, Cyclist, Truck, Tram, Misc (KITTI classes)
    'teacher_model': 'facebook/detr-resnet-50',
    'student_model': 'facebook/detr-resnet-50',
    'batch_size': 32,
    'num_workers': 4,
    'epochs_teacher': 30,
    'epochs_student': 30,
    'learning_rate': 1.0e-4,
    'weight_decay': 1.0e-4,
    'temperature': 2.0,
    'alpha': 0.5,
    'seed': 42,
    'patience': 3,
}

### Download KITTI dataset

In [None]:
import os, zipfile
from pathlib import Path
from urllib.request import urlretrieve
from tqdm import tqdm

# URLs
KITTI_URLS = {
    "images": "https://s3.eu-central-1.amazonaws.com/avg-kitti/data_object_image_2.zip",
    "labels": "https://s3.eu-central-1.amazonaws.com/avg-kitti/data_object_label_2.zip",
    "calib": "https://s3.eu-central-1.amazonaws.com/avg-kitti/data_object_calib.zip",
}

# Output directory
output_dir = Path("/content/kitti_data")
output_dir.mkdir(parents=True, exist_ok=True)

# Progress bar helper
class DownloadProgressBar(tqdm):
    def update_to(self, b=1, bsize=1, tsize=None):
        if tsize is not None:
            self.total = tsize
        self.update(b * bsize - self.n)

# Download + extract
for name, url in KITTI_URLS.items():
    filename = url.split("/")[-1]
    filepath = output_dir / filename

    if not filepath.exists():
        print(f"📦 Downloading {name}...")
        with DownloadProgressBar(unit='B', unit_scale=True, miniters=1, desc=filename) as t:
            urlretrieve(url, filename=filepath, reporthook=t.update_to)

    print(f"📂 Extracting {filename}...")
    with zipfile.ZipFile(filepath, 'r') as zip_ref:
        zip_ref.extractall(output_dir)

print("✅ KITTI dataset ready at:", output_dir)

### Extract subset of dataset and convert to COCO

In [None]:

KITTI_CLASS_MAPPING = {
    'Car': 'car', 'Van': 'car', 'Truck': 'truck',
    'Pedestrian': 'person', 'Person_sitting': 'person',
    'Cyclist': 'bicycle', 'Tram': 'train', 'Misc': 'other'
}

KITTI_CATEGORIES = [
    {'id': 1, 'name': 'person'},
    {'id': 2, 'name': 'car'},
    {'id': 3, 'name': 'truck'},
    {'id': 4, 'name': 'bicycle'},
    {'id': 5, 'name': 'train'},
    {'id': 6, 'name': 'other'},
]

CAT_NAME_TO_ID = {cat["name"]: cat["id"] for cat in KITTI_CATEGORIES}

def read_kitti_label(label_path: Path):
    objs = []
    with open(label_path, "r") as f:
        for line in f:
            parts = line.strip().split()
            if len(parts) < 15:
                continue
            objs.append({
                "type": parts[0],
                "bbox": [float(x) for x in parts[4:8]],
            })
    return objs

def convert_to_coco(kitti_root, output_dir, split, image_ids):
    image_dir, label_dir = kitti_root / "image_2", kitti_root / "label_2"
    out_img_dir = output_dir / "images" / split
    out_img_dir.mkdir(parents=True, exist_ok=True)

    images, annotations = [], []
    ann_id = 1
    for img_idx, img_name in enumerate(tqdm(image_ids, desc=f"{split}")):
        img_path = image_dir / f"{img_name}.png"
        if not img_path.exists(): continue
        img = Image.open(img_path)
        width, height = img.size
        shutil.copy(img_path, out_img_dir / f"{img_name}.png")
        images.append({"id": img_idx+1, "file_name": f"{img_name}.png", "width": width, "height": height})

        label_path = label_dir / f"{img_name}.txt"
        if not label_path.exists(): continue
        for obj in read_kitti_label(label_path):
            if obj["type"] not in KITTI_CLASS_MAPPING: continue
            name = KITTI_CLASS_MAPPING[obj["type"]]
            cat_id = CAT_NAME_TO_ID[name]
            x1, y1, x2, y2 = obj["bbox"]
            w, h = x2 - x1, y2 - y1
            if w <= 0 or h <= 0: continue
            annotations.append({
                "id": ann_id, "image_id": img_idx+1,
                "category_id": cat_id, "bbox": [x1, y1, w, h],
                "area": w*h, "iscrowd": 0
            })
            ann_id += 1
    return {"images": images, "annotations": annotations, "categories": KITTI_CATEGORIES}

# ---- MAIN ----
random.seed(seed)
image_files = sorted(list((kitti_root / "image_2").glob("*.png")))
image_ids = [f.stem for f in image_files]
if max_samples: image_ids = image_ids[:max_samples]
random.shuffle(image_ids)

n = len(image_ids)
train_end = int(n * train_split)
val_end = int(n * (train_split + val_split))
splits = {"train": image_ids[:train_end], "val": image_ids[train_end:val_end], "test": image_ids[val_end:]}

ann_dir = output_dir / "annotations"
ann_dir.mkdir(parents=True, exist_ok=True)

for split, ids in splits.items():
    coco = convert_to_coco(kitti_root, output_dir, split, ids)
    with open(ann_dir / f"instances_{split}.json", "w") as f:
        json.dump(coco, f)
    print(f"✅ {split}: {len(coco['images'])} images, {len(coco['annotations'])} annotations")

print("\n✅ Conversion complete! Output:", output_dir)

## 2.  Baselines

### Finetune Teacher

In [None]:
def train_teacher(data_yaml_path):
    """Train teacher YOLOv8m."""
    
    print("\n" + "="*70)
    print("STEP 3: Training Teacher YOLOv8m")
    print("="*70)
    
    model = YOLO('yolov8m.pt')
    
    results = model.train(
        data=data_yaml_path,
        epochs=CONFIG['epochs_teacher'],
        imgsz=CONFIG['img_size'],
        batch=CONFIG['batch_size'] // 2,  # Larger model needs more memory
        patience=CONFIG['patience'],
        device=device,
        project=CONFIG['output_dir'],
        name='teacher_yolov8m',
        exist_ok=True,
        verbose=True,
    )
    
    # Evaluate
    metrics = model.val(
        data=data_yaml_path,
        split='test',
        imgsz=CONFIG['img_size'],
        batch=CONFIG['batch_size'] // 2,
        device=device,
    )
    
    print(f"\n📊 Teacher Results:")
    print(f"  mAP@0.50: {metrics.box.map50:.4f}")
    print(f"  mAP@0.50-0.95: {metrics.box.map:.4f}")
    
    return model, metrics

### Finetune Student

In [None]:
def train_baseline(data_yaml_path):
    """Train baseline YOLOv8n without distillation."""
    
    print("\n" + "="*70)
    print("STEP 2: Training Baseline YOLOv8n (no distillation)")
    print("="*70)
    
    model = YOLO('yolov8n.pt')
    
    results = model.train(
        data=data_yaml_path,
        epochs=CONFIG['epochs_student'],
        imgsz=CONFIG['img_size'],
        batch=CONFIG['batch_size'],
        patience=CONFIG['patience'],
        device=device,
        project=CONFIG['output_dir'],
        name='baseline_yolov8n',
        exist_ok=True,
        verbose=True,
    )
    
    # Evaluate
    metrics = model.val(
        data=data_yaml_path,
        split='test',
        imgsz=CONFIG['img_size'],
        batch=CONFIG['batch_size'],
        device=device,
    )
    
    print(f"\n📊 Baseline Results:")
    print(f"  mAP@0.50: {metrics.box.map50:.4f}")
    print(f"  mAP@0.50-0.95: {metrics.box.map:.4f}")
    
    return model, metrics

## 3.  Distillation

In [None]:
# Distillation loss
kl_div = nn.KLDivLoss(reduction="batchmean")
bce_loss = nn.BCEWithLogitsLoss()

def distillation_loss(student_logits, teacher_logits, T=2.0):
    """KL-divergence based distillation"""
    s = nn.functional.log_softmax(student_logits / T, dim=1)
    t = nn.functional.softmax(teacher_logits / T, dim=1)
    return kl_div(s, t) * (T * T)

# Dataloader from YOLO
train_loader = student.model.dataloaders(data_yaml_path, batch_size=batch_size, mode="train")

optimizer = torch.optim.Adam(student.parameters(), lr=1e-4)

print("🚀 Starting YOLOv8 → YOLOv8 distillation...")

for epoch in range(epochs):
    student.train()
    total_loss = 0
    for imgs, targets in train_loader:
        imgs = imgs.to(device)
        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

        # Forward passes
        with torch.no_grad():
            teacher_preds = teacher(imgs).logits  # teacher outputs
        student_preds = student(imgs).logits

        # Compute losses
        loss_student = bce_loss(student_preds, torch.zeros_like(student_preds))  # placeholder hard loss
        loss_distill = distillation_loss(student_preds, teacher_preds, T=temperature)

        loss = alpha * loss_student + (1 - alpha) * loss_distill

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    print(f"Epoch [{epoch+1}/{epochs}] Loss: {total_loss:.4f}")

print("✅ Distillation finished!")
student.save("student_yolo_distilled.pt")

## 4. Evaluation

In [None]:
def compare_results(baseline_metrics, teacher_metrics, student_metrics):
    """Compare all three models."""
    
    print("\n" + "="*70)
    print("STEP 5: Comparing Results")
    print("="*70)
    
    results_data = {
        'Model': ['Baseline YOLOv8n', 'Teacher YOLOv8m', 'Student YOLOv8n (Distilled)'],
        'Parameters': ['3.2M', '25.9M', '3.2M'],
        'mAP@0.50': [
            baseline_metrics.box.map50,
            teacher_metrics.box.map50,
            student_metrics.box.map50
        ],
        'mAP@0.50-0.95': [
            baseline_metrics.box.map,
            teacher_metrics.box.map,
            student_metrics.box.map
        ],
    }
    
    df = pd.DataFrame(results_data)
    
    print("\n📊 FINAL RESULTS:")
    print(df.to_string(index=False))
    
    # Calculate improvement
    baseline_map = baseline_metrics.box.map50
    student_map = student_metrics.box.map50
    improvement = ((student_map - baseline_map) / baseline_map) * 100
    
    print(f"\n🎯 Distillation Improvement: {improvement:+.2f}%")
    
    # Plot comparison
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))
    
    models = ['Baseline\nYOLOv8n', 'Teacher\nYOLOv8m', 'Student\nYOLOv8n\n(Distilled)']
    map50_scores = results_data['mAP@0.50']
    map_scores = results_data['mAP@0.50-0.95']
    
    colors = ['#3498db', '#e74c3c', '#2ecc71']
    
    axes[0].bar(models, map50_scores, color=colors, alpha=0.7)
    axes[0].set_ylabel('mAP@0.50', fontsize=12)
    axes[0].set_title('Detection Performance (IoU=0.50)', fontsize=14, fontweight='bold')
    axes[0].set_ylim(0, max(map50_scores) * 1.2)
    for i, v in enumerate(map50_scores):
        axes[0].text(i, v + 0.01, f'{v:.3f}', ha='center', va='bottom', fontweight='bold')
    
    axes[1].bar(models, map_scores, color=colors, alpha=0.7)
    axes[1].set_ylabel('mAP@0.50-0.95', fontsize=12)
    axes[1].set_title('Detection Performance (IoU=0.50-0.95)', fontsize=14, fontweight='bold')
    axes[1].set_ylim(0, max(map_scores) * 1.2)
    for i, v in enumerate(map_scores):
        axes[1].text(i, v + 0.01, f'{v:.3f}', ha='center', va='bottom', fontweight='bold')
    
    plt.tight_layout()
    output_path = Path(CONFIG['output_dir']) / 'comparison.png'
    plt.savefig(output_path, dpi=150, bbox_inches='tight')
    print(f"\n💾 Comparison plot saved to: {output_path}")
    
    return df