# Model Distillation for Efficient AI Deployment

## Install libraries

In [1]:
!pip install datasets torch transformers rouge

Collecting rouge
  Downloading rouge-1.0.1-py3-none-any.whl.metadata (4.1 kB)
Downloading rouge-1.0.1-py3-none-any.whl (13 kB)
Installing collected packages: rouge
Successfully installed rouge-1.0.1


In [2]:
# Import required libraries
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.optim import AdamW
from transformers import (
    T5ForConditionalGeneration,
    T5Tokenizer,
    get_scheduler
)
import pandas as pd
import numpy as np
from datasets import load_dataset
from tqdm.notebook import tqdm
import wandb
import gc

In [3]:
# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Check GPU availability
if torch.cuda.is_available():
    print(f"GPU Model: {torch.cuda.get_device_name(0)}")
    print(f"Available GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**2:.2f} MB")
    print(f"Current GPU Memory Usage: {torch.cuda.memory_allocated(0) / 1024**2:.2f} MB")
    !nvidia-smi

# Set seed for reproducibility
def set_seed(seed=42):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)

set_seed()

Using device: cuda
GPU Model: Tesla T4
Available GPU Memory: 15095.06 MB
Current GPU Memory Usage: 0.00 MB
Wed Apr  2 09:27:51 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 560.35.03              Driver Version: 560.35.03      CUDA Version: 12.6     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  Tesla T4                       Off |   00000000:00:04.0 Off |                    0 |
| N/A   41C    P8              9W /   70W |       3MiB /  15360MiB |      0%      Default |
|                                         |                        |                  N/A |
+--------------------------------

In [4]:
from kaggle_secrets import UserSecretsClient
from huggingface_hub import login
user_secrets = UserSecretsClient()
hugging_face_token = user_secrets.get_secret("Hugging_Face_Token")
wnb_token = user_secrets.get_secret("wnb")

# Login to Hugging Face
login(hugging_face_token) # from huggingface_hub import login

# Login to WnB
wandb.login(key=wnb_token) # import wandb
run = wandb.init(
    project='Distillation-T5', 
    job_type="training", 
    anonymous="allow"
)

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mwenxupine[0m ([33mwenxupine-tampere-university[0m). Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


In [16]:
# Define Summarization Dataset class
class SummarizationDataset(Dataset):
    def __init__(self, texts, summaries, tokenizer, max_source_length=512, max_target_length=128):
        self.texts = texts
        self.summaries = summaries
        self.tokenizer = tokenizer
        self.max_source_length = max_source_length
        self.max_target_length = max_target_length

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        text = self.texts[idx]
        summary = self.summaries[idx]

        # Add task prefix to input
        source_text = f"summarize: {text}"

        # Tokenize input text
        source_encoding = self.tokenizer(
            source_text,
            max_length=self.max_source_length,
            padding="max_length",
            truncation=True,
            return_tensors="pt"
        )

        # Tokenize target summary
        target_encoding = self.tokenizer(
            summary,
            max_length=self.max_target_length,
            padding="max_length",
            truncation=True,
            return_tensors="pt"
        )

        input_ids = source_encoding["input_ids"].squeeze()
        attention_mask = source_encoding["attention_mask"].squeeze()
        labels = target_encoding["input_ids"].squeeze()
        labels[labels == self.tokenizer.pad_token_id] = -100  # Ignore pad tokens in loss

        return {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "labels": labels,
            "text": text,
            "summary": summary
        }

# Load dataset function
def load_dataset_for_distillation(tokenizer, batch_size=4):
    # Using CNN/DailyMail dataset as an example
    print("Loading dataset...")
    dataset = load_dataset("cnn_dailymail", "3.0.0")

    # Extract training and validation sets
    # Limiting to smaller subsets for faster training
    train_texts = dataset["train"]["article"][:1000]
    train_summaries = dataset["train"]["highlights"][:1000]

    val_texts = dataset["validation"]["article"][:200]
    val_summaries = dataset["validation"]["highlights"][:200]

    # Create datasets
    train_dataset = SummarizationDataset(train_texts, train_summaries, tokenizer)
    val_dataset = SummarizationDataset(val_texts, val_summaries, tokenizer)

    # Create data loaders
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

    return train_loader, val_loader

In [17]:
# Load models and tokenizers
def load_models_and_tokenizers():
    # Using T5-base as teacher model (much smaller than T5-3B)
    print("Loading teacher model (T5-base)...")
    teacher_tokenizer = T5Tokenizer.from_pretrained("t5-base")
    teacher_model = T5ForConditionalGeneration.from_pretrained("t5-base").to(device)
    teacher_model.eval()  # Set to evaluation mode

    print("Loading student model (T5-small)...")
    student_tokenizer = T5Tokenizer.from_pretrained("t5-small")
    student_model = T5ForConditionalGeneration.from_pretrained("t5-small").to(device)

    return teacher_model, teacher_tokenizer, student_model, student_tokenizer

# Memory optimization function
def optimize_memory():
    # Clear cache
    gc.collect()
    torch.cuda.empty_cache()

    # Print memory usage
    if torch.cuda.is_available():
        print(f"Current GPU Memory Usage: {torch.cuda.memory_allocated(0) / 1024**2:.2f} MB")
        print(f"GPU Memory Cached: {torch.cuda.memory_reserved(0) / 1024**2:.2f} MB")

# Load models
teacher_model, teacher_tokenizer, student_model, student_tokenizer = load_models_and_tokenizers()
optimize_memory()

Loading teacher model (T5-base)...
Loading student model (T5-small)...
Current GPU Memory Usage: 3083.84 MB
GPU Memory Cached: 3356.00 MB


In [18]:
# Define Distillation Loss
class DistillationLoss(nn.Module):
    def __init__(self, temperature=2.0, alpha=0.5):
        super(DistillationLoss, self).__init__()
        self.temperature = temperature
        self.alpha = alpha

    def forward(self, student_outputs, teacher_outputs):
        # Get logits
        student_logits = student_outputs.logits
        teacher_logits = teacher_outputs.logits

        # Get teacher's distribution
        teacher_probs = F.softmax(teacher_logits / self.temperature, dim=-1)

        # Calculate student's soft target loss
        soft_targets_loss = -(teacher_probs * F.log_softmax(student_logits / self.temperature, dim=-1)).sum(dim=-1).mean()

        # Use student_outputs' loss as hard target loss
        hard_targets_loss = student_outputs.loss

        # Combine soft and hard target losses
        loss = self.alpha * (self.temperature ** 2) * soft_targets_loss + (1 - self.alpha) * hard_targets_loss

        return loss, hard_targets_loss, soft_targets_loss

# Initialize distillation loss function
distillation_loss_fn = DistillationLoss(temperature=2.0, alpha=0.7)

In [19]:
# Training step function
# Fixed train step function - Run this first before starting training
def train_step(teacher_model, student_model, train_loader, optimizer, scheduler, distillation_loss_fn, epoch):
    student_model.train()
    total_loss = 0
    total_hard_loss = 0
    total_soft_loss = 0

    progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1} Training")

    for batch_idx, batch in enumerate(progress_bar):  # Added batch_idx here
        # Move data to device
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        labels = batch["labels"].to(device)

        # Clear gradients
        optimizer.zero_grad()

        # Forward pass - Student model
        student_outputs = student_model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels=labels
        )

        # Get Teacher model outputs (no gradient computation)
        with torch.no_grad():
            teacher_outputs = teacher_model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                labels=labels
            )

        # Calculate distillation loss
        loss, hard_loss, soft_loss = distillation_loss_fn(student_outputs, teacher_outputs)

        # Backward pass
        loss.backward()

        # Gradient clipping
        torch.nn.utils.clip_grad_norm_(student_model.parameters(), max_norm=1.0)

        # Update parameters
        optimizer.step()
        scheduler.step()

        # Update loss
        total_loss += loss.item()
        total_hard_loss += hard_loss.item()
        total_soft_loss += soft_loss.item()

        # Update progress bar
        progress_bar.set_postfix({
            'loss': loss.item(),
            'hard_loss': hard_loss.item(),
            'soft_loss': soft_loss.item()
        })

        # Free memory
        del input_ids, attention_mask, labels, student_outputs, teacher_outputs, loss
        if batch_idx % 10 == 0:  # Every 10 batches
            optimize_memory()

    avg_loss = total_loss / len(train_loader)
    avg_hard_loss = total_hard_loss / len(train_loader)
    avg_soft_loss = total_soft_loss / len(train_loader)

    return avg_loss, avg_hard_loss, avg_soft_loss

