# Transcription Pipeline for Speech-Language Audio

## Before you run the notebook...

1. All libraries below should be installed.
2. Have your HuggingFace token ready, along with other personal inputs for the following section.
3. All `.wav` files to transcript should be prepared in your `raw_audio_folder`.
4. Be sure user settings are correct in the first code block below.

Otherwise, **you should be able to run the entire notebook to process your audio files**.

## Customize User Settings

This next section includes all personalized sections. Change all elements to reflect your device, data, and settings. Everything afterward should run without issue.

In [None]:
# --- 1. User Settings ---
WORKING_DIR = r"C:\Users\USERNAME\Documents\asr" #your directory goes here
AUDIO_FOLDER = "raw_audio_folder/"
OUTPUT_FOLDER = "transcription_output/"

# --- 2. HuggingFace Token ---
HUGGINGFACE_TOKEN = "YOUR_TOKEN_HERE
# --- 3. Whisper Model Choice ---
WHISPER_MODEL = "base.en"  # Options: "tiny.en", "base.en", "small.en", "medium.en", "large-v3", etc.

# --- 4. Processing Options ---
DIARIZATION = True   # True to enable speaker diarization, otherwise False
LEVEL = "WORD"    # Options: "WORD" or "SEGMENT"
EXPORT_AS = "CSV"    # Options: "CSV" or "TXT"

print("User settings loaded.")

## Import Libraries

In [None]:
# --- Required Libraries ---
import stable_whisper
import pandas as pd # aliases are often used to avoid having to refer to the library by its full name 
import torch
import os
import json
import gc

In [None]:
from pyannote.audio.pipelines.utils.hook import ProgressHook
from pyannote.core import Timeline, Segment
from pyannote.database.util import load_rttm
from pyannote.audio import Pipeline

In [None]:
from collections import defaultdict

## Load Whisper Model

In [None]:
model = stable_whisper.load_model(WHISPER_MODEL)

## Set Up Directories and Diarization Pipeline

In [None]:
# --- Set Working Directory ---
os.chdir(WORKING_DIR)
os.makedirs(AUDIO_FOLDER, exist_ok=True)
os.makedirs(OUTPUT_FOLDER, exist_ok=True)

# --- Set up Pyannote Pipeline if Diarization is Enabled ---
if (DIARIZATION == True):
    pipeline = Pipeline.from_pretrained(
        "pyannote/speaker-diarization-3.1",
        use_auth_token=HUGGINGFACE_TOKEN
    )
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    pipeline.to(device)

## RUN WHISPER AND PYANNOTE ON ALL .WAV FILES

### Helper Functions

In [None]:
# --- Helper Function to Check Overlap ---
def is_overlap(start1, end1, start2, end2):
    return max(start1, start2) < min(end1, end2)

In [None]:
def diarize(audio_file_path):
    
    with ProgressHook() as hook:
          diarization = pipeline(audio_file_path, hook=hook, min_speakers=2, max_speakers=2) #<--this indicates anticipated # of speakers
    
    diarization_list = [{
                'start': segment.start,
                'end': segment.end,
                'speaker': list(track.values())[0]
            } for segment, track in diarization._tracks.items()]

    diarization_df = pd.DataFrame(diarization_list)

    return diarization_df 

In [None]:
def align_diarization_and_transcription(speaker_segs_df, df_segments):
    
    labels = []
    predominant_labels = []
    max_overlaps = []
    durations = []
    all_overlaps = []
    
      
    for i, row in df_segments.iterrows():
        
        overlaps = speaker_segs_df.apply(
            lambda x: (x['speaker'], min(row['end'], x['end']) - max(row['start'], x['start'])) 
            if is_overlap(row['start'], row['end'], x['start'], x['end']) else (None, 0), 
            axis=1
        )
        
        # Filter out non-overlapping entries
        overlaps = overlaps[overlaps.apply(lambda x: x[0] is not None)]
        
        # Extract labels and their corresponding overlap times
        overlapping_labels = overlaps.apply(lambda x: x[0])
        
        overlapping_labels.reset_index(drop=True, inplace=True)
        
        labels.append(" ".join(overlapping_labels))
        
        # Initialize a defaultdict to store the summed values
        collapsed_dict = defaultdict(float)
        
        # Iterate over the series and sum values for each key
        for item in overlaps:
            key, value = item
            collapsed_dict[key] += value
        
        # Convert the defaultdict back to a list of tuples 
        collapsed_list = list(collapsed_dict.items())
        non_empty_label_overlap = [item for item in collapsed_list if item[0] is not None]
        collapsed_series = pd.Series(collapsed_list)
        
        overlap_times = collapsed_series.apply(lambda x: x[1])
        overlap_labels = collapsed_series.apply(lambda x: x[0])
        
        
        overlap_times.reset_index(drop=True, inplace=True)
        overlap_labels.reset_index(drop=True, inplace=True)
        
        # Determine the predominant label (the one with the greatest overlap time)
        if not overlap_times.empty:
           # print(overlap_times)
            predominant_label = overlap_labels.iloc[overlap_times.idxmax()]
            predominant_labels.append(predominant_label)
            max_overlap = overlap_times.iloc[overlap_times.idxmax()]
            max_overlaps.append(max_overlap)
            all_overlaps.append(non_empty_label_overlap)
        else:
            predominant_labels.append("")
            max_overlaps.append("")
            all_overlaps.append("")
        
    df_segments["predominant_speaker"] = predominant_labels

    return df_segments

