In [1]:
import torch
import torchaudio
from torch import nn, optim
from torch.utils.data import Dataset, DataLoader
from torchaudio.transforms import MelSpectrogram, Resample, TimeMasking, FrequencyMasking, MFCC
from jiwer import wer
import numpy as np

In [2]:
def load_lists_from_files(audio_file_path='audio_files.txt', transcript_file_path='transcriptions.txt'):
    with open(audio_file_path, 'r') as af:
        audio_files = [line.strip() for line in af]

    with open(transcript_file_path, 'r') as tf:
        transcriptions = [line.strip() for line in tf]

    return audio_files, transcriptions

# Load the lists back from text files
audio_files, transcriptions = load_lists_from_files()

In [15]:
# Create a character-level vocabulary
vocab = set(''.join(transcriptions))
vocab = {char: idx for idx, char in enumerate(sorted(vocab))}
blank_token = len(vocab)
vocab['<blank>'] = blank_token
idx_to_char = {idx: char for char, idx in vocab.items()}
print(vocab)

# Function to convert transcription to numerical labels
def text_to_labels(text):
    return [vocab[char] for char in text]

# Function to convert labels to text
def labels_to_text(labels):
    return ''.join([idx_to_char[idx] for idx in labels if idx in idx_to_char])

{' ': 0, "'": 1, 'A': 2, 'B': 3, 'C': 4, 'D': 5, 'E': 6, 'F': 7, 'G': 8, 'H': 9, 'I': 10, 'J': 11, 'K': 12, 'L': 13, 'M': 14, 'N': 15, 'O': 16, 'P': 17, 'Q': 18, 'R': 19, 'S': 20, 'T': 21, 'U': 22, 'V': 23, 'W': 24, 'X': 25, 'Y': 26, 'Z': 27, '<blank>': 28}


In [4]:
# SpeechDataset class definition with spectrogram augmentation and MFCC
class SpeechDataset(Dataset):
    def __init__(self, audio_files, transcriptions, sample_rate=16000):
        self.audio_files = audio_files
        self.transcriptions = transcriptions
        self.sample_rate = sample_rate
        self.resample = Resample(orig_freq=sample_rate, new_freq=16000)
        self.melspec = MelSpectrogram(sample_rate=16000, n_mels=128)
        self.mfcc = MFCC(sample_rate=16000, n_mfcc=40)
        self.time_masking = TimeMasking(time_mask_param=30)
        self.freq_masking = FrequencyMasking(freq_mask_param=15)
    
    def __len__(self):
        return len(self.audio_files)
    
    def __getitem__(self, idx):
        try:
            waveform, sample_rate = torchaudio.load(self.audio_files[idx])
        except Exception as e:
            print(f"Error loading file {self.audio_files[idx]}: {e}")
            return None, None  # Return None for both to handle it later
        
        waveform = self.resample(waveform)
        mel_spec = self.melspec(waveform)
        mel_spec = self.time_masking(mel_spec)
        mel_spec = self.freq_masking(mel_spec)
        mfcc = self.mfcc(waveform)

         # Ensure mel_spec and mfcc have the same size along the concatenation dimension
        if mel_spec.size(2) > mfcc.size(2):
            mel_spec = mel_spec[:, :, :mfcc.size(2)]
        elif mfcc.size(2) > mel_spec.size(2):
            mfcc = mfcc[:, :, :mel_spec.size(2)]
        
        features = torch.cat((mel_spec, mfcc), dim=1)
        transcription = self.transcriptions[idx]
        return features.squeeze(0).transpose(0, 1), transcription  # Transpose to [seq_len, feature_dim]

In [5]:
# Function to pad sequences
def pad_sequence(batch):
    batch = [item for item in batch]  # Ensure [seq_len, feature_dim]
    batch = torch.nn.utils.rnn.pad_sequence(batch, batch_first=True)
    return batch

In [6]:
# Update the collate function to handle None entries
def collate_fn(batch):
    batch = [item for item in batch if item[0] is not None]  # Filter out None entries
    if len(batch) == 0:  # Handle the case where all items are None
        return None, None, None, None
    mel_specs = [item[0] for item in batch]
    transcriptions = [item[1] for item in batch]
    mel_specs_padded = pad_sequence(mel_specs)
    labels = [torch.tensor(text_to_labels(t)) for t in transcriptions]
    label_lengths = torch.tensor([len(label) for label in labels])
    labels_padded = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True)
    input_lengths = torch.tensor([mel_spec.size(0) for mel_spec in mel_specs_padded])
    return mel_specs_padded, labels_padded, input_lengths, label_lengths

In [7]:
# Create dataset and dataloader with the expanded dataset
dataset = SpeechDataset(audio_files, transcriptions)
dataloader = DataLoader(dataset, batch_size=2, shuffle=True, collate_fn=collate_fn)



