##Distillation Step by Step
Rational-based knowledge distillation (could be under Response-based method).
### Imports

In [None]:
!pip install -U transformers datasets scikit-learn

Collecting transformers
  Downloading transformers-4.53.1-py3-none-any.whl.metadata (40 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m40.9/40.9 kB[0m [31m1.9 MB/s[0m eta [36m0:00:00[0m
Collecting datasets
  Downloading datasets-3.6.0-py3-none-any.whl.metadata (19 kB)
Collecting scikit-learn
  Downloading scikit_learn-1.7.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (17 kB)
Collecting fsspec<=2025.3.0,>=2023.1.0 (from fsspec[http]<=2025.3.0,>=2023.1.0->datasets)
  Downloading fsspec-2025.3.0-py3-none-any.whl.metadata (11 kB)
Downloading transformers-4.53.1-py3-none-any.whl (10.8 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m10.8/10.8 MB[0m [31m110.7 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading datasets-3.6.0-py3-none-any.whl (491 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m491.5/491.5 kB[0m [31m39.1 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading scikit_learn-1.7.0-cp311-cp311-manylinux_

In [None]:
from datasets import load_dataset
import torch
from torch.utils.data import Dataset, DataLoader, random_split
from transformers import T5Tokenizer, T5ForConditionalGeneration, get_scheduler
from tqdm.auto import tqdm
from sklearn.metrics import accuracy_score

###Sample 10% of the training data for validation
Now, you're loading from your own repo the outputs from the teacher model.

In [None]:
from huggingface_hub import notebook_login
notebook_login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [None]:
ds = load_dataset("tosin/cos_e", split=['train[:8000]', 'test'])
train_dataset = ds[0]
test_set = ds[1]

val_size = int(0.1 * len(train_dataset))  # 10% for validation
train_size = len(train_dataset) - val_size

# Create the split
train_set, val_set = random_split(
    train_dataset,
    [train_size, val_size],
    generator=torch.Generator().manual_seed(42)
)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


README.md:   0%|          | 0.00/447 [00:00<?, ?B/s]

train-00000-of-00001.parquet:   0%|          | 0.00/6.04M [00:00<?, ?B/s]

test-00000-of-00001.parquet:   0%|          | 0.00/123k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/9741 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/1221 [00:00<?, ? examples/s]

In [None]:
print(f"Input: {train_set[0]['input']} | \nLabel: {train_set[0]['label']}")
print("Rationale:", train_set[0]['rationale'])

Input: What can happen if you attempt too much learning?
Answer Choices:
(a) headaches
(b) intelligence
(c) growth
(d) knowing more
(e) education | 
Label: headaches
Rationale: Attempting too much learning can lead to cognitive overload, where the brain processes an excessive amount of information simultaneously, resulting in mental fatigue and discomfort. This condition is commonly known as "overlearning" or "information overload." Symptoms include difficulty concentrating, irritability, and physical manifestations such as headaches due to strain on the brain's neural pathways. As one tries to absorb more knowledge rapidly, the brain struggles to manage the influx of data efficiently, leading to heightened stress responses and ultimately manifesting as tension and pain in the head. Headaches serve as a protective mechanism, signaling the body to slow down and rest before continuing with the intense learning process. Thus, among the given choices, headaches are the most likely conseque

In [None]:
class T5StepByStepDistillationDataset(Dataset):
    def __init__(self, dataset, tokenizer, max_length=526):
        self.dataset = dataset
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        example = self.dataset[idx]

        # Create input with both label and rationale prefixes
        input_text_label = f"[label] {example['input']}"
        input_text_rationale = f"[rationale] {example['input']}"

        # Tokenize inputs for both tasks
        input_encoding_label = self.tokenizer(
            input_text_label,
            padding="max_length",
            truncation=True,
            max_length=self.max_length,
            return_tensors="pt"
        )

        input_encoding_rationale = self.tokenizer(
            input_text_rationale,
            padding="max_length",
            truncation=True,
            max_length=self.max_length,
            return_tensors="pt"
        )

        # Tokenize targets for both tasks
        label_encoding = self.tokenizer(
            example['label'],
            padding="max_length",
            truncation=True,
            max_length=self.max_length,
            return_tensors="pt"
        )

        rationale_encoding = self.tokenizer(
            example['rationale'],
            padding="max_length",
            truncation=True,
            max_length=self.max_length,
            return_tensors="pt"
        )

        # Replace padding token ID with -100
        label_targets = label_encoding.input_ids.clone()
        label_targets[label_targets == self.tokenizer.pad_token_id] = -100

        rationale_targets = rationale_encoding.input_ids.clone()
        rationale_targets[rationale_targets == self.tokenizer.pad_token_id] = -100

        return {
            'input_ids_label': input_encoding_label.input_ids.squeeze(),
            'attention_mask_label': input_encoding_label.attention_mask.squeeze(),
            'labels_label': label_targets.squeeze(),
            'input_ids_rationale': input_encoding_rationale.input_ids.squeeze(),
            'attention_mask_rationale': input_encoding_rationale.attention_mask.squeeze(),
            'labels_rationale': rationale_targets.squeeze()
        }

In [None]:
class T5TestDataset(Dataset):
    def __init__(self, dataset, tokenizer, max_length=128):
        self.dataset = dataset
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        example = self.dataset[idx]

        # Only predict labels for testing
        input_text = f"[label] {example['input']}"

        # Tokenize inputs
        input_encoding = self.tokenizer(
            input_text,
            padding="max_length",
            truncation=True,
            max_length=self.max_length,
            return_tensors="pt"
        )

        # Tokenize labels
        label_encoding = self.tokenizer(
            example['label'],
            padding="max_length",
            truncation=True,
            max_length=self.max_length,
            return_tensors="pt"
        )

        # Replace padding token ID with -100
        labels = label_encoding.input_ids.clone()
        labels[labels == self.tokenizer.pad_token_id] = -100

        return {
            'input_ids': input_encoding.input_ids.squeeze(),
            'attention_mask': input_encoding.attention_mask.squeeze(),
            'labels': labels.squeeze(),
            'reference': example['label']  # Store original label for evaluation
        }

###Validation function

In [None]:
def validate_model(model, val_loader, device, lambda_value=0.2):
    model.eval()
    total_val_loss = 0

    with torch.no_grad():
        for batch in val_loader:
            # Process label task
            input_ids_label = batch['input_ids_label'].to(device)
            attention_mask_label = batch['attention_mask_label'].to(device)
            labels_label = batch['labels_label'].to(device)

            # Process rationale task
            input_ids_rationale = batch['input_ids_rationale'].to(device)
            attention_mask_rationale = batch['attention_mask_rationale'].to(device)
            labels_rationale = batch['labels_rationale'].to(device)

            # Get losses for both tasks
            outputs_label = model(
                input_ids=input_ids_label,
                attention_mask=attention_mask_label,
                labels=labels_label
            )

            outputs_rationale = model(
                input_ids=input_ids_rationale,
                attention_mask=attention_mask_rationale,
                labels=labels_rationale
            )

            # Calculate combined loss: L = Llabel + λLrationale
            loss_label = outputs_label.loss
            loss_rationale = outputs_rationale.loss
            combined_loss = loss_label + lambda_value * loss_rationale

            total_val_loss += combined_loss.item()

    return total_val_loss / len(val_loader)

###Training

In [None]:
# T5-small (the student) has 60 million parameters (much smaller than the teacher)
def train_step_by_step_distillation(train_set, val_set, model_name="google-t5/t5-small", max_length=526,
                                    batch_size=16, num_epochs=25, learning_rate=3e-5, weight_decay=0.01,
                                    lambda_value=0.2):
    # Load model and tokenizer
    tokenizer = T5Tokenizer.from_pretrained(model_name)
    model = T5ForConditionalGeneration.from_pretrained(model_name)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)

    # Prepare datasets for combined training
    train_dataset = T5StepByStepDistillationDataset(train_set, tokenizer, max_length)
    val_dataset = T5StepByStepDistillationDataset(val_set, tokenizer, max_length)

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size)

    # Initialize optimizer and scheduler
    optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
    total_steps = len(train_loader) * num_epochs

    lr_scheduler = get_scheduler(
        "linear",
        optimizer=optimizer,
        num_warmup_steps=0,
        num_training_steps=total_steps
    )

    # Store the best model
    best_val_loss = float('inf')
    best_model_state = None

    # Training loop
    for epoch in range(num_epochs):
        model.train()
        total_train_loss = 0
        progress_bar = tqdm(train_loader, desc=f"Training - Epoch {epoch+1}/{num_epochs}")

        for batch in progress_bar:
            # Process label task
            input_ids_label = batch['input_ids_label'].to(device)
            attention_mask_label = batch['attention_mask_label'].to(device)
            labels_label = batch['labels_label'].to(device)

            # Process rationale task
            input_ids_rationale = batch['input_ids_rationale'].to(device)
            attention_mask_rationale = batch['attention_mask_rationale'].to(device)
            labels_rationale = batch['labels_rationale'].to(device)

            # Get losses for both tasks
            outputs_label = model(
                input_ids=input_ids_label,
                attention_mask=attention_mask_label,
                labels=labels_label
            )

            outputs_rationale = model(
                input_ids=input_ids_rationale,
                attention_mask=attention_mask_rationale,
                labels=labels_rationale
            )

            # Calculate combined loss: L = Llabel + λLrationale
            loss_label = outputs_label.loss
            loss_rationale = outputs_rationale.loss
            combined_loss = loss_label + lambda_value * loss_rationale

            total_train_loss += combined_loss.item()

            # Backward pass
            optimizer.zero_grad()
            combined_loss.backward()
            optimizer.step()
            lr_scheduler.step()

            progress_bar.set_postfix({
                "loss": combined_loss.item(),
                "label_loss": loss_label.item(),
                "rationale_loss": loss_rationale.item()
            })

        avg_train_loss = total_train_loss / len(train_loader)
        print(f"Epoch {epoch+1} - Average training loss: {avg_train_loss:.4f}")

        # Validation
        print("Running validation...")
        val_loss = validate_model(model, val_loader, device, lambda_value)
        print(f"Epoch {epoch+1} - Validation loss: {val_loss:.4f}")

        # Save best model
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            best_model_state = model.state_dict().copy()
            print(f"New best model saved with validation loss: {val_loss:.4f}")

    # Load best model for final evaluation
    if best_model_state is not None:
        model.load_state_dict(best_model_state)
        print(f"Loaded best model with validation loss: {best_val_loss:.4f}")

    print("Training completed!")
    return model, tokenizer

In [None]:
model, tokenizer = train_step_by_step_distillation(train_set, val_set)

tokenizer_config.json: 0.00B [00:00, ?B/s]

spiece.model:   0%|          | 0.00/792k [00:00<?, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565


config.json: 0.00B [00:00, ?B/s]

model.safetensors:   0%|          | 0.00/242M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/147 [00:00<?, ?B/s]

Training - Epoch 1/25:   0%|          | 0/450 [00:00<?, ?it/s]

Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.48.0. You should pass an instance of `EncoderDecoderCache` instead, e.g. `past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`.


Epoch 1 - Average training loss: 2.5182
Running validation...
Epoch 1 - Validation loss: 1.5370
New best model saved with validation loss: 1.5370


Training - Epoch 2/25:   0%|          | 0/450 [00:00<?, ?it/s]

Epoch 2 - Average training loss: 1.7566
Running validation...
Epoch 2 - Validation loss: 1.4114
New best model saved with validation loss: 1.4114


Training - Epoch 3/25:   0%|          | 0/450 [00:00<?, ?it/s]

Epoch 3 - Average training loss: 1.6300
Running validation...
Epoch 3 - Validation loss: 1.3434
New best model saved with validation loss: 1.3434


Training - Epoch 4/25:   0%|          | 0/450 [00:00<?, ?it/s]

Epoch 4 - Average training loss: 1.5618
Running validation...
Epoch 4 - Validation loss: 1.3072
New best model saved with validation loss: 1.3072


Training - Epoch 5/25:   0%|          | 0/450 [00:00<?, ?it/s]

Epoch 5 - Average training loss: 1.4987
Running validation...
Epoch 5 - Validation loss: 1.2800
New best model saved with validation loss: 1.2800


Training - Epoch 6/25:   0%|          | 0/450 [00:00<?, ?it/s]

Epoch 6 - Average training loss: 1.4472
Running validation...
Epoch 6 - Validation loss: 1.2585
New best model saved with validation loss: 1.2585


Training - Epoch 7/25:   0%|          | 0/450 [00:00<?, ?it/s]

Epoch 7 - Average training loss: 1.4162
Running validation...
Epoch 7 - Validation loss: 1.2498
New best model saved with validation loss: 1.2498


Training - Epoch 8/25:   0%|          | 0/450 [00:00<?, ?it/s]

Epoch 8 - Average training loss: 1.3703
Running validation...
Epoch 8 - Validation loss: 1.2426
New best model saved with validation loss: 1.2426


Training - Epoch 9/25:   0%|          | 0/450 [00:00<?, ?it/s]

Epoch 9 - Average training loss: 1.3419
Running validation...
Epoch 9 - Validation loss: 1.2341
New best model saved with validation loss: 1.2341


Training - Epoch 10/25:   0%|          | 0/450 [00:00<?, ?it/s]

Epoch 10 - Average training loss: 1.3186
Running validation...
Epoch 10 - Validation loss: 1.2293
New best model saved with validation loss: 1.2293


Training - Epoch 11/25:   0%|          | 0/450 [00:00<?, ?it/s]

Epoch 11 - Average training loss: 1.2977
Running validation...
Epoch 11 - Validation loss: 1.2305


Training - Epoch 12/25:   0%|          | 0/450 [00:00<?, ?it/s]

Epoch 12 - Average training loss: 1.2747
Running validation...
Epoch 12 - Validation loss: 1.2258
New best model saved with validation loss: 1.2258


Training - Epoch 13/25:   0%|          | 0/450 [00:00<?, ?it/s]

Epoch 13 - Average training loss: 1.2477
Running validation...
Epoch 13 - Validation loss: 1.2270


Training - Epoch 14/25:   0%|          | 0/450 [00:00<?, ?it/s]

Epoch 14 - Average training loss: 1.2338
Running validation...
Epoch 14 - Validation loss: 1.2283


Training - Epoch 15/25:   0%|          | 0/450 [00:00<?, ?it/s]

Epoch 15 - Average training loss: 1.2248
Running validation...
Epoch 15 - Validation loss: 1.2274


Training - Epoch 16/25:   0%|          | 0/450 [00:00<?, ?it/s]

Epoch 16 - Average training loss: 1.2183
Running validation...
Epoch 16 - Validation loss: 1.2268


Training - Epoch 17/25:   0%|          | 0/450 [00:00<?, ?it/s]

Epoch 17 - Average training loss: 1.2017
Running validation...
Epoch 17 - Validation loss: 1.2279


Training - Epoch 18/25:   0%|          | 0/450 [00:00<?, ?it/s]

Epoch 18 - Average training loss: 1.1781
Running validation...
Epoch 18 - Validation loss: 1.2328


Training - Epoch 19/25:   0%|          | 0/450 [00:00<?, ?it/s]

Epoch 19 - Average training loss: 1.1706
Running validation...
Epoch 19 - Validation loss: 1.2332


Training - Epoch 20/25:   0%|          | 0/450 [00:00<?, ?it/s]

Epoch 20 - Average training loss: 1.1763
Running validation...
Epoch 20 - Validation loss: 1.2318


Training - Epoch 21/25:   0%|          | 0/450 [00:00<?, ?it/s]

Epoch 21 - Average training loss: 1.1572
Running validation...
Epoch 21 - Validation loss: 1.2329


Training - Epoch 22/25:   0%|          | 0/450 [00:00<?, ?it/s]

Epoch 22 - Average training loss: 1.1580
Running validation...
Epoch 22 - Validation loss: 1.2321


Training - Epoch 23/25:   0%|          | 0/450 [00:00<?, ?it/s]

Epoch 23 - Average training loss: 1.1532
Running validation...
Epoch 23 - Validation loss: 1.2342


Training - Epoch 24/25:   0%|          | 0/450 [00:00<?, ?it/s]

Epoch 24 - Average training loss: 1.1370
Running validation...
Epoch 24 - Validation loss: 1.2359


Training - Epoch 25/25:   0%|          | 0/450 [00:00<?, ?it/s]

Epoch 25 - Average training loss: 1.1468
Running validation...
Epoch 25 - Validation loss: 1.2359
Loaded best model with validation loss: 1.2258
Training completed!


###Evaluation

In [None]:
def evaluate_model(model, tokenizer, test_dataset, max_length=128, batch_size=16):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)

    # Prepare test dataset
    test_dataset_processed = T5TestDataset(test_dataset, tokenizer, max_length)
    test_loader = DataLoader(test_dataset_processed, batch_size=batch_size)

    model.eval()
    predictions = []
    references = []

    print("Starting evaluation...")
    with torch.no_grad():
        for batch in tqdm(test_loader, desc="Evaluating"):
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)

            # Generate predictions
            outputs = model.generate(
                input_ids=input_ids,
                attention_mask=attention_mask,
                max_length=max_length
            )

            # Decode predictions
            preds = [tokenizer.decode(output, skip_special_tokens=True) for output in outputs]
            refs = batch["reference"]

            predictions.extend(preds)
            references.extend(refs)

    # Calculate accuracy
    accuracy = accuracy_score(references, predictions) * 100
    return accuracy, predictions, references


In [None]:
accuracy, predictions, references = evaluate_model(model, tokenizer, test_set)

Starting evaluation...


Evaluating:   0%|          | 0/77 [00:00<?, ?it/s]

  type_true = type_of_target(y_true, input_name="y_true")
  type_pred = type_of_target(y_pred, input_name="y_pred")


###Results

In [None]:
# 2nd run - Final Evaluation Accuracy: 42.18%
def display_evaluation_results(accuracy, predictions, references, num_samples=3):
    print(f"Final Evaluation Accuracy: {accuracy:.2f}%")

    # Display some examples
    print("\nSample predictions:")
    indices = list(range(len(predictions)))
    sample_indices = indices[:num_samples]

    for i in sample_indices:
        print(f"Reference: {references[i]}")
        print(f"Prediction: {predictions[i]}")
        print("-" * 50)

# Display results
display_evaluation_results(accuracy, predictions, references, num_samples=3)

Final Evaluation Accuracy: 42.18%

Sample predictions:
Reference: wooded area
Prediction: wooded area
--------------------------------------------------
Reference: go downtown
Prediction: go downtown
--------------------------------------------------
Reference: play tag
Prediction: play tag
--------------------------------------------------


In [None]:
# 1st run - Final Evaluation Accuracy: 41.36%
def display_evaluation_results(accuracy, predictions, references, num_samples=3):
    print(f"Final Evaluation Accuracy: {accuracy:.2f}%")

    # Display some examples
    print("\nSample predictions:")
    indices = list(range(len(predictions)))
    sample_indices = indices[:num_samples]

    for i in sample_indices:
        print(f"Reference: {references[i]}")
        print(f"Prediction: {predictions[i]}")
        print("-" * 50)

# Display results
display_evaluation_results(accuracy, predictions, references, num_samples=3)

Final Evaluation Accuracy: 41.36%

Sample predictions:
Reference: wooded area
Prediction: wooded area
--------------------------------------------------
Reference: go downtown
Prediction: east
--------------------------------------------------
Reference: play tag
Prediction: play tag
--------------------------------------------------


Observe that results are not always the same after each run and it is not guaranteed to be always better than traditional finetuning.