In [None]:
from datasets import Dataset, concatenate_datasets, load_dataset, Audio
import pandas as pd
import os
import soundfile as sf
import librosa
import numpy as np
import matplotlib.pyplot as plt
from IPython.display import display, Audio as IPyAudio
import ipywidgets as widgets


"""Files need to be downloaded manually from the corresponding sources,
    for further informations check the paper, which point to the sources."""


### 1. Load MAILABS (x-lb) MBARNIG ###
def load_mailabs(root):
    entries = []
    for gender in ["male", "female"]:
        gender_path = os.path.join(root, gender)
        if not os.path.isdir(gender_path):
            continue
        for speaker in os.listdir(gender_path):
            speaker_path = os.path.join(gender_path, speaker)
            wav_dir = os.path.join(speaker_path, "wavs")
            metadata_file = os.path.join(speaker_path, "metadata.csv")
            if os.path.exists(metadata_file):
                df = pd.read_csv(metadata_file, sep="|", header=None, names=["filename", "transcription", "_"])
                for _, row in df.iterrows():
                    wav_path = os.path.join(wav_dir, row["filename"] + ".wav")
                    audio, sr = sf.read(wav_path)
                    audio = audio.astype(np.float32)
                    entries.append({
                        "audio": {"array": audio, "sampling_rate": sr},
                        "transcription": row["transcription"],
                        "source": "Mailabs"
                    })
    return Dataset.from_list(entries)

mailabs_ds = load_mailabs("./lb-de-fr-en-pt-12800-TTS-CORPUS/mailabs/x-lb/by_book")
print(f" Loaded {len(mailabs_ds)} samples from Mailabs")

# Load FLEURS dataset
raw_fleurs = load_dataset("google/fleurs", "lb_lu", split="train")
raw_fleurs = raw_fleurs.cast_column("audio", Audio(decode=True))

fleurs_entries = []
for sample in raw_fleurs:
    fleurs_entries.append({
        "audio": {
            "array": np.array(sample["audio"]["array"], dtype=np.float32),
            "sampling_rate": sample["audio"]["sampling_rate"]
        },
        "transcription": sample["raw_transcription"],
        "source": "FLEURS"
    })

fleurs_ds = Dataset.from_list(fleurs_entries)
print(f" Loaded {len(fleurs_ds)} samples from FLEURS")

### 3. Load RTL ###
def load_rtl(path):
    entries = []
    for split in ["dev", "test"]:
        tsv_path = os.path.join(path, split, f"{split}.tsv")
        df = pd.read_csv(tsv_path, sep="\t")
        for _, row in df.iterrows():
            audio_path = os.path.join(path, split, row["filename"])
            audio, sr = librosa.load(audio_path, sr=None)
            audio = audio.astype(np.float32)
            entries.append({
                "audio": {"array": audio, "sampling_rate": sr},
                "transcription": row["transcription"],
                "source": "RTL"
            })
    return Dataset.from_list(entries)

rtl_ds = load_rtl("./rtl")
print(f" Loaded {len(rtl_ds)} samples from RTL")

### 4. Combine
dataset_parts = [fleurs_ds, rtl_ds, mailabs_ds]
full_dataset = concatenate_datasets(dataset_parts)
print(f" Final combined dataset has {len(full_dataset)} samples.")




In [None]:
"""print(f" Loaded {len(full_dataset)} samples from combined dataset.")

# Visualize and play one sample from full_dataset
def show_example(index):
    sample = full_dataset[index]

    print(f"\n Sample {index}")
    print(f" Source: {sample['source']}")
    print(f" Transcript:\n{sample['transcription']}")

    audio = sample["audio"]["array"]
    sr = sample["audio"]["sampling_rate"]

    # Plot waveform
    plt.figure(figsize=(12, 2.5))
    plt.plot(np.linspace(0, len(audio) / sr, num=len(audio)), audio)
    plt.title("Waveform")
    plt.xlabel("Time (s)")
    plt.ylabel("Amplitude")
    plt.grid(True)
    plt.tight_layout()
    plt.show()

    # Audio playback
    display(IPyAudio(audio, rate=sr))

# Interactive slider
index_widget = widgets.IntSlider(
    value=0,
    min=0,
    max=len(full_dataset) - 1,
    step=1,
    description='Sample:',
    continuous_update=False
)

widgets.interact(show_example, index=index_widget)"""

In [None]:
from datasets import load_dataset, Audio, concatenate_datasets
from transformers import WhisperProcessor, WhisperForConditionalGeneration, Seq2SeqTrainer, Seq2SeqTrainingArguments
from transformers.models.whisper.english_normalizer import BasicTextNormalizer
import torch
from jiwer import wer as jiwer_wer
import wandb
import os
from IPython.display import Audio as PlayAudio, display #remove later


os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

