In [1]:
import pandas as pd
import torchaudio
import os
from pathlib import Path
from transformers import Wav2Vec2Processor

# Read csv
df = pd.read_csv("../annotations.csv")

# Add full path to audio files (assuming they're in ../data folder)
data_dir = Path("../data")
df['audio_path'] = df['audio_file'].apply(lambda x: str(data_dir / f"{x}.mp3"))

print("Dataset preview:")
print(df.head())
print(f"Total samples: {len(df)}")

# Utility: Load audio and extract segments per word/event
def extract_event(audio_path, start, end):
    """Load MP3 audio and extract a segment"""
    try:
        waveform, sr = torchaudio.load(audio_path)
        # Resample to 16kHz if needed
        if sr != 16000:
            resampler = torchaudio.transforms.Resample(sr, 16000)
            waveform = resampler(waveform)
            sr = 16000
        
        # Convert start/end times to sample indices
        start_sample = int(start * sr)
        end_sample = int(end * sr)
        segment = waveform[:, start_sample:end_sample]
        return segment.squeeze()
    except Exception as e:
        print(f"Error loading {audio_path}: {e}")
        return None

# Map event type to integer label
label_map = {'fluent': 0, 'repetition': 1, 'prolongation': 2, 'block': 3}
df['label'] = df['stutter_type'].map(label_map)

print("\nLabel mapping:")
print(label_map)

Dataset preview:
  audio_file      word  start_time  end_time  stutter_type  \
0      data1  "really"         0.0       4.0        fluent   
1      data2       "I"         0.0       0.8    repetition   
2      data3  "really"         0.0       1.5         block   
3      data4  "really"         0.5       1.5  prolongation   

                               transcript (optional)         audio_path  
0  "I really love playing the guitar in the eveni...  ..\data\data1.mp3  
1  "I really love playing the guitar in the eveni...  ..\data\data2.mp3  
2  "I really love playing the guitar in the eveni...  ..\data\data3.mp3  
3  "I really love playing the guitar in the eveni...  ..\data\data4.mp3  
Total samples: 4

Label mapping:
{'fluent': 0, 'repetition': 1, 'prolongation': 2, 'block': 3}


In [2]:
import torch
from torch.utils.data import Dataset

class StutterEventDataset(Dataset):
    """Dataset class for stuttering event classification"""
    def __init__(self, df, processor):
        self.df = df
        self.processor = processor

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        
        # Extract audio segment
        audio_segment = extract_event(row['audio_path'], row['start_time'], row['end_time'])
        
        if audio_segment is None:
            # Return dummy data if loading fails
            input_values = torch.zeros(16000)
            label = row['label']
        else:
            # Process audio through processor
            input_values = self.processor(
                audio_segment.numpy(), 
                sampling_rate=16000, 
                return_tensors="pt"
            ).input_values.squeeze()
        
        label = torch.tensor(row['label'], dtype=torch.long)
        
        return {
            "input_values": input_values, 
            "labels": label
        }

# Create dataset instance
processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")
dataset = StutterEventDataset(df, processor)

print(f"Dataset created with {len(dataset)} samples")

Dataset created with 4 samples


In [3]:
from transformers import Wav2Vec2ForSequenceClassification
import torch

# Check device
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

# Load model for sequence classification (for classifying entire audio segments)
model = Wav2Vec2ForSequenceClassification.from_pretrained(
    "facebook/wav2vec2-base-960h",
    num_labels=4,  # 4 stutter event types: fluent, repetition, prolongation, block
)

model = model.to(device)
print("Model loaded and ready for fine-tuning")

Using device: cuda


Some weights of Wav2Vec2ForSequenceClassification were not initialized from the model checkpoint at facebook/wav2vec2-base-960h and are newly initialized: ['classifier.bias', 'classifier.weight', 'projector.bias', 'projector.weight', 'wav2vec2.masked_spec_embed']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Model loaded and ready for fine-tuning


In [6]:
from transformers import TrainingArguments, Trainer
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
import numpy as np

# Split dataset into train and validation sets
train_indices, val_indices = train_test_split(
    range(len(df)), 
    test_size=0.2,  # 20% for validation
    random_state=42
)

df_train = df.iloc[train_indices].reset_index(drop=True)
df_val = df.iloc[val_indices].reset_index(drop=True)

print(f"Training samples: {len(df_train)}")
print(f"Validation samples: {len(df_val)}")

