In [1]:
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 

import tensorflow as tf
import datetime
import os

import numpy as np
import matplotlib.pyplot as plt

### Prepare Datasets

 * Load train dataset and split into train and validation sets
 * Load test dataset
 * Augment train data with random flip, rotation, translation and zoom operations

In [2]:
train, validation = tf.keras.utils.image_dataset_from_directory(
    '../Dataset/train',
    labels='inferred',
    label_mode='categorical',
    image_size=(224,224),
    seed=815,
    validation_split=0.1,
    subset='both'
)

test = tf.keras.utils.image_dataset_from_directory(
    '../Dataset/test',
    labels='inferred',
    label_mode='categorical',
    image_size=(224,224)
)

def prepare(dataset):
    data_augmentation = tf.keras.Sequential([
        tf.keras.layers.RandomFlip("horizontal_and_vertical"),
        tf.keras.layers.RandomRotation(0.5),
        tf.keras.layers.RandomTranslation(0.33, 0.33),
        tf.keras.layers.RandomZoom(0.33),
    ])

    augmented = dataset
    for _ in range(9):
        augmented = augmented.concatenate(
            dataset.map(
                lambda x, y: (data_augmentation(x, training=True), y),
                num_parallel_calls=tf.data.AUTOTUNE
            )
        )
        
    return augmented.prefetch(buffer_size=tf.data.AUTOTUNE)


train = prepare(train)

Found 22573 files belonging to 25 classes.
Using 20316 files for training.
Using 2257 files for validation.
Found 2500 files belonging to 25 classes.


## Distiller class
https://keras.io/examples/vision/knowledge_distillation


In [3]:
# https://keras.io/examples/vision/knowledge_distillation
class Distiller(tf.keras.Model):
    def __init__(self, student, teacher):
        super().__init__()
        self.teacher = teacher
        self.student = student

    def compile(
        self,
        optimizer,
        metrics,
        student_loss_fn,
        distillation_loss_fn,
        alpha=0.1,
        temperature=3,
    ):
        """ Configure the distiller.

        Args:
            optimizer: Keras optimizer for the student weights
            metrics: Keras metrics for evaluation
            student_loss_fn: Loss function of difference between student
                predictions and ground-truth
            distillation_loss_fn: Loss function of difference between soft
                student predictions and soft teacher predictions
            alpha: weight to student_loss_fn and 1-alpha to distillation_loss_fn
            temperature: Temperature for softening probability distributions.
                Larger temperature gives softer distributions.
        """
        super().compile(optimizer=optimizer, metrics=metrics)
        self.student_loss_fn = student_loss_fn
        self.distillation_loss_fn = distillation_loss_fn
        self.alpha = alpha
        self.temperature = temperature

    def train_step(self, data):
        # Unpack data
        x, y = data

        # Forward pass of teacher
        teacher_predictions = self.teacher(x, training=False)

        with tf.GradientTape() as tape:
            # Forward pass of student
            student_predictions = self.student(x, training=True)

            # Compute losses
            student_loss = self.student_loss_fn(y, student_predictions)

            # Compute scaled distillation loss from https://arxiv.org/abs/1503.02531
            # The magnitudes of the gradients produced by the soft targets scale
            # as 1/T^2, multiply them by T^2 when using both hard and soft targets.
            distillation_loss = (
                self.distillation_loss_fn(
                    tf.nn.softmax(teacher_predictions / self.temperature, axis=1),
                    tf.nn.softmax(student_predictions / self.temperature, axis=1),
                )
                * self.temperature**2
            )

            loss = self.alpha * student_loss + (1 - self.alpha) * distillation_loss

        # Compute gradients
        trainable_vars = self.student.trainable_variables
        gradients = tape.gradient(loss, trainable_vars)

        # Update weights
        self.optimizer.apply_gradients(zip(gradients, trainable_vars))

        # Update the metrics configured in `compile()`.
        self.compiled_metrics.update_state(y, student_predictions)

        # Return a dict of performance
        results = {m.name: m.result() for m in self.metrics}
        results.update(
            {"student_loss": student_loss, "distillation_loss": distillation_loss}
        )
        return results

    def test_step(self, data):
        # Unpack the data
        x, y = data

        # Compute predictions
        y_prediction = self.student(x, training=False)

        # Calculate the loss
        student_loss = self.student_loss_fn(y, y_prediction)

        # Update the metrics.
        self.compiled_metrics.update_state(y, y_prediction)

        # Return a dict of performance
        results = {m.name: m.result() for m in self.metrics}
        results.update({"student_loss": student_loss})
        return results


