In [1]:
import pandas as pd
import torch as th
import numpy as np
from Levenshtein import distance

In [2]:
act = pd.read_excel('act_data/act.xlsx', encoding='UTF-8')
act['alternative_characters'] = act['alternative_characters'].apply(lambda x: x.split(','))
act = act['alternative_characters']

In [21]:
def text_to_df(path_pred, path_gold):
    with open(path_pred, 'r', encoding='UTF-8-sig') as f:
        lines_pred = f.readlines()
    with open(path_gold, 'r', encoding='UTF-8-sig') as f:
        lines_gold = f.readlines()
    df = pd.DataFrame(columns=['pred', 'gold'])
    df['pred'] = lines_pred
    df['gold'] = lines_gold
    df['pred'] = df['pred'].apply(lambda w: w.replace('\n', '').replace(' ', ''))
    df['gold'] = df['gold'].apply(lambda w: w.replace('\n', '').replace(' ', ''))
    return df

In [25]:
def ACC(df):   
    
    # calculate MED
    dist = df.apply(lambda x: distance(x['pred'], x['gold']), axis=1)
    
    # zero MED means correct
    correct = np.sum(dist == 0)
    n = len(df)
    acc = correct / n
    
    # extract those with MED 1 or 2 and look up the ACT
    for_act = df[(dist > 0) & (dist <= 2)]
    print('The number of pred/gold pairs with MED of 1 or 2 is', len(for_act))
    correct_act = np.sum(for_act.apply(lambda x: look_up_ACT(x['pred'], x['gold']), axis=1))
    print('The number of replaceable names is', correct_act)
    acc_act = (correct + correct_act) / n
    
    return {
        'acc': acc,
        'acc-act': acc_act,
        'replaced': str(correct_act) + '/' + str(len(for_act)) 
    }


In [26]:
def look_up_ACT(pred, gold):
    
    # Method that examines if pred and gold are equivalent 
    # after looking up Alternating Character Table.
    # Check of the MED is done before this method.
    
    global act # each cell contains a list of alternating chars
    
    # The primary assumption requires pred and gold to be of the same length
    if not len(pred) == len(gold):
        return False
    
    for i in range(len(pred)):
        # everytime find two distinct characters at the same position, check the table
        if not pred[i] == gold[i]:
            replaceable = any(act.apply(lambda x: (pred[i] in x) and (gold[i] in x)))
            if not replaceable:
                return False
    
    # all the distinct characters are 'replaceable' in ACT
    return True

### For BiDeep

In [51]:
num = 4
exp_path = '../nmt/experiments/exp' + str(num)
result = ACC(text_to_df(exp_path + '/test_pred.txt', exp_path + '/test_ref.txt'))
print(result)
with open(exp_path + '/acc.txt', 'a+') as f:
    f.write('ACC-ACT: ' + str(result['acc-act']) + '\n')
    f.write('Replaced: ' + result['replaced'])

The number of pred/gold pairs with MED of 1 or 2 is 1631
The number of replaceable names is 151
{'acc': 0.7101921757035004, 'acc-act': 0.7361015785861359, 'replaced': '151/1631'}


### For OpenNMT

In [49]:
num = 1
result = ACC(text_to_df('../nmt/onmt_experiments/exp' + str(num) + '/bs_tst.txt', exp_path + '/test_ref.txt'))
result_valid = ACC(text_to_df('../nmt/onmt_experiments/exp' + str(num) + '/bs_dev.txt', exp_path + '/valid_ref.txt'))
print(result)
with open('../nmt/onmt_experiments/exp' + str(num) + '/result.txt', 'a+') as f:
    f.write('[OpenNMT]\n')
    f.write('Valid Acc: ' + str(result_valid['acc']) + '\n')
    f.write('Test ACC: ' + str(result['acc']) + '\n')
    f.write('Test ACC-ACT: ' + str(result['acc-act']) + '\n')
    f.write('Replaced: ' + result['replaced'])

The number of pred/gold pairs with MED of 1 or 2 is 1554
The number of replaceable names is 134
The number of pred/gold pairs with MED of 1 or 2 is 1518
The number of replaceable names is 133
{'acc': 0.7208304735758407, 'acc-act': 0.7438229238160604, 'replaced': '134/1554'}
