In [None]:
import os
import numpy as np
import librosa
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchaudio
from torchaudio.transforms import MelSpectrogram

# === GLOBAL CONFIGURATION ===
SAMPLE_RATE = 16000
N_MELS = 40
BATCH_SIZE = 96
NUM_EPOCHS = 10
LEARNING_RATE = 0.001
NUM_WORKERS = 0  # No parallel data loading workers -> THEY DONT WORK WITH JUPYTER!

# --------------------------------------------------
# CTC REQUIRES A BLANK LABEL
# We'll define '_' (underscore) as the blank token (index 0).
# The rest of the symbols follow after.
# --------------------------------------------------
LABELS = "_abcdefghijklmnopqrstuvwxyz' "  # 1st char '_' is for blank
BLANK_IDX = 0
LABEL2IDX = {label: idx for idx, label in enumerate(LABELS)}
IDX2LABEL = {idx: label for label, idx in LABEL2IDX.items()}

# === VERBOSE PRINT FUNCTION ===
verbose = True
def verbose_print(message):
    if verbose:
        print(message)

In [None]:
# === DATASET DEFINITION ===
class SpeechDataset(Dataset):
    def __init__(self, wav_paths, align_paths):
        self.wav_paths = wav_paths
        self.align_paths = align_paths
        self.mel_transform = MelSpectrogram(sample_rate=SAMPLE_RATE, n_mels=N_MELS)

        verbose_print(f"Initialized SpeechDataset with {len(self.wav_paths)} samples.")

    def __len__(self):
        return len(self.wav_paths)

    def __getitem__(self, idx):
        wav_path = self.wav_paths[idx]
        align_path = self.align_paths[idx]

        # Load audio and resample
        audio, _ = librosa.load(wav_path, sr=SAMPLE_RATE)
        
        # Extract Mel-spectrogram
        audio_tensor = torch.tensor(audio).float()
        features = self.mel_transform(audio_tensor)  # (n_mels, time)
        features = features.transpose(0, 1)  # Make it (time, n_mels)

        # Load alignment and convert to label indices
        with open(align_path, 'r') as f:
            alignment = f.read().strip().replace("\n", " ")
        # Convert each character to its label index if it exists in LABEL2IDX
        labels = [LABEL2IDX[char] for char in alignment if char in LABEL2IDX]

        return features, torch.tensor(labels, dtype=torch.long)

# === CUSTOM COLLATE FUNCTION ===
def collate_fn(batch):
    """
    batch: List of tuples (features, labels) from SpeechDataset.
    Each 'features' is shape (T, n_mels).
    Each 'labels' is shape (L,).
    """
    features_list, labels_list = zip(*batch)

    max_feat_len = max(feat.size(0) for feat in features_list)

    padded_features = []
    input_lengths = []
    for feat in features_list:
        seq_len = feat.size(0)
        pad_len = max_feat_len - seq_len
        feat_padded = F.pad(feat, (0, 0, 0, pad_len))  # pad time dimension
        padded_features.append(feat_padded)
        input_lengths.append(seq_len)

    target_lengths = [len(lbl) for lbl in labels_list]
    flattened_labels = torch.cat(labels_list)

    padded_features = torch.stack(padded_features, dim=0)  # (batch, max_time, n_mels)
    input_lengths = torch.tensor(input_lengths, dtype=torch.long)
    target_lengths = torch.tensor(target_lengths, dtype=torch.long)

    return padded_features, flattened_labels, input_lengths, target_lengths

In [None]:
# === MODEL DEFINITION ===
class CTCModel(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, dropout=0.3):
        super(CTCModel, self).__init__()
        self.lstm = nn.LSTM(
            input_dim, 
            hidden_dim, 
            batch_first=True, 
            bidirectional=True,
            dropout=dropout,
            num_layers=2  # 2-layer LSTM
        )
        self.fc = nn.Linear(hidden_dim * 2, output_dim)  # bidirectional doubles hidden_dim

    def forward(self, x):
        """
        x: (batch, time, input_dim)
        returns: (batch, time, output_dim)
        """
        x, _ = self.lstm(x)
        x = self.fc(x)
        return x

# === CTC LOSS WRAPPER ===
def ctc_loss_fn(logits, targets, input_lengths, target_lengths):
    """
    logits: (batch, time, num_classes)
    targets: (N) 1D tensor of all targets concatenated
    """
    logits = logits.permute(1, 0, 2)  # -> (time, batch, num_classes)
    log_probs = F.log_softmax(logits, dim=2)
    
    ctc_loss = nn.CTCLoss(
        blank=BLANK_IDX,
        zero_infinity=True
    )(log_probs, targets, input_lengths, target_lengths)

    return ctc_loss

