# 1- Installing & Importing Necessary Packages & Wav2Vec Model:

In [None]:
!pip install torch torchaudio transformers datasets

In [None]:
import torch
from torch import nn
from transformers import Wav2Vec2Processor, Wav2Vec2Model
from torch.utils.data import Dataset, DataLoader
from datasets import load_dataset, concatenate_datasets
import soundfile as sf
import torchaudio

In [None]:
processor = Wav2Vec2Processor.from_pretrained("othrif/wav2vec2-large-xlsr-arabic")
wav2vec_model = Wav2Vec2Model.from_pretrained("othrif/wav2vec2-large-xlsr-arabic")

In [None]:
from huggingface_hub import notebook_login

notebook_login()

# 2- Downloading & Processing Datasets:

## Downloading & Processing Common Voice 13 Dataset:

In [None]:
common_voice = load_dataset("mozilla-foundation/common_voice_13_0", "ar", split="train+validation+test")

In [None]:
print(common_voice)

In [None]:
common_voice = common_voice.remove_columns(["client_id", "audio", "up_votes", "down_votes", "age", "gender", "accent", "locale", "segment", "variant"])

print(common_voice)

In [None]:
print(common_voice[0])

## Downloading & Processing Google Fleurs Dataset:

In [None]:
fleurs_train = load_dataset("google/fleurs", "ar_eg", split="train")

In [None]:
fleurs_test = load_dataset("google/fleurs", "ar_eg", split="test")

In [None]:
fleurs_val = load_dataset("google/fleurs", "ar_eg", split="validation")

In [None]:
print(len(fleurs_train))
print(len(fleurs_test))
print(len(fleurs_val))

In [None]:
print(fleurs_train)

In [None]:
def update_audio_path_train(data_item):
  parts = data_item["path"].split('/')
  parts.insert(-1, "train")
  data_item["path"] = '/'.join(parts)
  data_item["sentence"] = data_item["transcription"]
  return data_item
def update_audio_path_test(data_item):
  parts = data_item["path"].split('/')
  parts.insert(-1, "test")
  data_item["path"] = '/'.join(parts)
  data_item["sentence"] = data_item["transcription"]
  return data_item
def update_audio_path_val(data_item):
  parts = data_item["path"].split('/')
  parts.insert(-1, "dev")
  data_item["path"] = '/'.join(parts)
  data_item["sentence"] = data_item["transcription"]
  return data_item

In [None]:
fleurs_train = fleurs_train.map(update_audio_path_train)

In [None]:
fleurs_test = fleurs_test.map(update_audio_path_test)

In [None]:
fleurs_val = fleurs_val.map(update_audio_path_val)

In [None]:
fleurs_train = fleurs_train.remove_columns(["id", "num_samples", "audio", "transcription", "raw_transcription", "gender", "lang_id", "language", "lang_group_id"])

In [None]:
fleurs_test = fleurs_test.remove_columns(["id", "num_samples", "audio", "transcription", "raw_transcription", "gender", "lang_id", "language", "lang_group_id"])

In [None]:
fleurs_val = fleurs_val.remove_columns(["id", "num_samples", "audio", "transcription", "raw_transcription", "gender", "lang_id", "language", "lang_group_id"])

In [None]:
print(fleurs_train[0])
print(fleurs_test[0])
print(fleurs_val[0])

# 3- Combining Datasets & Creating The Vocabulary:

## Creating The Combined Dataset For Training & The Vocabulary:

In [None]:
combined_dataset = concatenate_datasets([common_voice, fleurs_train, fleurs_test])

In [None]:
print(len(common_voice))
print(len(fleurs_train))
print(len(fleurs_test))
print(len(fleurs_val))
print(len(combined_dataset))

In [None]:
print(combined_dataset)

In [None]:
print(combined_dataset[0])

