<a href="https://colab.research.google.com/github/Steven032/ECE1512-2023F-ProjectRepo-Xiaohu.Yang-Yixin.Feng/blob/main/Task1.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Project A: Knowledge Distillation for Building Lightweight Deep Learning Models in Visual Classification Tasks

In [None]:
import tensorflow.compat.v2 as tf
import tensorflow_datasets as tfds
from typing import Union
tf.random.set_seed(1234)


tf.enable_v2_behavior()

builder = tfds.builder('mnist')
BATCH_SIZE = 256
NUM_EPOCHS = 12
NUM_CLASSES = 10  # 10 total classes.

# Data loading

In [None]:
# Load train and test splits.
def preprocess(x):
  image = tf.image.convert_image_dtype(x['image'], tf.float32)
  subclass_labels = tf.one_hot(x['label'], builder.info.features['label'].num_classes)
  return image, subclass_labels


mnist_train = tfds.load('mnist', split='train', shuffle_files=False).cache()
mnist_train = mnist_train.map(preprocess)
mnist_train = mnist_train.shuffle(builder.info.splits['train'].num_examples)
mnist_train = mnist_train.batch(BATCH_SIZE, drop_remainder=True)

mnist_test = tfds.load('mnist', split='test').cache()
mnist_test = mnist_test.map(preprocess).batch(BATCH_SIZE)

In [None]:
mnist_train

<_BatchDataset element_spec=(TensorSpec(shape=(256, 28, 28, 1), dtype=tf.float32, name=None), TensorSpec(shape=(256, 10), dtype=tf.float32, name=None))>

# Model creation

In [None]:
from keras.api._v2.keras import activations
from tensorflow.python.ops.gen_nn_ops import conv2d
#@test {"output": "ignore"}

# Build CNN teacher.
cnn_model = tf.keras.Sequential()

# your code start from here for stpe 2
cnn_model.add(tf.keras.layers.Conv2D(filters= 32,kernel_size = (3,3),strides = (1,1),activation="relu"))
cnn_model.add(tf.keras.layers.MaxPool2D(pool_size=(2,2),strides =(1,1)))
cnn_model.add(tf.keras.layers.Conv2D(64,(3,3),(1,1),activation = 'relu'))
cnn_model.add(tf.keras.layers.MaxPool2D(pool_size=(2,2),strides =(2,2)))
cnn_model.add(tf.keras.layers.Flatten())
cnn_model.add(tf.keras.layers.Dropout(rate = 0.5))
cnn_model.add(tf.keras.layers.Dense(128,activation='relu'))
cnn_model.add(tf.keras.layers.Dropout(rate = 0.5))
cnn_model.add(tf.keras.layers.Dense(10))


# Build fully connected student.
fc_model = tf.keras.Sequential()
fc_model.add(tf.keras.layers.Flatten())
fc_model.add(tf.keras.layers.Dense(784,activation = 'relu'))
fc_model.add(tf.keras.layers.Dense(784,activation = 'relu'))
fc_model.add(tf.keras.layers.Dense(10))






# Teacher loss function

In [None]:
@tf.function
def compute_teacher_loss(images, labels):
  """Compute subclass knowledge distillation teacher loss for given images
     and labels.

  Args:
    images: Tensor representing a batch of images.
    labels: Tensor representing a batch of labels.

  Returns:
    Scalar loss Tensor.
  """

  # Compute cross-entropy loss for subclasses.

  # your code start from here for step 3
  logits = cnn_model(images, training = True)
  cross_entropy_loss_value = tf.nn.softmax_cross_entropy_with_logits(labels=labels, logits = logits)


  return cross_entropy_loss_value

# Student loss function

In [None]:
#@test {"output": "ignore"}

# Hyperparameters for distillation (need to be tuned).
ALPHA = 0.5 # task balance between cross-entropy and distillation loss
DISTILLATION_TEMPERATURE = 4. #temperature hyperparameter

