In [14]:
import os
import re
from pathlib import Path
import pandas as pd
from datasets import Dataset, DatasetDict, Audio

def preprocess_transcript(text):
    """
    Replace non-word annotations with <NONWORD> token.
    Example: "[say Ah-P-Eee repeatedly]" -> "<NONWORD>"
    """
    return re.sub(r'\[.*?\]', '<NONWORD>', text)

def collect_data(base_dir):
    """
    Traverse through the dataset directory and collect audio and transcript paths.
    Only process directories that correspond to sessions (e.g., session1, session2_3).
    Handle both 'prompts' and 'promps' transcript directories.
    """
    data = []
    base_path = Path(base_dir)
    
    if not base_path.exists():
        print(f"Base directory {base_dir} does not exist.")
        return data
    
    speakers = [speaker for speaker in base_path.iterdir() if speaker.is_dir()]
    
    for speaker in speakers:
        print(f"Processing speaker: {speaker.name}")
        # Identify session directories (e.g., session1, session2_3)
        sessions = [session for session in speaker.iterdir() if session.is_dir() and session.name.lower().startswith('session')]
        if not sessions:
            print(f"  No session directories found for speaker {speaker.name}. Skipping.")
            continue
        
        for session in sessions:
            print(f"  Processing session: {session.name}")
            wav_arraymic_dir = session / 'wav_arraymic'
            wav_headmic_dir = session / 'wav_headmic'
            prompts_dir = session / 'prompts'
            promps_dir = session / 'promps'

            # Handle both 'prompts' and 'promps' directories
            transcript_dirs = []
            if prompts_dir.exists():
                transcript_dirs.append(prompts_dir)
                print(f"    Found prompts directory: {prompts_dir}")
            if promps_dir.exists():
                transcript_dirs.append(promps_dir)
                print(f"    Found promps directory: {promps_dir}")
            if not transcript_dirs:
                print(f"    No prompts or promps directory found in {session}. Skipping session.")
                continue  # Skip if no transcript directory is found

            # Collect audio files from both mic directories
            for mic_dir in [wav_arraymic_dir, wav_headmic_dir]:
                if mic_dir.exists():
                    print(f"    Processing microphone directory: {mic_dir}")
                    for wav_file in mic_dir.glob('*.wav'):
                        print(f"      Found audio file: {wav_file.name}")
                        transcript_file = None
                        for t_dir in transcript_dirs:
                            potential_transcript = t_dir / (wav_file.stem + '.txt')
                            if potential_transcript.exists():
                                transcript_file = potential_transcript
                                break
                        if transcript_file:
                            print(f"        Found transcript: {transcript_file.name}")
                            try:
                                with open(transcript_file, 'r', encoding='utf-8') as f:
                                    transcript = f.read().strip()
                                preprocessed_transcript = preprocess_transcript(transcript)
                                data.append({
                                    'audio': str(wav_file.resolve()),
                                    'transcript': preprocessed_transcript
                                })
                            except Exception as e:
                                print(f"        Error reading {transcript_file.name}: {e}")
                        else:
                            print(f"        Transcript for {wav_file.name} not found.")
                else:
                    print(f"    Microphone directory not found: {mic_dir}")
    return data

# Example usage
base_directory = r'D:\\datasets\\TORGO male with dysarthria'
data = collect_data(base_directory)
print(f"Total samples collected: {len(data)}")

# Convert data list to pandas DataFrame
if data:
    df = pd.DataFrame(data)
    print(df.head())
    
    # Create a Hugging Face Dataset
    dataset = Dataset.from_pandas(df)
    
    # Split into train and test sets (e.g., 80% train, 20% test)
    split_dataset = dataset.train_test_split(test_size=0.2, seed=42)
    dataset = DatasetDict({
        'train': split_dataset['train'],
        'test': split_dataset['test']
    })
    
    print(dataset)
    
    # Cast the audio column to ensure correct sampling rate
    dataset = dataset.cast_column("audio", Audio(sampling_rate=16000))
else:
    print("No data collected. Please verify the dataset structure and file naming.")


Processing speaker: M01
  Processing session: Session1
    Found prompts directory: D:\datasets\TORGO male with dysarthria\M01\Session1\prompts
    Processing microphone directory: D:\datasets\TORGO male with dysarthria\M01\Session1\wav_arraymic
      Found audio file: 0001.wav
        Found transcript: 0001.txt
      Found audio file: 0002.wav
        Found transcript: 0002.txt
      Found audio file: 0003.wav
        Found transcript: 0003.txt
      Found audio file: 0004.wav
        Found transcript: 0004.txt
      Found audio file: 0005.wav
        Found transcript: 0005.txt
      Found audio file: 0006.wav
        Found transcript: 0006.txt
      Found audio file: 0007.wav
        Found transcript: 0007.txt
      Found audio file: 0008.wav
        Found transcript: 0008.txt
      Found audio file: 0009.wav
        Found transcript: 0009.txt
      Found audio file: 0010.wav
        Found transcript: 0010.txt
      Found audio file: 0011.wav
        Found transcript: 0011.txt
      

In [33]:
from datasets import Dataset, DatasetDict, Audio

def preprocess_dataset(batch, processor):
    """
    Preprocesses the dataset by extracting input features from audio and tokenizing transcripts.
    """
    # Extract audio features
    batch["input_features"] = processor.feature_extractor(
        batch["audio"]["array"],
        sampling_rate=batch["audio"]["sampling_rate"]
    ).input_features[0]
    
    # Tokenize the transcript
    batch["labels"] = processor.tokenizer(batch["transcript"]).input_ids
    
    return batch

# Apply the preprocessing function to both train and test datasets
processed_dataset = dataset.map(
    preprocess_dataset,
    batched=False,
    remove_columns=["audio", "transcript"],
    fn_kwargs={"processor": processor}
)

# Split into train and test sets if not already done
# (Assuming you've already split your dataset earlier)
dataset = DatasetDict({
    'train': processed_dataset['train'],
    'test': processed_dataset['test']
})

# Verify the dataset structure again
print("Processed Train Dataset Columns:", dataset["train"].column_names)
print("Processed Test Dataset Columns:", dataset["test"].column_names)


Map: 100%|██████████| 3028/3028 [01:00<00:00, 49.82 examples/s] 
Map: 100%|██████████| 757/757 [00:10<00:00, 73.91 examples/s] 

Processed Train Dataset Columns: ['input_features', 'labels']
Processed Test Dataset Columns: ['input_features', 'labels']





In [15]:
import evaluate

# Load the WER metric using the evaluate library
wer_metric = evaluate.load("wer")
wer_metric

