In [None]:
import os
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchaudio
from torch.utils.data import DataLoader
from torchaudio.datasets import LIBRISPEECH

In [None]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
SAMPLE_RATE = 16000
N_MFCC = 13
N_MELS = 80
BLANK_ID = 0

torch.manual_seed(1234)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

In [None]:
# =========================
# 2) Vocab & encodage
# =========================
# LibriSpeech contient typiquement des majuscules et apostrophes, on normalise en minuscules
# et on restreint à espace + apostrophe + a..z.
VOCAB = [" ", "'"] + [chr(i) for i in range(97, 123)]  # " ", "'", a..z
char2id = {c: i + 1 for i, c in enumerate(VOCAB)}      # 1.. (0 = blank)
id2char = {i + 1: c for i, c in enumerate(VOCAB)}

def normalize_text(txt: str) -> str:
    txt = txt.lower()
    return "".join(ch for ch in txt if ch in char2id)

def text_to_int(txt: str):
    return [char2id[ch] for ch in normalize_text(txt)]

def int_to_text(ids):
    return "".join(id2char[i] for i in ids if i in id2char)

In [None]:
# =========================
# 3) MFCC & resample
# =========================
mfcc_transform = torchaudio.transforms.MFCC(
    sample_rate=SAMPLE_RATE,
    n_mfcc=N_MFCC,
    melkwargs={"n_fft": 400, "hop_length": 160, "n_mels": 40, "center": False},
)

_resamplers = {}
def maybe_resample(waveform: torch.Tensor, sr: int) -> torch.Tensor:
    if sr == SAMPLE_RATE:
        return waveform
    # Cache du resampler par sr d’origine
    if sr not in _resamplers:
        _resamplers[sr] = torchaudio.transforms.Resample(orig_freq=sr, new_freq=SAMPLE_RATE)
    return _resamplers[sr](waveform)

In [None]:
# =========================
# 3) Bi-directional GRU
# =========================
class BidirGRU_CTC(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_classes, num_layers=3, dropout=0.3):
        super().__init__()

        self.gru = nn.GRU(
            input_dim,
            hidden_dim,
            num_layers=num_layers,
            batch_first=True,
            bidirectional=True, # bidirectional pour savoir determiner a partir de ce qui vient apres
            dropout=dropout if num_layers > 1 else 0
        )

        self.fc = nn.Linear(hidden_dim * 2, num_classes)

        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        """
        x: (batch, seq_len, input_size)
        """

        out, _ = self.gru(x)

        out = self.dropout(out)
        logits = self.fc(out)

        logits = logits.transpose(0, 1)
        return logits

In [None]:
# =========================
# 5) Collate function
# =========================
def collate_fn(batch):
    """
    batch: list of tuples (waveform, sample_rate, transcript, speaker_id, chapter_id, utterance_id)
    Retour:
      - features: (B, T_max, F)
      - input_lengths: (B,)
      - targets_concat: (sum_S,)
      - target_lengths: (B,)
      - texts: liste des textes normalisés (debug/affichage)
    """
    feats_list = []
    input_lengths = []
    targets = []
    target_lengths = []
    texts = []

    for waveform, sr, transcript, *_ in batch:
        # waveform: (channels, time) -> mono
        if waveform.dim() == 2:
            waveform = waveform.mean(dim=0)  # (time,)
        else:
            waveform = waveform.squeeze(0)   # (time,)

        # Resample si besoin
        waveform = maybe_resample(waveform, int(sr))

        # MFCC: entrée 1D -> sortie (n_mfcc, time), puis transpose -> (time, n_mfcc)
        mfcc = mfcc_transform(waveform)      # (n_mfcc, time)
        if mfcc.dim() == 3:                  # sécurité si un canal subsiste
            mfcc = mfcc.squeeze(0)           # (n_mfcc, time)
        feat = mfcc.transpose(0, 1).contiguous()  # (time, n_mfcc)

        feats_list.append(feat)
        input_lengths.append(feat.shape[0])

        norm_txt = normalize_text(transcript)
        y = torch.tensor(text_to_int(norm_txt), dtype=torch.long)
        targets.append(y)
        target_lengths.append(len(y))
        texts.append(norm_txt)

    # Padding temporel pour empiler en batch-first
    # feats_list: list de (T_i, F) -> (B, T_max, F)
    F = feats_list[0].shape[1]
    T_max = max(t.shape[0] for t in feats_list)
    padded = torch.zeros(len(feats_list), T_max, F, dtype=feats_list[0].dtype)
    for i, f in enumerate(feats_list):
        padded[i, : f.shape[0]] = f

    features = padded
    input_lengths = torch.tensor(input_lengths, dtype=torch.long)
    targets_concat = torch.cat(targets) if len(targets) > 0 else torch.empty(0, dtype=torch.long)
    target_lengths = torch.tensor(target_lengths, dtype=torch.long)

    return features, input_lengths, targets_concat, target_lengths, texts


In [None]:
# =========================
# 6) Greedy CTC decode
# =========================
def greedy_ctc_decode(log_probs: torch.Tensor, input_lengths: torch.Tensor):
    """
    log_probs: (T, B, C)
    input_lengths: (B,)
    """
    max_ids = log_probs.argmax(dim=-1)  # (T, B)
    T, B = max_ids.shape
    results = []
    for b in range(B):
        prev = None
        seq = []
        for t in range(int(input_lengths[b].item())):
            p = int(max_ids[t, b].item())
            if p != BLANK_ID and p != prev:
                seq.append(p)
            prev = p
        results.append(int_to_text(seq))
    return results

