# 1- Uplaoding necessary libraries 

In [None]:
import os
import pickle
import zipfile
import random
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.optimizers import RMSprop
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from shutil import copyfile
import numpy as np
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import (
    BatchNormalization, Conv2D, MaxPooling2D, Activation, Flatten, Dropout, Dense
)
from tensorflow.keras import backend as K
from sklearn.preprocessing import LabelBinarizer
from keras.preprocessing import image
import matplotlib.pyplot as plt

# 3- Creating Distiller Class 

In [None]:
class Distiller(keras.Model):
    def __init__(self, student, teacher):
        super(Distiller, self).__init__()
        self.teacher = teacher
        self.student = student

    def compile(
        self,
        optimizer,
        metrics,
        student_loss_fn,
        distillation_loss_fn,
        alpha=0.1,
        temperature=3,
    ):
        """ Configure the distiller.

        Args:
            optimizer: Keras optimizer for the student weights
            metrics: Keras metrics for evaluation
            student_loss_fn: Loss function of difference between student
                predictions and ground-truth
            distillation_loss_fn: Loss function of difference between soft
                student predictions and soft teacher predictions
            alpha: weight to student_loss_fn and 1-alpha to distillation_loss_fn
            temperature: Temperature for softening probability distributions.
                Larger temperature gives softer distributions.
        """
        super(Distiller, self).compile(optimizer=optimizer, metrics=metrics)
        self.student_loss_fn = student_loss_fn
        self.distillation_loss_fn = distillation_loss_fn
        self.alpha = alpha
        self.temperature = temperature

    def train_step(self, data):
        # Unpack data
        x, y = data

        # Forward pass of teacher
        teacher_predictions = self.teacher(x, training=False)

        with tf.GradientTape() as tape:
            # Forward pass of student
            student_predictions = self.student(x, training=True)

            # Compute losses
            student_loss = self.student_loss_fn(y, student_predictions)
            distillation_loss = self.distillation_loss_fn(
                tf.nn.softmax(teacher_predictions / self.temperature, axis=1),
                tf.nn.softmax(student_predictions / self.temperature, axis=1),
            )
            loss = self.alpha * student_loss + (1 - self.alpha) * distillation_loss

        # Compute gradients
        trainable_vars = self.student.trainable_variables
        gradients = tape.gradient(loss, trainable_vars)

        # Update weights
        self.optimizer.apply_gradients(zip(gradients, trainable_vars))

        # Update the metrics configured in `compile()`.
        self.compiled_metrics.update_state(y, student_predictions)

        # Return a dict of performance
        results = {m.name: m.result() for m in self.metrics}
        results.update(
            {"student_loss": student_loss, "distillation_loss": distillation_loss}
        )
        return results

    def test_step(self, data):
        # Unpack the data
        x, y = data

        # Compute predictions
        y_prediction = self.student(x, training=False)

        # Calculate the loss
        student_loss = self.student_loss_fn(y, y_prediction)

        # Update the metrics.
        self.compiled_metrics.update_state(y, y_prediction)
        # Return a dict of performance
        results = {m.name: m.result() for m in self.metrics}
        results.update({"student_loss": student_loss})
        return results
    def call(self, data, training=False): 
        # You don't need this method for training.
        # So just pass.
        pass
        

In [None]:
!pip install split-folders
import splitfolders
splitfolders.ratio('../input/plantvillage-dataset/color', output="output", seed=1337, ratio=(.8, 0.1,0.1))

# 4- Flowing Images from directory
In step 4, the traning and validation datasets are flown from the directory to a variables called train_generator and validation_genearator. we also specify the batch size and resize all the images to 250 by 250px.

In [None]:
TRAINING_DIR = "./output/train"
train_datagen = ImageDataGenerator(rescale=1.0/255
                
               )
train_generator = train_datagen.flow_from_directory(TRAINING_DIR,
                                                    class_mode='categorical',
                                                    batch_size=250,
                                                    target_size=(250, 250))