EvaluationModule(name: "wer", module_type: "metric", features: {'predictions': Value(dtype='string', id='sequence'), 'references': Value(dtype='string', id='sequence')}, usage: """
Compute WER score of transcribed segments against references.

Args:
    references: List of references for each speech input.
    predictions: List of transcriptions to score.
    concatenate_texts (bool, default=False): Whether to concatenate all input texts or compute WER iteratively.

Returns:
    (float): the word error rate

Examples:

    >>> predictions = ["this is the prediction", "there is an other sample"]
    >>> references = ["this is the reference", "there is another one"]
    >>> wer = evaluate.load("wer")
    >>> wer_score = wer.compute(predictions=predictions, references=references)
    >>> print(wer_score)
    0.5
""", stored examples: 0)

In [22]:
from transformers import WhisperProcessor

# Load the Whisper processor
processor = WhisperProcessor.from_pretrained(
    "openai/whisper-small",
    language="english",
    task="transcribe"
)

# Extract the tokenizer from the processor
tokenizer = processor.tokenizer

# Handle special tokens (e.g., <NONWORD>)
if '<NONWORD>' not in tokenizer.get_vocab():
    tokenizer.add_special_tokens({'additional_special_tokens': ['<NONWORD>']})
    processor = WhisperProcessor(
        feature_extractor=processor.feature_extractor,
        tokenizer=tokenizer
    )


In [34]:
def compute_metrics(pred):
    """
    Compute Word Error Rate (WER) between predictions and references.

    Args:
        pred (EvalPrediction): Contains 'predictions' and 'label_ids'.

    Returns:
        dict: A dictionary with the WER metric.
    """
    pred_ids = pred.predictions
    label_ids = pred.label_ids

    # Replace -100 with the pad token id
    label_ids[label_ids == -100] = processor.tokenizer.pad_token_id

    # Decode the predictions and references
    pred_str = processor.batch_decode(pred_ids, skip_special_tokens=True)
    label_str = processor.batch_decode(label_ids, skip_special_tokens=True)

    # Compute WER
    wer = wer_metric.compute(predictions=pred_str, references=label_str)

    return {"wer": wer}


In [24]:
from transformers import WhisperForConditionalGeneration

# Load the pre-trained Whisper model
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-small")

# Update the model configuration for transcription
model.config.forced_decoder_ids = None  # Disable forced decoding if necessary
model.config.update({"task": "transcribe"})

# Resize token embeddings to accommodate new tokens
model.resize_token_embeddings(len(tokenizer))


The new embeddings will be initialized from a multivariate normal distribution that has old embeddings' mean and covariance. As described in this article: https://nlp.stanford.edu/~johnhew/vocab-expansion.html. To disable this, use `mean_resizing=False`


Embedding(51866, 768, padding_idx=50257)

In [35]:
import torch
from dataclasses import dataclass
from typing import Any, Dict, List, Union

@dataclass
class DataCollatorSpeechSeq2SeqWithPadding:
    processor: Any
    decoder_start_token_id: int

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        # Split inputs and labels
        input_features = [{"input_features": feature["input_features"]} for feature in features]
        batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")

        # Tokenize labels
        label_features = [{"input_ids": feature["labels"]} for feature in features]
        labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")

        # Replace padding with -100 to ignore in loss
        labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)

        # Remove the BOS token if present
        if (labels[:, 0] == self.decoder_start_token_id).all().cpu().item():
            labels = labels[:, 1:]

        batch["labels"] = labels
        return batch

# Initialize the data collator
data_collator = DataCollatorSpeechSeq2SeqWithPadding(
    processor=processor,
    decoder_start_token_id=processor.tokenizer.bos_token_id,
)


In [41]:
from transformers import Seq2SeqTrainingArguments

training_args = Seq2SeqTrainingArguments(
    output_dir="./whisper-dysarthric",
    num_train_epochs=3,
    per_device_train_batch_size=2,
    per_device_eval_batch_size=2,
    gradient_accumulation_steps=4,
    learning_rate=5e-5,
    warmup_steps=100,
    gradient_checkpointing=False,
    fp16=False,
    evaluation_strategy="steps",
    save_steps=1000,
    eval_steps=1000,
    logging_steps=100,
    predict_with_generate=True,
    generation_max_length=225,
    load_best_model_at_end=True,
    metric_for_best_model="wer",
    greater_is_better=False,
    push_to_hub=False,
    report_to=["none"],            # Ensure reporting is disabled to prevent wandb interference
    disable_tqdm=False,            # Explicitly enable tqdm
)


In [42]:
from transformers import Seq2SeqTrainer