In [None]:
vocabulary = ['ا', 'ب', 'ت', 'ث', 'ج', 'ح', 'خ', 'د', 'ذ', 'ر', 'ز', 'س', 'ش', 'ص', 'ض', 'ط', 'ظ', 'ع', 'غ', 'ف', 'ق', 'ك', 'ل', 'م', 'ن', 'ه', 'و', 'ي', 'ء', 'آ', 'أ', 'إ', 'ؤ', 'ئ', 'ة', 'ى', 'ﻻ', 'ﻷ', 'ﻹ', 'ﻵ',' ', '.']

## Processing The Training & Validation Datasets:

In [None]:
def process_transcriptions(data_item):
  new_sentence = ''.join([char for char in data_item["sentence"] if char in vocabulary])
  data_item["sentence"] = new_sentence
  return data_item

In [None]:
combined_dataset = combined_dataset.map(process_transcriptions)

In [None]:
print(combined_dataset[0])

In [None]:
fleurs_val = fleurs_val.map(process_transcriptions)

# 4- Creating The Dataset Class:

In [None]:
class CustomDataset(Dataset):
  def __init__(self, dataset, processor):
    self.dataset = dataset
    self.processor = processor

  def __len__(self):
    return len(self.dataset)

  def __getitem__(self, idx):
    audio_input, sampling_rate = torchaudio.load(self.dataset[idx]["path"])
    if sampling_rate != 16000:
      resampler = torchaudio.transforms.Resample(orig_freq=sampling_rate, new_freq=16000)
      audio_input = resampler(audio_input)

    # input_values = self.processor(audio_input.squeeze(), sampling_rate=16000, return_tensors="pt").input_values.squeeze()
    input_values = self.processor(audio_input, sampling_rate=16000, return_tensors="pt").input_values
    input_values = input_values.squeeze(0)
    input_length = input_values.shape[1]
    # print(len(input_values))
    # print(input_values)
    # print(input_values.shape)
    # print(input_length)

    labels = self.processor.tokenizer.encode(self.dataset[idx]["sentence"])
    labels = torch.tensor(labels)
    # print(labels)
    labels = torch.tensor(labels)

    return input_values, labels, input_length

In [None]:
custom_data_set = CustomDataset(combined_dataset, processor)
custom_val_data_set = CustomDataset(fleurs_val, processor)

# 5- Model Architecture:

In [None]:
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence

In [None]:
import torch.nn.functional as F

In [None]:
class CustomSTTModel(nn.Module):
  def __init__(self, wav2vec_model, lstm_hidden_size, lstm_layers, attention_heads):
    super(CustomSTTModel, self).__init__()
    self.wav2vec = wav2vec_model

    for param in self.wav2vec.parameters():
      param.requires_grad = False

    feature_size = self.wav2vec.config.hidden_size

    self.lstm = nn.LSTM(input_size=feature_size,
                        hidden_size=lstm_hidden_size,
                        num_layers=lstm_layers,
                        batch_first=True)

    self.attention = nn.MultiheadAttention(embed_dim=lstm_hidden_size,
                                            num_heads=attention_heads,
                                            batch_first=True)

    self.output_layer = nn.Linear(lstm_hidden_size, wav2vec_model.config.vocab_size)

  def forward(self, input_values, input_lengths):
    self.wav2vec.eval()
    with torch.no_grad():
      # print("input_values: ", len(input_values[0]))
      # print("input_values: ", input_values[0])
      # print("input_values: ", input_values[0].shape)
      # print("input_lengths: ", input_lengths)
      # print("vocab size: ", self.wav2vec.config.vocab_size)
      wav2vec_output = self.wav2vec(input_values).last_hidden_state

    processed_lengths = wav2vec_output.shape[1]
    processed_lengths = torch.full((wav2vec_output.shape[0],), processed_lengths, dtype=torch.int64)
    # print("processed_lengths: ", processed_lengths)
    sorted_lengths, sorted_indices = input_lengths.sort(descending=True)
    sorted_wav2vec_output = wav2vec_output[sorted_indices]

    # print("feature size: ", self.wav2vec.config.hidden_size)
    # print("Sorted lengths:", sorted_lengths)
    # print("Shape of wav2vec_output:", wav2vec_output.shape)
    # print("Shape of sorted_wav2vec_output:", sorted_wav2vec_output.shape)

    packed_input = pack_padded_sequence(sorted_wav2vec_output, processed_lengths.cpu(), batch_first=True)
    # print("packed_input shape", packed_input)
    if input_values.is_cuda:
      self.lstm.flatten_parameters()

    packed_lstm_output, _ = self.lstm(packed_input)

    lstm_output, _ = pad_packed_sequence(packed_lstm_output, batch_first=True)

    attention_output, _ = self.attention(lstm_output, lstm_output, lstm_output)

    output = self.output_layer(attention_output)

    # print("sorted_lengths:", sorted_lengths)
    # print("sorted_wav2vec_output:", sorted_wav2vec_output)
    # print("packed_input:", packed_input)
    # print("packed_lstm_output:", packed_lstm_output)
    # print("lstm_output:", lstm_output)
    # print("attention_output:", attention_output)
    # print("output:", output)
    return output

