In [1]:
import json
from pathlib import Path
from collections import Counter, defaultdict
from itertools import chain
from tqdm import tqdm

import pandas as pd
import numpy as np
from spellchecker import SpellChecker

from hw_asr.base.base_text_encoder import BaseTextEncoder

# Load index

In [2]:
index_directory = Path('saved_server/index/')
assert index_directory.exists()

In [3]:
observations_by_dataset = {}
for path in index_directory.iterdir():
    with open(path, 'r') as f:
        observations_by_dataset[path.name.removesuffix('_index.json')] = json.load(f)

# Compute statistics

In [4]:
q = [99, 95]
df = pd.DataFrame(columns=(
    ['n_samples', 'max_audio_len'] + 
    [f'audio_len_{q_i}%' for q_i in q] + 
    ['max_text_len'] +
    [f'max_text_len_{q_i}%' for q_i in q] +
    ['n_hours']
))

for name, ind in observations_by_dataset.items():
    n_samples = len(ind)

    audio_len = [x['audio_len'] for x in ind]
    max_audio_len = max(audio_len)

    text_len = [len(x['text']) for x in ind]
    max_text_len = max(text_len)

    df.loc[name, :] = (
        [n_samples, max(audio_len)] +
        [np.quantile(audio_len, q=q_i/100) for q_i in q] +
        [max(text_len)] +
        [np.quantile(text_len, q=q_i/100) for q_i in q] +
        [sum(audio_len) / 60 / 60]
    )
df

Unnamed: 0,n_samples,max_audio_len,audio_len_99%,audio_len_95%,max_text_len,max_text_len_99%,max_text_len_95%,n_hours
dev-clean,2703,32.645,23.755,16.4135,516,366.96,256.9,5.387811
dev-other,2864,35.155,22.17955,15.02925,427,307.48,219.0,5.121185
test-clean,2620,34.955,25.47575,17.842,576,363.05,261.0,5.403467
test-other,2939,34.51,21.4248,15.761,618,320.24,226.0,5.341547
train-clean-100,28539,24.525,16.7031,16.085,398,289.0,262.0,100.59088
train-clean-360,104014,29.735,16.67,16.075,524,289.0,264.0,363.605608
train-other-500,148688,27.92,16.685,16.06,453,285.0,258.0,496.85791


# Calculate word counts

In [5]:
def get_words_from_index(ind: list[dict]):
    cnt = Counter()
    for observation in ind:
        cnt.update(BaseTextEncoder.normalize_text(observation['text']).split())
    return cnt


words_by_dataset = {name: get_words_from_index(observations) for name, observations in observations_by_dataset.items()}

In [6]:
train_names = [name for name in observations_by_dataset if 'train' in name]
val_names = [name for name in observations_by_dataset if 'dev' in name]
test_names = [name for name in observations_by_dataset if 'test' in name]
assert sorted(train_names + val_names + test_names) == sorted(observations_by_dataset.keys())

In [7]:
words_in_train = sum([words_by_dataset[name] for name in train_names], start=Counter())
words_in_val = sum([words_by_dataset[name] for name in val_names], start=Counter())
words_in_test = sum([words_by_dataset[name] for name in test_names], start=Counter())

In [8]:
df = pd.DataFrame(columns=['train', 'val', 'test'])

In [9]:
df.loc['unique_words', :] = len(words_in_train), len(words_in_val), len(words_in_test)
df.loc['total_words', :] = sum(words_in_train.values()), sum(words_in_val.values()), sum(words_in_test.values())
df

Unnamed: 0,train,val,test
unique_words,86599,11739,11836
total_words,9403555,105350,104919


In [10]:
unique_test_words = set(words_in_test) - set(words_in_train) - set(words_in_val)
print(f'Number of unique words in test: {len(unique_test_words)}')
unique_test_words_count = sum([words_in_test[word] for word in unique_test_words])
print(f'Unique words in test count: {unique_test_words_count} ({unique_test_words_count / df.loc["total_words", "test"] * 100:.4f}%)')