# Initialize the Trainer
trainer = Seq2SeqTrainer(
    args=training_args,
    model=model,
    train_dataset=dataset["train"],
    eval_dataset=dataset["test"],
    data_collator=data_collator,
    tokenizer=processor.feature_extractor,  # Use the feature extractor as tokenizer
    compute_metrics=compute_metrics,
)


  trainer = Seq2SeqTrainer(


In [44]:
# Start training
trainer.train()


  0%|          | 1/1134 [1:13:08<1381:03:23, 4388.18s/it]
  0%|          | 0/1134 [1:07:57<?, ?it/s]


KeyboardInterrupt: 

In [32]:
# Inspect the dataset columns
print("Train Dataset Columns:", dataset["train"].column_names)
print("Test Dataset Columns:", dataset["test"].column_names)


Train Dataset Columns: ['audio', 'transcript']
Test Dataset Columns: ['audio', 'transcript']


In [1]:
pip install datasets

Note: you may need to restart the kernel to use updated packages.



[notice] A new release of pip is available: 24.0 -> 24.3.1
[notice] To update, run: C:\Users\capta\AppData\Local\Microsoft\WindowsApps\PythonSoftwareFoundation.Python.3.11_qbz5n2kfra8p0\python.exe -m pip install --upgrade pip


In [2]:
pip install nltk pandas


Collecting nltk
  Downloading nltk-3.9.1-py3-none-any.whl.metadata (2.9 kB)
Downloading nltk-3.9.1-py3-none-any.whl (1.5 MB)
   ---------------------------------------- 0.0/1.5 MB ? eta -:--:--
   ---------------------------------------- 0.0/1.5 MB ? eta -:--:--
   - -------------------------------------- 0.1/1.5 MB 1.1 MB/s eta 0:00:02
   -------- ------------------------------- 0.3/1.5 MB 2.7 MB/s eta 0:00:01
   -------- ------------------------------- 0.3/1.5 MB 2.6 MB/s eta 0:00:01
   -------- ------------------------------- 0.3/1.5 MB 2.6 MB/s eta 0:00:01
   -------- ------------------------------- 0.3/1.5 MB 2.6 MB/s eta 0:00:01
   ---------------- ----------------------- 0.6/1.5 MB 2.1 MB/s eta 0:00:01
   ------------------------- -------------- 0.9/1.5 MB 2.8 MB/s eta 0:00:01
   ---------------------------------------- 1.5/1.5 MB 4.0 MB/s eta 0:00:00
Installing collected packages: nltk
Successfully installed nltk-3.9.1
Note: you may need to restart the kernel to use updated pac


[notice] A new release of pip is available: 24.0 -> 24.3.1
[notice] To update, run: C:\Users\capta\AppData\Local\Microsoft\WindowsApps\PythonSoftwareFoundation.Python.3.11_qbz5n2kfra8p0\python.exe -m pip install --upgrade pip


In [4]:
import os
import re
from pathlib import Path
import pandas as pd
from datasets import Dataset, DatasetDict, Audio
import nltk
from nltk.corpus import cmudict

# Step 1: Download NLTK data
nltk.download('cmudict')

# Step 2: Load the CMU Pronouncing Dictionary
cmu_dict = cmudict.dict()

# Step 3: Define the preprocessing function
def preprocess_transcript(text):
    """
    Replace non-word annotations with <NONWORD> token.
    Example: "[say Ah-P-Eee repeatedly]" -> "<NONWORD>"
    """
    preprocessed_text = re.sub(r'\[.*?\]', '<NONWORD>', text)
    # Optionally remove other unwanted characters
    preprocessed_text = re.sub(r'[^\w\s<NONWORD>]', '', preprocessed_text)
    return preprocessed_text

# Step 4: Define the text-to-phoneme conversion function
def text_to_phonemes(text):
    """
    Converts a preprocessed sentence into a list of phonemes.
    Words not found in the CMU dictionary are replaced with <UNK>.
    
    Args:
        text (str): The preprocessed transcript text.
        
    Returns:
        list: A list of phonemes representing the sentence.
    """
    phoneme_sequence = []
    words = text.split()  # Split by whitespace
    
    for word in words:
        if word == '<NONWORD>':
            phoneme_sequence.append('<NONWORD>')
            continue
        
        # Remove any trailing punctuation from the word
        word_clean = re.sub(r'[^\w\s]', '', word).lower()
        
        if word_clean in cmu_dict:
            # Take the first pronunciation variant
            phonemes = cmu_dict[word_clean][0]
            # Remove stress markers from phonemes (e.g., AH0 -> AH)
            phonemes = [re.sub(r'\d', '', p) for p in phonemes]
            phoneme_sequence.extend(phonemes)
        else:
            # Handle unknown words by adding a placeholder
            phoneme_sequence.append('<UNK>')
    
    return phoneme_sequence

# Step 5: Define the data collection function (as provided by the user)
def collect_data(base_dir):
    """
    Traverse through the dataset directory and collect audio and transcript paths.
    Only process directories that correspond to sessions (e.g., session1, session2_3).
    Handle both 'prompts' and 'promps' transcript directories.
    """
    data = []
    base_path = Path(base_dir)

    if not base_path.exists():
        print(f"Base directory {base_dir} does not exist.")
        return data

    speakers = [speaker for speaker in base_path.iterdir() if speaker.is_dir()]

    for speaker in speakers:
        print(f"Processing speaker: {speaker.name}")
        # Identify session directories (e.g., session1, session2_3)
        sessions = [session for session in speaker.iterdir() if session.is_dir() and session.name.lower().startswith('session')]
        if not sessions:
            print(f"  No session directories found for speaker {speaker.name}. Skipping.")
            continue

        for session in sessions:
            print(f"  Processing session: {session.name}")
            wav_arraymic_dir = session / 'wav_arraymic'
            wav_headmic_dir = session / 'wav_headmic'
            prompts_dir = session / 'prompts'
            promps_dir = session / 'promps'

            # Handle both 'prompts' and 'promps' directories
            transcript_dirs = []
            if prompts_dir.exists():
                transcript_dirs.append(prompts_dir)
                print(f"    Found prompts directory: {prompts_dir}")
            if promps_dir.exists():
                transcript_dirs.append(promps_dir)
                print(f"    Found promps directory: {promps_dir}")
            if not transcript_dirs:
                print(f"    No prompts or promps directory found in {session}. Skipping session.")
                continue  # Skip if no transcript directory is found

            # Collect audio files from both mic directories
            for mic_dir in [wav_arraymic_dir, wav_headmic_dir]:
                if mic_dir.exists():
                    print(f"    Processing microphone directory: {mic_dir}")
                    for wav_file in mic_dir.glob('*.wav'):
                        print(f"      Found audio file: {wav_file.name}")
                        transcript_file = None
                        for t_dir in transcript_dirs:
                            potential_transcript = t_dir / (wav_file.stem + '.txt')
                            if potential_transcript.exists():
                                transcript_file = potential_transcript
                                break
                        if transcript_file:
                            print(f"        Found transcript: {transcript_file.name}")
                            try:
                                with open(transcript_file, 'r', encoding='utf-8') as f:
                                    transcript = f.read().strip()
                                preprocessed_transcript = preprocess_transcript(transcript)
                                phoneme_seq = text_to_phonemes(preprocessed_transcript)
                                phoneme_str = ' '.join(phoneme_seq)
                                data.append({
                                    'audio': str(wav_file.resolve()),
                                    'transcript': preprocessed_transcript,
                                    'phoneme_sequence': phoneme_str
                                })
                            except Exception as e:
                                print(f"        Error reading {transcript_file.name}: {e}")
                        else:
                            print(f"        Transcript for {wav_file.name} not found.")
                else:
                    print(f"    Microphone directory not found: {mic_dir}")
    return data

# Step 6: Collect data using the provided directory path
base_directory = r'D:\\datasets\\TORGO male with dysarthria'  # Replace with your actual path
data = collect_data(base_directory)
print(f"Total samples collected: {len(data)}")

# Step 7: Convert data list to pandas DataFrame
if data:
    df = pd.DataFrame(data)
    print(df.head())

    # Create a Hugging Face Dataset
    dataset = Dataset.from_pandas(df)

    # Split into train and test sets (e.g., 80% train, 20% test)
    split_dataset = dataset.train_test_split(test_size=0.2, seed=42)
    dataset = DatasetDict({
        'train': split_dataset['train'],
        'test': split_dataset['test']
    })

    print(dataset)

    # Cast the audio column to ensure correct sampling rate
    dataset = dataset.cast_column("audio", Audio(sampling_rate=16000))
else:
    print("No data collected. Please verify the dataset structure and file naming.")


[nltk_data] Downloading package cmudict to
[nltk_data]     C:\Users\capta\AppData\Roaming\nltk_data...
[nltk_data]   Package cmudict is already up-to-date!


Processing speaker: .tmp.drivedownload
  No session directories found for speaker .tmp.drivedownload. Skipping.
Processing speaker: .tmp.driveupload
  No session directories found for speaker .tmp.driveupload. Skipping.
Processing speaker: M01
  Processing session: Session1
    Found prompts directory: D:\datasets\TORGO male with dysarthria\M01\Session1\prompts
    Processing microphone directory: D:\datasets\TORGO male with dysarthria\M01\Session1\wav_arraymic
      Found audio file: 0001.wav
        Found transcript: 0001.txt
      Found audio file: 0002.wav
        Found transcript: 0002.txt
      Found audio file: 0003.wav
        Found transcript: 0003.txt
      Found audio file: 0004.wav
        Found transcript: 0004.txt
      Found audio file: 0005.wav
        Found transcript: 0005.txt
      Found audio file: 0006.wav
        Found transcript: 0006.txt
      Found audio file: 0007.wav
        Found transcript: 0007.txt
      Found audio file: 0008.wav
        Found transcript:

In [5]:
# Split the phoneme_sequence string into a list of phonemes
def split_phonemes(example):
    example["phonemes"] = example["phoneme_sequence"].split()
    return example

# Apply the function to both train and test sets
dataset = dataset.map(split_phonemes)


Map: 100%|██████████| 3028/3028 [00:00<00:00, 8869.82 examples/s]
Map: 100%|██████████| 757/757 [00:00<00:00, 7700.54 examples/s]


In [6]:
# Extract all unique phonemes from the training set
phoneme_list = []
for phoneme_seq in dataset["train"]["phonemes"]:
    phoneme_list.extend(phoneme_seq)

unique_phonemes = sorted(list(set(phoneme_list)))

# Create phoneme to ID and ID to phoneme mappings
phoneme_to_id = {phoneme: idx for idx, phoneme in enumerate(unique_phonemes)}
phoneme_to_id["<blank>"] = len(phoneme_to_id)  # Add a blank token for CTC

id_to_phoneme = {idx: phoneme for phoneme, idx in phoneme_to_id.items()}


In [7]:
def tokenize_phonemes(example):
    """
    Convert phoneme list to a list of IDs.
    """
    example["labels"] = [phoneme_to_id.get(p, phoneme_to_id["<blank>"]) for p in example["phonemes"]]
    return example

# Apply the tokenization to both train and test sets
dataset = dataset.map(tokenize_phonemes)


Map: 100%|██████████| 3028/3028 [00:00<00:00, 13692.98 examples/s]
Map: 100%|██████████| 757/757 [00:00<00:00, 12712.61 examples/s]


In [8]:
# Select only the necessary columns
columns = ["audio", "labels"]

# Update the dataset to include only these columns
dataset = dataset.remove_columns([col for col in dataset["train"].column_names if col not in columns])


In [10]:
!pip install --upgrade transformers

Collecting transformers
  Downloading transformers-4.48.0-py3-none-any.whl.metadata (44 kB)
     ---------------------------------------- 0.0/44.4 kB ? eta -:--:--
     ----------------- -------------------- 20.5/44.4 kB 330.3 kB/s eta 0:00:01
     ----------------------------------- -- 41.0/44.4 kB 495.5 kB/s eta 0:00:01
     -------------------------------------- 44.4/44.4 kB 437.8 kB/s eta 0:00:00
Downloading transformers-4.48.0-py3-none-any.whl (9.7 MB)
   ---------------------------------------- 0.0/9.7 MB ? eta -:--:--
   ---------------------------------------- 0.1/9.7 MB 1.1 MB/s eta 0:00:09
    --------------------------------------- 0.2/9.7 MB 2.4 MB/s eta 0:00:04
   - -------------------------------------- 0.3/9.7 MB 1.8 MB/s eta 0:00:06
   - -------------------------------------- 0.4/9.7 MB 2.5 MB/s eta 0:00:04
   -- ------------------------------------- 0.7/9.7 MB 2.9 MB/s eta 0:00:04
   ----- ---------------------------------- 1.2/9.7 MB 4.5 MB/s eta 0:00:02
   ------- --


[notice] A new release of pip is available: 24.0 -> 24.3.1
[notice] To update, run: C:\Users\capta\AppData\Local\Microsoft\WindowsApps\PythonSoftwareFoundation.Python.3.11_qbz5n2kfra8p0\python.exe -m pip install --upgrade pip


In [13]:
pip install --upgrade torch


Collecting torch
  Downloading torch-2.5.1-cp311-cp311-win_amd64.whl.metadata (28 kB)
Collecting sympy==1.13.1 (from torch)
  Downloading sympy-1.13.1-py3-none-any.whl.metadata (12 kB)
Downloading torch-2.5.1-cp311-cp311-win_amd64.whl (203.1 MB)
   ---------------------------------------- 0.0/203.1 MB ? eta -:--:--
   ---------------------------------------- 0.1/203.1 MB 2.6 MB/s eta 0:01:17
   ---------------------------------------- 0.5/203.1 MB 7.0 MB/s eta 0:00:29
   ---------------------------------------- 1.2/203.1 MB 9.4 MB/s eta 0:00:22
   ---------------------------------------- 1.9/203.1 MB 11.3 MB/s eta 0:00:18
   ---------------------------------------- 2.4/203.1 MB 11.1 MB/s eta 0:00:19
    --------------------------------------- 3.0/203.1 MB 11.2 MB/s eta 0:00:18
    --------------------------------------- 3.7/203.1 MB 11.9 MB/s eta 0:00:17
    --------------------------------------- 4.4/203.1 MB 12.3 MB/s eta 0:00:17
    --------------------------------------- 5.0/203.1 

  You can safely remove it manually.
ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
torchaudio 2.1.0+cu121 requires torch==2.1.0+cu121, but you have torch 2.5.1 which is incompatible.
torchvision 0.16.0 requires torch==2.1.0, but you have torch 2.5.1 which is incompatible.

[notice] A new release of pip is available: 24.0 -> 24.3.1
[notice] To update, run: C:\Users\capta\AppData\Local\Microsoft\WindowsApps\PythonSoftwareFoundation.Python.3.11_qbz5n2kfra8p0\python.exe -m pip install --upgrade pip


In [20]:
from transformers import Wav2Vec2Processor, DataCollatorWithPadding

# Initialize the processor
processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base")

# Create the data collator
data_collator = DataCollatorWithPadding(processor.tokenizer)


  if kwargs.get("gradient_checkpointing", False):


In [21]:
import os
import json
import torch
import librosa
import pandas as pd
from datasets import Dataset, DatasetDict, Audio
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC, Trainer, TrainingArguments, DataCollatorWithPadding
import nltk
from nltk.corpus import cmudict
import re
from pathlib import Path
from dataclasses import dataclass
from typing import Any, Dict, List, Union

# Step 1: Download NLTK data
nltk.download('cmudict')

# Step 2: Define paths
base_directory = r'D:\\datasets\\TORGO male with dysarthria'  # Replace with your actual path
output_dir = "./wav2vec2-phoneme"

# Step 3: Load CMU Pronouncing Dictionary
cmu_dict = cmudict.dict()

def preprocess_transcript(text):
    """
    Replace non-word annotations with <NONWORD> token.
    Example: "[say Ah-P-Eee repeatedly]" -> "<NONWORD>"
    """
    preprocessed_text = re.sub(r'\[.*?\]', '<NONWORD>', text)
    # Optionally remove other unwanted characters
    preprocessed_text = re.sub(r'[^\w\s<NONWORD>]', '', preprocessed_text)
    return preprocessed_text

def text_to_phonemes(text):
    """
    Converts a preprocessed sentence into a list of phonemes.
    Words not found in the CMU dictionary are replaced with <UNK>.
    
    Args:
        text (str): The preprocessed transcript text.
        
    Returns:
        list: A list of phonemes representing the sentence.
    """
    phoneme_sequence = []
    words = text.split()  # Split by whitespace
    
    for word in words:
        if word == '<NONWORD>':
            phoneme_sequence.append('<NONWORD>')
            continue
        
        # Remove any trailing punctuation from the word
        word_clean = re.sub(r'[^\w\s]', '', word).lower()
        
        if word_clean in cmu_dict:
            # Take the first pronunciation variant
            phonemes = cmu_dict[word_clean][0]
            # Remove stress markers from phonemes (e.g., AH0 -> AH)
            phonemes = [re.sub(r'\d', '', p) for p in phonemes]
            phoneme_sequence.extend(phonemes)
        else:
            # Handle unknown words by adding a placeholder
            phoneme_sequence.append('<UNK>')
    
    return phoneme_sequence

# Step 4: Define the data collection function (as provided by the user)
def collect_data(base_dir):
    """
    Traverse through the dataset directory and collect audio and transcript paths.
    Only process directories that correspond to sessions (e.g., session1, session2_3).
    Handle both 'prompts' and 'promps' transcript directories.
    """
    data = []
    base_path = Path(base_dir)

    if not base_path.exists():
        print(f"Base directory {base_dir} does not exist.")
        return data

    speakers = [speaker for speaker in base_path.iterdir() if speaker.is_dir()]

    for speaker in speakers:
        print(f"Processing speaker: {speaker.name}")
        # Identify session directories (e.g., session1, session2_3)
        sessions = [session for session in speaker.iterdir() if session.is_dir() and session.name.lower().startswith('session')]
        if not sessions:
            print(f"  No session directories found for speaker {speaker.name}. Skipping.")
            continue

        for session in sessions:
            print(f"  Processing session: {session.name}")
            wav_arraymic_dir = session / 'wav_arraymic'
            wav_headmic_dir = session / 'wav_headmic'
            prompts_dir = session / 'prompts'
            promps_dir = session / 'promps'

            # Handle both 'prompts' and 'promps' directories
            transcript_dirs = []
            if prompts_dir.exists():
                transcript_dirs.append(prompts_dir)
                print(f"    Found prompts directory: {prompts_dir}")
            if promps_dir.exists():
                transcript_dirs.append(promps_dir)
                print(f"    Found promps directory: {promps_dir}")
            if not transcript_dirs:
                print(f"    No prompts or promps directory found in {session}. Skipping session.")
                continue  # Skip if no transcript directory is found

            # Collect audio files from both mic directories
            for mic_dir in [wav_arraymic_dir, wav_headmic_dir]:
                if mic_dir.exists():
                    print(f"    Processing microphone directory: {mic_dir}")
                    for wav_file in mic_dir.glob('*.wav'):
                        print(f"      Found audio file: {wav_file.name}")
                        transcript_file = None
                        for t_dir in transcript_dirs:
                            potential_transcript = t_dir / (wav_file.stem + '.txt')
                            if potential_transcript.exists():
                                transcript_file = potential_transcript
                                break
                        if transcript_file:
                            print(f"        Found transcript: {transcript_file.name}")
                            try:
                                with open(transcript_file, 'r', encoding='utf-8') as f:
                                    transcript = f.read().strip()
                                preprocessed_transcript = preprocess_transcript(transcript)
                                phoneme_seq = text_to_phonemes(preprocessed_transcript)
                                phoneme_str = ' '.join(phoneme_seq)
                                data.append({
                                    'audio': str(wav_file.resolve()),
                                    'transcript': preprocessed_transcript,
                                    'phoneme_sequence': phoneme_str
                                })
                            except Exception as e:
                                print(f"        Error reading {transcript_file.name}: {e}")
                        else:
                            print(f"        Transcript for {wav_file.name} not found.")
                else:
                    print(f"    Microphone directory not found: {mic_dir}")
    return data

from pathlib import Path

# Step 5: Collect data
data = collect_data(base_directory)
print(f"Total samples collected: {len(data)}")

# Step 6: Convert data list to pandas DataFrame
if data:
    df = pd.DataFrame(data)
    print(df.head())

    # Step 7: Create a Hugging Face Dataset
    dataset = Dataset.from_pandas(df)

    # Step 8: Split into train and test sets (e.g., 80% train, 20% test)
    split_dataset = dataset.train_test_split(test_size=0.2, seed=42)
    dataset = DatasetDict({
        'train': split_dataset['train'],
        'test': split_dataset['test']
    })

    print(dataset)

    # Step 9: Cast the audio column to ensure correct sampling rate
    dataset = dataset.cast_column("audio", Audio(sampling_rate=16000))
else:
    print("No data collected. Please verify the dataset structure and file naming.")

# Proceed only if data is collected
if len(data) > 0:
    # Step 10: Apply phoneme splitting
    def split_phonemes(example):
        example["phonemes"] = example["phoneme_sequence"].split()
        return example

    # Apply phoneme splitting to both train and test
    dataset = dataset.map(split_phonemes)

    # Step 11: Create phoneme vocabulary
    phoneme_list = []
    for phoneme_seq in dataset["train"]["phonemes"]:
        phoneme_list.extend(phoneme_seq)
    unique_phonemes = sorted(list(set(phoneme_list)))

    # Create phoneme to ID and ID to phoneme mappings
    phoneme_to_id = {phoneme: idx for idx, phoneme in enumerate(unique_phonemes)}
    phoneme_to_id["<blank>"] = len(phoneme_to_id)  # Add a blank token for CTC
    id_to_phoneme = {idx: phoneme for phoneme, idx in phoneme_to_id.items()}

    # Step 12: Tokenize phoneme sequences
    def tokenize_phonemes(example):
        phoneme_tokens = example["phonemes"]
        # Map phonemes to IDs, use <blank> ID for unknown phonemes
        example["labels"] = [phoneme_to_id.get(p, phoneme_to_id["<blank>"]) for p in phoneme_tokens]
        return example

    # Apply tokenization to both train and test
    dataset = dataset.map(tokenize_phonemes)

    # Step 13: Select only the necessary columns
    columns = ["audio", "labels"]
    dataset = dataset.remove_columns([col for col in dataset["train"].column_names if col not in columns])

    # Step 14: Set dataset format for PyTorch
    dataset.set_format(type="torch", columns=columns)

    # Step 15: Initialize processor
    processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base")

    # Step 16: Define the custom data collator
    @dataclass
    class CustomDataCollatorCTCWithPadding:
        """
        Custom data collator that dynamically pads the inputs and labels for CTC.
        """
        processor: Wav2Vec2Processor
        padding: Union[bool, str] = True
        max_length: Union[int, None] = None
        pad_to_multiple_of: Union[int, None] = None

        def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
            # Separate inputs and labels
            input_features = [{"input_values": feature["audio"]["array"]} for feature in features]
            labels = [feature["labels"] for feature in features]

            # Use DataCollatorWithPadding to pad input_values
            data_collator = DataCollatorWithPadding(
                processor=self.processor,
                padding=self.padding,
                max_length=self.max_length,
                pad_to_multiple_of=self.pad_to_multiple_of,
            )
            batch = data_collator(input_features)

            # Pad labels manually
            label_lengths = torch.tensor([len(label) for label in labels], dtype=torch.long)
            max_label_length = label_lengths.max()

            # Initialize padded_labels with -100
            padded_labels = torch.full(
                (len(labels), max_label_length), fill_value=-100, dtype=torch.long
            )

            for i, label in enumerate(labels):
                padded_labels[i, :label_lengths[i]] = torch.tensor(label, dtype=torch.long)

            # Add labels and label_lengths to the batch
            batch["labels"] = padded_labels
            batch["labels_length"] = label_lengths

            return batch

    # Step 17: Initialize the custom data collator
    data_collator = CustomDataCollatorCTCWithPadding(
        processor=processor,
        padding=True
    )

    # Step 18: Load the pre-trained Wav2Vec2 model
    model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base")

    # Step 19: Resize token embeddings to match the number of phonemes
    model.resize_token_embeddings(len(phoneme_to_id))

    # Step 20: Define training arguments
    training_args = TrainingArguments(
        output_dir=output_dir,
        group_by_length=True,
        per_device_train_batch_size=8,  # Adjust based on your GPU
        per_device_eval_batch_size=8,   # Adjust based on your GPU
        evaluation_strategy="steps",
        num_train_epochs=10,
        fp16=True,                      # Set to False if not using GPU
        save_steps=400,
        eval_steps=400,
        logging_steps=100,
        learning_rate=1e-4,
        warmup_steps=500,
        save_total_limit=2,
    )

    # Step 21: Initialize Trainer
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=dataset["train"],
        eval_dataset=dataset["test"],
        data_collator=data_collator,
        tokenizer=processor.feature_extractor,  # Use feature extractor as tokenizer
    )

    # Step 22: Start training
    trainer.train()

    # Step 23: Save the fine-tuned model and processor
    model.save_pretrained(output_dir)
    processor.save_pretrained(output_dir)

    # Save phoneme mappings
    with open(os.path.join(output_dir, "phoneme_to_id.json"), "w") as f:
        json.dump(phoneme_to_id, f)

    with open(os.path.join(output_dir, "id_to_phoneme.json"), "w") as f:
        json.dump(id_to_phoneme, f)

    print("Model, processor, and phoneme mappings saved successfully!")

    # Step 24: Define the transcription function
    def transcribe_audio_to_phonemes(audio_path):
        """
        Transcribes an audio file into phonemes.

        Args:
            audio_path (str): Path to the audio file.

        Returns:
            list: List of phonemes.
        """
        # Load audio
        speech, rate = librosa.load(audio_path, sr=16000)

        # Process audio
        inputs = processor(speech, sampling_rate=16000, return_tensors="pt")

        # Get logits from the model
        with torch.no_grad():
            logits = model(inputs.input_values).logits

        # Take argmax and convert to phonemes
        predicted_ids = torch.argmax(logits, dim=-1)
        predicted_phonemes = [id_to_phoneme.get(str(id.item()), "<UNK>") for id in predicted_ids[0]]

        # Remove consecutive duplicates and blanks
        final_phonemes = []
        previous = None
        for phoneme in predicted_phonemes:
            if phoneme != previous and phoneme != "<blank>":
                final_phonemes.append(phoneme)
            previous = phoneme

        return final_phonemes

    # Step 25: Test the transcription function
    test_audio = "D:/datasets/TORGO male with dysarthria/M01/Session1/wav_arrayMic/0019.wav"  # Replace with your test audio file path
    phoneme_output = transcribe_audio_to_phonemes(test_audio)
    print("Phoneme Transcription:")
    print(phoneme_output)
