In [None]:
import torch
import torchaudio
from torchaudio.transforms import Resample, Spectrogram
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor , Trainer, TrainingArguments
from datasets import load_dataset
from torch.nn.utils.rnn import pad_sequence

import matplotlib.pyplot as plt

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

# Arabic tokenizer

In [None]:
arabic_characters = "ابتثجحخدذرزسشصضطظعغفقكلمنهوي "
char2id = {char: idx for idx, char in enumerate(arabic_characters)}
id2char = {idx: char for char, idx in char2id.items()}

In [None]:
print("char2id:", char2id)
print("id2char:", id2char)
for char, idx in char2id.items():
    assert id2char[idx] == char, f"Mismatch in char2id and id2char: {char} -> {idx}"
print("Character mappings are correct.")

characters=len(arabic_characters)
print('number of arabic characters + spasse = ' , characters)

# Loads a pretrained Wav2Vec2 model

In [None]:
feature_extractor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base").feature_extractor
model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base", vocab_size=characters)

# load th arabic common voice dataset

In [None]:
dataset_train = load_dataset("mozilla-foundation/common_voice_11_0", "ar", split="train[:50%]",trust_remote_code=True) #[:50%] is to take 50% of the dataset
dataset_test = load_dataset("mozilla-foundation/common_voice_11_0", "ar", split="test[:2%]",trust_remote_code=True) # [:2%] is to take 2% of the dataset
dataset_validation = load_dataset("mozilla-foundation/common_voice_11_0", "ar", split="validation[:50%]",trust_remote_code=True) # [:50%] is to take 50% of the dataset

In [None]:
dataset_train

In [None]:
dataset_train["path"][2]

In [None]:
dataset_train["audio"][2]


In [None]:
dataset_train['audio'][2]['sampling_rate']

### the sampling rate of arabic common voice  dataset is 48k
### but the input of (facebook/wav2vec2-base) model should be 16k
### so we need to re-sample the dataset

In [None]:
resampler = Resample(48_000, 16_000)

def preprocess(batch):
    speech_array, sampling_rate = torchaudio.load(batch["path"])
    speech_array = resampler(speech_array).squeeze().numpy()

    batch["input_values"] = feature_extractor(speech_array, sampling_rate=16000, return_tensors="pt").input_values[0]
    batch["labels"] = torch.tensor([char2id[char] for char in batch["sentence"] if char in char2id])

    return batch

In [None]:
dataset_train_map = dataset_train.map(preprocess)
dataset_test_map = dataset_test.map(preprocess)
dataset_validation_map = dataset_validation.map(preprocess)

In [None]:
dataset_train_map[2]['sentence']

In [None]:
dataset_train_map[2]['labels']

# visualization the waveform and the spectrogram of the inpud before and after re-sample

In [None]:

spectrogram_transform = Spectrogram()

def plot_waveform(waveform, sampling_rate, title="Waveform"):
    plt.figure(figsize=(10, 4))
    plt.plot(waveform)
    plt.title(title)
    plt.xlabel("Sample")
    plt.ylabel("Amplitude")
    plt.grid()
    plt.show()

def plot_spectrogram(waveform, sampling_rate, title="Spectrogram"):
    spec = spectrogram_transform(torch.tensor(waveform))
    spec = spec.log2()
    plt.figure(figsize=(10, 4))
    plt.imshow(spec.squeeze(0).numpy(), cmap="viridis", origin="lower", aspect="auto")
    plt.colorbar(format="%+2.0f dB")
    plt.title(title)
    plt.xlabel("Time")
    plt.ylabel("Frequency")
    plt.show()

In [None]:
print('befor resampler sr=48k')
sample = dataset_train[7]


plot_waveform(sample["audio"]['array'], sampling_rate=48_000, title="Sample Audio Waveform")
plot_spectrogram(sample["audio"]['array'], sampling_rate=48_000, title="Sample Audio Spectrogram")
print(sample["sentence"])
print('')

###########################################################################################################################

print('after resampler sr=16k')
sample = dataset_train_map[7]
plot_waveform(sample["input_values"], sampling_rate=16_000, title="Sample Audio Waveform")
plot_spectrogram(sample["audio"]['array'], sampling_rate=16_000, title="Sample Audio Spectrogram")
print(sample["sentence"])

# analyze the output of (facebook/wav2vec2-base) model Without Training

In [None]:
def analyze_model_output(dataset):

    model.to(device)
    model.eval()

    batch = dataset[3]
    input_values = batch["input_values"]

    if isinstance(input_values, list):
        input_values = torch.tensor(input_values)

    input_values = input_values.unsqueeze(0).to(device)

    with torch.no_grad():
        logits = model(input_values).logits

    predicted_ids = torch.argmax(logits, dim=-1)
    decoded_output = "".join([id2char[id] for id in predicted_ids[0].cpu().numpy() if id in id2char])

    print("Decoded output:", decoded_output)
    print("Reference:", batch["sentence"])

In [None]:
analyze_model_output(dataset_test_map)

# Prepares batches by padding sequences.

In [None]:
def data_collator(batch):

    input_values = pad_sequence(
        [torch.tensor(item["input_values"]) if isinstance(item["input_values"], list) else item["input_values"] for item in batch],
        batch_first=True
    )

    labels = pad_sequence(
        [torch.tensor(item["labels"]) if isinstance(item["labels"], list) else item["labels"] for item in batch],
        batch_first=True,
        padding_value=-100
    )
    return {"input_values": input_values, "labels": labels}

In [None]:
training_args = TrainingArguments(
    output_dir="./wav2vec2-finetuned-ar",
    eval_strategy="epoch",
    learning_rate=1e-4,
    per_device_train_batch_size=16,
    num_train_epochs=25,
    warmup_steps=500,
    remove_unused_columns=False,
    run_name="arabic-asr-wav2vec2",
    report_to="none"
)

trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=data_collator,
    train_dataset=dataset_train_map,
    eval_dataset=dataset_validation_map,
    tokenizer=feature_extractor  ,
)
model.to(device)

# start training

In [None]:
trainer.train()

#  CTC Decoder

In [None]:

def custom_beam_search_decoder(logits, beam_width=10, blank_id=0):
    T, V = logits.size()
    beams = [("", 0.0)]

    for t in range(T):
        new_beams = []
        current_probs = torch.log_softmax(logits[t], dim=-1)

        for seq, score in beams:
            for idx in range(V):
                char = id2char.get(idx, "")

                if idx == blank_id:
                    new_seq = seq
                else:
                    new_seq = seq + char if not seq or seq[-1] != char else seq

                new_score = score + current_probs[idx].item()
                new_beams.append((new_seq, new_score))

        new_beams = sorted(new_beams, key=lambda x: x[1], reverse=True)[:beam_width]
        beams = new_beams

    best_seq = beams[0][0]

    processed_output = []
    for i, char in enumerate(best_seq):
        if i == 0 or char != best_seq[i - 1]:
            processed_output.append(char)

    return ''.join(processed_output)

def predict_with_beam_search(batch):
    inputs = feature_extractor(batch["input_values"], sampling_rate=16000, return_tensors="pt", padding=True).input_values.to(device)
    with torch.no_grad():
        logits = model(inputs).logits[0]

    logits = logits.cpu()
    decoded_output = custom_beam_search_decoder(logits, beam_width=10, blank_id=characters)
    print("Prediction Decoded output:", decoded_output)
    print("Reference:", batch["sentence"])


In [None]:
for i in range(5):
  predict_with_beam_search(dataset_test_map[i])
  print('')