In [2]:
import numpy as np
import nltk
import json
from nltk.translate.meteor_score import meteor_score
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
nltk.download('punkt')
nltk.download('wordnet')
nltk.download('omw-1.4')
import os 

[nltk_data] Downloading package punkt to
[nltk_data]     /home/shixuan_leong/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package wordnet to
[nltk_data]     /home/shixuan_leong/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!
[nltk_data] Downloading package omw-1.4 to
[nltk_data]     /home/shixuan_leong/nltk_data...
[nltk_data]   Package omw-1.4 is already up-to-date!


##### BLEU score

In [3]:
#NOTE: will overpenalize shorter sentences

def calculate_bleu_score(generated: str, 
                         ground_truth: str):
        smoothing_function = SmoothingFunction().method1
        generated = nltk.word_tokenize(generated)
        ground_truth = nltk.word_tokenize(ground_truth)
        bleu_score = sentence_bleu([ground_truth], 
                                   generated, 
                                   smoothing_function=smoothing_function, 
                                   weights=(1, 0, 0, 0)) 
        return round(bleu_score, 3)

##### TER (normalized TER) scores

In [4]:
#TER
def calculate_ter_score(generated: str, 
                        ground_truth: str):
    edits = edit_distance(ground_truth, generated)
    ref_length = len(ground_truth.split())

    ter_score = edits / ref_length if ref_length > 0 else float('inf')
    return ter_score


def edit_distance(ref, hyp):
    ref_words = ref.split()
    hyp_words = hyp.split()

    d = [[0] * (len(hyp_words) + 1) for _ in range(len(ref_words) + 1)]

    for i in range(len(ref_words) + 1):
        d[i][0] = i
    for j in range(len(hyp_words) + 1):
        d[0][j] = j

    for i in range(1, len(ref_words) + 1):
        for j in range(1, len(hyp_words) + 1):
            cost = 0 if ref_words[i - 1] == hyp_words[j - 1] else 1
            d[i][j] = min(d[i - 1][j] + 1,   
                           d[i][j - 1] + 1,    
                           d[i - 1][j - 1] + cost)  
    ter_score = d[len(ref_words)][len(hyp_words)]

    return round(ter_score, 3)

#normalized TER
def calculate_normalized_ter(generated: str, 
                             ground_truth: str):
    """Calculate the normalized Translation Edit Rate (TER)."""
    gt_tokens = ground_truth.split()
    generated_tokens = generated.split()

    # Calculate the Levenshtein distance
    edit_distance = levenshtein_distance(gt_tokens, generated_tokens)

    # Calculate normalized TER
    if len(gt_tokens) + edit_distance == 0:  # To avoid division by zero
        return 0.0
    normalized_ter = edit_distance / (len(gt_tokens) + edit_distance)

    return round(normalized_ter, 3)

def levenshtein_distance(gt_tokens, ocr_tokens):
    """Calculate the Levenshtein distance between two lists of tokens."""
    if len(gt_tokens) < len(ocr_tokens):
        return levenshtein_distance(ocr_tokens, gt_tokens)

    # Create a distance matrix
    distances = np.zeros((len(gt_tokens) + 1, len(ocr_tokens) + 1))

    # Initialize the distance matrix
    for i in range(len(gt_tokens) + 1):
        distances[i][0] = i
    for j in range(len(ocr_tokens) + 1):
        distances[0][j] = j

    # Compute the distances
    for i in range(1, len(gt_tokens) + 1):
        for j in range(1, len(ocr_tokens) + 1):
            cost = 0 if gt_tokens[i - 1] == ocr_tokens[j - 1] else 1
            distances[i][j] = min(
                distances[i - 1][j] + 1,    # Deletion
                distances[i][j - 1] + 1,    # Insertion
                distances[i - 1][j - 1] + cost  # Substitution
            )

    return distances[len(gt_tokens)][len(ocr_tokens)]


##### METEOR score

In [5]:
def calculate_meteor_score(generated: str, 
                           ground_truth: str):
  ground_truth = nltk.word_tokenize(ground_truth)
  generated = nltk.word_tokenize(generated)
  score = meteor_score([ground_truth], generated)
  return round(score, 3)

##### Batch process

