In [12]:
import os
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchaudio
import pandas as pd
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader, Dataset
from torch.nn.utils.rnn import pad_sequence
from tqdm import tqdm

# Set random seeds for reproducibility
torch.manual_seed(42)
random.seed(42)


In [None]:
!wget https://www.openslr.org/resources/12/train-clean-100.tar.gz -P ./
!tar -xf train-clean-100.tar.gz -C ./

!wget https://www.openslr.org/resources/12/dev-clean.tar.gz -P ./
!tar -xf dev-clean.tar.gz -C ./

!wget https://www.openslr.org/resources/12/test-clean.tar.gz -P ./
!tar -xf test-clean.tar.gz -C ./



--2024-11-28 14:54:04--  https://www.openslr.org/resources/12/train-clean-100.tar.gz
Resolving www.openslr.org (www.openslr.org)... 46.101.158.64
Connecting to www.openslr.org (www.openslr.org)|46.101.158.64|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://openslr.elda.org/resources/12/train-clean-100.tar.gz [following]
--2024-11-28 14:54:05--  https://openslr.elda.org/resources/12/train-clean-100.tar.gz
Resolving openslr.elda.org (openslr.elda.org)... 141.94.109.138, 2001:41d0:203:ad8a::
Connecting to openslr.elda.org (openslr.elda.org)|141.94.109.138|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 6387309499 (5.9G) [application/x-gzip]
Saving to: ‘./train-clean-100.tar.gz’


2024-11-28 14:57:40 (28.4 MB/s) - ‘./train-clean-100.tar.gz’ saved [6387309499/6387309499]

--2024-11-28 14:58:01--  https://www.openslr.org/resources/12/dev-clean.tar.gz
Resolving www.openslr.org (www.openslr.org)... 46.101.158.64
Connecting to www.

In [None]:
# Specify the base dataset directory
base_dir = "/kaggle/working/LibriSpeech"

# Define specific dataset directories
train_dir = os.path.join(base_dir, "train-clean-100")
dev_dir = os.path.join(base_dir, "dev-clean")
test_dir = os.path.join(base_dir, "test-clean")

In [15]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchaudio
import os
import pandas as pd
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader, Dataset
from torch.nn.utils.rnn import pad_sequence
from tqdm import tqdm
import random
import csv

