In [None]:
!pip install opensmile gdown transformers soundfile shap --quiet

[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/996.0 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m996.0/996.0 kB[0m [31m60.1 MB/s[0m eta [36m0:00:00[0m
[?25h[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/70.3 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m70.3/70.3 kB[0m [31m7.9 MB/s[0m eta [36m0:00:00[0m
[?25h[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/44.0 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.0/44.0 kB[0m [31m4.7 MB/s[0m eta [36m0:00:00[0m
[?25h[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/41.9 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m41.9/41.9 kB[0m [31m4.2 MB/s[0m eta [36m0:00:00[0m
[?25h[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

In [None]:
import os
import tarfile
from pathlib import Path
import os
import tempfile
import concurrent.futures
from pathlib import Path
import matplotlib.pyplot as plt
import librosa
import librosa.display
from IPython.display import Audio, display
import pandas as pd
import numpy as np
import soundfile as sf
from transformers import pipeline
import torch

from transformers import DistilBertTokenizer, DistilBertModel
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score, classification_report

import shap
from sklearn.ensemble import RandomForestClassifier

### Accessing Dataset

In [None]:
# install dataset using gdown
!gdown 1p4ZQOwbHkD2RAvq2K5ekY5gcMi24XLnS

tarball_filename = "ADReSSo21-diagnosis-train.tar"

# extract contents from tarball
with tarfile.open(tarball_filename, 'r:*') as tar:
    tar.extractall(path="./")

# remove tarball after extraction
os.remove(tarball_filename)

Downloading...
From (original): https://drive.google.com/uc?id=1p4ZQOwbHkD2RAvq2K5ekY5gcMi24XLnS
From (redirected): https://drive.google.com/uc?id=1p4ZQOwbHkD2RAvq2K5ekY5gcMi24XLnS&confirm=t&uuid=b3ddf379-a455-4aa3-a747-bf6402e534b9
To: /content/ADReSSo21-diagnosis-train.tar
100% 1.75G/1.75G [00:18<00:00, 92.3MB/s]


In [None]:
def get_dataset_paths(dataset_root: str = "ADReSSo21/diagnosis/train"):
    """
    Return a dictionary containing paths for the audio and segmentation directories.
    """
    dataset_dir = Path(dataset_root)
    return {
        "audio": dataset_dir / "audio",
        "audio_ad": dataset_dir / "audio" / "ad",
        "audio_cn": dataset_dir / "audio" / "cn",
        "segmentation": dataset_dir / "segmentation",
        "segmentation_ad": dataset_dir / "segmentation" / "ad",
        "segmentation_cn": dataset_dir / "segmentation" / "cn",
    }

def load_audio_file(file_path: Path):
    """Load the audio waveform and its sampling rate using soundfile."""
    return sf.read(file_path)

def load_segmentation(seg_file: Path):
    """Load the segmentation CSV into a Pandas DataFrame."""
    return pd.read_csv(seg_file)

def extract_patient_segments(audio: np.ndarray, sr: int, seg_df: pd.DataFrame, speaker: str = "PAR"):
    """
    Extract patient segments from the audio as specified by the segmentation CSV.

    Returns:
      - patient_mask: Array with patient segments (NaN elsewhere).
      - concatenated: All patient segments concatenated into a single array.
      - segments: List of (begin_sample, end_sample) tuples.
    """
    patient_df = seg_df[seg_df["speaker"] == speaker]
    patient_mask = np.full_like(audio, np.nan)
    segments_list = []
    segments = []
    for _, row in patient_df.iterrows():
        begin_sample = int(float(row["begin"]) * sr / 1000)
        end_sample = int(float(row["end"]) * sr / 1000)
        patient_mask[begin_sample:end_sample] = audio[begin_sample:end_sample]
        segments_list.append(audio[begin_sample:end_sample])
        segments.append((begin_sample, end_sample))
    concatenated = np.concatenate(segments_list) if segments_list else np.array([])
    return patient_mask, concatenated, segments

def process_audio(audio_file: Path, seg_file: Path, plot: bool = False):
    """
    Load an audio file and its corresponding segmentation CSV,
    extract the patient-only segments, and optionally plot the waveform.

    Returns:
      audio, sr, patient_mask, concatenated (patient-only audio), segments.
    """
    audio, sr = load_audio_file(audio_file)
    seg_df = load_segmentation(seg_file)
    patient_mask, concatenated, segments = extract_patient_segments(audio, sr, seg_df)
    if plot:
        time_axis = np.linspace(0, len(audio) / sr, num=len(audio))
        plt.figure(figsize=(14, 4))
        plt.plot(time_axis, audio, label="Original")
        plt.plot(time_axis, patient_mask, label="Patient-Only")
        plt.xlabel("Time (s)")
        plt.ylabel("Amplitude")
        plt.title("Patient Speech Isolation")
        plt.legend()
        plt.show()
    return audio, sr, patient_mask, concatenated, segments

def init_transcriber(model_name: str = "openai/whisper-large", device: int = 0):
    """
    Initialize the Hugging Face ASR pipeline using GPU.

    Set device=0 to use your GPU.
    """
    return pipeline("automatic-speech-recognition", model=model_name, device=device)

def transcribe_audio_file(file_path: str, transcriber) -> str:
    """
    Transcribe an audio file given its file path using the provided transcriber.

    Returns the transcription text.
    """
    result = transcriber(file_path, return_timestamps=True)
    return result.get("text", "")

def create_transcription_df(transcription_records: list) -> pd.DataFrame:
    """
    Create a Pandas DataFrame from a list of transcription records.
    """
    return pd.DataFrame(transcription_records)

# -------------------------------------------------------------------------------
# Main Function: Get Transcripts for a Specific Group (Process All Files)
# -------------------------------------------------------------------------------
def get_transcripts_for_group(group: str):
    """
    Process all audio files from a specified group ("ad" or "cn") on GPU,
    print the word count for each transcript, and save the sorted transcripts
    to a CSV file.

    The CSV file is named differently based on the group.
    """
    paths = get_dataset_paths()
    device = 0  # Use GPU (device index 0)
    model_name = "openai/whisper-large"

    if group.lower() == "ad":
        audio_files = sorted(list(paths["audio_ad"].glob("*.wav")), key=lambda f: f.name)
        seg_dir = paths["segmentation_ad"]
        csv_filename = "sorted_patient_transcriptions_ad.csv"
    elif group.lower() == "cn":
        audio_files = sorted(list(paths["audio_cn"].glob("*.wav")), key=lambda f: f.name)
        seg_dir = paths["segmentation_cn"]
        csv_filename = "sorted_patient_transcriptions_cn.csv"
    else:
        raise ValueError("Unknown group. Please specify 'ad' or 'cn'.")

    print(f"Processing {len(audio_files)} audio files from the '{group.upper()}' group on GPU.")
    transcripts = []
    transcriber = init_transcriber(model_name=model_name, device=device)  # Initialize once for efficiency

    for audio_file in audio_files:
        seg_file = seg_dir / f"{audio_file.stem}.csv"

        # Process audio to extract the patient-only audio
        audio, sr, patient_mask, concatenated, segments = process_audio(audio_file, seg_file, plot=False)
        if concatenated.size == 0:
            transcript = "No patient speech segments found."
        else:
            # Write the concatenated patient-only audio to a temporary WAV file
            with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp_file:
                temp_filename = tmp_file.name
            sf.write(temp_filename, concatenated, sr)
            transcript = transcribe_audio_file(temp_filename, transcriber)
            os.remove(temp_filename)

        transcripts.append({"file_name": audio_file.name, "transcription": transcript})
        word_count = len(transcript.split())
        print(f"File '{audio_file.name}': {word_count} words in transcript.")

    # Sort the transcripts by file name and save to CSV
    transcripts = sorted(transcripts, key=lambda x: x["file_name"])
    df = create_transcription_df(transcripts)
    df.to_csv(csv_filename, index=False)
    print(f"All sorted transcriptions for the '{group.upper()}' group saved to {csv_filename}")


### Transcription


In [None]:
# Process all audio files from the AD group and save to CSV
get_transcripts_for_group("ad")
# Process all audio files from the CN group and save to CSV
get_transcripts_for_group("cn")

Processing 87 audio files from the 'AD' group on GPU.


Device set to use cuda:0


File 'adrso024.wav': 182 words in transcript.
File 'adrso025.wav': 195 words in transcript.
File 'adrso027.wav': 83 words in transcript.
File 'adrso028.wav': 47 words in transcript.


Whisper did not predict an ending timestamp, which can happen if audio is cut off in the middle of a word. Also make sure WhisperTimeStampLogitsProcessor was used during generation.


File 'adrso031.wav': 90 words in transcript.
File 'adrso032.wav': 184 words in transcript.


Whisper did not predict an ending timestamp, which can happen if audio is cut off in the middle of a word. Also make sure WhisperTimeStampLogitsProcessor was used during generation.


File 'adrso033.wav': 72 words in transcript.


Whisper did not predict an ending timestamp, which can happen if audio is cut off in the middle of a word. Also make sure WhisperTimeStampLogitsProcessor was used during generation.


File 'adrso035.wav': 84 words in transcript.
File 'adrso036.wav': 40 words in transcript.
File 'adrso039.wav': 28 words in transcript.


Whisper did not predict an ending timestamp, which can happen if audio is cut off in the middle of a word. Also make sure WhisperTimeStampLogitsProcessor was used during generation.


File 'adrso043.wav': 250 words in transcript.
File 'adrso045.wav': 291 words in transcript.
File 'adrso046.wav': 108 words in transcript.
File 'adrso047.wav': 94 words in transcript.
File 'adrso049.wav': 151 words in transcript.
File 'adrso053.wav': 201 words in transcript.


Whisper did not predict an ending timestamp, which can happen if audio is cut off in the middle of a word. Also make sure WhisperTimeStampLogitsProcessor was used during generation.


File 'adrso054.wav': 93 words in transcript.
File 'adrso055.wav': 121 words in transcript.
File 'adrso056.wav': 289 words in transcript.
File 'adrso059.wav': 63 words in transcript.


Whisper did not predict an ending timestamp, which can happen if audio is cut off in the middle of a word. Also make sure WhisperTimeStampLogitsProcessor was used during generation.


File 'adrso060.wav': 46 words in transcript.
File 'adrso063.wav': 27 words in transcript.
File 'adrso068.wav': 62 words in transcript.
File 'adrso070.wav': 262 words in transcript.


Whisper did not predict an ending timestamp, which can happen if audio is cut off in the middle of a word. Also make sure WhisperTimeStampLogitsProcessor was used during generation.


File 'adrso071.wav': 87 words in transcript.


Whisper did not predict an ending timestamp, which can happen if audio is cut off in the middle of a word. Also make sure WhisperTimeStampLogitsProcessor was used during generation.


File 'adrso072.wav': 91 words in transcript.
File 'adrso074.wav': 103 words in transcript.


Whisper did not predict an ending timestamp, which can happen if audio is cut off in the middle of a word. Also make sure WhisperTimeStampLogitsProcessor was used during generation.


File 'adrso075.wav': 60 words in transcript.
File 'adrso077.wav': 63 words in transcript.
File 'adrso078.wav': 87 words in transcript.


Whisper did not predict an ending timestamp, which can happen if audio is cut off in the middle of a word. Also make sure WhisperTimeStampLogitsProcessor was used during generation.


File 'adrso089.wav': 111 words in transcript.


Whisper did not predict an ending timestamp, which can happen if audio is cut off in the middle of a word. Also make sure WhisperTimeStampLogitsProcessor was used during generation.


File 'adrso090.wav': 44 words in transcript.


Whisper did not predict an ending timestamp, which can happen if audio is cut off in the middle of a word. Also make sure WhisperTimeStampLogitsProcessor was used during generation.


File 'adrso092.wav': 75 words in transcript.
File 'adrso093.wav': 146 words in transcript.


Whisper did not predict an ending timestamp, which can happen if audio is cut off in the middle of a word. Also make sure WhisperTimeStampLogitsProcessor was used during generation.


File 'adrso098.wav': 111 words in transcript.
File 'adrso106.wav': 121 words in transcript.
File 'adrso109.wav': 106 words in transcript.
File 'adrso110.wav': 101 words in transcript.


Whisper did not predict an ending timestamp, which can happen if audio is cut off in the middle of a word. Also make sure WhisperTimeStampLogitsProcessor was used during generation.


File 'adrso112.wav': 76 words in transcript.
File 'adrso116.wav': 52 words in transcript.
File 'adrso122.wav': 30 words in transcript.


Whisper did not predict an ending timestamp, which can happen if audio is cut off in the middle of a word. Also make sure WhisperTimeStampLogitsProcessor was used during generation.


File 'adrso123.wav': 49 words in transcript.
File 'adrso125.wav': 86 words in transcript.


Whisper did not predict an ending timestamp, which can happen if audio is cut off in the middle of a word. Also make sure WhisperTimeStampLogitsProcessor was used during generation.


File 'adrso126.wav': 64 words in transcript.
File 'adrso128.wav': 25 words in transcript.
File 'adrso130.wav': 33 words in transcript.


Whisper did not predict an ending timestamp, which can happen if audio is cut off in the middle of a word. Also make sure WhisperTimeStampLogitsProcessor was used during generation.


File 'adrso134.wav': 31 words in transcript.
File 'adrso138.wav': 78 words in transcript.
File 'adrso141.wav': 49 words in transcript.
File 'adrso142.wav': 50 words in transcript.
File 'adrso144.wav': 78 words in transcript.
File 'adrso187.wav': 109 words in transcript.
File 'adrso188.wav': 52 words in transcript.
File 'adrso189.wav': 98 words in transcript.


Whisper did not predict an ending timestamp, which can happen if audio is cut off in the middle of a word. Also make sure WhisperTimeStampLogitsProcessor was used during generation.


File 'adrso190.wav': 194 words in transcript.
File 'adrso192.wav': 84 words in transcript.


Whisper did not predict an ending timestamp, which can happen if audio is cut off in the middle of a word. Also make sure WhisperTimeStampLogitsProcessor was used during generation.


File 'adrso197.wav': 91 words in transcript.
File 'adrso198.wav': 41 words in transcript.


Whisper did not predict an ending timestamp, which can happen if audio is cut off in the middle of a word. Also make sure WhisperTimeStampLogitsProcessor was used during generation.


File 'adrso200.wav': 170 words in transcript.


Whisper did not predict an ending timestamp, which can happen if audio is cut off in the middle of a word. Also make sure WhisperTimeStampLogitsProcessor was used during generation.


File 'adrso202.wav': 107 words in transcript.


Whisper did not predict an ending timestamp, which can happen if audio is cut off in the middle of a word. Also make sure WhisperTimeStampLogitsProcessor was used during generation.


File 'adrso205.wav': 105 words in transcript.
File 'adrso206.wav': 95 words in transcript.
File 'adrso209.wav': 71 words in transcript.
File 'adrso211.wav': 147 words in transcript.
File 'adrso212.wav': 100 words in transcript.
File 'adrso215.wav': 53 words in transcript.


Whisper did not predict an ending timestamp, which can happen if audio is cut off in the middle of a word. Also make sure WhisperTimeStampLogitsProcessor was used during generation.


File 'adrso216.wav': 80 words in transcript.


Whisper did not predict an ending timestamp, which can happen if audio is cut off in the middle of a word. Also make sure WhisperTimeStampLogitsProcessor was used during generation.


File 'adrso218.wav': 49 words in transcript.
File 'adrso220.wav': 90 words in transcript.
File 'adrso222.wav': 103 words in transcript.
File 'adrso223.wav': 72 words in transcript.


Whisper did not predict an ending timestamp, which can happen if audio is cut off in the middle of a word. Also make sure WhisperTimeStampLogitsProcessor was used during generation.


File 'adrso224.wav': 132 words in transcript.


Whisper did not predict an ending timestamp, which can happen if audio is cut off in the middle of a word. Also make sure WhisperTimeStampLogitsProcessor was used during generation.


File 'adrso228.wav': 79 words in transcript.
File 'adrso229.wav': 140 words in transcript.


Whisper did not predict an ending timestamp, which can happen if audio is cut off in the middle of a word. Also make sure WhisperTimeStampLogitsProcessor was used during generation.


File 'adrso232.wav': 98 words in transcript.
File 'adrso233.wav': 133 words in transcript.
File 'adrso234.wav': 100 words in transcript.
File 'adrso236.wav': 39 words in transcript.
File 'adrso237.wav': 72 words in transcript.
File 'adrso244.wav': 85 words in transcript.


Whisper did not predict an ending timestamp, which can happen if audio is cut off in the middle of a word. Also make sure WhisperTimeStampLogitsProcessor was used during generation.


File 'adrso245.wav': 56 words in transcript.
File 'adrso246.wav': 67 words in transcript.
File 'adrso247.wav': 94 words in transcript.


Whisper did not predict an ending timestamp, which can happen if audio is cut off in the middle of a word. Also make sure WhisperTimeStampLogitsProcessor was used during generation.


File 'adrso248.wav': 72 words in transcript.
File 'adrso249.wav': 57 words in transcript.


Whisper did not predict an ending timestamp, which can happen if audio is cut off in the middle of a word. Also make sure WhisperTimeStampLogitsProcessor was used during generation.


File 'adrso250.wav': 148 words in transcript.


Whisper did not predict an ending timestamp, which can happen if audio is cut off in the middle of a word. Also make sure WhisperTimeStampLogitsProcessor was used during generation.


File 'adrso253.wav': 99 words in transcript.
All sorted transcriptions for the 'AD' group saved to sorted_patient_transcriptions_ad.csv
Processing 79 audio files from the 'CN' group on GPU.


Device set to use cuda:0


File 'adrso002.wav': 171 words in transcript.


Whisper did not predict an ending timestamp, which can happen if audio is cut off in the middle of a word. Also make sure WhisperTimeStampLogitsProcessor was used during generation.


File 'adrso003.wav': 82 words in transcript.
File 'adrso005.wav': 141 words in transcript.
File 'adrso007.wav': 63 words in transcript.
File 'adrso008.wav': 137 words in transcript.
File 'adrso010.wav': 41 words in transcript.
File 'adrso012.wav': 137 words in transcript.
File 'adrso014.wav': 138 words in transcript.
File 'adrso015.wav': 99 words in transcript.
File 'adrso016.wav': 79 words in transcript.


Whisper did not predict an ending timestamp, which can happen if audio is cut off in the middle of a word. Also make sure WhisperTimeStampLogitsProcessor was used during generation.


File 'adrso017.wav': 117 words in transcript.
File 'adrso018.wav': 5 words in transcript.
File 'adrso019.wav': 5 words in transcript.
File 'adrso021.wav': 5 words in transcript.
File 'adrso022.wav': 5 words in transcript.
File 'adrso023.wav': 5 words in transcript.
File 'adrso148.wav': 240 words in transcript.


Whisper did not predict an ending timestamp, which can happen if audio is cut off in the middle of a word. Also make sure WhisperTimeStampLogitsProcessor was used during generation.


File 'adrso151.wav': 49 words in transcript.
File 'adrso152.wav': 130 words in transcript.
File 'adrso153.wav': 110 words in transcript.
File 'adrso154.wav': 60 words in transcript.
File 'adrso156.wav': 95 words in transcript.


Whisper did not predict an ending timestamp, which can happen if audio is cut off in the middle of a word. Also make sure WhisperTimeStampLogitsProcessor was used during generation.


File 'adrso157.wav': 128 words in transcript.
File 'adrso158.wav': 100 words in transcript.
File 'adrso159.wav': 57 words in transcript.
File 'adrso160.wav': 45 words in transcript.


Whisper did not predict an ending timestamp, which can happen if audio is cut off in the middle of a word. Also make sure WhisperTimeStampLogitsProcessor was used during generation.


File 'adrso161.wav': 74 words in transcript.
File 'adrso162.wav': 135 words in transcript.
File 'adrso164.wav': 173 words in transcript.
File 'adrso165.wav': 82 words in transcript.


Whisper did not predict an ending timestamp, which can happen if audio is cut off in the middle of a word. Also make sure WhisperTimeStampLogitsProcessor was used during generation.


File 'adrso167.wav': 151 words in transcript.


Whisper did not predict an ending timestamp, which can happen if audio is cut off in the middle of a word. Also make sure WhisperTimeStampLogitsProcessor was used during generation.


File 'adrso168.wav': 129 words in transcript.


Whisper did not predict an ending timestamp, which can happen if audio is cut off in the middle of a word. Also make sure WhisperTimeStampLogitsProcessor was used during generation.


File 'adrso169.wav': 63 words in transcript.


Whisper did not predict an ending timestamp, which can happen if audio is cut off in the middle of a word. Also make sure WhisperTimeStampLogitsProcessor was used during generation.


File 'adrso170.wav': 87 words in transcript.


Whisper did not predict an ending timestamp, which can happen if audio is cut off in the middle of a word. Also make sure WhisperTimeStampLogitsProcessor was used during generation.


File 'adrso172.wav': 133 words in transcript.
File 'adrso173.wav': 104 words in transcript.
File 'adrso177.wav': 108 words in transcript.
File 'adrso178.wav': 64 words in transcript.


Whisper did not predict an ending timestamp, which can happen if audio is cut off in the middle of a word. Also make sure WhisperTimeStampLogitsProcessor was used during generation.


File 'adrso180.wav': 102 words in transcript.
File 'adrso182.wav': 90 words in transcript.
File 'adrso183.wav': 192 words in transcript.
File 'adrso186.wav': 79 words in transcript.
File 'adrso257.wav': 92 words in transcript.
File 'adrso259.wav': 183 words in transcript.
File 'adrso260.wav': 84 words in transcript.
File 'adrso261.wav': 114 words in transcript.
File 'adrso262.wav': 67 words in transcript.


Whisper did not predict an ending timestamp, which can happen if audio is cut off in the middle of a word. Also make sure WhisperTimeStampLogitsProcessor was used during generation.


File 'adrso263.wav': 98 words in transcript.
File 'adrso264.wav': 89 words in transcript.
File 'adrso265.wav': 237 words in transcript.
File 'adrso266.wav': 72 words in transcript.


Whisper did not predict an ending timestamp, which can happen if audio is cut off in the middle of a word. Also make sure WhisperTimeStampLogitsProcessor was used during generation.


File 'adrso267.wav': 117 words in transcript.
File 'adrso268.wav': 156 words in transcript.
File 'adrso270.wav': 109 words in transcript.


Whisper did not predict an ending timestamp, which can happen if audio is cut off in the middle of a word. Also make sure WhisperTimeStampLogitsProcessor was used during generation.


File 'adrso273.wav': 101 words in transcript.


Whisper did not predict an ending timestamp, which can happen if audio is cut off in the middle of a word. Also make sure WhisperTimeStampLogitsProcessor was used during generation.


File 'adrso274.wav': 286 words in transcript.


Whisper did not predict an ending timestamp, which can happen if audio is cut off in the middle of a word. Also make sure WhisperTimeStampLogitsProcessor was used during generation.


File 'adrso276.wav': 493 words in transcript.


Whisper did not predict an ending timestamp, which can happen if audio is cut off in the middle of a word. Also make sure WhisperTimeStampLogitsProcessor was used during generation.


File 'adrso277.wav': 140 words in transcript.
File 'adrso278.wav': 44 words in transcript.
File 'adrso280.wav': 229 words in transcript.


Whisper did not predict an ending timestamp, which can happen if audio is cut off in the middle of a word. Also make sure WhisperTimeStampLogitsProcessor was used during generation.


File 'adrso281.wav': 123 words in transcript.
File 'adrso283.wav': 137 words in transcript.


Whisper did not predict an ending timestamp, which can happen if audio is cut off in the middle of a word. Also make sure WhisperTimeStampLogitsProcessor was used during generation.


File 'adrso285.wav': 32 words in transcript.
File 'adrso286.wav': 103 words in transcript.
File 'adrso289.wav': 74 words in transcript.
File 'adrso291.wav': 186 words in transcript.
File 'adrso292.wav': 59 words in transcript.
File 'adrso296.wav': 139 words in transcript.
File 'adrso298.wav': 92 words in transcript.
File 'adrso299.wav': 107 words in transcript.


Whisper did not predict an ending timestamp, which can happen if audio is cut off in the middle of a word. Also make sure WhisperTimeStampLogitsProcessor was used during generation.


File 'adrso300.wav': 54 words in transcript.
File 'adrso302.wav': 236 words in transcript.
File 'adrso307.wav': 170 words in transcript.
File 'adrso308.wav': 152 words in transcript.
File 'adrso309.wav': 104 words in transcript.
File 'adrso310.wav': 81 words in transcript.


Whisper did not predict an ending timestamp, which can happen if audio is cut off in the middle of a word. Also make sure WhisperTimeStampLogitsProcessor was used during generation.


File 'adrso312.wav': 112 words in transcript.
File 'adrso315.wav': 101 words in transcript.


Whisper did not predict an ending timestamp, which can happen if audio is cut off in the middle of a word. Also make sure WhisperTimeStampLogitsProcessor was used during generation.


File 'adrso316.wav': 88 words in transcript.
All sorted transcriptions for the 'CN' group saved to sorted_patient_transcriptions_cn.csv


### Feature Extraction

In [None]:
import os
# disable HF-Hub and Transformers progress bars
os.environ["HF_HUB_DISABLE_PROGRESS_BARS"] = "1"
from transformers import logging
logging.disable_progress_bar()

import pandas as pd
import numpy as np
import torch
from transformers import DistilBertTokenizer, DistilBertModel

# ─── Load transcripts & labels ──────────────────────────────────────────────
df_ad = pd.read_csv("sorted_patient_transcriptions_ad.csv")
df_cn = pd.read_csv("sorted_patient_transcriptions_cn.csv")

# assume df_ad/df_cn each have columns: "file_name" (e.g. "adrso024.wav"), "transcription"
texts = df_ad["transcription"].tolist() + df_cn["transcription"].tolist()
y     = np.array([1] * len(df_ad) + [0] * len(df_cn))

# ─── Embed with DistilBERT ─────────────────────────────────────────────────
tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
model     = DistilBertModel.from_pretrained("distilbert-base-uncased").eval()
device    = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

features = []
with torch.no_grad():
    for text in texts:
        inputs = tokenizer(
            text,
            return_tensors="pt",
            truncation=True,
            padding=True
        ).to(device)
        emb = (
            model(**inputs)
            .last_hidden_state   # (1, seq_len, hidden_dim)
            .squeeze(0)          # (seq_len, hidden_dim)
            .mean(dim=0)         # (hidden_dim,)
        )
        features.append(emb.cpu().numpy())

X = np.vstack(features)

In [None]:
feat_cols = [f"f{i}" for i in range(X.shape[1])]
df_feat   = pd.DataFrame(X, columns=feat_cols)

# strip ".wav" so file_id matches audio CSVs
def strip_ext(fn):
    return os.path.splitext(fn)[0]

file_ids = (
    df_ad["file_name"].map(strip_ext).tolist()
  + df_cn["file_name"].map(strip_ext).tolist()
)

df_feat["file_id"] = file_ids
df_feat["label"]   = y

# reorder columns: file_id, label, then features
df_feat = df_feat[["file_id", "label"] + feat_cols]

df_feat.to_csv("text_feat.csv", index=False)

print("saved text_feat.csv (clean file_id)")

saved text_feats.csv


### Training

In [None]:
import pandas as pd
import numpy as np
from sklearn.ensemble import RandomForestClassifier
from sklearn.neural_network import MLPClassifier
from xgboost import XGBClassifier
from sklearn.metrics import (
    accuracy_score, precision_score, recall_score, f1_score, classification_report
)
from tabulate import tabulate

def train_rf(df, feat_cols, split_col="split",
             train_value="train", test_value="test",
             n_estimators=100, random_state=42, n_jobs=-1):
    X_train = df[df[split_col]==train_value][feat_cols].astype(np.float32).values
    y_train = df[df[split_col]==train_value]["label"].values
    X_test  = df[df[split_col]==test_value][feat_cols].astype(np.float32).values
    y_test  = df[df[split_col]==test_value]["label"].values

    clf = RandomForestClassifier(
        n_estimators=n_estimators,
        random_state=random_state,
        n_jobs=n_jobs
    )
    clf.fit(X_train, y_train)
    return clf

def train_xgb(df, feat_cols, split_col="split",
              train_value="train", test_value="test",
              n_estimators=100, max_depth=3, learning_rate=0.1,
              random_state=42, n_jobs=-1, eval_metric="logloss"):
    X_train = df[df[split_col]==train_value][feat_cols].astype(np.float32).values
    y_train = df[df[split_col]==train_value]["label"].values

    clf = XGBClassifier(
        n_estimators=n_estimators,
        max_depth=max_depth,
        learning_rate=learning_rate,
        random_state=random_state,
        n_jobs=n_jobs,
        eval_metric=eval_metric,
        use_label_encoder=False
    )
    clf.fit(X_train, y_train)
    return clf

def train_mlp(df, feat_cols, split_col="split",
              train_value="train", test_value="test",
              hidden_layer_sizes=(100,), activation="relu",
              solver="adam", max_iter=200, random_state=42):
    X_train = df[df[split_col]==train_value][feat_cols].astype(np.float32).values
    y_train = df[df[split_col]==train_value]["label"].values

    clf = MLPClassifier(
        hidden_layer_sizes=hidden_layer_sizes,
        activation=activation,
        solver=solver,
        max_iter=max_iter,
        random_state=random_state
    )
    clf.fit(X_train, y_train)
    return clf

def format_results(results):
    headers = ["Model", "Accuracy", "Precision", "Recall", "F1-score"]
    rows = [
        [
            r["model"],
            f"{r['accuracy']:.3f}",
            f"{r['precision']:.3f}",
            f"{r['recall']:.3f}",
            f"{r['f1']:.3f}"
        ]
        for r in results
    ]
    print(tabulate(rows, headers, tablefmt="pretty"))

def train_fn(df,
             feat_cols,
             split_col="split",
             train_value="train",
             test_value="test",
             models=None,
             rf_params=None,
             xgb_params=None,
             mlp_params=None):
    if models is None:
        models = ["rf", "xgb", "mlp"]
    rf_params  = rf_params or {}
    xgb_params = xgb_params or {}
    mlp_params = mlp_params or {}

    test_df = df[df[split_col] == test_value]
    X_test  = test_df[feat_cols].astype(np.float32).values
    y_test  = test_df["label"].values

    results = []
    for m in models:
        if m == "rf":
            clf = train_rf(df, feat_cols,
                           split_col=split_col,
                           train_value=train_value,
                           test_value=test_value,
                           **rf_params)
            name = "RandomForest"
        elif m == "xgb":
            clf = train_xgb(df, feat_cols,
                            split_col=split_col,
                            train_value=train_value,
                            test_value=test_value,
                            **xgb_params)
            name = "XGBoost"
        elif m == "mlp":
            clf = train_mlp(df, feat_cols,
                            split_col=split_col,
                            train_value=train_value,
                            test_value=test_value,
                            **mlp_params)
            name = "MLP"
        else:
            continue

        y_pred = clf.predict(X_test)
        results.append({
            "model":     name,
            "accuracy":  accuracy_score(y_test, y_pred),
            "precision": precision_score(y_test, y_pred),
            "recall":    recall_score(y_test, y_pred),
            "f1":        f1_score(y_test, y_pred),
        })

    return results

# example usage
split = pd.read_csv("split_master.csv")
feat  = pd.read_csv("text_feats_fixed.csv")
df    = feat.merge(split, on=["file_id", "label"])
feat_cols = [c for c in df.columns if c.startswith("f") and c[1:].isdigit()]

results = train_fn(
    df,
    feat_cols,
    models=["rf", "xgb", "mlp"],
    rf_params={"n_estimators":100, "random_state":23},
    xgb_params={"n_estimators":100, "max_depth":4, "learning_rate":0.05, "random_state":23},
    mlp_params={"hidden_layer_sizes":(100,), "max_iter":300, "random_state":23}
)

format_results(results)


Parameters: { "use_label_encoder" } are not used.



+--------------+----------+-----------+--------+----------+
|    Model     | Accuracy | Precision | Recall | F1-score |
+--------------+----------+-----------+--------+----------+
| RandomForest |  0.765   |   0.812   | 0.722  |  0.765   |
|   XGBoost    |  0.824   |   0.833   | 0.833  |  0.833   |
|     MLP      |  0.735   |   0.800   | 0.667  |  0.727   |
+--------------+----------+-----------+--------+----------+


