In [None]:
# +-----------------------------------------------------------------------------------------+
# | NVIDIA-SMI 550.120                Driver Version: 550.120        CUDA Version: 12.4     |
# |-----------------------------------------+------------------------+----------------------+
# | GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
# | Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
# |                                         |                        |               MIG M. |
# |=========================================+========================+======================|
# |   0  NVIDIA GeForce RTX 4080 ...    Off |   00000000:01:00.0 Off |                  N/A |
# |  0%   41C    P8              3W /  320W |    7752MiB /  16376MiB |      0%      Default |
# |                                         |                        |                  N/A |
# +-----------------------------------------+------------------------+----------------------+

In [2]:
import os
import re
import numpy as np
import pandas as pd
import librosa
import soundfile as sf
import librosa.display
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter
import torchaudio.transforms as T
from tqdm import tqdm
from sklearn.model_selection import train_test_split
from torchinfo import summary
import jiwer
import editdistance
import random
import time
import heapq
import difflib
import kenlm
from IPython.display import Audio

In [3]:
AUDIO_FOLDER = "wavs"
CSV_PATH = "data/metadata.csv"

OUTPUT_FOLDER = "processed_data_aug"  
TRAIN_FOLDER = os.path.join(OUTPUT_FOLDER, "train")
VAL_FOLDER = os.path.join(OUTPUT_FOLDER, "val")
TEST_FOLDER = os.path.join(OUTPUT_FOLDER, "test")
METADATA_FILE_TRAIN = "train_metadata.csv"
METADATA_FILE_VAL = "val_metadata.csv"
METADATA_FILE_TEST = "test_metadata.csv"
WAVS_FILE = 'aug_audio_folder'

DEVICE =  torch.device("cuda") #'gpu' if torch.cuda.is_available() else 'cpu'
os.makedirs(TRAIN_FOLDER, exist_ok=True)
os.makedirs(VAL_FOLDER, exist_ok=True)
os.makedirs(TEST_FOLDER, exist_ok=True)

df = pd.read_csv(CSV_PATH)
transcript_dict = dict(zip(df["file_id"], df["transcription"]))

SAMPLE_RATE = 16000
N_MELS = 128
HOP_LENGTH = 512
EPOCHS = 50

## **Data Preprocess**

In [4]:
abbreviation_mapping = {
    "Mr.": "Mister", "Mrs.": "Misess", "Dr.": "Doctor", "No.": "Number", "St.": "Saint", "Co.": "Company", "Jr.": "Junior",
    "Maj.": "Major", "Gen.": "General", "Drs.": "Doctors", "Rev.": "Reverend", "Lt.": "Lieutenant", "Hon.": "Honorable",
    "Sgt.": "Sergeant", "Capt.": "Captain", "Esq.": "Esquire", "Ltd.": "Limited", "Col.": "Colonel", "Ft.": "Fort"
}

class AudioAugmentation:
    def __init__(self, audio_folder, train_folder, sample_rate=SAMPLE_RATE, n_mels=N_MELS, hop_length=HOP_LENGTH, wavs_file=WAVS_FILE):
        self.audio_folder = audio_folder
        self.train_folder = train_folder
        self.wavs_file = wavs_file
        self.sample_rate = sample_rate
        self.n_mels = n_mels
        self.hop_length = hop_length
        self.train_metadata_records = []
    
    def clean_and_expand_text(self, text):
        if pd.isna(text):
            return ""
        text = str(text).lower()
        for abbr, full_form in abbreviation_mapping.items():
            text = text.replace(abbr.lower(), full_form.lower())
        text = text.replace(" -- ", " ")
        text = text.replace("ü", "u")
        text = text.replace("etc.", "etcetera")
        text = text.replace("i.e.", "i e ")
        text = re.sub(r'[^a-z\s]', '', text.lower())
        return text

    def time_stretch(self, y, rate=None):
        if rate is None:
            rate = random.uniform(0.9, 1.1)
        return librosa.effects.time_stretch(y, rate=rate)

    def pitch_shift(self, y, sr, n_steps=None):
        if n_steps is None:
            n_steps = random.randint(-3, 3) 
        return librosa.effects.pitch_shift(y, sr=sr, n_steps=n_steps)

    def add_noise(self, y, noise_level=0.005):
        noise = np.random.normal(0, noise_level, y.shape)
        return y + noise

    def apply_specaugment(self, mel_spec_db, warp=False):
        time_mask = T.TimeMasking(time_mask_param=20)
        freq_mask = T.FrequencyMasking(freq_mask_param=8)
        
        mel_spec_db = time_mask(torch.tensor(mel_spec_db)).numpy()
        mel_spec_db = freq_mask(torch.tensor(mel_spec_db)).numpy()
        mel_spec_db[mel_spec_db == 0] = -80.0 
        if warp and random.random() < 0.05:
            time_warp = T.TimeMasking(time_mask_param=50)
            mel_spec_db = time_warp(torch.tensor(mel_spec_db)).numpy()
            mel_spec_db[mel_spec_db == 0] = -80.0 
        return mel_spec_db
    
    def save_spectrogram(self, y, sr, file_name):
        mel_spec = librosa.feature.melspectrogram(y=y, sr=sr, n_mels=self.n_mels, hop_length=self.hop_length)
        mel_spec_db = librosa.power_to_db(mel_spec,ref=1.0)
        output_path = os.path.join(self.train_folder, f"{file_name}.npy")
        np.save(output_path, mel_spec_db)
        return output_path
    
    def augment_and_save(self, y, sr, file_name, aug_type=None):
        y_aug = y
        if random.random() < 0.6:
            if aug_type is None:
                aug_type = random.choices(["stretch", "pitch", "noise" , "mixer"], weights=[0.20, 0.3, 0.30,0.15], k=1)[0]            
            if aug_type == "stretch":
                y_aug = self.time_stretch(y, random.choice([0.8, 1.2]))
            elif aug_type == "pitch":
                y_aug = self.pitch_shift(y, sr, random.choice([-2, -1, 1, 2]))
            elif aug_type == "noise":
                y_aug = self.add_noise(y, noise_level=random.uniform(0.002, 0.01))
            elif aug_type == "mixer":
                aug_choice = random.choices(
                    ["stretch", "pitch", "noise"], 
                    weights=[0.25, 0.35, 0.25], 
                    k=2
                )        
                for mix_aug in aug_choice:
                    if mix_aug == "stretch":
                        y_aug = self.time_stretch(y_aug, random.choice([0.8, 1.2]))  
                    elif mix_aug == "pitch":
                        y_aug = self.pitch_shift(y_aug, sr, random.choice([-2, -1, 1, 2]))  
                    elif mix_aug == "noise":
                        y_aug = self.add_noise(y_aug, noise_level=random.uniform(0.002, 0.01))
        else: 
            aug_type = "specaugment"
            
        wav_output_path = os.path.join(self.wavs_file, f"{file_name}_{aug_type}.wav")
        sf.write(wav_output_path, y_aug, sr)
        
        if aug_type == "specaugment":
            mel_spec = librosa.feature.melspectrogram(y=y_aug, sr=sr, n_mels=self.n_mels, hop_length=self.hop_length)
            mel_spec_db_aug = librosa.power_to_db(mel_spec, ref=np.max)
            mel_spec_db_aug = self.apply_specaugment(mel_spec_db_aug, warp=True)            
            aug_output_path = os.path.join(self.train_folder, f"{file_name}_specaugment.npy")
            np.save(aug_output_path, mel_spec_db_aug)
        else:
            aug_output_path = self.save_spectrogram(y_aug, sr, f"{file_name}_{aug_type}")

        self.train_metadata_records.append([file_name + f"_{aug_type}", aug_output_path, self.clean_and_expand_text(transcript_dict[file_name])])
        return aug_output_path

    def process_audio_files(self, train_df, transcript_dict):
        for file_id in tqdm(train_df["file_id"], desc="Saving original train data"):
            file_path = os.path.join(self.audio_folder, f"{file_id}.wav")
            file_name = os.path.splitext(file_id)[0]
            
            if os.path.exists(file_path):
                try:
                    y, sr = librosa.load(file_path, sr=self.sample_rate)
                    orig_output_path = self.save_spectrogram(y, sr, file_name)
                    self.train_metadata_records.append([file_name, orig_output_path, self.clean_and_expand_text(transcript_dict[file_name])])
                except Exception as e:
                    print(f"Error processing {file_name}: {e}")
        
        # Apply augmentation
        for file_id in tqdm(train_df["file_id"], desc="Applying augmentation to train data"):
            file_path = os.path.join(self.audio_folder, f"{file_id}.wav")
            file_name = os.path.splitext(file_id)[0]

            if os.path.exists(file_path):
                try:
                    y, sr = librosa.load(file_path, sr=self.sample_rate)
                    self.augment_and_save(y, sr, file_name)
                except Exception as e:
                    print(f"Error processing {file_name}: {e}")

        return self.train_metadata_records
    