VALIDATION_DIR = "./output/test"
validation_datagen = ImageDataGenerator(rescale=1.0/255.)
validation_generator = validation_datagen.flow_from_directory(VALIDATION_DIR,
                                                              class_mode='categorical',
                                                              batch_size=250,
                                                              target_size=(250, 250))

# Sample Images From the Dataset
The code below defines a method plotImages which defined how to  plot sample images from the dataset using matplotlib library and imshow() method.  The Imagedatagenerator output array train_generator was used as input array to the method and then  five images  were selected from it.

In [None]:
def plotImages(images_arr):
    fig, axes = plt.subplots(1, 5, figsize=(20,20))
    axes = axes.flatten()
    for img, ax in zip( images_arr, axes):
        ax.imshow(img)
    plt.tight_layout()
    plt.show()
    
    
augmented_images = [train_generator[0][0][0] for i in range(5)]
plotImages(augmented_images)

# 5- Getting Categorical Labels
We obtained the the labels associated with each images that is all the directories name from our dataset and assign it to a varaible called train_y 

In [None]:
train_y=train_generator.classes


# 6- Converting categorical labels to binary otherwise called (Encoding)
Label Binarizer is an SciKit Learn class that accepts Categorical data as input and returns an Numpy array. Unlike Label Encoder, it encodes the data into dummy variables indicating the presence of a particular label or not. Encoding make column data using Label Binarizer

In [None]:
label_binarizer = LabelBinarizer()
image_labels = label_binarizer.fit_transform(train_y)
pickle.dump(label_binarizer,open('label_transform.pkl', 'wb'))
n_classes = len(label_binarizer.classes_)

## 7- Create student and teacher models

Initialy, we create a teacher model and a smaller student model using keras. Both models are
convolutional neural networks and created using `Sequential()`,
but could be any Keras model.

In [None]:
# Create the teacher
teacher = keras.Sequential(name="teacher",)
inputShape = (250, 250, 3)
chanDim = -1
if K.image_data_format() == "channels_first":
           inputShape = (3, 250, 250)
           chanDim = 1
teacher.add(Conv2D(32, (3, 3), padding="same",input_shape=inputShape))
teacher.add(Activation("relu"))
teacher.add(BatchNormalization(axis=chanDim))
teacher.add(MaxPooling2D(pool_size=(3, 3)))
teacher.add(Dropout(0.25))
teacher.add(Conv2D(64, (3, 3), padding="same"))
teacher.add(Activation("relu"))
teacher.add(BatchNormalization(axis=chanDim))
teacher.add(Conv2D(64, (3, 3), padding="same"))
teacher.add(Activation("relu"))
teacher.add(BatchNormalization(axis=chanDim))
teacher.add(MaxPooling2D(pool_size=(2, 2)))
teacher.add(Dropout(0.25))
teacher.add(Conv2D(128, (3, 3), padding="same"))
teacher.add(Activation("relu"))
teacher.add(BatchNormalization(axis=chanDim))
teacher.add(Conv2D(128, (3, 3), padding="same"))
teacher.add(Activation("relu"))
teacher.add(BatchNormalization(axis=chanDim))
teacher.add(MaxPooling2D(pool_size=(2, 2)))
teacher.add(Dropout(0.25))
teacher.add(Flatten())
teacher.add(Dense(1024))
teacher.add(Activation("relu"))
teacher.add(BatchNormalization())
teacher.add(Dropout(0.5))
teacher.add(Dense(n_classes))


# Create the student
student = keras.Sequential(
    [
        keras.Input(shape=(250, 250, 3)),
        layers.Conv2D(16, (3, 3), strides=(2, 2), padding="same"),
        layers.LeakyReLU(alpha=0.2),
        layers.MaxPooling2D(pool_size=(2, 2), strides=(1, 1), padding="same"),
        layers.Conv2D(32, (3, 3), strides=(2, 2), padding="same"),
        layers.LeakyReLU(alpha=0.2),
        layers.MaxPooling2D(pool_size=(2, 2), strides=(2, 2), padding="same"),
        layers.Flatten(),
        layers.Dense(n_classes),
    ],
    name="student",
)

# Clone student for later comparison
student_scratch = keras.models.clone_model(student)

