# Knowledge distillation

In [1]:
import pathlib
import shutil
import os
import time
import datetime
import numpy as np
import wandb
from wandb.keras import WandbCallback
import tensorflow as tf
import matplotlib.pyplot as plt
import zipfile

from typing import List

def load_data(run, artifact_name = "phcd_paper_splits_tfds") -> List[tf.data.Dataset]:
    """
    Downloads datasets from a wandb artifact and loads them into a list of tf.data.Datasets.
    """

    artifact = run.use_artifact(f"master-thesis/{artifact_name}:latest")
    artifact_dir = pathlib.Path(
        f"./artifacts/{artifact.name.replace(':', '-')}"
    ).resolve()
    if not artifact_dir.exists():
        artifact_dir = artifact.download()
        artifact_dir = pathlib.Path(artifact_dir).resolve()

    # if tf.__version__ minor is less than 10, use
    # tf.data.experimental.load instead of tf.data.Dataset.load

    if int(tf.__version__.split(".")[1]) < 10:
        load_function = tf.data.experimental.load
    else:
        load_function = tf.data.Dataset.load
    
    output_list = []
    for split in ["train", "test", "val"]:
        ds = load_function(str(artifact_dir / split), compression="GZIP")
        output_list.append(ds)
    
    return output_list

def get_readable_class_labels(subset = 'phcd_paper'):
    if subset == 'phcd_paper':
        return ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'a', 'b', 'c',
        'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p',
        'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', 'A', 'B', 'C',
        'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P',
        'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', 'ą', 'ć', 'ę',
        'ł', 'ń', 'ó', 'ś', 'ź', 'ż', 'Ą', 'Ć', 'Ę', 'Ł', 'Ń', 'Ó', 'Ś',
        'Ź', 'Ż', '+', '-', ':', ';', '$', '!', '?', '@', '.']
    elif subset == 'uppercase':
        return ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 
        'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', 'Ą', 'Ć', 
        'Ę', 'Ł', 'Ń', 'Ó', 'Ś', 'Ź', 'Ż']
    elif subset == 'lowercase':
        return ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm',
        'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', 'ą', 'ć',
        'ę', 'ł', 'ń', 'ó', 'ś', 'ź', 'ż']
    elif subset == 'numbers':
        return ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']
    elif subset == 'uppercase_no_diacritics':
        return ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M',
        'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z']
    elif subset == 'lowercase_no_diacritics':
        return ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm',
        'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z']

def calculate_accuracy_per_class(model, test_dataset, test_dataset_name):
    '''
    Calculates the accuracy per class for a given model and test dataset.

    Returns dict with class labels as keys and accuracy as values.
    '''
        
    y_pred = model.predict(test_dataset)
    y_pred = np.argmax(y_pred, axis=1)
    # get labels
    y_true = test_dataset.map(lambda x, y: y).as_numpy_iterator()
    y_true = np.concatenate(list(y_true))
    # calculate accuracy per class
    labels = get_readable_class_labels(test_dataset_name)
    class_accuracy = np.zeros(len(labels))
    for i, label in enumerate(labels):
        class_accuracy[i] = np.sum(y_pred[y_true == i] == i) / np.sum(y_true == i)
    return { label: acc for label, acc in zip(labels, class_accuracy) }
    

def plot_accuracy_per_class(class_accuracy_dict):
    plt.figure(figsize=(10, 5))
    labels = list(class_accuracy_dict.keys())
    class_accuracy = list(class_accuracy_dict.values())
    plt.bar(labels, class_accuracy)
    plt.xticks(labels)
    plt.xlabel("Class")
    plt.ylabel("Accuracy")
    plt.title("Accuracy per class")
    plt.show()


def accuracy_table(class_accuracy_dict):
    labels = list(class_accuracy_dict.keys())
    class_accuracy = list(class_accuracy_dict.values())
    return wandb.Table(columns=["Class", "Accuracy"], data=list(zip(labels, class_accuracy)))

def get_number_of_classes(ds: tf.data.Dataset) -> int:
    """
    Returns the number of classes in a dataset.
    """
    labels_iterator= ds.map(lambda x, y: y).as_numpy_iterator()
    labels = np.concatenate(list(labels_iterator))
    return len(np.unique(labels))

def get_number_of_examples(ds: tf.data.Dataset) -> int:
    """
    Returns the number of examples in a dataset.
    """
    return sum(1 for _ in ds)

def preprocess_dataset(ds: tf.data.Dataset, batch_size: int, cache: bool = True) -> tf.data.Dataset:
    ds = ds.map(lambda x, y: (tf.cast(x, tf.float32) / 255.0, y))  # normalize
    ds = ds.unbatch().batch(batch_size)
    if cache:
        ds = ds.cache().prefetch(buffer_size=tf.data.AUTOTUNE)
    return ds

def calculate_model_compressed_size_on_disk(path: str) -> int:
    compressed_path = path + ".zip"
    with zipfile.ZipFile(compressed_path, 'w', compression=zipfile.ZIP_DEFLATED) as f:
        f.write(path)
    return pathlib.Path(compressed_path).stat().st_size    

