In [None]:
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.applications import ResNet50
from tensorflow.keras.applications.resnet50 import preprocess_input
from tensorflow.keras.models import Model
from tensorflow.keras.models import load_model
from tensorflow.keras.layers import Dense, GlobalAveragePooling2D, Dropout, BatchNormalization,RandomFlip, RandomRotation, RandomZoom
from tensorflow.keras.optimizers import AdamW
import tensorflow as tf
from tensorflow.keras.datasets import cifar10
import numpy as np
import time
import random
import matplotlib.pyplot as plt
import os
import pandas as pd
from tensorflow.keras.callbacks import ReduceLROnPlateau
from tensorflow.keras.callbacks import EarlyStopping
from sklearn.metrics import classification_report, accuracy_score, confusion_matrix

In [None]:
# Mount google drive
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
# Load dataset
(x_train, y_train), (x_test, y_test) = cifar10.load_data()
print(f"Original Train set: {x_train.shape}")
print(f"Original Test set: {x_test.shape}")

Downloading data from https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz
[1m170498071/170498071[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 0us/step
Original Train set: (50000, 32, 32, 3)
Original Test set: (10000, 32, 32, 3)


In [None]:
# Reserve 5000 samples from the training set for validation
x_train, x_dev = x_train[:45000], x_train[45000:]
y_train, y_dev = y_train[:45000], y_train[45000:]

print(f"Train set: {x_train.shape}")
print(f"Dev set: {x_dev.shape}")
print(f"Test set: {x_test.shape}")

Train set: (45000, 32, 32, 3)
Dev set: (5000, 32, 32, 3)
Test set: (10000, 32, 32, 3)


In [None]:
# Define batch size
BATCH_SIZE = 64
AUTOTUNE = tf.data.AUTOTUNE  # Optimizes performance

# Function to resize and normalize images
def preprocess(image, label):
    image = tf.image.resize(image, (224, 224))  # Resize dynamically
    image = preprocess_input(image)  # resnet50 preprocess
    return image, label

# Convert datasets to tf.data.Dataset
train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
dev_dataset = tf.data.Dataset.from_tensor_slices((x_dev, y_dev))
test_dataset = tf.data.Dataset.from_tensor_slices((x_test, y_test))

# Apply preprocessing and batching
train_dataset = (
    train_dataset
    .map(preprocess, num_parallel_calls=AUTOTUNE)
    .batch(BATCH_SIZE)
    .prefetch(AUTOTUNE)
)

dev_dataset = (
    dev_dataset
    .map(preprocess, num_parallel_calls=AUTOTUNE)
    .batch(BATCH_SIZE)
    .prefetch(AUTOTUNE)
)

test_dataset = (
    test_dataset
    .map(preprocess, num_parallel_calls=AUTOTUNE)
    .batch(BATCH_SIZE)
    .prefetch(AUTOTUNE)
)

# Print dataset structure
print(train_dataset)
print(dev_dataset)
print(test_dataset)

<_PrefetchDataset element_spec=(TensorSpec(shape=(None, 224, 224, 3), dtype=tf.float32, name=None), TensorSpec(shape=(None, 1), dtype=tf.uint8, name=None))>
<_PrefetchDataset element_spec=(TensorSpec(shape=(None, 224, 224, 3), dtype=tf.float32, name=None), TensorSpec(shape=(None, 1), dtype=tf.uint8, name=None))>
<_PrefetchDataset element_spec=(TensorSpec(shape=(None, 224, 224, 3), dtype=tf.float32, name=None), TensorSpec(shape=(None, 1), dtype=tf.uint8, name=None))>


In [None]:
# Path to the trained teacher model on Drive
teacher_model_path = "/content/drive/MyDrive/saved_models/resnet50_cifar10.h5"

# Load the trained teacher model
teacher_model = load_model(teacher_model_path)

# Evaluate the teacher model on the test dataset
test_loss, test_acc = teacher_model.evaluate(test_dataset)

# Print the test accuracy
print(f"Teacher model test accuracy: {test_acc:.4f}")



[1m157/157[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m40s[0m 170ms/step - accuracy: 0.9530 - loss: 0.2024
Teacher model test accuracy: 0.9545


In [None]:
from tensorflow.keras.applications import MobileNet
from tensorflow.keras.layers import GlobalAveragePooling2D, Dense
from tensorflow.keras.models import Model

# Load MobileNetV1 without the classification head
base_model = MobileNet(input_shape=(224, 224, 3), include_top=False, weights='imagenet')

# Freeze all layers initially
base_model.trainable = False

# Unfreeze the top N layers (like the last 20 layers)
for layer in base_model.layers[-5:]:
    layer.trainable = True

# Add custom classification layers for CIFAR-10
x = GlobalAveragePooling2D()(base_model.output)
x = Dense(128, activation='relu')(x)
output = Dense(10, activation='softmax')(x)  # CIFAR-10 has 10 classes

# Create the student model
student_model = Model(inputs=base_model.input, outputs=output)

# Compile the student model
student_model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])

print(student_model.summary())

Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/mobilenet/mobilenet_1_0_224_tf_no_top.h5
[1m17225924/17225924[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 0us/step


None


In [None]:
import os
import keras
from keras import layers
from keras import ops
import numpy as np

class Distiller(keras.Model):
    def __init__(self, student, teacher):
        super().__init__()
        self.teacher = teacher
        self.student = student

    def compile(
        self,
        optimizer,
        metrics,
        student_loss_fn,
        distillation_loss_fn,
        alpha=0.1,
        temperature=3,
    ):
        """Configure the distiller.

        Args:
            optimizer: Keras optimizer for the student weights
            metrics: Keras metrics for evaluation
            student_loss_fn: Loss function of difference between student
                predictions and ground-truth
            distillation_loss_fn: Loss function of difference between soft
                student predictions and soft teacher predictions
            alpha: weight to student_loss_fn and 1-alpha to distillation_loss_fn
            temperature: Temperature for softening probability distributions.
                Larger temperature gives softer distributions.
        """
        super().compile(optimizer=optimizer, metrics=metrics)
        self.student_loss_fn = student_loss_fn
        self.distillation_loss_fn = distillation_loss_fn
        self.alpha = alpha
        self.temperature = temperature

    def compute_loss(
        self, x=None, y=None, y_pred=None, sample_weight=None, allow_empty=False
    ):
        teacher_pred = self.teacher(x, training=False)
        student_loss = self.student_loss_fn(y, y_pred)

        distillation_loss = self.distillation_loss_fn(
            ops.softmax(teacher_pred / self.temperature, axis=1),
            ops.softmax(y_pred / self.temperature, axis=1),
        ) * (self.temperature**2)

        loss = self.alpha * student_loss + (1 - self.alpha) * distillation_loss
        return loss

    def call(self, x):
        return self.student(x)


In [None]:
class AlphaScheduler(keras.callbacks.Callback):
    def __init__(self, distiller, start_alpha=0.1, end_alpha=0.9, total_epochs=20):
        super().__init__()
        self.distiller = distiller
        self.start_alpha = start_alpha
        self.end_alpha = end_alpha
        self.total_epochs = total_epochs

    def on_epoch_begin(self, epoch, logs=None):
        # Linearly increase alpha over epochs
        new_alpha = self.start_alpha + (self.end_alpha - self.start_alpha) * (epoch / self.total_epochs)
        self.distiller.alpha = new_alpha
        print(f"\nEpoch {epoch + 1}: Updated alpha to {self.distiller.alpha:.4f}")

# Initialize and compile distiller
distiller = Distiller(student=student_model, teacher=teacher_model)
distiller.compile(
    optimizer=keras.optimizers.Adam(learning_rate=1e-3),
    metrics=[keras.metrics.SparseCategoricalAccuracy()],
    student_loss_fn=keras.losses.SparseCategoricalCrossentropy(from_logits=False),
    distillation_loss_fn=keras.losses.KLDivergence(),
    alpha=0.1,
    temperature=10,
)

# Distill teacher to student
distiller.fit(train_dataset, validation_data=dev_dataset, epochs=5)
# Evaluate student on test dataset
distiller.evaluate(test_dataset)

Epoch 1/5
[1m704/704[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m312s[0m 387ms/step - loss: 0.2581 - sparse_categorical_accuracy: 0.2577 - val_loss: 0.1840 - val_sparse_categorical_accuracy: 0.5434
Epoch 2/5
[1m704/704[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m262s[0m 327ms/step - loss: 0.1721 - sparse_categorical_accuracy: 0.5580 - val_loss: 0.1629 - val_sparse_categorical_accuracy: 0.6220
Epoch 3/5
[1m704/704[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m236s[0m 335ms/step - loss: 0.1611 - sparse_categorical_accuracy: 0.6099 - val_loss: 0.1540 - val_sparse_categorical_accuracy: 0.6500
Epoch 4/5
[1m704/704[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m262s[0m 335ms/step - loss: 0.1544 - sparse_categorical_accuracy: 0.6386 - val_loss: 0.1536 - val_sparse_categorical_accuracy: 0.6382
Epoch 5/5
[1m704/704[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m256s[0m 326ms/step - loss: 0.1506 - sparse_categorical_accuracy: 0.6561 - val_loss: 0.1551 - val_sparse_categorical_ac

[0.15838569402694702, 0.6432999968528748]

In [None]:
data_augmentation = tf.keras.Sequential([
    RandomFlip("horizontal"),  # Random horizontal flipping
    RandomRotation(0.1),       # Random rotation (10% of 360 degrees)
    RandomZoom(0.1)            # Random zoom
])

# Function to resize and normalize images
def preprocess(image, label, augment=False):
    image = tf.image.resize(image, (224, 224))  # Resize dynamically

    if augment:
        image = data_augmentation(image)  # Apply augmentations

    image = preprocess_input(image)  # Normalize for MobileNet
    return image, label


# Apply preprocessing and batching
train_dataset = (
    tf.data.Dataset.from_tensor_slices((x_train, y_train))
    .shuffle(buffer_size=10000)  # Shuffle data for randomness
    .map(lambda x, y: preprocess(x, y, augment=True), num_parallel_calls=AUTOTUNE)
    .batch(BATCH_SIZE)
    .prefetch(AUTOTUNE)
)


# Print dataset structure
print(train_dataset)
print(dev_dataset)
print(test_dataset)
# Making Sure All layers are trainable
for layer in student_model.layers:
    layer.trainable = True

# Compile the student model for fine-tuning
student_model.compile(
    optimizer=keras.optimizers.Adam(learning_rate=1e-3),  # Lower LR for fine-tuning
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=False),
    metrics=[keras.metrics.SparseCategoricalAccuracy()]
)

# Fine-tune on the training data
student_model.fit(train_dataset, validation_data=dev_dataset, epochs=5)
student_model.evaluate(test_dataset)

<_PrefetchDataset element_spec=(TensorSpec(shape=(None, 224, 224, 3), dtype=tf.float32, name=None), TensorSpec(shape=(None, 1), dtype=tf.uint8, name=None))>
<_PrefetchDataset element_spec=(TensorSpec(shape=(None, 224, 224, 3), dtype=tf.float32, name=None), TensorSpec(shape=(None, 1), dtype=tf.uint8, name=None))>
<_PrefetchDataset element_spec=(TensorSpec(shape=(None, 224, 224, 3), dtype=tf.float32, name=None), TensorSpec(shape=(None, 1), dtype=tf.uint8, name=None))>
Epoch 1/5
[1m704/704[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m587s[0m 792ms/step - loss: 0.7495 - sparse_categorical_accuracy: 0.7431 - val_loss: 0.9087 - val_sparse_categorical_accuracy: 0.7114
Epoch 2/5
[1m704/704[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m551s[0m 782ms/step - loss: 0.3908 - sparse_categorical_accuracy: 0.8672 - val_loss: 0.4992 - val_sparse_categorical_accuracy: 0.8372
Epoch 3/5
[1m704/704[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m555s[0m 788ms/step - loss: 0.3122 - sparse_categorica

[0.43759921193122864, 0.8607000112533569]

In [None]:
# Save the student model and distiller weights
student_model.save('/content/drive/MyDrive/saved_models/MobileNet_student_model.h5')



In [None]:
# Convert student model to TFLite
converter = tf.lite.TFLiteConverter.from_keras_model(student_model)
tflite_model = converter.convert()

# Save the model
with open('/content/drive/MyDrive/saved_models/student_model.tflite', 'wb') as f:
    f.write(tflite_model)


Saved artifact at '/tmp/tmpvtedvsom'. The following endpoints are available:

* Endpoint 'serve'
  args_0 (POSITIONAL_ONLY): TensorSpec(shape=(None, 224, 224, 3), dtype=tf.float32, name='keras_tensor_374')
Output Type:
  TensorSpec(shape=(None, 10), dtype=tf.float32, name=None)
Captures:
  137691982733072: TensorSpec(shape=(), dtype=tf.resource, name=None)
  137691982737872: TensorSpec(shape=(), dtype=tf.resource, name=None)
  137691794443920: TensorSpec(shape=(), dtype=tf.resource, name=None)
  137691982732304: TensorSpec(shape=(), dtype=tf.resource, name=None)
  137691982732112: TensorSpec(shape=(), dtype=tf.resource, name=None)
  137691794444304: TensorSpec(shape=(), dtype=tf.resource, name=None)
  137691794444880: TensorSpec(shape=(), dtype=tf.resource, name=None)
  137691794445264: TensorSpec(shape=(), dtype=tf.resource, name=None)
  137691794445072: TensorSpec(shape=(), dtype=tf.resource, name=None)
  137691794443536: TensorSpec(shape=(), dtype=tf.resource, name=None)
  137691794

In [None]:
# Export the student model as a SavedModel (Protobuf format)
student_model.export('/content/drive/MyDrive/saved_models/student_model_pb')


Saved artifact at '/content/drive/MyDrive/saved_models/student_model_pb'. The following endpoints are available:

* Endpoint 'serve'
  args_0 (POSITIONAL_ONLY): TensorSpec(shape=(None, 224, 224, 3), dtype=tf.float32, name='keras_tensor_374')
Output Type:
  TensorSpec(shape=(None, 10), dtype=tf.float32, name=None)
Captures:
  137691982733072: TensorSpec(shape=(), dtype=tf.resource, name=None)
  137691982737872: TensorSpec(shape=(), dtype=tf.resource, name=None)
  137691794443920: TensorSpec(shape=(), dtype=tf.resource, name=None)
  137691982732304: TensorSpec(shape=(), dtype=tf.resource, name=None)
  137691982732112: TensorSpec(shape=(), dtype=tf.resource, name=None)
  137691794444304: TensorSpec(shape=(), dtype=tf.resource, name=None)
  137691794444880: TensorSpec(shape=(), dtype=tf.resource, name=None)
  137691794445264: TensorSpec(shape=(), dtype=tf.resource, name=None)
  137691794445072: TensorSpec(shape=(), dtype=tf.resource, name=None)
  137691794443536: TensorSpec(shape=(), dtype