# 8- Displaying the model summary
the custom method 'summary' allows us to diplay the architecture of the models described above, this allows us to know thw number parameters avaible in our defined model.

In [None]:
teacher.summary()

In [None]:
student.summary()

# 9- Compiling Teacher Model for training
The teacher model is first compile using custom compile method to achieve better accucaracy and learn more information

In [None]:
# Train teacher as usual
teacher.compile(
    optimizer=keras.optimizers.Adam(),
    loss=keras.losses.CategoricalCrossentropy(from_logits=True),
    metrics=['acc'],
)


## 10- Train the teacher

During training Step, the model sifts through preexisting data and draws conclusions based on what it “thinks” the data represents. Every time it comes to an incorrect conclusion, that result is fed back to the system so that it “learns” from its mistake. 

This process makes connections between the artificial neurons stronger over time and increases the likelihood that the system will make accurate predictions in the future. As it’s presented with novel data, the DNN should be able to categorize and analyze new and possibly more complex information. Ultimately, it will continue to learn from its encounters and become more intuitive over time.
* ** to train a model a fit() method is called on the model, fit takes arguments such as trainig data, validation data, epoch, train step and valudation step.


In [None]:
teacher_history=teacher.fit(train_generator, validation_data=validation_generator, epochs=5)


**11- Plotting Teacher Model Accuracy and loss Curves**

In [None]:
import matplotlib.pyplot as plt
acc = teacher_history.history['acc']
val_acc = teacher_history.history['val_acc']
loss = teacher_history.history['loss']
val_loss = teacher_history.history['val_loss']
print(range(1, len(acc) + 1))

epochs = range(1, len(acc) + 1)
#Train and validation accuracy
plt.plot(epochs, acc, 'b', label='Training accurarcy')
plt.plot(epochs, val_acc, 'r', label='Validation accurarcy')
plt.title('Training and Validation accurarcy')
plt.legend()



plt.figure()
#Train and validation loss
plt.plot(epochs, loss, 'b', label='Training loss')
plt.plot(epochs, val_loss, 'r', label='Validation loss')
plt.title('Training and Validation loss')
plt.legend()
plt.show()

**12 Distilling teacher knowledge to student Model**

In [None]:
# Initialize and compile distiller
distiller = Distiller(student=student, teacher=teacher)
distiller.compile(optimizer=keras.optimizers.Adam(), metrics=['acc'],
    student_loss_fn=keras.losses.CategoricalCrossentropy(from_logits=True),
    distillation_loss_fn=keras.losses.KLDivergence(),alpha=0.1,temperature=10,)

# Distill teacher to student
student_history=distiller.fit(train_generator, validation_data=validation_generator, epochs=3)


# Plotting Student model Accuracy and loss curves
to investigate the performance of the student model for comparison with teacher model, an accaucry and validation curve was plotted below.

In [None]:
import matplotlib.pyplot as plt
acc = student_history.history['acc']
val_acc =student_history.history['val_acc']
loss = student_history.history['distillation_loss']
val_loss = student_history.history['student_loss']
print(range(1, len(acc) + 1))

epochs = range(1, len(acc) + 1)
#Train and validation accuracy
plt.plot(epochs, acc, 'b', label='Training accurarcy')
plt.plot(epochs, val_acc, 'r', label='Validation accurarcy')
plt.title('Training and Validation accurarcy')
plt.legend()



plt.figure()
#Train and validation loss
plt.plot(epochs, loss, 'b', label='Training loss')
plt.plot(epochs, val_loss, 'r', label='Validation loss')
plt.title('Training and Validation loss')
plt.legend()
plt.show()

## Train student from scratch for comparison

We can also train an equivalent student model from scratch without the teacher, in order
to evaluate the performance gain obtained by knowledge distillation.

In [None]:
# Train student as doen usually
student_scratch.compile(
    optimizer=keras.optimizers.Adam(),
    loss=keras.losses.CategoricalCrossentropy(from_logits=True),
    metrics=['acc'],
)

# Train and evaluate student trained from scratch.
student_scratch.fit(train_generator, epochs=1)
std_history=student_scratch.evaluate(validation_generator)