# Create train and validation datasets
train_dataset = StutterEventDataset(df_train, processor)
eval_dataset = StutterEventDataset(df_val, processor)

# Compute metrics for evaluation (so Trainer produces eval_accuracy)
def compute_metrics(p):
    preds = p.predictions
    # some models return a tuple (logits, hidden_states,...)
    if isinstance(preds, tuple):
        preds = preds[0]
    # handle both regression and classification logits
    if preds is None:
        return {}
    if preds.ndim > 1:
        pred_labels = np.argmax(preds, axis=1)
    else:
        pred_labels = preds
    return {"accuracy": accuracy_score(p.label_ids, pred_labels)}

# Define training arguments
training_args = TrainingArguments(
    output_dir="../results",
    eval_strategy="epoch",
    save_strategy="epoch",
    learning_rate=3e-5,
    per_device_train_batch_size=4,  # Small batch size for small dataset
    per_device_eval_batch_size=4,
    num_train_epochs=20,  # More epochs for small dataset
    weight_decay=0.01,
    logging_steps=10,
    fp16=torch.cuda.is_available(),  # Use mixed precision if GPU available
    metric_for_best_model="accuracy",
    load_best_model_at_end=True,
)

# Create trainer (attach compute_metrics so 'eval_accuracy' will be available)
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    tokenizer=processor.feature_extractor,
    compute_metrics=compute_metrics,
)

print("Starting training...")
trainer.train()