# Configuration
learning_rate = 3e-5
run_name = ""
output_dir = ""
wandb.init(project="", name=run_name)

# Load datasets
train_dataset = full_dataset
eval_raw = load_dataset("google/fleurs", "lb_lu", split="validation")
eval_raw = eval_raw.cast_column("audio", Audio(decode=True))
eval_dataset = eval_raw

# Load model and processor
model_name = "openai/whisper-medium"
processor = WhisperProcessor.from_pretrained(model_name)
model = WhisperForConditionalGeneration.from_pretrained(model_name)
model.gradient_checkpointing_enable()

# Tokenizer setup
processor.tokenizer.set_prefix_tokens(language="luxembourgish", task="transcribe")
forced_decoder_ids = processor.get_decoder_prompt_ids(language="luxembourgish", task="transcribe")

# Freeze first conv layers
for param in model.model.encoder.conv1.parameters():
    param.requires_grad = False
for param in model.model.encoder.conv2.parameters():
    param.requires_grad = False

# Preprocessing
def prepare_example(example):
    audio = example["audio"]
    example["input_features"] = processor(audio["array"], sampling_rate=16000).input_features[0]
    transcription = example.get("raw_transcription", example.get("transcription", ""))
    example["labels"] = processor.tokenizer(transcription, add_special_tokens=True).input_ids #also adds EOS token
    return example

train_dataset = train_dataset.map(prepare_example, remove_columns=train_dataset.column_names)
eval_dataset = eval_dataset.map(prepare_example, remove_columns=eval_dataset.column_names)


# Data collator
def data_collator(batch):
    input_features = torch.tensor([b["input_features"] for b in batch])
    
    # Pad with a different token temporarily (like -100) so EOS stays untouched
    labels_batch = processor.tokenizer.pad(
    {"input_ids": [b["labels"] for b in batch]},
    return_tensors="pt",
    padding=True,
)
    labels_batch = labels_batch.input_ids
    labels_batch[labels_batch == processor.tokenizer.pad_token_id] = -100
    return {"input_features": input_features, "labels": labels_batch}


# Metrics
normalizer = BasicTextNormalizer()

def normalize_text(text):
    return normalizer(text.strip())

fixed_indices = [22, 56, 7]
fixed_examples = [eval_raw[i] for i in fixed_indices]
fixed_inputs = []
for example in fixed_examples:
    inputs = processor(example["audio"]["array"], sampling_rate=16000, return_tensors="pt")
    inputs["input_features"] = inputs.input_features[0]
    text = example.get("raw_transcription", example.get("transcription", ""))
    fixed_inputs.append({"input_features": inputs["input_features"], "reference": text})


def compute_metrics(pred):
    pred_ids = pred.predictions
    label_ids = pred.label_ids
    pred_str = processor.tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
    label_ids[label_ids == -100] = processor.tokenizer.pad_token_id
    label_str = processor.tokenizer.batch_decode(label_ids, skip_special_tokens=True)
    pred_str_norm = [normalize_text(p) for p in pred_str]
    label_str_norm = [normalize_text(l) for l in label_str]
    wer = float(jiwer_wer(label_str_norm, pred_str_norm))
    print(f"\n=== WER: {wer:.4f} ===")
    print("\n--- Fixed sample predictions ---")
    model.eval()
    for i, ex in enumerate(fixed_inputs):
        with torch.no_grad():
            features = ex["input_features"].unsqueeze(0).to(model.device)
            gen_ids = model.generate(
                features,
                max_length=128,
                early_stopping=True,
                eos_token_id=processor.tokenizer.eos_token_id,
                forced_decoder_ids=forced_decoder_ids,
                repetition_penalty=1.2,
                no_repeat_ngram_size=3,
                num_beams=2
            )
            pred_str = processor.tokenizer.decode(gen_ids[0], skip_special_tokens=True)
            print(f"[Sample {i}]")
            print(f"Reference: {ex['reference']}")
            print(f"Prediction: {pred_str}")
    wandb.log({"wer": wer})
    return {"wer": wer}

# Training arguments
training_args = Seq2SeqTrainingArguments(
    output_dir=output_dir,
    push_to_hub=True,
    hub_model_id="",
    report_to="wandb",
    run_name=run_name,
    per_device_train_batch_size=1,
    gradient_accumulation_steps=8,
    learning_rate=learning_rate,
    lr_scheduler_type="cosine",
    warmup_steps=500,
    num_train_epochs=5,
    save_steps=100,
    save_total_limit=2,
    logging_steps=10,
    evaluation_strategy="steps",
    eval_steps=100,
    predict_with_generate=True,
    generation_max_length=128,
    fp16=True,
)

# Trainer
trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    tokenizer=processor.tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

# Train
trainer.train()


In [None]:
processor.save_pretrained(training_args.output_dir)
processor.push_to_hub("Tun-Wellens/whisper-medium-final-test0")