# Template confidence scores with Whisper

In [1]:
import confidence_utils

  from .autonotebook import tqdm as notebook_tqdm


## 1. Load data

In [2]:
from datasets import load_dataset

In [3]:
fleurs_ch = load_dataset("google/fleurs", "cmn_hans_cn")
fleurs_en = load_dataset("google/fleurs", "en_us")

Found cached dataset fleurs (/Users/antonin/.cache/huggingface/datasets/google___fleurs/cmn_hans_cn/2.0.0/aabb39fb29739c495517ac904e2886819b6e344702f0a5b5283cb178b087c94a)
100%|██████████| 3/3 [00:00<00:00, 63.45it/s]
Found cached dataset fleurs (/Users/antonin/.cache/huggingface/datasets/google___fleurs/en_us/2.0.0/aabb39fb29739c495517ac904e2886819b6e344702f0a5b5283cb178b087c94a)
100%|██████████| 3/3 [00:00<00:00, 109.85it/s]


In [4]:
fleurs_ch = fleurs_ch.remove_columns(['id', 'num_samples', 'path', 'gender', 'lang_id', 'language', 'lang_group_id'])
fleurs_en = fleurs_en.remove_columns(['id', 'num_samples', 'path', 'gender', 'lang_id', 'language', 'lang_group_id'])

In [5]:
fleurs_ch_tiny = fleurs_ch['train'].select(range(50))
fleurs_en_tiny = fleurs_en['train'].select(range(50))

## 2. Load models

In [6]:
processor_ch, model_ch = confidence_utils.load_whisper_with_confidence_scores('openai/whisper-base', 'Chinese')
processor_en, model_en = confidence_utils.load_whisper_with_confidence_scores('openai/whisper-base', 'English')

## 3. Run inference test

In [7]:
model_en.config.forced_decoder_ids = processor_en.get_decoder_prompt_ids(language = "en", task = "transcribe")
result_en = fleurs_en_tiny.map(confidence_utils.map_to_pred_and_confidence_scores, 
    fn_kwargs={"processor": processor_en, "model": model_en, "lang": "en"}, 
    batched=True, 
    remove_columns=['audio'], 
    batch_size = 1)

100%|██████████| 50/50 [00:45<00:00,  1.11ba/s]


In [8]:
model_ch.config.forced_decoder_ids = processor_ch.get_decoder_prompt_ids(language = "zh", task = "transcribe")
result_ch = fleurs_ch_tiny.map(confidence_utils.map_to_pred_and_confidence_scores, 
    fn_kwargs={"processor": processor_ch, "model": model_ch, "lang": "zh"}, 
    batched=True, 
    remove_columns=['audio'], 
    batch_size = 1)

100%|██████████| 50/50 [00:54<00:00,  1.08s/ba]


## 4. Check results

In [15]:
result_en[0]

{'transcription': 'although most agencies are willing to take on most regular bookings many agents specialise in particular types of travel budget ranges or destinations',
 'raw_transcription': 'Although most agencies are willing to take on most regular bookings, many agents specialise in particular types of travel, budget ranges or destinations.',
 'string_pred': ' Although most agencies are willing to take on most regular bookings, many agents specialize in particular types of travel, budget ranges or destinations.',
 'tokens_pred': ['ĠAlthough',
  'Ġmost',
  'Ġagencies',
  'Ġare',
  'Ġwilling',
  'Ġto',
  'Ġtake',
  'Ġon',
  'Ġmost',
  'Ġregular',
  'Ġbook',
  'ings',
  ',',
  'Ġmany',
  'Ġagents',
  'Ġspecialize',
  'Ġin',
  'Ġparticular',
  'Ġtypes',
  'Ġof',
  'Ġtravel',
  ',',
  'Ġbudget',
  'Ġranges',
  'Ġor',
  'Ġdestinations',
  '.'],
 'probs_tokens_pred': [0.9250670671463013,
  0.9942331910133362,
  0.9957910776138306,
  0.9980700612068176,
  0.9987837672233582,
  0.99620419

## 5. Display results

In [16]:
from IPython.display import HTML as html_print

In [33]:
def print_tokens_with_confidence(prediction_dataset):
    """
    Retreive html string to display tokens with confidence colors
    """

    final_text = ""

    def cstr(s, color='black'):
        return "<text style=color:{}>{}</text>".format(color, s)

    def map_float_rgb(f, m, M):
        rgb = 'rgb({},{},0)'.format(int(255 * (1 - ((f - m) / (M - m)))), int(255 * (f - m) / (M - m)))
        return rgb

    for row_index in range(prediction_dataset.num_rows):
        tokens = prediction_dataset[row_index]['tokens_pred']
        probs_tokens = prediction_dataset[row_index]['probs_tokens_pred']


        min_prob = min(probs_tokens)
        max_prob = max(probs_tokens)

        final_text += "prediction &nbsp  &nbsp :  " + "".join([cstr(s=tokens[idx], color=map_float_rgb(probs_tokens[idx], min_prob, max_prob)) for idx in range(len(tokens))]) + "<br>"
        final_text += "ground truth : " + prediction_dataset[row_index]['raw_transcription'] + "<br><br>"

    return final_text

In [34]:
html_print(print_tokens_with_confidence(result_en))

In [35]:
# since we know the character Ġ in a token stands for a space in english, we can replace it
html_print(print_tokens_with_confidence(result_en).replace('Ġ', ' '))

In [36]:
html_print(print_tokens_with_confidence(result_ch))