# Evaluation function
def evaluate(student_model, val_loader):
    student_model.eval()
    generated_summaries = []
    reference_summaries = []

    with torch.no_grad():
        for batch in tqdm(val_loader, desc="Evaluating"):
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)

            # Generate summaries
            summary_ids = student_model.generate(
                input_ids=input_ids,
                attention_mask=attention_mask,
                max_length=128,
                num_beams=4,
                early_stopping=True
            )

            # Convert generated IDs to text
            decoded_summaries = [student_tokenizer.decode(g, skip_special_tokens=True) for g in summary_ids]
            generated_summaries.extend(decoded_summaries)
            reference_summaries.extend(batch["summary"])

            # Free memory
            del input_ids, attention_mask, summary_ids

    # Return generated summaries and references
    return generated_summaries, reference_summaries

In [21]:
# Load dataset
batch_size = 2  # Small batch size to save memory
train_loader, val_loader = load_dataset_for_distillation(student_tokenizer, batch_size=batch_size)
# Calculate total training steps
num_epochs = 3
total_steps = len(train_loader) * num_epochs
# Better optimizer configuration
optimizer = AdamW(student_model.parameters(), 
                 lr=3e-5,  # learning rate
                 weight_decay=0.01)

# Use cosine scheduler instead of linear
scheduler = get_scheduler(
    "linear",
    optimizer=optimizer,
    num_warmup_steps=100,
    num_training_steps=total_steps
)
print(f"Total training steps: {total_steps}")
print(f"Training with batch size: {batch_size}")
print(f"Number of training examples: {len(train_loader.dataset)}")
print(f"Number of validation examples: {len(val_loader.dataset)}")

Loading dataset...
Total training steps: 1500
Training with batch size: 2
Number of training examples: 1000
Number of validation examples: 200


In [22]:
print("Starting distillation training...")

for epoch in range(num_epochs):
    # Train one epoch
    avg_loss, avg_hard_loss, avg_soft_loss = train_step(
        teacher_model, student_model, train_loader, optimizer, scheduler, distillation_loss_fn, epoch
    )

    print(f"Epoch {epoch+1} - Average Loss: {avg_loss:.4f}, Hard Target Loss: {avg_hard_loss:.4f}, Soft Target Loss: {avg_soft_loss:.4f}")

    # Generate and evaluate summaries every epoch
    print("Generating summary examples...")
    generated_summaries, reference_summaries = evaluate(student_model, val_loader)

    # Print some examples
    for i in range(min(3, len(generated_summaries))):
        print(f"\nReference Summary: {reference_summaries[i]}")
        print(f"Generated Summary: {generated_summaries[i]}")
        print("-" * 50)

    # Save model checkpoint
    checkpoint_path = f"t5_small_distilled_epoch_{epoch+1}"
    student_model.save_pretrained(checkpoint_path)
    student_tokenizer.save_pretrained(checkpoint_path)
    print(f"Saved model checkpoint: {checkpoint_path}")

    # Free memory
    optimize_memory()

print("Distillation training completed!")

Starting distillation training...


Epoch 1 Training:   0%|          | 0/500 [00:00<?, ?it/s]

