# Distillation

In this notebook, we will use the knowledge distillation technique to train a YOLO v3 tiny model on the pedestrian detection task. We will use the already pretrained YOLO 8n model as the teacher model

### Step 0: Setup

In [10]:
import os
import cv2
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from ultralytics import YOLO


### Step 1: Define the Dataset class

In [11]:
class YOLODataset(Dataset):
    def __init__(self, images_path, labels_path, transform=None):
        self.images_path = images_path
        self.labels_path = labels_path
        self.transform = transform
        self.image_files = sorted(os.listdir(images_path))
        self.label_files = sorted(os.listdir(labels_path))

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        # Load image
        image_path = os.path.join(self.images_path, self.image_files[idx])
        label_path = os.path.join(self.labels_path, self.label_files[idx])

        image = cv2.imread(image_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

        # Load labels
        labels = []
        with open(label_path, 'r') as f:
            for line in f:
                class_id, x_center, y_center, width, height = map(float, line.split())
                labels.append([class_id, x_center, y_center, width, height])
        labels = torch.tensor(labels)

        # Apply transforms (if any)
        if self.transform:
            image = self.transform(image)

        return image, labels


### Step 2: Define the Knowledge Distillation Loss

In [12]:
class DistillationLoss(nn.Module):
    def __init__(self, alpha=0.5, temperature=3.0):
        super(DistillationLoss, self).__init__()
        self.alpha = alpha
        self.temperature = temperature
        self.ce_loss = nn.CrossEntropyLoss()

    def forward(self, student_logits, teacher_logits, labels):
        # Distillation loss (soft targets)
        distillation_loss = nn.KLDivLoss(reduction='batchmean')(
            nn.functional.log_softmax(student_logits / self.temperature, dim=1),
            nn.functional.softmax(teacher_logits / self.temperature, dim=1)
        ) * (self.temperature ** 2)

        # Standard cross-entropy loss (hard targets)
        ce_loss = self.ce_loss(student_logits, labels)

        # Combine both losses
        return self.alpha * distillation_loss + (1 - self.alpha) * ce_loss


### Step 3: Prepare Dataset and DataLoaders

In [13]:
# Paths to dataset
train_images = "dataset/train/images"
train_labels = "dataset/train/labels"
valid_images = "dataset/valid/images"
valid_labels = "dataset/valid/labels"

# Transforms
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize((640, 640)),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
])

# Dataset and DataLoader
train_dataset = YOLODataset(train_images, train_labels, transform=transform)
valid_dataset = YOLODataset(valid_images, valid_labels, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, collate_fn=lambda x: x)
valid_loader = DataLoader(valid_dataset, batch_size=16, shuffle=False, collate_fn=lambda x: x)


### Step 4: Load Teacher and Student Models

In [15]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

# Load teacher model (YOLOv3-Tiny)
# teacher_model = YOLO("yolov3-tiny.pt")
model = YOLO("yolov3-tinyu.pt")
torch.save(model, "training_output/runs/detect/train/weights/yolov3-tinyu.pt")

teacher_model = torch.load("training_output/runs/detect/train/weights/best.pt")

# Load student model (YOLOv8n)
student_model = torch.load("training_output/runs/detect/train/weights/yolov3-tinyu.pt").to(device)
student_model


cpu


YOLO(
  (model): DetectionModel(
    (model): Sequential(
      (0): Conv(
        (conv): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn): BatchNorm2d(16, eps=0.001, momentum=0.03, affine=True, track_running_stats=True)
        (act): SiLU(inplace=True)
      )
      (1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (2): Conv(
        (conv): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn): BatchNorm2d(32, eps=0.001, momentum=0.03, affine=True, track_running_stats=True)
        (act): SiLU(inplace=True)
      )
      (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (4): Conv(
        (conv): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn): BatchNorm2d(64, eps=0.001, momentum=0.03, affine=True, track_running_stats=True)
        (act): SiLU(inplace=True)
      )
      (5): MaxPool2d(kernel

### Step 5: Define Optimizer and Loss

In [16]:
optimizer = torch.optim.Adam(student_model.model.parameters(), lr=1e-4)
criterion = DistillationLoss(alpha=0.7, temperature=3.0)


### Step 6: Training Loop

In [None]:
num_epochs = 10

for epoch in range(num_epochs):
    # student_model.train()

    for batch in train_loader:
        images, labels = batch

        # Move to device
        images = images.to(device)
        labels = labels.to(device)

        # Teacher predictions (detach to avoid gradients)
        with torch.no_grad():
            teacher_outputs = teacher_model.model(images)

        # Student predictions
        student_outputs = student_model.model(images)

        # Compute loss
        loss = criterion(student_outputs, teacher_outputs, labels)

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {loss.item():.4f}")
    
    # Validation
    student_model.eval()  # Passer le modèle en mode évaluation

    validation_loss = 0.0
    num_batches = len(valid_loader)

    with torch.no_grad():  # Pas de calcul de gradients pendant la validation
        for batch in valid_loader:
            images, labels = batch

            # Déplacer les données vers le device
            images = images.to(device)
            labels = labels.to(device)

            # Prédictions du modèle étudiant (student)
            student_outputs = student_model.model(images)

            # Calcul de la loss
            loss = criterion(student_outputs, student_outputs, labels)  # Ici, vous pouvez ajuster selon la distillation

            # Accumuler la loss pour calculer la moyenne
            validation_loss += loss.item()

    # Calcul de la loss moyenne sur l'ensemble du set de validation
    average_validation_loss = validation_loss / num_batches
    print(f"Validation Loss: {average_validation_loss:.4f}")



New https://pypi.org/project/ultralytics/8.3.49 available  Update with 'pip install -U ultralytics'
[34m[1mengine\trainer: [0mtask=detect, mode=train, model=yolov3-tinyu.pt, data=coco.yaml, epochs=100, time=None, patience=100, batch=16, imgsz=640, save=True, save_period=-1, cache=False, device=cpu, workers=8, project=None, name=train10, exist_ok=False, pretrained=True, optimizer=auto, verbose=True, seed=0, deterministic=True, single_cls=False, rect=False, cos_lr=False, close_mosaic=10, resume=False, amp=True, fraction=1.0, profile=False, freeze=None, multi_scale=False, overlap_mask=True, mask_ratio=4, dropout=0.0, val=True, split=val, save_json=False, save_hybrid=False, conf=None, iou=0.7, max_det=300, half=False, dnn=False, plots=True, source=None, vid_stride=1, stream_buffer=False, visualize=False, augment=False, agnostic_nms=False, classes=None, retina_masks=False, embed=None, show=False, save_frames=False, save_txt=False, save_conf=False, save_crop=False, show_labels=True, show_

### Step 7: Validation Loop

In [None]:
student_model.eval()
with torch.no_grad():
    for batch in valid_loader:
        images, labels = batch

        images = images.to(device)
        predictions = student_model(images)

        # Evaluate predictions (e.g., mAP or custom evaluation)
        print("Validation predictions:", predictions)


### Step 8: Save the Student Model

In [None]:
student_model.save("yolov8n_distilled.pt")
