# MHAT Integration with Knowledge Distillation

In [None]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import numpy as np

In [None]:
class Distiller(keras.Model):
    def __init__(self, server, client):
        super().__init__()
        self.client = client
        self.server = server

    def compile(
        self,
        optimizer,
        metrics,
        server_loss_fn,
        distillation_loss_fn,
        alpha=0.1,
        temperature=3,
    ):

        super().compile(optimizer=optimizer, metrics=metrics)
        self.server_loss_fn = server_loss_fn
        self.distillation_loss_fn = distillation_loss_fn
        self.alpha = alpha
        self.temperature = temperature

    def train_step(self, data):
        # Unpack data
        x, y = data

        # Forward pass of client
        client_predictions = self.client(x, training=False)

        with tf.GradientTape() as tape:
            # Forward pass of server
            server_predictions = self.server(x, training=False)

            # Compute losses
            server_loss = self.server_loss_fn(y, server_predictions)

            # Compute scaled distillation loss from https://arxiv.org/abs/1503.02531
            # The magnitudes of the gradients produced by the soft targets scale
            # as 1/T^2, multiply them by T^2 when using both hard and soft targets.
            distillation_loss = (
                self.distillation_loss_fn(
                    tf.nn.softmax(client_predictions / self.temperature, axis=1),
                    tf.nn.softmax(server_predictions / self.temperature, axis=1),
                )
                * self.temperature**2
            )

            loss = self.alpha * server_loss + (1 - self.alpha) * distillation_loss

        # Compute gradients
        trainable_vars = self.server.trainable_variables
        gradients = tape.gradient(loss, trainable_vars)

        # Update weights
        self.optimizer.apply_gradients(zip(gradients, trainable_vars))

        # Update the metrics configured in `compile()`.
        self.compiled_metrics.update_state(y, server_predictions)

        # Return a dict of performance
        results = {m.name: m.result() for m in self.metrics}
        results.update(
            {"server_loss": server_loss, "distillation_loss": distillation_loss}
        )
        return results

    def test_step(self, data):
        # Unpack the data
        x, y = data

        # Compute predictions
        y_prediction = self.server(x, training=False)
        y_train= y_prediction

        # Calculate the loss
        server_loss = self.server_loss_fn(y, y_prediction)

        # Update the metrics.
        self.compiled_metrics.update_state(y, y_prediction)

        # Return a dict of performance
        results = {m.name: m.result() for m in self.metrics}
        results.update({"server_loss": server_loss})
        return results

In [None]:
# Create the client
client = keras.Sequential(
    [
        keras.Input(shape=(28, 28, 1)),
        layers.Conv2D(10, (3, 3), strides=(2, 2), padding="same"),
        layers.LeakyReLU(alpha=0.2),
        layers.MaxPooling2D(pool_size=(2, 2), strides=(1, 1), padding="same"),
        layers.Conv2D(512, (3, 3), strides=(2, 2), padding="same"),
        layers.Flatten(),
        layers.Dense(10),
    ],
    name="client",
)

# Create the server
server = keras.Sequential(
    [
        keras.Input(shape=(28, 28, 1)),
        layers.Conv2D(1, (3, 3), strides=(2, 2), padding="same"),
        layers.LeakyReLU(alpha=0.2),
        layers.MaxPooling2D(pool_size=(2, 2), strides=(1, 1), padding="same"),
        layers.Conv2D(32, (3, 3), strides=(2, 2), padding="same"),
        layers.Flatten(),
        layers.Dense(10),
    ],
    name="server",
)

# Clone server for later comparison
server_scratch = keras.models.clone_model(server)
server_scratch1 = keras.models.clone_model(server)

In [None]:
# Prepare the train and test dataset.
batch_size = 64
(x_train1, y_train1), (x_test, y_test) = keras.datasets.mnist.load_data()
x_train=x_train1[0:1000]
y_train=y_train1[0:1000]
# Normalize data
x_train1 = x_train1.astype("float32") / 255.0
x_train1 = np.reshape(x_train1, (-1, 28, 28, 1))
# Normalize data
x_train = x_train.astype("float32") / 255.0
x_train = np.reshape(x_train, (-1, 28, 28, 1))

x_test = x_test.astype("float32") / 255.0
x_test = np.reshape(x_test, (-1, 28, 28, 1))

Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz


In [None]:
# Train client as usual
client.compile(
    optimizer=keras.optimizers.Adam(),
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=[keras.metrics.SparseCategoricalAccuracy()],
)

# Train and evaluate client on data.
client.fit(x_train1, y_train1, epochs=5)
client.evaluate(x_test, y_test)

Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5


[0.07770504802465439, 0.9757999777793884]