def distillation_loss(teacher_logits: tf.Tensor, student_logits: tf.Tensor,
                      temperature: Union[float, tf.Tensor]):
  """Compute distillation loss.

  This function computes cross entropy between softened logits and softened
  targets. The resulting loss is scaled by the squared temperature so that
  the gradient magnitude remains approximately constant as the temperature is
  changed. For reference, see Hinton et al., 2014, "Distilling the knowledge in
  a neural network."

  Args:
    teacher_logits: A Tensor of logits provided by the teacher.
    student_logits: A Tensor of logits provided by the student, of the same
      shape as `teacher_logits`.
    temperature: Temperature to use for distillation.

  Returns:
    A scalar Tensor containing the distillation loss.
  """
 # your code start from here for step 3
  soft_targets = tf.nn.softmax(teacher_logits/temperature)

  return tf.reduce_mean(
      tf.nn.softmax_cross_entropy_with_logits(
          soft_targets, student_logits / temperature)) * temperature ** 2

def compute_student_loss(images, labels):
  """Compute subclass knowledge distillation student loss for given images
     and labels.

  Args:
    images: Tensor representing a batch of images.
    labels: Tensor representing a batch of labels.

  Returns:
    Scalar loss Tensor.
  """
  student_subclass_logits = fc_model(images, training=True)

  # Compute subclass distillation loss between student subclass logits and
  # softened teacher subclass targets probabilities.

  # your code start from here for step 3
  #distillation loss
  teacher_subclass_logits = cnn_model(images, training=False)
  distillation_loss_value = distillation_loss(teacher_subclass_logits,student_subclass_logits,DISTILLATION_TEMPERATURE)

  # Compute cross-entropy loss with hard targets.
  cross_entropy_loss_value = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels,student_subclass_logits)).numpy()
  #total loss
  loss = ALPHA*distillation_loss_value+(1-ALPHA)*cross_entropy_loss_value
  return loss

# Train and evaluation

In [None]:
@tf.function
def compute_num_correct(model, images, labels):
  """Compute number of correctly classified images in a batch.

  Args:
    model: Instance of tf.keras.Model.
    images: Tensor representing a batch of images.
    labels: Tensor representing a batch of labels.

  Returns:
    Number of correctly classified images.
  """
  class_logits = model(images, training=False)
  return tf.reduce_sum(
      tf.cast(tf.math.equal(tf.argmax(class_logits, -1), tf.argmax(labels, -1)),
              tf.float32)), tf.argmax(class_logits, -1), tf.argmax(labels, -1)


def train_and_evaluate(model, compute_loss_fn):
  """Perform training and evaluation for a given model.

  Args:
    model: Instance of tf.keras.Model.
    compute_loss_fn: A function that computes the training loss given the
      images, and labels.
  """
  accuracies = []
  # your code start from here for step 4
  optimizer = tf.optimizers.Adam(learning_rate=0.001)

  for epoch in range(1, NUM_EPOCHS + 1):
    # Run training.
    print('Epoch {}: '.format(epoch), end='')
    for images, labels in mnist_train:
      with tf.GradientTape() as tape:
         # your code start from here for step 4

        loss_value = compute_loss_fn(images,labels)

      grads = tape.gradient(loss_value,model.trainable_variables)
      optimizer.apply_gradients(zip(grads, model.trainable_variables))

    # Run evaluation.
    num_correct = 0
    num_total = builder.info.splits['test'].num_examples
    for images, labels in mnist_test:
      # your code start from here for step 4
      num_correct += compute_num_correct(model,images,labels)[0]
    print("Class_accuracy: " + '{:.2f}%'.format(
        num_correct / num_total * 100))
    accuracies.append((num_correct / num_total * 100).numpy())
  return accuracies


# Training models

In [None]:
# your code start from here for step 5

train_and_evaluate(cnn_model, compute_teacher_loss)

Epoch 1: 



Class_accuracy: 98.08%
Epoch 2: Class_accuracy: 98.54%
Epoch 3: 

KeyboardInterrupt: ignored

In [None]:
train_and_evaluate(fc_model, compute_student_loss)

# Test accuracy vs. tempreture curve