Current GPU Memory Usage: 3784.41 MB
GPU Memory Cached: 4090.00 MB
Current GPU Memory Usage: 3784.41 MB
GPU Memory Cached: 4216.00 MB
Current GPU Memory Usage: 3784.41 MB
GPU Memory Cached: 4192.00 MB
Current GPU Memory Usage: 3784.41 MB
GPU Memory Cached: 4212.00 MB
Current GPU Memory Usage: 3784.41 MB
GPU Memory Cached: 4272.00 MB
Current GPU Memory Usage: 3785.41 MB
GPU Memory Cached: 4228.00 MB
Current GPU Memory Usage: 3784.41 MB
GPU Memory Cached: 4268.00 MB
Current GPU Memory Usage: 3784.41 MB
GPU Memory Cached: 4248.00 MB
Current GPU Memory Usage: 3784.41 MB
GPU Memory Cached: 4248.00 MB
Current GPU Memory Usage: 3784.41 MB
GPU Memory Cached: 4248.00 MB
Current GPU Memory Usage: 3784.41 MB
GPU Memory Cached: 4248.00 MB
Current GPU Memory Usage: 3784.41 MB
GPU Memory Cached: 4248.00 MB
Current GPU Memory Usage: 3784.41 MB
GPU Memory Cached: 4248.00 MB
Current GPU Memory Usage: 3784.41 MB
GPU Memory Cached: 4248.00 MB
Current GPU Memory Usage: 3784.41 MB
GPU Memory Cached: 4248.0

Evaluating:   0%|          | 0/100 [00:00<?, ?it/s]


Reference Summary: Zully Broussard decided to give a kidney to a stranger .
A new computer program helped her donation spur transplants for six kidney patients .
Generated Summary: "I thought I was going to help this one person who I don't know, but the fact that so many people can have a life extension, that's pretty big," a comment on a Facebook page read. five surgeons, a covey of physician assistants, nurses, anesthesiologists, and more than 40 support staff perform surgeries on 12 people. the chain of surgeries is to be wrapped up Friday.
--------------------------------------------------

Reference Summary: The 20th MLS season begins this weekend .
League has changed dramatically since its inception in 1996 .
Some question whether rules regarding salary caps and transfers need to change .
Generated Summary: MLS is the first of a new domestic television and media rights deal with FOX, ESPN and Univision. the new season is the first of a new domestic television and media rights de

Epoch 2 Training:   0%|          | 0/500 [00:00<?, ?it/s]

Current GPU Memory Usage: 3785.41 MB
GPU Memory Cached: 4272.00 MB
Current GPU Memory Usage: 3785.41 MB
GPU Memory Cached: 4228.00 MB
Current GPU Memory Usage: 3784.41 MB
GPU Memory Cached: 4248.00 MB
Current GPU Memory Usage: 3784.41 MB
GPU Memory Cached: 4248.00 MB
Current GPU Memory Usage: 3784.41 MB
GPU Memory Cached: 4248.00 MB
Current GPU Memory Usage: 3784.41 MB
GPU Memory Cached: 4248.00 MB
Current GPU Memory Usage: 3784.41 MB
GPU Memory Cached: 4248.00 MB
Current GPU Memory Usage: 3784.41 MB
GPU Memory Cached: 4248.00 MB
Current GPU Memory Usage: 3784.41 MB
GPU Memory Cached: 4248.00 MB
Current GPU Memory Usage: 3784.41 MB
GPU Memory Cached: 4248.00 MB
Current GPU Memory Usage: 3784.41 MB
GPU Memory Cached: 4248.00 MB
Current GPU Memory Usage: 3784.41 MB
GPU Memory Cached: 4248.00 MB
Current GPU Memory Usage: 3784.41 MB
GPU Memory Cached: 4248.00 MB
Current GPU Memory Usage: 3784.41 MB
GPU Memory Cached: 4248.00 MB
Current GPU Memory Usage: 3784.41 MB
GPU Memory Cached: 4248.0

Evaluating:   0%|          | 0/100 [00:00<?, ?it/s]


Reference Summary: Zully Broussard decided to give a kidney to a stranger .
A new computer program helped her donation spur transplants for six kidney patients .
Generated Summary: "I thought I was going to help this one person who I don't know, but the fact that so many people can have a life extension, that's pretty big," a comment on a Facebook page read. "I know this entire journey is much bigger than all of us," a comment on a Facebook page read. five surgeons, a covey of physician assistants, nurses, anesthesiologists, and more than 40 support staff perform surgeries.
--------------------------------------------------

Reference Summary: The 20th MLS season begins this weekend .
League has changed dramatically since its inception in 1996 .
Some question whether rules regarding salary caps and transfers need to change .
Generated Summary: MLS is the first of a new domestic television and media rights deal with FOX, ESPN and Univision. the new season is the first of a new domestic

Epoch 3 Training:   0%|          | 0/500 [00:00<?, ?it/s]

Current GPU Memory Usage: 3785.41 MB
GPU Memory Cached: 4272.00 MB
Current GPU Memory Usage: 3785.41 MB
GPU Memory Cached: 4228.00 MB
Current GPU Memory Usage: 3784.41 MB
GPU Memory Cached: 4248.00 MB
Current GPU Memory Usage: 3784.41 MB
GPU Memory Cached: 4248.00 MB
Current GPU Memory Usage: 3784.41 MB
GPU Memory Cached: 4248.00 MB
Current GPU Memory Usage: 3784.41 MB
GPU Memory Cached: 4248.00 MB
Current GPU Memory Usage: 3784.41 MB
GPU Memory Cached: 4248.00 MB
Current GPU Memory Usage: 3784.41 MB
GPU Memory Cached: 4248.00 MB
Current GPU Memory Usage: 3784.41 MB
GPU Memory Cached: 4248.00 MB
Current GPU Memory Usage: 3784.41 MB
GPU Memory Cached: 4248.00 MB
Current GPU Memory Usage: 3784.41 MB
GPU Memory Cached: 4248.00 MB
Current GPU Memory Usage: 3784.41 MB
GPU Memory Cached: 4248.00 MB
Current GPU Memory Usage: 3784.41 MB
GPU Memory Cached: 4248.00 MB
Current GPU Memory Usage: 3784.41 MB
GPU Memory Cached: 4248.00 MB
Current GPU Memory Usage: 3784.41 MB
GPU Memory Cached: 4248.0