if os.path.exists('processed_data_aug/train'):
    print("File already downloaded")
else:
    train_df = df.sample(frac=0.7, random_state=42)
    remain_df = df.drop(train_df.index)

    val_df = remain_df.sample(frac=0.5,random_state=42)
    test_df = remain_df.drop(val_df.index)

    audio_augmentor = AudioAugmentation(AUDIO_FOLDER, TRAIN_FOLDER, SAMPLE_RATE, N_MELS, HOP_LENGTH,WAVS_FILE)
    train_metadata_records = audio_augmentor.process_audio_files(train_df, transcript_dict)

    train_metadata_df = pd.DataFrame(train_metadata_records, columns=["file_id", "spectrogram_path", "transcript"])
    train_metadata_df.to_csv(METADATA_FILE_TRAIN, index=False)
    print(f"✅ Training data processing complete! Metadata saved in {METADATA_FILE_TRAIN}")

# def max_time_stamps():  
#     max_time = 0 
#     for file_name in tqdm(os.listdir(trainn_folder), desc="Max time stamps computing"):
#         if file_name.endswith(".npy"):
#             file_path = os.path.join(trainn_folder, file_name)
#             mel_spec_db = np.load(file_path)
#             time_steps = mel_spec_db.shape[1]
#             max_time = max(time_steps,max_time)
#     return max_time

# def pad_or_truncate_spectrogram(mel_spec_db, max_time):        
#     if mel_spec_db.shape[1] < max_time:
#         pad_width = max_time - mel_spec_db.shape[1]
#         mel_spec_db = np.pad(mel_spec_db, ((0, 0), (0, pad_width)), mode='constant',constant_values=-80)
#     else:
#         mel_spec_db = mel_spec_db[:, :max_time]    
#     return mel_spec_db   

# trainn_folder = 'processed_data_aug/train' 
# max_time_steps = max_time_stamps()
# print(max_time_steps)

# for file_name in tqdm(os.listdir(trainn_folder), desc="Padding spectrogram files"):
#     if file_name.endswith(".npy"):
#         file_path = os.path.join(trainn_folder, file_name)
#         mel_spec_db = np.load(file_path)      
#         mel_spec_db = pad_or_truncate_spectrogram(mel_spec_db, max_time_steps)
#         np.save(file_path, mel_spec_db)

File already downloaded


In [5]:
if os.path.exists('processed_data_aug/val'):
    print("File already downloaded")
else:
    audio_augmentor = AudioAugmentation(AUDIO_FOLDER, TRAIN_FOLDER, SAMPLE_RATE, N_MELS, HOP_LENGTH)
    val_metadata_records = []
    for file_id in tqdm(val_df["file_id"], desc="Processing validation data"):
        file_path = os.path.join(AUDIO_FOLDER, f"{file_id}.wav")
        file_name = os.path.splitext(file_id)[0]

        try:
            y, sr = librosa.load(file_path, sr=SAMPLE_RATE)
            mel_spec = librosa.feature.melspectrogram(y=y, sr=sr, n_mels=N_MELS, hop_length=HOP_LENGTH)
            mel_spec_db = librosa.power_to_db(mel_spec, ref=np.max)

            val_output_path = os.path.join(VAL_FOLDER, f"{file_name}.npy")
            np.save(val_output_path, mel_spec_db)
    #         val_metadata_records.append([file_name, val_output_path, audio_augmentor.clean_and_expand_text(transcript_dict[file_name])])
            cleaned_transcript = audio_augmentor.clean_and_expand_text(transcript_dict.get(file_name, "No transcript found"))

            val_metadata_records.append([file_name, val_output_path, cleaned_transcript])
        except Exception as e:
            print(f"Error processing {file_name}: {e}")

    val_metadata_df = pd.DataFrame(val_metadata_records, columns=["file_id", "spectrogram_path", "transcript"])
    val_metadata_df.to_csv(METADATA_FILE_VAL, index=False)
    print(f"✅ Validation data processing complete! Metadata saved in {METADATA_FILE_VAL}") 

    test_metadata_records = []
    for file_id in tqdm(test_df["file_id"], desc="Processing test data"):
        file_path = os.path.join(AUDIO_FOLDER, f"{file_id}.wav")
        file_name = os.path.splitext(file_id)[0]

        try:
            y, sr = librosa.load(file_path, sr=SAMPLE_RATE)
            mel_spec = librosa.feature.melspectrogram(y=y, sr=sr, n_mels=N_MELS, hop_length=HOP_LENGTH)
            mel_spec_db = librosa.power_to_db(mel_spec, ref=np.max)
            test_output_path = os.path.join(TEST_FOLDER, f"{file_name}.npy")
            np.save(test_output_path, mel_spec_db)
            cleaned_transcript = audio_augmentor.clean_and_expand_text(transcript_dict.get(file_name, "No transcript found"))

            test_metadata_records.append([file_name, test_output_path, cleaned_transcript])
        except Exception as e:
            print(f"Error processing {file_name}: {e}")

    test_metadata_df = pd.DataFrame(test_metadata_records, columns=["file_id", "spectrogram_path", "transcript"])
    test_metadata_df.to_csv(METADATA_FILE_TEST, index=False)
    print(f"✅ TESTING data processing complete! Metadata saved in {METADATA_FILE_TEST}")      

File already downloaded


# MODEL TRAINING 

In [6]:
class ASRDataset(Dataset):
    def __init__(self, metadata_df, char_to_idx):
        self.metadata_df = metadata_df
        self.char_to_idx = char_to_idx

    def __len__(self):
        return len(self.metadata_df)

    def __getitem__(self, idx):
        spectrogram_path = self.metadata_df.iloc[idx]["spectrogram_path"]
        transcript = self.metadata_df.iloc[idx]["transcript"]
        spectrogram = np.load(spectrogram_path)
        encoded_text = [self.char_to_idx[char] for char in transcript if char in self.char_to_idx]
        return torch.tensor(spectrogram, dtype=torch.float32), torch.tensor(encoded_text, dtype=torch.long)

