In [1]:


# Step 2: Import all required modules
import os
import random
import librosa
import numpy as np
from datasets import Dataset
from transformers import AutoProcessor, AutoModelForCTC, TrainingArguments, Trainer
from dataclasses import dataclass
from typing import Dict, List, Union
from scipy.signal import butter, lfilter
from audiomentations import AddGaussianNoise
import torch

# This import is not strictly needed for the Trainer but good to have for TF config
import tensorflow as tf

# Step 3: Your Augmenter class (remains the same)
Training_dirs = "./dataset/LibriSpeech/train-clean-100/"

class Augmenter:
    # ... (Your full Augmenter class code remains here as provided) ...
    def __init__(self, sr=16000,
                 noise_prob=0.4, noise_max_amp=0.01,
                 reverb_prob=0.3, reverb_delay=0.025, reverb_decay=0.2,
                 shuffle_prob=0.05, time_stretch_prob=0.2, time_stretch_range=(0.9, 1.1),
                 gaps_prob=0.08, gaps_n=4, gaps_max_duration=0.1,
                 freq_mask_prob=0.2, freq_mask_n=1):
        self.sr, self.noise_prob, self.reverb_prob, self.shuffle_prob, self.time_stretch_prob, self.gaps_prob, self.freq_mask_prob = sr, noise_prob, reverb_prob, shuffle_prob, time_stretch_prob, gaps_prob, freq_mask_prob
        self.noise_aug = AddGaussianNoise(p=1.0, max_amplitude=noise_max_amp, sample_rate=sr)
        self.reverb_delay, self.reverb_decay, self.time_stretch_range, self.gaps_n, self.gaps_max_duration, self.freq_mask_n = reverb_delay, reverb_decay, time_stretch_range, gaps_n, gaps_max_duration, freq_mask_n
    def augment(self, audio):
        distortions, audio = [], np.array(audio, dtype=np.float32)
        if random.random() < self.noise_prob: distortions.append('noise')
        if random.random() < self.reverb_prob: distortions.append('reverb')
        if random.random() < self.shuffle_prob: distortions.append('shuffle')
        if random.random() < self.time_stretch_prob: distortions.append('time_stretch')
        if random.random() < self.gaps_prob: distortions.append('missing_gaps')
        if random.random() < self.freq_mask_prob: distortions.append('frequency_masking')
        for d in distortions:
            if d == 'noise': audio = self.noise_aug(samples=audio, sample_rate=self.sr)
            elif d == 'reverb':
                delay, reverb = int(self.reverb_delay*self.sr), np.pad(audio*self.reverb_decay, (int(self.reverb_delay*self.sr), 0), 'constant')
                audio += reverb[:len(audio)]
            elif d == 'shuffle':
                segs = np.array_split(audio, 3)
                random.shuffle(segs)
                audio = np.concatenate(segs)
            elif d == 'time_stretch': audio = librosa.effects.time_stretch(y=audio, rate=random.uniform(*self.time_stretch_range))
            elif d == 'missing_gaps':
                gap_audio = np.copy(audio)
                for _ in range(self.gaps_n):
                    gap_dur, gap_samples = random.uniform(0.1, self.gaps_max_duration), int(random.uniform(0.1, self.gaps_max_duration) * self.sr)
                    if len(gap_audio) > gap_samples:
                        start = random.randint(0, len(gap_audio) - gap_samples)
                        gap_audio[start:start+gap_samples] = 0
                audio = gap_audio
            elif d == 'frequency_masking':
                nyquist = self.sr / 2
                for _ in range(self.freq_mask_n):
                    l_freq, h_freq = random.uniform(500, 5000), random.uniform(500, 2000)
                    if l_freq + h_freq < nyquist:
                        b, a = butter(N=4, Wn=[l_freq, l_freq + h_freq], btype="bandstop", fs=self.sr)
                        audio = lfilter(b, a, audio)
        return audio

