<a href="https://colab.research.google.com/github/Signeemmanuel/Adapter-Based-Fine-Tuning/blob/master/Adapter_Based_Fine_Tuning_for_Low_Resource_Languages.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Environment Setup

We begin by installing the necessary libraries. The core of this solution relies on the Hugging Face ecosystem:
*   **Transformers:** For the Wav2Vec2 model architecture.
*   **Datasets:** To handle audio data efficiently.
*   **PEFT (Parameter-Efficient Fine-Tuning):** This is the crucial library. It allows us to inject "Adapters" (specifically LoRA) without unfreezing the main model.
*   **JiWER:** To calculate Word Error Rate (WER), the standard metric for ASR.

In [1]:
!pip install -q torch torchaudio torchcodec
!pip install -q transformers datasets peft evaluate jiwer librosa accelerate
!pip install -q huggingface_hub

!pip install -q soundfile librosa

# Install system libraries (Critical for .webm files)
!sudo apt-get update -y
!sudo apt-get install -y ffmpeg libsndfile1

print("Libraries installed successfully.")

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.1/2.1 MB[0m [31m35.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m84.1/84.1 kB[0m [31m3.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.2/3.2 MB[0m [31m55.2 MB/s[0m eta [36m0:00:00[0m
Hit:1 https://cli.github.com/packages stable InRelease
Get:2 https://cloud.r-project.org/bin/linux/ubuntu jammy-cran40/ InRelease [3,632 B]
Get:3 https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64  InRelease [1,581 B]
Get:4 https://r2u.stat.illinois.edu/ubuntu jammy InRelease [6,555 B]
Hit:5 http://archive.ubuntu.com/ubuntu jammy InRelease
Get:6 http://security.ubuntu.com/ubuntu jammy-security InRelease [129 kB]
Get:7 http://archive.ubuntu.com/ubuntu jammy-updates InRelease [128 kB]
Get:8 https://cloud.r-project.org/bin/linux/ubuntu jammy-cran40/ Packages [83.6 kB]
Get:9 https://developer.download.nvidia.com/compute/

# Imports and Configuration

In [2]:
import os
import json
import re
import random
import torch
import librosa
import numpy as np
import pandas as pd
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Union
import tarfile
import glob
import shutil
from tqdm.auto import tqdm
import gc
from datasets import load_from_disk, concatenate_datasets


# Hugging Face
from huggingface_hub import snapshot_download
from datasets import load_dataset, Dataset, Audio
from transformers import (
    Wav2Vec2CTCTokenizer,
    Wav2Vec2FeatureExtractor,
    Wav2Vec2Processor,
    Wav2Vec2ForCTC,
    TrainingArguments,
    Trainer,
)
from peft import LoraConfig, get_peft_model, TaskType
import evaluate

In [3]:
# # @title Install TPU Dependencies (Only if using TPU Runtime)
# try:
#   import torch_xla
#   print("Torch XLA is already installed.")
# except ImportError:
#   print("Installing Torch XLA for TPU...")
#   !pip install -q torch_xla

In [4]:
# Setup Device (TPU/GPU/CPU) & Seeding
# 1. Intelligent Device Detection
try:
    # Try importing PyTorch XLA (for TPU)
    import torch_xla.core.xla_model as xm
    device = xm.xla_device()
    device_type = "tpu"
    print(f"Using Device: TPU ({device})")
except ImportError:
    # Fallback to CUDA or CPU
    if torch.cuda.is_available():
        device = "cuda"
        device_type = "cuda"
    else:
        device = "cpu"
        device_type = "cpu"
    print(f"Using Device: {device}")

# 2. Set Seed for Reproducibility
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)

    if device_type == "cuda":
        torch.cuda.manual_seed_all(seed)
    elif device_type == "tpu":
        # XLA specific seeding (optional but good practice)
        import torch_xla.core.xla_model as xm
        xm.set_rng_state(seed)

set_seed(42)

Using Device: cuda


In [5]:
MAP_BATCH_SIZE = 32
WRITER_BATCH_SIZE = 32
NUM_PROC = 1               # Keep at 1 for Colab (Multiprocessing eats RAM)

PER_DEVICE_TRAIN_BATCH_SIZE = 2  # Small batch to prevent GPU OOM
PER_DEVICE_EVAL_BATCH_SIZE = 2
GRADIENT_ACCUMULATION_STEPS = 4  # Accumulate gradients to simulate larger batch (2 * 4 = 8 effective batch)

NUM_EPOCHS = 10
LEARNING_RATE = 1e-3
OUTPUT_DIR = "./adapter_output"

print(f"Configuration Set:\nProcessing Batch: {MAP_BATCH_SIZE}\nTraining Batch: {PER_DEVICE_TRAIN_BATCH_SIZE} (Effective: {PER_DEVICE_TRAIN_BATCH_SIZE * GRADIENT_ACCUMULATION_STEPS})")

Configuration Set:
Processing Batch: 32
Training Batch: 2 (Effective: 8)


# Data Acquisition

We download the dataset directly from the Hugging Face Hub.

In [6]:
from huggingface_hub import snapshot_download
import os

local_dir = "./asr_data"
repo_id = "DigitalUmuganda/ASR_Fellowship_Challenge_Dataset"

# Define what you want to allow.
# We use * wildcards to match folder structures.
allow_patterns = [
    # --- TRAINING SUBSET (Adjust [0-4] to change how many you get) ---
    "train_tarred/**/audio_[0-4].tar.xz",   # Downloads audio_0, audio_1, ... audio_4
    "train_tarred/**/manifest_[0-4].json",  # Downloads manifest_0, ... manifest_4

    # --- KEEP ALL VALIDATION & TEST (Needed for scoring) ---
    "val_tarred/**",
    "test_tarred/**",

    # --- REQUIRED METADATA ---
    ".gitattributes"
]

# We explicitly ignore images to save massive space
ignore_patterns = ["**/image_*.tar.xz", "**/images/**"]

print(f"Downloading SUBSET of {repo_id}...")
snapshot_download(
    repo_id=repo_id,
    repo_type='dataset',
    local_dir=local_dir,
    allow_patterns=allow_patterns,
    ignore_patterns=ignore_patterns,
    resume_download=True # Useful if connection breaks
)
print("Subset download complete.")

Downloading SUBSET of DigitalUmuganda/ASR_Fellowship_Challenge_Dataset...


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


Fetching 15 files:   0%|          | 0/15 [00:00<?, ?it/s]

manifest_0.json: 0.00B [00:00, ?B/s]

.gitattributes: 0.00B [00:00, ?B/s]

train_tarred/sharded_manifests_with_imag(…):   0%|          | 0.00/2.07G [00:00<?, ?B/s]

test_tarred/sharded_manifests_with_image(…):   0%|          | 0.00/522M [00:00<?, ?B/s]

train_tarred/sharded_manifests_with_imag(…):   0%|          | 0.00/2.07G [00:00<?, ?B/s]

train_tarred/sharded_manifests_with_imag(…):   0%|          | 0.00/2.07G [00:00<?, ?B/s]

manifest_0.json: 0.00B [00:00, ?B/s]

manifest_1.json: 0.00B [00:00, ?B/s]

train_tarred/sharded_manifests_with_imag(…):   0%|          | 0.00/2.07G [00:00<?, ?B/s]

manifest_2.json: 0.00B [00:00, ?B/s]

manifest_3.json: 0.00B [00:00, ?B/s]

train_tarred/sharded_manifests_with_imag(…):   0%|          | 0.00/2.06G [00:00<?, ?B/s]

manifest_4.json: 0.00B [00:00, ?B/s]

val_tarred/sharded_manifests_with_image/(…):   0%|          | 0.00/416M [00:00<?, ?B/s]

manifest_0.json: 0.00B [00:00, ?B/s]

Subset download complete.


# Data Extraction (Extract & Delete Strategy)

To handle the large dataset on Google Colab, we combine **subsampling** with a **strict memory management** strategy:

1.  **Extract & Delete:** We extract one `.tar.xz` shard at a time and **immediately delete the compressed file** before moving to the next. This keeps disk usage low.
2.  **Training Data (15 Shards):** We limit the training extraction to **15 shards**. This provides a sufficiently large dataset (approx. 35k-40k samples) to train the Adapter effectively without hitting storage limits or causing timeouts.
3.  **Validation & Test Data (Full):** We extract **all** available shards for the Validation and Test sets to ensure our evaluation metrics (WER) are accurate and comparable to the full benchmark.

In [7]:
# Define the root directory
DATA_ROOT = "./asr_data"

def get_shard_id(filename):
    """Extracts the number from 'audio_15.tar.xz' -> 15"""
    match = re.search(r"audio_(\d+)", filename)
    return int(match.group(1)) if match else None

def unpack_and_load_split(split_name, max_shards=None):
    source_path = os.path.join(DATA_ROOT, split_name)
    extract_dir = os.path.abspath(f"./extracted_{split_name}")

    # 1. Cleanup previous extraction
    if os.path.exists(extract_dir):
        print(f"[{split_name}] Cleaning up previous extraction directory...")
        shutil.rmtree(extract_dir)
    os.makedirs(extract_dir, exist_ok=True)

    # 2. Find and Sort Shards
    # We specifically look for audio_*.tar.xz
    tar_files = sorted(glob.glob(f"{source_path}/**/audio_*.tar.xz", recursive=True))

    if not tar_files:
        print(f"[{split_name}] No tar files found.")
        return None

    # 3. Apply Limit (Max Shards)
    total_shards = len(tar_files)
    if max_shards is not None and max_shards < total_shards:
        print(f"[{split_name}] Limit applied: Using {max_shards} of {total_shards} shards.")
        tar_files = tar_files[:max_shards]
    else:
        print(f"[{split_name}] Using all {total_shards} shards.")

    # 4. Extract & Track Processed IDs
    processed_ids = set()

    for tar_path in tqdm(tar_files, desc=f"Extracting {split_name}"):
        try:
            # Get ID before deleting
            shard_id = get_shard_id(os.path.basename(tar_path))
            if shard_id is not None:
                processed_ids.add(shard_id)

            # Extract
            with tarfile.open(tar_path, "r:xz") as tar:
                tar.extractall(path=extract_dir)

            # Delete to save space
            os.remove(tar_path)

        except Exception as e:
            print(f"Error on {tar_path}: {e}")

    # 5. Load ONLY Matching Manifests
    # We look for manifest_X.json where X is in processed_ids
    print(f"[{split_name}] Loading manifests for shards: {sorted(list(processed_ids))}")

    df_list = []

    # Helper to find manifest file for a specific ID
    def find_manifest(sid):
        # Search in Source (if not inside tar)
        candidates = glob.glob(f"{source_path}/**/manifest_{sid}.json", recursive=True)
        # Search in Extracted (if inside tar)
        candidates += glob.glob(f"{extract_dir}/**/manifest_{sid}.json", recursive=True)
        return candidates[0] if candidates else None

    for sid in processed_ids:
        m_path = find_manifest(sid)
        if m_path:
            try:
                # Read JSON (Try Lines=True first as it's common for ASR)
                try:
                    df_temp = pd.read_json(m_path, lines=True)
                except ValueError:
                    df_temp = pd.read_json(m_path)
                df_list.append(df_temp)
            except Exception as e:
                print(f"Failed to read manifest_{sid}.json: {e}")

    if not df_list:
        print("No valid manifests loaded.")
        return None

    full_df = pd.concat(df_list, ignore_index=True)

    # 6. Optimized Path Construction (Vectorized)
    # Instead of checking os.path.exists for every file (Slow),
    # we construct the path using string operations (Fast).
    # We search for the 'audio_shards' folder once.

    # Find where the audio files actually are inside the extract dir
    audio_roots = glob.glob(f"{extract_dir}/**/audio_shards", recursive=True)
    if audio_roots:
        # If 'audio_shards' folder exists, files are likely inside it
        base_audio_dir = audio_roots[0]
    else:
        # Otherwise they are likely in the root of extract_dir
        base_audio_dir = extract_dir

    print(f"[{split_name}] Assuming audio files are in: {base_audio_dir}")

    # Vectorized path concatenation (Instant)
    # We assume the filename in manifest corresponds to the file on disk
    # Use apply only to handle potential sub-path differences if necessary,
    # but straight string concat is fastest.

    def quick_path_fix(filename):
        # Handle case where manifest says "audio/file.webm" but we have "file.webm"
        return os.path.join(base_audio_dir, os.path.basename(filename))

    full_df['audio'] = full_df['audio_filepath'].apply(quick_path_fix)

    print(f"[{split_name}] Loaded {len(full_df)} samples.")

    # 7. Create Dataset
    dataset = Dataset.from_pandas(full_df)
    dataset = dataset.cast_column("audio", Audio(sampling_rate=16000))
    return dataset


In [8]:
# Train: Use 2 shards (Subsampled for speed/space)
train_dataset = unpack_and_load_split("train_tarred", max_shards=2)

# Validation: Use ALL shards (To get accurate WER)
eval_dataset = unpack_and_load_split("val_tarred", max_shards=None)

# Test: Use ALL shards (For final submission)
test_dataset = unpack_and_load_split("test_tarred", max_shards=None)

print("\n=== Data Loading Complete ===")
print(f"Train: {len(train_dataset) if train_dataset else 0}")
print(f"Eval:  {len(eval_dataset) if eval_dataset else 0}")
print(f"Test:  {len(test_dataset) if test_dataset else 0}")

[train_tarred] Limit applied: Using 2 of 5 shards.


Extracting train_tarred:   0%|          | 0/2 [00:00<?, ?it/s]

  tar.extractall(path=extract_dir)


[train_tarred] Loading manifests for shards: [0, 1]
[train_tarred] Assuming audio files are in: /content/extracted_train_tarred
[train_tarred] Loaded 13064 samples.
[val_tarred] Using all 1 shards.


Extracting val_tarred:   0%|          | 0/1 [00:00<?, ?it/s]

  tar.extractall(path=extract_dir)


[val_tarred] Loading manifests for shards: [0]
[val_tarred] Assuming audio files are in: /content/extracted_val_tarred
[val_tarred] Loaded 1617 samples.
[test_tarred] Using all 1 shards.


Extracting test_tarred:   0%|          | 0/1 [00:00<?, ?it/s]

  tar.extractall(path=extract_dir)


[test_tarred] Loading manifests for shards: [0]
[test_tarred] Assuming audio files are in: /content/extracted_test_tarred
[test_tarred] Loaded 1569 samples.

=== Data Loading Complete ===
Train: 13064
Eval:  1617
Test:  1569


In [9]:
dir_to_delete = DATA_ROOT

if os.path.exists(dir_to_delete):
    try:
        shutil.rmtree(dir_to_delete)
        print(f"Successfully deleted: {dir_to_delete}")
    except OSError as e:
        print(f"Error deleting folder: {e}")
else:
    print(f"Directory not found: {dir_to_delete}")

Successfully deleted: ./asr_data


# Text Cleaning

ASR models trained with CTC loss generally require a simple character set.
1.  **Normalization:** We convert all text to lowercase.
2.  **Cleaning:** We remove punctuation (periods, commas, etc.) because we want the model to focus on the phonetic sounds, not grammatical structure.

In [10]:
def clean_text(batch):
    # Handle Test Set (where text is None)
    if batch["text"] is None:
        batch["text"] = ""
        return batch

    # Safety Lowercase
    batch["text"] = batch["text"].lower()

    return batch

print("Running safety check on text...")
train_dataset = train_dataset.map(clean_text)
eval_dataset = eval_dataset.map(clean_text)
test_dataset = test_dataset.map(clean_text)
print("Text verified.")

Running safety check on text...


Map:   0%|          | 0/13064 [00:00<?, ? examples/s]

Map:   0%|          | 0/1617 [00:00<?, ? examples/s]

Map:   0%|          | 0/1569 [00:00<?, ? examples/s]

Text verified.


# 6. Create Vocabulary

The model doesn't output words directly; it outputs characters. We must build a vocabulary (JSON) containing every unique character found in the dataset.

We also add special tokens. While CTC mainly uses the first three, the `Wav2Vec2Processor` requires the others to maintain compatibility with the library structure:

*   **`|`**: Represents a space between words (Visual delimiter).
*   **`[UNK]`**: Unknown character (if the model encounters a char not in the vocab).
*   **`[PAD]`**: Padding token (used to make audio batches the same length).
*   **`[CLS]`**: Classification start token (Standard in Hugging Face models).
*   **`[SEP]`**: Separator token (Standard in Hugging Face models).
*   **`[MASK]`**: Mask token (Used during pre-training, required by the config).

In [11]:
def extract_all_chars(batch):
    all_text = " ".join(batch["text"])
    vocab = list(set(all_text))
    return {"vocab": [vocab], "all_text": [all_text]}

# Extract all chars
vocab_train = train_dataset.map(extract_all_chars, batched=True, batch_size=-1, keep_in_memory=True, remove_columns=train_dataset.column_names)
vocab_eval = eval_dataset.map(extract_all_chars, batched=True, batch_size=-1, keep_in_memory=True, remove_columns=train_dataset.column_names)

# Merge and enumerate
vocab_list = list(set(vocab_train["vocab"][0]) | set(vocab_eval["vocab"][0]))
vocab_dict = {v: k for k, v in enumerate(vocab_list)}

# Add special tokens
vocab_dict["|"] = vocab_dict[" "]
del vocab_dict[" "]
vocab_dict["[UNK]"] = len(vocab_dict)
vocab_dict["[PAD]"] = len(vocab_dict)
vocab_dict["[CLS]"] = len(vocab_dict)
vocab_dict["[SEP]"] = len(vocab_dict)
vocab_dict["[MASK]"] = len(vocab_dict)

# Save to file
with open("vocab.json", "w") as vocab_file:
    json.dump(vocab_dict, vocab_file)

print(f"Vocabulary created with {len(vocab_dict)} characters.")

Map:   0%|          | 0/13064 [00:00<?, ? examples/s]

Map:   0%|          | 0/1617 [00:00<?, ? examples/s]

Vocabulary created with 33 characters.


# 7. Initialize Processor

The `Wav2Vec2Processor` wraps two things:
1.  **Feature Extractor:** Converts raw audio (arrays) into the numerical inputs the model understands.
2.  **Tokenizer:** Converts text strings into the integer IDs from our vocabulary.

In [12]:
tokenizer = Wav2Vec2CTCTokenizer("./vocab.json", unk_token="[UNK]", pad_token="[PAD]", word_delimiter_token="|")

feature_extractor = Wav2Vec2FeatureExtractor(
    feature_size=1,
    sampling_rate=16000,
    padding_value=0.0,
    do_normalize=True,
    return_attention_mask=True
)

processor = Wav2Vec2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer)
print("Processor initialized.")

Processor initialized.


# 8. Feature Extraction (RAM-Safe "Divide and Conquer")

**The Problem:** Audio processing is extremely memory-intensive. A compressed dataset (like .tar.xz) expands by 20x-50x when loaded into RAM as float arrays. On Google Colab, standard processing causes a "Memory Leak" where RAM usage climbs linearly until the session crashes.

**The Solution:** We implement a **Chunked Processing Strategy** (also known as the "Nuclear Option" for low-RAM environments).

Instead of processing the whole dataset at once, we:
1.  **Filter:** Remove empty/corrupt rows first.
2.  **Shard:** Split the dataset into small pieces (e.g., 20 chunks for training).
3.  **Process & Flush:**
    *   Load *one* small chunk into RAM.
    *   Convert Audio to Numbers (`input_values`) and Text to IDs (`labels`).
    *   **Save immediately to disk** (`save_to_disk`).
    *   **Nuke RAM:** Manually delete variables and run Garbage Collection (`gc.collect()`) to reset memory usage to zero.
4.  **Reassemble:** Finally, we load the processed chunks from the disk using pointers (which uses almost no RAM) and concatenate them.

**Column Management:**
*   **Train/Eval:** We drop the original `audio` and `text` columns to save space.
*   **Test:** We **keep** metadata columns (like filenames) because we need them to generate the submission file later.

In [13]:
# @title 8. Feature Extraction (On-the-Fly / Zero RAM Spike)

# 1. Filter empty text (Safety)
train_dataset = train_dataset.filter(lambda x: x["text"] is not None and len(x["text"].strip()) > 0)
eval_dataset = eval_dataset.filter(lambda x: x["text"] is not None and len(x["text"].strip()) > 0)

# 2. Define the "On-the-Fly" Transform
def prepare_dataset_on_the_fly(batch):
    # This runs ONLY when the GPU asks for data.
    # Batch is a dictionary of lists: {'audio': [audio1, audio2], 'text': ['txt1', 'txt2']}

    # A. Extract Audio Arrays
    audio_arrays = [x["array"] for x in batch["audio"]]

    # B. Process Audio (Input Features)
    inputs = processor(
        audio_arrays,
        sampling_rate=16000,
        padding=True  # Pad to longest in this specific mini-batch
    )
    batch["input_values"] = inputs.input_values

    # C. Process Labels (if text exists)
    if "text" in batch:
        # --- FIX: Removed "with as_target_processor():" ---
        # Use text= argument explicitly
        batch["labels"] = processor(text=batch["text"]).input_ids

    return batch

# 3. Apply the Transform (Instantaneous)
# NOTE: We do NOT remove columns here. The transform handles selecting the right data.
print("Setting up on-the-fly processing...")

train_dataset.set_transform(prepare_dataset_on_the_fly)
eval_dataset.set_transform(prepare_dataset_on_the_fly)

# For Test, we need a slightly different transform that doesn't crash on missing text
def prepare_test_on_the_fly(batch):
    audio_arrays = [x["array"] for x in batch["audio"]]
    inputs = processor(audio_arrays, sampling_rate=16000, padding=True)
    batch["input_values"] = inputs.input_values
    return batch

test_dataset.set_transform(prepare_test_on_the_fly)

print("✓ Transforms set. RAM usage is flat. Ready to train.")

Filter:   0%|          | 0/13064 [00:00<?, ? examples/s]

Filter:   0%|          | 0/1617 [00:00<?, ? examples/s]

Setting up on-the-fly processing...
✓ Transforms set. RAM usage is flat. Ready to train.


# 9. Load Base Model & Freeze

1.  We load `facebook/wav2vec2-large-xlsr-53`. This is a massive model pre-trained on 53 languages.
2.  **FREEZING:** We loop through `model.parameters()` and set `requires_grad = False`. This ensures we respect the challenge constraints and do not modify the base model weights.

In [14]:
# Load Base Model
model_id = "facebook/wav2vec2-large-xlsr-53"
print(f"Loading Base Model: {model_id}...")

model = Wav2Vec2ForCTC.from_pretrained(
    model_id,
    ctc_loss_reduction="mean",
    pad_token_id=processor.tokenizer.pad_token_id,
    vocab_size=len(processor.tokenizer),
    ignore_mismatched_sizes=True # Needed because vocab size differs from pre-training
)

# FREEZE BASE MODEL
for param in model.parameters():
    param.requires_grad = False

print("Base Model loaded and FROZEN.")

Loading Base Model: facebook/wav2vec2-large-xlsr-53...


config.json: 0.00B [00:00, ?B/s]

pytorch_model.bin:   0%|          | 0.00/1.27G [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/1.27G [00:00<?, ?B/s]

Some weights of Wav2Vec2ForCTC were not initialized from the model checkpoint at facebook/wav2vec2-large-xlsr-53 and are newly initialized: ['lm_head.bias', 'lm_head.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Base Model loaded and FROZEN.


# 10. Inject Adapters (LoRA)

We use **LoRA (Low-Rank Adaptation)**.

Instead of retraining the whole model, LoRA inserts tiny matrices into the attention layers (`q_proj`, `v_proj`). We only train these tiny matrices and the final classification layer (the `lm_head`), which maps features to our specific Kinyarwanda vocabulary.

**Why this matters:**
*   Full Fine-tuning: ~315 Million params.
*   Adapter Fine-tuning: ~1-2 Million params.

In [15]:
# @title 10. Adapter Configuration (Fixed for Wav2Vec2)
from peft import LoraConfig, get_peft_model

# --- THE FIX: Monkey-Patch the Model ---
# We manually tell the model: "If anyone asks for input embeddings, give them Nothing."
# This prevents the NotImplementedError during checkpoint saving.
model.get_input_embeddings = lambda: None

# 1. Configure LoRA
peft_config = LoraConfig(
    inference_mode=False,
    r=32,
    lora_alpha=64,
    lora_dropout=0.1,
    target_modules=["q_proj", "v_proj"],
    bias="none",
    # We explicitly ask PEFT to manage the output head (lm_head).
    # This ensures the new Kinyarwanda vocabulary weights are saved with the adapter.
    modules_to_save=["lm_head"]
)

# 2. Apply Adapter
model = get_peft_model(model, peft_config)

# 3. Verify
print("✓ Adapters injected.")
print("✓ Wav2Vec2 Input Embedding Patch applied (prevents saving crash).")
print(f"--- DATA FOR REPORT.PDF ---")
model.print_trainable_parameters()

✓ Adapters injected.
✓ Wav2Vec2 Input Embedding Patch applied (prevents saving crash).
--- DATA FOR REPORT.PDF ---
trainable params: 3,181,603 || all params: 318,656,198 || trainable%: 0.9984


# 11. Helper Functions

*   **DataCollator:** Since audio files have different lengths, we must pad them so they fit into a rectangular batch.
*   **Metric:** We load the WER metric to evaluate performance during training.

In [16]:
# @title 11. Helper Functions (Fixed Warnings)
@dataclass
class DataCollatorCTCWithPadding:
    processor: Any
    padding: Union[bool, str] = True

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        input_features = [{"input_values": feature["input_values"]} for feature in features]
        # Only pad labels if they exist
        label_features = [{"input_ids": feature["labels"]} for feature in features if "labels" in feature]

        batch = self.processor.feature_extractor.pad(
            input_features,
            padding=self.padding,
            return_tensors="pt",
        )

        if label_features:
            # --- FIX: Removed "with as_target_processor():" ---
            # Directly use the tokenizer inside the processor to pad labels
            labels_batch = self.processor.tokenizer.pad(
                label_features,
                padding=self.padding,
                return_tensors="pt",
            )

            # Mask padding with -100 to ignore in loss calculation
            labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)
            batch["labels"] = labels

        return batch

data_collator = DataCollatorCTCWithPadding(processor=processor)
wer_metric = evaluate.load("wer")

def compute_metrics(pred):
    pred_logits = pred.predictions
    pred_ids = np.argmax(pred_logits, axis=-1)
    pred.label_ids[pred.label_ids == -100] = processor.tokenizer.pad_token_id
    pred_str = processor.batch_decode(pred_ids)
    label_str = processor.batch_decode(pred.label_ids, group_tokens=False)
    wer = wer_metric.compute(predictions=pred_str, references=label_str)
    return {"wer": wer}

print("✓ Training pipeline ready (Warnings Fixed).")

Downloading builder script: 0.00B [00:00, ?B/s]



# 12. Training

We configure the `Trainer`.
*   `learning_rate=1e-3`: LoRA usually allows for a higher learning rate than full fine-tuning.
*   `fp16=True`: Uses Mixed Precision to speed up training on the GPU.
*   `num_train_epochs`: 15 epochs should be sufficient for a small dataset.

In [17]:
# @title 12. Training Execution (On-the-Fly Optimized)
from transformers import TrainingArguments, Trainer

training_args = TrainingArguments(
    output_dir="./adapter_output",
    group_by_length=True,

    # --- BATCH SIZES ---
    per_device_train_batch_size=2, # Keep small
    gradient_accumulation_steps=4,

    # --- OPTIMIZATION ---
    gradient_checkpointing=True,
    fp16=True,

    # --- ON-THE-FLY SPECIFIC SETTINGS ---
    # Since we process data during training, we allow some workers
    # BUT separate them from the main process to prevent hangs
    dataloader_num_workers=2,
    remove_unused_columns=False, # CRITICAL: Required when using set_transform

    num_train_epochs=3, # Reduce to 3 epochs just to get a result quickly!
    save_steps=100,
    eval_steps=100,
    logging_steps=50,
    learning_rate=1e-3,
    save_total_limit=2,
    report_to="none"
)


In [18]:
trainer = Trainer(
    model=model,
    data_collator=data_collator,
    args=training_args,
    compute_metrics=compute_metrics,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    # --- FIX: Changed 'tokenizer' to 'processing_class' ---
    processing_class=processor.feature_extractor,
)

print("Starting training...")
trainer.train()

Starting training...


  torch._C._get_cudnn_allow_tf32(),


Step,Training Loss
50,6.918
100,2.9178
150,2.895
200,2.8941
250,2.9136
300,2.9263
350,2.9031
400,2.8959
450,2.8999
500,2.8878




TrainOutput(global_step=4893, training_loss=2.9322018400184855, metrics={'train_runtime': 13823.9898, 'train_samples_per_second': 2.831, 'train_steps_per_second': 0.354, 'total_flos': 2.3841243212835746e+19, 'train_loss': 2.9322018400184855, 'epoch': 3.0})

# 13. Generate Submission Files

Finally, we need to satisfy the submission checklist.
1.  **Save the Adapter Weights:** We save *only* the small adapter files.
2.  **Generate Transcriptions:** We run inference on the test set using our fine-tuned model.
3.  **Baseline Comparison:** To prove our adapter works, we (temporarily) disable the adapter and run inference on the base model to show how bad the WER is without it.

In [20]:
# @title 13. Generate Submission Artifacts (Fixed)
# Create submission folder
os.makedirs("submission", exist_ok=True)

# --- 1. Save Adapter Weights ---
print("Saving Adapter Weights...")
model.save_pretrained("./submission/adapter_weights")

# --- Helper Function for Inference ---
def generate_transcriptions(dataset, model_to_use, description):
    model_to_use.eval()
    model_to_use.to(device)

    filenames = []
    predictions = []

    print(f"Running inference: {description}...")

    # We iterate manually instead of using .map() to avoid column errors with set_transform
    for i in tqdm(range(len(dataset))):
        # Accessing dataset[i] triggers the on-the-fly transform we wrote in Step 8
        item = dataset[i]

        # Get input values (generated on the fly)
        input_values = torch.tensor(item["input_values"]).unsqueeze(0).to(device)

        # Get filename for the report
        # Handle case where audio_filepath might be inside a list or direct string
        fpath = item.get("audio_filepath", f"file_{i}")

        # Forward pass
        with torch.no_grad():
            logits = model_to_use(input_values).logits

        # Decode
        pred_ids = torch.argmax(logits, dim=-1)
        pred_str = processor.batch_decode(pred_ids)[0]

        filenames.append(fpath)
        predictions.append(pred_str)

    return filenames, predictions

# --- 2. Generate Fine-Tuned Transcriptions ---
files, ft_preds = generate_transcriptions(test_dataset, model, "Fine-Tuned Model")

with open("submission/finetuned_transcriptions.txt", "w") as f:
    for path, pred in zip(files, ft_preds):
        f.write(f"File: {path}\nPrediction: {pred}\n\n")

# --- 3. Generate Base Model Transcriptions ---
print("Disabling adapters for Base Model inference...")
with model.disable_adapter():
    _, base_preds = generate_transcriptions(test_dataset, model, "Base Model")

with open("submission/base_transcriptions.txt", "w") as f:
    for path, pred in zip(files, base_preds):
        f.write(f"File: {path}\nPrediction: {pred}\n\n")

print("\n✓ Submission artifacts generated in ./submission folder.")
print("  - adapter_weights/")
print("  - finetuned_transcriptions.txt")
print("  - base_transcriptions.txt")

# NOTE: We cannot calculate Test WER because the Test set has no labels (text is None).
# If you want WER for your report, calculate it on 'eval_dataset' instead.

Saving Adapter Weights...
Running inference: Fine-Tuned Model...


  0%|          | 0/1569 [00:00<?, ?it/s]

Disabling adapters for Base Model inference...
Running inference: Base Model...


  0%|          | 0/1569 [00:00<?, ?it/s]


✓ Submission artifacts generated in ./submission folder.
  - adapter_weights/
  - finetuned_transcriptions.txt
  - base_transcriptions.txt


In [21]:
import shutil
# Zip the submission folder which contains your weights and text files
shutil.make_archive('submission_artifacts', 'zip', './submission')
print("Download 'submission_artifacts.zip' from the file browser!")

Download 'submission_artifacts.zip' from the file browser!