In [None]:
# # Initialize and compile distiller
# distiller = Distiller(server=server, client=client)
# distiller.compile(
#     optimizer=keras.optimizers.Adam(),
#     metrics=[keras.metrics.SparseCategoricalAccuracy()],
#     server_loss_fn=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
#     distillation_loss_fn=keras.losses.KLDivergence(),
#     alpha=0.1,
#     temperature=10,
# )

# # Distill client to server
# distiller.fit(x_train, y_train, epochs=3)
# # y_train=server.predict(x_train)
# # print(y_train)
# # Evaluate server on test dataset
# distiller.evaluate(x_test, y_test)


In [None]:

# tp=client(x_train, training=False)
# ff=tf.nn.softmax(tp / 3, axis=1)
# print(ff)


In [None]:
y_test

array([7, 2, 1, ..., 4, 5, 6], dtype=uint8)

In [None]:
# Train server as doen usually
server_scratch.compile(
    optimizer=keras.optimizers.Adam(),
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=[keras.metrics.SparseCategoricalAccuracy()],
)

# Train and evaluate server trained from scratch.
server_scratch.fit(x_train1, y_train1, epochs=5)
server_scratch.evaluate(x_test, y_test)

Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5


[0.2695508897304535, 0.9235000014305115]

In [None]:
distiller = Distiller(server=server, client=client)
distiller.compile(
    optimizer=keras.optimizers.Adam(),
    metrics=[keras.metrics.SparseCategoricalAccuracy()],
    server_loss_fn=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    distillation_loss_fn=keras.losses.KLDivergence(),
    alpha=0.995,
    temperature=1,
  )

In [None]:
distiller1 = Distiller(server=client, client=server)
distiller1.compile(
      optimizer=keras.optimizers.Adam(),
      metrics=[keras.metrics.SparseCategoricalAccuracy()],
      server_loss_fn=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
      distillation_loss_fn=keras.losses.KLDivergence(),
      alpha=0.01,
      temperature=10,
  )


In [None]:
for i in range(0,10):
  # Initialize and compile distiller


  # Distill client to server
  distiller.fit(x_train, y_train, epochs=5)

  # Evaluate server on test dataset
  distiller.evaluate(x_test, y_test)
  # Initialize and compile distiller

  # Distill client to server
  distiller1.fit(x_train, y_train, epochs=1)

  # Evaluate server on test dataset
  distiller1.evaluate(x_test, y_test)
  print("---------------------------------------------------------------------------------------------------------")

Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
---------------------------------------------------------------------------------------------------------
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
---------------------------------------------------------------------------------------------------------
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
---------------------------------------------------------------------------------------------------------
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
---------------------------------------------------------------------------------------------------------
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
---------------------------------------------------------------------------------------------------------
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
---------------------------------------------------------------------------------------------------------
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
--------------

In [None]:
x_train1.shape

(60000, 28, 28, 1)

In [None]:
# Train server as doen usually
server_scratch1.compile(
    optimizer=keras.optimizers.Adam(),
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=[keras.metrics.SparseCategoricalAccuracy()],
)

# Train and evaluate server trained from scratch.
server_scratch1.fit(x_train, y_train, epochs=5)
server_scratch1.evaluate(x_test, y_test)

Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5


[0.5020046830177307, 0.8403000235557556]

#MHAT/Distillation

In [None]:
class Distiller(keras.Model):
    def __init__(self, server, client):
        super().__init__()
        self.client = client
        self.server = server

    def compile(
        self,
        optimizer,
        metrics,
        server_loss_fn,
        distillation_loss_fn,
        alpha=0.1,
        temperature=3,
    ):

        super().compile(optimizer=optimizer, metrics=metrics)
        self.server_loss_fn = server_loss_fn
        self.distillation_loss_fn = distillation_loss_fn
        self.alpha = alpha
        self.temperature = temperature

    def train_step(self, data):
        # Unpack data
        x, y = data

        # Forward pass of client
        client_predictions = self.client(x, training=False)

        with tf.GradientTape() as tape:
            # Forward pass of server
            server_predictions = self.server(x, training=False)

            # Compute losses
            server_loss = self.server_loss_fn(y, server_predictions)

            # Compute scaled distillation loss from https://arxiv.org/abs/1503.02531
            # The magnitudes of the gradients produced by the soft targets scale
            # as 1/T^2, multiply them by T^2 when using both hard and soft targets.
            distillation_loss = (
                self.distillation_loss_fn(
                    tf.nn.softmax(client_predictions / self.temperature, axis=1),
                    tf.nn.softmax(server_predictions / self.temperature, axis=1),
                )
                * self.temperature**2
            )

            loss = self.alpha * server_loss + (1 - self.alpha) * distillation_loss

        # Compute gradients
        trainable_vars = self.server.trainable_variables
        gradients = tape.gradient(loss, trainable_vars)

        # Update weights
        self.optimizer.apply_gradients(zip(gradients, trainable_vars))

        # Update the metrics configured in `compile()`.
        self.compiled_metrics.update_state(y, server_predictions)

        # Return a dict of performance
        results = {m.name: m.result() for m in self.metrics}
        results.update(
            {"server_loss": server_loss, "distillation_loss": distillation_loss}
        )
        return results

    def test_step(self, data):
        # Unpack the data
        x, y = data

        # Compute predictions
        y_prediction = self.server(x, training=False)
        y_train= y_prediction

        # Calculate the loss
        server_loss = self.server_loss_fn(y, y_prediction)

        # Update the metrics.
        self.compiled_metrics.update_state(y, y_prediction)

        # Return a dict of performance
        results = {m.name: m.result() for m in self.metrics}
        results.update({"server_loss": server_loss})
        return results

