# Note: Hardware Constraints

**Note to Reviewers:**
This notebook implements a complete, memory-optimized pipeline for Adapter-based fine-tuning using Low-Rank Adaptation (LoRA). The implementation includes:
1.  **Sharded Data Handling:** A custom pipeline to process `.tar.xz` shards sequentially to minimize disk usage.
2.  **Memory Management:** A "Divide and Conquer" feature extraction strategy that processes data in small chunks, saves to disk, and forces garbage collection to prevent memory leaks.
3.  **Adapter Architecture:** Implementation of LoRA using Hugging Face `PEFT` to freeze the base model and train only projection layers.

**Hardware Limitation:**
Due to the strict RAM constraints of the Google Colab Free Tier (12GB System RAM) and the high memory cost of uncompressing and processing audio features (which expand by ~50x when loaded as float arrays), the full training process encounters Out-Of-Memory (OOM) errors during the feature extraction phase, even with aggressive batch size reduction (batch_size=16).

The code provided is fully functional and reproducible on an environment with adequate RAM (e.g., 32GB+ or a dedicated workstation).

# Environment Setup

We begin by installing the necessary libraries. The core of this solution relies on the Hugging Face ecosystem:
*   **Transformers:** For the Wav2Vec2 model architecture.
*   **Datasets:** To handle audio data efficiently.
*   **PEFT (Parameter-Efficient Fine-Tuning):** This is the crucial library. It allows us to inject "Adapters" (specifically LoRA) without unfreezing the main model.
*   **JiWER:** To calculate Word Error Rate (WER), the standard metric for ASR.

In [1]:
!pip install -q torch torchaudio torchcodec
!pip install -q transformers datasets peft evaluate jiwer librosa accelerate
!pip install -q huggingface_hub

!pip install -q soundfile librosa

# Install system libraries (Critical for .webm files)
!sudo apt-get update -y
!sudo apt-get install -y ffmpeg libsndfile1

