## 知識蒸餾(knowledge Distillation)
- 目的: 提升速度，透過大模型(Teacher) 幫助 小模型(Student)，最後上限小模型當做預測模型，應用在需要追求速度的應用場景非常適合，如推薦系統，可以搭配其他加速手段一起使用。 這個技巧是在模型本身上加速。

- 原始定義: Knowledge Distillation is a procedure for model compression, in which a small (student) model is trained to match a large pre-trained (teacher) model.

- 以數學角度出發: 通過最小化損失函數，知識從教師模型轉移到學生，旨在match softened teacher logits和真實標籤。

- 通過在 softmax 中應用“溫度”縮放函數來軟化 logits，有效地平滑概率分佈並揭示教師學習的類間關係。

In [2]:
# 模組

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

## 建立 Distiller class
- 客製化的Distiller class 將會 override Model 的一些methods
    - train_step
        - 我們會執行前向傳遞(both teacher and student model)，計算加權(alpha)的loss(student_loss and ditillation_loss)，然後執行反向傳地更新參數(只有student)。
    - test_step
        - In the test_step method, we evaluate the student model on the provided dataset.
    - compile
- Distiller 的組成
    - 訓練好的 Teacher model
    - 將要學習的 Student model
    - Student loss function 去計算 student預測以及真實標籤的差距
    - distillation loss function 加上一個 "temperature"，去計算 soft student predictions and the soft teacher labels
    - 一個 alpha factor 去當做一個權重 student 以及 distillation loss
    - 一個優化器 給 student 以及 metrics(optional) 去衡量效能。

In [3]:
class Distiller(keras.Model):
    def __init__(self, student, teacher):
        super(Distiller, self).__init__()
        self.teacher = teacher
        self.student = student
        
    def compile(
        self,
        optimizer,
        metrics,
        student_loss_fn,
        distillation_loss_fn,
        alpha=0.1,
        temperature=3,
    ):
        """
            Args:
                optimizer: Keras optimizer for the student weights
                metrics: Keras metrics for evaluation
                student_loss_fn: Loss function of difference between student
                    predictions and ground-truth
                distillation_loss_fn: Loss function of difference between soft
                    student predictions and soft teacher predictions
                alpha: weight to student_loss_fn and 1-alpha to distillation_loss_fn
                temperature: Temperature for softening probability distributions.
                    Larger temperature gives softer distributions.
        """
        super(Distiller, self).compile(optimizer=optimizer, metrics=metrics)
        self.student_loss_fn = student_loss_fn
        self.distillation_loss_fn = distillation_loss_fn
        self.alpha = alpha
        self.temperature = temperature
        
    def train_step(self, data):
        # unpack data
        x, y = data
        
        # forward pass of teacher
        teacher_predictions = self.teacher(x, training=False)
        
        with tf.GradientTape() as tape:
            # forward pass of student
            student_predictions = self.student(x, training=True)
            
            # compute losses
            student_loss = self.student_loss_fn(y, student_predictions)
            distillation_loss = self.distillation_loss_fn(
                tf.nn.softmax(teacher_predictions / self.temperature, axis=1),
                tf.nn.softmax(student_predictions / self.temperature, axis=1),
            )
            loss = self.alpha * student_loss + (1-self.alpha) * distillation_loss
        
        # compute gradients
        trainable_vars = self.student.trainable_variables
        gradients = tape.gradient(loss, trainable_vars)
        
        # update weights
        self.optimizer.apply_gradients(zip(gradients, trainable_vars))
        
        # update metrics
        sel.compiled_metrics.update_state(y, student_predictions)
        
        # return a dict of performance
        results = {m.name: m.result() for m in self.metrics}
        results.update(
            {'stuent_loss': student_loss, 'distillation_loss': distillation_loss}
        )
        
        return results
    
    def test_step(self, data):
        # unpack the data
        x, y = data
        
        # compute predictions
        y_prediction = self.student(x, training=False)
        
        # calculate the loss
        student_loss = self.student_loss_fn(y, y_prediction)
        
        # update the metrics
        self.compiled_metrics.update_state(y, y_prediction)
        
        # return a dict of performance
        results = {m.name: m.result() for m in self.metrics}
        results.update({'student_loss': student_loss})
        return results

## 建立 student and teacher models

In [7]:
# teacher model

teacher = keras.Sequential([
    keras.Input(shape=(28, 28, 1)),
    layers.Conv2D(256, (3, 3), strides=(2, 2), padding='same'),
    layers.LeakyReLU(alpha=0.2),
    layers.MaxPooling2D(pool_size=(2, 2), strides=(1, 1), padding='same'),
    layers.Conv2D(512, (3, 3), strides=(2, 2), padding='same'),
    layers.Flatten(),
    layers.Dense(10)
], name='teacher')

# student
student = keras.Sequential([
    keras.Input(shape=(28, 28, 1)),
    layers.Conv2D(16, (3, 3), strides=(2, 2), padding='same'),
    layers.LeakyReLU(alpha=0.2),
    layers.MaxPooling2D(pool_size=(2, 2), strides=(1, 1), padding='same'),
    layers.Conv2D(32, (3, 3), strides=(2, 2), padding='same'),
    layers.Flatten(),
    layers.Dense(10)
], name='student')

# clone student for later comparison
student_scratch = keras.models.clone_model(student)



## 資料集

In [9]:
batch_size = 64
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()

# normalize data
x_train = x_train.astype('float32') / 255.0
x_train = np.reshape(x_train, (-1, 28, 28, 1))

x_test = x_test.astype('float32') / 255.0
x_test = np.reshape(x_test, (-1, 28, 28, 1))

In [10]:
x_train.shape, x_test.shape

((60000, 28, 28, 1), (10000, 28, 28, 1))

In [12]:
# 需要SparseCategoricalCrossentropy
y_train

array([5, 0, 4, ..., 5, 6, 8], dtype=uint8)

## 訓練 teacher model
- 因為 teacher model 需要是訓練過且參數固定的。

In [13]:
teacher.compile(
    optimizer=keras.optimizers.Adam(),
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=[keras.metrics.SparseCategoricalAccuracy()],
)

# train
teacher.fit(x_train, y_train, epochs=5, batch_size=batch_size)
teacher.evaluate(x_test, y_test)

Epoch 1/5

KeyboardInterrupt: 

# Distill teacher to student
- 現在開始用之前定義好的Distiller

In [None]:
# init
distiller = Distiller(student=student, teacher=teacher)
distiller.compile(
    keras.optimizers.Adam(),
    metrics=[keras.metrics.SparseCategoricalAccuracy()],
    student_loss_fn=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    distillation_loss_fn=keras.losses.KLDivergence(),
    alpha=0.1,
    temperature=10,
)

# Distill teacher to student
distiller.fit(x_train, y_train, epochs=3, batch_size=batch_size)

# Evaluate student on test dataset
distiller.evaluate(x_test, y_test)

## Train student from scratch for comparison
We can also train an equivalent student model from scratch without the teacher, in order to evaluate the performance gain obtained by knowledge distillation.

In [None]:
# Train student as doen usually
student_scratch.compile(
    optimizer=keras.optimizers.Adam(),
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=[keras.metrics.SparseCategoricalAccuracy()],
)

# Train and evaluate student trained from scratch.
student_scratch.fit(x_train, y_train, epochs=3)
student_scratch.evaluate(x_test, y_test)