In [6]:
def compare_groundtruth_ocr(groundtruth_dict_path, 
                            ocr_path): 
    with open(groundtruth_dict_path, 'r') as file:
        groundtruth_dict = json.load(file)

    with open(ocr_path, 'r') as file:
        ocr_caption = json.load(file)

    groundtruth_log = {}
    ocr_log = []

    for key, value in groundtruth_dict.items(): 
        try: 
            ocr_value = ocr_caption[key]
            bleu_score = calculate_bleu_score(ocr_value, value)
            normalized_ter_score = calculate_normalized_ter(ocr_value, value)
            meteorscore= calculate_meteor_score(ocr_value, value)
            groundtruth_log[key]=bleu_score, normalized_ter_score, meteorscore
            ocr_log.append(key)
            
        except KeyError: 
            continue 

    uncompared_gt_captions = [key for key, _ in groundtruth_dict.items() if key not in groundtruth_log]
    uncompared_ocr_captions = [ocr_key for ocr_key, _ in ocr_caption.items() if ocr_key not in ocr_log]

    if not uncompared_gt_captions and not uncompared_ocr_captions:
        return groundtruth_log, None
    
    uncompared_captions= {
    "uncompared_gt_captions": uncompared_gt_captions,
    "uncompared_ocr_captions": uncompared_ocr_captions
}

    return groundtruth_log, uncompared_captions

In [9]:
field_key = "organic_synthesis/"
gt_caption_dir = os.path.join("../ocr_eval_results/captions_groundtruth/", field_key)
ocr_caption_dir = os.path.join("../ocr_eval_results/captions_ocr/", field_key)
bleu_score_dir = os.path.join("../ocr_eval_results/ocr_eval/", field_key)

for file in os.listdir(gt_caption_dir): 
    if file.endswith('_cleaned.json'): 
        ocr_file = file.replace('_cleaned.json', '.json_dict.json').lower()
        gt_file_path = os.path.join(gt_caption_dir, file)
        try: 
            ocr_file_path = os.path.join(ocr_caption_dir, ocr_file)
            print(f"processing {file}")
            groundtruth_log, uncompared_captions = compare_groundtruth_ocr(gt_file_path, ocr_file_path)
            print(f"done processing {file}")

            response_name = file.replace('_cleaned.json', '_bleuscore.json')
            output_path = os.path.join(bleu_score_dir, response_name)
            with open(output_path, 'w') as f:
                json.dump(groundtruth_log, f, indent=4)
            print(f"done saving {file}")

            if uncompared_captions is not None: 
                uncompared_captions_name = file.replace('_cleaned.json', '_recheck.json')
                output_path2 = os.path.join(bleu_score_dir, uncompared_captions_name)
                with open(output_path2, 'w') as f:
                    json.dump(uncompared_captions, f, indent=4)
                print(f"done saving unprocessed captions for {file}")
                print()
            else: 
                print(f"yay! there's no uncompared captions")
                print()
        except Exception as e: 
            print(f"{e}")
                





processing 10.1016_j.gresc.2021.09.003_cleaned.json
done processing 10.1016_j.gresc.2021.09.003_cleaned.json
done saving 10.1016_j.gresc.2021.09.003_cleaned.json
done saving unprocessed captions for 10.1016_j.gresc.2021.09.003_cleaned.json

processing 10.1039_d0sc00031k_cleaned.json
done processing 10.1039_d0sc00031k_cleaned.json
done saving 10.1039_d0sc00031k_cleaned.json
yay! there's no uncompared captions

processing 10.1039_d4sc02969k_cleaned.json
done processing 10.1039_d4sc02969k_cleaned.json
done saving 10.1039_d4sc02969k_cleaned.json
yay! there's no uncompared captions

processing 10.1021_acs.joc.8b00486_cleaned.json
done processing 10.1021_acs.joc.8b00486_cleaned.json
done saving 10.1021_acs.joc.8b00486_cleaned.json
done saving unprocessed captions for 10.1021_acs.joc.8b00486_cleaned.json

processing 10.1039_c5sc00238a_cleaned.json
done processing 10.1039_c5sc00238a_cleaned.json
done saving 10.1039_c5sc00238a_cleaned.json
yay! there's no uncompared captions

processing 10.1039

##### additional tests for unprocessed captions

In [None]:
value = ""
ocr_value = ""

bleu_score = calculate_bleu_score(ocr_value, value)
normalized_ter_score = calculate_normalized_ter(ocr_value, value)
meteorscore= calculate_meteor_score(ocr_value, value)
[bleu_score, normalized_ter_score, meteorscore]