# Full Audio Analysis Pipeline: Classification, Transcription & Correction

This notebook demonstrates a complete, multi-stage pipeline for analyzing and processing speech audio. The system showcases three key AI capabilities:

1.  **Stutter Detection (Classification):** A custom-trained `ResNet18` model, which analyzes spectrogram images of the audio, classifies the speech into different levels of fluency.
2.  **Speech-to-Text (Transcription):** OpenAI's `Whisper` model is used to generate a highly accurate transcription of the original audio content.
3.  **Corrected Audio Generation (Text-to-Speech):** The transcribed text is then used by the `pyttsx3` library to generate a new, clean version of the audio without any speech impediments.

This demo will process a sample audio file through this entire pipeline to showcase an end-to-end solution.

In [1]:
# Import Libraries and Load All Models

# --- Core Libraries ---
import torch
import torch.nn as nn
import torchvision.models as models
from torchvision import transforms
from PIL import Image
import librosa
import numpy as np
import os
import logging
from IPython.display import display, Audio
import whisper
import pyttsx3 
# Configure logging for this session
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)-8s - %(message)s', force=True)

# --- 1. Load the Stutter Classification Model (ResNet18) ---

class CustomResNet(nn.Module):
    def __init__(self, num_classes=3):
        super(CustomResNet, self).__init__()
        self.model = models.resnet18(weights='DEFAULT')
        self.model.fc = nn.Linear(self.model.fc.in_features, num_classes)
    
    def forward(self, x):
        return self.model(x)

CLASSIFICATION_MODEL_PATH = 'model/resnet18_stutter_detection.pth' 

try:
    # Initialize the model structure
    classification_model = CustomResNet(num_classes=3)
    # Load the trained weights
    classification_model.load_state_dict(torch.load(CLASSIFICATION_MODEL_PATH, map_location=torch.device('cpu')))
    # Set the model to evaluation mode
    classification_model.eval()
    logging.info(f"Classification model loaded successfully from '{CLASSIFICATION_MODEL_PATH}'")
    print(" Stutter Classification model (ResNet18) loaded.")

except Exception as e:
    logging.error(f"Failed to load classification model: {e}")
    print(f" Error: Could not load the classification model. Make sure '{CLASSIFICATION_MODEL_PATH}' is correct.")

# --- 2. Load the Transcription Model (Whisper) ---
try:
    # Load the 'tiny' version of the Whisper model
    transcription_model = whisper.load_model("tiny")
    logging.info("Whisper transcription model (tiny) loaded successfully.")
    print(" Transcription model (Whisper) loaded.")
except Exception as e:
    logging.error(f"Failed to load Whisper model: {e}")
    print(" Error: Could not load the Whisper model.")



2025-08-06 19:02:10,416 - INFO     - Classification model loaded successfully from 'model/resnet18_stutter_detection.pth'


 Stutter Classification model (ResNet18) loaded.


2025-08-06 19:02:11,058 - INFO     - Whisper transcription model (tiny) loaded successfully.


 Transcription model (Whisper) loaded.


### 1.1 Audio Preprocessing Function

This function will contain the complete pipeline to transform a raw audio file into a spectrogram image tensor, ready to be fed into our trained ResNet model. This logic is adapted directly from the original project to ensure consistency.

In [2]:
# Define the audio-to-spectrogram processing function

# --- Audio Parameters from your original project ---
SR = 16000
N_MELS = 128
FMAX = 8000

# --- Image Transformations ---
# This transform sequence should be the same as the one used during training.
# We use standard ImageNet normalization as a robust default.
data_transforms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

def process_audio_for_classification(audio_path):
    """
    Takes the path to an audio file, converts it to a Mel Spectrogram,
    and then transforms it into a tensor ready for the ResNet model.
    """
    try:
        # 1. Load the audio file with the specified sample rate
        y, _ = librosa.load(audio_path, sr=SR)
        if len(y) == 0:
            logging.warning(f"Audio file is empty: {audio_path}")
            return None

        # 2. Generate Mel Spectrogram
        mel_spec = librosa.feature.melspectrogram(y=y, sr=SR, n_mels=N_MELS, fmax=FMAX)
        mel_db = librosa.power_to_db(mel_spec, ref=np.max)

        # 3. Convert spectrogram to a 3-channel PIL Image
        # Handle potential infinite values
        mel_db_finite = np.nan_to_num(mel_db, nan=np.min(mel_db), posinf=np.max(mel_db), neginf=np.min(mel_db))
        # Normalize to 0-1 range for image conversion
        mel_normalized = (mel_db_finite - np.min(mel_db_finite)) / (np.max(mel_db_finite) - np.min(mel_db_finite) + 1e-6)
        img_array = (mel_normalized * 255).astype(np.uint8)
        img = Image.fromarray(img_array).convert("RGB")
        
        # 4. Apply the transformations to get the final tensor
        tensor = data_transforms(img).unsqueeze(0) # Add batch dimension
        
        return tensor

    except Exception as e:
        logging.error(f"Failed to process audio file {audio_path}. Error: {e}")
        return None