# Step 4: Your custom load_data function (remains the same)
def load_data():
    file_paths, transcriptions = [], []
    print("Scanning directories for .flac files...")
    for speaker_id in os.listdir(Training_dirs):
        speaker_path = os.path.join(Training_dirs, speaker_id)
        if not os.path.isdir(speaker_path): continue
        for chapter_id in os.listdir(speaker_path):
            chapter_path = os.path.join(speaker_path, chapter_id)
            if not os.path.isdir(chapter_path): continue
            trans_file = f"{speaker_id}-{chapter_id}.trans.txt"
            trans_path = os.path.join(chapter_path, trans_file)
            if os.path.exists(trans_path):
                with open(trans_path, 'r') as f:
                    for line in f:
                        parts = line.strip().split(' ', 1)
                        file_id, text = parts[0], parts[1]
                        audio_path = os.path.join(chapter_path, f"{file_id}.flac")
                        if os.path.exists(audio_path):
                            file_paths.append(audio_path)
                            transcriptions.append(text)
    return file_paths, transcriptions

# --- Main Fine-Tuning Workflow for Citrinet ---

# 1. Load data and create Dataset object
file_paths, transcriptions = load_data()
data_dict = {"file_path": file_paths, "transcription": transcriptions}
hf_dataset = Dataset.from_dict(data_dict)
print(f"Created a dataset with {len(hf_dataset)} samples.")

# 2. Instantiate dependencies for Citrinet
augmenter = Augmenter()
model_id = "nvidia/stt_en_citrinet_256_ls" # <-- CHANGED to the Citrinet model
processor = AutoProcessor.from_pretrained(model_id)
model = AutoModelForCTC.from_pretrained(model_id)

# 3. Create the preprocessing function (remains the same)
def prepare_dataset(batch):
    audio, sr = librosa.load(batch["file_path"], sr=16000)
    augmented_audio = augmenter.augment(audio)
    batch["input_values"] = processor(audio=augmented_audio, sampling_rate=16000).input_values[0]
    batch["labels"] = processor(text=batch["transcription"]).input_ids
    return batch

processed_ds = hf_dataset.map(prepare_dataset, remove_columns=hf_dataset.column_names, num_proc=1)

# 4. Define the Data Collator (remains the same)
@dataclass
class DataCollatorCTCWithPadding:
    processor: AutoProcessor
    padding: Union[bool, str] = True
    def __call__(self, features: List[Dict[str, Union[List[int], np.ndarray]]]) -> Dict[str, torch.Tensor]:
        input_features = [{"input_values": feature["input_values"]} for feature in features]
        label_features = [{"input_ids": feature["labels"]} for feature in features]
        batch = processor.pad(input_features, padding=self.padding, return_tensors="pt")
        labels_batch = processor.pad(labels=label_features, padding=self.padding, return_tensors="pt")
        labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)
        batch["labels"] = labels
        return batch

data_collator = DataCollatorCTCWithPadding(processor=processor, padding=True)

# 5. Define Training Arguments, adjusted for Citrinet
training_args = TrainingArguments(
    output_dir="./citrinet-finetuned-augmented", # <-- CHANGED output directory
    per_device_train_batch_size=16, # <-- CHANGED batch size (Citrinet is smaller)
    gradient_accumulation_steps=2,
    num_train_epochs=10, # <-- CHANGED epochs (convolutional models can fine-tune quickly)
    fp16=True,
    learning_rate=1e-4, # <-- CHANGED learning rate (often better for Citrinet/QuartzNet)
    save_total_limit=2,
)

# 6. Instantiate the Trainer (remains the same)
trainer = Trainer(
    model=model,
    data_collator=data_collator,
    args=training_args,
    train_dataset=processed_ds,
    tokenizer=processor.feature_extractor,
)

# 7. Start Fine-Tuning
print("\n--- Starting Model Fine-Tuning with Citrinet and Trainer ---")
trainer.train()
print("\n--- Fine-Tuning Complete ---")

# 8. Save the final model
trainer.save_model("./citrinet-final-model") # <-- CHANGED save path


Scanning directories for .flac files...
Created a dataset with 28539 samples.


TypeError: AddGaussianNoise.__init__() got an unexpected keyword argument 'sample_rate'