# Conformer Block
class ConformerBlock(nn.Module):
    def __init__(self, d_model, nhead, ff_dim, kernel_size, dropout=0.1):
        super().__init__()
        self.ff1 = nn.Sequential(
            nn.LayerNorm(d_model),
            nn.Linear(d_model, ff_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(ff_dim, d_model),
            nn.Dropout(dropout)
        )
        self.attention = nn.MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=True)
        self.conv_module = nn.Sequential(
            nn.Conv1d(d_model, d_model, kernel_size, padding=kernel_size // 2, groups=d_model),
            nn.BatchNorm1d(d_model),
            nn.GELU(),
            nn.Conv1d(d_model, d_model, kernel_size=1),
            nn.BatchNorm1d(d_model)
        )
        self.ff2 = nn.Sequential(
            nn.LayerNorm(d_model),
            nn.Linear(d_model, ff_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(ff_dim, d_model),
            nn.Dropout(dropout)
        )
        self.norm = nn.LayerNorm(d_model)

    def forward(self, x, src_key_padding_mask=None):
        # First feedforward
        x = x + 0.5 * self.ff1(x)

        # Multi-head attention
        attn_output, _ = self.attention(x, x, x, key_padding_mask=src_key_padding_mask)
        x = x + attn_output

        # Convolutional module
        x_conv = x.permute(0, 2, 1)  # Convert to (B, C, T) for Conv1d
        x_conv = self.conv_module(x_conv)
        x_conv = x_conv.permute(0, 2, 1)  # Back to (B, T, C)
        x = x + x_conv

        # Second feedforward
        x = x + 0.5 * self.ff2(x)

        # Final normalization
        x = self.norm(x)
        return x

# Simple Conformer Model
class SimpleConformer(nn.Module):
    def __init__(self, nheads=1, d_model=768, num_layers=1, ff_dim=2048, kernel_size=15, dropout=0.1):
        super().__init__()
        bundle = torchaudio.pipelines.WAV2VEC2_ASR_BASE_960H
        self.sr = bundle.sample_rate
        self.labels = bundle.get_labels()
        self.num_chars = len(self.labels)
        self.ƒ = bundle.get_model().extract_features  # Only the feature extractor

        # Project input features to d_model for Conformer
        self.linear = nn.Linear(768, d_model)
        self.conformers = nn.ModuleList(
            [ConformerBlock(d_model, nheads, ff_dim, kernel_size, dropout) for _ in range(num_layers)]
        )

        # Final output layer to predict characters
        self.char_out = nn.Linear(d_model, self.num_chars)

    def forward(self, x):
        mask = self.create_mask(x)
        x = self.linear(x)  # Project to d_model dimension
        for conformer in self.conformers:
            x = conformer(x, src_key_padding_mask=mask)
        x = self.char_out(x)
        return F.log_softmax(x, dim=-1)

    def create_mask(self, x):
        # Create mask for padded positions (where first feature dim is zero)
        return (x[:, :, 0] == PAD_IDX)  # True where padding is applied 

PAD_IDX = 0  # Use zero for padding

# Data Augmentation Functions
def add_noise(waveform, noise_level=0.005):
    noise = torch.randn_like(waveform)
    augmented = waveform + noise_level * noise
    augmented = augmented.clamp(-1.0, 1.0)
    return augmented

def random_gain(waveform, min_gain=0.8, max_gain=1.2):
    gain = random.uniform(min_gain, max_gain)
    return waveform * gain

def time_stretch(waveform, rate=1.0):
    if rate == 1.0:
        return waveform
    stretched_waveform = torchaudio.transforms.Resample(orig_freq=decoder.sr, new_freq=int(decoder.sr * rate))(waveform)
    return stretched_waveform

def pitch_shift(waveform, sample_rate, n_steps=0):
    return torchaudio.transforms.PitchShift(sample_rate=sample_rate, n_steps=n_steps)(waveform)

# Custom Dataset with Data Augmentation
class DataSet(Dataset):
    def __init__(self, local_df, encoder, decoder, augment=False):
        self.df = local_df
        self.sr = decoder.sr
        self.decoder = decoder
        self.encoder = encoder
        self.augment = augment

    def __getitem__(self, index):
        dir = self.df.iloc[index].dir
        ids = self.df.iloc[index].ids
        labels = self.df.iloc[index].labels
        labels = labels.replace(" ", "|")  # Mark spaces with "|"

        # Convert characters to indices
        char_to_idx = {char: idx for idx, char in enumerate(self.decoder.labels)}
        ground_truth_token = torch.tensor([char_to_idx[c] for c in labels if c in char_to_idx])

        # Load audio and extract features
        waveform, sample_rate = torchaudio.load(dir)
        if sample_rate != self.decoder.sr:
            waveform = torchaudio.functional.resample(waveform, sample_rate, self.decoder.sr)

        if self.augment:
            waveform = self.apply_augmentation(waveform)

        with torch.no_grad():
            latent, _ = self.encoder(waveform)

        latent = latent[-1].squeeze()
        return latent, ground_truth_token

    def apply_augmentation(self, waveform):
        # Apply random augmentations
        if random.random() < 0.5:
            waveform = add_noise(waveform)
        if random.random() < 0.5:
            waveform = random_gain(waveform)
        if random.random() < 0.5:
            rate = random.uniform(0.9, 1.1)
            waveform = time_stretch(waveform, rate)
        if random.random() < 0.5:
            n_steps = random.randint(-2, 2)
            waveform = pitch_shift(waveform, self.sr, n_steps)
        return waveform

    def __len__(self):
        return len(self.df)

# Collate function to handle padding in batches
def collate_fn(batch):
    latents, labels = zip(*batch)
    latents_padded = pad_sequence(latents, batch_first=True, padding_value=PAD_IDX)
    labels_padded = pad_sequence(labels, batch_first=True, padding_value=PAD_IDX)
    return latents_padded, labels_padded

# Function to process dataset directories
def process_dataset(data_dir):
    flac_files = []
    labels_files = []
    for dir in os.listdir(data_dir):
        speaker_dir = os.path.join(data_dir, dir)
        if os.path.isdir(speaker_dir):
            for chapter in os.listdir(speaker_dir):
                chapter_dir = os.path.join(speaker_dir, chapter)
                if os.path.isdir(chapter_dir):
                    for file in os.listdir(chapter_dir):
                        if file.endswith(".flac"):
                            flac_files.append(os.path.join(chapter_dir, file))
                        if file.endswith(".txt"):
                            labels_files.append(os.path.join(chapter_dir, file))
    # Process labels and create dataframe
    dirs = []
    all_labels = []
    for labels in labels_files:
        with open(labels, "r") as file:
            lines = file.readlines()
        lines = [line.strip("\n") for line in lines]
        for _ in range(len(lines)):
            dirs.append(labels)
        all_labels.extend(lines)
    ids = [x.split(" ")[0] for x in all_labels]
    true_labels = [" ".join(x.split(" ")[1:]) for x in all_labels]
    flac_dict = {flac.split("/")[-1].split(".flac")[0]: flac for flac in flac_files}
    organized_flac = []
    for id in tqdm(ids):
        if id in flac_dict:
            organized_flac.append(flac_dict[id])
        else:
            print(f"ERROR: {id} not found in flac files.")
    df = pd.DataFrame()
    df["dir"] = organized_flac
    df["ids"] = ids
    df["labels"] = true_labels
    return df

# Initialize model to access sample rate for augmentation transforms
# This is necessary because some transforms like pitch_shift require sample rate
decoder_init = SimpleConformer(nheads=4, d_model=768, num_layers=6, ff_dim=2048, kernel_size=31, dropout=0.1)
bundle = torchaudio.pipelines.WAV2VEC2_ASR_BASE_960H
decoder_init.sr = bundle.sample_rate


# Process train dataset
print("Processing train dataset...")
df_train = process_dataset(train_dir)

# Process validation dataset (dev-clean)
print("Processing validation dataset...")
df_valid = process_dataset(dev_dir)

# Process test dataset (test-clean)
print("Processing test dataset...")
df_test = process_dataset(test_dir)

# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Initialize model
decoder = SimpleConformer(nheads=2, d_model=384, num_layers=3, ff_dim=1024, kernel_size=31, dropout=0.1)

# Create datasets and dataloaders with augmentation for training set
train_ds = DataSet(df_train, decoder_init.ƒ, decoder, augment=True)
valid_ds = DataSet(df_valid, decoder_init.ƒ, decoder, augment=False)
test_ds = DataSet(df_test, decoder_init.ƒ, decoder, augment=False)

train_dl = DataLoader(train_ds, batch_size=16, shuffle=True, collate_fn=collate_fn)
valid_dl = DataLoader(valid_ds, batch_size=16, shuffle=False, collate_fn=collate_fn)
test_dl = DataLoader(test_ds, batch_size=16, shuffle=False, collate_fn=collate_fn)

# Optimizer and Loss function
optim = torch.optim.Adam(decoder.parameters(), lr=0.0001)
criterion = nn.CTCLoss(blank=0, zero_infinity=False)

# Number of epochs
num_epochs = 5

# Path to latest checkpoint
checkpoint_path = 'latest_model_checkpoint.pth'

# Load checkpoint if it exists
if os.path.exists(checkpoint_path):
    print("Loading checkpoint...")
    checkpoint = torch.load(checkpoint_path)
    decoder.load_state_dict(checkpoint['model_state_dict'])
    start_epoch = checkpoint['epoch']
    train_losses = checkpoint['train_losses']
    valid_losses = checkpoint['valid_losses']
else:
    print("No checkpoint found, starting training from scratch.")
    start_epoch = 0
    train_losses = []
    valid_losses = []

# Training Loop with Prediction Output
decoder.to(device)

for epoch in range(start_epoch, num_epochs):
    print(f"Epoch {epoch+1}")
    # Training Phase
    decoder.train()
    train_running_loss = 0

    # Open CSV file for training predictions
    with open(f'train_predictions_epoch_{epoch+1}.csv', 'w', newline='', encoding='utf-8') as train_csvfile:
        train_csv_writer = csv.writer(train_csvfile)
        train_csv_writer.writerow(['Prediction', 'Ground Truth'])

        for latents, labels in tqdm(train_dl):
            optim.zero_grad()
            latents, labels = latents.to(device), labels.to(device)
            prediction = decoder(latents)
            prediction = prediction.permute(1, 0, 2)  # (T, N, C)

            # Compute input and target lengths
            input_lengths = torch.tensor([l[l[:, 0] != PAD_IDX].shape[0] for l in latents]).to(device)
            target_lengths = torch.tensor([len(lbl[lbl != PAD_IDX]) for lbl in labels], dtype=torch.long).to(device)

            loss = criterion(prediction, labels, input_lengths, target_lengths)
            loss.backward()
            optim.step()

            train_running_loss += loss.item()

            # Decode the predictions and ground truth
            max_indices = torch.argmax(prediction, dim=2)  # Get predicted characters as indices
            predicted_str = "".join([decoder.labels[idx] for idx in max_indices[:, 0].cpu().numpy() if idx != PAD_IDX])
            predicted_str = predicted_str.replace("-", "").replace("|", " ")  # Clean up padding and special tokens

            # Decode the ground truth for comparison
            gt_str = "".join([decoder.labels[idx.item()] for idx in labels[0] if idx != PAD_IDX])
            gt_str = gt_str.replace("-", "").replace("|", " ")

            print(f"Batch Prediction: {predicted_str}")
            print(f"Batch Ground Truth: {gt_str}")
            print("")
            train_csv_writer.writerow([predicted_str, gt_str])

    train_losses.append(train_running_loss / len(train_dl))

    # Validation Phase
    decoder.eval()
    valid_running_loss = 0

    # Open CSV file for validation predictions
    with open(f'valid_predictions_epoch_{epoch+1}.csv', 'w', newline='', encoding='utf-8') as valid_csvfile:
        valid_csv_writer = csv.writer(valid_csvfile)
        valid_csv_writer.writerow(['Prediction', 'Ground Truth'])

        with torch.no_grad():
            for latents, labels in tqdm(valid_dl):
                latents, labels = latents.to(device), labels.to(device)
                prediction = decoder(latents)
                prediction = prediction.permute(1, 0, 2)

                input_lengths = torch.tensor([l[l[:, 0] != PAD_IDX].shape[0] for l in latents]).to(device)
                target_lengths = torch.tensor([len(lbl[lbl != PAD_IDX]) for lbl in labels], dtype=torch.long).to(device)

                loss = criterion(prediction, labels, input_lengths, target_lengths)
                valid_running_loss += loss.item()

                # Decode the predictions and ground truth
                max_indices = torch.argmax(prediction, dim=2)  # (T, N)max_indices = torch.argmax(prediction, dim=2)  # Get predicted characters as indices
        predicted_str = "".join([decoder.labels[idx] for idx in max_indices[:, 0].cpu().numpy() if idx != PAD_IDX])
        predicted_str = predicted_str.replace("-", "").replace("|", " ")  # Clean up padding and special tokens

        # Decode the ground truth for comparison
        gt_str = "".join([decoder.labels[idx.item()] for idx in labels[0] if idx != PAD_IDX])
        gt_str = gt_str.replace("-", "").replace("|", " ")

        print(f"Batch Prediction: {predicted_str}")
        print(f"Batch Ground Truth: {gt_str}")
        print("")
        valid_csv_writer.writerow([predicted_str, gt_str])

    valid_losses.append(valid_running_loss / len(valid_dl))

    print(f"Epoch {epoch+1} Summary - Train Loss: {train_losses[-1]:.4f}, Valid Loss: {valid_losses[-1]:.4f}")

    # Save model checkpoint
    checkpoint = {
        "epoch": epoch + 1,
        "model_state_dict": decoder.state_dict(),
        "train_losses": train_losses,
        "valid_losses": valid_losses
    }
    torch.save(checkpoint, f"{epoch+1}_model_checkpoint.pth")
    torch.save(checkpoint, checkpoint_path)  # Save latest checkpoint

print("Training completed!")



Downloading: "https://download.pytorch.org/torchaudio/models/wav2vec2_fairseq_base_ls960_asr_ls960.pth" to /home/nrelab-titan/.cache/torch/hub/checkpoints/wav2vec2_fairseq_base_ls960_asr_ls960.pth
100%|██████████| 360M/360M [00:05<00:00, 66.9MB/s] 


Processing train dataset...


FileNotFoundError: [Errno 2] No such file or directory: '/kaggle/working/LibriSpeech/train-clean-100'

In [None]:
# Function to perform CTC decoding
def ctc_decode(predictions, blank_index=0):
    # predictions: list or numpy array of label indices
    output = []
    prev_label = None
    for l in predictions:
        if l != blank_index and l != prev_label:
            output.append(l)
        prev_label = l
    return output

# Function to compute edit distance
def edit_distance(seq1, seq2):
    # Initialize the matrix
    m = len(seq1)
    n = len(seq2)
    D = [[0]*(n+1) for _ in range(m+1)]
    for i in range(m+1):
        D[i][0] = i
    for j in range(n+1):
        D[0][j] = j
    # Compute the edit distance
    for i in range(1, m+1):
        for j in range(1, n+1):
            if seq1[i-1] == seq2[j-1]:
                cost = 0
            else:
                cost = 1
            D[i][j] = min(D[i-1][j] +1,      # Deletion
                          D[i][j-1] +1,      # Insertion
                          D[i-1][j-1] + cost)  # Substitution
    return D[m][n]



In [None]:
%pip install jiwer

In [None]:
%

In [None]:
from jiwer import wer

# Evaluate on Test Set
print("Evaluating on test dataset...")
decoder.eval()
test_running_loss = 0
total_distance = 0
total_words = 0

manual_pred_sentences = []
manual_gt_sentences = []

with torch.no_grad():
    for latents, labels in tqdm(test_dl):
        latents, labels = latents.to(device), labels.to(device)
        prediction = decoder(latents)
        prediction = prediction.permute(1, 0, 2)

        input_lengths = torch.tensor([l[l[:, 0] != PAD_IDX].shape[0] for l in latents]).to(device)
        target_lengths = torch.tensor([len(lbl[lbl != PAD_IDX]) for lbl in labels], dtype=torch.long).to(device)

        loss = criterion(prediction, labels, input_lengths, target_lengths)
        test_running_loss += loss.item()

        # Decode predictions and ground truths
        max_indices = torch.argmax(prediction, dim=2)
        for i in range(latents.size(0)):
            predicted_indices = max_indices[:, i].cpu().numpy()
            predicted_indices = ctc_decode(predicted_indices, blank_index=0)
            predicted_str = "".join([decoder.labels[idx] for idx in predicted_indices])
            predicted_str = predicted_str.replace("-", "").replace("|", " ").strip()

            gt_indices = labels[i]
            gt_str = "".join([decoder.labels[idx.item()] for idx in gt_indices if idx != PAD_IDX])
            gt_str = gt_str.replace("-", "").replace("|", " ").strip()

            # Collect sentences for `jiwer` calculation
            manual_pred_sentences.append(predicted_str)
            manual_gt_sentences.append(gt_str)

            # Compute manual WER
            pred_words = predicted_str.strip().split()
            gt_words = gt_str.strip().split()
            distance = edit_distance(pred_words, gt_words)
            total_distance += distance
            total_words += len(gt_words)

# Calculate manual WER
manual_wer_score = total_distance / total_words
print(f"Test Loss: {test_running_loss / len(test_dl):.4f}")
print(f"Manual WER: {manual_wer_score:.4f}")

# Calculate WER using jiwer
jiwer_wer_score = wer(manual_gt_sentences, manual_pred_sentences)
print(f"jiwer WER: {jiwer_wer_score:.4f}")
