In [1]:
# Import the necessary modules
import matplotlib.pyplot as plt
import tensorflow as tf
import tensorflow_compression as tfc
import tensorflow_datasets as tfds

# Define the hyperparameters
batch_size = 128
num_epochs = 10
temperature = 10 # A scaling factor for the soft targets
alpha = 0.9 # A weighting factor for the soft loss

# Load the Fashion MNIST dataset
transform = tf.keras.layers.experimental.preprocessing.Rescaling(1./255)
trainset = tfds.load('fashion_mnist', split='train', as_supervised=True).map(lambda x, y: (transform(x), y)).batch(batch_size)
testset = tfds.load('fashion_mnist', split='test', as_supervised=True).map(lambda x, y: (transform(x), y)).batch(batch_size)
classes = ('T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat', 'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot')

# Define the teacher model (a small convolutional neural network)
class Teacher(tf.keras.Model):
    def __init__(self):
        super(Teacher, self).__init__()
        self.conv1 = tf.keras.layers.Conv2D(16, 3, padding='same', activation='relu')
        self.pool1 = tf.keras.layers.MaxPool2D(2, 2)
        self.conv2 = tf.keras.layers.Conv2D(32, 3, padding='same', activation='relu')
        self.pool2 = tf.keras.layers.MaxPool2D(2, 2)
        self.flatten = tf.keras.layers.Flatten()
        self.fc1 = tf.keras.layers.Dense(64, activation='relu')
        self.fc2 = tf.keras.layers.Dense(10)

    def call(self, x):
        x = self.pool1(self.conv1(x))
        x = self.pool2(self.conv2(x))
        x = self.flatten(x)
        x = self.fc1(x)
        x = self.fc2(x)
        return x

# Define the student model (a smaller fully-connected neural network)
class Student(tf.keras.Model):
    def __init__(self):
        super(Student, self).__init__()
        self.flatten = tf.keras.layers.Flatten()
        self.fc1 = tf.keras.layers.Dense(32, activation='relu')
        self.fc2 = tf.keras.layers.Dense(10)

    def call(self, x):
        x = self.flatten(x)
        x = self.fc1(x)
        x = self.fc2(x)
        return x


ModuleNotFoundError: No module named 'tensorflow_compression'

In [None]:
# Instantiate the models
teacher = Teacher()
student = Student()

# Define the loss function and the optimizer
criterion = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
optimizer = tf.keras.optimizers.SGD(learning_rate=0.01, momentum=0.9)

# Train the teacher model on the original hard targets
teacher.compile(optimizer=optimizer, loss=criterion, metrics=['accuracy'])
teacher.fit(trainset, epochs=num_epochs, validation_data=testset)
print('Finished training the teacher model')

# Evaluate the teacher model on the test set
teacher.evaluate(testset)
print('Accuracy of the teacher model on the test set: %.2f %%' % (teacher.metrics[-1].result() * 100))

In [None]:
# Train the student model on the soft targets from the teacher model
student.compile(optimizer=optimizer, loss=criterion, metrics=['accuracy'])
for epoch in range(num_epochs):
    running_loss = 0.0
    for i, (inputs, labels) in enumerate(trainset):
        with tf.GradientTape() as tape:
            outputs = student(inputs)
            with tf.GradientTape(watch_accessed_variables=False) as tape2:
                tape2.watch(outputs)
                teacher_outputs = teacher(inputs)
            # Compute the soft loss and the hard loss
            soft_loss = tf.keras.losses.KLDivergence()(tf.nn.softmax(teacher_outputs / temperature), tf.nn.softmax(outputs / temperature))
            hard_loss = criterion(labels, outputs)
            # Combine the soft loss and the hard loss with a weighting factor
            loss = alpha * soft_loss + (1 - alpha) * hard_loss
        # Apply the gradients
        grads = tape.gradient(loss, student.trainable_variables)
        optimizer.apply_gradients(zip(grads, student.trainable_variables))
        running_loss += loss.numpy()
        if i % 200 == 199:
            print('[%d, %5d] loss: %.3f' % (epoch + 1, i + 1, running_loss / 200))
            running_loss = 0.0
print('Finished training the student model')

# Evaluate the student model on the test set
student.evaluate(testset)
print('Accuracy of the student model on the test set: %.2f %%' % (student.metrics[-1].result() * 100))