def collate_fn(batch):
    filtered_batch = []  
    for spectrogram, text in batch:
        spec_length = spectrogram.shape[1]
        text_length = len(text)

        if text_length == 0:
            continue 

        ratio = spec_length / text_length
        if ratio >= 1:
            filtered_batch.append((spectrogram, text))

    if len(filtered_batch) == 0:
        return torch.tensor([]), torch.tensor([]), torch.tensor([]), torch.tensor([])
    
    spectrograms, texts = zip(*filtered_batch)
    spectrogram_lengths = torch.tensor([spec.shape[1] for spec in spectrograms])
    text_lengths = torch.tensor([len(text) for text in texts])

    # Pad spectrograms
    max_time_steps = max(spectrogram_lengths)
    padded_spectrograms = torch.zeros(len(spectrograms), N_MELS, max_time_steps)
    for i, spec in enumerate(spectrograms):
        padded_spectrograms[i, :, :spec.shape[1]] = spec

    # Pad texts
    max_text_length = max(text_lengths)
    padded_texts = torch.zeros(len(texts), max_text_length, dtype=torch.long)
    
    for i, text in enumerate(texts):
        padded_texts[i, :len(text)] = text

    return  padded_spectrograms,spectrogram_lengths, padded_texts, text_lengths
def load_arpa_lm(arpa_file):
    lm = {}  # Store probabilities with words as string keys
    backoff = {}  # Store backoff weights

    with open(arpa_file, 'r', encoding="utf-8") as f:
        for line in f:
            line = line.strip()            
            if line.startswith("\\") or not line:
                continue  # Skip headers and empty lines            
            parts = line.split("\t")
            if len(parts) >= 2:
                prob = float(parts[0])
                ngram = parts[1]  # Store as a plain string instead of a tuple                
                if len(parts) == 3:  # If there's a backoff weight
                    backoff[ngram] = float(parts[2])                
                lm[ngram] = prob  # Store LM probability    
    return lm, backoff        
# def get_lm_score(seq,lm,backoff):
#     if seq in lm:
#         return lm[seq]
#     if len(seq) > 1:
#         backoff_word = seq[:-1]
#         if backoff_word in backoff:
#             return backoff[backoff_word] + get_lm_score(seq[1:], lm, backoff)
#     return -10  

# def get_best_lm_replacement(word,context, lm, backoff):
#     best_score = float('-inf')
#     best_word = word  
#     closest_match = difflib.get_close_matches(word, lm.keys(), n=3, cutoff=0.8)  # You can adjust cutoff
# #     if closest_match:
# #         best_word = closest_match[0]
# #         best_score = lm[best_word]
# #     else:
# #         best_word = word
#     for candidate in closest_match:
#         candidate_ngram = context + " " + candidate
#         score = get_lm_score(candidate_ngram, lm, backoff)
#         if score > best_score:
#             best_word = candidate
#             best_score = score 
#     return best_word

# def post_process_with_lm(decoded_seq, lm, backoff):
#     words = decoded_seq.strip().split()
#     final_transcription = []
#     for i,word in enumerate(words):
#         word = word.strip()
#         context = " ".join(final_transcription[max(0,i-2):i])
        
#         if word in lm: 
# #             lm_score = lm[word] 
#             final_transcription.append(word)  # Keep the word if it's correct
#         else:
#             # If the word is not in the LM, we look for the best replacement word from the LM
#             best_replacement = get_best_lm_replacement(word,context, lm, backoff)
#             final_transcription.append(best_replacement)
#     return " ".join(final_transcription)

# def beam_search_decoder_lm(probs, char_map, k, lm, backoff):
#     T, V = probs.shape
#     beam = [("", 0.0)]  # Start with an empty sequence
    
#     for t in range(T):
#         new_beam = {}
        
#         # Iterate through current beam sequences
#         for seq, log_probs in beam:
#             # Try appending each possible character to the sequence
#             for v in range(V):
#                 if v == 0:  # Blank token
#                     new_seq = seq  # Keep the sequence unchanged for blank
#                 elif len(seq) > 0 and seq[-1] == char_map[v]:  # Avoid repetition of the same character
#                     new_seq = seq
#                 else:
#                     new_seq = seq + char_map[v]
                
#                 new_probs = log_probs + probs[t, v]  # Update log-probabilities

#                 # Store the new sequence with updated probabilities
#                 if new_seq in new_beam:
#                     new_beam[new_seq] = max(new_beam[new_seq], new_probs)
#                 else:
#                     new_beam[new_seq] = new_probs
        
#         # Sort beam by probabilities and keep top k sequences
#         beam = heapq.nlargest(k, new_beam.items(), key=lambda x: x[1])
    
#     # Select the best sequence after beam search
#     best_seq = beam[0][0]#'there were more varied and at times especialy when ber had circulated frely more upororious diversiaons' 
# #     print(f"Best sequence before LM: {best_seq}")
#     # Join the words with spaces
#     final_transcription = post_process_with_lm(best_seq, lm, backoff)
    
#     return final_transcription
# def ctc_decode_lm(output, idx_to_char):    
#     pred_indices = output.argmax(dim=-1).cpu().numpy()  
#     decoded_text = []
    
#     for seq in pred_indices:  # Iterate over batch
#         chars = []
#         prev_char = None  
#         for idx in seq:
#             if idx != 0 and idx != prev_char:
#                 chars.append(idx_to_char[idx])
#             prev_char = idx
#         decoded_text.append("".join(chars))  
#     final_transcription = [post_process_with_lm(text, lm, backoff) for text in decoded_text]
#     return final_transcription

def ctc_decode(output, idx_to_char):    
    pred_indices = output.argmax(dim=-1).cpu().numpy()  
    decoded_text = []
    
    for seq in pred_indices:  # Iterate over batch
        chars = []
        prev_char = None  
        for idx in seq:
            if idx != 0 and idx != prev_char:  
                chars.append(idx_to_char[idx])
            prev_char = idx
        decoded_text.append("".join(chars))  
    return decoded_text

