In [38]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import numpy as np

In [45]:
class Distiller(keras.Model):
  def __init__(self, student, teacher):
    super(Distiller, self).__init__()
    self.teacher = teacher
    self.student = student

  def compile(self, optimizer, metrics, student_loss_fn, distillation_loss_fn, alpha=0.1, temperature=3):
    super(Distiller, self).compile(optimizer=optimizer, metrics=metrics)
    self.student_loss_fn = student_loss_fn
    self.distillation_loss_fn = distillation_loss_fn
    self.alpha = alpha
    self.temperature = temperature

  # 训练步骤
  def train_step(self, data):
    # 获取数据
    x, y = data
    # 对原模型进行正向传播
    teacher_predictions = self.teacher(x, training=False)

    # 进行梯度下降,更新权值
    with tf.GradientTape() as tape:
      # 对新模型进行正向传播
      student_predictions = self.student(x, training=True)

      # 计算损失
      student_loss = self.student_loss_fn(y, student_predictions)
      distillation_loss = self.distillation_loss_fn(
          tf.nn.softmax(teacher_predictions / self.temperature, axis=1),
          tf.nn.softmax(student_predictions / self.temperature, axis=1)
      )
      loss = self.alpha * student_loss + (1-self.alpha) * distillation_loss
    # 计算梯度
    trainable_vars = self.student.trainable_variables
    gradients = tape.gradient(loss, trainable_vars)

    # 更新权值
    self.optimizer.apply_gradients(zip(gradients, trainable_vars))
    # 更新指标
    self.compiled_metrics.update_state(y, student_predictions)
    # 输出指标中的值
    results = {m.name: m.result() for m in self.metrics}
    results.update(
        {'student_loss': student_loss,
         'distillation_loss': distillation_loss}
    )
    return results

  def test_step(self, data):
    x, y = data
    # 计算预测值
    y_prediction = self.student(x, training=False)
    
    # 计算损失
    student_loss = self.student_loss_fn(y, y_prediction)

    # 更新指标
    self.compiled_metrics.update_state(y, y_prediction)

    # 输出指标中的值
    results = {m.name: m.result() for m in self.metrics}
    results.update({'student_loss': student_loss})
    return results


In [40]:
# 创建原模型架构
teacher = keras.Sequential(
    [
     keras.Input(shape=(28, 28, 1)),
     layers.Conv2D(256, (3, 3), strides=(2, 2), padding='same'),
     layers.LeakyReLU(alpha=0.2),
     layers.MaxPooling2D(pool_size=(2, 2), strides=(1, 1), padding='same'),
     layers.Conv2D(512, (3, 3), strides=(2, 2), padding='same'),
     layers.Flatten(),
     layers.Dense(10)
    ],
    name='teacher'
)

# 创建新模型架构
student = keras.Sequential(
    [
     keras.Input(shape=(28, 28, 1)),
     layers.Conv2D(16, (3, 3), strides=(2, 2), padding='same'),
     layers.LeakyReLU(alpha=0.2),
     layers.MaxPooling2D(pool_size=(2, 2), strides=(1, 1), padding='same'),
     layers.Conv2D(32, (3, 3), strides=(2, 2), padding='same'),
     layers.Flatten(),
     layers.Dense(10)
    ],
    name='student'
)
# 克隆新模型和之后的模型对照
student_scratch = keras.models.clone_model(student)



In [41]:
# 数据预处理
batch_size = 64
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()

x_train = x_train.astype('float32') / 255.0
x_train = np.reshape(x_train, (-1, 28, 28, 1))

x_test = x_test.astype('float32') / 255.0
x_test = np.reshape(x_test, (-1, 28, 28, 1))

In [42]:
# 训练原模型
teacher.compile(
    optimizer = keras.optimizers.Adam(),
    loss = keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics = [keras.metrics.SparseCategoricalAccuracy()]
)

teacher.fit(x_train, y_train, epochs=5)
teacher.evaluate(x_test, y_test)

Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5


[0.08362632244825363, 0.9797999858856201]

In [43]:
# 训练蒸馏器模型
distiller = Distiller(student=student, teacher=teacher)
distiller.compile(
    optimizer = keras.optimizers.Adam(),
    metrics = [keras.metrics.SparseCategoricalAccuracy()],
    student_loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    distillation_loss_fn = keras.losses.KLDivergence(),
    alpha=0.1,
    temperature=10
)

distiller.fit(x_train, y_train, epochs=3)

distiller.evaluate(x_test, y_test)

Epoch 1/3
Epoch 2/3
Epoch 3/3


[0.9787999987602234, 1.2322904694883619e-05]

In [44]:
# 训练新模型
student_scratch.compile(
    optimizer = keras.optimizers.Adam(),
    loss = keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics = [keras.metrics.SparseCategoricalAccuracy()]
)

student_scratch.fit(x_train, y_train, epochs=3)
student_scratch.evaluate(x_test, y_test)

Epoch 1/3
Epoch 2/3
Epoch 3/3


[0.06773968786001205, 0.9775000214576721]