In [1]:
!pip install --upgrade pip
!pip install --upgrade datasets transformers accelerate soundfile librosa evaluate jiwer tensorboard gradio



In [2]:
from huggingface_hub import interpreter_login
from datasets import load_dataset, Audio
from transformers import WhisperProcessor, WhisperForConditionalGeneration, pipeline
import evaluate
from transformers.models.whisper.english_normalizer import BasicTextNormalizer
from tqdm import tqdm

In [3]:
interpreter_login()


    _|    _|  _|    _|    _|_|_|    _|_|_|  _|_|_|  _|      _|    _|_|_|      _|_|_|_|    _|_|      _|_|_|  _|_|_|_|
    _|    _|  _|    _|  _|        _|          _|    _|_|    _|  _|            _|        _|    _|  _|        _|
    _|_|_|_|  _|    _|  _|  _|_|  _|  _|_|    _|    _|  _|  _|  _|  _|_|      _|_|_|    _|_|_|_|  _|        _|_|_|
    _|    _|  _|    _|  _|    _|  _|    _|    _|    _|    _|_|  _|    _|      _|        _|    _|  _|        _|
    _|    _|    _|_|      _|_|_|    _|_|_|  _|_|_|  _|      _|    _|_|_|      _|        _|    _|    _|_|_|  _|_|_|_|

    A token is already saved on your machine. Run `huggingface-cli whoami` to get more information or `huggingface-cli logout` if you want to log out.
    Setting a new token will erase the existing one.
    To login, `huggingface_hub` requires a token generated from https://huggingface.co/settings/tokens .
Token is valid (permission: write).
Your token has been saved to /Users/amyguan/.cache/huggingface/token
Login success

In [17]:
# dataset_total = load_dataset("mozilla-foundation/common_voice_16_1", "en", split="train", token=True, trust_remote_code=True, streaming=True)
# dataset_total.features

In [86]:
# dataset_total = dataset_total.shuffle(seed=42, buffer_size=10_000)
# dataset_total = dataset_total.take(60_000) # approx half of training
# dataset_total = dataset_total.filter(lambda example: example['accent'] != '') # doesn't work
# dataset_total = dataset_total.cast_column("audio", Audio(sampling_rate=16_000))

# Preprocess and Init Whisper
### note: remember to change the gpu config stuff, change dataset to train

In [4]:
processor = WhisperProcessor.from_pretrained("openai/whisper-small", language="English", task="transcribe")
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-small")

tokenizer_config.json:   0%|          | 0.00/283k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/2.48M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/494k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/2.19k [00:00<?, ?B/s]

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [5]:
wer_metric = evaluate.load("wer")
cer_metric = evaluate.load("cer")

In [6]:
# modified from a github repo: https://github.com/vasistalodagala/whisper-finetune/tree/master

def is_target_text_in_range(ref):
    if ref.strip() == "ignore time segment in scoring":
        return False
    else:
        return ref.strip() != ""
    
def get_text(sample):
    # can replace with just return sample["sentence"]?
    if "text" in sample:
        return sample["text"]
    elif "sentence" in sample:
        return sample["sentence"]
    elif "normalized_text" in sample:
        return sample["normalized_text"]
    elif "transcript" in sample:
        return sample["transcript"]
    elif "transcription" in sample:
        return sample["transcription"]
    else:
        raise ValueError(
            f"Expected transcript column of either 'text', 'sentence', 'normalized_text' or 'transcript'. Got sample of "
            ".join{sample.keys()}. Ensure a text column name is present in the dataset."
        )
    
def get_accents(sample):
    if "accent" in sample:
        # can remove the india comma outlier thing
        return sample["accent"].split(',')
    else:
        raise ValueError(
            f"Expected transcript column of accent. Ensure an accent column is present in the dataset."
        )

whisper_norm = BasicTextNormalizer()

def normalise(batch):
    batch["norm_text"] = whisper_norm(get_text(batch))
    return batch

def data(dataset):
    for i, item in enumerate(dataset):
        yield {**item["audio"], "reference": get_text(item), "norm_reference": item["norm_text"], "accents": get_accents(item)}

In [7]:
# device 0 for gpu, device -1 for cpu?
whisper_asr = pipeline(
    "automatic-speech-recognition", model=model, tokenizer=processor.tokenizer, feature_extractor=processor.feature_extractor, device=-1
) # "openai/whisper-small"

whisper_asr.model.config.forced_decoder_ids = (
    whisper_asr.tokenizer.get_decoder_prompt_ids(
        language="english", task="transcribe"
    )
)

whisper_asr.model.generation_config.forced_decoder_ids = processor.tokenizer.get_decoder_prompt_ids(
    language="english", task="transcribe"
)

In [8]:
# iterable dataset
dataset_total = load_dataset("mozilla-foundation/common_voice_16_1", "en", split="train", token=True, trust_remote_code=True, streaming=True)
# dataset_total = load_dataset("mozilla-foundation/common_voice_16_1", "en", split="train", token=True, trust_remote_code=True, streaming=True)
text_column_name = "sentence"

dataset_total = dataset_total.shuffle(seed=42, buffer_size=10_000)
dataset_total = dataset_total.take(60_000) # approx half of training
dataset_total = dataset_total.cast_column("audio", Audio(sampling_rate=16000))
dataset_total = dataset_total.map(normalise) # , num_proc=2
dataset_total = dataset_total.filter(is_target_text_in_range, input_columns=[text_column_name]) # , num_proc=2
dataset_total = dataset_total.filter(lambda example: example['accent'] != '')