In [7]:
class ASRModel(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim , num_layers=4, dropout=0.3,fil_dim=32):
        super(ASRModel, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(1, fil_dim, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(fil_dim),
            nn.GELU(),            
            nn.Conv2d(fil_dim, fil_dim*2, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(fil_dim*2),
            nn.GELU(),  
        )
        self.conv1d = nn.Sequential(
            nn.Conv1d(in_channels=64*input_dim, out_channels = hidden_dim, kernel_size=3, stride=1, padding=1),  
            nn.BatchNorm1d(hidden_dim),
            nn.GELU(),
        )

        self.bi_gru = nn.GRU(
            input_size=hidden_dim, 
            hidden_size=hidden_dim, 
            num_layers=num_layers, 
            batch_first=True, 
            bidirectional=True
        )
        # self.fc = nn.Linear(hidden_dim * 2, output_dim)
        self.fc = nn.Sequential(
            nn.LayerNorm(hidden_dim*2),
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, output_dim) 
        )

    def forward(self, x, lengths):
        x = x.unsqueeze(1) 
        x = self.conv(x) # (Batch, channels, Features, time) ---- [32, 64, 128, 395]
        x = x.permute(0, 3, 2, 1)  # (Batch, channels, Features, time) ---- [32, 64, 128, 395]
        x = x.reshape(x.size(0), x.size(1), -1)  # Flatten for RNN input ----- [32, 395, 8192]
        x = x.permute(0, 2, 1)
        
        x = self.conv1d(x)   # ---- [32, 512, 395]
        x = x.permute(0, 2, 1)   # ---- [32, 395, 512]
        
        lengths = lengths.cpu().int()
        packed_x = nn.utils.rnn.pack_padded_sequence(x, lengths, batch_first=True, enforce_sorted=False)#---[15264, 256]
        packed_output, _ = self.bi_gru(packed_x) #----- [15264, 512]
        output, _ = nn.utils.rnn.pad_packed_sequence(packed_output, batch_first=True) # ----- [32, 477, 512]
        return self.fc(output)
        
class ASRMetrics:
    def __init__(self):
        self.cer = []
        self.wer = []
        
    def calculate_cer(self, reference, hypothesis):
        return editdistance.eval(reference, hypothesis) / max(len(reference), 1)

    def calculate_wer(self, reference, hypothesis):
        ref_words = reference.split()
        hyp_words = hypothesis.split()
        return editdistance.eval(ref_words, hyp_words) / max(len(ref_words), 1)

    def log_metrics(self, epoch, predictions, references, writer):
        cer_epoch = np.mean([self.calculate_cer(r, p) for r, p in zip(references, predictions)])
        wer_epoch = np.mean([self.calculate_wer(r, p) for r, p in zip(references, predictions)])
        writer.add_scalar("CER", cer_epoch, epoch)
        writer.add_scalar("WER", wer_epoch, epoch)
        print(f"Epoch {epoch+1}: CER: {cer_epoch:.4f}, WER: {wer_epoch:.4f}")

class EarlyStopping:
    def __init__(self, patience=7, delta=0):
        self.patience = patience
        self.delta = delta
        self.best_loss = None
        self.counter = 0
        self.early_stop = False

    def __call__(self, val_loss):
        if self.best_loss is None:
            self.best_loss = val_loss
        elif val_loss > self.best_loss - self.delta:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_loss = val_loss
            self.counter = 0

In [8]:
asr_metrics = ASRMetrics()
def train_model(model, train_loader, val_loader, epochs=EPOCHS, device=DEVICE):
    model.to(device)
    criterion = nn.CTCLoss(blank=0) #, reduction='mean', zero_infinity=True
    # optimizer = optim.Adam(model.parameters(), lr=0.0001)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.0003) # ,weight_decay=1e-4
    early_stopper = EarlyStopping()
    writer = SummaryWriter() #log_dir=LOG_DIR
    

    for epoch in range(epochs):
        model.train()
        epoch_loss = 0
        first_batch = True
        
        for spectrograms, spectrogram_lengths, texts, text_lengths in tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}"):
            optimizer.zero_grad()
            outputs = model(spectrograms.to(device), spectrogram_lengths.to(device))
            outputs = outputs.log_softmax(dim=-1).transpose(0, 1)
            loss = criterion(outputs, texts.to(device), spectrogram_lengths.to(device), text_lengths.to(device))
            loss.backward()
            
            # Log gradients & weights
            # for name, param in model.named_parameters():
            #     if param.requires_grad:
            #         writer.add_histogram(f"Gradients/{name}", param.grad, epoch)
            #         writer.add_histogram(f"Weights/{name}", param, epoch)

            if first_batch:
                writer.add_graph(model, (spectrograms.to(device), spectrogram_lengths.to(device)))
                first_batch = False  # Avoid logging graph repeatedly
#             for name, param in model.named_parameters():
#                 if param.grad is not None:
#                     print(f"Gradient for {name}:", param.grad.abs().mean().item())            
            # torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=2.0)
            optimizer.step()
            epoch_loss += loss.item()
        avg_train_loss = epoch_loss / len(train_loader)
        writer.add_scalar("Loss/Train", avg_train_loss, epoch)

# Validation
        model.eval()
        val_loss = 0
        predictions, references = [], []
        with torch.no_grad():
            for spectrograms, spectrogram_lengths, texts, text_lengths in val_loader:
                outputs = model(spectrograms.to(device), spectrogram_lengths.to(device))
                outputs = outputs.log_softmax(dim=-1).transpose(0, 1) # (time, batch, vocab)
                loss = criterion(outputs, texts.to(device), spectrogram_lengths.to(device), text_lengths.to(device))
                val_loss += loss.item()                
                outputs = outputs.transpose(0, 1) # correct shape for CER (batch, time, vocab)
                b_s = outputs.shape[0]
                # decoded_preds = ["".join(idx_to_char[idx.item()] for idx in output.argmax(dim=-1) if idx.item() != 0) for output in outputs]
                decoded_preds = ctc_decode(outputs, idx_to_char)
                decoded_refs = ["".join(idx_to_char[idx.item()] for idx in text if idx.item() != 0) for text in texts]
#                 print(f"prediction:{decoded_preds}")# , reference:{decoded_refs}
                predictions.extend(decoded_preds)
                references.extend(decoded_refs)

        avg_val_loss = val_loss / len(val_loader)
        writer.add_scalar("Loss/Validation", avg_val_loss, epoch)
        print(f"Epoch {epoch+1}/{epochs}, Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}")

        asr_metrics.log_metrics(epoch, predictions, references , writer)
        
        early_stopper(avg_val_loss)
        if early_stopper.early_stop:
            print("Early stopping triggered!")
            break

    writer.close()
    torch.save(model.state_dict(), "models/augmented/conv1d/asr_model--4.pth")

In [9]:
lm, backoff = load_arpa_lm("kenlm/build/3gram.arpa") 
train_metadata_df = pd.read_csv(METADATA_FILE_TRAIN)
val_metadata_df = pd.read_csv(METADATA_FILE_VAL)
test_metadata_df = pd.read_csv(METADATA_FILE_TEST)

train_metadata_df["transcript"] = train_metadata_df["transcript"].fillna("").astype(str)
val_metadata_df["transcript"] = val_metadata_df["transcript"].fillna("").astype(str)

# vocab = sorted(set("".join(train_metadata_df["transcript"] + "".join(val_metadata_df["transcript"]))))
vocab = [' ', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z']
char_to_idx = {char: idx + 1 for idx, char in enumerate(vocab)}
idx_to_char = {idx: char for char, idx in char_to_idx.items()}

In [None]:
val_metadata_df = val_metadata_df.drop_duplicates()
train_dataset = ASRDataset(train_metadata_df, char_to_idx)
val_dataset = ASRDataset(val_metadata_df, char_to_idx)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, collate_fn=collate_fn)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, collate_fn=collate_fn)

# Train model
model = ASRModel(N_MELS,512, len(vocab) + 1)
train_model(model, train_loader, val_loader, epochs=EPOCHS, device=DEVICE)

# Testing Model

In [None]:
def get_lm_score(seq,lm,backoff):
    if seq in lm:
        return lm[seq]
    if len(seq) > 1:
        backoff_word = seq[:-1]
        if backoff_word in backoff:
            return backoff[backoff_word] + get_lm_score(seq[1:], lm, backoff)
    return -10         

def get_best_lm_replacement(word,context, lm, backoff,thres=2.0):
    best_score = float('-inf')
    best_word = word  
    closest_match = difflib.get_close_matches(word, lm.keys(), n=3, cutoff=0.78)  # You can adjust cutoff
#     if closest_match:
#         best_word = closest_match[0]
#         best_score = lm[best_word]
#     else:
#         best_word = word
    for candidate in closest_match:
        candidate_ngram = str(context) + " " + candidate
        score = get_lm_score(candidate_ngram, lm, backoff)
        if score > best_score:
            best_word = candidate
            best_score = score        
    return best_word