print(" Audio processing function is ready.")

 Audio processing function is ready.


## 2. End-to-End Analysis Pipeline

Now we will run the complete analysis pipeline on our sample audio files. For each audio clip in the `data/` folder, the system will perform:
1.  **Classification:** Predict the stuttering level using the ResNet model.
2.  **Transcription:** Transcribe the speech to text using the Whisper model.
3.  **Correction:** Generate a new, corrected audio file from the transcribed text using pyttsx3.

In [None]:
# Run the full pipeline on sample audio files

# --- Configuration ---
DATA_FOLDER = 'data/'
# Define the mapping from the model's output index to a human-readable label
class_mapping = {
    0: "Normal Fluency",
    1: "Mild Stutter Detected",
    2: "Significant Stutter Detected"
}
# Initialize the Text-to-Speech engine
try:
    tts_engine = pyttsx3.init()
except Exception as e:
    tts_engine = None
    logging.warning(f"Could not initialize pyttsx3 engine: {e}. TTS will be disabled.")


# --- Find and process audio files ---
audio_files = [f for f in os.listdir(DATA_FOLDER) if f.lower().endswith(('.wav', '.mp3','.m4a'))]

if not audio_files:
    print(f" No audio files found in the '{DATA_FOLDER}' folder. Please add samples to test.")
else:
    for audio_file in audio_files:
        audio_path = os.path.join(DATA_FOLDER, audio_file)
        
        print("\n" + "="*50)
        print(f"---  Processing : {audio_file} ---")
        print("="*50)

        # --- Display Original Audio ---
        print("\n Original Audio:")
        display(Audio(audio_path))
        
        # --- 1. Classification ---
        print("\n Running Classification...")
        stuttering_level = "Error"
        try:
            tensor = process_audio_for_classification(audio_path)
            if tensor is not None:
                with torch.no_grad():
                    outputs = classification_model(tensor)
                prediction_index = torch.argmax(outputs).item()
                stuttering_level = class_mapping.get(prediction_index, "Unknown")
        except Exception as e:
            logging.error(f"Classification failed for {audio_path}: {e}")
        print(f"   => Predicted Level: {stuttering_level}")

        # --- 2. Transcription (Speech-to-Text) ---
        print("\n Running Transcription...")
        transcribed_text = "Error"
        try:
            result = transcription_model.transcribe(audio_path)
            transcribed_text = result['text']
        except Exception as e:
            logging.error(f"Transcription failed for {audio_path}: {e}")
        print(f"   => Transcribed Text: \"{transcribed_text}\"")

        # --- 3. Corrected Audio Generation (Text-to-Speech) ---
        print("\n Generating Corrected Audio...")
        if tts_engine and transcribed_text != "Error":
            try:
                corrected_audio_path = os.path.join(DATA_FOLDER, f"corrected_{audio_file}.wav")
                tts_engine.save_to_file(transcribed_text, corrected_audio_path)
                tts_engine.runAndWait()
                print(f"   => Corrected audio saved. Displaying player:")
                display(Audio(corrected_audio_path))
            except Exception as e:
                logging.error(f"TTS failed for {audio_path}: {e}")
                print("   => Could not generate corrected audio.")
        else:
            print("   => TTS engine not available or transcription failed.")


---  Processing : corrected_output_.mp3.wav ---

 Original Audio:



 Running Classification...
   => Predicted Level: Significant Stutter Detected

 Running Transcription...




   => Transcribed Text: " Hello, I am Yazen."

---  Processing : output_.mp3 ---

 Original Audio:



 Running Classification...
   => Predicted Level: Significant Stutter Detected

 Running Transcription...
   => Transcribed Text: " Hello, I am Yazan."
