In [2]:
import pandas as pd
import torch as th
import numpy as np
from Levenshtein import distance

In [3]:
act = pd.read_excel('wer-tools/act.xlsx', encoding='UTF-8')
act['alternative_characters'] = act['alternative_characters'].apply(lambda x: x.split(','))
act = act['alternative_characters']

In [4]:
def text_to_df(path_pred, path_gold):
    with open(path_pred, 'r', encoding='UTF-8-sig') as f:
        lines_pred = f.readlines()
    with open(path_gold, 'r', encoding='UTF-8-sig') as f:
        lines_gold = f.readlines()
    df = pd.DataFrame(columns=['pred', 'gold'])
    df['pred'] = lines_pred
    df['gold'] = lines_gold
    df['pred'] = df['pred'].apply(lambda w: w.replace('\n', '').replace(' ', ''))
    df['gold'] = df['gold'].apply(lambda w: w.replace('\n', '').replace(' ', ''))
    return df

In [5]:
def WER(df):   
    
    # calculate MED
    dist = df.apply(lambda x: distance(x['pred'], x['gold']), axis=1)
    
    # zero MED means correct
    correct = np.sum(dist == 0)
    n = len(df)
    wer = 1 - (correct / n)
    
    # extract those with MED 1 or 2 and look up the ACT
    for_act = df[(dist > 0) & (dist <= 2)]
    print('The number of pred/gold pairs with MED of 1 or 2 is', len(for_act))
    correct_act = np.sum(for_act.apply(lambda x: look_up_ACT(x['pred'], x['gold']), axis=1))
    print('The number of replaceable names is', correct_act)
    wer_act = 1 - ((correct + correct_act) / n)
    
    return {
        'wer': wer,
        'wer-act': wer_act
    }


In [6]:
def look_up_ACT(pred, gold):
    
    # Method that examines if pred and gold are equivalent 
    # after looking up Alternating Character Table.
    # Check of the MED is done before this method.
    
    global act # each cell contains a list of alternating chars
    
    # The primary assumption requires pred and gold to be of the same length
    if not len(pred) == len(gold):
        return False
    
    for i in range(len(pred)):
        # everytime find two distinct characters at the same position, check the table
        if not pred[i] == gold[i]:
            replaceable = any(act.apply(lambda x: (pred[i] in x) and (gold[i] in x)))
            if not replaceable:
                return False
    
    # all the distinct characters are 'replaceable' in ACT
    return True

In [7]:
ch_dev = text_to_df('baseline_en2ch/result/bs_dev.txt', 'data/ch_dev.txt')
ch_tst = text_to_df('baseline_en2ch/result/bs_tst.txt', 'data/ch_tst.txt')

In [8]:
WER(ch_dev)

The number of pred/gold pairs with MED of 1 or 2 is 1561
The number of replaceable names is 128


{'wer': 0.28242964996568287, 'wer-act': 0.2604667124227865}

In [9]:
WER(ch_tst)

The number of pred/gold pairs with MED of 1 or 2 is 1577
The number of replaceable names is 122


{'wer': 0.2834591626630062, 'wer-act': 0.26252573781743305}

In [None]:
pinyin_dev_gold = text_to_df('data/pinyin_dev.txt')
pinyin_tst_gold = text_to_df('data/pinyin_tst.txt')
pinyin_dev_pred = text_to_df('baseline_en2pinyin/result/bs_dev.txt')
pinyin_tst_pred = text_to_df('baseline_en2pinyin/result/bs_tst.txt')

In [None]:
pinyin_no_tone_dev_gold = text_to_df('data/pinyin_no_tone_dev.txt')
pinyin_no_tone_tst_gold = text_to_df('data/pinyin_no_tone_tst.txt')
pinyin_no_tone_dev_pred = text_to_df('baseline_en2pinyin_no_tone/result/bs_dev.txt')
pinyin_no_tone_tst_pred = text_to_df('baseline_en2pinyin_no_tone/result/bs_tst.txt')

In [None]:
wer = pd.DataFrame(columns=['dev', 'tst'], index=['ch', 'pinyin', 'pinyin_no_tone'])

In [None]:
wer.loc['ch'] = [WER(ch_dev_gold, ch_dev_pred), WER(ch_tst_gold, ch_tst_pred)]
wer.loc['pinyin'] = [WER(pinyin_dev_gold, pinyin_dev_pred), WER(pinyin_tst_gold, pinyin_tst_pred)]
wer.loc['pinyin_no_tone'] = [WER(pinyin_no_tone_dev_gold, pinyin_no_tone_dev_pred), WER(pinyin_no_tone_tst_gold, pinyin_no_tone_tst_pred)]

In [None]:
wer

In [None]:
ch_dev_pred_bridge = text_to_df('baseline_en2ch/with_bridge/result/bs_dev.txt')
ch_tst_pred_bridge = text_to_df('baseline_en2ch/with_bridge/result/bs_tst.txt')

In [None]:
print([WER(ch_dev_gold, ch_dev_pred_bridge), WER(ch_tst_gold, ch_tst_pred_bridge)])

In [None]:
pd.read_csv('ipa.csv', encoding='UTF-8')