def post_process_with_lm(decoded_seq, lm, backoff):
    words = decoded_seq.strip().split()
    final_transcription = []
    for i,word in enumerate(words):
        word = word.strip()
        context = " ".join(final_transcription[max(0,i-2):i])
        
        if word in lm: 
#             lm_score = get_lm_score(context + word,lm,backoff)
            final_transcription.append(word)  # Keep the word if it's correct
        else:
#             If the word is not in the LM, we look for the best replacement word from the LM
            best_replacement = get_best_lm_replacement(word,context, lm, backoff)
            final_transcription.append(best_replacement) 
    return " ".join(final_transcription)

def ctc_decode_lm(output, idx_to_char):    
    pred_indices = output.argmax(dim=-1).cpu().numpy()  
    decoded_text = []
    
    for seq in pred_indices:  # Iterate over batch
        chars = []
        prev_char = None  
        for idx in seq:
            if idx != 0 and idx != prev_char:  
                chars.append(idx_to_char[idx])
            prev_char = idx
        decoded_text.append("".join(chars))  
    final_transcription = [post_process_with_lm(text, lm, backoff) for text in decoded_text]
    return final_transcription


test_metadata_df = pd.read_csv(METADATA_FILE_TEST)
test_metadata_df = test_metadata_df.drop_duplicates()
test_dataset = ASRDataset(test_metadata_df, char_to_idx)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False, collate_fn=collate_fn)
predictions, references = [], []
epoch = 0
writer = SummaryWriter()
model = model.to(DEVICE)
with torch.no_grad():
    for spectrograms, spectrogram_lengths, texts, text_lengths in tqdm(test_loader):
        outputs = model(spectrograms.to(DEVICE), spectrogram_lengths.to(DEVICE))
#         probs = torch.nn.functional.log_softmax(outputs, dim=-1).squeeze().numpy()
        outputs = outputs.log_softmax(dim=-1).transpose(0, 1) # (time, batch, vocab)                   
        outputs = outputs.transpose(0, 1)  # correct shape for CER (batch, time, vocab)
        decoded_preds = ctc_decode_lm(outputs, idx_to_char)
        decoded_refs = ["".join(idx_to_char[idx.item()] for idx in text if idx.item() != 0) for text in texts] 
        predictions.extend(decoded_preds)
        references.extend(decoded_refs)
    df_pred = pd.DataFrame({'prediction':predictions,'reference':references})
    df_pred.to_csv('pred_vs_ref.csv',index=False,encoding="utf-8")
    asr_metrics.log_metrics(epoch, predictions, references , writer)

In [11]:
total_data = pd.concat([train_metadata_df,val_metadata_df],ignore_index=True,axis=0)
total_data = ASRDataset(total_data, char_to_idx)
total_data_loader = DataLoader(total_data, batch_size=64, shuffle=False, collate_fn=collate_fn)
predictions, references = [], []
epoch = 0
writer = SummaryWriter()
model = model.to(DEVICE)
with torch.no_grad():
    for spectrograms, spectrogram_lengths, texts, text_lengths in tqdm(total_data_loader):
        outputs = model(spectrograms.to(DEVICE), spectrogram_lengths.to(DEVICE))
#         probs = torch.nn.functional.log_softmax(outputs, dim=-1).squeeze().numpy()
        outputs = outputs.log_softmax(dim=-1).transpose(0, 1) # (time, batch, vocab)                   
        outputs = outputs.transpose(0, 1)  # correct shape for CER (batch, time, vocab)
        decoded_preds = ctc_decode(outputs, idx_to_char)
        decoded_refs = ["".join(idx_to_char[idx.item()] for idx in text if idx.item() != 0) for text in texts]       
        predictions.extend(decoded_preds)
        references.extend(decoded_refs)
    df_pred = pd.DataFrame({'prediction':predictions,'reference':references})
    df_pred.to_csv('pred_ref_mlm.csv',index=False,encoding="utf-8")
    asr_metrics.log_metrics(epoch, predictions, references , writer)

100%|█████████████████████████████████████████████████████████████████████████████████| 318/318 [02:19<00:00,  2.27it/s]


Epoch 1: CER: 0.0277, WER: 0.0570


# Model summary

In [13]:
# this is model summary 
N_MELS = 128 
HIDDEN_DIM = 256
VOCAB_SIZE = 28  

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = ASRModel(N_MELS, HIDDEN_DIM, VOCAB_SIZE).to(device)
dummy_input = torch.randn(1, N_MELS, 486).to(device)
dummy_lengths = torch.tensor([120]).to(device) #DUMMY_SEQ_LEN

summary(model, input_data=(dummy_input, dummy_lengths), col_names=["input_size", "output_size", "num_params", "trainable"], depth=4)



Layer (type:depth-idx)                   Input Shape               Output Shape              Param #                   Trainable
ASRModel                                 [1, 128, 486]             [1, 120, 28]              --                        True
├─Sequential: 1-1                        [1, 1, 128, 486]          [1, 64, 128, 486]         --                        True
│    └─Conv2d: 2-1                       [1, 1, 128, 486]          [1, 32, 128, 486]         320                       True
│    └─BatchNorm2d: 2-2                  [1, 32, 128, 486]         [1, 32, 128, 486]         64                        True
│    └─GELU: 2-3                         [1, 32, 128, 486]         [1, 32, 128, 486]         --                        --
│    └─Conv2d: 2-4                       [1, 32, 128, 486]         [1, 64, 128, 486]         18,496                    True
│    └─BatchNorm2d: 2-5                  [1, 64, 128, 486]         [1, 64, 128, 486]         128                       True
│    

# Manual Testing the model

In [13]:
if os.path.exists('kenlm'):
    print("File already downloaded")
else:
    text_f = train_metadata_df['transcript'].dropna().unique().tolist() + val_metadata_df['transcript'].dropna().unique().tolist() + test_metadata_df['transcript'].dropna().unique().tolist()
    print(len(text_f))

    with open('kenlm/build/corpus_full.txt' ,'w', encoding='utf-8') as f:
        for line in text_f:
            f.write(line.strip() + '\n')
    print('corpus.txt saved sucessfully')
# all_trans = pd.concat([train_metadata_df['transcript'],val_metadata_df['transcript']])
# word_dict = set()
# for transcript in all_trans.dropna():
#     words = transcript.lower().split()
#     word_dict.update(words)
# word_dict = sorted(word_dict)

File already downloaded


