# Distillation Example: Faster Speech Transcription through Model Distillation

**Model distillation** is a machine learning technique where knowledge from a large, complex model (often referred to as the "teacher" model) is transferred to a smaller, simpler model (known as the "student" model). The goal is to make the student model mimic the teacher's performance, thereby achieving high accuracy while being more efficient in terms of computational resources and speed.

_Our task in the following is to transcribe recorded speech into text. We want to choose a model that gives us an optimal tradeoff of transcription speed and quality. We compare the **whisper** model and its distilled variant, **distil-whisper**._

In [None]:
from ai_dojo import show

## Audio Dataset

For this demo, we use the [**LibriSpeech Corpus**](https://paperswithcode.com/dataset/librispeech), a collection of approximately 1,000 hours of audiobooks that are a part of the LibriVox project. Most of the audiobooks come from the Project Gutenberg. In this dataset, the audiobook narrations are already split up into chapters and short segments.

In [None]:
from pathlib import Path


In [None]:
data_dir = Path("../data/audio/tmp")

In [None]:
import torchaudio

librispeech_dataset = torchaudio.datasets.LIBRISPEECH(
    data_dir,
    url="dev-clean",
    download=True,
)

Let's have a look at an example audio segment.

In [None]:
waveform, sample_rate, _, _, _, _ = librispeech_dataset[0]  

In [None]:
# Play the audio
show.audio(waveform, sample_rate)

Now we process the dataset into a dataframe.

In [None]:
import pandas as pd


def create_audio_dataframe(dataset, base_path, num_samples=5):
    """
    Create a DataFrame containing audio players and metadata for a given audio dataset.
    
    Args:
        dataset (torchaudio.dataset): The audio dataset from which to load samples.
        base_path (str): The base directory where the dataset files are stored.
        num_samples (int): Number of samples to include in the DataFrame.
    
    Returns:
        pd.DataFrame: DataFrame with an 'Audio Player', file path, and metadata for each sample.
    """
    data = []
    for i in range(min(num_samples, len(dataset))):  # Process only available samples
        waveform, sample_rate, transcript, speaker_id, chapter_id, utterance_id = dataset[i]
        
        # Convert utterance_id to string and apply zero-padding
        utterance_id_str = str(utterance_id).zfill(4)
        file_name = f"{speaker_id}-{chapter_id}-{utterance_id_str}.flac"
        file_path = f"{base_path}/{speaker_id}/{chapter_id}/{file_name}"

        audio_player = show.audio(waveform, sample_rate)
        data.append({
            "Sample ID": i,
            "Audio Player": audio_player,
            "File Path": file_path,
            #"Sample Rate": sample_rate,
            "Transcript": transcript,
            "Speaker ID": speaker_id,
            "Chapter ID": chapter_id,
            "Utterance ID": utterance_id
        })
        

    data = pd.DataFrame(data)
    data = data.set_index(["Speaker ID", "Chapter ID", "Utterance ID"])
    return data

To limit transcription time, we take only a few rows from the top. Adjust this number if you want a bigger or smaller benchmark set. 

In [None]:
n_samples = 32
n_samples

In [None]:
# Create a DataFrame with audio players
audio_df = create_audio_dataframe(
    librispeech_dataset,
    num_samples=n_samples,
    base_path=f"{data_dir}/LibriSpeech/dev-clean/",
)

In [None]:
def display_dataframe_with_audio(df):
    from IPython.display import display
    # Render DataFrame with audio players using HTML representation
    display(df.style.format({'Audio Player': lambda x: x._repr_html_()}))


In [None]:

display_dataframe_with_audio(audio_df.head(5))

## Transcription with the Original Whisper Model

We now apply the original **whisper** model as published by OpenAI.

In [None]:
show.github_repo("https://github.com/openai/whisper")

In [None]:
models = {}  # collection of transcription models

In [None]:
import whisper


For a first demonstration, we use the smallest available model variant. There is a tradeoff between model size and transcript quality. 

In [None]:
model_type = "whisper"
model_variant = "tiny.en"

original_model_name = f"{model_type} {model_variant}"

# Load a Whisper model
models[original_model_name] = whisper.load_model(model_variant)
models[original_model_name]

We want to transcribe the audio with different models from different sources. Here is a function for that that aims to be generic.

In [None]:
def transcribe_audio(model, file_path):
    """
    Transcribe audio using the provided model (either Whisper or Hugging Face pipeline) given the file path.
    
    Args:
        model: Loaded Whisper model or Hugging Face ASR pipeline.
        file_path (str): Path to the audio file.
    
    Returns:
        str: The transcribed text.
    """
    # Check if the model is a Whisper model by checking for the 'transcribe' attribute
    if hasattr(model, 'transcribe'):
        result = model.transcribe(file_path)
        return result['text']
    # If it's not a Whisper model, assume it's a Hugging Face pipeline
    else:
        result = model(file_path)
        # Extract the transcription text from the pipeline output, which is typically a list of dictionaries
        return result['text'] if result else ""

Furthermore, we want to transcribe all the audio in the given dataframe. Here is a function that does this:

In [None]:
import time
from tqdm import tqdm
import warnings

def transcribe_data(df, model, output_col="Transcription"):
    """
    Adds transcriptions to the DataFrame containing audio file paths and measures the time taken to complete the transcription.
    Displays a progress bar and suppresses all warnings during the process.
    
    Args:
        df (pd.DataFrame): DataFrame containing the 'File Path' column.
        model: Loaded Whisper model or Hugging Face ASR pipeline.
        output_col (str): Name of the column to store the transcription results. Defaults to "Transcription".
    
    Returns:
        pd.DataFrame: Updated DataFrame with a new 'Transcription' column.
        float: Total time taken for the transcription process in seconds.
    """
    start_time = time.time()  # Start timing
    
    # Suppress warnings
    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        
        # Apply transcription to each file path in the DataFrame with a progress bar
        tqdm.pandas(desc="Transcribing Audio Files")
        df[output_col] = df['File Path'].progress_apply(lambda x: transcribe_audio(model, x))
    
    end_time = time.time()  # End timing
    total_time = end_time - start_time  # Calculate total time taken
    
    return df, total_time

Here we test the function on a single audio segment:

In [None]:
transcript = transcribe_audio(models[original_model_name], audio_df.iloc[0]["File Path"])
show.text(transcript)

We are now ready to start the transcription job.

In [None]:
transcription_time = {} # record the time to transcribe in [s]

In [None]:
output_col = f"Transcription {original_model_name}"
audio_data_transcribed, transcription_time[original_model_name] = transcribe_data(
    audio_df,
    models[original_model_name],
    output_col=output_col,
)

... and here is the result:

In [None]:
display_dataframe_with_audio(
    audio_data_transcribed.head(10)
)

## First Evaluation

The [**Word Error Rate (WER)**](https://en.wikipedia.org/wiki/Word_error_rate) is a suitable metric for the transcription quality. The simple formula for WER is:


$$\text{WER} = \frac{\text{S} + \text{D} + \text{I}}{\text{N}}$$

where:
- **S** is the number of substitutions (words that are incorrectly recognized),
- **D** is the number of deletions (words that are missed),
- **I** is the number of insertions (extra words that are added),
- **N** is the total number of words in the reference transcription.

WER is expressed as a percentage, representing the proportion of words that were incorrectly recognized. A lower WER indicates better performance of the speech recognition system.

In [None]:
import re
from jiwer import wer

def preprocess_text(text):
    """
    Preprocess the text by removing punctuation and converting to lowercase.
    
    Args:
        text (str): Input text string.
    
    Returns:
        str: Preprocessed text string.
    """
    # Remove punctuation using regex
    text = re.sub(r'[^\w\s]', '', text)
    # Convert text to lowercase to ensure case insensitivity
    text = text.lower()
    return text

def calculate_wer(ground_truth, transcription):
    """
    Calculate the Word Error Rate (WER) between ground truth and transcription.
    
    Args:
        ground_truth (str): The correct text.
        transcription (str): The output text from the speech recognition system.
    
    Returns:
        float: The Word Error Rate expressed as a percentage.
    """
    # Preprocess both ground truth and transcription
    ground_truth = preprocess_text(ground_truth)
    transcription = preprocess_text(transcription)
    
    # Calculate WER using jiwer library
    error_rate = wer(ground_truth, transcription)
    return error_rate * 100  # Convert to percentage


This function adds the WER as a column by comparing two transcript columns.

In [None]:
import pandas as pd

def compute_error_rate(df, ground_truth_col, transcription_col, error_rate_col='WER'):
    """
    Compute the Word Error Rate (WER) for each row in a DataFrame and add it as a new column.
    
    Args:
        df (pd.DataFrame): DataFrame containing the transcripts.
        ground_truth_col (str): Column name for the ground truth transcripts.
        transcription_col (str): Column name for the machine-generated transcripts.
        error_rate_col (str): Column name for the resulting error rate. Defaults to 'WER'.
    
    Returns:
        pd.DataFrame: The original DataFrame with an additional column for the error rate.
    """
    df[error_rate_col] = df.apply(
        lambda row: calculate_wer(row[ground_truth_col], row[transcription_col]),
        axis=1
    )
    return df

In [None]:
audio_data_transcribed = compute_error_rate(
    audio_data_transcribed,
    ground_truth_col="Transcript",
    transcription_col=output_col,
    error_rate_col=f"WER {original_model_name}"
)

In [None]:
display_dataframe_with_audio(
    audio_data_transcribed.head(10)
)

That was a pretty quick transcription process using the "tiny" variant of _whisper_. Are you happy with the results? There may be room for improvement by using the larger model variants. 

## A Bigger Model Variant for Better Transcripts

We move up two steps and choose the "medium" size variant of whisper. 

In [None]:
from ai_dojo.plot import model_size_comparison

In [None]:

# Example usage:
model_sizes = [("whisper tiny.en", 39e6), ("whispher medium.en", 769e6)]
model_size_comparison(model_sizes)

In [None]:
model_variant = "medium.en"
bigger_model_name = f"{model_type} {model_variant}"
bigger_model_name

In [None]:
models[bigger_model_name] = whisper.load_model(model_variant)
models[bigger_model_name]

We now repeat the transcription with the larger model. This is probably going to take a few minutes...

In [None]:
output_col = f"Transcription {bigger_model_name}"
audio_data_transcribed, transcription_time[bigger_model_name] = transcribe_data(
    audio_data_transcribed,
    models[bigger_model_name],
    output_col=output_col,
)

In [None]:
audio_data_transcribed = compute_error_rate(
    audio_data_transcribed,
    ground_truth_col="Transcript",
    transcription_col=output_col,
    error_rate_col=f"WER {bigger_model_name}"
)

In [None]:
display_dataframe_with_audio(
    audio_data_transcribed.head(10),
)

### Performance Comparison

In [None]:
import pandas as pd
import matplotlib.pyplot as plt

def compare_error_rates(data: pd.DataFrame):
    """
    Calculates the average of each column starting with "WER" in the DataFrame and plots a horizontal bar chart using pandas' built-in plotting capabilities.
    
    Args:
        data (pd.DataFrame): A DataFrame containing columns with names starting with "WER".
    """
    # Filter columns that start with 'WER'
    wer_columns = [col for col in data.columns if col.startswith('WER')]
    
    # Calculate the mean of these columns
    wer_means = data[wer_columns].mean()
    
    # Plotting using pandas' built-in plot method for a horizontal bar chart
    ax = wer_means.plot.barh(figsize=(8, 1), title='Average Word Error Rate (WER) Comparison')
    
    # Labeling axes
    ax.set_xlabel('Average WER (%)')
    ax.set_ylabel('Models')
    
    # Show the plot
    plt.show()

def compare_transcription_time(transcription_time):
     ax = (
        pd.Series(transcription_time.values(), index=transcription_time.keys())
        .plot(kind="barh", figsize=(8, 1))
    )
     ax.set_xlabel('Transcription Time [s]')
     ax.set_ylabel('Models')
    

In [None]:
compare_error_rates(audio_data_transcribed)

In [None]:
compare_transcription_time(transcription_time)

We have significantly improved transcript quality, but at the cost of a multiple of the compute time. Can we get both, high transcript quality and speed?

## Enter Distil-Whisper

**distil-whisper** is a distilled version of whisper for the English language. The distilled variants are 50% smaller, and the authors claim that the model transcribes up to 6 times faster with a very small quality loss. Let's put it to the test.


In [None]:
show.github_repo("https://github.com/huggingface/distil-whisper")

In [None]:

model_sizes += [("distil-whispher medium.en", 394e6)]
model_size_comparison(model_sizes)

We can obtain _distil-whisper_ from Huggingface hub as follows:

In [None]:
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline


In [None]:

model_id = "distil-whisper/distil-small.en"

# Load the model with the specified device and dtype settings
model = AutoModelForSpeechSeq2Seq.from_pretrained(
    model_id,
    low_cpu_mem_usage=True,
    use_safetensors=True,
)

# Load the processor
processor = AutoProcessor.from_pretrained(model_id)


In [None]:
# Create a pipeline for automatic speech recognition with the specified settings
model_type = "distil-whisper"
model_variant = "medium.en"
distilled_model_name = f"{model_type} {model_variant}"

models[distilled_model_name] = pipeline(
    "automatic-speech-recognition",
    model=model,
    tokenizer=processor.tokenizer,
    feature_extractor=processor.feature_extractor,
    #max_new_tokens=128,
)

In [None]:
models[distilled_model_name]

We now rerun the transcription job with a distilled version of the bigger, higher quality model.

In [None]:
output_col = f"Transcription {distilled_model_name}"
audio_data_transcribed, transcription_time[distilled_model_name] = transcribe_data(
    audio_df,
    models[distilled_model_name],
    output_col=output_col,
)

### Performance Comparison

In [None]:
audio_data_transcribed = compute_error_rate(
    audio_data_transcribed,
    ground_truth_col="Transcript",
    transcription_col=output_col,
    error_rate_col=f"WER {distilled_model_name}"
)

In [None]:
display_dataframe_with_audio(
    audio_data_transcribed.head(10)
)

In [None]:
compare_error_rates(audio_data_transcribed)

In [None]:
compare_transcription_time(transcription_time)

In [None]:
speedup = transcription_time["whisper medium.en"] / transcription_time["distil-whisper medium.en"] 
print(f"Speedup factor achieved: {round(speedup, 3)}")

## Conclusion

Our quick experiment generally confirms what the authors of _distil-whisper_ have reported: The distilled version indeed offers significant speedup at the cost of only a very small transcription quality loss. In many use cases, the distilled model will be preferable.

---
_This notebook is licensed under a [Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International (CC BY-NC-SA 4.0)](https://creativecommons.org/licenses/by-nc-sa/4.0/). Copyright © 2024 [Christian Staudt](https://clstaudt.me), [Katharina Rasch](https://krasch.io)_