In [None]:
# Create the client
client = keras.Sequential(
    [
        keras.Input(shape=(28, 28, 1)),
        layers.Conv2D(10, (3, 3), strides=(2, 2), padding="same"),
        layers.LeakyReLU(alpha=0.2),
        layers.MaxPooling2D(pool_size=(2, 2), strides=(1, 1), padding="same"),
        layers.Conv2D(512, (3, 3), strides=(2, 2), padding="same"),
        layers.Flatten(),
        layers.Dense(10),
    ],
    name="client",
)

# Create the server
server = keras.Sequential(
    [
        keras.Input(shape=(28, 28, 1)),
        layers.Conv2D(1, (3, 3), strides=(2, 2), padding="same"),
        layers.LeakyReLU(alpha=0.2),
        layers.MaxPooling2D(pool_size=(2, 2), strides=(1, 1), padding="same"),
        layers.Conv2D(32, (3, 3), strides=(2, 2), padding="same"),
        layers.Flatten(),
        layers.Dense(10),
    ],
    name="server",
)

# Clone server for later comparison
server_scratch = keras.models.clone_model(server)
server_scratch1 = keras.models.clone_model(server)

In [None]:
# Prepare the train and test dataset.
batch_size = 64
(x_train1, y_train1), (x_test, y_test) = keras.datasets.mnist.load_data()
x_train=x_train1[0:1000]
y_train=y_train1[0:1000]
# Normalize data
x_train1 = x_train1.astype("float32") / 255.0
x_train1 = np.reshape(x_train1, (-1, 28, 28, 1))
# Normalize data
x_train = x_train.astype("float32") / 255.0
x_train = np.reshape(x_train, (-1, 28, 28, 1))

x_test = x_test.astype("float32") / 255.0
x_test = np.reshape(x_test, (-1, 28, 28, 1))

In [None]:
# Train client as usual
client.compile(
    optimizer=keras.optimizers.Adam(),
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=[keras.metrics.SparseCategoricalAccuracy()],
)

# Train and evaluate client on data.
client.fit(x_train1, y_train1, epochs=5)
client.evaluate(x_test, y_test)

Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5


[0.085181824862957, 0.9743000268936157]

In [None]:
distiller = Distiller(server=server, client=client)
distiller.compile(
    optimizer=keras.optimizers.Adam(),
    metrics=[keras.metrics.SparseCategoricalAccuracy()],
    server_loss_fn=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    distillation_loss_fn=keras.losses.KLDivergence(),
    alpha=0,      # 0.955 normal make it 0 for only aggregation result sharing
    temperature=5,
  )

In [None]:
distiller1 = Distiller(server=client, client=server)
distiller1.compile(
      optimizer=keras.optimizers.Adam(),
      metrics=[keras.metrics.SparseCategoricalAccuracy()],
      server_loss_fn=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
      distillation_loss_fn=keras.losses.KLDivergence(),
      alpha=0,    # 0.01 normal make it 0 for only aggregation result sharing
      temperature=10,
  )


In [None]:
for i in range(0,10):
  # Initialize and compile distiller


  # Distill client to server
  distiller.fit(x_train, y_train, epochs=5)

  # Evaluate server on test dataset
  distiller.evaluate(x_test, y_test)
  # Initialize and compile distiller

  # Distill client to server
  distiller1.fit(x_train, y_train, epochs=1)

  # Evaluate server on test dataset
  distiller1.evaluate(x_test, y_test)
  print("---------------------------------------------------------------------------------------------------------")

Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
---------------------------------------------------------------------------------------------------------
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
---------------------------------------------------------------------------------------------------------
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
---------------------------------------------------------------------------------------------------------
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
---------------------------------------------------------------------------------------------------------
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
---------------------------------------------------------------------------------------------------------
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
---------------------------------------------------------------------------------------------------------
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
--------------