In [12]:
%%time
class ASRModel(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim , num_layers=4, dropout=0.3,fil_dim=32):
        super(ASRModel, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(1, fil_dim, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(fil_dim),
            nn.GELU(),            
            nn.Conv2d(fil_dim, fil_dim*2, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(fil_dim*2),
            nn.GELU(),  
        )
        self.conv1d = nn.Sequential(
            nn.Conv1d(in_channels=64*input_dim, out_channels = hidden_dim, kernel_size=3, stride=1, padding=1),  
            nn.BatchNorm1d(hidden_dim),
            nn.GELU(),
        )

        self.bi_gru = nn.GRU(
            input_size=hidden_dim, 
            hidden_size=hidden_dim, 
            num_layers=num_layers, 
            batch_first=True, 
            bidirectional=True
        )
        # self.fc = nn.Linear(hidden_dim * 2, output_dim)
        self.fc = nn.Sequential(
            nn.LayerNorm(hidden_dim*2),
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, output_dim) 
        )

    def forward(self, x, lengths):
        x = x.unsqueeze(1) 
        x = self.conv(x) # (Batch, channels, Features, time) ---- [32, 64, 128, 395]
        x = x.permute(0, 3, 2, 1)  # (Batch, channels, Features, time) ---- [32, 64, 128, 395]
        x = x.reshape(x.size(0), x.size(1), -1)  # Flatten for RNN input ----- [32, 395, 8192]
        x = x.permute(0, 2, 1)
        
        x = self.conv1d(x)   # ---- [32, 512, 395]
        x = x.permute(0, 2, 1)   # ---- [32, 395, 512]
        lengths = torch.tensor(lengths)
        lengths = lengths.cpu().int()
        packed_x = nn.utils.rnn.pack_padded_sequence(x, lengths, batch_first=True, enforce_sorted=False)
        packed_output, _ = self.bi_gru(packed_x)
        output, _ = nn.utils.rnn.pad_packed_sequence(packed_output, batch_first=True)
        return self.fc(output)
        
# Load the trained model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def beam_search_decoder(probs,char_map,k):
    T,V = probs.shape
    beam = [("",0.0)]
    for t in range(T):
        new_beam={}
        for seq,log_probs in beam:
            for v in range(V):                
                if v == 0:
                    new_seq = seq
                elif len(seq) > 0 and seq[-1]==char_map[v]:
                    new_seq = seq
                else:
                    new_seq = seq + char_map[v]

                new_probs = log_probs + probs[t,v]
                
                if new_seq in new_beam:
                    new_beam[new_seq] = max(new_beam[new_seq], new_probs)
                else:
                    new_beam[new_seq] = new_probs
                   
        beam = heapq.nlargest(k,new_beam.items(),key=lambda x:x[1])
    return beam[0][0]         
    
    
def transcribe(audio_path, model, char_map, device="cpu"):
    model.to(device)
    model.eval()
    
    # Load and process audio
    audio, sr = librosa.load(audio_path, sr=16000)
    audio = librosa.effects.time_stretch(audio, rate=0.9)
#     Audio(audio,rate=16000)
    mel_spectrogram = librosa.feature.melspectrogram(y=audio, sr=sr, n_mels=128)
    mel_spectrogram = librosa.power_to_db(mel_spectrogram, ref=np.max)

    # Convert to tensor
    audio_tensor = torch.tensor(mel_spectrogram, dtype=torch.float32).unsqueeze(0).to(device)

    start_time = time.time()
    # Get model output
    with torch.no_grad():
        output = model(audio_tensor, [audio_tensor.shape[2]])  # Fix time dimension
    probs = torch.nn.functional.log_softmax(output, dim=-1).squeeze().numpy()
#     output = output.argmax(dim=-1).squeeze().flatten().tolist()

    # Greedy decoding with blank removal
    transcription = beam_search_decoder(probs,char_map,k=1)
#     prev_char = None
#     blank_idx = 0  # Adjust if your blank token index is different
    
#     for idx in output:
#         if idx != prev_char and idx != blank_idx and idx in char_map:
#             transcription.append(char_map[idx])
#         prev_char = idx  
        
    end_time = time.time()
    inference_time = end_time - start_time
    
    return "".join(transcription), inference_time
        

# Load the model
input_dim = 128 
hidden_dim = 512
output_dim = 28 
model = ASRModel(input_dim, hidden_dim, output_dim)
model.load_state_dict(torch.load("models/augmented/conv1d/asr_model--3.pth", map_location="cpu"))
model.eval()


# Test transcription
test_audio = "aug_audio"
transcription, latency = transcribe(test_audio, model, idx_to_char)


print(f"Transcription: {transcription}")
print(f"Inference Time: {latency*1000:.1f} ms")

Transcription: thes a wo icand metsouraiastion energy and wais conperform an wmiol tat charche pardical in ther pat it is unecesary for on o discusions to esamine detailed auntitetif elect rom mavenodicrerlationships a which are defined by maxfwals ecquations a 
Inference Time: 388.4 ms
CPU times: user 7.91 s, sys: 1.1 s, total: 9.01 s
Wall time: 1.02 s


# Language model

In [14]:
def get_lm_score(seq,lm,backoff):
    if seq in lm:
        return lm[seq]
    if len(seq) > 1:
        backoff_word = seq[:-1]
        if backoff_word in backoff:
            return backoff[backoff_word] + get_lm_score(seq[1:], lm, backoff)
    return -10         

def get_best_lm_replacement(word,context, lm, backoff,thres=2.0):
    best_score = float('-inf')
    best_word = word  
    closest_match = difflib.get_close_matches(word, lm.keys(), n=1, cutoff=0.85)  # You can adjust cutoff
    if closest_match:
        best_word = closest_match[0]
        best_score = lm[best_word]
    else:
        best_word = word
#     for candidate in closest_match:
#         candidate_ngram = str(context) + " " + candidate
#         score = get_lm_score(candidate_ngram, lm, backoff)
#         if score > best_score:
#             best_word = candidate
#             best_score = score        
    return best_word

def post_process_with_lm(decoded_seq, lm, backoff):
    words = decoded_seq.strip().split()
    final_transcription = []
    for i,word in enumerate(words):
        word = word.strip()
        context = " ".join(final_transcription[max(0,i-2):i])
        
        if word in lm: 
#             lm_score = get_lm_score(context + word,lm,backoff)
            final_transcription.append(word)  # Keep the word if it's correct
        else:
#             If the word is not in the LM, we look for the best replacement word from the LM
            best_replacement = get_best_lm_replacement(word,context, lm, backoff)
            final_transcription.append(best_replacement) 
    return " ".join(final_transcription)

def ctc_decode_lm(output, idx_to_char):    
    pred_indices = output.argmax(dim=-1).cpu().numpy()  
    decoded_text = []
    
    for seq in pred_indices:  # Iterate over batch
        chars = []
        prev_char = None  
        for idx in seq:
            if idx != 0 and idx != prev_char:  
                chars.append(idx_to_char[idx])
            prev_char = idx
        decoded_text.append("".join(chars))  
    final_transcription = [post_process_with_lm(text, lm, backoff) for text in decoded_text]
    return final_transcription


test_metadata_df = pd.read_csv(METADATA_FILE_TEST)
test_metadata_df = test_metadata_df.drop_duplicates()
test_dataset = ASRDataset(test_metadata_df, char_to_idx)
test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False, collate_fn=collate_fn)
predictions, references = [], []
epoch = 0
writer = SummaryWriter()
model = model.to(DEVICE)
with torch.no_grad():
    for spectrograms, spectrogram_lengths, texts, text_lengths in tqdm(test_loader):
        outputs = model(spectrograms.to(DEVICE), spectrogram_lengths.to(DEVICE))
#         probs = torch.nn.functional.log_softmax(outputs, dim=-1).squeeze().numpy()
        outputs = outputs.log_softmax(dim=-1).transpose(b0, 1) # (time, batch, vocab)                   
        outputs = outputs.transpose(0, 1)  # correct shape for CER (batch, time, vocab)
        decoded_preds = ctc_decode_lm(outputs, idx_to_char)
        decoded_refs = ["".join(idx_to_char[idx.item()] for idx in text if idx.item() != 0) for text in texts] 
        predictions.extend(decoded_preds)
        references.extend(decoded_refs)
    df_pred = pd.DataFrame({'prediction':predictions,'reference':references})
    df_pred.to_csv('pred_vs_ref.csv',index=False,encoding="utf-8")
    asr_metrics.log_metrics(epoch, predictions, references , writer)