In [None]:
custom_model = CustomSTTModel(wav2vec_model, lstm_hidden_size=128, lstm_layers=2, attention_heads=4)

# 6- Training Phase:

## Creating The Collate Function & Initializing Train & Validation Loaders:

In [None]:
from torch.nn.utils.rnn import pad_sequence
def collate_fn(batch):
  batch = sorted(batch, key=lambda x: x[2], reverse=True)

  input_values, labels, input_lengths = zip(*batch)
  # print(input_values)
  # print(labels)
  # print(input_lengths)

  input_values_padded = pad_sequence([iv.squeeze() for iv in input_values], batch_first=True)

  labels_padded = pad_sequence(labels, batch_first=True)

  input_lengths = torch.tensor([iv.shape[1] for iv in input_values], dtype=torch.long)

  # print(input_values_padded)
  # print(labels_padded)
  # print(input_lengths)

  return input_values_padded, labels_padded, input_lengths

In [None]:
batch_size = 8
train_loader = DataLoader(custom_data_set, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
validation_loader = DataLoader(custom_val_data_set, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)

## Creating The Validation Loop:

In [None]:
def validate(model, data_loader, criterion, device):
  model.eval()
  total_loss = 0
  with torch.no_grad():
    for batch in data_loader:
      input_values, labels, input_lengths = batch
      input_values, labels = input_values.to(device), labels.to(device)
      output = model(input_values, input_lengths.cpu())
      output_lengths = output.shape[1]
      output_lengths = torch.full((output.shape[0],), output_lengths, dtype=torch.int64)

      log_probs = torch.nn.functional.log_softmax(output, dim=2)
      log_probs = log_probs.permute(1, 0, 2)
      label_lengths = torch.tensor([len(label) for label in labels], dtype=torch.long, device=device)

      loss = criterion(log_probs, labels, output_lengths, label_lengths)
      total_loss += loss.item()

  avg_loss = total_loss / len(data_loader)
  return avg_loss

## Creating The Training Loop:

In [None]:
best_loss = 10000

def train(model, data_loader, val_loader, criterion, optimizer, epochs, device):
  global best_loss
  model.to(device)

  for epoch in range(epochs):
    model.train()
    batch_idx = 0
    total_loss = 0
    for batch in data_loader:
      input_values, labels, input_lengths = batch
      input_values, labels = input_values.to(device), labels.to(device)

      output = model(input_values, input_lengths.cpu())  # Added to the CPU; as it produced errors on GPU
      # print("output:", output.shape)
      # print("labels:", labels.shape)
      output_lengths = output.shape[1]
      output_lengths = torch.full((output.shape[0],), output_lengths, dtype=torch.int64)
      # print("output_lengths", output_lengths)
      # print("input_lengths", input_lengths)

      log_probs = torch.nn.functional.log_softmax(output, dim=2) # Because CTC expects log softmax probabilities
      log_probs = log_probs.permute(1, 0, 2) # The output of the network needs to be in the shape (output sequence length, batch, classes)

      label_lengths = torch.tensor([len(label) for label in labels], dtype=torch.long, device=device)

      # print("Output shape:", output.shape)
      # print("Labels:", labels)
      # print("Label lengths:", label_lengths)
      # print(output_lengths)
      # print(log_probs.shape)
      # print(label_lengths)
      # print(labels.shape)

      loss = criterion(log_probs, labels, output_lengths, label_lengths)

      total_loss += loss.item()

      optimizer.zero_grad()
      loss.backward()
      torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=0.5)
      optimizer.step()

      print(f"Epoch [{epoch+1}/{epochs}], Batch [{batch_idx+1}/{len(data_loader)}], Loss: {loss.item():.4f}")
      batch_idx+=1

    avg_loss = total_loss / len(data_loader)

    avg_val_loss = validate(model, val_loader, criterion, device)
    print("validation loss: ", avg_val_loss)

    if epoch % 2 == 0:
      torch.save(model.state_dict(), f"model_state_dict_epoch_{epoch}.pth")

    if avg_loss < best_loss:
      best_loss = avg_loss
      torch.save(model.state_dict(), "best_model_state_dict.pth")