Number of unique words in test: 517
Unique words in test count: 711 (0.6777%)


# Construct SpellChecker

In [11]:
def build_dictionary(words_in_train: Counter, words_in_val: Counter):
    spell = SpellChecker()
    spell.word_frequency.remove_words(spell.word_frequency.dictionary)
    words = []
    for word, count in chain(words_in_train.items(), words_in_val.items()):
        words += [word]  * count
    spell.word_frequency.load_words(words)
    assert spell.word_frequency.unique_words == len(set(words_in_train) | set(words_in_val))
    return spell


def correct(text: str, spell: SpellChecker):
    result = []
    for word in text.split():
        corrected = spell.correction(word)
        if corrected:
            result.append(corrected)
        else:
            result.append(word)
    return ' '.join(result)


spell = build_dictionary(words_in_train, words_in_val)

# Autocorrect

In [12]:
predictions_directory = Path('saved_server/output/')
assert predictions_directory.exists()
predictions = {}
for path in predictions_directory.iterdir():
    with open(path, 'r') as f:
        predictions[f'{path.name.removesuffix("_output.json")}'] = json.load(f)
predictions.keys()

dict_keys(['test-clean', 'test-other'])

In [13]:
from hw_asr.metric.utils import calc_cer, calc_wer
from textblob import TextBlob


def autocorrect_sentence(sentence):
    blob = TextBlob(sentence)
    corrected_sentence = blob.correct()
    return str(corrected_sentence)


def add_predictions(observation: dict[str, str]):
    observation = observation.copy()
    observation['pred_text_corrected_my'] = correct(observation['pred_text_beam_search'], spell)
    observation['pred_text_corrected_library'] = autocorrect_sentence(observation['pred_text_beam_search'])
    return observation


def compute_metrics(predictions: dict[str, list[dict[str, str]]]) -> dict[str, dict[str, dict[str, float]]]:
    metrics_by_part = {}
    for part, observations in predictions.items():
        metrics = {'WER': defaultdict(list), 'CER': defaultdict(list)}
        for observation in tqdm(observations):
            observation = add_predictions(observation)
            ground_truth = observation['ground_truth']
            for name, prediction in observation.items():
                if name != 'ground_truth':
                    for metric, func in zip(['WER', 'CER'], [calc_wer, calc_cer]):
                        metrics[metric][name].append(func(ground_truth, prediction) * 100)
                        metrics[metric][name].append(func(ground_truth, prediction) * 100)
        metrics_by_part[part] = {
            metric: {
                name: np.array(values).mean() for name, values in metric_by_predictions.items()
            } for metric, metric_by_predictions in metrics.items()
        }
    return metrics_by_part


def print_metrics(metrics: dict[str, dict[str, dict[str, float]]]):
    for part, metrics_by_part in metrics.items():
        print('=' * 10)
        print(part)
        df = pd.DataFrame(columns=metrics_by_part.keys())
        for metric_name, metric_by_predictions in metrics_by_part.items():
            for predictions_name, metric_value in metric_by_predictions.items():
                df.loc[predictions_name, metric_name] = metric_value
        print(df)
        print('=' * 10)


metrics = compute_metrics(predictions)
print_metrics(metrics)

100%|██████████| 2620/2620 [07:37<00:00,  5.73it/s]
100%|██████████| 2939/2939 [13:34<00:00,  3.61it/s]


test-clean
                                   WER       CER
pred_text_argmax             14.286005  4.347183
pred_text_beam_search        13.994843   4.25308
pred_text_corrected_my       12.204572   4.31496
pred_text_corrected_library  13.407942  4.859289
test-other
                                   WER        CER
pred_text_argmax             31.420353  11.985951
pred_text_beam_search        30.795859  11.682252
pred_text_corrected_my       27.821106  11.995854
pred_text_corrected_library  28.315971  12.566642