100%|███████████████████████████████████████████████████████████████████████████████████| 16/16 [10:22<00:00, 38.89s/it]

Epoch 1: CER: 0.0517, WER: 0.1571





In [94]:
%%time 
import heapq
import math
import difflib
from torchaudio.models.decoder import ctc_decoder
from pyctcdecode import build_ctcdecoder
# from flashlight.lib.text.dictionary import Dictionary
import kenlm

# from flashlight.lib.text.decoder import KenLM, LexiconDecoder, LexiconDecoderOptions,CriterionType
hotwords = ['pala' , 'vinayak','model' ,'improvement']

model = ASRModel(input_dim,hidden_dim,output_dim)
model.load_state_dict(torch.load("models/augmented/asr_model--11.pth"))
model.eval()

# token_dict = Dictionary()
# for idx, token in enumerate(vocab):
#     token_dict.add_entry(token)
# lm = kenlm.Model('kenlm/build/3gram.binary')
# lm = KenLM("kenlm/build/3gram.binary" , vocab)
LM_WEIGHT = 3.23
WORD_SCORE = -0.26

beam_search_decoder =ctc_decoder(
    lexicon=None,
    tokens=vocab,
    lm='kenlm/build/3gram.arpa',
    nbest=3,
    beam_size=1500,
    lm_weight=LM_WEIGHT,
    word_score=WORD_SCORE,
    blank_token=" ",
    sil_token=" ",
)


def replace_with_hotwords(sequence, hotwords):
    words = sequence.split()
    updated_words = []
    for word in words:
        closest_match = difflib.get_close_matches(word, hotwords, n=1, cutoff=0.75)
        if closest_match:
            updated_words.append(closest_match[0])  # Replace with hotword
        else:
            updated_words.append(word)

    return " ".join(updated_words)

# def beam_search_decoder(probs, char_map, k, hotwords=None, hotword_weight=10.0):
#     if hotwords is None:
#         hotwords = []
#     hotwords = set(hotwords)
#     T, V = probs.shape
#     beam = [("", 0.0)]
    
#     for t in range(T):
#         new_beam = {}
        
#         for seq, log_prob in beam:
#             for v in range(V):
#                 new_seq = seq
                
#                 if v != 0:  # Ignore blank index
#                     if len(seq) > 0 and seq[-1] == char_map[v]:
#                         new_seq = seq  # Repeat character, no addition
#                     else:
#                         new_seq = seq + char_map[v]
                
#                 # Adjust probability if the word is a hotword
#                 new_prob = log_prob + probs[t, v]
                
#                 for hotword in hotwords:
#                     if new_seq.endswith(hotword):
#                         new_prob += hotword_weight
# #                     elif hotword.startswith(new_seq):
# #                         new_prob += hotword_weight/2
# #                     elif hotword in new_seq:
# #                         new_prob += hotword_weight/3
                
#                 # Keep best probability for each sequence
#                 if new_seq in new_beam:
#                     new_beam[new_seq] = max(new_beam[new_seq], new_prob)
#                 else:
#                     new_beam[new_seq] = new_prob
                   
#         # Keep top k sequences
#         beam = heapq.nlargest(k, new_beam.items(), key=lambda x: x[1])
#         best_sequence = beam[0][0]
#         best_sequence = replace_with_hotwords(best_sequence, hotwords)
    
#     return best_sequence # Return best sequence

audio_p = 'testing_audio/into.mp3'
audio,sr = librosa.load(audio_p,sr=16000)
melspec = librosa.feature.melspectrogram(y=audio,n_mels=128,sr=sr)
melspec = librosa.power_to_db(melspec , ref = np.max)
audio_tensor = torch.tensor(melspec, dtype = torch.float32).unsqueeze(0)
with torch.no_grad():
    output = model(audio_tensor , [audio_tensor.shape[2]])
output = torch.nn.functional.log_softmax(output,dim=-1 ).cpu()

# emissions = output.squeeze(0)
emissions = torch.tensor(output, dtype=torch.float32)
# emissions = np.ascontiguousarray(emissions)
# T,N = emissions.shape
# results = decoder.decode(emissions.ctypes.data_as(np.ctypeslib.as_ctypes(np.float32)), T, N) 
# best_transcription = results[0].tokens
# print(best_transcription)
# transcription = beam_search_decoder(output, idx_to_char, k=20, hotwords=hotwords, hotword_weight=10.0)
beam_search_result = beam_search_decoder(emissions)
beam_search_transcript = " ".join(beam_search_result[0][0].words).strip()

print(f"Transcript: {beam_search_transcript}")

# decoder = build_ctcdecoder(vocab)
# text = decoder.decode(
#     output,
#     hotwords=hotwords,
#     hotword_weight = 10.0,
# )
# print(transcription)

Loading the LM will be faster if you build a binary file.
Reading kenlm/build/3gram.arpa
----5---10---15---20---25---30---35---40---45---50---55---60---65---70---75---80---85---90---95--100
****************************************************************************************************


Transcript: 
CPU times: user 5.9 s, sys: 1.07 s, total: 6.97 s
Wall time: 1.76 s


#  old methods

In [None]:
import torch
import librosa
import numpy as np

# Import your ASR model
class ASRModel(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim , dropout=0.3):
        super(ASRModel, self).__init__()

        # Convolutional Layers
        self.conv = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(32),
            nn.LeakyReLU(0.1),
            
            nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.1),
        )


        # 4-layer Bi-GRU
        self.bi_gru = nn.GRU(
            input_size=64 * input_dim, 
            hidden_size=hidden_dim, 
            num_layers=3,  # Increased to 4 layers
            batch_first=True, 
            bidirectional=True
        )

        # Fully Connected Layer
        # self.fc = nn.Linear(hidden_dim * 2, output_dim)
        self.fc = nn.Sequential(
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, output_dim) 
        )

    def forward(self, x, lengths):
        x = x.unsqueeze(1)  # Add channel dimension
        x = self.conv(x)
        x = x.permute(0, 3, 2, 1)  # (Batch, Time, Features, Channels)
        x = x.reshape(x.size(0), x.size(1), -1)  # Flatten for RNN input
        lengths = torch.tensor(lengths)
        lengths = lengths.cpu().int()
        
        
        # Pack padded sequence for efficiency
        packed_x = nn.utils.rnn.pack_padded_sequence(x, lengths, batch_first=True, enforce_sorted=False)
        packed_output, _ = self.bi_gru(packed_x)
        output, _ = nn.utils.rnn.pad_packed_sequence(packed_output, batch_first=True)

        return self.fc(output)

# Load the trained model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# def transcribe(audio_path, model, char_map, device="cpu"):
#     model.to(device)
#     model.eval()
    
#     # Load and process audio
#     audio, sr = librosa.load(audio_path, sr=16000)
#     mel_spectrogram = librosa.feature.melspectrogram(y=audio, sr=sr, n_mels=128, fmax=8000)
#     mel_spectrogram = librosa.power_to_db(mel_spectrogram, ref=np.max)
    
#     # Convert to tensor
#     audio_tensor = torch.tensor(mel_spectrogram, dtype=torch.float32).unsqueeze(0).to(device)
    
#     # Get model output
#     with torch.no_grad():
#         output = model(audio_tensor, [audio_tensor.shape[1]])
#     print(output)
#     # Convert to characters
#     output = output.argmax(dim=-1).squeeze(0).tolist()
#     transcription = "".join([char_map[idx] for idx in output if idx in char_map])
    