Evaluating:   0%|          | 0/100 [00:00<?, ?it/s]


Reference Summary: Zully Broussard decided to give a kidney to a stranger .
A new computer program helped her donation spur transplants for six kidney patients .
Generated Summary: "I thought I was going to help this one person who I don't know, but the fact that so many people can have a life extension, that's pretty big," she says. "I know this entire journey is much bigger than all of us. I also know I'm just the messenger," she says. "the ages of the donors and recipients range from 26 to 70," the medical center says.
--------------------------------------------------

Reference Summary: The 20th MLS season begins this weekend .
League has changed dramatically since its inception in 1996 .
Some question whether rules regarding salary caps and transfers need to change .
Generated Summary: MLS is the first of a new domestic television and media rights deal with FOX, ESPN and Univision. the new season is the first of a new domestic television and media rights deal with FOX, ESPN and 

In [23]:
repo_id = "Wenfi/distillation-T5-cnn"

# Save model and tokenizer locally
save_path = "./student_model_distilled"
student_model.save_pretrained(save_path)
student_tokenizer.save_pretrained(save_path)

# Push to Hugging Face
student_model.push_to_hub(repo_id)
student_tokenizer.push_to_hub(repo_id)

README.md:   0%|          | 0.00/5.17k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/242M [00:00<?, ?B/s]

No files have been modified since last commit. Skipping to prevent empty commit.


CommitInfo(commit_url='https://huggingface.co/Wenfi/distillation-T5-cnn/commit/8e71fdae0b63748927ff3c80eb0c1b5ed11cb068', commit_message='Upload tokenizer', commit_description='', oid='8e71fdae0b63748927ff3c80eb0c1b5ed11cb068', pr_url=None, repo_url=RepoUrl('https://huggingface.co/Wenfi/distillation-T5-cnn', endpoint='https://huggingface.co', repo_type='model', repo_id='Wenfi/distillation-T5-cnn'), pr_revision=None, pr_num=None)

## Evaluate

In [24]:
# Inference function for the distilled model
def generate_summary(model, tokenizer, text, max_length=150):
    # Prepare input
    input_text = f"summarize: {text}"
    input_ids = tokenizer.encode(input_text, return_tensors="pt").to(device)

    # Generate summary
    summary_ids = model.generate(
        input_ids,
        max_length=max_length,
        num_beams=4,
        early_stopping=True,
        no_repeat_ngram_size=2
    )

    # Decode summary
    summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
    return summary

# Compare original and distilled model's output for a given text
def compare_models(text):
    # Generate with teacher model
    teacher_summary = generate_summary(teacher_model, teacher_tokenizer, text)

    # Generate with student model
    student_summary = generate_summary(student_model, student_tokenizer, text)

    print("Original Text:")
    print(text[:500] + "..." if len(text) > 500 else text)
    print("\nTeacher Model Summary:")
    print(teacher_summary)
    print("\nDistilled Student Model Summary:")
    print(student_summary)

# Example usage
sample_text = """
Climate change is the long-term alteration of temperature and typical weather patterns in a place.
Climate change could refer to a particular location or the planet as a whole. Climate change may
cause weather patterns to be less predictable. These unexpected weather patterns can make it
difficult to maintain and grow crops in regions that rely on farming because expected temperature
and rainfall levels can no longer be relied on. Climate change has also been connected with other
damaging weather events such as more frequent and more intense hurricanes, floods, downpours, and
winter storms. In polar regions, the warming global temperatures associated with climate change have
meant ice sheets and glaciers are melting at an accelerated rate from season to season. This contributes
to sea levels rising in different regions of the planet. Together with expanding ocean waters due to
rising temperatures, the resulting rise in sea level has begun to damage coastlines as a result of
increased flooding and erosion.
"""

compare_models(sample_text)

Original Text:

Climate change is the long-term alteration of temperature and typical weather patterns in a place.
Climate change could refer to a particular location or the planet as a whole. Climate change may
cause weather patterns to be less predictable. These unexpected weather patterns can make it
difficult to maintain and grow crops in regions that rely on farming because expected temperature
and rainfall levels can no longer be relied on. Climate change has also been connected with other
damaging weath...

Teacher Model Summary:
climate change is the long-term alteration of temperature and typical weather patterns in a place . this can make it difficult to maintain and grow crops in regions that rely on farming if temperatures and rainfall levels can no longer be relied on.

Distilled Student Model Summary:
climate change is the long-term alteration of temperature and typical weather patterns in a place. polar regions have been linked to more frequent and more intense hurricane

