In [4]:
!pip install torchaudio jiwer einops
!apt-get install -y libsndfile1

!mkdir -p ./data/librispeech 

Reading package lists... Done
Building dependency tree... Done
Reading state information... Done
libsndfile1 is already the newest version (1.0.31-2ubuntu0.2).
0 upgraded, 0 newly installed, 0 to remove and 87 not upgraded.


In [10]:
import torch
import torch.nn as nn
import torchaudio
import time
import numpy as np
from torch.utils.data import DataLoader, random_split
from torch.nn.utils.rnn import pad_sequence
from torchaudio.transforms import MelSpectrogram, Resample
from einops import rearrange
from tqdm import tqdm
from jiwer import wer

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


In [11]:
class LibriSpeechPreprocessor:
    def __init__(self, sample_rate=16000, n_mels=80, n_fft=512, hop_length=256):
        self.mel = MelSpectrogram(sample_rate=sample_rate, n_mels=n_mels, n_fft=n_fft, hop_length=hop_length)
        self.resample = Resample(orig_freq=sample_rate, new_freq=sample_rate)

    def __call__(self, waveform, sample_rate):
        if sample_rate != 16000:
            waveform = self.resample(waveform)
        mel = self.mel(waveform)
        mel = torch.log(torch.clamp(mel, min=1e-5))
        mel = (mel - mel.mean()) / (mel.std() + 1e-5)
        return mel.squeeze(0).T


In [12]:
class CharTokenizer:
    def __init__(self):
        self.vocab = " abcdefghijklmnopqrstuvwxyz'"
        self.char2idx = {c: i for i, c in enumerate(self.vocab)}
        self.idx2char = {i: c for i, c in enumerate(self.vocab)}
        self.vocab_size = len(self.vocab)

    def encode(self, text):
        return [self.char2idx.get(c, 0) for c in text.lower()]

    def decode(self, indices):
        return ''.join([self.idx2char.get(i, '') for i in indices if i != 0])

tokenizer = CharTokenizer()

class LibriSpeechDataset(torch.utils.data.Dataset):
    def __init__(self, root, url, transform):
        self.dataset = torchaudio.datasets.LIBRISPEECH(root, url=url, download=True)
        self.transform = transform

    def __getitem__(self, idx):
        waveform, sr, text, *_ = self.dataset[idx]
        return self.transform(waveform, sr), text.lower()

    def __len__(self):
        return len(self.dataset)

def collate_fn(batch):
    specs, texts = zip(*batch)
    specs = [s for s in specs]
    labels = [torch.tensor(tokenizer.encode(t)) for t in texts]
    return (
        pad_sequence(specs, batch_first=True),
        torch.tensor([len(s) for s in specs]),
        pad_sequence(labels, batch_first=True),
        torch.tensor([len(l) for l in labels]),
    )


In [13]:
class Swish(nn.Module):
    def forward(self, x):
        return x * torch.sigmoid(x)

