In [1]:
import torchaudio
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
from torchaudio.datasets import LIBRISPEECH
from torchaudio.transforms import MelSpectrogram, Resample
from torch.nn.utils.rnn import pad_sequence
from tqdm import tqdm

In [2]:
# Define Hyperparameters
BATCH_SIZE = 32
EPOCHS = 10
LEARNING_RATE = 1e-3
SAMPLE_RATE = 16000
N_MELS = 80
HIDDEN_SIZE = 512
NUM_LAYERS = 3
NUM_CLASSES = 28  # 26 letters + space + blank for CTC
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
EARLY_STOPPING_PATIENCE = 3

# Character Mapping
char_map = {c: i for i, c in enumerate("abcdefghijklmnopqrstuvwxyz ")}
char_map[''] = 27  # CTC blank token
idx_map = {i: c for c, i in char_map.items()}

In [3]:
# Data Processing
class AudioProcessor:
    def __init__(self):
        self.mel_spec = MelSpectrogram(sample_rate=SAMPLE_RATE, n_mels=N_MELS)
        self.resample = Resample(orig_freq=48000, new_freq=SAMPLE_RATE)

    def __call__(self, waveform, sample_rate):
        if sample_rate != SAMPLE_RATE:
            waveform = self.resample(waveform)
        return self.mel_spec(waveform).squeeze(0)

def text_to_int(text):
    return torch.tensor([char_map[c] for c in text.lower() if c in char_map], dtype=torch.long)

def collate_fn(batch):
    waveforms, labels, input_lengths, label_lengths = [], [], [], []
    for waveform, _, text, _, _, _ in batch:
        spec = processor(waveform, SAMPLE_RATE)
        waveforms.append(spec.T)
        labels.append(text_to_int(text))
        input_lengths.append(spec.shape[1])
        label_lengths.append(len(labels[-1]))
    waveforms = pad_sequence(waveforms, batch_first=True).permute(0, 2, 1)
    labels = pad_sequence(labels, batch_first=True)
    return waveforms, labels, torch.tensor(input_lengths), torch.tensor(label_lengths)

def int_to_text(int_seq):
    return ''.join([idx_map[i] for i in int_seq])

# Model Definition
class STTModel(nn.Module):
    def __init__(self):
        super(STTModel, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
        )
        self.rnn = nn.LSTM(input_size=N_MELS, hidden_size=HIDDEN_SIZE, num_layers=NUM_LAYERS, batch_first=True, bidirectional=True)
        self.fc = nn.Linear(HIDDEN_SIZE * 2, NUM_CLASSES)

    def forward(self, x):
        x = x.unsqueeze(1)
        x = self.conv(x)
        x = x.permute(0, 3, 2, 1).mean(dim=-1)
        x, _ = self.rnn(x)
        return self.fc(x)

# Training Setup
model = STTModel().to(DEVICE)
ctc_loss = nn.CTCLoss(blank=27)
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
train_losses, val_losses = [], []
best_val_loss = float('inf')
no_improve_epochs = 0

def train(train_loader, val_loader):
    global best_val_loss, no_improve_epochs
    for epoch in range(EPOCHS):
        total_train_loss = 0
        model.train()
        for waveforms, labels, input_lengths, label_lengths in tqdm(train_loader, desc=f"Epoch {epoch+1}"):
            waveforms, labels = waveforms.to(DEVICE), labels.to(DEVICE)
            optimizer.zero_grad()
            outputs = model(waveforms)
            log_probs = nn.functional.log_softmax(outputs, dim=2)
            loss = ctc_loss(log_probs.permute(1, 0, 2), labels, input_lengths, label_lengths)
            loss.backward()
            optimizer.step()
            total_train_loss += loss.item()
        train_loss = total_train_loss / len(train_loader)
        train_losses.append(train_loss)

        # Validation
        model.eval()
        total_val_loss = 0
        with torch.no_grad():
            for waveforms, labels, input_lengths, label_lengths in val_loader:
                waveforms, labels = waveforms.to(DEVICE), labels.to(DEVICE)
                outputs = model(waveforms)
                log_probs = nn.functional.log_softmax(outputs, dim=2)
                loss = ctc_loss(log_probs.permute(1, 0, 2), labels, input_lengths, label_lengths)
                total_val_loss += loss.item()
        val_loss = total_val_loss / len(val_loader)
        val_losses.append(val_loss)

        print(f"Epoch {epoch+1}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}")
        torch.save(model.state_dict(), f"./models/model_{epoch+1}.pth")
        
        # Early stopping
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            no_improve_epochs = 0
            torch.save(model.state_dict(), "./models/best_model.pth")
        else:
            no_improve_epochs += 1
            if no_improve_epochs >= EARLY_STOPPING_PATIENCE:
                print("Early stopping triggered.")
                break

def evaluate(val_loader):
    model.load_state_dict(torch.load("./models/best_model.pth"))
    model.eval()
    with torch.no_grad():
        for waveforms, labels, input_lengths, label_lengths in val_loader:
            waveforms, labels = waveforms.to(DEVICE), labels.to(DEVICE)
            outputs = model(waveforms)
            predictions = torch.argmax(outputs, dim=2)
            for pred in predictions:
                print(int_to_text(pred.cpu().numpy()))
            break  # Print only first batch


In [4]:
# Load Dataset
processor = AudioProcessor()
dataset = LIBRISPEECH("./data", url="train-clean-360", download=True)
train_size = int(0.9 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, collate_fn=collate_fn, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, collate_fn=collate_fn)


In [5]:
# Run Training
train(train_loader, val_loader)

Epoch 1: 100%|██████████| 2926/2926 [1:24:46<00:00,  1.74s/it]


Epoch 1, Train Loss: 1.1174, Val Loss: 0.4777


Epoch 2: 100%|██████████| 2926/2926 [1:24:54<00:00,  1.74s/it]


Epoch 2, Train Loss: 0.3867, Val Loss: 0.3363


Epoch 3: 100%|██████████| 2926/2926 [1:24:17<00:00,  1.73s/it]


Epoch 3, Train Loss: 0.2824, Val Loss: 0.2761


Epoch 4: 100%|██████████| 2926/2926 [1:24:17<00:00,  1.73s/it]


Epoch 4, Train Loss: 0.2245, Val Loss: 0.2566


Epoch 5: 100%|██████████| 2926/2926 [1:24:55<00:00,  1.74s/it]


Epoch 5, Train Loss: 0.1927, Val Loss: 0.2346


Epoch 6: 100%|██████████| 2926/2926 [1:24:39<00:00,  1.74s/it]


Epoch 6, Train Loss: 0.1691, Val Loss: 0.2132


Epoch 7: 100%|██████████| 2926/2926 [1:24:10<00:00,  1.73s/it]


Epoch 7, Train Loss: 0.1454, Val Loss: 0.2189


Epoch 8: 100%|██████████| 2926/2926 [1:24:02<00:00,  1.72s/it]


Epoch 8, Train Loss: 0.1327, Val Loss: 0.2039


Epoch 9: 100%|██████████| 2926/2926 [1:23:44<00:00,  1.72s/it]


Epoch 9, Train Loss: 0.1207, Val Loss: 0.2005


Epoch 10:  22%|██▏       | 646/2926 [18:28<1:05:12,  1.72s/it]


KeyboardInterrupt: 

In [None]:
evaluate(val_loader)