#     return transcription
def transcribe(audio_path, model, char_map, device="cpu"):
    model.to(device)
    model.eval()
    
    # Load and process audio
    audio, sr = librosa.load(audio_path, sr=16000)
    mel_spectrogram = librosa.feature.melspectrogram(y=audio, sr=sr, n_mels=128, fmax=8000)
    mel_spectrogram = librosa.power_to_db(mel_spectrogram, ref=np.max)
    
    # Convert to tensor
    audio_tensor = torch.tensor(mel_spectrogram, dtype=torch.float32).unsqueeze(0).to(device)
    
    # Get model output
    with torch.no_grad():
        output = model(audio_tensor, [audio_tensor.shape[2]])  # Fix time dimension

    # Take argmax and ensure 1D list
    output = output.argmax(dim=-1).squeeze().flatten().tolist()

    # Greedy decoding with blank removal
    transcription = []
    prev_char = None
    blank_idx = 0  # Adjust if your blank token index is different
    
    for idx in output:
        if idx != prev_char and idx != blank_idx and idx in char_map:
            transcription.append(char_map[idx])
        prev_char = idx  

    return "".join(transcription)
                      
# Load the model
input_dim = 128 
hidden_dim = 256 
output_dim = 28 
model = ASRModel(input_dim, hidden_dim, output_dim)
model.load_state_dict(torch.load("models/asr_model--8.pth", map_location="cpu"))
model.eval()


# Test transcription
test_audio = "wavs/LJ007-0198.wav"
print("Transcription:", transcribe(test_audio, model, idx_to_char))


In [None]:
the windows were to be glazed and painted to prevent prisoners from looking out

In [None]:
priting in the only sense with which we are at presene concerned differs fror most ivf not from all the arts incrafts represented in the exivition

In [None]:
%%time
from torchaudio.models.decoder import ctc_decoder
    
LM_WEIGHT = 1.5
WORD_SCORE = -0.26

beam_search_decoder = ctc_decoder(
    lexicon=None, 
    tokens= vocab,
    lm= 'test.arpa',
    nbest=3,
    beam_size=10,
    lm_weight=LM_WEIGHT,
    word_score=WORD_SCORE,
    blank_token=' ',
    sil_token = ' '
)


model = ASRModel(input_dim,hidden_dim,output_dim)
model.load_state_dict(torch.load("models/augmented/asr_model--7.pth"))
model.eval()

audio_path = 'testing_audio/trouble.mp3'
audio , sr = librosa.load(audio_path, sr=16000)
mel_spec = librosa.feature.melspectrogram(y=audio,n_mels=128,sr=sr)
mel_spec = librosa.power_to_db(mel_spec, ref = np.max)
audio_tensor = torch.tensor(mel_spec, dtype=torch.float32).unsqueeze(0)
with torch.no_grad():
    output = model(audio_tensor, [audio_tensor.shape[2]])
emissions  = torch.nn.functional.log_softmax(output, dim=-1).squeeze()
emissions = emissions.unsqueeze(0)
result = beam_search_decoder(emissions)

token_str = "".join(beam_search_decoder.idxs_to_tokens(result[0][0].tokens))
transcript = " ".join(token_str.strip())

print(transcript)

In [None]:
char_to_idx

In [None]:
import Levenshtein
def calculate_cer(reference: str, hypothesis: str) -> float:
    reference = reference.replace(" ", "")  # CER ignores spaces
    hypothesis = hypothesis.replace(" ", "")    
    cer = Levenshtein.distance(reference, hypothesis) / max(1, len(reference))
    return cer

def calculate_wer(reference: str, hypothesis: str) -> float:
    ref_words = reference.split()
    hyp_words = hypothesis.split()    
    wer = Levenshtein.distance(ref_words, hyp_words) / max(1, len(ref_words))
    return wer

reference_text = "We can measure radiation energy and waves can perform and move a charged particle in their path It is unnecessary for quantum discussions to examine detailed quantitative electromagnetic relationships which are defined by Maxwells equations"
hypothesis_text = "witc an measure radiition energy and weofs con performan wol a charged particle in their pat it is unecessary for qoune of discussions to examine detailed quantity of electrom agneatic relationships  which are defied by maxwallsequations "
cer = calculate_cer(reference_text, hypothesis_text)
wer = calculate_wer(reference_text, hypothesis_text)

print(f"CER: {cer:.4f}, WER: {wer:.4f}") #

In [56]:
def mask_asr_errors(asr_sentence, correct_sentence):
    asr_words = asr_sentence.split()
    correct_words = correct_sentence.split()
    
    # Identify incorrect words using SequenceMatcher
    matcher = difflib.SequenceMatcher(None, asr_words, correct_words)
    
    masked_sentence = asr_words.copy()
    
    for opcode, i1, i2, j1, j2 in matcher.get_opcodes():
        if opcode in ["replace", "insert", "delete"]:
            # Replace incorrect words with [MASK]
            for idx in range(i1, i2):
                masked_sentence[idx] = "[MASK]"
    return " ".join(masked_sentence), correct_sentence

pred_df = pd.read_csv('pred_vs_ref.csv')
pred_df = pred_df.dropna()
processed_data = [mask_asr_errors(row["prediction"], row["reference"]) for _,row in pred_df.iterrows()]

df_mask = pd.DataFrame(processed_data, columns=["masked_input", "target_output"])
print(df_mask)


                                           masked_input  \
0                         in being comparatively modern   
1     the invention of movable metal letters in the ...   
2     and it is worth mention in passing that as an ...   
3     now as all books not primarily intended as [MA...   
4     especially as no more time is occupied or cost...   
...                                                 ...   
2612  the secret service has been receiving full coo...   
2613  even if the [MASK] [MASK] [MASK] resources of ...   
2614  with the office of science and technology and ...   
2615  made certain recommendations which it believes...   
2616  as has been pointed out the commission has not...   

                                          target_output  
0                         in being comparatively modern  
1     the invention of movable metal letters in the ...  
2     and it is worth mention in passing that as an ...  
3     now as all books not primarily intended as pic...  
4

In [6]:
from transformers import AlbertTokenizer, AlbertForMaskedLM
import torch

# Load pre-trained ALBERT model and tokenizer
tokenizer = AlbertTokenizer.from_pretrained("albert-base-v2")
model = AlbertForMaskedLM.from_pretrained("albert-base-v2")

def correct_asr_text(text):
    """
    Takes noisy ASR text and corrects it using ALBERT's masked language modeling.
    """
    # Tokenize the input sentence
    inputs = tokenizer(text, return_tensors="pt")

    with torch.no_grad():
        outputs = model(**inputs).logits

    # Decode predicted sentence
    corrected_text = tokenizer.decode(torch.argmax(outputs, dim=2)[0], skip_special_tokens=True)

    return corrected_text

# Example: Noisy ASR Output
noisy_text = "I no their going to the park."
corrected_text = correct_asr_text(noisy_text)

print("Noisy ASR Output:", noisy_text)
print("Corrected Sentence:", corrected_text)


Some weights of the model checkpoint at albert-base-v2 were not used when initializing AlbertForMaskedLM: ['albert.pooler.bias', 'albert.pooler.weight']
- This IS expected if you are initializing AlbertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing AlbertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Noisy ASR Output: I no their going to the park.
Corrected Sentence: my i am theyre going to the park!!! your