### Create model
 * Convolutional block
   * Residual "skipping" layer
   * Batch normalization around each convolutional layer
   * (2,2) Max Pooling with stride (2,2) after all convulutional layers
 * Flatten
 * Dense layers with ReLu activation
 * Output layer with Softmax activation


In [4]:
def generate_model(conv_layers, dense_layers):    
    img_input = tf.keras.layers.Input(shape=(224, 224, 3))
    x = tf.keras.layers.BatchNormalization()(img_input)

    for filters, kernel_size, depth, pool in conv_layers:
        shortcut = tf.keras.layers.Conv2D(
            filters, 
            1, 
            activation='relu')(x)
        shortcut = tf.keras.layers.BatchNormalization()(shortcut)
        
        for i in range(depth):
            x = tf.keras.layers.Conv2D(                
                filters,
                kernel_size,
                activation='relu',
                padding='same'
            )(x)
            x = tf.keras.layers.BatchNormalization()(x)
                
        x = tf.keras.layers.Add()([shortcut,x])
        x = tf.keras.layers.Activation('relu')(x)
            
        if pool > 1:
            x = tf.keras.layers.MaxPool2D(
                pool_size=pool
            )(x)
            
        x = tf.keras.layers.BatchNormalization()(x)
        
    x = tf.keras.layers.Flatten()(x)
    
    for units in dense_layers:
        x = tf.keras.layers.Dense(
            units,
            activation='relu'
        )(x)
    
    x = tf.keras.layers.Dense(
        25,
        activation='softmax'
    )(x)
    
    
    model = tf.keras.Model(img_input, x)
    
    return model

### Run 
 * Load teacher model and weights using supplied id
 * Instantiate model with set parameters
 * Save model json to file
 * Initialize learning rate scheduler
 * Instantiate and compile distiller class
 * Fit student model using distiller class
   * Early stopping callback
   * TensorBoard callback
 * Fit model to training dataset

In [5]:
def run(i, id, teacher_id):    
    json = ''
    with open(f'models/{teacher_id}.json', 'r') as json_file:
        json = json_file.read()
    
    teacher = tf.keras.models.model_from_json(json)
    teacher.load_weights(f'checkpoints/{teacher_id}')
    #teacher.summary()
    
    student = generate_model(
        [
            # filters, kernel_size, depth, pool
            (16, 3, 2, 2),
            (32, 3, 2, 2),
            (64, 3, 2, 2),
        ], 
        [
            # units
            96,
            128,
        ])
    
    with open(f'models/{id}-{i}.json', 'w') as f:
        f.write(student.to_json())
        
    lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
        initial_learning_rate=0.002, 
        decay_steps=6350, 
        decay_rate=0.67
    )
    
    distiller = Distiller(student=student, teacher=teacher)
    distiller.compile(
        optimizer=tf.keras.optimizers.Adam(learning_rate=lr_schedule),
        student_loss_fn=tf.keras.losses.CategoricalCrossentropy(),
        metrics=['accuracy'],
        distillation_loss_fn=tf.keras.losses.KLDivergence(),
        alpha=0.1,
        temperature=5,
    )
        
    distiller.fit(
        train,
        validation_data=validation,
        epochs=10, 
        callbacks=[
            tf.keras.callbacks.EarlyStopping(monitor='val_student_loss'),
            tf.keras.callbacks.TensorBoard(
                log_dir=f'logs/fit/{id}-{i}'
            ),
            #tf.keras.callbacks.ModelCheckpoint(
            #    filepath=f'checkpoints/{id}',
            #    monitor='val_accuracy',
            #)
        ]
    )
    
    student.save_weights(f'checkpoints/{id}-{i}')

In [6]:
now = datetime.datetime.now().strftime('%Y%m%d-%H%M%S')
for i in range(1):
    run(i, f'{now}', '20230530-223331')

Epoch 1/10