In [None]:
# =========================
# 7) Entraînement
# =========================
def train_one_epoch(model, loader, optimizer, criterion, epoch, log_interval=50, grad_clip=5.0):
    model.train()
    running = 0.0
    for step, (features, input_lengths, targets_concat, target_lengths, texts) in enumerate(loader, 1):
        features = features.to(DEVICE)            # (B, T, F)
        input_lengths = input_lengths.to(DEVICE)  # (B,)
        targets_concat = targets_concat.to(DEVICE)  # (sum_S,)
        target_lengths = target_lengths.to(DEVICE)  # (B,)

        logits = model(features)                 # (T, B, C)
        log_probs = logits.log_softmax(dim=2)    # CTC attend log-probas

        loss = criterion(log_probs, targets_concat, input_lengths, target_lengths)

        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        if grad_clip is not None:
            torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
        optimizer.step()

        running += loss.item()
        if step % log_interval == 0:
            avg = running / log_interval
            print(f"[epoch {epoch} step {step}] loss={avg:.4f}")
            running = 0.0

            # Décodage rapide sur le mini-batch courant
            with torch.no_grad():
                preds = greedy_ctc_decode(log_probs.detach(), input_lengths)
            # Affiche deux exemples
            for i in range(min(2, len(preds))):
                tgt = texts[i]
                print(f"  tgt: {tgt}")
                print(f"  prd: {preds[i]}")

In [None]:
!pip install torchcodec

Collecting torchcodec
  Downloading torchcodec-0.8.1-cp312-cp312-manylinux_2_28_x86_64.whl.metadata (9.7 kB)
Downloading torchcodec-0.8.1-cp312-cp312-manylinux_2_28_x86_64.whl (2.0 MB)
[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/2.0 MB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.0/2.0 MB[0m [31m63.8 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: torchcodec
Successfully installed torchcodec-0.8.1


In [None]:
train_dataset = LIBRISPEECH("./data", url="dev-clean", download=True)

# DataLoader
train_loader = DataLoader(
    train_dataset,
    batch_size=8,                  # ajuster selon mémoire GPU/CPU
    shuffle=True,
    num_workers=4,                 # ajuster selon machine
    collate_fn=collate_fn,
    pin_memory=True if DEVICE.type == "cuda" else False,
)

num_classes = 1 + len(VOCAB)      # 0=blank + 1.. vocab
model = BidirGRU_CTC(input_dim=N_MFCC, hidden_dim=128, num_classes=num_classes).to(DEVICE)
criterion = nn.CTCLoss(blank=BLANK_ID, zero_infinity=True).to(DEVICE)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
epochs = 50   # commencer petit, augmenter ensuite
for ep in range(1, epochs + 1):
    train_one_epoch(model, train_loader, optimizer, criterion, ep, log_interval=200)


[epoch 1 step 200] loss=3.1038
  tgt: for some little time that is it seemed long though i believe it was not more than a minute before two men came running from the musicians gallery
  prd: 
  tgt: i have always delighted in and reverenced beauty but i felt simply abashed in the presence of such a splendid type a compound of all that is best in egyptian greek and italian
  prd: 
[epoch 2 step 200] loss=2.3716
  tgt: the most gifted individuals in the land emulated each other in proving which entertained for him the most sincere affection
  prd:  m n   s       r rs er
  tgt: we've lost the key of the cellar and there's nothing out except water and i don't think you'd care for that
  prd: ssns d    
[epoch 3 step 200] loss=1.9733
  tgt: on the tenth of october he would meet rod at sprucewood on the black sturgeon river
  prd: o ian o re o mi s wis at onn he ucr er
  tgt: a little longer and she was compelled to yield and the silent tears flowed freely
  prd:  e mrin se wis om l l an te 

In [None]:
def calculate_cer(predictions, targets):
    """Calcule le Character Error Rate"""
    total_chars = 0
    total_errors = 0

    for pred, tgt in zip(predictions, targets):
        errors = sum(1 for p, t in zip(pred, tgt) if p != t)
        errors += abs(len(pred) - len(tgt))
        total_errors += errors
        total_chars += len(tgt)

    return total_errors / max(total_chars, 1)

In [None]:
def validate(model, loader, criterion):
    model.eval()
    total_loss = 0.0
    all_preds = []
    all_targets = []

    with torch.no_grad():
        for features, input_lengths, targets_concat, target_lengths, texts in loader:
            features = features.to(DEVICE)
            input_lengths = input_lengths.to(DEVICE)
            targets_concat = targets_concat.to(DEVICE)
            target_lengths = target_lengths.to(DEVICE)

            logits = model(features)
            log_probs = logits.log_softmax(dim=2)
            loss = criterion(log_probs, targets_concat, input_lengths, target_lengths)

            total_loss += loss.item()

            # Décodage pour CER
            preds = greedy_ctc_decode(log_probs, input_lengths)
            all_preds.extend(preds)
            all_targets.extend(texts)

    avg_loss = total_loss / len(loader)
    cer = calculate_cer(all_preds, all_targets)

    return avg_loss, cer