In [None]:
#| default_exp distill.losses

In [None]:
#| include: false
from nbdev.showdoc import *

In [None]:
#| export
from __future__ import annotations
import torch
import torch.nn as nn
import torch.nn.functional as F

## Overview

This module provides loss functions for knowledge distillation. These losses enable training a smaller "student" network to mimic a larger "teacher" network.

**Loss Categories:**
- **Output-based**: `SoftTarget`, `Logits`, `Mutual` - compare final predictions
- **Feature-based**: `Attention`, `FitNet`, `Similarity`, `ActivationBoundaries` - compare intermediate representations

## Output-Based Losses

These losses compare the final output predictions between student and teacher networks.

In [None]:
#| export
def SoftTarget(pred: torch.Tensor,          # Student predictions
               teacher_pred: torch.Tensor,  # Teacher predictions
               T: float = 5,                # Temperature for softening
               **kwargs
) -> torch.Tensor:
    "Knowledge distillation with softened distributions (Hinton et al.)"
    student_soft = F.log_softmax(pred / T, dim=1)
    teacher_soft = F.softmax(teacher_pred / T, dim=1)
    return nn.KLDivLoss(reduction='batchmean')(student_soft, teacher_soft) * (T * T)

In [None]:
show_doc(SoftTarget)

Found permutation search CUDA kernels
[ASP][Info] permutation_search_kernels can be imported.


---

