In [17]:
# Step 2: Import all required modules
import os
import random
import librosa
import numpy as np
from datasets import Dataset
from transformers import AutoProcessor, AutoModelForCTC, TrainingArguments, Trainer, AutoTokenizer, AutoFeatureExtractor
from dataclasses import dataclass
from typing import Dict, List, Union
from scipy.signal import butter, lfilter
from audiomentations import AddGaussianNoise
import torch
import tensorflow as tf

In [9]:


Training_dirs = "./dataset/LibriSpeech/train-clean-100/"

class Augmenter:

    def __init__(self, sr=16000,
                 noise_prob=0.4, noise_max_amp=0.01,
                 reverb_prob=0.3, reverb_delay=0.025, reverb_decay=0.2,
                 shuffle_prob=0.05, shuffle_segments=3,
                 time_stretch_prob=0.2, time_stretch_range=(0.9, 1.1),
                 gaps_prob=0.08, gaps_n=4, gaps_max_duration=0.1,
                 freq_mask_prob=0.2, freq_mask_n=1,
                 shuffle_seg_dur=0.08, shuffle_overlap=0.02, shuffle_local_range=3):

        self.sr = sr
        self.noise_aug = AddGaussianNoise(p=1.0, max_amplitude=noise_max_amp)

        self.noise_prob = noise_prob
        self.reverb_prob = reverb_prob
        self.reverb_delay = reverb_delay
        self.reverb_decay = reverb_decay
        self.shuffle_prob = shuffle_prob
        self.shuffle_segments = shuffle_segments
        self.time_stretch_prob = time_stretch_prob
        self.time_stretch_range = time_stretch_range
        self.gaps_prob = gaps_prob
        self.gaps_n = gaps_n
        self.gaps_max_duration = gaps_max_duration
        self.freq_mask_prob = freq_mask_prob
        self.freq_mask_n = freq_mask_n
        self.shuffle_seg_dur = shuffle_seg_dur
        self.shuffle_overlap = shuffle_overlap
        self.shuffle_local_range = shuffle_local_range

    def augment(self, audio):
        if not isinstance(audio, tf.Tensor):
            audio = tf.convert_to_tensor(audio, dtype=tf.float32)
        #Set of distortions to be applied randomly with probabilities below
        distortions = []

        # 1. Noise
        # Adds gaussian noise -> makes it slightly grainy
        # Min max amplitude of noise - not set
        # p=1.0 - always apply noise
        if random.random() < self.noise_prob:
            distortions.append('noise')

        # 2. Reverb
        # Echo effect
        # Delay the audio by 0.1 -> reduce volume -> pad it to original length -> add
        if random.random() < self.reverb_prob:
            distortions.append('reverb')

        # 3. Shuffle
        # Break into n segments and concat them randomly
        if random.random() < self.shuffle_prob:
            distortions.append('shuffle')

        # 4. Time stretch
        # Randomly slows (0.9) or speeds (1.1) the audio / doesn't change pitch
        if random.random() < self.time_stretch_prob:
            distortions.append('time_stretch')

        # 5. Missing Gaps
        # Randomly insert silences/gaps in the audio
        if random.random() < self.gaps_prob:
            distortions.append('missing_gaps')

        # 6. Frequency Masking
        # Randomly masks a range of frequencies in the spectrogram
        #Butterworth filter is better than applying freqeuncy masks on spectogram (which already has frequency bins) because real wrld freq loss occurs during sound capture/transmission, affecting the raw audio.
        #Butterworth simulates this situation by removing frequency content from the waveform which can then go thru the rest of pipeline,
        #additionally it affects the phase relations and harmonics naturally in contrast to the crude zeroing of freq bins in spectogram

        if random.random() < self.freq_mask_prob:
            distortions.append('frequency_masking')



        # Apply selected distortions
        for distortion in distortions:
            if distortion == 'noise':
                audio = self._add_noise(audio)

            elif distortion == 'reverb':
                audio = self._add_reverb(audio)

            elif distortion == 'shuffle':
                audio = self._segment_shuffle(audio)

            elif distortion == 'time_stretch':
                audio = self._time_stretch(audio)

            elif distortion == 'missing_gaps':
                audio = self._add_missing_gaps(audio)

            elif distortion == 'frequency_masking':
                audio = self._add_frequency_mask(audio)

        return audio
    def _add_noise(self, audio):
        return self.noise_aug(audio, self.sr)

    def _add_reverb(self, audio):
        delay = int(self.reverb_delay * self.sr)  # Delay in samples (0.05 sec)
        reverb = tf.pad(audio * self.reverb_decay, [[delay, 0]]) #Amplitude scaling -> 0.2
        reverb = reverb[:tf.shape(audio)[0]]
        return audio + reverb

    def _segment_shuffle(self, audio, n_segments=None):
        # Previous method -> shuffle random large segments / unrealistic and destroys linguistic stuff
        # if n_segments is None:
        #     n_segments = self.shuffle_segments
        # segments = np.array_split(audio, n_segments)
        # np.random.shuffle(segments)
        # return np.concatenate(segments)


        #New method -> try to simulate temporal jitter / noise in the time domain -> split into micro-segments which are overlapping  -> the segments are shuffle locally within shuffle range
        seg_len = int(self.shuffle_seg_dur * self.sr)
        overlap = int(self.shuffle_overlap * self.sr)
        local_range = self.shuffle_local_range
        segments = []
        i = 0
        while i < len(audio):
            end = min(i + seg_len, len(audio))
            segments.append(audio[i:end])
            i += seg_len - overlap
        n_regions = min(4, max(1, len(segments) // 10))
        region_indices = random.sample(range(len(segments)), n_regions)
        shuffled = segments.copy()
        for r in region_indices:
            for offset in range(-local_range, local_range + 1):
                idx = r + offset
                if 0 <= idx < len(segments):
                    shift = random.randint(-local_range, local_range)
                    new_idx = max(0, min(len(segments) - 1, idx + shift))
                    shuffled[idx] = segments[new_idx]
        return np.concatenate(shuffled)[:len(audio)]


    def _time_stretch(self, audio):
        return tf.constant(librosa.effects.time_stretch(audio, rate=random.uniform(*self.time_stretch_range)),dtype = tf.float32)

    def _add_missing_gaps(self, audio, n_gaps=None, max_gap_duration=None):
        # if n_gaps is None:
        #     n_gaps = self.gaps_n
        # if max_gap_duration is None:
        #     max_gap_duration = self.gaps_max_duration
        # gap_audio = np.copy(audio)
        # for _ in range(n_gaps):
        #     gap_duration = random.uniform(0.1, max_gap_duration)
        #     gap_samples = int(gap_duration * self.sr)
        #     start = random.randint(0, max(1, len(audio) - gap_samples))
        #     gap_audio[start:start + gap_samples] = 0
        # return gap_audio

        #New method -> fill gaps with low level noises and make edges smoother
        if n_gaps is None:
            n_gaps = self.gaps_n
        if max_gap_duration is None:
            max_gap_duration = self.gaps_max_duration

        gap_audio = tf.identity(audio)

        for _ in range(n_gaps):
            gap_duration = tf.random.uniform([], 0.1, max_gap_duration, dtype=tf.float32)
            gap_samples = tf.cast(gap_duration * tf.cast(self.sr, tf.float32), tf.int32)
            start = tf.random.uniform([], 0, tf.shape(audio)[0] - gap_samples, dtype=tf.int32)

            fade_len = tf.minimum(tf.cast(0.05 * tf.cast(gap_samples, tf.float32), tf.int32), gap_samples // 4)
            fade_out = tf.cast(tf.linspace(1.0, 0.0, fade_len), tf.float32)
            fade_in = tf.cast(tf.linspace(0.0, 1.0, fade_len), tf.float32)

            mid_len = gap_samples - 2 * fade_len

            if mid_len > 0:
            #In case 0 values are needed.
            #     if random.random() < 0.5:
            #         gap_audio[start + fade_len:start + fade_len + mid_len] = 0
            #     else:
              noise = tf.random.normal([mid_len], stddev=0.001, dtype=tf.float32)
              indices = tf.reshape(tf.range(start + fade_len, start + fade_len + mid_len), (-1, 1))
              gap_audio = tf.tensor_scatter_nd_update(gap_audio, indices, noise)

        fade_indices = tf.reshape(tf.range(start + fade_len + mid_len, start + gap_samples), (-1, 1))
        fade_vals = gap_audio[start + fade_len + mid_len:start + gap_samples] * fade_in
        gap_audio = tf.tensor_scatter_nd_update(gap_audio, fade_indices, fade_vals)

        return gap_audio[:len(audio)]
    #APPLY FREQUENCY MASKING USING BUTTERWORTH FILTER
    #n_masks -> how many bands will be filtered out
    #A 16 kHz sampler can only capture up to 8 kHz frequencies because you need at least two samples per wave cycle to know what the wave looks like -> nyquist
    def _add_frequency_mask(self, audio, n_masks=None):
      if n_masks is None:
          n_masks = self.freq_mask_n
      if isinstance(audio, tf.Tensor):
        masked_audio = audio.numpy().copy()
      else:
        masked_audio = audio.copy()
      nyquist = self.sr/2
      for _ in range(n_masks):
        l_freq = random.uniform(500,5000)
        h_freq = l_freq + random.uniform(500,2000)
        l_freq = min(l_freq,nyquist-100)
        h_freq = min(h_freq,nyquist-100)
        b,a = butter(N=4, Wn=[l_freq,h_freq], btype= "bandstop", fs = self.sr) #N=4 th order -> a dip in frequency response where that removal band is with smooth edges
        masked_audio = lfilter(b,a,masked_audio)
      return masked_audio

In [10]:

def load_data():
    file_paths, transcriptions = [], []
    print("Scanning directories for .flac files...")
    for speaker_id in os.listdir(Training_dirs):
        speaker_path = os.path.join(Training_dirs, speaker_id)
        if not os.path.isdir(speaker_path): continue
        for chapter_id in os.listdir(speaker_path):
            chapter_path = os.path.join(speaker_path, chapter_id)
            if not os.path.isdir(chapter_path): continue
            trans_file = f"{speaker_id}-{chapter_id}.trans.txt"
            trans_path = os.path.join(chapter_path, trans_file)
            if os.path.exists(trans_path):
                with open(trans_path, 'r') as f:
                    for line in f:
                        parts = line.strip().split(' ', 1)
                        file_id, text = parts[0], parts[1]
                        audio_path = os.path.join(chapter_path, f"{file_id}.flac")
                        if os.path.exists(audio_path):
                            file_paths.append(audio_path)
                            transcriptions.append(text)
    return file_paths, transcriptions


In [20]:

# --- Main Fine-Tuning Workflow using the Trainer ---

# 1. Load data and create Dataset object
file_paths, transcriptions = load_data()
data_dict = {"file_path": file_paths, "transcription": transcriptions}
hf_dataset = Dataset.from_dict(data_dict)
print(f"Created a dataset with {len(hf_dataset)} samples.")

# 2. Instantiate dependencies
augmenter = Augmenter()
model_id = "nvidia/stt_en_citrinet_256_ls"

# 1. Load the tokenizer and feature extractor separately
#tokenizer = AutoTokenizer.from_pretrained(model_id)
#feature_extractor = AutoFeatureExtractor.from_pretrained(model_id)

# 2. Combine them manually into a processor object
#    (Wav2Vec2Processor is a generic class used for many CTC models)
processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
model = AutoModelForCTC.from_pretrained(model_id, trust_remote_code=True)

# 3. Create the preprocessing function
def prepare_dataset(batch):
    audio, sr = librosa.load(batch["file_path"], sr=16000)
    augmented_audio = augmenter.augment(audio)
    batch["input_values"] = processor(audio=augmented_audio, sampling_rate=16000).input_values[0]
    batch["labels"] = processor(text=batch["transcription"]).input_ids
    return batch

processed_ds = hf_dataset.map(prepare_dataset, remove_columns=hf_dataset.column_names, num_proc=-1)

# 4. Define the Data Collator
@dataclass
class DataCollatorCTCWithPadding:
    processor: AutoProcessor
    padding: Union[bool, str] = True
    def __call__(self, features: List[Dict[str, Union[List[int], np.ndarray]]]) -> Dict[str, torch.Tensor]:
        input_features = [{"input_values": feature["input_values"]} for feature in features]
        label_features = [{"input_ids": feature["labels"]} for feature in features]

        # Pad inputs and labels, returning PyTorch tensors
        batch = processor.pad(input_features, padding=self.padding, return_tensors="pt")
        labels_batch = processor.pad(labels=label_features, padding=self.padding, return_tensors="pt")

        # Replace padding with -100 to be ignored by loss function
        labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)
        batch["labels"] = labels
        return batch

data_collator = DataCollatorCTCWithPadding(processor=processor, padding=True)

# 5. Define Training Arguments
training_args = TrainingArguments(
    output_dir="./citrinet-finetuned-augmented", # <-- CHANGED output directory
    per_device_train_batch_size=16, # <-- CHANGED batch size (Citrinet is smaller)
    gradient_accumulation_steps=2,
    num_train_epochs=10, # <-- CHANGED epochs (convolutional models can fine-tune quickly)
    fp16=True,
    learning_rate=1e-4, # <-- CHANGED learning rate (often better for Citrinet/QuartzNet)
    save_total_limit=2,
)

# 6. Instantiate the Trainer
trainer = Trainer(
    model=model,
    data_collator=data_collator,
    args=training_args,
    train_dataset=processed_ds,
    tokenizer=processor.feature_extractor,
)

print("\n--- Starting Model Fine-Tuning with citrinet and Trainer ---")
trainer.train()
print("\n--- Fine-Tuning Complete ---")
trainer.save_model("./citrinet-final-model")

Scanning directories for .flac files...
Created a dataset with 28539 samples.


OSError: Can't load feature extractor for 'nvidia/stt_en_citrinet_256_ls'. If you were trying to load it from 'https://huggingface.co/models', make sure you don't have a local directory with the same name. Otherwise, make sure 'nvidia/stt_en_citrinet_256_ls' is the correct path to a directory containing a preprocessor_config.json file