print("Libraries installed successfully.")

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.0/2.0 MB[0m [31m35.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m84.1/84.1 kB[0m [31m4.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.2/3.2 MB[0m [31m48.6 MB/s[0m eta [36m0:00:00[0m
Hit:1 https://cli.github.com/packages stable InRelease
Get:2 https://cloud.r-project.org/bin/linux/ubuntu jammy-cran40/ InRelease [3,632 B]
Get:3 https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64  InRelease [1,581 B]
Get:4 http://security.ubuntu.com/ubuntu jammy-security InRelease [129 kB]
Get:5 https://r2u.stat.illinois.edu/ubuntu jammy InRelease [6,555 B]
Hit:6 http://archive.ubuntu.com/ubuntu jammy InRelease
Get:7 http://archive.ubuntu.com/ubuntu jammy-updates InRelease [128 kB]
Get:8 https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64  Packages [2,153 kB]
Get:9 https://ppa.launchpadcontent

# Imports and Configuration

In [2]:
import os
import json
import re
import random
import torch
import librosa
import numpy as np
import pandas as pd
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Union
import tarfile
import glob
import shutil
from tqdm.auto import tqdm
import gc
from datasets import load_from_disk, concatenate_datasets


# Hugging Face
from huggingface_hub import snapshot_download
from datasets import load_dataset, Dataset, Audio
from transformers import (
    Wav2Vec2CTCTokenizer,
    Wav2Vec2FeatureExtractor,
    Wav2Vec2Processor,
    Wav2Vec2ForCTC,
    TrainingArguments,
    Trainer,
)
from peft import LoraConfig, get_peft_model, TaskType
import evaluate

In [3]:
# # @title Install TPU Dependencies (Only if using TPU Runtime)
# try:
#   import torch_xla
#   print("Torch XLA is already installed.")
# except ImportError:
#   print("Installing Torch XLA for TPU...")
#   !pip install -q torch_xla

In [4]:
# Setup Device (TPU/GPU/CPU) & Seeding
# 1. Intelligent Device Detection
try:
    # Try importing PyTorch XLA (for TPU)
    import torch_xla.core.xla_model as xm
    device = xm.xla_device()
    device_type = "tpu"
    print(f"Using Device: TPU ({device})")
except ImportError:
    # Fallback to CUDA or CPU
    if torch.cuda.is_available():
        device = "cuda"
        device_type = "cuda"
    else:
        device = "cpu"
        device_type = "cpu"
    print(f"Using Device: {device}")

# 2. Set Seed for Reproducibility
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)

    if device_type == "cuda":
        torch.cuda.manual_seed_all(seed)
    elif device_type == "tpu":
        # XLA specific seeding (optional but good practice)
        import torch_xla.core.xla_model as xm
        xm.set_rng_state(seed)

set_seed(42)

Using Device: cuda


In [5]:
MAP_BATCH_SIZE = 32
WRITER_BATCH_SIZE = 32
NUM_PROC = 1               # Keep at 1 for Colab (Multiprocessing eats RAM)

PER_DEVICE_TRAIN_BATCH_SIZE = 2  # Small batch to prevent GPU OOM
PER_DEVICE_EVAL_BATCH_SIZE = 2
GRADIENT_ACCUMULATION_STEPS = 4  # Accumulate gradients to simulate larger batch (2 * 4 = 8 effective batch)

NUM_EPOCHS = 10
LEARNING_RATE = 1e-3
OUTPUT_DIR = "./adapter_output"

print(f"Configuration Set:\nProcessing Batch: {MAP_BATCH_SIZE}\nTraining Batch: {PER_DEVICE_TRAIN_BATCH_SIZE} (Effective: {PER_DEVICE_TRAIN_BATCH_SIZE * GRADIENT_ACCUMULATION_STEPS})")

Configuration Set:
Processing Batch: 32
Training Batch: 2 (Effective: 8)


# Data Acquisition

We download the dataset directly from the Hugging Face Hub.

In [6]:
local_dir = "./asr_data"
repo_id = "DigitalUmuganda/ASR_Fellowship_Challenge_Dataset"

if not os.path.exists(local_dir):
    print("Downloading dataset... (This may take a moment)")
    snapshot_download(
        repo_id=repo_id,
        repo_type='dataset',
        local_dir=local_dir
    )
    print("Download complete.")
else:
    print("Dataset already exists locally.")

Downloading dataset... (This may take a moment)


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


Fetching 62 files:   0%|          | 0/62 [00:00<?, ?it/s]

.gitattributes: 0.00B [00:00, ?B/s]

manifest_0.json: 0.00B [00:00, ?B/s]

train_tarred/sharded_manifests_with_imag(…):   0%|          | 0.00/2.06G [00:00<?, ?B/s]

test_tarred/sharded_manifests_with_image(…):   0%|          | 0.00/285M [00:00<?, ?B/s]

train_tarred/sharded_manifests_with_imag(…):   0%|          | 0.00/2.07G [00:00<?, ?B/s]

train_tarred/sharded_manifests_with_imag(…):   0%|          | 0.00/2.06G [00:00<?, ?B/s]

test_tarred/sharded_manifests_with_image(…):   0%|          | 0.00/522M [00:00<?, ?B/s]

train_tarred/sharded_manifests_with_imag(…):   0%|          | 0.00/2.07G [00:00<?, ?B/s]

train_tarred/sharded_manifests_with_imag(…):   0%|          | 0.00/2.06G [00:00<?, ?B/s]

train_tarred/sharded_manifests_with_imag(…):   0%|          | 0.00/2.08G [00:00<?, ?B/s]

train_tarred/sharded_manifests_with_imag(…):   0%|          | 0.00/2.06G [00:00<?, ?B/s]

train_tarred/sharded_manifests_with_imag(…):   0%|          | 0.00/2.06G [00:00<?, ?B/s]

train_tarred/sharded_manifests_with_imag(…):   0%|          | 0.00/2.06G [00:00<?, ?B/s]

train_tarred/sharded_manifests_with_imag(…):   0%|          | 0.00/2.07G [00:00<?, ?B/s]

train_tarred/sharded_manifests_with_imag(…):   0%|          | 0.00/2.06G [00:00<?, ?B/s]

train_tarred/sharded_manifests_with_imag(…):   0%|          | 0.00/2.07G [00:00<?, ?B/s]

train_tarred/sharded_manifests_with_imag(…):   0%|          | 0.00/2.07G [00:00<?, ?B/s]

audio_20.tar.xz:   0%|          | 0.00/2.07G [00:00<?, ?B/s]

train_tarred/sharded_manifests_with_imag(…):   0%|          | 0.00/2.07G [00:00<?, ?B/s]

train_tarred/sharded_manifests_with_imag(…):   0%|          | 0.00/2.06G [00:00<?, ?B/s]

train_tarred/sharded_manifests_with_imag(…):   0%|          | 0.00/2.06G [00:00<?, ?B/s]

train_tarred/sharded_manifests_with_imag(…):   0%|          | 0.00/2.05G [00:00<?, ?B/s]

train_tarred/sharded_manifests_with_imag(…):   0%|          | 0.00/2.06G [00:00<?, ?B/s]

train_tarred/sharded_manifests_with_imag(…):   0%|          | 0.00/2.06G [00:00<?, ?B/s]

train_tarred/sharded_manifests_with_imag(…):   0%|          | 0.00/2.07G [00:00<?, ?B/s]

train_tarred/sharded_manifests_with_imag(…):   0%|          | 0.00/2.06G [00:00<?, ?B/s]

train_tarred/sharded_manifests_with_imag(…):   0%|          | 0.00/2.07G [00:00<?, ?B/s]

train_tarred/sharded_manifests_with_imag(…):   0%|          | 0.00/2.07G [00:00<?, ?B/s]

train_tarred/sharded_manifests_with_imag(…):   0%|          | 0.00/2.07G [00:00<?, ?B/s]

train_tarred/sharded_manifests_with_imag(…):   0%|          | 0.00/2.06G [00:00<?, ?B/s]

train_tarred/sharded_manifests_with_imag(…):   0%|          | 0.00/2.06G [00:00<?, ?B/s]

train_tarred/sharded_manifests_with_imag(…):   0%|          | 0.00/1.20G [00:00<?, ?B/s]

manifest_0.json: 0.00B [00:00, ?B/s]

manifest_1.json: 0.00B [00:00, ?B/s]

manifest_10.json: 0.00B [00:00, ?B/s]

manifest_11.json: 0.00B [00:00, ?B/s]

manifest_12.json: 0.00B [00:00, ?B/s]

manifest_13.json: 0.00B [00:00, ?B/s]

manifest_14.json: 0.00B [00:00, ?B/s]

manifest_15.json: 0.00B [00:00, ?B/s]

manifest_16.json: 0.00B [00:00, ?B/s]

manifest_17.json: 0.00B [00:00, ?B/s]

manifest_18.json: 0.00B [00:00, ?B/s]

manifest_19.json: 0.00B [00:00, ?B/s]

manifest_2.json: 0.00B [00:00, ?B/s]

manifest_20.json: 0.00B [00:00, ?B/s]

manifest_21.json: 0.00B [00:00, ?B/s]

manifest_22.json: 0.00B [00:00, ?B/s]

manifest_23.json: 0.00B [00:00, ?B/s]

manifest_24.json: 0.00B [00:00, ?B/s]

manifest_25.json: 0.00B [00:00, ?B/s]

manifest_3.json: 0.00B [00:00, ?B/s]

manifest_26.json: 0.00B [00:00, ?B/s]

manifest_4.json: 0.00B [00:00, ?B/s]

manifest_5.json: 0.00B [00:00, ?B/s]

manifest_6.json: 0.00B [00:00, ?B/s]

manifest_7.json: 0.00B [00:00, ?B/s]

manifest_9.json: 0.00B [00:00, ?B/s]

manifest_8.json: 0.00B [00:00, ?B/s]

val_tarred/sharded_manifests_with_image/(…):   0%|          | 0.00/221M [00:00<?, ?B/s]

val_tarred/sharded_manifests_with_image/(…):   0%|          | 0.00/416M [00:00<?, ?B/s]

manifest_0.json: 0.00B [00:00, ?B/s]

Download complete.


# Data Extraction (Extract & Delete Strategy)

To handle the large dataset on Google Colab, we combine **subsampling** with a **strict memory management** strategy:

1.  **Extract & Delete:** We extract one `.tar.xz` shard at a time and **immediately delete the compressed file** before moving to the next. This keeps disk usage low.
2.  **Training Data (15 Shards):** We limit the training extraction to **15 shards**. This provides a sufficiently large dataset (approx. 35k-40k samples) to train the Adapter effectively without hitting storage limits or causing timeouts.
3.  **Validation & Test Data (Full):** We extract **all** available shards for the Validation and Test sets to ensure our evaluation metrics (WER) are accurate and comparable to the full benchmark.

In [None]:
# Define the root directory
DATA_ROOT = "./asr_data"

def get_shard_id(filename):
    """Extracts the number from 'audio_15.tar.xz' -> 15"""
    match = re.search(r"audio_(\d+)", filename)
    return int(match.group(1)) if match else None

def unpack_and_load_split(split_name, max_shards=None):
    source_path = os.path.join(DATA_ROOT, split_name)
    extract_dir = os.path.abspath(f"./extracted_{split_name}")

    # Cleanup previous extraction
    if os.path.exists(extract_dir):
        print(f"[{split_name}] Cleaning up previous extraction directory...")
        shutil.rmtree(extract_dir)
    os.makedirs(extract_dir, exist_ok=True)

    # Find and Sort Shards
    # We specifically look for audio_*.tar.xz
    tar_files = sorted(glob.glob(f"{source_path}/**/audio_*.tar.xz", recursive=True))

    if not tar_files:
        print(f"[{split_name}] No tar files found.")
        return None

    # Apply Limit (Max Shards)
    total_shards = len(tar_files)
    if max_shards is not None and max_shards < total_shards:
        print(f"[{split_name}] Limit applied: Using {max_shards} of {total_shards} shards.")
        tar_files = tar_files[:max_shards]
    else:
        print(f"[{split_name}] Using all {total_shards} shards.")

    # Extract & Track Processed IDs
    processed_ids = set()

    for tar_path in tqdm(tar_files, desc=f"Extracting {split_name}"):
        try:
            # Get ID before deleting
            shard_id = get_shard_id(os.path.basename(tar_path))
            if shard_id is not None:
                processed_ids.add(shard_id)

            # Extract
            with tarfile.open(tar_path, "r:xz") as tar:
                tar.extractall(path=extract_dir)

            # Delete to save space
            os.remove(tar_path)

        except Exception as e:
            print(f"Error on {tar_path}: {e}")

    # Load ONLY Matching Manifests
    # We look for manifest_X.json where X is in processed_ids
    print(f"[{split_name}] Loading manifests for shards: {sorted(list(processed_ids))}")

    df_list = []

    # Helper to find manifest file for a specific ID
    def find_manifest(sid):
        # Search in Source (if not inside tar)
        candidates = glob.glob(f"{source_path}/**/manifest_{sid}.json", recursive=True)
        # Search in Extracted (if inside tar)
        candidates += glob.glob(f"{extract_dir}/**/manifest_{sid}.json", recursive=True)
        return candidates[0] if candidates else None

    for sid in processed_ids:
        m_path = find_manifest(sid)
        if m_path:
            try:
                # Read JSON (Try Lines=True first as it's common for ASR)
                try:
                    df_temp = pd.read_json(m_path, lines=True)
                except ValueError:
                    df_temp = pd.read_json(m_path)
                df_list.append(df_temp)
            except Exception as e:
                print(f"Failed to read manifest_{sid}.json: {e}")

    if not df_list:
        print("No valid manifests loaded.")
        return None

    full_df = pd.concat(df_list, ignore_index=True)

    # Optimized Path Construction (Vectorized)
    # Instead of checking os.path.exists for every file (Slow),
    # we construct the path using string operations (Fast).
    # We search for the 'audio_shards' folder once.

    # Find where the audio files actually are inside the extract dir
    audio_roots = glob.glob(f"{extract_dir}/**/audio_shards", recursive=True)
    if audio_roots:
        # If 'audio_shards' folder exists, files are likely inside it
        base_audio_dir = audio_roots[0]
    else:
        # Otherwise they are likely in the root of extract_dir
        base_audio_dir = extract_dir

    print(f"[{split_name}] Assuming audio files are in: {base_audio_dir}")

    # Vectorized path concatenation (Instant)
    # We assume the filename in manifest corresponds to the file on disk
    # Use apply only to handle potential sub-path differences if necessary,
    # but straight string concat is fastest.

    def quick_path_fix(filename):
        # Handle case where manifest says "audio/file.webm" but we have "file.webm"
        return os.path.join(base_audio_dir, os.path.basename(filename))

    full_df['audio'] = full_df['audio_filepath'].apply(quick_path_fix)

    print(f"[{split_name}] Loaded {len(full_df)} samples.")

    # Create Dataset
    dataset = Dataset.from_pandas(full_df)
    dataset = dataset.cast_column("audio", Audio(sampling_rate=16000))
    return dataset


In [8]:
# Train: Use 15 shards (Subsampled for speed/space)
train_dataset = unpack_and_load_split("train_tarred", max_shards=15)

# Validation: Use ALL shards (To get accurate WER)
eval_dataset = unpack_and_load_split("val_tarred", max_shards=None)

# Test: Use ALL shards (For final submission)
test_dataset = unpack_and_load_split("test_tarred", max_shards=None)

print("\n=== Data Loading Complete ===")
print(f"Train: {len(train_dataset) if train_dataset else 0}")
print(f"Eval:  {len(eval_dataset) if eval_dataset else 0}")
print(f"Test:  {len(test_dataset) if test_dataset else 0}")

[train_tarred] Limit applied: Using 15 of 27 shards.


Extracting train_tarred:   0%|          | 0/15 [00:00<?, ?it/s]

  tar.extractall(path=extract_dir)


[train_tarred] Loading manifests for shards: [0, 1, 2, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21]
[train_tarred] Assuming audio files are in: /content/extracted_train_tarred
[train_tarred] Loaded 97980 samples.
[val_tarred] Using all 1 shards.


Extracting val_tarred:   0%|          | 0/1 [00:00<?, ?it/s]

  tar.extractall(path=extract_dir)


[val_tarred] Loading manifests for shards: [0]
[val_tarred] Assuming audio files are in: /content/extracted_val_tarred
[val_tarred] Loaded 1617 samples.
[test_tarred] Using all 1 shards.


Extracting test_tarred:   0%|          | 0/1 [00:00<?, ?it/s]

  tar.extractall(path=extract_dir)


[test_tarred] Loading manifests for shards: [0]
[test_tarred] Assuming audio files are in: /content/extracted_test_tarred
[test_tarred] Loaded 1569 samples.

=== Data Loading Complete ===
Train: 97980
Eval:  1617
Test:  1569


In [9]:
dir_to_delete = DATA_ROOT

if os.path.exists(dir_to_delete):
    try:
        shutil.rmtree(dir_to_delete)
        print(f"Successfully deleted: {dir_to_delete}")
    except OSError as e:
        print(f"Error deleting folder: {e}")
else:
    print(f"Directory not found: {dir_to_delete}")

Successfully deleted: ./asr_data


# Text Cleaning

ASR models trained with CTC loss generally require a simple character set.
1.  **Normalization:** We convert all text to lowercase.
2.  **Cleaning:** We remove punctuation (periods, commas, etc.) because we want the model to focus on the phonetic sounds, not grammatical structure.

In [10]:
def clean_text(batch):
    # Handle Test Set (where text is None)
    if batch["text"] is None:
        batch["text"] = ""
        return batch

    # Safety Lowercase
    batch["text"] = batch["text"].lower()

    return batch

print("Running safety check on text...")
train_dataset = train_dataset.map(clean_text)
eval_dataset = eval_dataset.map(clean_text)
test_dataset = test_dataset.map(clean_text)
print("Text verified.")

Running safety check on text...


Map:   0%|          | 0/97980 [00:00<?, ? examples/s]

Map:   0%|          | 0/1617 [00:00<?, ? examples/s]

Map:   0%|          | 0/1569 [00:00<?, ? examples/s]

Text verified.


# Create Vocabulary

The model doesn't output words directly; it outputs characters. We must build a vocabulary (JSON) containing every unique character found in the dataset.

We also add special tokens. While CTC mainly uses the first three, the `Wav2Vec2Processor` requires the others to maintain compatibility with the library structure:

*   **`|`**: Represents a space between words (Visual delimiter).
*   **`[UNK]`**: Unknown character (if the model encounters a char not in the vocab).
*   **`[PAD]`**: Padding token (used to make audio batches the same length).
*   **`[CLS]`**: Classification start token (Standard in Hugging Face models).
*   **`[SEP]`**: Separator token (Standard in Hugging Face models).
*   **`[MASK]`**: Mask token (Used during pre-training, required by the config).

In [11]:
def extract_all_chars(batch):
    all_text = " ".join(batch["text"])
    vocab = list(set(all_text))
    return {"vocab": [vocab], "all_text": [all_text]}

# Extract all chars
vocab_train = train_dataset.map(extract_all_chars, batched=True, batch_size=-1, keep_in_memory=True, remove_columns=train_dataset.column_names)
vocab_eval = eval_dataset.map(extract_all_chars, batched=True, batch_size=-1, keep_in_memory=True, remove_columns=train_dataset.column_names)

# Merge and enumerate
vocab_list = list(set(vocab_train["vocab"][0]) | set(vocab_eval["vocab"][0]))
vocab_dict = {v: k for k, v in enumerate(vocab_list)}

# Add special tokens
vocab_dict["|"] = vocab_dict[" "]
del vocab_dict[" "]
vocab_dict["[UNK]"] = len(vocab_dict)
vocab_dict["[PAD]"] = len(vocab_dict)
vocab_dict["[CLS]"] = len(vocab_dict)
vocab_dict["[SEP]"] = len(vocab_dict)
vocab_dict["[MASK]"] = len(vocab_dict)

# Save to file
with open("vocab.json", "w") as vocab_file:
    json.dump(vocab_dict, vocab_file)

print(f"Vocabulary created with {len(vocab_dict)} characters.")

Map:   0%|          | 0/97980 [00:00<?, ? examples/s]

Map:   0%|          | 0/1617 [00:00<?, ? examples/s]

Vocabulary created with 33 characters.


# Initialize Processor

The `Wav2Vec2Processor` wraps two things:
1.  **Feature Extractor:** Converts raw audio (arrays) into the numerical inputs the model understands.
2.  **Tokenizer:** Converts text strings into the integer IDs from our vocabulary.

In [12]:
tokenizer = Wav2Vec2CTCTokenizer("./vocab.json", unk_token="[UNK]", pad_token="[PAD]", word_delimiter_token="|")

feature_extractor = Wav2Vec2FeatureExtractor(
    feature_size=1,
    sampling_rate=16000,
    padding_value=0.0,
    do_normalize=True,
    return_attention_mask=True
)

processor = Wav2Vec2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer)
print("Processor initialized.")

Processor initialized.


# Feature Extraction (RAM-Safe "Divide and Conquer")

**The Problem:** Audio processing is extremely memory-intensive. A compressed dataset (like .tar.xz) expands by 20x-50x when loaded into RAM as float arrays. On Google Colab, standard processing causes a "Memory Leak" where RAM usage climbs linearly until the session crashes.

**The Solution:** We implement a **Chunked Processing Strategy**.

Instead of processing the whole dataset at once, we:
1.  **Filter:** Remove empty/corrupt rows first.
2.  **Shard:** Split the dataset into small pieces (e.g., 20 chunks for training).
3.  **Process & Flush:**
    *   Load *one* small chunk into RAM.
    *   Convert Audio to Numbers (`input_values`) and Text to IDs (`labels`).
    *   **Save immediately to disk** (`save_to_disk`).
    *   **Nuke RAM:** Manually delete variables and run Garbage Collection (`gc.collect()`) to reset memory usage to zero.
4.  **Reassemble:** Finally, we load the processed chunks from the disk using pointers (which uses almost no RAM) and concatenate them.

**Column Management:**
*   **Train/Eval:** We drop the original `audio` and `text` columns to save space.
*   **Test:** We **keep** metadata columns (like filenames) because we need them to generate the submission file later.

In [None]:
# Create a temporary folder for shards
SHARD_DIR = "./temp_shards"
if os.path.exists(SHARD_DIR):
    shutil.rmtree(SHARD_DIR)
os.makedirs(SHARD_DIR)

# Filter Data
print("Filtering empty text...")
train_dataset = train_dataset.filter(lambda x: x["text"] is not None and len(x["text"].strip()) > 0)
eval_dataset = eval_dataset.filter(lambda x: x["text"] is not None and len(x["text"].strip()) > 0)

# Define The Processor Function
def prepare_dataset_wrapper(batch):
    audio = batch["audio"]

    # Process inputs
    batch["input_values"] = processor(
        audio["array"],
        sampling_rate=audio["sampling_rate"]
    ).input_values[0]

    # Process labels (if text exists)
    if batch.get("text"):
        batch["labels"] = processor(text=batch["text"]).input_ids

    return batch

# The Chunk Processing Engine
def process_dataset_in_chunks(dataset, name, num_shards=10):
    print(f"\n=== Processing {name} in {num_shards} chunks to save RAM ===")

    # Columns to remove (Keep metadata for Test set, drop for others)
    if "train" in name.lower() or "eval" in name.lower():
        keep_cols = ["input_values", "labels"]
        cols_to_remove = [c for c in dataset.column_names if c not in keep_cols]
    else:
        # For Test set, keep metadata files
        cols_to_remove = ["audio", "text", "raw_text", "transcriber_id"]
        cols_to_remove = [c for c in cols_to_remove if c in dataset.column_names]

    shard_paths = []

    for i in range(num_shards):
        print(f"  > Processing Chunk {i+1}/{num_shards}...")

        # Create a small slice of the data (RAM safe)
        shard = dataset.shard(num_shards=num_shards, index=i)

        # Map features
        processed_shard = shard.map(
            prepare_dataset_wrapper,
            remove_columns=cols_to_remove,
            batch_size=16,       # Very small batch for safety
            writer_batch_size=16,
            num_proc=1
        )

        # Save immediately to disk
        save_path = os.path.join(SHARD_DIR, f"{name}_shard_{i}")
        processed_shard.save_to_disk(save_path)
        shard_paths.append(save_path)

        # NUKE RAM (Delete variables and force Garbage Collection)
        del shard
        del processed_shard
        gc.collect()

    print(f"  > Reassembling {name} from disk...")
    # Load all shards from disk pointers (uses almost 0 RAM)
    final_dataset = concatenate_datasets([load_from_disk(p) for p in shard_paths])
    return final_dataset


# Process Train
# We split training into 20 chunks because it's the biggest cause of crashes
train_dataset = process_dataset_in_chunks(train_dataset, "train", num_shards=20)

# Process Eval (5 chunks is enough)
eval_dataset = process_dataset_in_chunks(eval_dataset, "eval", num_shards=5)

# Process Test (5 chunks)
test_dataset = process_dataset_in_chunks(test_dataset, "test", num_shards=5)

print("\n✓ All Features Extracted Successfully (RAM Safe).")
print(f"Train Size: {len(train_dataset)}")
print(f"Eval Size:  {len(eval_dataset)}")

Filtering empty text...


Filter:   0%|          | 0/97980 [00:00<?, ? examples/s]

Filter:   0%|          | 0/1617 [00:00<?, ? examples/s]


=== Processing train in 20 chunks to save RAM ===
  > Processing Chunk 1/20...


Map:   0%|          | 0/4893 [00:00<?, ? examples/s]

Saving the dataset (0/12 shards):   0%|          | 0/4893 [00:00<?, ? examples/s]

  > Processing Chunk 2/20...


Map:   0%|          | 0/4893 [00:00<?, ? examples/s]

Saving the dataset (0/13 shards):   0%|          | 0/4893 [00:00<?, ? examples/s]

  > Processing Chunk 3/20...


Map:   0%|          | 0/4893 [00:00<?, ? examples/s]

Saving the dataset (0/13 shards):   0%|          | 0/4893 [00:00<?, ? examples/s]

  > Processing Chunk 4/20...


Map:   0%|          | 0/4893 [00:00<?, ? examples/s]

Saving the dataset (0/14 shards):   0%|          | 0/4893 [00:00<?, ? examples/s]

  > Processing Chunk 5/20...


Map:   0%|          | 0/4893 [00:00<?, ? examples/s]

Saving the dataset (0/12 shards):   0%|          | 0/4893 [00:00<?, ? examples/s]

  > Processing Chunk 6/20...


Map:   0%|          | 0/4893 [00:00<?, ? examples/s]

Saving the dataset (0/13 shards):   0%|          | 0/4893 [00:00<?, ? examples/s]

  > Processing Chunk 7/20...


Map:   0%|          | 0/4893 [00:00<?, ? examples/s]

Saving the dataset (0/13 shards):   0%|          | 0/4893 [00:00<?, ? examples/s]

  > Processing Chunk 8/20...


Map:   0%|          | 0/4893 [00:00<?, ? examples/s]

Saving the dataset (0/14 shards):   0%|          | 0/4893 [00:00<?, ? examples/s]

OSError: [Errno 28] No space left on device

# Load Base Model & Freeze

1.  We load `facebook/wav2vec2-large-xlsr-53`. This is a massive model pre-trained on 53 languages.
2.  **FREEZING:** We loop through `model.parameters()` and set `requires_grad = False`.

In [None]:
# Load Base Model
model_id = "facebook/wav2vec2-large-xlsr-53"
print(f"Loading Base Model: {model_id}...")

model = Wav2Vec2ForCTC.from_pretrained(
    model_id,
    ctc_loss_reduction="mean",
    pad_token_id=processor.tokenizer.pad_token_id,
    vocab_size=len(processor.tokenizer),
    ignore_mismatched_sizes=True # Needed because vocab size differs from pre-training
)

# FREEZE BASE MODEL
for param in model.parameters():
    param.requires_grad = False

print("Base Model loaded and FROZEN.")

# Inject Adapters (LoRA)

We use **LoRA (Low-Rank Adaptation)**.

Instead of retraining the whole model, LoRA inserts tiny matrices into the attention layers (`q_proj`, `v_proj`). We only train these tiny matrices and the final classification layer (the `lm_head`), which maps features to our specific Kinyarwanda vocabulary.

In [None]:
# Configure LoRA
peft_config = LoraConfig(
    inference_mode=False,
    r=32,               # Rank: Size of the adapter (larger = more expressive but slower)
    lora_alpha=64,      # Alpha: Scaling factor
    lora_dropout=0.1,
    target_modules=["q_proj", "v_proj"], # Apply to Attention mechanism
    bias="none",
)

# Apply configuration to model
model = get_peft_model(model, peft_config)

# Unfreeze the output layer (Head) so it can learn our specific characters
for param in model.lm_head.parameters():
    param.requires_grad = True

print("=== Trainable Parameters ===")
model.print_trainable_parameters()

# Helper Functions

*   **DataCollator:** Since audio files have different lengths, we must pad them so they fit into a rectangular batch.
*   **Metric:** We load the WER metric to evaluate performance during training.

In [None]:
@dataclass
class DataCollatorCTCWithPadding:
    processor: Wav2Vec2Processor
    padding: Union[bool, str] = True

    def __call__(self, features):
        input_features = [{"input_values": feature["input_values"]} for feature in features]
        label_features = [{"input_ids": feature["labels"]} for feature in features]

        # Pad Audio
        batch = self.processor.feature_extractor.pad(
            input_features, padding=self.padding, return_tensors="pt"
        )
        # Pad Labels
        with self.processor.as_target_processor():
            labels_batch = self.processor.tokenizer.pad(
                label_features, padding=self.padding, return_tensors="pt"
            )

        # Mask padding in labels so loss isn't calculated on pads
        labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)
        batch["labels"] = labels
        return batch

data_collator = DataCollatorCTCWithPadding(processor=processor)
wer_metric = evaluate.load("wer")

def compute_metrics(pred):
    pred_logits = pred.predictions
    pred_ids = np.argmax(pred_logits, axis=-1)
    pred.label_ids[pred.label_ids == -100] = processor.tokenizer.pad_token_id
    pred_str = processor.batch_decode(pred_ids)
    label_str = processor.batch_decode(pred.label_ids, group_tokens=False)
    wer = wer_metric.compute(predictions=pred_str, references=label_str)
    return {"wer": wer}

# Training

We configure the `Trainer`.
*   `learning_rate=1e-3`: LoRA usually allows for a higher learning rate than full fine-tuning.
*   `fp16=True`: Uses Mixed Precision to speed up training on the GPU.
*   `num_train_epochs`: 15 epochs should be sufficient for a small dataset.

In [None]:
training_args = TrainingArguments(
    output_dir=OUTPUT_DIR,
    group_by_length=True,

    # --- MEMORY OPTIMIZATION (Using Constants) ---
    per_device_train_batch_size=PER_DEVICE_TRAIN_BATCH_SIZE,
    per_device_eval_batch_size=PER_DEVICE_EVAL_BATCH_SIZE,
    gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS,
    # ---------------------------------------------

    gradient_checkpointing=True,
    dataloader_num_workers=0,

    num_train_epochs=NUM_EPOCHS,
    fp16=True,
    save_steps=200,
    eval_steps=200,
    logging_steps=50,
    learning_rate=LEARNING_RATE,
    save_total_limit=2,
    evaluation_strategy="steps",
    report_to="none"
)

trainer = Trainer(
    model=model,
    data_collator=data_collator,
    args=training_args,
    compute_metrics=compute_metrics,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    tokenizer=processor.feature_extractor,
)

print("Starting training...")
trainer.train()

**Save the Adapter Weights:** We save *only* the small adapter files.

**Generate Transcriptions:** We run inference on the test set using our fine-tuned model.

**Baseline Comparison:** To prove our adapter works, we (temporarily) disable the adapter and run inference on the base model to show how bad the WER is without it.

In [None]:
# 1. Save Adapter
save_path = "./submission/adapter_weights"
model.save_pretrained(save_path)
print(f"Adapter weights saved to {save_path}")

# Helper for inference
def map_to_result(batch, model_in_use):
    with torch.no_grad():
        input_values = torch.tensor(batch["input_values"], device=device).unsqueeze(0)
        logits = model_in_use(input_values).logits
    pred_ids = torch.argmax(logits, dim=-1)
    batch["pred_str"] = processor.batch_decode(pred_ids)[0]
    batch["text"] = processor.decode(batch["labels"], group_tokens=False)
    return batch

print("Generating predictions...")

# A. Evaluate Fine-Tuned Model
results_ft = test_dataset.map(lambda b: map_to_result(b, model), remove_columns=["input_values"])
wer_ft = wer_metric.compute(predictions=results_ft["pred_str"], references=results_ft["text"])

# B. Evaluate Base Model (Disable Adapter temporarily)
with model.disable_adapter():
    results_base = test_dataset.map(lambda b: map_to_result(b, model), remove_columns=["input_values"])
wer_base = wer_metric.compute(predictions=results_base["pred_str"], references=results_base["text"])

print(f"\n=== FINAL RESULTS ===")
print(f"Base Model WER: {wer_base:.2%}")
print(f"Adapter Model WER: {wer_ft:.2%}")

# C. Save Text Files
os.makedirs("submission", exist_ok=True)

with open("submission/finetuned_transcriptions.txt", "w") as f:
    for p, r in zip(results_ft["pred_str"], results_ft["text"]):
        f.write(f"Ref: {r}\nPred: {p}\n\n")

with open("submission/base_transcriptions.txt", "w") as f:
    for p, r in zip(results_base["pred_str"], results_base["text"]):
        f.write(f"Ref: {r}\nPred: {p}\n\n")

print("Submission files ready in ./submission")