class ConformerBlock(nn.Module):
    def __init__(self, d_model, heads, kernel_size):
        super().__init__()
        self.ff1 = nn.Sequential(
            nn.LayerNorm(d_model),
            nn.Linear(d_model, d_model * 4),
            Swish(),
            nn.Linear(d_model * 4, d_model)
        )
        self.self_attn = nn.MultiheadAttention(d_model, heads, batch_first=True)
        self.ln1 = nn.LayerNorm(d_model)
        self.conv = nn.Sequential(
            nn.Conv1d(d_model, d_model * 2, 1),
            nn.GLU(dim=1),
            nn.Conv1d(d_model, d_model, kernel_size, padding=kernel_size // 2, groups=d_model),
            nn.BatchNorm1d(d_model),
            Swish(),
            nn.Conv1d(d_model, d_model, 1)
        )
        self.ln2 = nn.LayerNorm(d_model)
        self.ff2 = nn.Sequential(
            nn.LayerNorm(d_model),
            nn.Linear(d_model, d_model * 4),
            Swish(),
            nn.Linear(d_model * 4, d_model)
        )

    def forward(self, x):
        x = x + 0.5 * self.ff1(x)
        x = x + self.self_attn(x, x, x)[0]
        x = self.ln1(x)
        conv_input = rearrange(x, 'b t d -> b d t')
        conv_out = self.conv(conv_input)
        x = x + rearrange(conv_out, 'b d t -> b t d')
        x = x + 0.5 * self.ff2(x)
        return self.ln2(x)

class ConformerCTCModel(nn.Module):
    def __init__(self, input_dim=80, d_model=256, num_blocks=8, heads=4, vocab_size=30):
        super().__init__()
        self.frontend = nn.Sequential(
            nn.Conv1d(input_dim, d_model // 2, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv1d(d_model // 2, d_model, kernel_size=3, stride=2, padding=1),
            nn.ReLU()
        )
        self.encoder = nn.Sequential(*[ConformerBlock(d_model, heads, kernel_size=15) for _ in range(num_blocks)])
        self.fc = nn.Linear(d_model, vocab_size)

    def forward(self, x):
        x = x.permute(0, 2, 1)  # (B, F, T)
        x = self.frontend(x)    # (B, D, T//4)
        x = x.permute(0, 2, 1)  # (B, T//4, D)
        x = self.encoder(x)
        return self.fc(x)


In [14]:
def greedy_decode(preds):
    return ["".join(tokenizer.decode([p for i, p in enumerate(seq) if p != 0 and (i == 0 or p != seq[i-1])])) for seq in preds]

def train_epoch(model, loader, optimizer, scheduler, criterion):
    model.train()
    total_loss, start = 0, time.time()
    for x, xlen, y, ylen in tqdm(loader, desc="Train"):
        x, y = x.to(device), y.to(device)
        out = model(x).permute(1, 0, 2)  # (T, B, C)
        loss = criterion(out, y, xlen // 4, ylen)
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 5.0)
        optimizer.step()
        scheduler.step()
        total_loss += loss.item()
    return total_loss / len(loader), time.time() - start

def evaluate(model, loader, criterion):
    model.eval()
    total_loss, total_wer = 0, 0
    with torch.no_grad():
        for x, xlen, y, ylen in tqdm(loader, desc="Eval"):
            x, y = x.to(device), y.to(device)
            out = model(x).permute(1, 0, 2)
            loss = criterion(out, y, xlen // 4, ylen)
            total_loss += loss.item()
            pred = torch.argmax(out, dim=-1).permute(1, 0)
            hyp = greedy_decode(pred)
            ref = [tokenizer.decode(t.cpu().numpy()) for t in y]
            total_wer += wer(ref, hyp)
    return total_loss / len(loader), total_wer / len(loader)


In [15]:
preprocessor = LibriSpeechPreprocessor()
train_dataset = LibriSpeechDataset('./data/librispeech', 'train-clean-100', preprocessor)
train_len = int(0.9 * len(train_dataset))
train_data, val_data = random_split(train_dataset, [train_len, len(train_dataset) - train_len])
test_data = LibriSpeechDataset('./data/librispeech', 'test-clean', preprocessor)

train_loader = DataLoader(train_data, batch_size=8, shuffle=True, collate_fn=collate_fn)
val_loader = DataLoader(val_data, batch_size=8, collate_fn=collate_fn)
test_loader = DataLoader(test_data, batch_size=8, collate_fn=collate_fn)

model = ConformerCTCModel(vocab_size=tokenizer.vocab_size).to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)
scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=3e-4, steps_per_epoch=len(train_loader), epochs=10)
criterion = nn.CTCLoss(blank=0, zero_infinity=True)

for epoch in range(1, 11):
    print(f"\nEpoch {epoch}")
    train_loss, train_time = train_epoch(model, train_loader, optimizer, scheduler, criterion)
    val_loss, val_wer = evaluate(model, val_loader, criterion)
    print(f"Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f} | Val WER: {val_wer:.2%} | Time: {train_time:.2f}s")
    torch.save(model.state_dict(), f"conformer_ctc_epoch{epoch}.pth")

test_loss, test_wer = evaluate(model, test_loader, criterion)
print(f"\nTest Loss: {test_loss:.4f} | Test WER: {test_wer:.2%}")



Epoch 1


Train: 100%|██████████| 3211/3211 [06:29<00:00,  8.24it/s]
Eval: 100%|██████████| 357/357 [00:52<00:00,  6.78it/s]


Train Loss: 1.5304 | Val Loss: 1.6734 | Val WER: 100.00% | Time: 389.56s

Epoch 2


Train: 100%|██████████| 3211/3211 [06:28<00:00,  8.27it/s]
Eval: 100%|██████████| 357/357 [00:52<00:00,  6.77it/s]


Train Loss: 1.5694 | Val Loss: 1.5625 | Val WER: 100.00% | Time: 388.34s

Epoch 3


Train: 100%|██████████| 3211/3211 [06:29<00:00,  8.25it/s]
Eval: 100%|██████████| 357/357 [00:57<00:00,  6.17it/s]


Train Loss: 1.4488 | Val Loss: 1.4636 | Val WER: 100.00% | Time: 389.23s

Epoch 4


Train: 100%|██████████| 3211/3211 [06:28<00:00,  8.26it/s]
Eval: 100%|██████████| 357/357 [01:00<00:00,  5.86it/s]


Train Loss: 1.2988 | Val Loss: 1.3218 | Val WER: 100.00% | Time: 388.88s

Epoch 5


Train: 100%|██████████| 3211/3211 [06:27<00:00,  8.30it/s]
Eval: 100%|██████████| 357/357 [00:57<00:00,  6.22it/s]


Train Loss: 1.1646 | Val Loss: 1.1460 | Val WER: 100.00% | Time: 387.01s

Epoch 6


Train: 100%|██████████| 3211/3211 [06:27<00:00,  8.29it/s]
Eval: 100%|██████████| 357/357 [00:59<00:00,  5.98it/s]


Train Loss: 1.0554 | Val Loss: 1.0904 | Val WER: 100.00% | Time: 387.34s

Epoch 7


Train: 100%|██████████| 3211/3211 [06:27<00:00,  8.28it/s]
Eval: 100%|██████████| 357/357 [01:00<00:00,  5.88it/s]


Train Loss: 0.9522 | Val Loss: 0.9866 | Val WER: 100.00% | Time: 387.90s

Epoch 8


Train: 100%|██████████| 3211/3211 [06:28<00:00,  8.26it/s]
Eval: 100%|██████████| 357/357 [01:01<00:00,  5.81it/s]


Train Loss: 0.8564 | Val Loss: 0.9451 | Val WER: 100.00% | Time: 388.75s

Epoch 9


Train: 100%|██████████| 3211/3211 [06:27<00:00,  8.28it/s]
Eval: 100%|██████████| 357/357 [01:01<00:00,  5.78it/s]


Train Loss: 0.7776 | Val Loss: 0.9115 | Val WER: 100.00% | Time: 387.92s

Epoch 10


Train: 100%|██████████| 3211/3211 [06:27<00:00,  8.30it/s]
Eval: 100%|██████████| 357/357 [01:01<00:00,  5.79it/s]


Train Loss: 0.7321 | Val Loss: 0.9134 | Val WER: 100.00% | Time: 387.00s


Eval: 100%|██████████| 328/328 [00:44<00:00,  7.39it/s]


Test Loss: 0.7932 | Test WER: 100.00%