In [8]:
# Define a simple model architecture
class EnhancedSpeechModel(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(EnhancedSpeechModel, self).__init__()
        self.conv1 = nn.Conv1d(input_dim, 64, kernel_size=3, stride=1, padding=1)
        self.relu = nn.ReLU()
        self.lstm = nn.LSTM(64, hidden_dim, batch_first=True, bidirectional=True)
        self.fc = nn.Linear(hidden_dim * 2, output_dim)  # Bidirectional LSTM

    def forward(self, x):
        x = x.transpose(1, 2)  # Change to [batch_size, feature_dim, seq_len] for Conv1d
        x = self.conv1(x)
        x = self.relu(x)
        x = x.transpose(1, 2)  # Change back to [batch_size, seq_len, feature_dim] for LSTM
        x, _ = self.lstm(x)
        x = self.fc(x)
        return x

In [9]:
# Initialize model
input_dim = 128 + 40  # Number of mel bands + number of MFCCs
hidden_dim = 256
output_dim = len(vocab) + 1  # Output dimension based on the size of the vocabulary + 1 for the blank token
model = EnhancedSpeechModel(input_dim, hidden_dim, output_dim)

In [10]:
# Training with Early Stopping and Learning Rate Scheduler
def train_model(model, dataloader, num_epochs=100, learning_rate=0.001, patience=10):
    criterion = nn.CTCLoss(blank=output_dim - 1, zero_infinity=True)
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=5, factor=0.5, verbose=True)
    best_loss = float('inf')
    patience_counter = 0
    
    for epoch in range(num_epochs):
        model.train()
        epoch_loss = 0
        for batch in dataloader:
            if batch[0] is None:  # Skip if batch is None
                continue
            
            mel_specs, labels, input_lengths, label_lengths = batch
            
            if mel_specs.dim() != 3:
                print(f"Unexpected dimensions: {mel_specs.shape}")
                continue  # Skip this batch if dimensions are not as expected
            
            outputs = model(mel_specs)
            outputs = outputs.log_softmax(2)
            outputs = outputs.permute(1, 0, 2)  # (T, N, C) for CTCLoss

            loss = criterion(outputs, labels, input_lengths, label_lengths)
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            epoch_loss += loss.item()
        
        epoch_loss /= len(dataloader)
        scheduler.step(epoch_loss)
        
        print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss:.4f}")
        
        # Early stopping
        if epoch_loss < best_loss:
            best_loss = epoch_loss
            patience_counter = 0
        else:
            patience_counter += 1
        
        if patience_counter >= patience:
            print(f"Early stopping at epoch {epoch+1}")
            break

In [11]:
# Train the model with the expanded dataset
train_model(model, dataloader)



Epoch [1/100], Loss: 3.8645
Epoch [2/100], Loss: 3.0191
Epoch [3/100], Loss: 2.9379
Epoch [4/100], Loss: 2.8875
Epoch [5/100], Loss: 2.8284
Epoch [6/100], Loss: 2.8199
Epoch [7/100], Loss: 2.7812
Epoch [8/100], Loss: 2.7701
Epoch [9/100], Loss: 2.7386
Epoch [10/100], Loss: 2.7213
Epoch [11/100], Loss: 2.7088
Epoch [12/100], Loss: 2.6705
Epoch [13/100], Loss: 2.6524
Epoch [14/100], Loss: 2.6426
Epoch [15/100], Loss: 2.6258
Epoch [16/100], Loss: 2.6104
Epoch [17/100], Loss: 2.5920
Epoch [18/100], Loss: 2.4275
Epoch [19/100], Loss: 2.1562
Epoch [20/100], Loss: 1.9644
Epoch [21/100], Loss: 1.8336
Epoch [22/100], Loss: 1.7332
Epoch [23/100], Loss: 1.6490
Epoch [24/100], Loss: 1.5743
Epoch [25/100], Loss: 1.5048
Epoch [26/100], Loss: 1.4428
Epoch [27/100], Loss: 1.3834
Epoch [28/100], Loss: 1.3262
Epoch [29/100], Loss: 1.2776
Epoch [30/100], Loss: 1.2216
Epoch [31/100], Loss: 1.1689
Epoch [32/100], Loss: 1.1177
Epoch [33/100], Loss: 1.0731
Epoch [34/100], Loss: 1.0260
Epoch [35/100], Loss: 0

In [17]:
# Function to decode model outputs into text, handling the blank token
def decode_predictions(predictions, vocab):
    idx_to_char = {idx: char for char, idx in vocab.items()}
    blank_token = len(vocab) - 1
    decoded_output = []
    for prediction in predictions:
        pred_indices = torch.argmax(prediction, dim=-1)
        pred_text = ''.join([idx_to_char[idx.item()] for idx in pred_indices if idx.item() != blank_token])
        decoded_output.append(pred_text)
    return decoded_output