else:
    print("No data available to proceed with training.")


[nltk_data] Downloading package cmudict to
[nltk_data]     C:\Users\capta\AppData\Roaming\nltk_data...
[nltk_data]   Package cmudict is already up-to-date!


Processing speaker: .tmp.drivedownload
  No session directories found for speaker .tmp.drivedownload. Skipping.
Processing speaker: .tmp.driveupload
  No session directories found for speaker .tmp.driveupload. Skipping.
Processing speaker: M01
  Processing session: Session1
    Found prompts directory: D:\datasets\TORGO male with dysarthria\M01\Session1\prompts
    Processing microphone directory: D:\datasets\TORGO male with dysarthria\M01\Session1\wav_arraymic
      Found audio file: 0001.wav
        Found transcript: 0001.txt
      Found audio file: 0002.wav
        Found transcript: 0002.txt
      Found audio file: 0003.wav
        Found transcript: 0003.txt
      Found audio file: 0004.wav
        Found transcript: 0004.txt
      Found audio file: 0005.wav
        Found transcript: 0005.txt
      Found audio file: 0006.wav
        Found transcript: 0006.txt
      Found audio file: 0007.wav
        Found transcript: 0007.txt
      Found audio file: 0008.wav
        Found transcript:

Map: 100%|██████████| 3028/3028 [00:00<00:00, 10853.91 examples/s]
Map: 100%|██████████| 757/757 [00:00<00:00, 8837.66 examples/s]
Map: 100%|██████████| 3028/3028 [00:00<00:00, 12719.83 examples/s]
Map: 100%|██████████| 757/757 [00:00<00:00, 12168.68 examples/s]
  if kwargs.get("gradient_checkpointing", False):
  