def calculate_model_num_parameters(model: tf.keras.Model) -> int:
    return model.count_params()

def calculate_model_flops(summary) -> float:
    # from run.summary get GFLOPs or GFLOPS whichever is available
    if "GFLOPs" in summary.keys():
        return summary.get("GFLOPs")
    elif "GFLOPS" in summary.keys():
        return summary.get("GFLOPS")
    else:
        return 0

def calculate_model_throughput(model, test_dataset, batch_size) -> float:
    '''
    Calculates the average throughput of a model over 50 batch predictions of 100 batches, in images per second.
    '''
    output = []
    for _ in range(50):
        start = time.time()
        model.predict(test_dataset.take(100))
        end = time.time()
        output.append((100 * batch_size) / (end - start))
    return np.mean(output)

def plot_history(history, title):
    plt.figure(figsize=(15,7))
    plt.suptitle(title)
    
    plt.subplot(121)
    plt.plot(history.history['accuracy'], label='train')
    plt.plot(history.history['val_accuracy'], label='val')
    plt.ylabel('accuracy')
    plt.xlabel('epoch')
    plt.legend()
    
    plt.subplot(122)
    plt.plot(history.history['loss'], label='train')
    plt.plot(history.history['val_loss'], label='val')
    plt.ylabel('loss')
    plt.xlabel('epoch')
    plt.legend()

In [2]:
print("Num GPUs Available: ", len(tf.config.experimental.list_physical_devices('GPU')))
print("Available devices: ", tf.config.list_physical_devices())

Num GPUs Available:  0
Available devices:  [PhysicalDevice(name='/physical_device:CPU:0', device_type='CPU')]


# Baseline model

In [3]:
defaults = dict(
    batch_size=32*2,
    epochs=60,    
    optimizer="adam"
)

artifact_base_name = "phcd_paper"
model_name = f"architecture-1-distilled-{artifact_base_name}"
artifact_name = f"{artifact_base_name}_splits_tfds" # "phcd_paper_splits_tfds
run = wandb.init(project="master-thesis", job_type="training", name=model_name, config=defaults, tags=[artifact_name])
    
# hyperparameters
epochs = wandb.config.epochs
bs = wandb.config.batch_size

ds_train, ds_test, ds_val = load_data(run, artifact_name=artifact_name)

num_classes = get_number_of_classes(ds_val)