In [25]:
# Function to compare original and distilled models's summaries with teacher models' summaries
def compare_original_vs_distilled(text, original_model_path="t5-small"):
    """
    Compare summaries generated by the original student model and the distilled model.

    Args:
        text (str): Input text to summarize
        original_model_path (str): Path to the original model (default: "t5-small")

    Returns:
        None: Prints the comparison results
    """
    # Load the original student model
    print("Loading original student model...")
    original_model = T5ForConditionalGeneration.from_pretrained(original_model_path).to(device)
    original_tokenizer = T5Tokenizer.from_pretrained(original_model_path)

    # Load the latest distilled model checkpoint (assuming it exists)
    distilled_model_path = f"t5_small_distilled_epoch_{num_epochs}"

    try:
        print(f"Loading distilled model from {distilled_model_path}...")
        distilled_model = T5ForConditionalGeneration.from_pretrained(distilled_model_path).to(device)
        distilled_tokenizer = T5Tokenizer.from_pretrained(distilled_model_path)
    except:
        print("Couldn't find the distilled model checkpoint. Using current student model...")
        distilled_model = student_model
        distilled_tokenizer = student_tokenizer

    # Set models to evaluation mode
    original_model.eval()
    distilled_model.eval()
    teacher_model.eval()

    # Function to generate summary
    def get_summary(model, tokenizer, text):
        input_text = f"summarize: {text}"
        input_ids = tokenizer.encode(input_text, return_tensors="pt").to(device)

        with torch.no_grad():
            start_time = time.time()
            summary_ids = model.generate(
                input_ids,
                max_length=150,
                num_beams=4,
                early_stopping=True,
                no_repeat_ngram_size=2
            )
            end_time = time.time()

        inference_time = end_time - start_time
        summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
        return summary, inference_time

    # Get summaries from all models
    original_summary, original_time = get_summary(original_model, original_tokenizer, text)
    distilled_summary, distilled_time = get_summary(distilled_model, distilled_tokenizer, text)
    teacher_summary, teacher_time = get_summary(teacher_model, teacher_tokenizer, text)

    # Calculate ROUGE score (if rouge is available)
    try:
        from rouge import Rouge
        rouge = Rouge()

        # Calculate ROUGE scores comparing to teacher model
        original_scores = rouge.get_scores(original_summary, teacher_summary)[0]
        distilled_scores = rouge.get_scores(distilled_summary, teacher_summary)[0]
        rouge_available = True
    except:
        rouge_available = False
        print("ROUGE scoring not available. Install it with: pip install rouge")

    # Print results
    print("\n" + "="*80)
    print("COMPARISON OF MODELS")
    print("="*80)

    print("\nORIGINAL TEXT:")
    print(text[:500] + "..." if len(text) > 500 else text)

    print("\nTEACHER MODEL SUMMARY (T5-base):")
    print(f"Time: {teacher_time:.4f} seconds")
    print(teacher_summary)

    print("\nORIGINAL STUDENT MODEL SUMMARY (T5-small):")
    print(f"Time: {original_time:.4f} seconds")
    print(original_summary)

    print("\nDISTILLED STUDENT MODEL SUMMARY (T5-small distilled):")
    print(f"Time: {distilled_time:.4f} seconds")
    print(distilled_summary)

    # Print ROUGE scores if available
    if rouge_available:
        print("\nROUGE SCORES (compared to teacher model):")
        print(f"Original Student Model:")
        print(f"  ROUGE-1: {original_scores['rouge-1']['f']:.4f}")
        print(f"  ROUGE-2: {original_scores['rouge-2']['f']:.4f}")
        print(f"  ROUGE-L: {original_scores['rouge-l']['f']:.4f}")

        print(f"\nDistilled Student Model:")
        print(f"  ROUGE-1: {distilled_scores['rouge-1']['f']:.4f}")
        print(f"  ROUGE-2: {distilled_scores['rouge-2']['f']:.4f}")
        print(f"  ROUGE-L: {distilled_scores['rouge-l']['f']:.4f}")

        # Calculate improvement percentage
        rouge1_improvement = ((distilled_scores['rouge-1']['f'] - original_scores['rouge-1']['f']) /
                             original_scores['rouge-1']['f'] * 100)
        rouge2_improvement = ((distilled_scores['rouge-2']['f'] - original_scores['rouge-2']['f']) /
                             original_scores['rouge-2']['f'] * 100)
        rougeL_improvement = ((distilled_scores['rouge-l']['f'] - original_scores['rouge-l']['f']) /
                             original_scores['rouge-l']['f'] * 100)

        print(f"\nImprovement from Distillation:")
        print(f"  ROUGE-1: {rouge1_improvement:.2f}%")
        print(f"  ROUGE-2: {rouge2_improvement:.2f}%")
        print(f"  ROUGE-L: {rougeL_improvement:.2f}%")

    # Performance comparison
    speed_improvement = ((original_time - distilled_time) / original_time) * 100
    print(f"\nInference Speed Improvement: {speed_improvement:.2f}%")

    # Clean up to save memory
    del original_model, distilled_model
    gc.collect()
    torch.cuda.empty_cache()

# Example usage:
# First, import the time module if not already imported
import time

# Sample text to test
sample_text = """
Climate change is the long-term alteration of temperature and typical weather patterns in a place.
Climate change could refer to a particular location or the planet as a whole. Climate change may
cause weather patterns to be less predictable. These unexpected weather patterns can make it
difficult to maintain and grow crops in regions that rely on farming because expected temperature
and rainfall levels can no longer be relied on. Climate change has also been connected with other
damaging weather events such as more frequent and more intense hurricanes, floods, downpours, and
winter storms. In polar regions, the warming global temperatures associated with climate change have
meant ice sheets and glaciers are melting at an accelerated rate from season to season.
"""

#compare_original_vs_distilled(sample_text)

In [26]:
compare_original_vs_distilled(sample_text)

Loading original student model...
Loading distilled model from t5_small_distilled_epoch_3...

COMPARISON OF MODELS

ORIGINAL TEXT:

Climate change is the long-term alteration of temperature and typical weather patterns in a place.
Climate change could refer to a particular location or the planet as a whole. Climate change may
cause weather patterns to be less predictable. These unexpected weather patterns can make it
difficult to maintain and grow crops in regions that rely on farming because expected temperature
and rainfall levels can no longer be relied on. Climate change has also been connected with other
damaging weath...

TEACHER MODEL SUMMARY (T5-base):
Time: 1.9858 seconds
climate change is the long-term alteration of temperature and typical weather patterns in a place . this can make it difficult to maintain and grow crops in regions that rely on farming if weather is less predictable, says nina dos santos, director of climate research at the u.s. ice sheets and glaciers are m

In [19]:
!pip install -q rouge

In [27]:
import time
import torch
import gc
from rouge import Rouge
from transformers import T5ForConditionalGeneration, T5Tokenizer

