In [1]:
import torch
import torchaudio
from torch import nn, optim
from torch.utils.data import Dataset, DataLoader
from torchaudio.transforms import MelSpectrogram, Resample
from jiwer import wer

In [2]:
def load_lists_from_files(audio_file_path='audio_files.txt', transcript_file_path='transcriptions.txt'):
    with open(audio_file_path, 'r') as af:
        audio_files = [line.strip() for line in af]

    with open(transcript_file_path, 'r') as tf:
        transcriptions = [line.strip() for line in tf]

    return audio_files, transcriptions

# Load the lists back from text files
audio_files, transcriptions = load_lists_from_files()

In [3]:
# Create a character-level vocabulary
vocab = set(''.join(transcriptions))
vocab = {char: idx for idx, char in enumerate(sorted(vocab))}
idx_to_char = {idx: char for char, idx in vocab.items()}
blank_token_idx = len(vocab)
print(vocab)

# Function to convert transcription to numerical labels
def text_to_labels(text):
    return [vocab[char] for char in text]

# Function to convert labels to text
def labels_to_text(labels):
    return ''.join([idx_to_char[idx] for idx in labels if idx in idx_to_char])

{' ': 0, "'": 1, 'A': 2, 'B': 3, 'C': 4, 'D': 5, 'E': 6, 'F': 7, 'G': 8, 'H': 9, 'I': 10, 'J': 11, 'K': 12, 'L': 13, 'M': 14, 'N': 15, 'O': 16, 'P': 17, 'Q': 18, 'R': 19, 'S': 20, 'T': 21, 'U': 22, 'V': 23, 'W': 24, 'X': 25, 'Y': 26, 'Z': 27}


In [4]:
# SpeechDataset class definition
class SpeechDataset(Dataset):
    def __init__(self, audio_files, transcriptions, sample_rate=16000):
        self.audio_files = audio_files
        self.transcriptions = transcriptions
        self.sample_rate = sample_rate
        self.resample = Resample(orig_freq=sample_rate, new_freq=16000)
        self.melspec = MelSpectrogram(sample_rate=16000, n_mels=128)
    
    def __len__(self):
        return len(self.audio_files)
    
    def __getitem__(self, idx):
        try:
            waveform, sample_rate = torchaudio.load(self.audio_files[idx])
        except Exception as e:
            print(f"Error loading file {self.audio_files[idx]}: {e}")
            return None, None  # Return None for both to handle it later
        
        waveform = self.resample(waveform)
        mel_spec = self.melspec(waveform)
        transcription = self.transcriptions[idx]
        return mel_spec.squeeze(0).transpose(0, 1), transcription  # Transpose to [seq_len, feature_dim]

In [5]:
# Function to pad sequences
def pad_sequence(batch):
    batch = [item for item in batch]  # Ensure [seq_len, feature_dim]
    batch = torch.nn.utils.rnn.pad_sequence(batch, batch_first=True)
    return batch

In [6]:
# Update the collate function to handle None entries
def collate_fn(batch):
    batch = [item for item in batch if item[0] is not None]  # Filter out None entries
    if len(batch) == 0:  # Handle the case where all items are None
        return None, None, None, None
    mel_specs = [item[0] for item in batch]
    transcriptions = [item[1] for item in batch]
    mel_specs_padded = pad_sequence(mel_specs)
    labels = [torch.tensor(text_to_labels(t)) for t in transcriptions]
    label_lengths = torch.tensor([len(label) for label in labels])
    labels_padded = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True)
    input_lengths = torch.tensor([mel_spec.size(0) for mel_spec in mel_specs_padded])
    return mel_specs_padded, labels_padded, input_lengths, label_lengths

In [7]:
# Create dataset and dataloader with the expanded dataset
dataset = SpeechDataset(audio_files, transcriptions)
dataloader = DataLoader(dataset, batch_size=2, shuffle=True, collate_fn=collate_fn)



