In [None]:
!pip install textgrid
!pip install python-Levenshtein

Collecting textgrid
  Downloading TextGrid-1.6.1.tar.gz (9.4 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: textgrid
  Building wheel for textgrid (setup.py) ... [?25l[?25hdone
  Created wheel for textgrid: filename=TextGrid-1.6.1-py3-none-any.whl size=10146 sha256=aa3f4ec0012903cd5aa9a18b8733648aadeacff8c264decc84de0c8c7bc5c689
  Stored in directory: /root/.cache/pip/wheels/7a/c5/96/5e43aa4c640995fbbb0b9a7b98e6007bfd777add3c7e56d70a
Successfully built textgrid
Installing collected packages: textgrid
Successfully installed textgrid-1.6.1
Collecting python-Levenshtein
  Downloading python_levenshtein-0.27.1-py3-none-any.whl.metadata (3.7 kB)
Collecting Levenshtein==0.27.1 (from python-Levenshtein)
  Downloading levenshtein-0.27.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (3.6 kB)
Collecting rapidfuzz<4.0.0,>=3.9.0 (from Levenshtein==0.27.1->python-Levenshtein)
  Downloading rapidfuzz-3.13.0-cp311-cp311-manylin

In [None]:
import os
import re
import gc
import glob
import zipfile
import shutil
import random
import numpy as np
import librosa
import librosa.display
import IPython.display as ipd
import pandas as pd
import soundfile as sf
import scipy.fftpack
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
from collections import Counter
from itertools import zip_longest
from google.colab import files
from google.colab import drive
from textgrid import TextGrid
from Levenshtein import distance as levenshtein
import torch
from torch.utils.data import Dataset
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix, ConfusionMatrixDisplay, classification_report
from sklearn.calibration import calibration_curve

# DATA LOADING

## AUDIO

In [None]:
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
zip_files = "/content/drive/path/to/wav/zipfile" # path to zip of wav files
with zipfile.ZipFile(zip_files, 'r') as zip_ref:
    zip_ref.extractall("audio")

wav_dir = "/content/audio"

In [None]:
wav_files = sorted(os.listdir(wav_dir))
print("Extracted WAV files:")
for file in wav_files:
    print(file)

print(f"Total number of WAV files: {len(wav_files)}")

## TEXTGRIDS

In [None]:
uploaded_textgrid = files.upload()
print("Uploaded files:", uploaded_textgrid.keys())

In [None]:
with zipfile.ZipFile('zipfile.zip', 'r') as zip_ref: # zip file of corresponding textgrids
    zip_ref.extractall("textgrids")

textgrid_dir = "/content/textgrids"

In [None]:
textgrid_files = sorted(os.listdir(textgrid_dir))
print("Extracted TEXTGRID files:")
for file in textgrid_files:
    print(file)

print(f"Total number of TEXTGRID files: {len(textgrid_files)}")

# DATA PREPROCESSING AND FEATURE EXTRACTION

In [None]:
# phoneme to id dictionary
phoneme_to_id = {
    # vowels with stress symbols
    "AA0": 0, "AA1": 1, "AA2": 2, "AW0": 3, "AW1": 4, "AY0": 5, "AY1": 6,
    "EH0": 7, "EH1": 8, "ER0": 9, "EY1": 10,
    "IH1": 11, "IY0": 12, "IY1": 13, "IY2": 14,
    "OW0": 15, "OW1": 16, "OW2": 17, "OY0": 18, "OY1": 19,
    "UW0": 20, "UW1": 21, "UW2": 22,
    # Consonants (no stress markers)
    "B": 23, "D": 24, "F": 25, "G": 26, "H": 27, "JH": 28, "K": 29, "L": 30, "M": 31, "N": 32,
    "NG": 33, "P": 34, "R": 35, "S": 36, "SH": 37, "T": 38, "V": 39, "W": 40, "Y": 41, "Z": 42,
    "<BLANK>": 43
}

In [None]:
#audio preprocessing by resampling all audio to 16kHz
def preprocess_audio(file_path, target_sr=16000):
    audio, sr = librosa.load(file_path, sr=None)
    if sr != target_sr:
        audio = librosa.resample(audio, orig_sr=sr, target_sr=target_sr)
    return audio, target_sr

In [None]:
def mel_filterbank(sr=16000, n_fft=512, n_mels=26, fmin=0, fmax=None):
    if fmax is None:
        fmax = sr // 2
    mel_points = np.linspace(librosa.hz_to_mel(fmin),
                             librosa.hz_to_mel(fmax),
                             n_mels + 2)
    hz_points = librosa.mel_to_hz(mel_points)
    bin_points = np.floor((n_fft + 1) * hz_points / sr).astype(int)

    filters = np.zeros((n_mels, int(n_fft // 2 + 1)))
    for m in range(1, n_mels + 1):
        f_m_minus = bin_points[m - 1]
        f_m = bin_points[m]
        f_m_plus = bin_points[m + 1]

        for k in range(f_m_minus, f_m):
            filters[m - 1, k] = (k - f_m_minus) / (f_m - f_m_minus)
        for k in range(f_m, f_m_plus):
            filters[m - 1, k] = (f_m_plus - k) / (f_m_plus - f_m)
    return filters

# mel filterbank frequency warping (VTLN)
def apply_piecewise_warp(filters, warp_factor, pivot_freq, sr=16000, n_fft=512):
    center_bin = int(np.floor((n_fft + 1) * pivot_freq / sr))

    warped_filters = np.zeros_like(filters)
    for i in range(filters.shape[0]):
        orig_bins = np.arange(filters.shape[1])
        warped_bins = np.where(
            orig_bins <= center_bin,
            orig_bins,
            center_bin + (orig_bins - center_bin) * warp_factor
        )
        warped_filters[i] = np.interp(orig_bins, warped_bins, filters[i], left=0, right=0)
    return warped_filters

def extract_mfcc_vtln(signal, sr=16000, warp_factor=0.85, n_mfcc=13, n_mels=26, n_fft=512, hop_length=160):
    pre_emphasis = 0.97
    emphasized = np.append(signal[0], signal[1:] - pre_emphasis * signal[:-1])
    stft = librosa.stft(emphasized, n_fft=n_fft, hop_length=hop_length, win_length=400, window='hamming')
    power_spec = np.abs(stft) ** 2
    fb = mel_filterbank(sr=sr, n_fft=n_fft, n_mels=n_mels)
    fb_warped = apply_piecewise_warp(fb, warp_factor, pivot_freq=1500, sr=sr, n_fft=n_fft)
    mel_spec = np.dot(fb_warped, power_spec[:int(n_fft // 2 + 1), :])
    mel_spec = mel_spec[1:-1, :]
    log_mel_spec = librosa.power_to_db(mel_spec)
    mfcc = scipy.fftpack.dct(log_mel_spec, axis=0, type=2, norm='ortho')[0:n_mfcc].T
    return mfcc

def extract_full_features(signal, sr=16000, warp_factor=0.85):
    mfcc = extract_mfcc_vtln(signal, sr, warp_factor=warp_factor)
    d_mfcc = librosa.feature.delta(mfcc)
    dd_mfcc = librosa.feature.delta(mfcc, order=2)
    f0, _, _ = librosa.pyin(signal, fmin=80, fmax=400, sr=sr, frame_length=1024, hop_length=160)
    f0 = np.nan_to_num(f0, nan=0.0).reshape(-1, 1)
    energy = librosa.feature.rms(y=signal, frame_length=400, hop_length=160).T

    min_len = min(mfcc.shape[0], d_mfcc.shape[0], dd_mfcc.shape[0], f0.shape[0], energy.shape[0])

    mfcc_stack = np.hstack([mfcc[:min_len], d_mfcc[:min_len], dd_mfcc[:min_len]])
    mean = np.mean(mfcc_stack, axis=0)
    std = np.std(mfcc_stack, axis=0) + 1e-10
    mfcc_cmvn = (mfcc_stack - mean) / std

    f0_part = f0[:min_len]
    energy_part = energy[:min_len]

    features = np.hstack([mfcc_cmvn, f0_part, energy_part])

    return features

In [None]:
# extract phoneme sequences from TextGrid files
def extract_phoneme_sequence(textgrid_path, tier_name="phones", skip_silence=True):
    tg = TextGrid.fromFile(textgrid_path)
    phoneme_tier = tg.getFirst(tier_name)

    sequence = []
    for interval in phoneme_tier:
        label = interval.mark.strip().lower()
        if skip_silence and label in ["sil", "sp", ""]:
            continue
        sequence.append(label.upper())
    return sequence

In [None]:
# feature extraction and storing features & labels in an npz file
num_files = 0
output_dir = '/content/train'
os.makedirs(output_dir, exist_ok=True)

for wav_file in sorted(os.listdir(wav_dir)):
  if wav_file.endswith('.wav'):
    wav_path = os.path.join(wav_dir, wav_file)
    textgrid_file = wav_file.replace('.wav', '.TextGrid')
    textgrid_path = os.path.join(textgrid_dir, textgrid_file)

    signal, sr = preprocess_audio(wav_path)
    features = extract_full_features(signal, sr)

    phonemes = extract_phoneme_sequence(textgrid_path)
    phoneme_ids = [phoneme_to_id[p] for p in phonemes if p in phoneme_to_id]

    np.savez(
      os.path.join(output_dir, f"{os.path.splitext(wav_file)[0]}.npz"),
      features=features,
      labels=np.array(phoneme_ids)
    )

    del signal, features, phonemes, phoneme_ids
    gc.collect()
    torch.cuda.empty_cache()

    num_files += 1
    print(f"Processed {wav_file}, {num_files} files processed")

In [None]:
shutil.make_archive("/content/drive/output/path/for/features", 'zip', output_dir)

In [None]:
class PhonemeDataset(Dataset):
    def __init__(self, npz_file_paths):
        self.paths = npz_file_paths

    def __getitem__(self, idx):
        data = np.load(self.paths[idx])
        features = data['features']
        labels = data['labels']
        return torch.tensor(features, dtype=torch.float32), torch.tensor(labels, dtype=torch.long)

    def __len__(self):
        return len(self.paths)

In [None]:
def collate_fn(batch):
    features, labels = zip(*batch)
    feat_lengths = torch.tensor([f.shape[0] for f in features])
    label_lengths = torch.tensor([len(l) for l in labels])

    padded_feats = pad_sequence(features, batch_first=True)
    concatenated_labels = torch.cat(labels)

    return padded_feats, concatenated_labels, feat_lengths, label_lengths

# MODEL TRAINING AND TESTING

In [None]:
# load features here if downloaded to local storage
drive.mount('/content/drive')

with zipfile.ZipFile('/content/drive/path/to/features.zip', 'r') as zip_ref:
    zip_ref.extractall("train")

In [None]:
npz_files = sorted(os.listdir('/content/train'))
print("Extracted NPZ files:")
for file in npz_files:
    print(file)

print(f"Total number of WAV files: {len(npz_files)}")

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
# Hyperparameters
input_dim = 41
hidden_dim = 128
output_dim = len(phoneme_to_id)
num_layers = 2
num_epochs = 30
batch_size = 16
learning_rate = 1e-3
dropout = 0.4

## Model Architecture

In [None]:
class BiLSTM_CTC(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, num_layers=num_layers, dropout=dropout):
        super().__init__()

        self.hidden_dim = hidden_dim
        self.num_layers = num_layers

        self.lstm = nn.LSTM(input_dim, hidden_dim, num_layers, batch_first=True, bidirectional=True, dropout=dropout)
        self.fc = nn.Linear(hidden_dim * 2, output_dim)
        self.log_softmax = nn.LogSoftmax(dim=-1)

    def forward(self, x):
        h0 = torch.zeros(self.num_layers * 2, x.size(0), self.hidden_dim).to(device)
        c0 = torch.zeros(self.num_layers * 2, x.size(0), self.hidden_dim).to(device)

        x, _ = self.lstm(x, (h0, c0))
        x = self.fc(x)
        return self.log_softmax(x)

In [None]:
# convert probability outputs into phoneme sequence
def ctc_greedy_decode(log_probs, blank=43, suppress_blanks=True):
    probs = torch.exp(log_probs)
    pred = torch.argmax(probs, dim=-1)

    decoded = []
    for b in range(pred.size(1)):
        sequence = []
        prev_token = -1
        for t in range(pred.size(0)):
            token = pred[t, b].item()
            if token != blank:
                if token != prev_token:
                    sequence.append(token)
            elif not suppress_blanks:
                sequence.append(blank)
            prev_token = token
        decoded.append(sequence)
    return decoded

In [None]:
# LR warmup
class WarmupScheduler:
    def __init__(self, optimizer, warmup_steps, base_lr):
        self.optimizer    = optimizer
        self.warmup_steps = warmup_steps
        self.base_lr      = base_lr
        self.step_count   = 0

    def step(self):
        self.step_count += 1
        if self.step_count <= self.warmup_steps:
            scale = self.step_count / self.warmup_steps
            for pg in self.optimizer.param_groups:
                pg["lr"] = self.base_lr * scale

## Training Loop

In [None]:
def train_ctc(model, dataloader, val_loader, num_epochs=num_epochs, lr=learning_rate, blank_id=43, patience=5, clip_norm=5.0):

    model = model.to(device)
    criterion = nn.CTCLoss(blank=blank_id, zero_infinity=True)
    optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-4)

    warmup_scheduler = WarmupScheduler(optimizer, warmup_steps=500, base_lr=lr)
    plateau_scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.3, patience=2, min_lr=1e-6)

    best_val_loss = float('inf')
    epochs_without_improvement = 0
    best_model = None
    accuracy_graph = []
    trainloss_graph = []
    validationsloss_graph = []

    for epoch in range(num_epochs):
        model.train()
        running_loss = running_ctc = running_ent = 0
        total_dist = 0
        total_len = 0

        for features, labels, feat_lens, label_lens in dataloader:
            features, labels = features.to(device), labels.to(device)
            feat_lens, label_lens = feat_lens.to(device), label_lens.to(device)
            log_probs = model(features).permute(1, 0, 2)
            ctc_loss = criterion(log_probs, labels, feat_lens, label_lens)

            probs = log_probs.exp()
            entropy = -torch.sum(probs * log_probs, dim=-1)

            mask = (
                torch.arange(entropy.size(0), device=feat_lens.device)
                .unsqueeze(1) < feat_lens.unsqueeze(0)
            )

            entropy_reg = (entropy * mask).sum() / mask.sum()
            loss = ctc_loss - 0.0001 * entropy_reg

            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), clip_norm)
            optimizer.step()
            warmup_scheduler.step()

            running_ctc  += ctc_loss.item()
            running_ent  += entropy_reg.item()
            running_loss += loss.item()

            # Accuracy (edit distance)
            with torch.no_grad():
                decoded_preds = ctc_greedy_decode(log_probs, blank=blank_id)
                label_splits = torch.split(labels, label_lens.cpu().numpy().tolist())
                for pred_seq, target_seq in zip(decoded_preds, label_splits):
                    target_seq = target_seq.cpu().numpy().tolist()
                    dist = levenshtein(pred_seq, target_seq)
                    total_dist += dist
                    total_len += len(target_seq)

        acc = 100 * (1 - total_dist / total_len) if total_len > 0 else 0.0
        PER = 100 * (total_dist / total_len) if total_len > 0 else 0.0
        avg_train_loss = running_loss / len(dataloader)
        avg_ctc_loss = running_ctc / len(dataloader)
        avg_ent_loss = running_ent / len(dataloader)
        accuracy_graph.append(acc)
        trainloss_graph.append(avg_train_loss)
        print(f"Epoch {epoch+1} - Train Loss: {avg_train_loss:.4f}, CTC Loss: {avg_ctc_loss:.4f}, Entropy Loss: {avg_ent_loss:.4f} - Accuracy: {acc:.2f}%, PER: {PER:.2f}%")

        # ---- VALIDATION ----
        model.eval()
        val_loss = 0
        with torch.no_grad():
            for features, labels, feat_lens, label_lens in val_loader:
                features, labels = features.to(device), labels.to(device)
                feat_lens, label_lens = feat_lens.to(device), label_lens.to(device)
                log_probs = model(features).permute(1, 0, 2)
                loss = criterion(log_probs, labels, feat_lens, label_lens)
                val_loss += loss.item()

                decoded_preds = ctc_greedy_decode(log_probs, blank=blank_id)
                label_splits = torch.split(labels, label_lens.cpu().numpy().tolist())
                for pred_seq, target_seq in zip(decoded_preds, label_splits):
                    target_seq = target_seq.cpu().numpy().tolist()
                    dist = levenshtein(pred_seq, target_seq)
                    total_dist += dist
                    total_len += len(target_seq)

        PER = 100 * (total_dist / total_len) if total_len > 0 else 0.0
        avg_val_loss = val_loss / len(val_loader)
        validationsloss_graph.append(avg_val_loss)
        print(f"Epoch {epoch+1} - Val Loss: {avg_val_loss:.4f} - PER: {PER:.2f}")

        plateau_scheduler.step(avg_val_loss)

        # Print LR
        print(f"Current Learning Rate: {optimizer.param_groups[0]['lr']:.6f}")

        # ---- EARLY STOPPING ----
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            epochs_without_improvement = 0
            best_model = model.state_dict()
        else:
            epochs_without_improvement += 1
            print(f"No improvement for {epochs_without_improvement} epoch(s).")
            if epochs_without_improvement >= patience:
                print("Early stopping triggered.")
                break

    # ---- Load Best Model ----
    if best_model is not None:
        model.load_state_dict(best_model)

    # Plot training and validation accuracy and loss
    plt.figure(figsize=(12, 4))
    plt.subplot(1, 3, 1)
    plt.plot(accuracy_graph, label='Training Accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Value')
    plt.title('Training Accuracy')
    plt.grid(True)
    plt.legend()

    plt.subplot(1, 3, 2)
    plt.plot(trainloss_graph, label='Training Loss')
    plt.grid(True)
    plt.xlabel('Epoch')
    plt.ylabel('Value')
    plt.title('Training Loss')
    plt.legend()

    plt.subplot(1, 3, 3)
    plt.plot(validationsloss_graph, label='Validation Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Value')
    plt.title('Validation Loss')
    plt.grid(True)
    plt.legend()

    plt.show()

    return model

In [None]:
# organizes train-val-test splits based on phoneme distribution on all audio files

# Load filenames from CSVs
def extract_book_utt_key(filename):
    parts = filename.split("-")
    book = parts[2]  # book number
    utt = parts[3]   # utterance number
    return f"{book}-{utt}"

# Load original speaker splits
train_keys = pd.read_csv("/content/train_files.csv")["Filename"].apply(extract_book_utt_key).tolist()
val_keys   = pd.read_csv("/content/val_files.csv")["Filename"].apply(extract_book_utt_key).tolist()
test_keys  = pd.read_csv("/content/test_files.csv")["Filename"].apply(extract_book_utt_key).tolist()

npz_dir = "/content/train"
train_npz, val_npz, test_npz = [], [], []

for fname in os.listdir(npz_dir):
    if fname.endswith(".npz"):
        parts = fname.split("-")
        if len(parts) >= 4:
            key = f"{parts[2]}-{parts[3]}"
            if key in train_keys:
                train_npz.append(fname)
            elif key in val_keys:
                val_npz.append(fname)
            elif key in test_keys:
                test_npz.append(fname)

# Create full file paths
train_files = [os.path.join(npz_dir, f) for f in train_npz]
val_files   = [os.path.join(npz_dir, f) for f in val_npz]
test_files  = [os.path.join(npz_dir, f) for f in test_npz]

train_dataset = PhonemeDataset(train_files)
test_dataset = PhonemeDataset(test_files)
val_dataset = PhonemeDataset(val_files)

print(len(train_files))
print(len(test_files))
print(len(val_files))

for filez in sorted(val_files):
    print(filez)

In [None]:
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)

In [None]:
model = BiLSTM_CTC(input_dim, hidden_dim, output_dim)

In [None]:
trained_model = train_ctc(model, train_loader, val_loader)

## Test Loop

In [None]:
# alignment with stress-aware matching
def strip_stress(p):
    return re.sub(r"\d", "", p)

def align_sequences(pred_seq, ref_seq):
    n, m = len(ref_seq), len(pred_seq)
    dp = np.zeros((n + 1, m + 1))
    backtrace = [[None]*(m + 1) for _ in range(n + 1)]

    # Initialize
    for i in range(n + 1):
        dp[i][0] = i
        backtrace[i][0] = 'del'
    for j in range(m + 1):
        dp[0][j] = j
        backtrace[0][j] = 'ins'
    backtrace[0][0] = None

    # Fill DP table
    for i in range(1, n + 1):
        for j in range(1, m + 1):
            if ref_seq[i-1] == pred_seq[j-1]:
                cost = 0
                op = 'ok'
            elif strip_stress(ref_seq[i-1]) == strip_stress(pred_seq[j-1]):
                cost = 1
                op = 'stress'
            else:
                cost = 1
                op = 'sub'

            options = [
                (dp[i-1][j-1] + cost, op),
                (dp[i-1][j] + 1, 'del'),
                (dp[i][j-1] + 1, 'ins'),
            ]
            dp[i][j], backtrace[i][j] = min(options, key=lambda x: x[0])

    # Backtrace
    i, j = n, m
    alignment = []
    while i > 0 or j > 0:
        op = backtrace[i][j]
        if op == 'ok' or op == 'stress' or op == 'sub':
            alignment.append((ref_seq[i-1], pred_seq[j-1], op))
            i -= 1
            j -= 1
        elif op == 'del':
            alignment.append((ref_seq[i-1], None, 'del'))
            i -= 1
        elif op == 'ins':
            alignment.append((None, pred_seq[j-1], 'ins'))
            j -= 1

    alignment.reverse()
    return alignment

def visualize_alignment(ref_seq, pred_seq, phoneme_to_letter=None):
    alignment = align_sequences(pred_seq, ref_seq)

    ref_line = "REF : "
    pred_line = "PRED: "
    mark_line = "      "

    for ref, pred, op in alignment:
        # Map phonemes using dictionary
        if phoneme_to_letter:
            ref = phoneme_to_letter.get(ref, ref) if ref is not None else None
            pred = phoneme_to_letter.get(pred, pred) if pred is not None else None

        ref_token = f"{ref:<5}" if ref is not None else "     "
        pred_token = f"{pred:<5}" if pred is not None else "     "

        if op == "ok":
            mark = "✅"
        elif op == "sub":
            mark = "🔄"
        elif op == "ins":
            mark = "➕"
        elif op == "del":
            mark = "➖"
        elif op == "stress":
            mark = "⚠"
        else:
            mark = "?"

        ref_line += ref_token
        pred_line += pred_token
        mark_line += f"{mark:<5}"

    print(ref_line)
    print(pred_line)
    print(mark_line)

In [None]:
def evaluate_ctc(model, dataloader, blank_id=43):
    model.eval()
    total_dist = 0
    total_len = 0
    all_preds = []
    all_refs = []
    id2phoneme = {v: k for k, v in phoneme_to_id.items()}
    os.makedirs("/content/results", exist_ok=True)
    save_dir = "/content/results"

    with torch.no_grad():
        for features, labels, feat_lens, label_lens in dataloader:
            features = features.to(device)
            labels = labels.to(device)
            feat_lens = feat_lens.to(device)
            label_lens = label_lens.to(device)

            log_probs = model(features)
            log_probs = log_probs.permute(1, 0, 2)
            decoded_preds = ctc_greedy_decode(log_probs, blank=blank_id)

            label_splits = torch.split(labels, label_lens.cpu().numpy().tolist())

            for pred_seq, ref_seq in zip(decoded_preds, label_splits):
                ref_seq1 = ref_seq.cpu().numpy().tolist()
                dist = levenshtein(pred_seq, ref_seq1)
                total_dist += dist
                total_len += len(ref_seq1)

                ref_str = [id2phoneme[i.item()] for i in ref_seq]
                pred_str = [id2phoneme[i] for i in pred_seq]

                alignment = align_sequences(pred_str, ref_str)
                for ref, pred, op in alignment:
                    if ref is None or pred is None:
                        continue
                    all_refs.append(ref)
                    all_preds.append(pred)


    # Filter out <blank> pairs for metric evaluation
    valid_pairs = [(p, r) for p, r in zip(all_preds, all_refs) if p != "<BLANK>" and r != "<BLANK>"]
    filtered_preds = [p for p, r in valid_pairs]
    filtered_refs = [r for p, r in valid_pairs]

    # Create label set and encode to indices
    phonemes = sorted(set(filtered_preds + filtered_refs))
    phoneme_to_idx = {p: i for i, p in enumerate(phonemes)}
    y_pred = [phoneme_to_idx[p] for p in filtered_preds]
    y_true = [phoneme_to_idx[r] for r in filtered_refs]

    # Classification report as dict and dataframe
    report_dict = classification_report(y_true, y_pred, target_names=phonemes, zero_division=0, output_dict=True)
    report_df = pd.DataFrame(report_dict).transpose()
    phoneme_df = report_df.loc[phonemes]
    report_df.to_csv(f"{save_dir}/phoneme_classification_report.csv")
    print(f"Saved classification report to {save_dir}/phoneme_classification_report.csv")

    # F1 score bar plot
    f1_scores = report_df.loc[phonemes, "f1-score"]
    sorted_scores = f1_scores.sort_values(ascending=True)
    plt.figure(figsize=(10, len(sorted_scores) * 0.2))
    sns.barplot(x=sorted_scores.values, y=sorted_scores.index, palette="viridis")
    plt.xlabel("F1 Score")
    plt.title("Phoneme F1 Scores")
    plt.tight_layout()
    plt.savefig(f"{save_dir}/f1_score_plot.png")
    print(f"Saved F1 score plot to {save_dir}/f1_score_plot.png")
    plt.close()

    # Confusion matrix
    cm = confusion_matrix(y_true, y_pred)
    fig, ax = plt.subplots(figsize=(15, 15))
    disp = ConfusionMatrixDisplay(cm, display_labels=phonemes)
    disp.plot(xticks_rotation=90, ax=ax, colorbar=True)
    plt.title("Phoneme Confusion Matrix")
    plt.tight_layout()
    plt.savefig(f"{save_dir}/confusion_matrix.png")
    print(f"Saved confusion matrix image to {save_dir}/confusion_matrix.png")
    plt.close()

    # Top Confused Phoneme Pairs
    cm_df = pd.DataFrame(cm, index=phonemes, columns=phonemes)
    confusions = cm_df.stack().reset_index()
    confusions.columns = ['True', 'Pred', 'Count']
    confusions = confusions[confusions['True'] != confusions['Pred']]
    top_confusions = confusions.sort_values('Count', ascending=False).head(10)
    print("Top Confused Phoneme Pairs:\n", top_confusions)
    top_confusions.to_csv(f"{save_dir}/top_confusions.csv", index=False)
    print(f"Saved top confused phoneme pairs to {save_dir}/top_confusions.csv")

    # Least Predicted Phonemes
    least_predicted = report_df.sort_values("support").head(5)
    least_predicted.to_csv(f"{save_dir}/least_predicted_phonemes.csv")
    print(f"Saved least predicted phonemes to {save_dir}/least_predicted_phonemes.csv")

    cm_df = pd.DataFrame(cm, index=phonemes, columns=phonemes)
    fp = cm_df.sum(axis=0) - cm_df.values.diagonal()  # false positives per predicted label
    fn = cm_df.sum(axis=1) - cm_df.values.diagonal()  # false negatives per reference label

    # — Top-10 most frequent phonemes
    top10 = phoneme_df["support"].sort_values(ascending=False).head(10).index.tolist()
    cm_top10 = cm_df.loc[top10, top10]

    fig, ax = plt.subplots(figsize=(8, 8))
    sns.heatmap(
        cm_top10,
        annot=True, fmt="d", cmap="Blues",
        xticklabels=top10, yticklabels=top10,
        cbar_kws={"shrink": 0.5}
    )
    ax.set_xlabel("Predicted")
    ax.set_ylabel("True")
    ax.set_title("Confusion Matrix (Top-10 Frequent Phonemes)")
    plt.tight_layout()
    plt.savefig(f"{save_dir}/confusion_matrix_top10.png")
    plt.close()

    # — 5 best / 5 worst by F1
    f1 = phoneme_df["f1-score"]
    best5  = f1.sort_values(ascending=False).head(5)
    worst5 = f1.sort_values(ascending=True).head(5)
    f1_subset = pd.concat([best5, worst5])

    fig, ax = plt.subplots(figsize=(6, 6))
    sns.barplot(x=f1_subset.values, y=f1_subset.index, orient="h", ax=ax)
    ax.set_xlabel("F1 Score")
    ax.set_ylabel("Phoneme")
    ax.set_title("Top 5 & Bottom 5 Phoneme F1 Scores")
    plt.tight_layout()
    plt.savefig(f"{save_dir}/f1_scores_best5_worst5.png")
    plt.close()

    # False positive/negative stats
    error_stats = pd.DataFrame({
        "False Positives": fp,
        "False Negatives": fn,
        "Support": cm_df.sum(axis=1)
    })
    error_stats.to_csv(f"{save_dir}/phoneme_false_pos_neg.csv")
    print(f"Saved phoneme false positive/negative stats to {save_dir}/phoneme_false_pos_neg.csv")

    # Compute precision, recall, F1
    precision = precision_score(y_true, y_pred, average='macro', zero_division=0)
    recall = recall_score(y_true, y_pred, average='macro', zero_division=0)
    f1 = f1_score(y_true, y_pred, average='macro', zero_division=0)
    accuracy = accuracy_score(y_true, y_pred)

    print("Test Metrics:")
    print(f"Precision: {precision:.4f} | Recall: {recall:.4f} | F1 Score: {f1:.4f} | Accuracy sklearn: {accuracy:.2f}%")
    print("\nClassification Report:")
    print(classification_report(y_true, y_pred, target_names=phonemes, zero_division=0))

    acc = 100 * (1 - total_dist / total_len) if total_len > 0 else 0.0
    print(f"\nAccuracy: {acc:.2f}%")
    print(f"PER: {100 - acc:.2f}%")
    return acc

evaluate_ctc(trained_model, test_loader)

In [None]:
torch.save(trained_model, "/content/drive/path/to/model.pt")

In [None]:
from google.colab import runtime
runtime.unassign()

In [None]:
# Saving weights
torch.save(trained_model.state_dict(), "bilstm_ctc_model_weights.pt")

In [None]:
gc.collect()
torch.cuda.empty_cache()