Some weights of Wav2Vec2ForCTC were not initialized from the model checkpoint at facebook/wav2vec2-base and are newly initialized: ['lm_head.bias', 'lm_head.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


NotImplementedError: 

In [27]:
import os
import re
import json
from pathlib import Path
import pandas as pd
from datasets import Dataset, DatasetDict, Audio
import nltk
from nltk.corpus import cmudict
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC, Trainer, TrainingArguments, DataCollatorWithPadding
import torch
import librosa

# ----------------------------------------
# Step 1: Download and Load CMU Pronouncing Dictionary
# ----------------------------------------
nltk.download('cmudict')
cmu_dict = cmudict.dict()

def text_to_phonemes(text):
    """
    Converts a preprocessed sentence into a list of phonemes.
    Words not found in the CMU dictionary are replaced with <UNK>.
    
    Args:
        text (str): The preprocessed transcript text.
        
    Returns:
        list: A list of phonemes representing the sentence.
    """
    phoneme_sequence = []
    words = text.split()  # Split by whitespace
    
    for word in words:
        if word == '<NONWORD>':
            phoneme_sequence.append('<NONWORD>')
            continue
        
        # Remove any trailing punctuation from the word
        word_clean = re.sub(r'[^\w\s]', '', word).lower()
        
        if word_clean in cmu_dict:
            # Take the first pronunciation variant
            phonemes = cmu_dict[word_clean][0]
            # Remove stress markers from phonemes (e.g., AH0 -> AH)
            phonemes = [re.sub(r'\d', '', p) for p in phonemes]
            phoneme_sequence.extend(phonemes)
        else:
            # Handle unknown words by adding a placeholder
            phoneme_sequence.append('<UNK>')
    
    return phoneme_sequence

# ----------------------------------------
# Step 2: Collect and Prepare Data
# ----------------------------------------
def preprocess_transcript(text):
    """
    Replace non-word annotations with <NONWORD> token.
    Example: "[say Ah-P-Eee repeatedly]" -> "<NONWORD>"
    """
    preprocessed_text = re.sub(r'\[.*?\]', '<NONWORD>', text)
    # Optionally remove other unwanted characters
    preprocessed_text = re.sub(r'[^\w\s<NONWORD>]', '', preprocessed_text)
    return preprocessed_text

def collect_data(base_dir):
    """
    Traverse through the dataset directory and collect audio and transcript paths.
    Only process directories that correspond to sessions (e.g., session1, session2_3).
    Handle both 'prompts' and 'promps' transcript directories.
    """
    data = []
    base_path = Path(base_dir)

    if not base_path.exists():
        print(f"Base directory {base_dir} does not exist.")
        return data

    speakers = [speaker for speaker in base_path.iterdir() if speaker.is_dir()]

    for speaker in speakers:
        print(f"Processing speaker: {speaker.name}")
        # Identify session directories (e.g., session1, session2_3)
        sessions = [session for session in speaker.iterdir() if session.is_dir() and session.name.lower().startswith('session')]
        if not sessions:
            print(f"  No session directories found for speaker {speaker.name}. Skipping.")
            continue

        for session in sessions:
            print(f"  Processing session: {session.name}")
            wav_arraymic_dir = session / 'wav_arraymic'
            wav_headmic_dir = session / 'wav_headmic'
            prompts_dir = session / 'prompts'
            promps_dir = session / 'promps'

            # Handle both 'prompts' and 'promps' directories
            transcript_dirs = []
            if prompts_dir.exists():
                transcript_dirs.append(prompts_dir)
                print(f"    Found prompts directory: {prompts_dir}")
            if promps_dir.exists():
                transcript_dirs.append(promps_dir)
                print(f"    Found promps directory: {promps_dir}")
            if not transcript_dirs:
                print(f"    No prompts or promps directory found in {session}. Skipping session.")
                continue  # Skip if no transcript directory is found

            # Collect audio files from both mic directories
            for mic_dir in [wav_arraymic_dir, wav_headmic_dir]:
                if mic_dir.exists():
                    print(f"    Processing microphone directory: {mic_dir}")
                    for wav_file in mic_dir.glob('*.wav'):
                        print(f"      Found audio file: {wav_file.name}")
                        transcript_file = None
                        for t_dir in transcript_dirs:
                            potential_transcript = t_dir / (wav_file.stem + '.txt')
                            if potential_transcript.exists():
                                transcript_file = potential_transcript
                                break
                        if transcript_file:
                            print(f"        Found transcript: {transcript_file.name}")
                            try:
                                with open(transcript_file, 'r', encoding='utf-8') as f:
                                    transcript = f.read().strip()
                                preprocessed_transcript = preprocess_transcript(transcript)
                                phoneme_seq = text_to_phonemes(preprocessed_transcript)
                                phoneme_str = ' '.join(phoneme_seq)
                                data.append({
                                    'audio': str(wav_file.resolve()),
                                    'transcript': preprocessed_transcript,
                                    'phoneme_sequence': phoneme_str
                                })
                            except Exception as e:
                                print(f"        Error reading {transcript_file.name}: {e}")
                        else:
                            print(f"        Transcript for {wav_file.name} not found.")
                else:
                    print(f"    Microphone directory not found: {mic_dir}")
    return data

# Example usage
base_directory = r'D:\\datasets\\TORGO male with dysarthria'  # Replace with your actual path
data = collect_data(base_directory)
print(f"Total samples collected: {len(data)}")

# Convert data list to pandas DataFrame
if data:
    df = pd.DataFrame(data)
    print(df.head())

    # Create a Hugging Face Dataset
    dataset = Dataset.from_pandas(df)

    # Split into train and test sets (e.g., 80% train, 20% test)
    split_dataset = dataset.train_test_split(test_size=0.2, seed=42)
    dataset = DatasetDict({
        'train': split_dataset['train'],
        'test': split_dataset['test']
    })

    print(dataset)

    # Cast the audio column to ensure correct sampling rate
    dataset = dataset.cast_column("audio", Audio(sampling_rate=16000))
else:
    print("No data collected. Please verify the dataset structure and file naming.")

# Proceed only if data is collected
if len(data) > 0:
    # ----------------------------------------
    # Step 3: Create Phoneme Vocabulary and Mappings
    # ----------------------------------------
    # Split phoneme_sequence into list of phonemes
    def split_phonemes(example):
        example["phonemes"] = example["phoneme_sequence"].split()
        return example

    # Apply the function to both train and test sets
    dataset = dataset.map(split_phonemes)

    # Extract all unique phonemes from the training set
    phoneme_list = []
    for phoneme_seq in dataset["train"]["phonemes"]:
        phoneme_list.extend(phoneme_seq)

    unique_phonemes = sorted(list(set(phoneme_list)))

    # Create phoneme to ID and ID to phoneme mappings
    phoneme_to_id = {phoneme: idx for idx, phoneme in enumerate(unique_phonemes)}
    phoneme_to_id["<blank>"] = len(phoneme_to_id)  # Add a blank token for CTC

    id_to_phoneme = {idx: phoneme for phoneme, idx in phoneme_to_id.items()}

    # ----------------------------------------
    # Step 4: Tokenize Phoneme Sequences
    # ----------------------------------------
    def tokenize_phonemes(example):
        """
        Convert phoneme list to a list of IDs.
        """
        phoneme_tokens = example["phonemes"]
        # Map phonemes to IDs, use <blank> ID for unknown phonemes
        example["labels"] = [phoneme_to_id.get(p, phoneme_to_id["<blank>"]) for p in phoneme_tokens]
        return example

    # Apply to both train and test sets
    dataset = dataset.map(tokenize_phonemes)

    # ----------------------------------------
    # Step 5: Select Only Necessary Columns
    # ----------------------------------------
    columns = ["audio", "labels"]
    dataset = dataset.remove_columns([col for col in dataset["train"].column_names if col not in columns])

    # ----------------------------------------
    # Step 6: Set Dataset Format for PyTorch
    # ----------------------------------------
    dataset.set_format(type="torch", columns=columns)

    # ----------------------------------------
    # Step 7: Initialize Processor and Model
    # ----------------------------------------
    processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base")

    model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base")

    # Resize the token embeddings to match the number of phonemes
    model.resize_token_embeddings(len(phoneme_to_id))

    # ----------------------------------------
    # Step 8: Define Training Arguments
    # ----------------------------------------
    training_args = TrainingArguments(
        output_dir="./wav2vec2-phoneme",          # Output directory
        group_by_length=True,                     # Group samples by length for efficiency
        per_device_train_batch_size=8,            # Adjust based on your GPU
        per_device_eval_batch_size=8,             # Adjust based on your GPU
        evaluation_strategy="steps",              # Evaluate every 'eval_steps'
        num_train_epochs=10,                      # Number of training epochs
        fp16=True,                                # Use mixed precision if GPU supports it
        save_steps=400,                           # Save checkpoint every 400 steps
        eval_steps=400,                           # Evaluate every 400 steps
        logging_steps=100,                        # Log training progress every 100 steps
        learning_rate=1e-4,                       # Learning rate
        warmup_steps=500,                         # Number of warmup steps
        save_total_limit=2,                       # Limit the total amount of checkpoints
    )

    # ----------------------------------------
    # Step 9: Define Data Collator
    # ----------------------------------------
    # Using DataCollatorWithPadding instead of DataCollatorCTCWithPadding
    data_collator = DataCollatorWithPadding(
        processor=processor,
        padding=True
    )

    # ----------------------------------------
    # Step 10: Initialize Trainer
    # ----------------------------------------
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=dataset["train"],
        eval_dataset=dataset["test"],
        data_collator=data_collator,
        tokenizer=processor.feature_extractor,  # Use feature extractor as tokenizer
    )

    # ----------------------------------------
    # Step 11: Start Training
    # ----------------------------------------
    trainer.train()

    # ----------------------------------------
    # Step 12: Save the Fine-Tuned Model and Mappings
    # ----------------------------------------
    model.save_pretrained("./wav2vec2-phoneme")
    processor.save_pretrained("./wav2vec2-phoneme")

    # Save phoneme mappings
    with open("./wav2vec2-phoneme/phoneme_to_id.json", "w") as f:
        json.dump(phoneme_to_id, f)

    with open("./wav2vec2-phoneme/id_to_phoneme.json", "w") as f:
        json.dump(id_to_phoneme, f)

    print("Model, processor, and phoneme mappings saved successfully!")
