# FinCompress — Notebook 02: Knowledge Distillation

**RUN ON: Colab/GPU**

## Section A: Vanilla Knowledge Distillation

### How Knowledge Distillation Works

Standard supervised learning trains the student on hard labels: a one-hot vector that says "this is positive" with probability 1. But the teacher knows more: it assigns probabilities to *all* classes. For a borderline negative sentence, the teacher might output `[0.65, 0.30, 0.05]` — the 0.30 neutral probability tells the student this is ambiguous. Hard labels would just say `[1, 0, 0]` and discard this nuance entirely.

**Temperature T** controls how soft the teacher's distribution is:
- T=1: Original softmax — the teacher's actual probability distribution
- T>1: Softer — minority classes get more relative probability, amplifying the information in dark knowledge
- T→∞: Uniform — all information is lost

**T² scaling** is required because KL divergence between two T-softened distributions scales as 1/T². Without the correction, high temperature would make the KL loss negligible relative to CE loss.

**α (alpha)** balances soft-label KL loss vs. hard-label CE loss. α=0.5 gives equal weight.

In [None]:
import os, sys
PROJECT_PATH = '/content/drive/MyDrive/fincompress'
os.chdir(PROJECT_PATH)
sys.path.insert(0, PROJECT_PATH)

In [None]:
# Run vanilla KD distillation
!python -m fincompress.distillation.soft_label_distillation

In [None]:
import pandas as pd
import matplotlib.pyplot as plt

log_df = pd.read_csv('fincompress/logs/vanilla_kd_training.csv')

fig, ax = plt.subplots(figsize=(10, 5))
ax.plot(log_df['epoch'], log_df['train_total_loss'], 'o-', label='Total Loss', color='steelblue')
ax.plot(log_df['epoch'], log_df['train_ce_loss'],    's-', label='CE Loss',    color='coral')
ax.plot(log_df['epoch'], log_df['train_kl_loss'],    '^-', label='KL Loss',    color='mediumseagreen')
ax.set_xlabel('Epoch')
ax.set_ylabel('Loss')
ax.set_title('Vanilla KD Training Loss Components')
ax.legend()
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

In [None]:
import json
from pathlib import Path

teacher_info = json.loads(Path('fincompress/checkpoints/teacher/checkpoint_info.json').read_text())
student_info = json.loads(Path('fincompress/checkpoints/student_vanilla_kd/checkpoint_info.json').read_text())

print('=' * 55)
print(f'{"Metric":<25} {"Teacher":>12} {"Vanilla KD":>12}')
print('-' * 55)
print(f'{"Val Macro F1":<25} {teacher_info["val_macro_f1"]:>12.4f} {student_info["val_macro_f1"]:>12.4f}')
print(f'{"Params":<25} {teacher_info["num_parameters"]:>12,} {student_info["num_parameters"]:>12,}')
print(f'{"Size (MB)":<25} {teacher_info["size_mb"]:>12.1f} {student_info["size_mb"]:>12.1f}')
compression = teacher_info['num_parameters'] / student_info['num_parameters']
f1_gap = teacher_info['val_macro_f1'] - student_info['val_macro_f1']
print('-' * 55)
print(f'{"Compression ratio":<25} {1.0:>12.1f}× {compression:>11.1f}×')
print(f'{"F1 gap from teacher":<25} {0.0:>12.4f} {f1_gap:>+12.4f}')

## Section B: Intermediate Layer Distillation

### Why Logit-Level Supervision Alone Is Insufficient

Vanilla KD tells the student: "your output distribution should look like mine." But there are many internal computation paths that produce similar output distributions. Some of these paths are shortcuts that happen to work on training data but fail on distribution shifts.

**Intermediate distillation** adds two constraints:

1. **Hidden state MSE**: The student's layer outputs (projected to teacher dimension) should match the teacher's layer outputs at corresponding depths. This forces the student's *representations* to be similar, not just its final predictions.

2. **Attention pattern MSE**: The student's attention weight matrices should match the teacher's at corresponding layers. This constrains the student's *routing decisions* — which tokens it attends to — to follow the teacher's proven strategy.