In [18]:
# Function to evaluate the model on a dataset and calculate WER
def evaluate_model(model, dataloader, vocab):
    model.eval()
    total_wer = 0
    num_samples = 0
    with torch.no_grad():
        for batch in dataloader:
            if batch[0] is None:  # Skip if batch is None
                continue
            
            mel_specs, labels, input_lengths, label_lengths = batch
            outputs = model(mel_specs)
            decoded_output = decode_predictions(outputs, vocab)
            ground_truth = [''.join([idx_to_char[idx.item()] for idx in label if idx.item() in idx_to_char]) for label in labels]
            
            for pred_text, true_text in zip(decoded_output, ground_truth):
                total_wer += wer(true_text, pred_text)
                num_samples += 1
    
    avg_wer = total_wer / num_samples
    print(f'Average WER: {avg_wer:.4f}')

In [28]:
# Evaluate the model with the created dataloader
evaluate_model(model, dataloader, vocab)

Average WER: 0.7480


In [29]:
# Function to decode the model output into text
def decode_output(output):
    _, max_indices = torch.max(output, dim=-1)
    tokens = max_indices.unique_consecutive()
    decoded = ''.join([idx_to_char[idx.item()] for idx in tokens if idx.item() in idx_to_char])
    return decoded

In [30]:
# Calculate accuracy using Word Error Rate (WER)
def calculate_accuracy(model, dataloader):
    model.eval()
    predictions = []
    ground_truths = []
    
    with torch.no_grad():
        for batch in dataloader:
            if batch[0] is None:  # Skip if batch is None
                continue
            
            mel_specs, labels, input_lengths, label_lengths = batch
            
            outputs = model(mel_specs)
            outputs = outputs.log_softmax(2)
            outputs = outputs.permute(1, 0, 2)  # (T, N, C) for CTCLoss
            
            for i in range(outputs.size(1)):  # Iterate over batch
                decoded_output = decode_output(outputs[:, i, :])
                predictions.append(decoded_output)
                ground_truths.append(''.join([idx_to_char[idx.item()] for idx in labels[i] if idx.item() in idx_to_char]))
    
    wer_score = wer(ground_truths, predictions)
    accuracy = 1 - wer_score
    return accuracy


In [31]:
# Calculate accuracy
accuracy = calculate_accuracy(model, dataloader)
print(f"Model Accuracy: {accuracy:.4f}")

Model Accuracy: 0.0113


In [23]:
# Function to preprocess a single audio file
def preprocess_audio(audio_file, sample_rate=16000):
    resample = Resample(orig_freq=sample_rate, new_freq=16000)
    melspec = MelSpectrogram(sample_rate=16000, n_mels=128)
    mfcc = MFCC(sample_rate=16000, n_mfcc=40)
    time_masking = TimeMasking(time_mask_param=30)
    freq_masking = FrequencyMasking(freq_mask_param=15)
    
    try:
        waveform, sample_rate = torchaudio.load(audio_file)
    except Exception as e:
        print(f"Error loading file {audio_file}: {e}")
        return None
    
    waveform = resample(waveform)
    mel_spec = melspec(waveform)
    mel_spec = time_masking(mel_spec)
    mel_spec = freq_masking(mel_spec)
    mfcc_feat = mfcc(waveform)

    # Ensure mel_spec and mfcc have the same size along the concatenation dimension
    if mel_spec.size(2) > mfcc_feat.size(2):
        mel_spec = mel_spec[:, :, :mfcc_feat.size(2)]
    elif mfcc_feat.size(2) > mel_spec.size(2):
        mfcc_feat = mfcc_feat[:, :, :mel_spec.size(2)]

    features = torch.cat((mel_spec, mfcc_feat), dim=1)
    return features.squeeze(0).transpose(0, 1)  # Transpose to [seq_len, feature_dim]

In [24]:
# Function to transcribe a single audio file using the trained model
def transcribe_audio(model, audio_file, vocab):
    model.eval()
    with torch.no_grad():
        features = preprocess_audio(audio_file)
        if features is None:
            return None
        features = features.unsqueeze(0)  # Add batch dimension
        outputs = model(features)
        decoded_output = decode_predictions(outputs, vocab)
        return decoded_output[0]

In [27]:
# Example usage
audio_file = "31-121972-0000.wav"  # Replace with the path to your audio file
transcription = transcribe_audio(model, audio_file, vocab)
print(f"Transcription: {transcription}")

Transcription: HOL   ER  VE    WHE WONN  HHHONMD  BETE  BEENNNTHAY GRRRETE THEN  HE MMMYTTT  A IISS SIF SHHE  CONNTEY MMA TWODDD DIIT  MMANND   NOK   A DANNT  CHOED  YUUU   ONT TOULD   I  SHSHHANNTITUU  PICTSLEE   O   OA NNDDDDI WWEETTS    SOY  HE I APEDASA  IINN THHE MMOUINN   HISAEEE  MOW
