# Sajad Rahmanian - 97101683
# Q2

In [None]:
import tensorflow as tf
from tensorflow.keras.applications.resnet50 import ResNet50
import math
from tensorflow import keras
from tensorflow.keras import layers

# ResNet18
***Tensorflow doesn't have standard implementation of ResNet18, so I used the following implementation from [this link](https://github.com/jimmyyhwu/resnet18-tf2). I checked the layers of this network, and they were the same as torchvision standard implementation.***

In [None]:
! git clone https://github.com/jimmyyhwu/resnet18-tf2.git
! cp '/content/resnet18-tf2/resnet.py' '/content/'

Cloning into 'resnet18-tf2'...
remote: Enumerating objects: 6, done.[K
remote: Counting objects: 100% (6/6), done.[K
remote: Compressing objects: 100% (5/5), done.[K
remote: Total 6 (delta 1), reused 6 (delta 1), pack-reused 0[K
Unpacking objects: 100% (6/6), 2.89 KiB | 2.89 MiB/s, done.


In [None]:
from resnet import resnet18

# Loading the CIFAR-10 Dataset

In [None]:
(training_images, training_labels) , (validation_images, validation_labels) = tf.keras.datasets.cifar10.load_data()

Downloading data from https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz


***Normalizing the dataset***

In [None]:
def preprocess_image_input(input_images):
  input_images = input_images.astype('float32')
  output_ims = tf.keras.applications.resnet50.preprocess_input(input_images)
  return output_ims

In [None]:
train_X = preprocess_image_input(training_images)
valid_X = preprocess_image_input(validation_images)

# Linear tuning the ResNet50 model

In [None]:
def resnet50(x):
    y = layers.UpSampling2D(size=(7, 7), interpolation='bicubic')(x)
    feature_map = ResNet50(input_shape=(224, 224, 3), include_top=False, weights='imagenet')(y)
    y = layers.GlobalAveragePooling2D()(feature_map)
    y = layers.Flatten()(y)
    y = layers.Dense(10, name='classifier')(y)
    return y

def resnet50_model():
    inputs = layers.Input((32, 32, 3))
    outputs = resnet50(inputs)
    model = keras.Model(inputs=inputs, outputs=outputs)
    model.compile(optimizer='SGD',
                  loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                  metrics=['accuracy'])
    return model

In [None]:
teacher_model = resnet50_model()
teacher_model.layers[2].trainable = False
teacher_model.summary()

Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/resnet/resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5
Model: "model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_1 (InputLayer)        [(None, 32, 32, 3)]       0         
                                                                 
 up_sampling2d (UpSampling2D  (None, 224, 224, 3)      0         
 )                                                               
                                                                 
 resnet50 (Functional)       (None, 7, 7, 2048)        23587712  
                                                                 
 global_average_pooling2d (G  (None, 2048)             0         
 lobalAveragePooling2D)                                          
                                                                 
 flatten (Flatten)           (None, 2048)         

In [None]:
from tensorflow.python.client import device_lib
print(device_lib.list_local_devices())

[name: "/device:CPU:0"
device_type: "CPU"
memory_limit: 268435456
locality {
}
incarnation: 5220552397384426829
xla_global_id: -1
, name: "/device:GPU:0"
device_type: "GPU"
memory_limit: 14415560704
locality {
  bus_id: 1
  links {
  }
}
incarnation: 12960937943434928406
physical_device_desc: "device: 0, name: Tesla T4, pci bus id: 0000:00:04.0, compute capability: 7.5"
xla_global_id: 416903419
]


In [None]:
EPOCHS = 3
BATCH_SIZE = 64

In [None]:
history = teacher_model.fit(train_X,
                    training_labels,
                    epochs=EPOCHS,
                    validation_data=(valid_X, validation_labels),
                    batch_size=BATCH_SIZE)

Epoch 1/3
Epoch 2/3
Epoch 3/3


***Making teacher untrainable***

In [None]:
for layer in teacher_model.layers:
    layer.trainable = False

# Knowledge Distillation Model

In [None]:
class TeacherStudent(keras.Model):
    def __init__(self, student_net, teacher_net):
        super().__init__()
        self.student = student_net
        self.teacher = teacher_net

    def compile(self, alpha=0.1, tau=3):
        super().compile(optimizer=keras.optimizers.SGD(), metrics=[keras.metrics.SparseCategoricalAccuracy()])
        self.student_loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True)
        self.teacher_student_loss_fn = keras.losses.CategoricalCrossentropy(from_logits=True)
        self.alpha = alpha
        self.tau = tau
    
    def augmented_loss_fn(self, student_loss, teacher_student_loss):
        loss_ = self.alpha * student_loss + (1 - self.alpha) * (self.tau ** 2) * teacher_student_loss
        return loss_

    def train_step(self, data):
        x, y = data
        teacher_predictions = self.teacher(x, training=False)
        with tf.GradientTape() as tape:
            student_predictions = self.student(x, training=True)
            student_loss = self.student_loss_fn(y, student_predictions)
            teacher_student_loss = self.teacher_student_loss_fn(
                    tf.nn.softmax(teacher_predictions / self.tau, axis=1),
                    tf.nn.softmax(student_predictions / self.tau, axis=1))

            loss = self.augmented_loss_fn(student_loss, teacher_student_loss)

        trainable_vars = self.student.trainable_variables
        gradients = tape.gradient(loss, trainable_vars)
        self.optimizer.apply_gradients(zip(gradients, trainable_vars))
        self.compiled_metrics.update_state(y, student_predictions)
        results = {m.name: m.result() for m in self.metrics}
        results.update(
            {"true_loss": student_loss, "augmented_loss": loss}
        )
        return results

    def test_step(self, data):
        x, y = data
        y_prediction = self.student(x, training=False)
        student_loss = self.student_loss_fn(y, y_prediction)
        self.compiled_metrics.update_state(y, y_prediction)
        results = {m.name: m.result() for m in self.metrics}
        results.update({"true_loss": student_loss})
        return results

In [None]:
def resnet18_new(x):
    y = layers.UpSampling2D(size=(7, 7), interpolation='bicubic')(x)
    y = resnet18(y, num_classes=10)
    return y

inputs = layers.Input((32, 32, 3))
outputs = resnet18_new(inputs)
student_model = keras.Model(inputs=inputs, outputs=outputs)
student_scratch = keras.models.clone_model(student_model)

# Trying different values for $\tau$ and $\alpha$

In [None]:
taus = [1, 3, 5]
alphas = [0.2, 0.5, 0.8]
for tau in taus:
    for alpha in alphas:
        student_temp = keras.models.clone_model(student_model)
        teacher_student = TeacherStudent(student_temp, teacher_model)
        teacher_student.compile(alpha=alpha, tau=tau)
        teacher_student.fit(train_X,
                    training_labels,
                    epochs=1,
                    validation_data=(valid_X, validation_labels),
                    batch_size=BATCH_SIZE)



In [None]:
student_final = keras.models.clone_model(student_model)
teacher_student = TeacherStudent(student_final, teacher_model)
teacher_student.compile(alpha=0.8, tau=1)

***The following results are obtained after 13 epochs***

In [None]:
teacher_student.fit(train_X,
                    training_labels,
                    epochs=5,
                    validation_data=(valid_X, validation_labels),
                    batch_size=BATCH_SIZE)

Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5


<keras.callbacks.History at 0x7fe3bcde93a0>

# Training ResNet18 without teacher

In [None]:
student_scratch.compile(optimizer='SGD',
                        loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                        metrics=['accuracy'])

***The following results are obtained after 13 epochs***

In [None]:
history2 = student_scratch.fit(train_X,
                               training_labels,
                               epochs=EPOCHS+1,
                               validation_data=(valid_X, validation_labels),
                               batch_size=BATCH_SIZE)

Epoch 1/4
Epoch 2/4
Epoch 3/4
Epoch 4/4


***We can see that the normal training has better results. However, the results seem to be noisy since in some epochs validation accuracy changes unusally. But the better results can otherwise be due to the fact that in second method we are optimizing the true loss***

# Fine-tuning

In [None]:
res50_model = resnet50_model()
history = res50_model.fit(train_X,
                            training_labels,
                            epochs=EPOCHS,
                            validation_data=(valid_X, validation_labels),
                            batch_size=BATCH_SIZE)

Epoch 1/3
Epoch 2/3
Epoch 3/3


***We can see that the outputs don't differ much. This is probably because that the feature maps that are obtained for imagenet dataset classification are also good for classifying cifar-10***