# === TRAINING LOOP ===
def train_model(model, dataloader, optimizer, num_epochs, device):
    model.train()
    total_steps = len(dataloader) * num_epochs  # total number of training steps
    verbose_print(f"[train_model] Total epochs: {num_epochs}, total steps: {total_steps}")

    for epoch in range(num_epochs):
        total_loss = 0.0
        verbose_print(f"[train_model] Starting epoch {epoch+1}/{num_epochs} (batches in epoch: {len(dataloader)})...")
        
        for batch_idx, (features, labels, input_lengths, target_lengths) in enumerate(dataloader):
            global_step = epoch * len(dataloader) + batch_idx  # current global step

            # Move data to device
            features = features.to(device)
            labels = labels.to(device)
            input_lengths = input_lengths.to(device)
            target_lengths = target_lengths.to(device)

            # Forward
            optimizer.zero_grad()
            logits = model(features)

            # Compute CTC loss
            loss = ctc_loss_fn(logits, labels, input_lengths, target_lengths)
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

            # Print epoch details for verbose
            if verbose and (batch_idx % 25 == 0):
                verbose_print(
                    f"  Epoch [{epoch+1}/{num_epochs}], "
                    f"Batch [{batch_idx}/{len(dataloader)}], "
                    f"Global Step: {global_step}/{total_steps}, "
                    f"Loss: {loss.item():.4f}, "
                    f"Features shape: {features.shape}, "
                    f"Labels shape: {labels.shape}"
                )

        # Compute avg loss per batch
        avg_loss = total_loss / len(dataloader)
        verbose_print(f"Epoch [{epoch+1}/{num_epochs}] completed - Avg Loss: {avg_loss:.4f}\n")


# === DECODE FUNCTION (GREEDY) ===
def decode_ctc(logits):
    """
    logits: (batch, time, num_classes)
    Returns a list of decoded strings (one per batch element).
    """
    argmax_indices = torch.argmax(logits, dim=2)  # (batch, time)
    decoded_batch = []

    for sequence in argmax_indices:
        decoded_seq = []
        prev_idx = None
        for idx in sequence:
            idx = idx.item()
            # Skip repeating indices or blank
            if idx != prev_idx and idx != BLANK_IDX:
                decoded_seq.append(IDX2LABEL[idx])
            prev_idx = idx
        decoded_batch.append("".join(decoded_seq))

    return decoded_batch

# ------------------------------------------------------------------
# Helper function to gather data for speakers s1..s34,
# excluding s1, s2, s20, s21, and s22 by default.
# ------------------------------------------------------------------
def gather_speaker_data(wav_root, align_root, exclude_speakers=None):
    """
    Gathers .wav and .align file paths from speaker directories
    (s1 through s34) but excludes certain speaker IDs if provided.
    Returns two sorted lists of paths: (all_wav_paths, all_align_paths).
    """
    if exclude_speakers is None:
        exclude_speakers = [1, 2, 20, 21, 22]

    all_wav_paths = []
    all_align_paths = []

    for spk_id in range(1, 35):
        if spk_id in exclude_speakers:
            verbose_print(f"[gather_speaker_data] Skipping speaker s{spk_id} (excluded).")
            continue

        spk_dir = f"s{spk_id}"
        spk_wav_dir = os.path.join(wav_root, spk_dir)
        spk_align_dir = os.path.join(align_root, spk_dir)

        if not os.path.isdir(spk_wav_dir) or not os.path.isdir(spk_align_dir):
            verbose_print(f"[gather_speaker_data] Speaker s{spk_id} directory missing, skipping.")
            continue

        verbose_print(f"[gather_speaker_data] Collecting data for speaker s{spk_id}...")

        # Collect .wav files
        spk_wavs = [
            os.path.join(spk_wav_dir, fname)
            for fname in os.listdir(spk_wav_dir)
            if fname.endswith(".wav")
        ]
        # Collect .align files
        spk_aligns = [
            os.path.join(spk_align_dir, fname)
            for fname in os.listdir(spk_align_dir)
            if fname.endswith(".align")
        ]

        verbose_print(f"  Found {len(spk_wavs)} wavs and {len(spk_aligns)} aligns in s{spk_id}.")

        all_wav_paths.extend(spk_wavs)
        all_align_paths.extend(spk_aligns)

    # Sort for consistent ordering
    all_wav_paths.sort()
    all_align_paths.sort()

    return all_wav_paths, all_align_paths

# ------------------------------------------------------------------
# MAIN FUNCTION that uses data from s1..s34 except s1, s2, s20, s21, s22
# ------------------------------------------------------------------
def main_excluding_some_speakers(wav_root, align_root):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    verbose_print(f"Using device: {device}")

    wav_paths, align_paths = gather_speaker_data(
        wav_root, align_root, 
        exclude_speakers=[1, 2, 20, 21, 22]
    )
    verbose_print(f"Total .wav files collected: {len(wav_paths)}")
    verbose_print(f"Total .align files collected: {len(align_paths)}")

    assert len(wav_paths) == len(align_paths), "Number of .wav and .align files must match"

    dataset = SpeechDataset(wav_paths, align_paths)
    dataloader = DataLoader(
        dataset,
        batch_size=BATCH_SIZE,
        shuffle=True,
        collate_fn=collate_fn,
        drop_last=True,
        num_workers=NUM_WORKERS
    )

    input_dim = N_MELS
    hidden_dim = 256
    output_dim = len(LABELS)  # includes blank
    model = CTCModel(input_dim, hidden_dim, output_dim, dropout=0.3)
    model.to(device)

    # Optimizer
    optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)

    # Train
    train_model(model, dataloader, optimizer, num_epochs=NUM_EPOCHS, device=device)

    model.eval()
    with torch.no_grad():
        for features, labels, input_lengths, target_lengths in dataloader:
            features = features.to(device)
            logits = model(features)  # (batch, time, output_dim)
            decoded_output = decode_ctc(logits)
            verbose_print("[main_excluding_some_speakers] Predictions on a batch:")
            for pred_str in decoded_output:
                print("  ", pred_str)  

In [None]:
if __name__ == "__main__":
    wav_root = r"GRID\audio_25k\audio_25k"
    align_root = r"GRID\alignments\alignments"
    main_excluding_some_speakers(wav_root, align_root)