## Training & Validation:

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
from torch import nn
import torch.optim as optim

criterion = nn.CTCLoss(blank=processor.tokenizer.pad_token_id).to(device)

In [None]:
optimizer = optim.Adam(custom_model.parameters(), lr=0.001)

In [None]:
epochs = 10
train(custom_model, train_loader, validation_loader, criterion, optimizer, epochs, device)

# 7- Inference Phase:

## Loading Model Weights:

In [None]:
model = CustomSTTModel(wav2vec_model, lstm_hidden_size=128, lstm_layers=2, attention_heads=4)
model.load_state_dict(torch.load("model_state_dict.pth"))
model.to(device)
model.eval()

## Selecting Specific Model Outputs to Decode:

In [None]:
def select_outputs(outputs, blank_label):
  arg_maxes = torch.argmax(outputs, dim=2)
  # print(arg_maxes.shape)
  decodes = []
  for i in range(arg_maxes.size(0)):
    decode = []
    for j in range(arg_maxes.size(1)):
      if arg_maxes[i][j] != blank_label:
        if j != 0 and arg_maxes[i][j-1] == arg_maxes[i][j]:
          continue
        decode.append(arg_maxes[i][j].item())
    decodes.append(decode)
  return decodes

## Inference:

In [None]:
from torch.nn.functional import log_softmax

In [None]:
def predict(model, processor, input_values, input_length, device):
  model.eval()
  input_values = input_values.to(device)

  with torch.no_grad():
    logits = model(input_values, input_length)

    log_probs = log_softmax(logits, dim=2)

    decoded_preds = select_outputs(log_probs, blank_label=processor.tokenizer.pad_token_id)

    decoded_text = [processor.decode(pred) for pred in decoded_preds]

  return decoded_text

In [None]:
def process_audio_file(file_path, processor, target_sample_rate=16000):
  audio_input, sampling_rate = torchaudio.load(file_path)

  if audio_input.shape[0] == 2: # Convert stereo to mono
    audio_input = torch.mean(audio_input, dim=0, keepdim=True)

  if sampling_rate != target_sample_rate:
    resampler = torchaudio.transforms.Resample(orig_freq=sampling_rate, new_freq=target_sample_rate)
    audio_input = resampler(audio_input)

  input_values = processor(audio_input, sampling_rate=target_sample_rate, return_tensors="pt").input_values
  input_values = input_values.squeeze(0)
  input_length = input_values.shape[1]
  input_length = torch.tensor([input_length], dtype=torch.long)
  return input_values, input_length

In [None]:
file_path = "r2.wav"
input_values, input_length = process_audio_file(file_path, processor)

In [None]:
print(input_values.shape)
print(input_length)

In [None]:
predictions = predict(custom_model, processor, input_values, input_length, device)
print(predictions)