In [1]:
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Flatten, Max
from tensorflow.keras.losses import CategoricalCrossentropy
from tensorflow.keras.datasets import mnist
from tqdm.notebook import tqdm  # Import the tqdm library for progress bars
from sklearn.model_selection import train_test_split

In [2]:
# Load MNIST dataset
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0  # Normalize the data

# Reshape for the models
x_train = x_train.reshape(-1, 28, 28, 1)
x_test = x_test.reshape(-1, 28, 28, 1)

# One-hot encode the labels
y_train = tf.keras.utils.to_categorical(y_train, 10)
y_test = tf.keras.utils.to_categorical(y_test, 10)

# Define the teacher model (Small Scale Model)
teacher_model = Sequential([
    Flatten(input_shape=(28, 28, 1)),
    Dense(256, activation='relu'),
    Dense(128, activation='relu'),
    Dense(10, activation='softmax')
])

# Compile and train the teacher model
teacher_model.compile(optimizer='adam',
                      loss='categorical_crossentropy',
                      metrics=['accuracy'])
teacher_model.fit(x_train, y_train, epochs=5, batch_size=32, validation_split=0.2)

# Evaluate the teacher model
teacher_accuracy = teacher_model.evaluate(x_test, y_test)[1]
print(f"Teacher Model Accuracy: {teacher_accuracy:.2f}")

Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
Teacher Model Accuracy: 0.98


In [3]:
# Define the student model (Nano Scale Model)
student_model = Sequential([
    Flatten(input_shape=(28, 28, 1)),
    Dense(32, activation='relu'),
    Dense(16, activation='relu'),
    Dense(10, activation='softmax')
])

# Knowledge Distillation Loss Function
def distillation_loss(y_true, y_pred, teacher_logits, temperature=5):
    # Cross-entropy with soft predictions from the teacher
    soft_labels = tf.nn.softmax(teacher_logits / temperature)
    student_logits = tf.nn.softmax(y_pred / temperature)
    soft_loss = CategoricalCrossentropy()(soft_labels, student_logits)
    
    # Cross-entropy with ground-truth labels
    hard_loss = CategoricalCrossentropy()(y_true, y_pred)
    
    # Combine the two losses
    return 0.5 * soft_loss + 0.5 * hard_loss

# Custom Training Loop for Knowledge Distillation
optimizer = tf.keras.optimizers.Adam()
temperature = 5  # Temperature for softening probabilities
epochs = 5
batch_size = 32


# Split the dataset into training and validation sets
x_train, x_val, y_train, y_val = train_test_split(x_train, y_train, test_size=0.2, random_state=42)

# Prepare the training and validation datasets
train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train)).batch(batch_size)
val_dataset = tf.data.Dataset.from_tensor_slices((x_val, y_val)).batch(batch_size)


# Training loop with progress bar
for epoch in range(epochs):
    print(f"\nEpoch {epoch + 1}/{epochs}")
    
    # Training phase
    train_loss = 0.0
    train_steps = 0
    train_accuracy_metric = tf.keras.metrics.CategoricalAccuracy()
    
    # Wrap training dataset with tqdm for progress bar
    train_dataset_tqdm = tqdm(train_dataset, desc="Training Progress", unit="batch")
    for x_batch, y_batch in train_dataset_tqdm:
        with tf.GradientTape() as tape:
            # Teacher predictions
            teacher_logits = teacher_model(x_batch, training=False)
            
            # Student predictions
            y_pred = student_model(x_batch, training=True)
            
            # Compute the distillation loss
            loss = distillation_loss(y_batch, y_pred, teacher_logits, temperature)
        
        # Backpropagation
        gradients = tape.gradient(loss, student_model.trainable_variables)
        optimizer.apply_gradients(zip(gradients, student_model.trainable_variables))
        
        # Track training loss and accuracy
        train_loss += loss.numpy()
        train_accuracy_metric.update_state(y_batch, y_pred)
        train_steps += 1
        
        # Update tqdm description
        train_dataset_tqdm.set_postfix(loss=loss.numpy())

    avg_train_loss = train_loss / train_steps
    train_accuracy = train_accuracy_metric.result().numpy()
    print(f"  Training Loss: {avg_train_loss:.4f}, Training Accuracy: {train_accuracy:.4f}")
    
    # Validation phase
    val_loss = 0.0
    val_steps = 0
    val_accuracy_metric = tf.keras.metrics.CategoricalAccuracy()
    
    # Wrap validation dataset with tqdm for progress bar
    val_dataset_tqdm = tqdm(val_dataset, desc="Validation Progress", unit="batch")
    for x_batch_val, y_batch_val in val_dataset_tqdm:
        # Predictions on validation data
        val_logits = student_model(x_batch_val, training=False)
        teacher_logits_val = teacher_model(x_batch_val, training=False)
        
        # Compute validation loss
        loss = distillation_loss(y_batch_val, val_logits, teacher_logits_val, temperature)
        
        # Track validation loss and accuracy
        val_loss += loss.numpy()
        val_accuracy_metric.update_state(y_batch_val, val_logits)
        val_steps += 1
        
        # Update tqdm description
        val_dataset_tqdm.set_postfix(val_loss=loss.numpy())

    avg_val_loss = val_loss / val_steps
    val_accuracy = val_accuracy_metric.result().numpy()
    print(f"Validation Loss: {avg_val_loss:.4f}, Validation Accuracy: {val_accuracy:.4f}")


# Compile the student model
student_model.compile(optimizer='adam',
                      loss='categorical_crossentropy',
                      metrics=['accuracy'])

# Evaluate the model
student_accuracy = student_model.evaluate(x_test, y_test)[1]
print(f"Student Model Accuracy: {student_accuracy:.2f}")


Epoch 1/5


Training Progress:   0%|          | 0/1500 [00:00<?, ?batch/s]

  Training Loss: 1.3733, Training Accuracy: 0.8683


Validation Progress:   0%|          | 0/375 [00:00<?, ?batch/s]

Validation Loss: 1.2728, Validation Accuracy: 0.9300

Epoch 2/5


Training Progress:   0%|          | 0/1500 [00:00<?, ?batch/s]

  Training Loss: 1.2537, Training Accuracy: 0.9396


Validation Progress:   0%|          | 0/375 [00:00<?, ?batch/s]

Validation Loss: 1.2440, Validation Accuracy: 0.9478

Epoch 3/5


Training Progress:   0%|          | 0/1500 [00:00<?, ?batch/s]

  Training Loss: 1.2305, Training Accuracy: 0.9525


Validation Progress:   0%|          | 0/375 [00:00<?, ?batch/s]

Validation Loss: 1.2311, Validation Accuracy: 0.9549

Epoch 4/5


Training Progress:   0%|          | 0/1500 [00:00<?, ?batch/s]

  Training Loss: 1.2174, Training Accuracy: 0.9611


Validation Progress:   0%|          | 0/375 [00:00<?, ?batch/s]

Validation Loss: 1.2233, Validation Accuracy: 0.9586

Epoch 5/5


Training Progress:   0%|          | 0/1500 [00:00<?, ?batch/s]

  Training Loss: 1.2088, Training Accuracy: 0.9658


Validation Progress:   0%|          | 0/375 [00:00<?, ?batch/s]

Validation Loss: 1.2193, Validation Accuracy: 0.9615
Student Model Accuracy: 0.96