In [9]:
dataset_total

IterableDataset({
    features: Unknown,
    n_shards: 28
})

In [10]:
# make these specific per accent?
predictions = {}
references = {}
norm_predictions = {}
norm_references = {}

all_accents = []

for out in tqdm(whisper_asr(data(dataset_total), batch_size=16), desc='Decode Progress'):
    for accent in out["accents"][0]: # will skip if empty
        # print(out["accents"])
        if accent not in all_accents:
            all_accents.append(accent)
        if accent not in predictions:
            predictions[accent] = []
        if accent not in references:
            references[accent] = []
        if accent not in norm_predictions:
            norm_predictions[accent] = []
        if accent not in norm_references:
            norm_references[accent] = []
        predictions[accent].append(out["text"])
        references[accent].append(out["reference"][0])
        norm_predictions[accent].append(whisper_norm(out["text"]))
        norm_references[accent].append(out["norm_reference"][0])

"""
while True: # ????
    try:
        # idk about this map fn lowkey
        # dataset = dataset_total.map(data, batched=True)
        # print('done mapping data')
        
        for out in tqdm(whisper_asr(data(dataset_total), batch_size=16), desc='Decode Progress'):
        # for out in tqdm(whisper_asr(dataset, batch_size=16), desc='Decode Progress'):
            # predictions.append(out["text"])
            # references.append(out["reference"][0])
            
            print('test!')
            # for accent in out["accents"]: # will skip if empty
            #     print(out["accents"])
            #     if accent not in all_accents:
            #         all_accents.append(accent)
            #     if accent not in predictions:
            #         predictions[accent] = []
            #     if accent not in references:
            #         references[accent] = []
            #     if accent not in norm_predictions:
            #         norm_predictions[accent] = []
            #     if accent not in norm_references:
            #         norm_references[accent] = []
            #     predictions[accent].append(out["text"])
            #     references[accent].append(out["reference"][0])
            #     norm_predictions[accent].append(whisper_norm(out["text"]))
            #     norm_references[accent].append(out["norm_reference"][0])

        # dataset_total.skip(i * 1028)
        # dataset_it = dataset_total.take(1028)
        # print('iter done')
        # dataset = Dataset.from_generator(lambda: (yield from dataset_it), features=dataset_it.features)
        # for out in tqdm(whisper_asr(data(dataset), batch_size=16), desc='Decode Progress'):
        #     # predictions.append(out["text"])
        #     # references.append(out["reference"][0])
            
        #     for accent in out["accents"]: # will skip if empty
        #         print(out["accents"])
        #         if accent not in all_accents:
        #             all_accents.append(accent)
        #         if accent not in predictions:
        #             predictions[accent] = []
        #         if accent not in references:
        #             references[accent] = []
        #         if accent not in norm_predictions:
        #             norm_predictions[accent] = []
        #         if accent not in norm_references:
        #             norm_references[accent] = []
        #         predictions[accent].append(out["text"])
        #         references[accent].append(out["reference"][0])
        #         norm_predictions[accent].append(whisper_norm(out["text"]))
        #         norm_references[accent].append(out["norm_reference"][0])
    except:
        print("done going through whole dataset (???)")
"""

Reading metadata...: 1090061it [00:58, 18764.45it/s]




Decode Progress: 240it [10:48,  2.70s/it]


KeyboardInterrupt: 

In [48]:
# os.system(f"mkdir -p predictions/")
import os

metrics = {}

for accent in all_accents:
    wer = wer_metric.compute(references=references[accent], predictions=predictions[accent])
    wer = round(100 * wer, 2)
    cer = cer_metric.compute(references=references[accent], predictions=predictions[accent])
    cer = round(100 * cer, 2)
    norm_wer = wer_metric.compute(references=norm_references[accent], predictions=norm_predictions[accent])
    norm_wer = round(100 * norm_wer, 2)
    norm_cer = cer_metric.compute(references=norm_references[accent], predictions=norm_predictions[accent])
    norm_cer = round(100 * norm_cer, 2)

    # print("\nACCENT: ", accent)
    # print("\nWER : ", wer)
    # print("CER : ", cer)
    # print("\nNORMALIZED WER : ", norm_wer)
    # print("NORMALIZED CER : ", norm_cer)

    acc_name = whisper_norm(accent).strip().replace(' ', '_')
    if acc_name == 'pakistan' or acc_name == 'sri_lanka':
        continue

    metrics[acc_name] = {'wer': wer, 'cer': cer, 'norm_wer': norm_wer, 'norm_cer': norm_cer}

    os.system(f"mkdir -p predictions")
    # os.system(f"mkdir -p predictions/{acc_name}")
    op_file = f"predictions/{acc_name}.txt" # ??
    result_file = open(op_file, 'w') # a? or w?
    result_file.write('ACCENT: ' + str(accent) + '\n')
    result_file.write('\nWER: ' + str(wer) + '\n')
    result_file.write('CER: ' + str(cer) + '\n')
    result_file.write('\nNORMALIZED WER: ' + str(norm_wer) + '\n')
    result_file.write('NORMALIZED CER: ' + str(norm_cer) + '\n\n\n')

    for ref, hyp in zip(references[accent], predictions[accent]):
        result_file.write('REF: ' + ref + '\n')
        result_file.write('HYP: ' + hyp + '\n')
        result_file.write("------------------------------------------------------" + '\n')
    result_file.close()