In [8]:
# Define a simple model architecture
class SimpleSpeechModel(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(SimpleSpeechModel, self).__init__()
        self.lstm = nn.LSTM(input_dim, hidden_dim, batch_first=True, bidirectional=True)
        self.fc = nn.Linear(hidden_dim * 2, output_dim)  # Bidirectional LSTM

    def forward(self, x):
        x, _ = self.lstm(x)
        x = self.fc(x)
        return x

In [9]:
# Initialize and train the model
input_dim = 128  # Number of mel bands
hidden_dim = 256
output_dim = len(vocab) + 1  # Output dimension based on the size of the vocabulary + 1 for the blank token
model = SimpleSpeechModel(input_dim, hidden_dim, output_dim)

In [10]:
def train_model(model, dataloader, num_epochs=100, learning_rate=0.001):
    criterion = nn.CTCLoss(blank=output_dim - 1, zero_infinity=True)
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    
    for epoch in range(num_epochs):
        total_loss = 0
        for batch in dataloader:
            if batch[0] is None:  # Skip if batch is None
                continue
            
            mel_specs, labels, input_lengths, label_lengths = batch
            
            if mel_specs.dim() != 3:
                print(f"Unexpected dimensions: {mel_specs.shape}")
                continue  # Skip this batch if dimensions are not as expected
            
            outputs = model(mel_specs)
            outputs = outputs.log_softmax(2)
            outputs = outputs.permute(1, 0, 2)  # (T, N, C) for CTCLoss

            loss = criterion(outputs, labels, input_lengths, label_lengths)
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
        
        avg_loss = total_loss / len(dataloader)
        print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss:.4f}")

In [11]:
# Train the model with the expanded dataset
train_model(model, dataloader)

Epoch [1/100], Loss: 4.4631
Epoch [2/100], Loss: 2.9693
Epoch [3/100], Loss: 2.9224
Epoch [4/100], Loss: 2.9113
Epoch [5/100], Loss: 2.8445
Epoch [6/100], Loss: 2.8095
Epoch [7/100], Loss: 2.7201
Epoch [8/100], Loss: 2.6751
Epoch [9/100], Loss: 2.6394
Epoch [10/100], Loss: 2.6314
Epoch [11/100], Loss: 2.6377
Epoch [12/100], Loss: 2.6178
Epoch [13/100], Loss: 2.5222
Epoch [14/100], Loss: 2.4607
Epoch [15/100], Loss: 2.3946
Epoch [16/100], Loss: 2.3194
Epoch [17/100], Loss: 2.5675
Epoch [18/100], Loss: 2.5583
Epoch [19/100], Loss: 2.4490
Epoch [20/100], Loss: 2.4033
Epoch [21/100], Loss: 2.3128
Epoch [22/100], Loss: 2.2359
Epoch [23/100], Loss: 2.1699
Epoch [24/100], Loss: 2.1086
Epoch [25/100], Loss: 2.0504
Epoch [26/100], Loss: 1.9871
Epoch [27/100], Loss: 1.9298
Epoch [28/100], Loss: 1.8696
Epoch [29/100], Loss: 1.8109
Epoch [30/100], Loss: 1.7642
Epoch [31/100], Loss: 1.7046
Epoch [32/100], Loss: 1.6584
Epoch [33/100], Loss: 1.6056
Epoch [34/100], Loss: 1.5487
Epoch [35/100], Loss: 1

In [15]:
# Function to decode the model output into text
def decode_output(output):
    _, max_indices = torch.max(output, dim=-1)
    tokens = max_indices.unique_consecutive()
    decoded = ''.join([idx_to_char[idx.item()] for idx in tokens if idx.item() in idx_to_char])
    return decoded

In [16]:
# Calculate accuracy using Word Error Rate (WER)
def calculate_accuracy(model, dataloader):
    model.eval()
    predictions = []
    ground_truths = []
    
    with torch.no_grad():
        for batch in dataloader:
            if batch[0] is None:  # Skip if batch is None
                continue
            
            mel_specs, labels, input_lengths, label_lengths = batch
            
            outputs = model(mel_specs)
            outputs = outputs.log_softmax(2)
            outputs = outputs.permute(1, 0, 2)  # (T, N, C) for CTCLoss
            
            for i in range(outputs.size(1)):  # Iterate over batch
                decoded_output = decode_output(outputs[:, i, :])
                predictions.append(decoded_output)
                ground_truths.append(''.join([idx_to_char[idx.item()] for idx in labels[i] if idx.item() in idx_to_char]))
    
    wer_score = wer(ground_truths, predictions)
    accuracy = 1 - wer_score
    return accuracy


In [18]:
# Calculate accuracy
accuracy = calculate_accuracy(model, dataloader)
print(f"Model Accuracy: {accuracy:.4f}")

Model Accuracy: 0.6648
