In [7]:
from jiwer import wer, cer
import pandas as pd
import re
from tqdm import tqdm

In [8]:
def pmr(gt, pred):
    gt_words = gt.split()
    pred_words = pred.split()
    length = min(len(gt_words), len(pred_words))
    matches = sum(1 for i in range(length) if gt_words[i] == pred_words[i])
    
    if(length == 0): 
        print('length 0')
        return 0
    return matches / length 

In [9]:
SYMSPELL_DIR = '../../text-processing/algorithm/symspell_res'
GT_DIR = '../../data/raw/ground_truth'
BASELINE_DIR = '../../data/raw/ocr_result'

In [10]:
with open('../eval_list.txt', 'r') as file:
    content = file.read()

test_files = []
for file in content.split('\n'):
    test_files.append(file.split('.')[0])
len(test_files)

100

In [11]:
# baseline
baseline_cer, baseline_wer, baseline_pmr  = [], [], []
# refinement symspell
symspell_cer, symspell_wer, symspell_pmr  = [], [], []

In [12]:
def read_file(path):
    try:
        return open(path, 'r', encoding='utf-8').read()
    except UnicodeDecodeError:
        return open(path, 'r', encoding='utf-8', errors='ignore').read()

In [13]:
for filename in tqdm(test_files):
    baseline = read_file(f'{BASELINE_DIR}/ocr_{filename}.txt')
    gt = read_file(f'{GT_DIR}/gt_{filename}.txt')
    symspell_str = read_file(f'{SYMSPELL_DIR}/res_{filename}.txt')

    baseline = re.sub(r"\s+", " ", baseline.replace("\n", " ")).strip().lower()
    gt = re.sub(r"\s+", " ", gt.replace("\n", " ")).strip().lower()
    symspell_str = re.sub(r"\s+", " ", symspell_str.replace("\n", " ")).strip().lower()

    if(len(gt) == 0): print(filename)

    baseline_wer.append(wer(gt, baseline))
    baseline_cer.append(cer(gt, baseline))
    baseline_pmr.append(pmr(gt, baseline))

    symspell_wer.append(wer(gt, symspell_str))
    symspell_cer.append(cer(gt, symspell_str))
    symspell_pmr.append(pmr(gt, symspell_str))
    

  0%|          | 0/100 [00:00<?, ?it/s]

100%|██████████| 100/100 [00:01<00:00, 51.67it/s]


In [14]:
data = {
  "name": test_files,
  "baseline_wer": baseline_wer,
  "baseline_cer" : baseline_cer,
  "baseline_pmr" : baseline_pmr,
  "symspell_wer": symspell_wer,
  "symspell_cer" : symspell_cer,
  "symspell_pmr" : symspell_pmr,
}

df = pd.DataFrame(data)
df.head()

Unnamed: 0,name,baseline_wer,baseline_cer,baseline_pmr,symspell_wer,symspell_cer,symspell_pmr
0,522,0.16129,0.113295,0.290323,0.169355,0.116763,0.290323
1,479,21.230769,16.631868,0.0,21.230769,16.538462,0.0
2,528,1.166667,0.885895,0.362745,1.166667,0.881141,0.362745
3,365,0.362694,0.294331,0.005181,0.362694,0.294331,0.005181
4,478,1.12782,1.413115,0.0,1.12782,1.414208,0.0


In [15]:
df.describe()

Unnamed: 0,baseline_wer,baseline_cer,baseline_pmr,symspell_wer,symspell_cer,symspell_pmr
count,100.0,100.0,100.0,100.0,100.0,100.0
mean,0.432745,0.319161,0.172997,0.437868,0.319356,0.172238
std,2.128334,1.668447,0.28395,2.127372,1.659033,0.282862
min,0.0,0.0,0.0,0.003448,0.000449,0.0
25%,0.042914,0.023549,0.007589,0.049335,0.024243,0.007589
50%,0.078184,0.041788,0.017484,0.085179,0.044333,0.017484
75%,0.21019,0.163096,0.241223,0.216763,0.163512,0.236121
max,21.230769,16.631868,1.0,21.230769,16.538462,1.0


In [16]:
df.to_csv('symspell_final_result.csv',index=False)