In [None]:
# your code start from here for step 6
T = [1,2,4,16,32,64]
acc_lst = []
for temperature in T:
  #global DISTILLATION_TEMPERATURE
  DISTILLATION_TEMPERATURE = temperature
  acc = train_and_evaluate(fc_model, compute_student_loss)
  acc_lst.append(acc)
  #print(acc_lst)



In [None]:
import matplotlib.pyplot as plt

plt.figure(figsize=(10,6))
plt.title('Student Test Accuracy vs Temperature')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')

for i, t in enumerate(T):
    plt.plot(acc_lst[i], label=f'Temperature={t}', linestyle='-', marker='o')

plt.legend()
plt.show()


In [None]:
import numpy as np
acc_lst_np = np.array(acc_lst)
avg_acc = np.mean(acc_lst_np,axis = 1)
T = [1,2,4,16,32,64]
plt.figure(figsize=(10,6))
plt.title('Student Test Accuracy vs Temperature')
plt.xlabel('Temperature')
plt.ylabel('Accuracy')
x = list(range(len(T)))

plt.plot(avg_acc, linestyle='-', marker='o')
plt.xticks(x, T)

plt.legend()
plt.show()


In [None]:
# your code start from here for step 6
T = [0.3,0.4,0.5,0.6,0.7]
acc_lst = []
for temperature in T:
  #global DISTILLATION_TEMPERATURE
  DISTILLATION_TEMPERATURE = 64
  ALPHA = temperature
  acc = train_and_evaluate(fc_model, compute_student_loss)
  acc_lst.append(acc)
  #print(acc_lst)

# Train student from scratch

In [None]:
# Build fully connected student.
fc_model_no_distillation = tf.keras.Sequential()

# your code start from here for step 7
fc_model_no_distillation.add(tf.keras.layers.Flatten())
fc_model_no_distillation.add(tf.keras.layers.Dense(784,activation = 'relu'))
fc_model_no_distillation.add(tf.keras.layers.Dense(784,activation = 'relu'))
fc_model_no_distillation.add(tf.keras.layers.Dense(10))



#@test {"output": "ignore"}

def compute_plain_cross_entropy_loss(images, labels):
  """Compute plain loss for given images and labels.

  For fair comparison and convenience, this function also performs a
  LogSumExp over subclasses, but does not perform subclass distillation.

  Args:
    images: Tensor representing a batch of images.
    labels: Tensor representing a batch of labels.

  Returns:
    Scalar loss Tensor.
  """
  # your code start from here for step 7

  student_subclass_logits = fc_model_no_distillation(images, training=True)
  cross_entropy_loss = tf.nn.softmax_cross_entropy_with_logits(labels,student_subclass_logits)

  return cross_entropy_loss

DISTILLATION_TEMPERATURE = 4
train_and_evaluate(fc_model_no_distillation, compute_plain_cross_entropy_loss)

# Comparing the teacher and student model (number of of parameters and FLOPs)

In [None]:
# your code start from here for step 8
cnn_model.summary()
fc_model.summary()

In [None]:
import tensorflow as tf

def get_flops(model):
    session = tf.compat.v1.Session()
    graph = tf.compat.v1.get_default_graph()

    with graph.as_default():
        with session.as_default():
            model = tf.keras.models.clone_model(model)
            run_meta = tf.compat.v1.RunMetadata()
            opts = tf.compat.v1.profiler.ProfileOptionBuilder.float_operation()

            flops = tf.compat.v1.profiler.profile(graph=graph,
                                                  run_meta=run_meta, cmd='op', options=opts)
            return flops.total_float_ops

In [None]:
print(f"Total FLOPs: {get_flops(cnn_model)}")
print(f"Total FLOPs: {get_flops(fc_model)}")


https://stackoverflow.com/questions/45085938/tensorflow-is-there-a-way-to-measure-flops-for-a-model

# Implementing the state-of-the-art KD algorithm paper 9 - Yixin Feng

In [None]:
#@test {"output": "ignore"}
from tensorflow.keras.losses import KLDivergence

ALPHA = 0.6 # task balance between cross-entropy and distillation loss
DISTILLATION_TEMPERATURE = 4. #temperature hyperparameter