### Main Processing Loop

In [None]:
def process_audio_file(model, audio_file_path, level, diarization = False):
    result = model.transcribe(audio_file_path)

    data = result.ori_dict

    if diarization == True:
        diarization_df = diarize(audio_file_path)


    if level == "WORD":
        words_data = []
        for segment in data['segments']:
            for word in segment['words']:
                words_data.append({
                    'word': word['word'],
                    'start': word['start'],
                    'end': word['end'],
                    'probability': word['probability']
                })
        
        timestamp_df = pd.DataFrame(words_data)

    elif level == "SEGMENT":
        utterances_data = []
        for segment in data['segments']:
            utterances_data.append({
                'text': segment['text'],
                'start': segment['start'],
                'end': segment['end'],
                'id': segment['id']
            })
        
        timestamp_df = pd.DataFrame(utterances_data)

    if diarization == True:
        diarization_and_timestamp_df = align_diarization_and_transcription(diarization_df, timestamp_df)
        return diarization_and_timestamp_df 
    else:
        return timestamp_df

In [None]:
def df_to_textgrid(df):
    # Create a new TextGrid object
    tg = textgrid.Textgrid()
    
    min_time = df['start'].min()
    max_time = df['end'].max()
    
    # Check if 'predominant_speaker' column exists
    has_speaker_info = 'predominant_speaker' in df.columns
    
    if has_speaker_info:
        # Create an IntervalTier for each unique speaker
        speakers = df['predominant_speaker'].unique()
        for speaker in speakers:
            # Create a tier for the speaker
            speaker_tier =  IntervalTier(str(speaker), [], minT=min_time, maxT=max_time)
            
            # Filter the DataFrame for this speaker
            speaker_df = df[df['predominant_speaker'] == speaker]
            
            # Add intervals to the tier
            for _, row in speaker_df.iterrows():
                interval = Interval(row['start'], row['end'], row['word'])
                speaker_tier.insertEntry(interval, collisionMode='merge')
            
            # Add tier to TextGrid
            tg.addTier(speaker_tier)
    
    # Create a transcription tier (this will be created regardless of speaker info)
    trans_tier = IntervalTier('transcription', [], minT=min_time, maxT=max_time)

    
    for _, row in df.iterrows():
        interval = Interval(row['start'], row['end'], row['text'])
        trans_tier.insertEntry(interval, collisionMode='merge')
    
    tg.addTier(trans_tier)
    
    return tg

In [None]:
audio_folder = AUDIO_FOLDER
output_folder = OUTPUT_FOLDER

os.makedirs(output_folder, exist_ok=True)

for filename in os.listdir(audio_folder):
    if filename.lower().endswith(('.wav', '.mp3', '.flac')):  # Add or remove audio formats as needed
        audio_file_path = os.path.join(audio_folder, filename)
        
        print(f"Processing: {filename}")
        
        # Process the audio file
        df_transcript = process_audio_file(model, audio_file_path, level = LEVEL, diarization = DIARIZATION)
        
        # Generate output filename (without extension)
        output_filename = os.path.splitext(filename)[0]
        
        if EXPORT_AS == "CSV":
            output_path = os.path.join(output_folder, f"{output_filename}.csv")
            df_transcript.to_csv(output_path, index=False)
        elif EXPORT_AS == "TXT":
            output_path = os.path.join(output_folder, f"{output_filename}.txt")
            with open(output_path, 'w') as f:
                df_string = df_transcript.to_string(header=False, index=False)
                f.write(df_string)
        else:
            print(f"Unsupported export format: {export_format}")
            continue
        
        print(f"Saved transcription to: {os.path.basename(output_path)}")

    gc.collect()
    torch.cuda.empty_cache()