def compare_student_vs_distilled(texts, original_model_path="t5-small"):
    """
    Compare the original student model with the distilled model using ROUGE scores.
    High rouge indicates these two models' outputs are more similar, and it doesn't mean good quality.

    Args:
        texts (list): List of input texts to summarize
        original_model_path (str): Path to the original model (default: "t5-small")
    """
    # Initialize ROUGE
    try:
        rouge = Rouge()
        print("ROUGE scoring initialized successfully")
    except Exception as e:
        print(f"Error initializing ROUGE: {e}")
        print("Installing ROUGE...")
        !pip install -q rouge
        try:
            from rouge import Rouge
            rouge = Rouge()
            print("ROUGE installed and initialized successfully")
        except Exception as e:
            print(f"Failed to install ROUGE: {e}")
            return

    # Load the original student model
    print("Loading original student model...")
    original_model = T5ForConditionalGeneration.from_pretrained(original_model_path).to(device)
    original_tokenizer = T5Tokenizer.from_pretrained(original_model_path)

    # Load the latest distilled model checkpoint
    distilled_model_path = f"t5_small_distilled_epoch_{num_epochs}"

    try:
        print(f"Loading distilled model from {distilled_model_path}...")
        distilled_model = T5ForConditionalGeneration.from_pretrained(distilled_model_path).to(device)
        distilled_tokenizer = T5Tokenizer.from_pretrained(distilled_model_path)
    except:
        print("Couldn't find the distilled model checkpoint. Using current student model...")
        distilled_model = student_model
        distilled_tokenizer = student_tokenizer

    # Function to generate summary
    def get_summary(model, tokenizer, text):
        input_text = f"summarize: {text}"
        input_ids = tokenizer.encode(input_text, return_tensors="pt").to(device)

        with torch.no_grad():
            start_time = time.time()
            summary_ids = model.generate(
                input_ids,
                max_length=150,
                num_beams=4,
                early_stopping=True,
                no_repeat_ngram_size=2
            )
            end_time = time.time()

        inference_time = end_time - start_time
        summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
        return summary, inference_time

    # Results storage
    all_rouge1 = []
    all_rouge2 = []
    all_rougeL = []
    all_orig_times = []
    all_dist_times = []

    # Process each text
    for i, text in enumerate(texts):
        print(f"\nProcessing text {i+1}/{len(texts)}")

        # Get summaries
        original_summary, original_time = get_summary(original_model, original_tokenizer, text)
        distilled_summary, distilled_time = get_summary(distilled_model, distilled_tokenizer, text)

        # Track times
        all_orig_times.append(original_time)
        all_dist_times.append(distilled_time)

        # Calculate ROUGE scores between original and distilled
        scores = rouge.get_scores(distilled_summary, original_summary)[0]

        # Store scores
        all_rouge1.append(scores['rouge-1']['f'])
        all_rouge2.append(scores['rouge-2']['f'])
        all_rougeL.append(scores['rouge-l']['f'])

        # Print individual results
        print(f"\nText {i+1}:")
        print(f"Original Summary ({original_time:.4f}s): {original_summary}")
        print(f"Distilled Summary ({distilled_time:.4f}s): {distilled_summary}")
        print(f"ROUGE-1: {scores['rouge-1']['f']:.4f}")
        print(f"ROUGE-2: {scores['rouge-2']['f']:.4f}")
        print(f"ROUGE-L: {scores['rouge-l']['f']:.4f}")

    # Calculate averages
    avg_rouge1 = sum(all_rouge1) / len(all_rouge1)
    avg_rouge2 = sum(all_rouge2) / len(all_rouge2)
    avg_rougeL = sum(all_rougeL) / len(all_rougeL)
    avg_orig_time = sum(all_orig_times) / len(all_orig_times)
    avg_dist_time = sum(all_dist_times) / len(all_dist_times)

    # Print summary report
    print("\n" + "="*80)
    print("STUDENT VS DISTILLED MODEL COMPARISON")
    print("="*80)

    print("\nROUGE SCORES SUMMARY (higher means more similar):")
    print(f"Average ROUGE-1: {avg_rouge1:.4f}")
    print(f"Average ROUGE-2: {avg_rouge2:.4f}")
    print(f"Average ROUGE-L: {avg_rougeL:.4f}")

    print("\nSPEED COMPARISON:")
    speed_improvement = ((avg_orig_time - avg_dist_time) / avg_orig_time) * 100
    print(f"Original model average time: {avg_orig_time:.4f} seconds")
    print(f"Distilled model average time: {avg_dist_time:.4f} seconds")
    print(f"Speed improvement: {speed_improvement:.2f}%")

    # Interpretation
    print("\nINTERPRETATION:")

    if avg_rouge1 > 0.8 and avg_rouge2 > 0.6 and avg_rougeL > 0.7:
        print("• The distilled model produces very similar summaries to the original student model")
    elif avg_rouge1 > 0.6 and avg_rouge2 > 0.4 and avg_rougeL > 0.5:
        print("• The distilled model produces moderately similar summaries to the original student model")
    else:
        print("• The distilled model produces somewhat different summaries from the original student model")

    if speed_improvement > 10:
        print(f"• The distilled model shows significant speed improvements ({speed_improvement:.1f}%)")
    elif speed_improvement > 0:
        print(f"• The distilled model shows modest speed improvements ({speed_improvement:.1f}%)")
    else:
        print(f"• The distilled model does not show speed improvements ({speed_improvement:.1f}%)")

    # Clean up memory
    del original_model, distilled_model
    gc.collect()
    torch.cuda.empty_cache()

# Example texts for testing
test_texts = [
    """
    Climate change is the long-term alteration of temperature and typical weather patterns in a place.
    Climate change could refer to a particular location or the planet as a whole. Climate change may
    cause weather patterns to be less predictable. These unexpected weather patterns can make it
    difficult to maintain and grow crops in regions that rely on farming because expected temperature
    and rainfall levels can no longer be relied on.
    """,

    """
    Artificial intelligence (AI) refers to the simulation of human intelligence in machines that are
    programmed to think like humans and mimic their actions. The term may also be applied to any machine
    that exhibits traits associated with a human mind such as learning and problem-solving. The ideal
    characteristic of artificial intelligence is its ability to rationalize and take actions that have
    the best chance of achieving a specific goal.
    """,

    """
    The COVID-19 pandemic, also known as the coronavirus pandemic, is an ongoing global pandemic of
    coronavirus disease 2019 (COVID-19) caused by severe acute respiratory syndrome coronavirus 2
    (SARS-CoV-2). The novel virus was first identified from an outbreak in Wuhan, China, in December
    2019. Attempts to contain it there failed, allowing it to spread across the globe.
    """
]

