# Knowledge Distillation with TensorFlow and HuggingFace

This notebook demonstrates how to implement knowledge distillation using TensorFlow and HuggingFace Transformers.

Knowledge distillation is a process where a smaller model (student) learns from a larger pre-trained model (teacher).
        

## Step 1: Import Required Libraries

In [1]:

import tensorflow as tf
from transformers import TFAutoModelForSequenceClassification, AutoTokenizer
import tensorflow_datasets as tfds


## Step 2: Load Teacher and Student Models and Tokenizer

We load the pre-trained teacher model and a smaller student model, along with the tokenizer.
        

In [2]:

# Load teacher and student models and tokenizer
teacher_model_name = "bert-large-uncased"
student_model_name = "bert-base-uncased"

teacher_model = TFAutoModelForSequenceClassification.from_pretrained(teacher_model_name)
student_model = TFAutoModelForSequenceClassification.from_pretrained(student_model_name)
tokenizer = AutoTokenizer.from_pretrained(teacher_model_name)


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


config.json:   0%|          | 0.00/571 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/1.34G [00:00<?, ?B/s]

All PyTorch model weights were used when initializing TFBertForSequenceClassification.

Some weights or buffers of the TF 2.0 model TFBertForSequenceClassification were not initialized from the PyTorch model and are newly initialized: ['classifier.weight', 'classifier.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/440M [00:00<?, ?B/s]

All PyTorch model weights were used when initializing TFBertForSequenceClassification.

Some weights or buffers of the TF 2.0 model TFBertForSequenceClassification were not initialized from the PyTorch model and are newly initialized: ['classifier.weight', 'classifier.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

## Step 3: Load and Preprocess Dataset

We load the MRPC dataset from TensorFlow Datasets and define a function to tokenize the dataset.
        

In [3]:

# Load dataset
dataset = tfds.load('glue/mrpc', split='train')
batch_size = 128

# Function to tokenize and prepare inputs
def tokenize_function(sentence1, sentence2):
    sentence1 = [s.decode('utf-8') for s in sentence1.numpy()]
    sentence2 = [s.decode('utf-8') for s in sentence2.numpy()]
    inputs = tokenizer(sentence1, sentence2, truncation=True, padding='max_length', max_length=128)
    return inputs['input_ids'], inputs['attention_mask'], inputs['token_type_ids']


Downloading and preparing dataset 1.43 MiB (download: 1.43 MiB, generated: 1.74 MiB, total: 3.17 MiB) to /root/tensorflow_datasets/glue/mrpc/2.0.0...


Dl Completed...: 0 url [00:00, ? url/s]

Dl Size...: 0 MiB [00:00, ? MiB/s]

Generating splits...:   0%|          | 0/3 [00:00<?, ? splits/s]

Generating train examples...:   0%|          | 0/3668 [00:00<?, ? examples/s]

Shuffling /root/tensorflow_datasets/glue/mrpc/incomplete.IQRROW_2.0.0/glue-train.tfrecord*...:   0%|          …

Generating validation examples...:   0%|          | 0/408 [00:00<?, ? examples/s]

Shuffling /root/tensorflow_datasets/glue/mrpc/incomplete.IQRROW_2.0.0/glue-validation.tfrecord*...:   0%|     …

Generating test examples...:   0%|          | 0/1725 [00:00<?, ? examples/s]

Shuffling /root/tensorflow_datasets/glue/mrpc/incomplete.IQRROW_2.0.0/glue-test.tfrecord*...:   0%|          |…

Dataset glue downloaded and prepared to /root/tensorflow_datasets/glue/mrpc/2.0.0. Subsequent calls will reuse this data.


## Step 4: Tokenize the Dataset

We batch the dataset and apply the tokenization function.
        

In [4]:

# Prepare the batched dataset
def tokenize_batch(batch):
    sentence1 = batch['sentence1']
    sentence2 = batch['sentence2']

    input_ids, attention_mask, token_type_ids = tf.py_function(
        func=tokenize_function,
        inp=[sentence1, sentence2],
        Tout=[tf.int32, tf.int32, tf.int32]
    )

    input_ids.set_shape([batch_size, 128])
    attention_mask.set_shape([batch_size, 128])
    token_type_ids.set_shape([batch_size, 128])

    return input_ids, attention_mask, token_type_ids

def format_tokenized_output(input_ids, attention_mask, token_type_ids):
    return {
        'input_ids': input_ids,
        'attention_mask': attention_mask,
        'token_type_ids': token_type_ids
    }

# Tokenize the dataset
tokenized_datasets = dataset.batch(batch_size, drop_remainder=True).map(tokenize_batch).map(lambda ids, mask, type_ids: format_tokenized_output(ids, mask, type_ids))


## Step 5: Extract Teacher Logits

We pass the tokenized dataset through the teacher model to obtain the logits.
        

In [5]:

# Function to get teacher logits
def get_teacher_logits_tf(model, tokenized_dataset):
    logits = []
    for batch in tokenized_dataset:
        inputs = {k: v for k, v in batch.items() if k in tokenizer.model_input_names}
        outputs = model(inputs)
        logits.append(outputs.logits)
    return tf.concat(logits, axis=0)

# Get teacher logits
teacher_logits = get_teacher_logits_tf(teacher_model, tokenized_datasets)


## Step 6: Prepare Student Inputs

We prepare the student inputs dataset from the tokenized dataset.
        

In [6]:

# Prepare the student inputs dataset
def generate_student_inputs(tokenized_dataset):
    inputs = []
    for batch in tokenized_dataset:
        inputs.append({k: v for k, v in batch.items() if k in tokenizer.model_input_names})
    return inputs

student_inputs = generate_student_inputs(tokenized_datasets)


## Step 7: Create TensorFlow Datasets

We convert the inputs and logits to TensorFlow Datasets and zip them together for training.
        

In [7]:

# Convert inputs and logits to tf.data.Dataset
input_ids_dataset = tf.data.Dataset.from_tensor_slices(tf.concat([example['input_ids'] for example in student_inputs], axis=0))
attention_mask_dataset = tf.data.Dataset.from_tensor_slices(tf.concat([example['attention_mask'] for example in student_inputs], axis=0))
teacher_logits_dataset = tf.data.Dataset.from_tensor_slices(teacher_logits)

# Zip the datasets
train_dataset = tf.data.Dataset.zip(({
    'input_ids': input_ids_dataset,
    'attention_mask': attention_mask_dataset
}, teacher_logits_dataset))


## Step 8: Train the Student Model

We compile the student model with the KL Divergence loss function and an Adam optimizer, and then train it.
        

In [8]:

# Custom training loop
optimizer = tf.keras.optimizers.Adam(learning_rate=5e-5)
loss_fn = tf.keras.losses.KLDivergence()  # Use KLDivergence as the loss function for distillation

# Compile student model with appropriate loss function and optimizer
student_model.compile(optimizer=optimizer, loss=loss_fn)

# Prepare the dataset for model.fit
train_dataset = train_dataset.batch(batch_size)

# Train the student model using model.fit
student_model.fit(train_dataset, epochs=3)

print("Training complete.")


Epoch 1/3


Cause: for/else statement not yet supported


Cause: for/else statement not yet supported
Epoch 2/3
Epoch 3/3
Training complete.
