In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchaudio.datasets import LIBRISPEECH
from torchaudio.transforms import MelSpectrogram, Resample
from torch.nn.utils.rnn import pad_sequence
from torch.optim.lr_scheduler import ReduceLROnPlateau
import numpy as np
from torchaudio.transforms import TimeMasking, FrequencyMasking

: 

In [None]:
# Hyperparameters
BATCH_SIZE = 128
EPOCHS = 50
LEARNING_RATE = 1e-3
SAMPLE_RATE = 16000
N_MELS = 128
HIDDEN_SIZE = 512
NUM_LAYERS = 3
NUM_CLASSES = 28  # 26 letters + space + blank for CTC
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
PATIENCE = 5  # Early stopping patience

# Character Mapping
char_map = {c: i for i, c in enumerate("abcdefghijklmnopqrstuvwxyz ")}
char_map[''] = 27  # CTC blank token
idx_map = {i: c for c, i in char_map.items()}

In [None]:
# Data Processing with Augmentation
class AudioProcessor:
    def __init__(self):
        self.mel_spec = MelSpectrogram(sample_rate=SAMPLE_RATE, n_mels=N_MELS)
        self.resample = Resample(orig_freq=48000, new_freq=SAMPLE_RATE)
        self.time_mask = TimeMasking(time_mask_param=80)
        self.freq_mask = FrequencyMasking(freq_mask_param=30)

    def __call__(self, waveform, sample_rate):
        if sample_rate != SAMPLE_RATE:
            waveform = self.resample(waveform)
        mel_spec = self.mel_spec(waveform).squeeze(0)
        mel_spec = self.time_mask(mel_spec)  # Apply time mask
        mel_spec = self.freq_mask(mel_spec)  # Apply frequency mask
        return mel_spec

def text_to_int(text):
    return torch.tensor([char_map[c] for c in text.lower() if c in char_map], dtype=torch.long)

def collate_fn(batch):
    waveforms, labels, input_lengths, label_lengths = [], [], [], []
    for waveform, _, text, _, _, _ in batch:
        spec = processor(waveform, SAMPLE_RATE)
        waveforms.append(spec.T)
        labels.append(text_to_int(text))
        input_lengths.append(spec.shape[1])  # Time dimension
        label_lengths.append(len(labels[-1]))
    
    waveforms = pad_sequence(waveforms, batch_first=True).permute(0, 2, 1)
    labels = pad_sequence(labels, batch_first=True)
    return waveforms, labels, torch.tensor(input_lengths), torch.tensor(label_lengths)

def int_to_text(int_seq):
    return ''.join([idx_map[i] for i in int_seq])

In [None]:
# Model Definition with BatchNorm and Dropout
class STTModel(nn.Module):
    def __init__(self):
        super(STTModel, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(32),
            nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(64),
            nn.Dropout(0.2)
        )
        self.rnn = nn.LSTM(input_size=N_MELS, hidden_size=HIDDEN_SIZE, num_layers=NUM_LAYERS, batch_first=True, bidirectional=True)
        self.fc = nn.Linear(HIDDEN_SIZE * 2, NUM_CLASSES)

    def forward(self, x):
        x = x.unsqueeze(1)  # [batch, 1, freq, time]
        x = self.conv(x)  # [batch, channels, freq, time]
        x = x.permute(0, 3, 2, 1)  # [batch, time, freq, channels]
        x = x.mean(dim=-1)  # Reduce channels: [batch, time, freq]
        x, _ = self.rnn(x)  # Input now matches N_MELS
        x = self.fc(x)
        return x


In [None]:
# Training Setup
model = STTModel().to(DEVICE)
ctc_loss = nn.CTCLoss(blank=27)
optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE)
scheduler = ReduceLROnPlateau(optimizer, 'min', patience=3, verbose=True)

# Early Stopping Setup
best_loss = np.inf
patience_counter = 0

In [None]:
# Validation Function
def evaluate(validation_loader):
    model.eval()
    total_val_loss = 0
    with torch.no_grad():
        for waveforms, labels, input_lengths, label_lengths in validation_loader:
            waveforms, labels = waveforms.to(DEVICE), labels.to(DEVICE)
            outputs = model(waveforms)
            log_probs = nn.functional.log_softmax(outputs, dim=2)
            val_loss = ctc_loss(log_probs.permute(1, 0, 2), labels, input_lengths, label_lengths)
            total_val_loss += val_loss.item()

    avg_val_loss = total_val_loss / len(validation_loader)
    return avg_val_loss

def train(train_loader, validation_loader):
    global best_loss, patience_counter
    model.train()  # Ensure the model is in training mode

    for epoch in range(EPOCHS):
        total_loss = 0
        for waveforms, labels, input_lengths, label_lengths in train_loader:
            waveforms, labels = waveforms.to(DEVICE), labels.to(DEVICE)
            optimizer.zero_grad()
            outputs = model(waveforms)
            log_probs = nn.functional.log_softmax(outputs, dim=2)
            loss = ctc_loss(log_probs.permute(1, 0, 2), labels, input_lengths, label_lengths)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        
        avg_loss = total_loss / len(train_loader)
        print(f"Epoch {epoch+1}, Training Loss: {avg_loss}")
        torch.save(model.state_dict(), f"model_{epoch+1}.pth")

        # Validation Loss
        val_loss = evaluate(validation_loader)
        print(f"Epoch {epoch+1}, Validation Loss: {val_loss}")
        model.train()

        # Learning Rate Scheduler step
        scheduler.step(val_loss)

        # Early stopping check
        if val_loss < best_loss:
            best_loss = val_loss
            patience_counter = 0
            torch.save(model.state_dict(), "best_model.pth")  # Save the best model
        else:
            patience_counter += 1

        if patience_counter >= PATIENCE:
            print("Early stopping triggered.")
            break

In [None]:
# Load Dataset and Initialize Processor
processor = AudioProcessor()
train_dataset = LIBRISPEECH("./data", url="train-clean-360", download=True)
validation_dataset = LIBRISPEECH("./data", url="dev-clean", download=True)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, collate_fn=collate_fn, shuffle=True)
validation_loader = DataLoader(validation_dataset, batch_size=BATCH_SIZE, collate_fn=collate_fn, shuffle=False)

# Run Training
train(train_loader, validation_loader)