# compare_student_vs_distilled(test_texts)

In [28]:
compare_student_vs_distilled(test_texts)

Error initializing ROUGE: local variable 'Rouge' referenced before assignment
Installing ROUGE...
ROUGE installed and initialized successfully
Loading original student model...
Loading distilled model from t5_small_distilled_epoch_3...

Processing text 1/3

Text 1:
Original Summary (1.0589s): climate change is the long-term alteration of temperature and typical weather patterns in a place. climate changes could refer to particular location or the planet as an whole - and may cause weather pattern to be less predictable. if the weather is not predictable, it can make it difficult to maintain and grow crops in regions that rely on farming because expected temperatures and rainfall levels can no longer be relied on.
Distilled Summary (0.7876s): climate change is the long-term alteration of temperature and typical weather patterns in a place. this could refer to particular location or the planet asa whole. Climate change can make it difficult to maintain and grow crops in regions that rely

In [35]:
import time
# Function to compare rouge metric of different models' summary against reference summary
def compare_models_with_references(test_texts, test_references, original_model_path="t5-small"):
    """
    Compare original and distilled models using reference summaries from the dataset
    
    Args:
        test_texts (list): List of input texts from test set
        test_references (list): Corresponding reference summaries
        original_model_path (str): Path to original model
    """
    # Initialize ROUGE
    try:
        rouge = Rouge()
    except:
        !pip install -q rouge
        from rouge import Rouge
        rouge = Rouge()

    # Load models 
    print("Loading tacher student model...")
    teacher_model = T5ForConditionalGeneration.from_pretrained("t5-base").to(device)
    teacher_tokenizer = T5Tokenizer.from_pretrained("t5-base")
    print("Loading original student model...")
    original_model = T5ForConditionalGeneration.from_pretrained(original_model_path).to(device)
    original_tokenizer = T5Tokenizer.from_pretrained(original_model_path)

    print("Loading distilled model...")
    try:
        distilled_model = T5ForConditionalGeneration.from_pretrained(f"t5_small_distilled_epoch_{num_epochs}").to(device)
        distilled_tokenizer = T5Tokenizer.from_pretrained(f"t5_small_distilled_epoch_{num_epochs}")
    except:
        distilled_model = student_model
        distilled_tokenizer = student_tokenizer

    # Modified generation function to include reference
    def evaluate_model(model, tokenizer, texts, references):
        model.eval()
        generated_summaries = []
        inference_times = []
        
        for text in texts:
            input_text = f"summarize: {text}"
            input_ids = tokenizer.encode(input_text,max_length=512,         
                truncation=True, return_tensors="pt").to(device)
            
            start_time = time.time()
            with torch.no_grad():
                summary_ids = model.generate(
                    input_ids,
                    max_length=150,
                    num_beams=4,
                    early_stopping=True
                )
            inference_time = time.time() - start_time
            
            summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
            generated_summaries.append(summary)
            inference_times.append(inference_time)
            
        # Calculate ROUGE scores against references
        scores = rouge.get_scores(generated_summaries, references, avg=True)
        return scores, np.mean(inference_times)

    # Evaluate models
    print("\nEvaluating Teacher Model...")
    teacher_scores, teacher_avg_time = evaluate_model(teacher_model, teacher_tokenizer, test_texts, test_references)
    
    print("\nEvaluating Original Model...")
    orig_scores, orig_avg_time = evaluate_model(original_model, original_tokenizer, test_texts, test_references)
    
    print("\nEvaluating Distilled Model...")
    dist_scores, dist_avg_time = evaluate_model(distilled_model, distilled_tokenizer, test_texts, test_references)

    # Print comparison report
    print("\n" + "="*80)
    print("MODEL COMPARISON AGAINST REFERENCE SUMMARIES")
    print("="*80)
    
    print(f"\n{'Metric':<10} | {'Teacher Model':<15}|{'Original Model':<15} | {'Distilled Model':<15} | Improvement")
    print("-"*65)
    
    for metric in ['rouge-1', 'rouge-2', 'rouge-l']:
        teacher_f=teacher_scores[metric]['f']
        orig_f = orig_scores[metric]['f']
        dist_f = dist_scores[metric]['f']
        improvement = dist_f - orig_f
        print(f"{metric.upper():<10} | {teacher_f:.4f}{'':<5}|{orig_f:.4f}{'':<5} | {dist_f:.4f}{'':<5} | {improvement:+.4f}")

    print("\nSpeed Comparison:")
    print(f"Original Model Average Inference Time: {orig_avg_time:.4f}s")
    print(f"Distilled Model Average Inference Time: {dist_avg_time:.4f}s")
    print(f"Speed Improvement: {(orig_avg_time - dist_avg_time)/orig_avg_time*100:.2f}%")

    # Clean up
    del original_model, distilled_model
    torch.cuda.empty_cache()

# Load CNN/DailyMail test data
cnn_test = load_dataset("cnn_dailymail", "3.0.0", split="test")
test_texts = cnn_test["article"][:10]
test_references = cnn_test["highlights"][:10]

# Run comparison
compare_models_with_references(test_texts, test_references)

Loading tacher student model...
Loading original student model...
Loading distilled model...

Evaluating Teacher Model...

Evaluating Original Model...

Evaluating Distilled Model...

MODEL COMPARISON AGAINST REFERENCE SUMMARIES

Metric     | Teacher Model  |Original Model  | Distilled Model | Improvement
-----------------------------------------------------------------
ROUGE-1    | 0.2723     |0.3205      | 0.2944      | -0.0261
ROUGE-2    | 0.0996     |0.1230      | 0.1220      | -0.0011
ROUGE-L    | 0.2495     |0.2911      | 0.2858      | -0.0053