[source](https://github.com/FasterAI-Labs/fasterai/tree/master/blob/master/fasterai/distill/losses.py#L13){target="_blank" style="float:right; font-size:smaller"}

### SoftTarget

```python

def SoftTarget(
    pred:torch.Tensor, # Student predictions
    teacher_pred:torch.Tensor, # Teacher predictions
    T:float=5, # Temperature for softening
    kwargs:VAR_KEYWORD
)->torch.Tensor:


```

*Knowledge distillation with softened distributions (Hinton et al.)*

In [None]:
#| export
def Logits(pred: torch.Tensor,          # Student predictions
           teacher_pred: torch.Tensor,  # Teacher predictions
           **kwargs
) -> torch.Tensor:
    "Direct logit matching between student and teacher"
    return F.mse_loss(pred, teacher_pred)

In [None]:
show_doc(Logits)

---

[source](https://github.com/FasterAI-Labs/fasterai/tree/master/blob/master/fasterai/distill/losses.py#L24){target="_blank" style="float:right; font-size:smaller"}

### Logits

```python

def Logits(
    pred:torch.Tensor, # Student predictions
    teacher_pred:torch.Tensor, # Teacher predictions
    kwargs:VAR_KEYWORD
)->torch.Tensor:


```

*Direct logit matching between student and teacher*

In [None]:
#| export
def Mutual(pred: torch.Tensor,          # Student predictions
           teacher_pred: torch.Tensor,  # Teacher predictions
           **kwargs
) -> torch.Tensor:
    "KL divergence between student and teacher"
    student_log_prob = F.log_softmax(pred, dim=1)
    teacher_prob = F.softmax(teacher_pred, dim=1)
    return nn.KLDivLoss(reduction='batchmean')(student_log_prob, teacher_prob)

In [None]:
show_doc(Mutual)

---

[source](https://github.com/FasterAI-Labs/fasterai/tree/master/blob/master/fasterai/distill/losses.py#L32){target="_blank" style="float:right; font-size:smaller"}

### Mutual

```python

def Mutual(
    pred:torch.Tensor, # Student predictions
    teacher_pred:torch.Tensor, # Teacher predictions
    kwargs:VAR_KEYWORD
)->torch.Tensor:


```

*KL divergence between student and teacher*

---

## Feature-Based Losses

These losses compare intermediate feature representations, enabling the student to learn internal representations similar to the teacher.

In [None]:
#| export
def Attention(fm_s: dict[str, torch.Tensor],  # Student feature maps {name: tensor}
              fm_t: dict[str, torch.Tensor],  # Teacher feature maps {name: tensor}
              p: int = 2,                     # Power for attention computation
              **kwargs
) -> torch.Tensor:
    "Attention transfer loss (Zagoruyko & Komodakis)"
    total_loss = 0.0
    for name_st, name_t in zip(fm_s, fm_t):
        student_attention = fm_s[name_st].pow(p).mean(1)
        teacher_attention = fm_t[name_t].pow(p).mean(1)
        student_norm = F.normalize(student_attention, dim=(1, 2))
        teacher_norm = F.normalize(teacher_attention, dim=(1, 2))
        total_loss += F.mse_loss(student_norm, teacher_norm)
    return total_loss

In [None]:
show_doc(Attention)

---

[source](https://github.com/FasterAI-Labs/fasterai/tree/master/blob/master/fasterai/distill/losses.py#L42){target="_blank" style="float:right; font-size:smaller"}

### Attention

```python

def Attention(
    fm_s:dict[str, torch.Tensor], # Student feature maps {name: tensor}
    fm_t:dict[str, torch.Tensor], # Teacher feature maps {name: tensor}
    p:int=2, # Power for attention computation
    kwargs:VAR_KEYWORD
)->torch.Tensor:


```

*Attention transfer loss (Zagoruyko & Komodakis)*

In [None]:
#| export
def ActivationBoundaries(fm_s: dict[str, torch.Tensor],  # Student feature maps
                         fm_t: dict[str, torch.Tensor],  # Teacher feature maps
                         m: float = 2,                   # Boundary margin
                         **kwargs
) -> torch.Tensor:
    "Boundary-based knowledge distillation (Heo et al.)"
    total_loss = 0.0
    for name_st, name_t in zip(fm_s, fm_t):
        student_act = fm_s[name_st]
        teacher_act = fm_t[name_t]
        positive_boundary = (student_act + m).pow(2) * ((student_act > -m) & (teacher_act <= 0)).float()
        negative_boundary = (student_act - m).pow(2) * ((student_act <= m) & (teacher_act > 0)).float()
        total_loss += (positive_boundary + negative_boundary).mean()
    return total_loss

In [None]:
show_doc(ActivationBoundaries)

---

[source](https://github.com/FasterAI-Labs/fasterai/tree/master/blob/master/fasterai/distill/losses.py#L58){target="_blank" style="float:right; font-size:smaller"}

### ActivationBoundaries

```python

def ActivationBoundaries(
    fm_s:dict[str, torch.Tensor], # Student feature maps
    fm_t:dict[str, torch.Tensor], # Teacher feature maps
    m:float=2, # Boundary margin
    kwargs:VAR_KEYWORD
)->torch.Tensor:


```

*Boundary-based knowledge distillation (Heo et al.)*

In [None]:
#| export
def FitNet(fm_s: dict[str, torch.Tensor],  # Student feature maps
           fm_t: dict[str, torch.Tensor],  # Teacher feature maps
           **kwargs
) -> torch.Tensor:
    "FitNets: direct feature map matching (Romero et al.)"
    total_loss = 0.0
    for name_st, name_t in zip(fm_s, fm_t):
        total_loss += F.mse_loss(fm_s[name_st], fm_t[name_t])
    return total_loss

In [None]:
show_doc(FitNet)

---

[source](https://github.com/FasterAI-Labs/fasterai/tree/master/blob/master/fasterai/distill/losses.py#L74){target="_blank" style="float:right; font-size:smaller"}

### FitNet

```python

def FitNet(
    fm_s:dict[str, torch.Tensor], # Student feature maps
    fm_t:dict[str, torch.Tensor], # Teacher feature maps
    kwargs:VAR_KEYWORD
)->torch.Tensor:


```

*FitNets: direct feature map matching (Romero et al.)*

In [None]:
#| export
def Similarity(fm_s: dict[str, torch.Tensor],  # Student feature maps
               fm_t: dict[str, torch.Tensor],  # Teacher feature maps
               pred: torch.Tensor,             # Student predictions (unused, for API consistency)
               p: int = 2,                     # Normalization power
               **kwargs
) -> torch.Tensor:
    "Similarity-preserving knowledge distillation (Tung & Mori)"
    total_loss = 0.0
    for name_st, name_t in zip(fm_s, fm_t):
        student_flat = fm_s[name_st].view(fm_s[name_st].size(0), -1)
        teacher_flat = fm_t[name_t].view(fm_t[name_t].size(0), -1)
        student_sim = F.normalize(student_flat @ student_flat.t(), p=p, dim=1)
        teacher_sim = F.normalize(teacher_flat @ teacher_flat.t(), p=p, dim=1)
        total_loss += F.mse_loss(student_sim, teacher_sim)
    return total_loss

In [None]:
show_doc(Similarity)

---

[source](https://github.com/FasterAI-Labs/fasterai/tree/master/blob/master/fasterai/distill/losses.py#L85){target="_blank" style="float:right; font-size:smaller"}

### Similarity

```python

def Similarity(
    fm_s:dict[str, torch.Tensor], # Student feature maps
    fm_t:dict[str, torch.Tensor], # Teacher feature maps
    pred:torch.Tensor, # Student predictions (unused, for API consistency)
    p:int=2, # Normalization power
    kwargs:VAR_KEYWORD
)->torch.Tensor:


```

*Similarity-preserving knowledge distillation (Tung & Mori)*

In [None]:
#| hide
from fastcore.test import *

# Output-based losses return scalars
pred_s, pred_t = torch.randn(4, 10), torch.randn(4, 10)

test_eq(SoftTarget(pred_s, pred_t).dim(), 0)
test_eq(Logits(pred_s, pred_t).dim(), 0)
test_eq(Mutual(pred_s, pred_t).dim(), 0)

# Different temperature → different loss
test_ne(SoftTarget(pred_s, pred_t, T=1), SoftTarget(pred_s, pred_t, T=10))

# Feature-based losses return scalars
fm_s = {'l1': torch.randn(4, 32, 8, 8), 'l2': torch.randn(4, 64, 4, 4)}
fm_t = {'l1': torch.randn(4, 32, 8, 8), 'l2': torch.randn(4, 64, 4, 4)}

test_eq(Attention(fm_s, fm_t).dim(), 0)
test_eq(FitNet(fm_s, fm_t).dim(), 0)

# Identical inputs → ~0 loss
fm_id = {'l1': torch.randn(4, 32, 8, 8)}
test_close(FitNet(fm_id, fm_id).item(), 0.0, eps=1e-5)
test_close(Attention(fm_id, fm_id).item(), 0.0, eps=1e-4)

# All losses non-negative
assert SoftTarget(pred_s, pred_t) >= 0
assert Attention(fm_s, fm_t) >= 0
assert FitNet(fm_s, fm_t) >= 0

# ActivationBoundaries returns scalar
test_eq(ActivationBoundaries(fm_s, fm_t).dim(), 0)
assert ActivationBoundaries(fm_s, fm_t) >= 0

---

## See Also

- [KnowledgeDistillationCallback](distillation_callback.html) - Apply these losses during training
- [Distillation Tutorial](../tutorials/distill/distill_callback.html) - Practical examples with different losses

### Loss Selection Guide

| Loss | Best For | Complexity |
|------|----------|------------|
| **SoftTarget** | General distillation, logit matching | Low |
| **Attention** | When attention patterns matter | Low |
| **FitNet** | Intermediate feature matching | Medium |
| **PKT** | Probability distribution matching | Medium |
| **RKD** | Relational knowledge transfer | High |