else:
    print("No data available to proceed with training.")


[nltk_data] Downloading package cmudict to
[nltk_data]     C:\Users\capta\AppData\Roaming\nltk_data...
[nltk_data]   Package cmudict is already up-to-date!


Processing speaker: .tmp.drivedownload
  No session directories found for speaker .tmp.drivedownload. Skipping.
Processing speaker: .tmp.driveupload
  No session directories found for speaker .tmp.driveupload. Skipping.
Processing speaker: M01
  Processing session: Session1
    Found prompts directory: D:\datasets\TORGO male with dysarthria\M01\Session1\prompts
    Processing microphone directory: D:\datasets\TORGO male with dysarthria\M01\Session1\wav_arraymic
      Found audio file: 0001.wav
        Found transcript: 0001.txt
      Found audio file: 0002.wav
        Found transcript: 0002.txt
      Found audio file: 0003.wav
        Found transcript: 0003.txt
      Found audio file: 0004.wav
        Found transcript: 0004.txt
      Found audio file: 0005.wav
        Found transcript: 0005.txt
      Found audio file: 0006.wav
        Found transcript: 0006.txt
      Found audio file: 0007.wav
        Found transcript: 0007.txt
      Found audio file: 0008.wav
        Found transcript:

Map: 100%|██████████| 3028/3028 [00:00<00:00, 7149.61 examples/s]
Map: 100%|██████████| 757/757 [00:00<00:00, 10425.85 examples/s]
Map: 100%|██████████| 3028/3028 [00:00<00:00, 9425.85 examples/s]
Map: 100%|██████████| 757/757 [00:00<00:00, 11097.17 examples/s]
  if kwargs.get("gradient_checkpointing", False):
Some weights of Wav2Vec2ForCTC were not initialized from the model checkpoint at facebook/wav2vec2-base and are newly initialized: ['lm_head.bias', 'lm_head.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


NotImplementedError: 