# Addis Whisper

Simple notebook for Addis Whisper checkpoints.


## WER (from each model card)

| Checkpoint | WER |
|---|---:|
| `b1n1yam/shook-medium-amharic-2k` | `9.1091` |
| `b1n1yam/shook-tiny-amharic-stage2-polish` | `20.7634` |
| `b1n1yam/shook-tiny-amharic-600hr` | `22.1786` |
| `b1n1yam/shhook-1.2k-sm` | `21.0` |

Links:
- https://huggingface.co/b1n1yam/shook-medium-amharic-2k
- https://huggingface.co/b1n1yam/shook-tiny-amharic-stage2-polish
- https://huggingface.co/b1n1yam/shook-tiny-amharic-600hr
- https://huggingface.co/b1n1yam/shhook-1.2k-sm


## Setup


In [None]:
import os
import sys

IN_COLAB = "google.colab" in sys.modules
IN_KAGGLE = os.path.exists("/kaggle")
print(f"IN_COLAB={IN_COLAB}, IN_KAGGLE={IN_KAGGLE}")

if IN_COLAB or IN_KAGGLE:
    !apt-get -qq update
    !apt-get -qq install -y ffmpeg

!pip -q install -U transformers accelerate gradio librosa soundfile


## Choose Model (Uncomment One)


In [None]:
# Choose ONE model (uncomment one line only)
model_name = "b1n1yam/shook-medium-amharic-2k"              # best WER on card: 9.1091
# model_name = "b1n1yam/shook-tiny-amharic-stage2-polish"   # WER on card: 20.7634
# model_name = "b1n1yam/shook-tiny-amharic-600hr"           # WER on card: 22.1786
# model_name = "b1n1yam/shhook-1.2k-sm"                     # WER on card: 21.0

print("Using model_name:", model_name)


In [None]:
MODEL_WER = {
    "b1n1yam/shook-medium-amharic-2k": 9.1091,
    "b1n1yam/shook-tiny-amharic-stage2-polish": 20.7634,
    "b1n1yam/shook-tiny-amharic-600hr": 22.1786,
    "b1n1yam/shhook-1.2k-sm": 21.0,
}

selected_wer = MODEL_WER.get(model_name)
print(f"WER for {model_name}: {selected_wer}")


## Load + Inference
Rerun this cell if you change `model_name`.


In [None]:
import torch
import gradio as gr
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline

VALID_MODELS = {
    "b1n1yam/shook-medium-amharic-2k",
    "b1n1yam/shook-tiny-amharic-stage2-polish",
    "b1n1yam/shook-tiny-amharic-600hr",
    "b1n1yam/shhook-1.2k-sm",
}

if model_name not in VALID_MODELS:
    raise ValueError(f"Invalid model_name: {model_name}")

DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu"
DEVICE_INDEX = 0 if DEVICE.startswith("cuda") else -1
DTYPE = torch.float16 if DEVICE.startswith("cuda") else torch.float32

print(f"Loading {model_name} on {DEVICE} ({DTYPE})")

processor = AutoProcessor.from_pretrained(model_name)

try:
    model = AutoModelForSpeechSeq2Seq.from_pretrained(
        model_name,
        torch_dtype=DTYPE,
        low_cpu_mem_usage=True,
        use_safetensors=True,
    )
except Exception:
    model = AutoModelForSpeechSeq2Seq.from_pretrained(
        model_name,
        torch_dtype=DTYPE,
        low_cpu_mem_usage=True,
        use_safetensors=False,
    )

asr = pipeline(
    "automatic-speech-recognition",
    model=model,
    tokenizer=processor.tokenizer,
    feature_extractor=processor.feature_extractor,
    device=DEVICE_INDEX,
    torch_dtype=DTYPE,
    chunk_length_s=30,
    batch_size=8,
)


def transcribe_audio(
    audio_path,
    language="am",
    return_timestamps=True,
    max_new_tokens=225,
    num_beams=1,
    temperature=0.0,
):
    if audio_path is None:
        return "", {"error": "No audio provided"}

    out = asr(
        audio_path,
        return_timestamps=bool(return_timestamps),
        generate_kwargs={
            "language": language,
            "task": "transcribe",
            "max_new_tokens": int(max_new_tokens),
            "num_beams": int(num_beams),
            "temperature": float(temperature),
        },
    )

    text = out.get("text", "").strip()
    chunks = out.get("chunks", [])
    return text, {"model": model_name, "chunks": chunks}


## Gradio App


In [None]:
with gr.Blocks(title="Addis Whisper") as demo:
    gr.Markdown(f"## Addis Whisper\nModel: `{model_name}`")

    audio = gr.Audio(sources=["upload", "microphone"], type="filepath", label="Audio")

    with gr.Row():
        language = gr.Textbox(value="am", label="Language")
        return_timestamps = gr.Checkbox(value=True, label="Timestamps")

    with gr.Row():
        max_new_tokens = gr.Slider(32, 448, value=225, step=1, label="max_new_tokens")
        num_beams = gr.Slider(1, 8, value=1, step=1, label="num_beams")
        temperature = gr.Slider(0.0, 1.0, value=0.0, step=0.05, label="temperature")

    run_btn = gr.Button("Transcribe", variant="primary")

    transcript = gr.Textbox(label="Transcript", lines=8)
    details = gr.JSON(label="Details")

    run_btn.click(
        fn=transcribe_audio,
        inputs=[audio, language, return_timestamps, max_new_tokens, num_beams, temperature],
        outputs=[transcript, details],
    )


demo.launch(share=False, debug=False)


## Script Example


In [None]:
# Example without Gradio:
# text, details = transcribe_audio(
#     audio_path="/content/sample.wav",
#     language="am",
# )
# print(text)