def distillation_loss_new(teacher_logits: tf.Tensor, student_logits: tf.Tensor,
                      temperature: Union[float, tf.Tensor]):
  """Compute distillation loss.

  This function computes cross entropy between softened logits and softened
  targets. The resulting loss is scaled by the squared temperature so that
  the gradient magnitude remains approximately constant as the temperature is
  changed. For reference, see Hinton et al., 2014, "Distilling the knowledge in
  a neural network."

  Args:
    teacher_logits: A Tensor of logits provided by the teacher.
    student_logits: A Tensor of logits provided by the student, of the same
      shape as `teacher_logits`.
    temperature: Temperature to use for distillation.

  Returns:
    A scalar Tensor containing the distillation loss.
  """
 # your code start from here for step 3
  soft_targets = tf.nn.softmax(teacher_logits / temperature, axis=-1)
  soft_student_logits = tf.nn.softmax(student_logits / temperature, axis=-1)
  return tf.reduce_mean(KLDivergence()(soft_targets, soft_student_logits)) * (temperature ** 2)
  # soft_targets = tf.nn.softmax(teacher_logits/temperature)

  # return tf.reduce_mean(
  #     tf.nn.softmax_cross_entropy_with_logits(
  #         soft_targets, student_logits / temperature)) * temperature ** 2

def compute_student_loss_new(images, labels,teacher_model,student_model):
  """Compute subclass knowledge distillation student loss for given images
     and labels.

  Args:
    images: Tensor representing a batch of images.
    labels: Tensor representing a batch of labels.

  Returns:
    Scalar loss Tensor.
  """
  student_subclass_logits = student_model(images, training=True)
  #print('student_subclass_logits:',student_subclass_logits.shape)
  # Compute subclass distillation loss between student subclass logits and
  # softened teacher subclass targets probabilities.

  # your code start from here for step 3

  teacher_subclass_logits = teacher_model(images, training=False)
  #print('teacher_subclass_logits:',teacher_subclass_logits.shape)
  distillation_loss_value = distillation_loss_new(teacher_subclass_logits,student_subclass_logits,DISTILLATION_TEMPERATURE)

  # Compute cross-entropy loss with hard targets.

  # your code start from here for step 3

  cross_entropy_loss_value = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels,student_subclass_logits)).numpy()
  #print()
  loss = ALPHA*distillation_loss_value+(1-ALPHA)*cross_entropy_loss_value
  return loss

In [None]:
from keras.api._v2.keras import activations
from tensorflow.python.ops.gen_nn_ops import conv2d
#@test {"output": "ignore"}

# Build CNN teacher.
cnn_model = tf.keras.Sequential()

# your code start from here for stpe 2
cnn_model.add(tf.keras.layers.Conv2D(filters= 32,kernel_size = (3,3),strides = (1,1),activation="relu"))
cnn_model.add(tf.keras.layers.MaxPool2D(pool_size=(2,2),strides =(1,1)))
cnn_model.add(tf.keras.layers.Conv2D(64,(3,3),(1,1),activation = 'relu'))
cnn_model.add(tf.keras.layers.MaxPool2D(pool_size=(2,2),strides =(2,2)))
cnn_model.add(tf.keras.layers.Flatten())
cnn_model.add(tf.keras.layers.Dropout(rate = 0.5))
cnn_model.add(tf.keras.layers.Dense(128,activation='relu'))
cnn_model.add(tf.keras.layers.Dropout(rate = 0.5))
cnn_model.add(tf.keras.layers.Dense(10))

# Build Teacher Assistant (TA) model
ta_model = tf.keras.Sequential()
ta_model.add(tf.keras.layers.Conv2D(filters=32, kernel_size=(3,3), strides=(1,1), activation="relu"))  # Half the filters of teacher
ta_model.add(tf.keras.layers.MaxPool2D(pool_size=(2,2), strides=(1,1)))
ta_model.add(tf.keras.layers.Flatten())
ta_model.add(tf.keras.layers.Dense(units=784, activation='relu'))
ta_model.add(tf.keras.layers.Dense(10))
# ta_model.add(tf.keras.layers.Dropout(rate = 0.5))
# ta_model.add(tf.keras.layers.Dense(128,activation='relu'))
# ta_model.add(tf.keras.layers.Dropout(rate = 0.5))
# ta_model.add(tf.keras.layers.Dense(10))


