<a href="https://colab.research.google.com/github/Moses05/CardiovascularDiseaseWebApp/blob/main/lingala_english_asr.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Table of Contents

In [None]:
!git clone https://github.com/Moses05/lingala-english-asr.git

Cloning into 'lingala-english-asr'...
remote: Enumerating objects: 2981, done.[K
remote: Counting objects: 100% (6/6), done.[K
remote: Compressing objects: 100% (5/5), done.[K
remote: Total 2981 (delta 0), reused 3 (delta 0), pack-reused 2975 (from 1)[K
Receiving objects: 100% (2981/2981), 432.52 MiB | 25.77 MiB/s, done.
Resolving deltas: 100% (11/11), done.


In [None]:
import torch
torch.cuda.empty_cache()

In [None]:
base_path = "lingala-english-asr/LRSC/lingala"

train_audio_path = f"{base_path}/train/audio"
train_transcript_path = f"{base_path}/train/transcript.txt"

valid_audio_path = f"{base_path}/valid/audio"
valid_transcript_path = f"{base_path}/valid/transcript.txt"

In [None]:
print(train_audio_path)

In [None]:
import os
import wave

def check_sample_rate(dir):

  wrongFramerate = []

  for wav in os.listdir(dir):
    path = f"{dir}/{wav}"

    if os.path.isfile(path):
      with wave.open(path, "rb") as wav_file:
        if wav_file.getframerate() != 16000:
          wrongFramerate.append(wav_file)

  return wrongFramerate


train_wrongSample = check_sample_rate(train_audio_path)
valid_wrongSample = check_sample_rate(valid_audio_path)

print(f"list of train audio files not 16000hz: {train_wrongSample}")
print(f"list of valid audio files not 16000hz: {valid_wrongSample}")

In [None]:
manifest_path = f"{base_path}/manifest"

dict_txt = f"{manifest_path}/dict.ltr.txt"

train_letter = f"{manifest_path}/train.ltr"
train_tsv = f"{manifest_path}/train.tsv"
train_word = f"{manifest_path}/train.wrd"

valid_letter = f"{manifest_path}/valid.ltr"
valid_tsv = f"{manifest_path}/valid.tsv"
valid_word = f"{manifest_path}/valid.wrd"

In [None]:
!pip install datasets

In [None]:
from datasets import load_dataset, Dataset
import os

def load_manifest_data(tsv_file, ltr_file, audio_path):
    data = {
        "path": [],
        "duration": [],
        "text": [],
    }

    # Load the .tsv file
    with open(tsv_file, 'r') as tsv_f:
        lines = tsv_f.readlines()[1:]  # Skip header
        # lines = tsv_f.readlines()[1:]  # Skip header
        for line in lines:
            parts = line.strip().split("\t")

            # Ensure two columns path and duration
            if len(parts) != 2:
              print(f"skipping malformed line: {line}")
              continue

            path, duration = parts
            full_path = os.path.join(audio_path, path) # prepend base path

            data["path"].append(full_path)
            data["duration"].append(int(duration) / 1000)

    # Load the .ltr file for transcriptions
    with open(ltr_file, 'r') as ltr_f:
        transcriptions = ltr_f.readlines()
        data["text"] = [trans.strip() for trans in transcriptions]

    return Dataset.from_dict(data)

# Load the training and validation datasets
train_dataset = load_manifest_data(train_tsv, train_letter, train_audio_path)
valid_dataset = load_manifest_data(valid_tsv, valid_letter, valid_audio_path)

In [None]:
# @title Default title text
from transformers import Wav2Vec2FeatureExtractor, Wav2Vec2Model
import torch
import torchaudio

# Load the feature extractor and model
# feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained("facebook/wav2vec2-large-xlsr-53")
# model = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-large-xlsr-53")

feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained("facebook/wav2vec2-base")
model = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base")

In [None]:
def load_vocab(vocab_path):
    char_to_index = {}
    index_to_char = {}

    with open(vocab_path, "r") as f:
        for line in f:
            char, index = line.strip().split()
            index = int(index)
            char_to_index[char] = index
            index_to_char[index] = char

    # Add the <blank> token if not already present
    if "<blank>" not in char_to_index:
        char_to_index["<blank>"] = len(char_to_index)
        index_to_char[len(index_to_char)] = "<blank>"

    return char_to_index, index_to_char

char_to_index, index_to_char = load_vocab(dict_txt)
vocab_size = len(char_to_index)

char_to_index, index_to_char = load_vocab(dict_txt)
vocab_size = len(char_to_index)

print(f"Loaded vocabulary with {vocab_size} characters")

In [None]:
def text_to_ids(text, char_to_index):
    return [char_to_index[char] for char in text if char in char_to_index]

def add_labels(batch):
    # Convert text into numerical labels
    batch["labels"] = text_to_ids(batch["text"], char_to_index)
    return batch

In [None]:
import torchaudio

def preprocess_audio(batch):
  waveform, sample_rate = torchaudio.load(batch["path"])

  if sample_rate != 16000:
    resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)
    waveform = resampler(waveform)

  waveform = waveform / torch.max(torch.abs(waveform))

  inputs = feature_extractor(
      waveform.squeeze().numpy(),
      sampling_rate=16000,
      return_tensors="pt",
      # padding=True,
  )

  batch["input_values"] = inputs.input_values[0].numpy()
  return batch

train_dataset = train_dataset.map(preprocess_audio, remove_columns=["path", "duration"])
valid_dataset = valid_dataset.map(preprocess_audio, remove_columns=["path", "duration"])

In [None]:
print(train_dataset)
print(valid_dataset)

In [None]:
# Add labels to the dataset
train_dataset = train_dataset.map(add_labels, remove_columns=["text"])
valid_dataset = valid_dataset.map(add_labels, remove_columns=["text"])

print(train_dataset)
print(valid_dataset)

In [None]:
def truncate_audio(sample, max_length=80000):
  if len(sample["input_values"]) > max_length:
    sample["input_values"] = sample["input_values"][:max_length]
  return sample

train_dataset = train_dataset.map(truncate_audio)
valid_dataset = valid_dataset.map(truncate_audio)

In [None]:
from torch.nn.utils.rnn import pad_sequence

def collate_fn(batch):
  input_values = [torch.tensor(sample["input_values"]) for sample in batch]
  input_values_padded = pad_sequence(input_values, batch_first=True, padding_value=0.0)

  labels = [torch.tensor(sample["labels"]) for sample in batch]
  labels_padded = pad_sequence(labels, batch_first=True, padding_value=-100)

  return {"input_values": input_values_padded, "labels": labels_padded}

In [None]:
from torch.utils.data import DataLoader

train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True, collate_fn=collate_fn)
valid_loader = DataLoader(valid_dataset, batch_size=1, collate_fn=collate_fn)

In [None]:
for batch in train_loader:
    input_values = batch["input_values"]
    print(f"Batch Shape:", input_values.shape)
    break

In [None]:
torch.cuda.empty_cache()

In [None]:
import torch
import torch.nn as nn
from transformers import Wav2Vec2Model

class Wav2Vec2CTC(nn.Module):
  def __init__(self, model, vocab_size):
    super(Wav2Vec2CTC, self).__init__()
    self.feature_extractor = model
    self.ctc_head = nn.Linear(self.feature_extractor.config.hidden_size, vocab_size)

  def forward(self, input_values):
    features = self.feature_extractor(input_values).last_hidden_state

    logits = self.ctc_head(features)
    return logits

wav2vec2_model = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-large-xlsr-53")

model = Wav2Vec2CTC(wav2vec2_model, vocab_size)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

In [None]:
import torch.optim as optim

blank_index = char_to_index["<blank>"]
print(f"Blank index: {blank_index}")

criterion = nn.CTCLoss(blank=blank_index, zero_infinity=True)
optimizer = optim.Adam(model.parameters(), lr=1e-4)

In [None]:
torch.cuda.empty_cache()

In [None]:
def train_model(model, train_loader, criterion, optimizer, num_epochs=5):
    model.train()

    for epoch in range(num_epochs):
        total_loss = 0

        for batch in train_loader:
            inputs = batch["input_values"].to(device)
            labels = batch["labels"].to(device)
            # inputs = batch["input_values"]
            # labels = batch["labels"]

            # Forward pass through the model
            logits = model(inputs)
            logits = logits.log_softmax(2).permute(1, 0, 2)  # Shape: (seq_len, batch, vocab_size)

            # Calculate input lengths based on model output
            input_lengths = torch.full((logits.size(1),), logits.size(0), dtype=torch.long)
            label_lengths = torch.sum(labels != -100, dim=1)

            # Zero gradients
            optimizer.zero_grad()

            # Compute the loss
            loss = criterion(logits, labels, input_lengths, label_lengths)
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {total_loss / len(train_loader)}")

# Run the training
train_model(model, train_loader, criterion, optimizer)

In [None]:
!pip install evaluate jiwer

In [None]:
from evaluate import load
import torch

wer_metric = load("wer")

def evaluate_model(model, valid_loader, index_to_char):
    model.eval()
    total_wer = 0
    num_batches = 0

    with torch.no_grad():
      for batch in valid_loader:
        inputs = batch["input_values"].to(device)
        labels = batch["labels"].to(device)

        # Forward pass through model to get logits
        logits = model(inputs).log_softmax(2)

        predicted_ids = torch.argmax(logits, dim=1)

        pred_texts = ["".join([index_to_char[i] for i in pred if i in index_to_char]) for pred in predicted_ids]

        label_texts = ["".join([index_to_char[i] for i in label if i in index_to_char]) for label in labels]

        non_empty_indices = [i for i, label in enumerate(label_texts) if label]

        if not non_empty_indices:
          continue

        pred_texts = [pred_texts[i] for i in non_empty_indices]
        label_texts = [label_texts[i] for i in non_empty_indices]

        wer = wer_metric.compute(predictions=pred_texts, references=label_texts)
        total_wer += wer
        num_batches += 1

    avg_wer = total_wer / num_batches if num_batches > 0 else float("inf")
    print(f"Validation WER: {avg_wer:.4f}")


evaluate_model(model, valid_loader, index_to_char)