In [42]:
from datasets import load_from_disk

dataset = load_from_disk("../datasets/dataset_oai_coral")
dataset = dataset.with_format("numpy", columns=["audio"], output_all_columns=True)
dataset[0]


{'audio': array([ 0.00140381,  0.00143433,  0.00137329, ...,  0.00180054,
         0.00125122, -0.00219727], shape=(55200,), dtype=float32),
 'time_string': '12:00 AM',
 'hour': 0,
 'minute': 0,
 'period': 'AM',
 'text': "The time is twelve o'clock AM.",
 'transcribed_text': " The time is 12 o'clock AM.\n"}

In [30]:
from IPython.display import Audio

Audio(dataset[2]["audio"], rate=24000)


In [43]:
# NOTE: this is a hack; I fixed this for subsequent runs
dataset = dataset.map(lambda r: {"ground_truth": f"The time is {r['time_string']}."})
dataset[0]

Map: 100%|██████████| 1440/1440 [00:00<00:00, 1937.21 examples/s]


{'audio': array([ 0.00140381,  0.00143433,  0.00137329, ...,  0.00180054,
         0.00125122, -0.00219727], shape=(55200,), dtype=float32),
 'ground_truth': np.str_('The time is 12:00 AM.'),
 'time_string': '12:00 AM',
 'hour': 0,
 'minute': 0,
 'period': 'AM',
 'text': "The time is twelve o'clock AM.",
 'transcribed_text': " The time is 12 o'clock AM.\n"}

In [44]:
import pandas as pd
from dateparser import parse

def grade_row(row):
    if "The time is" not in row["text"]:
        return { "is_correct": False, "error": "failed_prefix" }

    date = parse(row["text"].split("The time is ")[1])
    if date is None:
        return { "is_correct": False, "error": "no_time" }
    elif date.hour % 12 == row["hour"] % 12 and date.minute == row["minute"]:
        return { "is_correct": True, "error": None }
    else:
        return { "is_correct": False, "error": "wrong_time"}


dataset = dataset.map(grade_row, num_proc=12)


Map (num_proc=12): 100%|██████████| 1440/1440 [00:01<00:00, 1365.07 examples/s]


In [45]:
pd.Series(dataset['is_correct']).value_counts()

True     1405
False      35
Name: count, dtype: int64

## Error analysis of known good OAI output

In [47]:
ds_failed = dataset.filter(lambda r: not r['is_correct'])
print(f"Failed: {len(ds_failed)}")
ds_failed[:5]

Filter: 100%|██████████| 1440/1440 [00:00<00:00, 14177.93 examples/s]

Failed: 35





{'audio': array([array([ 0.00140381,  0.00143433,  0.00137329, ...,  0.00180054,
                0.00125122, -0.00219727], shape=(55200,), dtype=float32),
        array([-0.00256348, -0.00201416, -0.0022583 , ...,  0.00164795,
                0.00091553, -0.00234985], shape=(52800,), dtype=float32),
        array([ 0.00311279,  0.00335693,  0.00323486, ...,  0.00137329,
                0.00018311, -0.00234985], shape=(80400,), dtype=float32),
        array([-0.00250244,  0.01016235,  0.01260376, ...,  0.00204468,
                0.00119019, -0.00134277], shape=(55200,), dtype=float32),
        array([-0.00250244, -0.00219727, -0.00219727, ...,  0.0017395 ,
                0.00079346, -0.0022583 ], shape=(57600,), dtype=float32)],
       dtype=object),
 'ground_truth': array(['The time is 12:00 AM.', 'The time is 2:59 AM.',
        'The time is 3:00 AM.', 'The time is 3:23 AM.',
        'The time is 3:38 AM.'], dtype='<U21'),
 'is_correct': array([False, False, False, False, False]),
 '

In [51]:
pd.Series(ds_failed['error']).value_counts()

failed_prefix    24
no_time          10
wrong_time        1
Name: count, dtype: int64

In [53]:
ds_failed.select_columns(['transcribed_text', 'error', 'hour', 'minute']).map(lambda x: { "transcribed_text": x['transcribed_text'].strip() }).to_csv('ds_failed_text_only.csv', index=False)

Map: 100%|██████████| 35/35 [00:00<00:00, 6195.69 examples/s]
Creating CSV from Arrow format: 100%|██████████| 1/1 [00:00<00:00, 1163.47ba/s]


1447

In [49]:
failed_gen_1 = ds_failed.filter(lambda r: r['error'] == 'no_time')[9]
print(failed_gen_1['text'])
print(failed_gen_1['hour'], failed_gen_1['minute'])

Audio(failed_gen_1['audio'], rate=24000)

Filter: 100%|██████████| 35/35 [00:00<00:00, 5243.82 examples/s]

The time is five o'clock PM.
17 0





In [60]:
failed_gen_1 = ds_failed.filter(lambda r: r['error'] == 'failed_prefix')[7]
print(failed_gen_1['text'])
print(failed_gen_1['hour'], failed_gen_1['minute'])

Audio(failed_gen_1['audio'], rate=24000)

What time is 5:28 AM?
5 28