# Build fully connected student.
fc_model = tf.keras.Sequential()
fc_model.add(tf.keras.layers.Flatten())
fc_model.add(tf.keras.layers.Dense(784,activation = 'relu'))
fc_model.add(tf.keras.layers.Dense(784,activation = 'relu'))
fc_model.add(tf.keras.layers.Dense(10))






In [None]:
@tf.function
def compute_num_correct(model, images, labels):
  """Compute number of correctly classified images in a batch.

  Args:
    model: Instance of tf.keras.Model.
    images: Tensor representing a batch of images.
    labels: Tensor representing a batch of labels.

  Returns:
    Number of correctly classified images.
  """
  class_logits = model(images, training=False)
  return tf.reduce_sum(
      tf.cast(tf.math.equal(tf.argmax(class_logits, -1), tf.argmax(labels, -1)),
              tf.float32)), tf.argmax(class_logits, -1), tf.argmax(labels, -1)


def train_and_evaluate_new(model, compute_loss_fn,teacher_model):
  """Perform training and evaluation for a given model.

  Args:
    model: Instance of tf.keras.Model.
    compute_loss_fn: A function that computes the training loss given the
      images, and labels.
  """
  accuracies = []
  # your code start from here for step 4
  optimizer = tf.optimizers.Adam(learning_rate=0.001)

  for epoch in range(1, NUM_EPOCHS + 1):
    # Run training.
    print('Epoch {}: '.format(epoch), end='')
    for images, labels in mnist_train:
      with tf.GradientTape() as tape:
         # your code start from here for step 4

        loss_value = compute_loss_fn(images,labels,teacher_model,model)

      grads = tape.gradient(loss_value,model.trainable_variables)
      optimizer.apply_gradients(zip(grads, model.trainable_variables))

    # Run evaluation.
    num_correct = 0
    num_total = builder.info.splits['test'].num_examples
    for images, labels in mnist_test:
      # your code start from here for step 4
      num_correct += compute_num_correct(model,images,labels)[0]
    print("Class_accuracy: " + '{:.2f}%'.format(
        num_correct / num_total * 100))
    accuracies.append((num_correct / num_total * 100).numpy())
  return accuracies


In [None]:
train_and_evaluate(cnn_model, compute_teacher_loss)

In [None]:
train_and_evaluate_new(ta_model, compute_student_loss_new,cnn_model)

In [None]:
train_and_evaluate_new(fc_model, compute_student_loss_new,ta_model)

# Implementing the state-of-the-art KD algorithm Paper 2- Xiaohu Yang

In [None]:
ALPHA = 0.6 # task balance between cross-entropy and distillation loss
DISTILLATION_TEMPERATURE = 4. #temperature hyperparameter


def distillation_loss_new(teacher_logits: tf.Tensor, student_logits: tf.Tensor,
                      temperature: Union[float, tf.Tensor]):
  """Compute distillation loss.

  This function computes cross entropy between softened logits and softened
  targets. The resulting loss is scaled by the squared temperature so that
  the gradient magnitude remains approximately constant as the temperature is
  changed. For reference, see Hinton et al., 2014, "Distilling the knowledge in
  a neural network."

  Args:
    teacher_logits: A Tensor of logits provided by the teacher.
    student_logits: A Tensor of logits provided by the student, of the same
      shape as `teacher_logits`.
    temperature: Temperature to use for distillation.

  Returns:
    A scalar Tensor containing the distillation loss.
  """
 # your code start from here for step 3
  soft_targets = tf.nn.softmax(teacher_logits / temperature, axis=-1)
  soft_student_logits = tf.nn.softmax(student_logits / temperature, axis=-1)
  return tf.reduce_mean(KLDivergence()(soft_targets, soft_student_logits)) * (temperature ** 2)
  # soft_targets = tf.nn.softmax(teacher_logits/temperature)

  # return tf.reduce_mean(
  #     tf.nn.softmax_cross_entropy_with_logits(
  #         soft_targets, student_logits / temperature)) * temperature ** 2