Training samples: 3
Validation samples: 1


  trainer = Trainer(


Starting training...


Epoch,Training Loss,Validation Loss,Accuracy
1,No log,1.390625,0.0
2,No log,1.390625,0.0
3,No log,1.390625,0.0
4,No log,1.390625,0.0
5,No log,1.390625,0.0
6,No log,1.390625,0.0
7,No log,1.390625,0.0
8,No log,1.404297,0.0
9,No log,1.404297,0.0
10,1.389100,1.417969,0.0


TrainOutput(global_step=20, training_loss=1.3756510257720946, metrics={'train_runtime': 43.5856, 'train_samples_per_second': 1.377, 'train_steps_per_second': 0.459, 'total_flos': 2178883676160000.0, 'train_loss': 1.3756510257720946, 'epoch': 20.0})

In [7]:
# Create reverse mapping
reverse_label_map = {v: k for k, v in label_map.items()}

# Evaluate on validation set
print("\nEvaluating on validation set...")
eval_results = trainer.evaluate()

print("\nValidation Results:")
for key, value in eval_results.items():
    print(f"{key}: {value}")

# Make predictions on a sample
print("\n" + "="*50)
print("Sample Predictions")
print("="*50)

sample_idx = 0
sample = eval_dataset[sample_idx]
input_values = sample["input_values"].unsqueeze(0).to(device)

with torch.no_grad():
    outputs = model(input_values)
    logits = outputs.logits
    predicted_label_id = torch.argmax(logits, dim=-1).item()
    predicted_label = reverse_label_map[predicted_label_id]
    true_label = reverse_label_map[sample["labels"].item()]
    
    print(f"True label: {true_label}")
    print(f"Predicted label: {predicted_label}")
    print(f"Confidence scores: {torch.softmax(logits, dim=-1)[0].tolist()}")

# Save the model
model.save_pretrained("../models/wav2vec2_stutter_classifier")
processor.save_pretrained("../models/wav2vec2_stutter_classifier")
print("\nModel saved to ../models/wav2vec2_stutter_classifier")


Evaluating on validation set...



Validation Results:
eval_loss: 1.390625
eval_accuracy: 0.0
eval_runtime: 1.0395
eval_samples_per_second: 0.962
eval_steps_per_second: 0.962
epoch: 20.0

Sample Predictions
True label: repetition
Predicted label: prolongation
Confidence scores: [0.24441483616828918, 0.24895991384983063, 0.27545788884162903, 0.23116734623908997]

Model saved to ../models/wav2vec2_stutter_classifier


## Inference: Detecting Stuttering in Continuous Speech

Now we can use the trained model to detect stuttering in a full MP3 file. This involves:
1. Loading the full audio file
2. Transcribing it with wav2vec2 to get word boundaries (using CTC alignment)
3. Extracting segments for each word
4. Classifying each word segment

In [46]:
# Load the base wav2vec2 model for transcription and CTC alignment
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
import re

# Load a fresh processor and ASR model for getting word boundaries
asr_processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")
asr_model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h").to(device)
asr_model.eval()

def get_word_boundaries(audio_path):
    """
    Get word-level time boundaries from an audio file using CTC alignment.
    Returns a list of dictionaries with word and time boundaries.
    """
    # Load audio
    waveform, sr = torchaudio.load(audio_path)
    if len(waveform.shape) > 1:
        waveform = waveform[0:1]
    
    # Resample to 16kHz
    if sr != 16000:
        resampler = torchaudio.transforms.Resample(sr, 16000)
        waveform = resampler(waveform)
        sr = 16000
    
    # Get model inputs
    inputs = asr_processor(waveform.squeeze().numpy(), sampling_rate=16000, return_tensors="pt")
    input_values = inputs.input_values.to(device)
    
    # Get CTC output
    with torch.no_grad():
        outputs = asr_model(input_values)
    
    # Get logits and decode
    logits = outputs.logits[0].cpu()
    predicted_ids = torch.argmax(logits, dim=-1)
    
    # Get transcription
    transcription = asr_processor.decode(predicted_ids)
    
    # Simple approach: split by spaces to get words
    words = transcription.split()
    
    # Calculate frame-to-time ratio
    num_frames = logits.shape[0]
    total_duration = waveform.shape[1] / sr
    frame_duration = total_duration / num_frames
    
    # Rough estimation of word boundaries based on frame count
    # For better alignment, consider using Montreal Forced Aligner
    word_boundaries = []
    chars_processed = 0
    current_time = 0.0
    
    for word in words:
        # Estimate duration based on word length
        word_duration = len(word) * 0.08  # ~80ms per character (rough estimate)
        word_boundaries.append({
            'word': word,
            'start_time': current_time,
            'end_time': current_time + word_duration,
        })
        current_time += word_duration
    
    return word_boundaries, waveform, sr

print("Functions loaded for word-level stuttering detection")

Some weights of Wav2Vec2ForCTC were not initialized from the model checkpoint at facebook/wav2vec2-base-960h and are newly initialized: ['wav2vec2.masked_spec_embed']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Functions loaded for word-level stuttering detection


In [47]:
# Load the trained stuttering classifier
trained_model = Wav2Vec2ForSequenceClassification.from_pretrained(
    "../models/wav2vec2_stutter_classifier"
).to(device)
trained_model.eval()

classifier_processor = Wav2Vec2Processor.from_pretrained(
    "../models/wav2vec2_stutter_classifier"
)

print("Trained stuttering classifier loaded")

Trained stuttering classifier loaded


In [48]:
def detect_stuttering_in_speech(audio_path):
    """
    Detect stuttering at word level in a speech audio file.
    
    Returns:
    - results: List of dicts with word, time boundaries, and stuttering classification
    - waveform: Audio waveform
    - sr: Sample rate
    """
    # Get word boundaries
    word_boundaries, waveform, sr = get_word_boundaries(audio_path)
    
    results = []
    
    for word_info in word_boundaries:
        word = word_info['word']
        start_time = word_info['start_time']
        end_time = word_info['end_time']
        
        try:
            # Extract audio segment for this word
            start_sample = int(start_time * sr)
            end_sample = int(end_time * sr)
            
            # Ensure boundaries are valid
            if end_sample > waveform.shape[1]:
                end_sample = waveform.shape[1]
            if start_sample >= end_sample:
                continue
            
            audio_segment = waveform[:, start_sample:end_sample].squeeze()
            
            # Skip very short segments
            if audio_segment.shape[0] < sr * 0.05:  # Less than 50ms
                results.append({
                    'word': word,
                    'start_time': start_time,
                    'end_time': end_time,
                    'stutter_type': 'unknown',
                    'confidence': 0.0,
                    'error': 'segment too short'
                })
                continue
            
            # Process through classifier
            inputs = classifier_processor(
                audio_segment.numpy(),
                sampling_rate=16000,
                return_tensors="pt"
            )
            input_values = inputs.input_values.to(device)
            
            with torch.no_grad():
                outputs = trained_model(input_values)
                logits = outputs.logits[0]
                probabilities = torch.softmax(logits, dim=-1)
                predicted_label_id = torch.argmax(probabilities).item()
                confidence = probabilities[predicted_label_id].item()
            
            stutter_type = reverse_label_map[predicted_label_id]
            
            results.append({
                'word': word,
                'start_time': round(start_time, 2),
                'end_time': round(end_time, 2),
                'stutter_type': stutter_type,
                'confidence': round(confidence, 3),
                'all_confidences': {
                    label: round(prob.item(), 3) 
                    for label, prob in zip(reverse_label_map.values(), probabilities)
                }
            })
        
        except Exception as e:
            print(f"Error processing word '{word}': {e}")
            results.append({
                'word': word,
                'start_time': start_time,
                'end_time': end_time,
                'stutter_type': 'error',
                'confidence': 0.0,
                'error': str(e)
            })
    
    return results, waveform, sr

print("Stuttering detection function ready")

Stuttering detection function ready


In [50]:
# Install nltk for English dictionary
import nltk
nltk.download('words', quiet=True)
from nltk.corpus import words
from textblob import TextBlob

# Load English dictionary
english_words_set = set(words.words())

def is_valid_english_word(word):
    """
    Check if a word is valid English using multiple methods
    """
    word_lower = word.lower()
    
    # Method 1: Check in NLTK dictionary
    if word_lower in english_words_set:
        return True
    
    # Method 2: Check if removing repeated characters gives valid word
    # (handles cases like "rrreally" -> "really")
    cleaned_word = ''.join([word_lower[i] for i in range(len(word_lower)) 
                           if i == 0 or word_lower[i] != word_lower[i-1]])
    if cleaned_word != word_lower and cleaned_word in english_words_set:
        return False  # This is a stuttering variant, mark as invalid
    
    # Method 3: Use TextBlob for spell checking
    try:
        blob = TextBlob(word_lower)
        # If the word is corrected to something different, it's likely misspelled
        corrected = str(blob.correct()).lower()
        if corrected == word_lower:
            return True
        # If it can be corrected, it's likely a stutter
        return False
    except:
        pass
    
    return False

# Example: Run inference on a test audio file
test_audio_path = "../data/data4.mp3"  # Replace with your test MP3 file

# Check if file exists, if not create a demo message
import os
if os.path.exists(test_audio_path):
    print(f"Processing: {test_audio_path}")
    results, waveform, sr = detect_stuttering_in_speech(test_audio_path)
    
    all_words = [r['word'] for r in results]
    
    # SMART DETECTION: Find words that are NOT valid English
    # These are likely stuttering transcriptions (e.g., "rrrrreally", "iiiii")
    stuttering_results = []
    for r in results:
        word_lower = r['word'].lower()
        is_valid = is_valid_english_word(r['word'])
        
        if not is_valid:
            stuttering_results.append(r)
    
    print("\n" + "="*80)
    print("STUTTERING DETECTION RESULTS (Dictionary-based)")
    print("="*80)
    print(f"Full transcript: {' '.join(all_words)}")
    print("="*80)
    
    # DEBUG: Show which words are valid/invalid English
    print("\nDEBUG: Word validation:")
    print("-" * 80)
    for i, r in enumerate(results):
        is_valid = is_valid_english_word(r['word'])
        status = "✓ VALID" if is_valid else "✗ INVALID (likely stutter)"
        print(f"{i}: '{r['word']}' - {status}")
    print("-" * 80)
    
    if len(stuttering_results) == 0:
        print("\n✓ No stuttering detected. All words are valid English.")
        print(f"\nTranscript: {' '.join(all_words)}")
    else:
        print(f"\n⚠ Stuttering detected in {len(stuttering_results)} word(s):\n")
        print(f"{'Word':<20} {'Likely Intended':<25} {'Confidence':<12}")
        print("-"*80)
        
        for result in stuttering_results:
            # Try to find the closest valid English word by removing repeated characters
            word_lower = result['word'].lower()
            # Remove repeated characters to guess the intended word
            cleaned_word = ''.join([word_lower[i] for i in range(len(word_lower)) 
                                  if i == 0 or word_lower[i] != word_lower[i-1]])
            
            print(f"{result['word']:<20} {cleaned_word:<25} {result['confidence']:<12}")
        
        print("-"*80)
        
        # Highlight stuttering in transcript
        transcript_highlighted = []
        for word_result in results:
            is_valid = is_valid_english_word(word_result['word'])
            if not is_valid:
                transcript_highlighted.append(f"[{word_result['word'].upper()}](STUTTER)")
            else:
                transcript_highlighted.append(word_result['word'])
        
        print(f"\nHighlighted transcript:")
        print(" ".join(transcript_highlighted))
        
        print(f"\nSummary:")
        print(f"  Total words: {len(results)}")
        print(f"  Valid English words: {len(results) - len(stuttering_results)}")
        print(f"  Stuttering words detected: {len(stuttering_results)}")
        
        # Show the stuttering words
        print(f"\nStuttering words identified:")
        for result in stuttering_results:
            print(f"  - '{result['word']}'")
    
    # Save full results for reference
    import pandas as pd
    results_df = pd.DataFrame(results)
    results_df.to_csv("../stuttering_detection_results.csv", index=False)
    print(f"\nFull results saved to ../stuttering_detection_results.csv")
    
    # Save only stuttering results
    if len(stuttering_results) > 0:
        stuttering_df = pd.DataFrame(stuttering_results)
        stuttering_df.to_csv("../stuttering_instances.csv", index=False)
        print(f"Stuttering instances saved to ../stuttering_instances.csv")
else:
    print(f"Test audio file not found at {test_audio_path}")
    print("Please place your MP3 file in the data folder and update the path.")

Processing: ../data/data4.mp3

STUTTERING DETECTION RESULTS (Dictionary-based)
Full transcript: I RREALLY LOVE PLAYING THE GUITAR IN THE EVENINGS

DEBUG: Word validation:
--------------------------------------------------------------------------------
0: 'I' - ✓ VALID
1: 'RREALLY' - ✗ INVALID (likely stutter)
2: 'LOVE' - ✓ VALID
3: 'PLAYING' - ✓ VALID
4: 'THE' - ✓ VALID
5: 'GUITAR' - ✓ VALID
6: 'IN' - ✓ VALID
7: 'THE' - ✓ VALID
8: 'EVENINGS' - ✓ VALID
--------------------------------------------------------------------------------

⚠ Stuttering detected in 1 word(s):

Word                 Likely Intended           Confidence  
--------------------------------------------------------------------------------
RREALLY              realy                     0.268       
--------------------------------------------------------------------------------

Highlighted transcript:
I [RREALLY](STUTTER) LOVE PLAYING THE GUITAR IN THE EVENINGS

Summary:
  Total words: 9
  Valid English words: 8
  Stu

In [51]:
from difflib import get_close_matches
from nltk.corpus import words

def get_correct_word(stuttered_word):
    """
    Find the grammatically correct version of a stuttered word.
    Uses multiple strategies to find the best match.
    """
    word_lower = stuttered_word.lower()
    english_words_list = list(english_words_set)
    
    # Strategy 1: Remove repeated characters and check if it's valid
    cleaned_word = ''.join([word_lower[i] for i in range(len(word_lower)) 
                           if i == 0 or word_lower[i] != word_lower[i-1]])
    
    if cleaned_word in english_words_set:
        return cleaned_word, "Repeated character removal"
    
    # Strategy 2: Use difflib to find closest matches (fuzzy matching)
    close_matches = get_close_matches(word_lower, english_words_list, n=1, cutoff=0.6)
    
    if close_matches:
        return close_matches[0], "Fuzzy matching"
    
    # Strategy 3: Try TextBlob spell correction
    try:
        blob = TextBlob(word_lower)
        corrected = str(blob.correct()).lower()
        if corrected != word_lower:
            return corrected, "TextBlob spell check"
    except:
        pass
    
    # Strategy 4: Remove vowels and find similar consonant patterns
    def get_consonants(word):
        """Extract consonants from word"""
        return ''.join([c for c in word if c not in 'aeiou'])
    
    consonants = get_consonants(word_lower)
    if len(consonants) > 2:
        for candidate in english_words_list:
            if get_consonants(candidate) == consonants and len(candidate) < len(word_lower) + 3:
                return candidate, "Consonant pattern matching"
    
    return word_lower, "No match found"

# Test the function with the stuttering results
print("\n" + "="*100)
print("CORRECTED STUTTERED WORDS")
print("="*100)

if len(stuttering_results) > 0:
    print(f"\n{'Stuttered Word':<20} {'Corrected Word':<20} {'Correction Method':<30} {'Confidence':<12}")
    print("-"*100)
    
    for result in stuttering_results:
        stuttered = result['word']
        corrected, method = get_correct_word(stuttered)
        confidence = result['confidence']
        
        print(f"{stuttered:<20} {corrected:<20} {method:<30} {confidence:<12}")
    
    print("-"*100)
    
    # Create corrected transcript
    corrected_transcript = []
    for word_result in results:
        is_valid = is_valid_english_word(word_result['word'])
        if not is_valid:
            corrected_word, _ = get_correct_word(word_result['word'])
            corrected_transcript.append(corrected_word)
        else:
            corrected_transcript.append(word_result['word'])
    
    print(f"\nOriginal transcript: {' '.join(all_words)}")
    print(f"Corrected transcript: {' '.join(corrected_transcript)}")
else:
    print("\nNo stuttering detected. Transcript is already correct.")
    print(f"Transcript: {' '.join(all_words)}")

print("="*100)


CORRECTED STUTTERED WORDS

Stuttered Word       Corrected Word       Correction Method              Confidence  
----------------------------------------------------------------------------------------------------
RREALLY              really               Fuzzy matching                 0.268       
----------------------------------------------------------------------------------------------------
RREALLY              really               Fuzzy matching                 0.268       
----------------------------------------------------------------------------------------------------

Original transcript: I RREALLY LOVE PLAYING THE GUITAR IN THE EVENINGS
Corrected transcript: I really LOVE PLAYING THE GUITAR IN THE EVENINGS

Original transcript: I RREALLY LOVE PLAYING THE GUITAR IN THE EVENINGS
Corrected transcript: I really LOVE PLAYING THE GUITAR IN THE EVENINGS


## Generate Audio Pronunciation of Corrected Words

Use Google Text-to-Speech (gTTS) to generate proper pronunciation of the corrected stuttering words:

In [53]:
from gtts import gTTS
from IPython.display import Audio
import os

def generate_pronunciation(word, output_dir="../audio_pronunciations"):
    """
    Generate audio pronunciation of a word using Google Text-to-Speech
    
    Args:
        word: The word to pronounce
        output_dir: Directory to save audio files
    
    Returns:
        path to the audio file
    """
    # Create output directory if it doesn't exist
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    
    # Create filename
    filename = f"{output_dir}/{word.lower()}_pronunciation.mp3"
    
    try:
        # Generate speech using gTTS
        tts = gTTS(text=word, lang='en', slow=False)
        tts.save(filename)
        print(f"✓ Generated pronunciation: {word} -> {filename}")
        return filename
    except Exception as e:
        print(f"✗ Error generating pronunciation for '{word}': {e}")
        return None

# Generate pronunciations for all stuttering words
print("\n" + "="*100)
print("GENERATING PRONUNCIATIONS FOR CORRECTED WORDS")
print("="*100)

if len(stuttering_results) > 0:
    print(f"\nGenerating audio pronunciations...\n")
    
    pronunciation_files = {}
    
    for result in stuttering_results:
        stuttered = result['word']
        corrected, method = get_correct_word(stuttered)
        
        # Generate pronunciation for corrected word
        audio_file = generate_pronunciation(corrected)
        
        if audio_file:
            pronunciation_files[corrected] = audio_file
    
    print(f"\n" + "-"*100)
    print("\nPlayable Pronunciations:")
    print("-"*100)
    
    for result in stuttering_results:
        stuttered = result['word']
        corrected, method = get_correct_word(stuttered)
        
        if corrected in pronunciation_files:
            print(f"\nStuttered: {stuttered}")
            print(f"Corrected: {corrected}")
            print(f"Pronunciation file: {pronunciation_files[corrected]}")
            print(f"Correction method: {method}")
            
            # Display audio player
            try:
                display(Audio(pronunciation_files[corrected]))
            except:
                print(f"(Audio file saved but cannot display in this environment)")
    
    print(f"\n" + "-"*100)
    print(f"\nAll pronunciation files saved to: ../audio_pronunciations/")
    print(f"Total pronunciations generated: {len(pronunciation_files)}")
    
else:
    print("\nNo stuttering detected. No pronunciations needed.")

print("="*100)


GENERATING PRONUNCIATIONS FOR CORRECTED WORDS

Generating audio pronunciations...

✓ Generated pronunciation: really -> ../audio_pronunciations/really_pronunciation.mp3

----------------------------------------------------------------------------------------------------

Playable Pronunciations:
----------------------------------------------------------------------------------------------------

Stuttered: RREALLY
Corrected: really
Pronunciation file: ../audio_pronunciations/really_pronunciation.mp3
Correction method: Fuzzy matching



----------------------------------------------------------------------------------------------------

All pronunciation files saved to: ../audio_pronunciations/
Total pronunciations generated: 1
