# Fine-tune KenLM Hyperparams 

## Install dependancies

In [2]:
!pip install pyctcdecode jiwer torchcodec

Collecting pyctcdecode
  Downloading pyctcdecode-0.5.0-py2.py3-none-any.whl.metadata (20 kB)
Collecting jiwer
  Downloading jiwer-4.0.0-py3-none-any.whl.metadata (3.3 kB)
Collecting torchcodec
  Downloading torchcodec-0.7.0-cp311-cp311-manylinux_2_28_x86_64.whl.metadata (9.4 kB)
Collecting pygtrie<3.0,>=2.1 (from pyctcdecode)
  Downloading pygtrie-2.5.0-py3-none-any.whl.metadata (7.5 kB)
Collecting hypothesis<7,>=6.14 (from pyctcdecode)
  Downloading hypothesis-6.141.0-py3-none-any.whl.metadata (5.7 kB)
Collecting rapidfuzz>=3.9.7 (from jiwer)
  Downloading rapidfuzz-3.14.1-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl.metadata (12 kB)
Downloading pyctcdecode-0.5.0-py2.py3-none-any.whl (39 kB)
Downloading jiwer-4.0.0-py3-none-any.whl (23 kB)
Downloading torchcodec-0.7.0-cp311-cp311-manylinux_2_28_x86_64.whl (1.4 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.4/1.4 MB[0m [31m33.9 MB/s[0m eta [36m0:00:00[0m00:01[0m
[?25hDownloading hypothesis-6.1

In [3]:
!pip install https://github.com/kpu/kenlm/archive/master.zip

Collecting https://github.com/kpu/kenlm/archive/master.zip
  Downloading https://github.com/kpu/kenlm/archive/master.zip
[2K     [32m-[0m [32m553.6 kB[0m [31m9.3 MB/s[0m [33m0:00:00[0m
[?25h  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Building wheels for collected packages: kenlm
  Building wheel for kenlm (pyproject.toml) ... [?25l[?25hdone
  Created wheel for kenlm: filename=kenlm-0.2.0-cp311-cp311-linux_x86_64.whl size=3185029 sha256=7d7aa8b9dd7abbe712151c16a97fd33f2ab3f16405e4c4df1f3b5252d5df8389
  Stored in directory: /tmp/pip-ephem-wheel-cache-rgvp51tp/wheels/4e/ca/6a/e5da175b1396483f6f410cdb4cfe8bc8fa5e12088e91d60413
Successfully built kenlm
Installing collected packages: kenlm
Successfully installed kenlm-0.2.0


## Setup environment

In [None]:
import torch

device = "cuda" if torch.cuda.is_available() else "cpu"
print("device:", device)

device: cuda


### Load Pre-trained Model

In [5]:
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
from transformers import Wav2Vec2Config

processor = Wav2Vec2Processor.from_pretrained("/kaggle/input/wavlm-ctc-ex-2/wavlm-ctc-ex-2")
config = Wav2Vec2Config.from_pretrained(
    "/kaggle/input/wavlm-ctc-ex-2/wavlm-ctc-ex-2",
)
config.ctc_loss_reduction = "mean"
config.pad_token_id = processor.tokenizer.pad_token_id
config.final_dropout = 0.2

model = Wav2Vec2ForCTC.from_pretrained(
    "/kaggle/input/wavlm-ctc-ex-2/wavlm-ctc-ex-2",
    config=config
).to(device)

2025-10-15 14:34:05.115344: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1760538845.298483      37 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1760538845.353251      37 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


In [6]:
!wget -O 4-gram.arpa.gz https://openslr.elda.org/resources/11/4-gram.arpa.gz

--2025-10-15 14:34:45--  https://openslr.elda.org/resources/11/4-gram.arpa.gz
Resolving openslr.elda.org (openslr.elda.org)... 141.94.109.138, 2001:41d0:203:ad8a::
Connecting to openslr.elda.org (openslr.elda.org)|141.94.109.138|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1355172078 (1.3G) [application/x-gzip]
Saving to: ‘4-gram.arpa.gz’


2025-10-15 14:35:41 (23.3 MB/s) - ‘4-gram.arpa.gz’ saved [1355172078/1355172078]



In [7]:
!gunzip 4-gram.arpa.gz

In [8]:
import os
import soundfile as sf
from datasets import Dataset

### Data Preprocessing

In [9]:
base_path = "/kaggle/input/librispeech/LibriSpeech/test-clean"

audio_data = []
transcripts = []

for speaker_folder in os.listdir(base_path):
    speaker_path = os.path.join(base_path, speaker_folder)
    if not os.path.isdir(speaker_path):
        continue
    for chapter_folder in os.listdir(speaker_path):
        chapter_path = os.path.join(speaker_path, chapter_folder)
        if not os.path.isdir(chapter_path):
            continue
        
        # Read transcript file
        transcript_file = [f for f in os.listdir(chapter_path) if f.endswith(".txt")][0]
        with open(os.path.join(chapter_path, transcript_file), "r") as f:
            lines = f.readlines()
        
        # Read each FLAC file
        for file in os.listdir(chapter_path):
            if file.endswith(".flac"):
                audio_path = os.path.join(chapter_path, file)
                audio_array, sr = sf.read(audio_path)  # Load audio
                audio_data.append(audio_array)
                
                # Get transcript corresponding to the file
                file_id = os.path.splitext(file)[0]
                transcript = [l.split(" ", 1)[1].strip() for l in lines if l.startswith(file_id)][0]
                transcripts.append(transcript)

print(f"Loaded {len(audio_data)} audio files and transcripts.")

Loaded 2620 audio files and transcripts.


In [10]:
dataset = Dataset.from_dict({
    "audio": audio_data,
    "text": transcripts
})

In [11]:
dataset.column_names

['audio', 'text']

---

### Define Evaluation Metrics

In [12]:
from jiwer import wer, cer

def get_wer_cer(text, transcription):
  return wer(list(text), list(transcription)), cer(list(text), list(transcription))

---

In [13]:
def map_to_pred_model(batch):
    # process batch
    inputs = processor(batch["audio"], sampling_rate=16000, return_tensors="pt", padding=True)

    with torch.no_grad():
        logits = model(inputs.input_values.to(model.device)).logits.cpu().numpy()

    batch["logits"] = logits
    return batch

In [14]:
resultbatch = dataset.map(map_to_pred_model, batched=True, batch_size=8, remove_columns=["audio"])



Map:   0%|          | 0/2620 [00:00<?, ? examples/s]

---

In [3]:
from pyctcdecode import build_ctcdecoder

In [None]:
vocab = [x[0] for x in sorted(processor.tokenizer.get_vocab().items(), key=lambda item: item[1])]

def make_decoder(alpha, beta, logits):
    decoder = build_ctcdecoder(
        labels=vocab,
        kenlm_model_path="4-gram.arpa",
        alpha=alpha,
        beta=beta
    )
    return [decoder.decode(logit) for logit in logits]

---

# Fine Tuning

In [33]:
import gc
from pyctcdecode import build_ctcdecoder

def tune_decoder_hparams(logits, kenlm_path, alphas, betas):
    """
    Tune alpha and beta for LM decoding using a small validation subset.
    """

    best_wer = float("inf")
    best_params = (None, None)

    # Loop through combinations
    for alpha in alphas:
        for beta in betas:
            beam_lm = make_decoder(alpha, beta, logits)

            # beam_lm = [decoder.decode(logit) for logit in logits]
            # resultModel = resultbatch.add_column("transcription", beam_lm)
            wer_score, cer_score = get_wer_cer(resultbatch["text"], beam_lm)
            print(f"Alpha={alpha}, Beta={beta}, WER={wer_score:.4f}")

            if wer_score < best_wer:
                best_wer = wer_score
                best_params = (alpha, beta)

            # del decoder
            del beam_lm
            gc.collect()

    print(f"\n Best Params → Alpha={best_params[0]}, Beta={best_params[1]}, WER={best_wer:.4f}")
    return best_params

In [None]:
import numpy as np
import pyctcdecode

alphas = [0.3, 0.5, 1.0, 1.5]
betas = [0.1, 0.35, 0.5]

best_alpha, best_beta = tune_decoder_hparams(
    logits=[np.array(l) for l in resultbatch["logits"]],
    kenlm_path="4-gram.arpa",
    alphas=alphas,
    betas=betas
)

Loading the LM will be faster if you build a binary file.
Reading /kaggle/working/4-gram.arpa
----5---10---15---20---25---30---35---40---45---50---55---60---65---70---75---80---85---90---95--100
****************************************************************************************************


Alpha=0.3, Beta=0.1, WER=0.0435


Loading the LM will be faster if you build a binary file.
Reading /kaggle/working/4-gram.arpa
----5---10---15---20---25---30---35---40---45---50---55---60---65---70---75---80---85---90---95--100
****************************************************************************************************


Alpha=0.3, Beta=0.35, WER=0.0437


Loading the LM will be faster if you build a binary file.
Reading /kaggle/working/4-gram.arpa
----5---10---15---20---25---30---35---40---45---50---55---60---65---70---75---80---85---90---95--100
****************************************************************************************************


Alpha=0.3, Beta=0.5, WER=0.0437


Loading the LM will be faster if you build a binary file.
Reading /kaggle/working/4-gram.arpa
----5---10---15---20---25---30---35---40---45---50---55---60---65---70---75---80---85---90---95--100
****************************************************************************************************


Alpha=0.5, Beta=0.1, WER=0.0405


Loading the LM will be faster if you build a binary file.
Reading /kaggle/working/4-gram.arpa
----5---10---15---20---25---30---35---40---45---50---55---60---65---70---75---80---85---90---95--100
****************************************************************************************************


Alpha=0.5, Beta=0.35, WER=0.0406


Loading the LM will be faster if you build a binary file.
Reading /kaggle/working/4-gram.arpa
----5---10---15---20---25---30---35---40---45---50---55---60---65---70---75---80---85---90---95--100
****************************************************************************************************


Alpha=0.5, Beta=0.5, WER=0.0408


Loading the LM will be faster if you build a binary file.
Reading /kaggle/working/4-gram.arpa
----5---10---15---20---25---30---35---40---45---50---55---60---65---70---75---80---85---90---95--100
****************************************************************************************************


---

## Load fine-tuned Decoder

In [None]:
# Load 4-gram LM
logits=[np.array(l) for l in resultbatch["logits"]]
beam_lm = make_decoder(best_alpha, best_beta, logits)

In [2]:
wer_score, cer_score = get_wer_cer(resultbatch["text"], beam_lm)
print(f"WER={wer_score:.4f}, CER={cer_score:.4f}")