def compute_student_loss_new(images, labels,teacher_model,student_model):
  """Compute subclass knowledge distillation student loss for given images
     and labels.

  Args:
    images: Tensor representing a batch of images.
    labels: Tensor representing a batch of labels.

  Returns:
    Scalar loss Tensor.
  """
  student_subclass_logits = student_model(images, training=True)
  #print('student_subclass_logits:',student_subclass_logits.shape)
  # Compute subclass distillation loss between student subclass logits and
  # softened teacher subclass targets probabilities.

  # your code start from here for step 3

  teacher_subclass_logits = teacher_model(images, training=False)
  #print('teacher_subclass_logits:',teacher_subclass_logits.shape)
  distillation_loss_value = distillation_loss_new(teacher_subclass_logits,student_subclass_logits,DISTILLATION_TEMPERATURE)

  # Compute cross-entropy loss with hard targets.

  # your code start from here for step 3

  cross_entropy_loss_value = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels,student_subclass_logits)).numpy()
  #print()
  loss = ALPHA*distillation_loss_value+(1-ALPHA)*cross_entropy_loss_value
  return loss

In [None]:
!pip install --upgrade tensorflow


In [None]:
from keras.api._v2.keras import activations
from tensorflow.python.ops.gen_nn_ops import conv2d


# Build CNN teacher.
cnn_model = tf.keras.Sequential()

# your code start from here for stpe 2
cnn_model.add(tf.keras.layers.Conv2D(filters= 32,kernel_size = (3,3),strides = (1,1),activation="relu"))
cnn_model.add(tf.keras.layers.MaxPool2D(pool_size=(2,2),strides =(1,1)))
cnn_model.add(tf.keras.layers.Conv2D(64,(3,3),(1,1),activation = 'relu'))
cnn_model.add(tf.keras.layers.MaxPool2D(pool_size=(2,2),strides =(2,2)))
cnn_model.add(tf.keras.layers.Flatten())
cnn_model.add(tf.keras.layers.Dropout(rate = 0.5))
cnn_model.add(tf.keras.layers.Dense(128,activation='relu'))
cnn_model.add(tf.keras.layers.Dropout(rate = 0.5))
cnn_model.add(tf.keras.layers.Dense(10))

#implement Early stopping

early_stop = tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=10, restore_best_weights=True)

# Build fully connected student.
fc_model = tf.keras.Sequential()
fc_model.add(tf.keras.layers.Flatten())
fc_model.add(tf.keras.layers.Dense(784,activation = 'relu'))
fc_model.add(tf.keras.layers.Dense(784,activation = 'relu'))
fc_model.add(tf.keras.layers.Dense(10))


In [None]:
@tf.function
def compute_num_correct(model, images, labels):
  """Compute number of correctly classified images in a batch.

  Args:
    model: Instance of tf.keras.Model.
    images: Tensor representing a batch of images.
    labels: Tensor representing a batch of labels.

  Returns:
    Number of correctly classified images.
  """
  class_logits = model(images, training=False)
  return tf.reduce_sum(
      tf.cast(tf.math.equal(tf.argmax(class_logits, -1), tf.argmax(labels, -1)),
              tf.float32)), tf.argmax(class_logits, -1), tf.argmax(labels, -1)