ds_train = preprocess_dataset(ds_train, batch_size=bs)
ds_val = preprocess_dataset(ds_val, batch_size=bs)
ds_test = preprocess_dataset(ds_test, batch_size=bs, cache=False)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mgratkadlafana[0m. Use [1m`wandb login --relogin`[0m to force relogin


Instructions for updating:
Lambda fuctions will be no more assumed to be used in the statement where they are used, or at least in the same block. https://github.com/tensorflow/tensorflow/issues/56089


# Download teacher_model model & define student model

In [26]:
api = wandb.Api()
runs = api.runs(f"gratkadlafana/master-thesis")
baseline_run = [run for run in runs if run.name == "baseline-phcd_paper"][0]

run = api.run(f"gratkadlafana/master-thesis/297mmd8r")
model_baseline = run.file("model_baseline.h5").download(replace=True)
teacher_model = tf.keras.models.load_model(model_baseline.name)



In [13]:
student_model = tf.keras.Sequential([
    tf.keras.layers.InputLayer(input_shape=(32, 32, 1)),
    tf.keras.layers.Conv2D(16, 3, padding="same", activation="relu"),
    tf.keras.layers.MaxPooling2D(),
    tf.keras.layers.Conv2D(32, 3, padding="same", activation="relu"),
    tf.keras.layers.MaxPooling2D(),
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(64, activation="relu"),
    tf.keras.layers.Dense(89, activation="softmax"),
])
student_model.summary()

Model: "sequential_10"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 conv2d_22 (Conv2D)          (None, 32, 32, 16)        160       
                                                                 
 max_pooling2d_22 (MaxPoolin  (None, 16, 16, 16)       0         
 g2D)                                                            
                                                                 
 conv2d_23 (Conv2D)          (None, 16, 16, 32)        4640      
                                                                 
 max_pooling2d_23 (MaxPoolin  (None, 8, 8, 32)         0         
 g2D)                                                            
                                                                 
 flatten_11 (Flatten)        (None, 2048)              0         
                                                                 
 dense_19 (Dense)            (None, 64)              

# Distiller class

In [None]:
student_scratch = tf.keras.models.clone_model(student_model)

class Distiller(tf.keras.Model):
    def __init__(self, student, teacher):
        super(Distiller, self).__init__()
        self.student = student
        self.teacher = teacher

    def compile(
        self,
        optimizer,
        metrics,
        student_loss_fn,
        distillation_loss_fn,
        alpha=0.05,
        temperature=1
    ):
        super(Distiller, self).compile(optimizer=optimizer, metrics=metrics)
        self.student_loss_fn = student_loss_fn
        self.distillation_loss_fn = distillation_loss_fn
        self.alpha = alpha
        self.temperature = temperature

    def train_step(self, data):
        x, y = data

        # Forward pass of teacher
        teacher_predictions = self.teacher(x, training=False)

        with tf.GradientTape() as tape:
            # Forward pass of student
            student_predictions = self.student(x, training=True)

            # Compute losses
            student_loss = self.student_loss_fn(y, student_predictions)
            distillation_loss = self.distillation_loss_fn(
                tf.nn.softmax(teacher_predictions / self.temperature, axis=1),
                tf.nn.softmax(student_predictions / self.temperature, axis=1),
            )
            loss = self.alpha * student_loss + (1 - self.alpha) * distillation_loss

        # Compute gradients
        trainable_vars = self.student.trainable_variables
        gradients = tape.gradient(loss, trainable_vars)

        # update weights
        self.optimizer.apply_gradients(zip(gradients, trainable_vars))

        # Update the metrics configured in `compile()`.
        self.compiled_metrics.update_state(y, student_predictions)

        # Return a dict of performance
        results = {m.name: m.result() for m in self.metrics}
        results.update({"student_loss": student_loss, "distillation_loss": distillation_loss})
        return results

    def test_step(self, data):
        # unpack the data
        x, y = data

        # compute predictions
        y_prediction = self.student(x, training=False)

        # calculate the loss
        student_loss = self.student_loss_fn(y, y_prediction)

        # update the metrics
        self.compiled_metrics.update_state(y, y_prediction)

        # return a dict of performance
        results = {m.name: m.result() for m in self.metrics}
        results.update({"student_loss": student_loss})
        return results


# Distill small student

In [None]:
# Initialize and compile distiller
distiller = Distiller(student=student_model, teacher=teacher_model)
distiller.compile(
    optimizer=tf.keras.optimizers.Adam(),
    metrics=[tf.keras.metrics.SparseCategoricalAccuracy()],
    student_loss_fn=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    distillation_loss_fn=tf.keras.losses.KLDivergence(),
    alpha=0.1,
    temperature=10,
)

history_distiller = distiller.fit(
    ds_train,
    epochs=60,
    validation_data=ds_val,
    callbacks=[
        WandbCallback(
            compute_flops=True, 
            save_model=False, 
            log_weights=False, 
            log_gradients=False
        )
    ],
)

student_model.compile(
    optimizer="adam",
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=["accuracy"],
)

tf.keras.models.save_model(student_model, 'model_distilled.h5', include_optimizer=False)

In [None]:
# evaluate model on ds_test and log to wandb
test_loss_before, test_acc_before = student_model.evaluate(ds_test)

# calculate model size on disk, flops and number of parameters
flops = calculate_model_flops(wandb.run.summary)
num_parameters = calculate_model_num_parameters(student_model)
compressed_disk_size = calculate_model_compressed_size_on_disk('model_distilled.h5')
class_acc = calculate_accuracy_per_class(student_model, ds_test, artifact_base_name)

data_to_log = {
    "test_loss": test_loss_before, 
    "test_accuracy": test_acc_before, 
    "num_parameters": num_parameters,
    "compressed_disk_size": compressed_disk_size, 
    "model_flops": flops,
    "accuracy_per_class": accuracy_table(class_acc)
    }
plot_accuracy_per_class(class_acc)
print(data_to_log)
wandb.log(data_to_log)

# upload model to wandb
wandb.save('model_distilled.h5')
run.finish()

# Train same from scratch

In [None]:
model_name = f"architecture-1-scratch-{artifact_base_name}"
run = wandb.init(project="master-thesis", job_type="training", name=model_name, config=defaults, tags=[artifact_name])
student_scratch.compile(
    optimizer="adam",
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=["accuracy"],
)
history_scratch = student_scratch.fit(
    ds_train,
    epochs=60,
    validation_data=ds_val,
    callbacks=[
        WandbCallback(
            compute_flops=True, 
            save_model=False, 
            log_weights=False, 
            log_gradients=False
        )
    ],
)
tf.keras.models.save_model(distiller, 'model_scratch.h5', include_optimizer=False)

In [None]:
# evaluate model on ds_test and log to wandb
test_loss_before, test_acc_before = student_scratch.evaluate(ds_test)

# calculate model size on disk, flops and number of parameters
flops = calculate_model_flops(wandb.run.summary)
num_parameters = calculate_model_num_parameters(student_scratch)
compressed_disk_size = calculate_model_compressed_size_on_disk('model_scratch.h5')
class_acc = calculate_accuracy_per_class(student_scratch, ds_test, artifact_base_name)

data_to_log = {
    "test_loss": test_loss_before, 
    "test_accuracy": test_acc_before, 
    "num_parameters": num_parameters,
    "compressed_disk_size": compressed_disk_size, 
    "model_flops": flops,
    "accuracy_per_class": accuracy_table(class_acc)
    }
plot_accuracy_per_class(class_acc)
print(data_to_log)
wandb.log(data_to_log)

# upload model to wandb
wandb.save('model_scratch.h5')
run.finish()

# Plot training

In [None]:
plt.figure(figsize=(20, 10))
plt.plot(history_scratch.history["val_accuracy"], label="scratch")
plt.plot(history_distiller.history["val_accuracy"], label="distilled")
plt.legend()