In [2]:
import torch
import transformers
from dataclasses import dataclass, field
from typing import Union, List, Dict, Any
import string
import os
import json
import datasets
import numpy as np
import yaml
import evaluate
import salt.dataset
import salt.metrics
import salt.constants
from salt.utils import DataCollatorCTCWithPadding as dcwp
import huggingface_hub
import peft
import pandas as pd
import tqdm.notebook as tqdm
import jiwer

In [104]:
config = {'pretrained_model': 'jq/whisper-large-v3-kin'}
#config = {'pretrained_model': 'jq/whisper-large-v3-kin-nyn-lug-xog'}
feature_extractor = transformers.WhisperFeatureExtractor.from_pretrained(
    config['pretrained_model'])
processor = transformers.WhisperProcessor.from_pretrained(
    config['pretrained_model'],
    language=processor.tokenizer.decode(salt.constants.SALT_LANGUAGE_TOKENS_WHISPER['kin']),
    task="transcribe")
model = transformers.WhisperForConditionalGeneration.from_pretrained(
    config['pretrained_model'])

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [105]:
model = model.to('cuda').eval()

In [80]:
model.config.forced_decoder_ids = None
model.generation_config.forced_decoder_ids = None

In [147]:
predict_full_test_set = False

if predict_full_test_set:
    test_ds = datasets.load_dataset('jq/kinyarwanda-speech-hackathon', split='test')
else:
    test_ds = datasets.load_dataset('jq/kinyarwanda-speech-hackathon', split='dev_test[:300]')
    
test_ds = test_ds.cast_column("audio", datasets.Audio(sampling_rate=16000))

Resolving data files:   0%|          | 0/75 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/75 [00:00<?, ?it/s]

In [148]:
test_ids = []
test_transcriptions = []
test_labels = []

for i in tqdm.tqdm(range(len(test_ds))):   
    example = test_ds[i]
    input_features = processor(
        example["audio"]["array"], sampling_rate=16000, return_tensors="pt").input_features
    input_features = input_features.to('cuda')

    prompt_ids = processor.tokenizer.get_prompt_ids(example['prompt'],
        return_tensors='pt',
    ).to('cuda')
    
    predicted_ids = model.generate(
        input_features,
        #num_beams=5,
        #do_sample=True,
        max_length=400,
        temperature=0.01,
        language=processor.tokenizer.decode(salt.constants.SALT_LANGUAGE_TOKENS_WHISPER['kin']),
        #forced_decoder_ids=[[0,50258],[1, 50350]],
    )
    transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]

    if not predict_full_test_set:
        test_labels.append(example['text'])

    test_transcriptions.append(transcription)
    test_ids.append(example['id'])

  0%|          | 0/300 [00:00<?, ?it/s]

In [130]:
import string

def strip_punctuation(text):
    # Create a translation table to remove all punctuation
    translator = str.maketrans('', '', string.punctuation)
    return text.translate(translator)
    
def normalise(texts):
    return [strip_punctuation(t.lower()) for t in texts]
    
if not predict_full_test_set:
    total_wer = jiwer.wer(normalise(test_labels), normalise(test_transcriptions))
    total_cer = jiwer.cer(normalise(test_labels), normalise(test_transcriptions))
    score = 1 - (0.6 * total_cer + 0.4 * total_wer)
    
    print(f"Word Error Rate (WER): {total_wer:.3f}")
    print(f"Character Error Rate (CER): {total_cer:.3f}")
    print(f"Score: {score:.3f}")

Word Error Rate (WER): 0.100
Character Error Rate (CER): 0.027
Score: 0.944


In [137]:
with open('test.json') as f:
    test_metadata = json.load(f)

test_keys = test_metadata.keys()

In [135]:
predictions = {}
for i, pred in zip(test_ids, test_transcriptions):
    predictions[i] = pred

In [144]:
import string

def strip_punctuation(text):
    # Create a translation table to remove all punctuation
    punctuation = '!"#$%&\()*+,-./:;<=>?@[\\]^_`{|}~'
    translator = str.maketrans('', '', punctuation)
    return text.translate(translator)
    
with open('submission.csv', "w", encoding="utf-8") as f:
    f.write('id,transcription\n')
    for k in test_keys:
        pred = predictions.get(k)
        if not pred:
            print('No prediction for key ', k)
            f.write(f"{k},a\n")
        else:
            normalised_pred = strip_punctuation(pred.lower())
            f.write(f"{k},{normalised_pred}\n")

In [145]:
!wc -l submission.csv

9266 submission.csv


In [146]:
!head submission.csv

id,transcription
4ibA9OLWZTajRbwnWjjY,ndabona umugabo uri kuri uhagaze wambaye kasike na jire akaba ashobora kuba atwara ibintu handitseho ngo kashi
ZarC9zz753YnLnE98mpK,pisine ku ruhande rwayo hari udutebe dutwikiriye n'imitaka tubiri ku rundi ruhande naho hakaba hari akandi gatebe konyine hirya hakaba udutebe tundi hari n'imitaka itwikiriye hirya yaho hakaba hari inzu iri kubakwa itari yuzura
1ai3w0iU2yUOeUtLoTSX,ubwishingizi ni ingenzi cyane kubera ko budufasha kandi ntaho batageze bahageze amashami yabo hano ni muri rusizi nk'iki kirango nk'uko kibigaragaza mu ibara ry'ubururu amagambo yandikishije umweru ndetse n'umuhondo ubona ko rero ushobora kuza nawe ugatanga ikibazo cyawe bakakugoboka
IQFsYcsFTsGlnqftc8jg,imodoka ihagaze iri mu ibara ritukura iriho ibirango byamamaza isoko rikorera kuri murandasi hariho nimero zabo za telefone ngendanwa ndetse n'ahandi hose ushobora kubabona
Sd3umUI1wjqp5z5poHe6,ahantu bacururiza amata hari ameza ya purasitike imwe iteretseho ishage n'umufuni