def train_and_evaluate_new(model, compute_loss_fn,teacher_model):
  """Perform training and evaluation for a given model.

  Args:
    model: Instance of tf.keras.Model.
    compute_loss_fn: A function that computes the training loss given the
      images, and labels.
  """
  accuracies = []
  # your code start from here for step 4
  optimizer = tf.optimizers.Adam(learning_rate=0.001)

  for epoch in range(1, NUM_EPOCHS + 1):
    # Run training.
    print('Epoch {}: '.format(epoch), end='')
    for images, labels in mnist_train:
      with tf.GradientTape() as tape:
         # your code start from here for step 4

        loss_value = compute_loss_fn(images,labels,teacher_model,model)

      grads = tape.gradient(loss_value,model.trainable_variables)
      optimizer.apply_gradients(zip(grads, model.trainable_variables))

    # Run evaluation.
    num_correct = 0
    num_total = builder.info.splits['test'].num_examples
    for images, labels in mnist_test:
      # your code start from here for step 4
      num_correct += compute_num_correct(model,images,labels)[0]
    print("Class_accuracy: " + '{:.2f}%'.format(
        num_correct / num_total * 100))
    accuracies.append((num_correct / num_total * 100).numpy())
  return accuracies

In [None]:
train_and_evaluate(cnn_model, compute_teacher_loss)

In [None]:
train_and_evaluate_new(ta_model, compute_student_loss_new,cnn_model)

In [None]:
train_and_evaluate_new(fc_model, compute_student_loss_new,ta_model)

# XAI method to explain models

In [None]:
!pip install lime

In [None]:
from lime import lime_image
from skimage.segmentation import mark_boundaries
import matplotlib.pyplot as plt
import numpy as np

# ... [Your existing code] ...

# Assume you have the following after training:
# `cnn_model` - Your trained CNN teacher model.
# `fc_model` - Your trained fully connected student model.

# LIME works with models that output probabilities
# Define prediction functions for both models
def cnn_predict(images):
    #images_rgb = transform_to_rgb(images)
    #print('cnn shape:',images.shape)
    images = tf.expand_dims(images, axis=-1)
    return cnn_model.predict(images[:, :, :, 0])

def fc_predict(images):
    #images = tf.expand_dims(images, axis=-1)
    #images_rgb = transform_to_rgb(images)
    #images = tf.reshape(images, (3, 784))
    #print('cnn shape:',images.shape)

    return fc_model.predict(images[:, :, :, 0])

def fc_nondistill_predict(images):
    #images = tf.expand_dims(images, axis=-1)
    #images_rgb = transform_to_rgb(images)
    #images = tf.reshape(images, (3, 784))
    #print('cnn shape:',images.shape)

    return fc_model_no_distillation.predict(images[:, :, :, 0])


# Create an explainer object
explainer = lime_image.LimeImageExplainer()

for images, labels in mnist_train.take(1):
    sample_images = images
    sample_labels = labels
sample_image = sample_images[2:3]
print(sample_image.shape)
print(sample_image[0, :, :, 0].shape)
# Explain using the cnn_model
explanation_cnn = explainer.explain_instance(sample_image[0, :, :, 0],
                                             cnn_predict,
                                             top_labels=5,
                                             hide_color=0,
                                             num_samples=1000)

# Visualize the explanation for the top class
temp, mask = explanation_cnn.get_image_and_mask(explanation_cnn.top_labels[0], positive_only=True, num_features=5, hide_rest=False)
plt.imshow(mark_boundaries(temp, mask))
plt.title('CNN Model Explanation')
plt.show()

# Explain using the fc_model
explanation_fc = explainer.explain_instance(sample_image[0, :, :, 0],
                                           fc_predict,
                                           top_labels=5,
                                           hide_color=0,
                                           num_samples=1000)

# Visualize the explanation for the top class
temp, mask = explanation_fc.get_image_and_mask(explanation_fc.top_labels[0], positive_only=True, num_features=5, hide_rest=False)
plt.imshow(mark_boundaries(temp, mask))
plt.title('FC Model Explanation')
plt.show()


# Explain using the fc_model_nondistill
explanation_fc_non_distill = explainer.explain_instance(sample_image[0, :, :, 0],
                                           fc_nondistill_predict,
                                           top_labels=5,
                                           hide_color=0,
                                           num_samples=1000)

# Visualize the explanation for the top class
temp, mask = explanation_fc.get_image_and_mask(explanation_fc_non_distill.top_labels[0], positive_only=True, num_features=5, hide_rest=False)
plt.imshow(mark_boundaries(temp, mask))
plt.title('FC Model with No Distillation Explanation')
plt.show()