Speed Comparison:
Original Model Average Inference Time: 0.7176s
Distilled Model Average Inference Time: 0.7134s
Speed Improvement: 0.58%


## Build User interface

In [30]:
!pip install gradio

Collecting gradio
  Downloading gradio-5.23.3-py3-none-any.whl.metadata (16 kB)
Collecting fastapi<1.0,>=0.115.2 (from gradio)
  Downloading fastapi-0.115.12-py3-none-any.whl.metadata (27 kB)
Collecting ffmpy (from gradio)
  Downloading ffmpy-0.5.0-py3-none-any.whl.metadata (3.0 kB)
Collecting gradio-client==1.8.0 (from gradio)
  Downloading gradio_client-1.8.0-py3-none-any.whl.metadata (7.1 kB)
Collecting groovy~=0.1 (from gradio)
  Downloading groovy-0.1.2-py3-none-any.whl.metadata (6.1 kB)
Collecting python-multipart>=0.0.18 (from gradio)
  Downloading python_multipart-0.0.20-py3-none-any.whl.metadata (1.8 kB)
Collecting ruff>=0.9.3 (from gradio)
  Downloading ruff-0.11.2-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (25 kB)
Collecting safehttpx<0.2.0,>=0.1.6 (from gradio)
  Downloading safehttpx-0.1.6-py3-none-any.whl.metadata (4.2 kB)
Collecting semantic-version~=2.0 (from gradio)
  Downloading semantic_version-2.10.0-py2.py3-none-any.whl.metadata (9.7 kB)
Colle

In [33]:
import gradio as gr
import torch
from transformers import T5ForConditionalGeneration, T5Tokenizer
import time

# Check for GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Load original student model
def load_original_model():
    print("Loading original T5-small model...")
    model = T5ForConditionalGeneration.from_pretrained("t5-small").to(device)
    tokenizer = T5Tokenizer.from_pretrained("t5-small")
    return model, tokenizer

# Load distilled model from Hugging Face
def load_distilled_model(distilled_path="Wenfi/distillation-T5-cnn"):
    print(f"Loading distilled model from Hugging Face: {distilled_path}...")
    try:
        model = T5ForConditionalGeneration.from_pretrained(distilled_path).to(device)
        tokenizer = T5Tokenizer.from_pretrained(distilled_path)
    except Exception as e:
        print(f"Error loading distilled model: {e}")
        print("Using the original model path as fallback...")
        model = T5ForConditionalGeneration.from_pretrained("t5-small").to(device)
        tokenizer = T5Tokenizer.from_pretrained("t5-small")
    return model, tokenizer

# Load both models at startup
original_model, original_tokenizer = load_original_model()
distilled_model, distilled_tokenizer = load_distilled_model()

# Generate summary function
def generate_summary(model, tokenizer, text, max_length=150):
    # Prepare input
    input_text = f"summarize: {text}"
    input_ids = tokenizer.encode(input_text, return_tensors="pt").to(device)

    # Generate summary and measure time
    start_time = time.time()
    summary_ids = model.generate(
        input_ids,
        max_length=max_length,
        num_beams=4,
        early_stopping=True,
        no_repeat_ngram_size=2
    )
    end_time = time.time()
    inference_time = end_time - start_time

    # Decode summary
    summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
    return summary, inference_time

# Function for original model button
def summarize_with_original(text):
    if not text.strip():
        return "Please enter some text to summarize."

    summary, inference_time = generate_summary(original_model, original_tokenizer, text)
    return f"Original T5-small Summary (took {inference_time:.4f} seconds):\n\n{summary}"

# Function for distilled model button
def summarize_with_distilled(text):
    if not text.strip():
        return "Please enter some text to summarize."

    summary, inference_time = generate_summary(distilled_model, distilled_tokenizer, text)
    return f"Distilled T5-small Summary (took {inference_time:.4f} seconds):\n\n{summary}"

# Create Gradio interface
def create_ui():
    with gr.Blocks(title="T5 Summarization Models Comparison") as demo:
        gr.Markdown("# Compare Original vs Distilled T5 Summarization Models")
        gr.Markdown("Enter text below and click on either button to generate a summary using that model.")

        with gr.Row():
            text_input = gr.Textbox(
                lines=10,
                placeholder="Enter text to summarize here...",
                label="Input Text"
            )

        with gr.Row():
            original_button = gr.Button("Summarize with Original T5-small", variant="primary")
            distilled_button = gr.Button("Summarize with Distilled T5-small", variant="primary")

        with gr.Row():
            output = gr.Textbox(lines=8, label="Summary Output")

        original_button.click(
            fn=summarize_with_original,
            inputs=text_input,
            outputs=output
        )

        distilled_button.click(
            fn=summarize_with_distilled,
            inputs=text_input,
            outputs=output
        )

        gr.Markdown("""
        ## About the Models
        - **Original T5-small**: The baseline T5-small model (60M parameters)
        - **Distilled T5-small**: A T5-small model distilled from T5-base (220M parameters)
           from HuggingFace: [ooor/t5-small-distilled-summarization](https://huggingface.co/Wenfi/distillation-T5-cnn)

        The distilled model should provide similar quality summaries but potentially with faster inference.
        And we are ChunkyPanda:)
        from team12
        """)

    return demo

# Launch the UI
if __name__ == "__main__":
    # Install Gradio if not already installed
    # !pip install -q gradio

    # Create and launch the UI
    demo = create_ui()
    demo.launch(share=True)  # Set share=False if you don't want a public link

Using device: cuda
Loading original T5-small model...
Loading distilled model from Hugging Face: Wenfi/distillation-T5-cnn...


config.json:   0%|          | 0.00/1.50k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/242M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/142 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/20.8k [00:00<?, ?B/s]

spiece.model:   0%|          | 0.00/792k [00:00<?, ?B/s]

added_tokens.json:   0%|          | 0.00/2.59k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/2.54k [00:00<?, ?B/s]

* Running on local URL:  http://127.0.0.1:7860
* Running on public URL: https://ab869e13af3d4c35aa.gradio.live

This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)