**Layer mapping** {0→2, 1→5, 2→8, 3→11} maps each of the 4 student layers to teacher layers spaced every 3 positions across the 12-layer hierarchy. This ensures supervision at early (syntactic), middle (semantic), and final (task-relevant) abstraction levels.

In [None]:
# Run intermediate KD distillation
!python -m fincompress.distillation.intermediate_distillation

In [None]:
import pandas as pd
import matplotlib.pyplot as plt

log_df = pd.read_csv('fincompress/logs/intermediate_kd_training.csv')

fig, ax = plt.subplots(figsize=(12, 5))
colors = {'train_total_loss': 'steelblue', 'train_ce_loss': 'coral',
          'train_kl_loss': 'mediumseagreen', 'train_hidden_loss': 'darkorange', 'train_attn_loss': 'purple'}
labels = {'train_total_loss': 'Total', 'train_ce_loss': 'CE', 'train_kl_loss': 'KL',
          'train_hidden_loss': 'Hidden MSE', 'train_attn_loss': 'Attn MSE'}

for col, color in colors.items():
    ax.plot(log_df['epoch'], log_df[col], 'o-', label=labels[col], color=color)

ax.set_xlabel('Epoch')
ax.set_ylabel('Loss')
ax.set_title('Intermediate KD — All Loss Components')
ax.legend()
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

In [None]:
import json
from pathlib import Path

vanilla_info = json.loads(Path('fincompress/checkpoints/student_vanilla_kd/checkpoint_info.json').read_text())
inter_info   = json.loads(Path('fincompress/checkpoints/student_intermediate_kd/checkpoint_info.json').read_text())

print('=' * 55)
print(f'{"Metric":<25} {"Vanilla KD":>12} {"Intermediate KD":>15}')
print('-' * 55)
print(f'{"Val Macro F1":<25} {vanilla_info["val_macro_f1"]:>12.4f} {inter_info["val_macro_f1"]:>15.4f}')
print(f'{"Size (MB)":<25} {vanilla_info["size_mb"]:>12.1f} {inter_info["size_mb"]:>15.1f}')
delta = inter_info['val_macro_f1'] - vanilla_info['val_macro_f1']
print('-' * 55)
print(f'{"F1 improvement":<25} {"":>12} {delta:>+15.4f}')

## Section C: Comparison and Analysis

### When Would You Choose Vanilla KD Over Intermediate KD in Practice?

| Scenario | Recommendation |
|----------|----------------|
| Teacher is a black-box API (no internal access) | Vanilla KD — you only have logits |
| Limited compute budget | Vanilla KD — no projection layers, faster training |
| Student architecture differs significantly from teacher | Vanilla KD — hidden state MSE assumes compatible representations |
| Maximum accuracy with full teacher access | Intermediate KD — consistently outperforms vanilla by 1-3 F1 points |
| Production: latency SLA requires smallest model | Intermediate KD — the extra training cost pays off in better student quality |

In [None]:
import json, matplotlib.pyplot as plt
import numpy as np
from pathlib import Path

teacher_info = json.loads(Path('fincompress/checkpoints/teacher/checkpoint_info.json').read_text())
vanilla_info = json.loads(Path('fincompress/checkpoints/student_vanilla_kd/checkpoint_info.json').read_text())
inter_info   = json.loads(Path('fincompress/checkpoints/student_intermediate_kd/checkpoint_info.json').read_text())

models = ['Teacher', 'Vanilla KD', 'Intermediate KD']
infos  = [teacher_info, vanilla_info, inter_info]

fig, axes = plt.subplots(1, 3, figsize=(15, 5))
metrics = ['val_macro_f1', 'num_parameters', 'size_mb']
titles  = ['Val Macro F1', 'Parameters', 'Size (MB)']
colors  = ['#888888', '#4472C4', '#2196F3']

for ax, metric, title in zip(axes, metrics, titles):
    values = [info[metric] for info in infos]
    bars = ax.bar(models, values, color=colors)
    ax.set_title(title)
    ax.set_ylabel(title)
    for bar, val in zip(bars, values):
        ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() * 1.01,
                f'{val:.4f}' if isinstance(val, float) else f'{val:,}',
                ha='center', va='bottom', fontsize=9)

plt.suptitle('Teacher vs. Distilled Students: Key Metrics', y=1.02)
plt.tight_layout()
plt.show()