# Prompt selection and testing

This notebook chooses the most appropriate prompt and prompt structure for the OCR correction. 

In [70]:
#import config  # Import your config.py file this contains you openai api key
import pandas as pd
import numpy as np
import os
from llm_comparison_toolkit import RateLimiter, get_response_openai, get_response_anthropic,  create_config_dict_func, use_df_to_call_llm_api, compare_request_configurations
from evaluate import load
from evaluation_funcs import evaluate_correction_performance, evaluate_correction_performance_folders, get_metric_error_reduction
import seaborn as sns
import matplotlib.pyplot as plt
from helper_functions import files_to_df_func, evaluate_ner, calculate_entity_similarity
from transformers import AutoTokenizer, AutoModelForTokenClassification
from transformers import pipeline
import re


dev_transcripts = 'data/dev_data_transcript'

#load the dev and test sets for prompt development and selection
dev_data_df = pd.read_csv('data/dev_data_raw.csv')
test_data_df = pd.read_csv('data/test_data_raw.csv')

# Explore different system prompts

This explores a range of system prompt to find the one that appears to work the best, we use gpt4 as the baseline model.

Although there is no comparison with all models we do test gpt3.5 gpt4, clause haiku and claude opus, in addition we put the prompt in the system message and the prompt after the text becuase this may affect the quality of the result. When the prompt is after the text the response has "nosm_" (no system message) appended to the file name

In [71]:
#Create a modular set of system messages that can be combined in different ways
basic_prompt = "Please recover the text from the corrupted OCR."
expertise_prompt = "You are an expert in post-OCR correction of documents."
recover_prompt = "Using the context available from the text please recover the most likely original text from the corrupted OCR."
publication_context_prompt = "The text is from an english newspaper in the 1800's."
text_context_prompt = "The text may be an advert or article and may be missing the beggining or end."
additional_instructions_prompt = "Do not add any text, commentary, or lead in sentences beyond the recovered text. Do not add a title, or any introductions."

#combine all the message parts into a variety of system messages, a tuple is used where 0 is the name of the message and 1 is the message itself
#N.B. This is not and exhaustive combination as that would be very expensive and likley not yield significantly better results
system_messages_list = [
('basic_prompt', basic_prompt),
('expert_basic_prompt', expertise_prompt + ' '+ basic_prompt),
('expert_recover_prompt', expertise_prompt + ' '+ recover_prompt),
('expert_recover_publication_prompt', expertise_prompt + ' '+ recover_prompt + ' ' + publication_context_prompt),
('expert_recover_text_prompt', expertise_prompt + ' '+ recover_prompt + ' ' + text_context_prompt),
('expert_recover_publication_text_prompt', expertise_prompt + ' '+ recover_prompt + ' ' + publication_context_prompt + ' ' + text_context_prompt),
('expert_recover_instructions_prompt', expertise_prompt + ' '+ recover_prompt + ' ' + additional_instructions_prompt),
('full_context_prompt', expertise_prompt + ' '+ recover_prompt + ' ' + publication_context_prompt + ' ' + text_context_prompt+ ' ' + additional_instructions_prompt)
]

The below function is used to make the creation of the config dictionaries for the test more compact and increase readability

In [72]:
def create_message_test_configs(system_messages_list, get_response_func, engine):
    message_test_configs = []
    for iter_system_message in system_messages_list:
        message_test_configs.append(
            create_config_dict_func(
                get_response_func=get_response_func,
                rate_limiter=RateLimiter(50000),
                engine=engine,
                system_message_template=iter_system_message[1],
                prompt_template="{content_html}",
                additional_args={'response_name': iter_system_message[0]}
            )
        )
        message_test_configs.append(
            create_config_dict_func(
                get_response_func=get_response_func,
                rate_limiter=RateLimiter(50000),
                engine=engine,
                system_message_template="",
                prompt_template="{content_html}" + f"""\n\n""" + iter_system_message[1],
                additional_args={'response_name': "nosm_"+iter_system_message[0]}
            )
        )
    return message_test_configs

## Create configs and run tests

In [73]:
#gpt configs
gpt3_prompt_testing_configs = create_message_test_configs(system_messages_list, get_response_openai, "gpt-3.5-turbo")
gpt4_prompt_testing_configs = create_message_test_configs(system_messages_list, get_response_openai, 'gpt-4-turbo-preview')

#claude configs
haiku_prompt_testing_configs = create_message_test_configs(system_messages_list, get_response_anthropic, "claude-3-haiku-20240307")
opus_prompt_testing_configs = create_message_test_configs(system_messages_list, get_response_anthropic, "claude-3-opus-20240229")

#run the experiment on all the prompt configs and save to the folder
compare_request_configurations(dev_data_df, 
                               gpt3_prompt_testing_configs + gpt4_prompt_testing_configs + haiku_prompt_testing_configs + opus_prompt_testing_configs,
                               folder_path='./data/dev_system_message_variants')

## Evaluate system prompt tests

We evaluate the system prompts below to see if thre is any significant difference between the prompts


In [74]:
wer = load("wer")
cer = load("cer")

In [75]:
raw_dev_ocr_scores = evaluate_correction_performance('data/dev_raw_ocr', dev_transcripts, wer, cer, 'raw_ocr')

In [76]:
corrected_folder = './data/dev_system_message_variants'

performance_eval = evaluate_correction_performance_folders(corrected_folder, dev_transcripts, wer, cer)

performance_eval =  performance_eval.loc[(performance_eval['File Name']!='slug_ar02501_periodical_pc_issue_tec_06121884_page_number_25.txt'),:]

performance_eval['type'] = performance_eval['type'].str.replace("claude-3-haiku-20240307", "haiku").replace("gpt-3.5-turbo", "gpt-3.5")

performance_eval['model'] = performance_eval['type'].str.split('_').str[-1]

In [77]:
test = get_metric_error_reduction(performance_eval, raw_dev_ocr_scores)

In [78]:
test.loc[test['type'].str.contains('opus')].groupby('type').describe().filter(regex = '50|mean').round(2).sort_values(('CER', '50%'))

Unnamed: 0_level_0,WER,WER,CER,CER,lev_dist,lev_dist
Unnamed: 0_level_1,mean,50%,mean,50%,mean,50%
type,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2
nosm_expert_recover_publication_text_prompt_claude-3-opus-20240229,56.55,74.7,-29.86,43.62,-29.53,45.0
nosm_expert_recover_publication_prompt_claude-3-opus-20240229,71.64,75.47,44.57,49.56,43.02,48.07
expert_recover_publication_prompt_claude-3-opus-20240229,71.07,77.36,41.74,55.14,39.58,52.66
full_context_prompt_claude-3-opus-20240229,68.79,77.39,25.49,55.41,25.91,56.0
expert_recover_text_prompt_claude-3-opus-20240229,72.19,77.36,44.01,56.22,42.49,53.72
basic_prompt_claude-3-opus-20240229,72.25,77.36,41.81,58.38,40.48,55.6
nosm_expert_recover_prompt_claude-3-opus-20240229,72.85,76.02,49.35,58.54,47.96,53.88
nosm_expert_recover_text_prompt_claude-3-opus-20240229,71.35,75.86,50.45,58.71,49.13,58.31
nosm_basic_prompt_claude-3-opus-20240229,70.73,76.02,47.31,59.73,45.54,49.76
expert_recover_publication_text_prompt_claude-3-opus-20240229,70.79,77.36,42.16,60.0,40.97,55.6


In [79]:
performance_eval2 = performance_eval.copy()
performance_eval2['type'] = performance_eval2['type'].str.replace("claude-3-haiku-20240307", "haiku").replace("gpt-3.5-turbo", "gpt-3.5")
performance_eval2['model'] = performance_eval2['type'].str.split('_').str[-1]
performance_eval2 = performance_eval2.loc[performance_eval2['model'].str.contains('gpt-4')]
performance_eval2.drop(columns = 'File Name').groupby(['type', 'model']).describe().filter(regex = '50|mean').round(2).sort_values(('lev_dist', '50%'))

Unnamed: 0_level_0,Unnamed: 1_level_0,WER,WER,CER,CER,lev_dist,lev_dist
Unnamed: 0_level_1,Unnamed: 1_level_1,mean,50%,mean,50%,mean,50%
type,model,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2
full_context_prompt_gpt-4-turbo-preview,gpt-4-turbo-preview,0.23,0.13,0.13,0.06,148.67,72.0
nosm_expert_recover_instructions_prompt_gpt-4-turbo-preview,gpt-4-turbo-preview,0.23,0.15,0.13,0.06,146.67,75.0
nosm_full_context_prompt_gpt-4-turbo-preview,gpt-4-turbo-preview,0.24,0.16,0.14,0.07,158.33,76.0
expert_recover_instructions_prompt_gpt-4-turbo-preview,gpt-4-turbo-preview,0.23,0.14,0.13,0.05,149.62,81.0
expert_recover_publication_prompt_gpt-4-turbo-preview,gpt-4-turbo-preview,0.26,0.16,0.15,0.05,160.71,85.0
nosm_expert_recover_prompt_gpt-4-turbo-preview,gpt-4-turbo-preview,0.25,0.16,0.15,0.06,168.38,105.0
expert_recover_prompt_gpt-4-turbo-preview,gpt-4-turbo-preview,0.26,0.21,0.15,0.09,192.76,110.0
nosm_expert_recover_text_prompt_gpt-4-turbo-preview,gpt-4-turbo-preview,0.25,0.14,0.14,0.06,153.86,110.0
expert_recover_publication_text_prompt_gpt-4-turbo-preview,gpt-4-turbo-preview,0.27,0.14,0.15,0.05,208.9,114.0
expert_recover_text_prompt_gpt-4-turbo-preview,gpt-4-turbo-preview,0.24,0.14,0.14,0.05,158.19,114.0


## Conclusions of the prompt test

It appears that the placing the prompt after the text instead of using the system prompt gives the best results. However, the prompts did give significantlty different performance.  I think that perhaps using the `full_context_prompt` and the `expert_recover_publication_prompt` with no system message and the prompt after the text may be the best option. This will require twice as much compute as I was planning to use.

# Create basic and no message responses

Having identified two different prompts and that the prompts appear to work better when places after the text we can now compare the different models


The below code creates the basic configuration dictionaries for each model and then fills in the with the two different prompt messages creating a single list of all basic prompt/model configurations. It then calls all the LLM's and saves the results.
This works in series so takes a while.

In [80]:
#Create the prompt/system message using the best performing from the previous section

full_prompt = "{content_html}"+f""" \n \n """ + f"""You are an expert in post-OCR correction of documents. Using the context available from the text please recover the most likely original text from the corrupted OCR. The text is from an english newspaper in the 1800's. The text may be an advert or article and may be missing the beggining or end. Do not add any text, commentary, or lead in sentences beyond the recovered text. Do not add a title, or any introductions."""

instruct_prompt = "{content_html}"+f""" \n \n """ + f"""You are an expert in post-OCR correction of documents. Using the context available from the text please recover the most likely original text from the corrupted OCR. Do not add any text, commentary, or lead in sentences beyond the recovered text. Do not add a title, or any introductions."""

boros_basic  = "{content_html}"+f""" \n \n """ +"Correct the text"

boros_complex  ="{content_html}"+f""" \n \n """ + f"""Please assist with reviewing and correcting errors in texts produced by automatic transcription (OCR) of historical documents.
Your task is to carefully examine the following text and correct any mistakes introduced by the OCR software. 
Do not write anything else than the corrected text."""

In [81]:
groq_alt_endpoint = {'alt_endpoint':{'base_url':'https://api.groq.com/openai/v1',
                     'api_key':os.getenv("GROQ_API_KEY")}}

basic_model_configs = pd.DataFrame({
    'get_response_func': [get_response_openai, get_response_openai, get_response_anthropic, get_response_anthropic, 
                          get_response_openai, get_response_openai, get_response_openai], 
    'engine': ['gpt-3.5-turbo', 'gpt-4-turbo-preview', "claude-3-haiku-20240307", "claude-3-opus-20240229", 
               'mixtral-8x7b-32768', 'llama2-70b-4096', 'gemma-7b-it'],
    'additional_args': [
        {}, {}, {}, {}, 
        groq_alt_endpoint, 
        groq_alt_endpoint, 
        groq_alt_endpoint
    ]
})

base_model_configs= []

for index, row in basic_model_configs.iterrows():
    #modify the response name for the type
    row['additional_args']['response_name'] = 'full_'
    base_model_configs.append(

        create_config_dict_func(
    get_response_func = row['get_response_func'],
    rate_limiter = RateLimiter(40000),
    engine = row['engine'],
    system_message_template = "",
    prompt_template =  full_prompt,
    additional_args=row['additional_args']
    )

    )

nosm_model_configs = []

for index, row in basic_model_configs.iterrows():
    #modify the response name for the type
    row['additional_args']['response_name'] = 'instruct_' 
    nosm_model_configs.append(

        create_config_dict_func(
    get_response_func = row['get_response_func'],
    rate_limiter = RateLimiter(40000),
    engine = row['engine'],
    system_message_template = "",
    prompt_template =  instruct_prompt ,
    additional_args=row['additional_args']
    )

    )

boros_list = [create_config_dict_func(
    get_response_func = get_response_openai,
    rate_limiter = RateLimiter(40000),
    engine = 'gpt-4-turbo-preview',
    system_message_template = "",
    prompt_template =  boros_complex ,
    additional_args={"response_name":"boros_complex_"}
),

create_config_dict_func(
    get_response_func = get_response_openai,
    rate_limiter = RateLimiter(40000),
    engine = 'gpt-4-turbo-preview',
    system_message_template = "",
    prompt_template =  boros_basic ,
    additional_args={"response_name":"boros_basic_"}
),

create_config_dict_func(
    get_response_func = get_response_anthropic,
    rate_limiter = RateLimiter(40000),
    engine = "claude-3-opus-20240229",
    system_message_template = "",
    prompt_template =  boros_basic ,
    additional_args={"response_name":"boros_complex_"}
)]




model_configs = base_model_configs + nosm_model_configs + boros_list

compare_request_configurations(dev_data_df, model_configs, folder_path='./data/dev_corrected_base')
    

In [82]:
corrected_folder = './data/dev_corrected_base'

performance_eval =  evaluate_correction_performance_folders(corrected_folder, dev_transcripts, wer, cer)

performance_eval =  performance_eval.loc[(performance_eval['File Name']!='slug_ar02501_periodical_pc_issue_tec_06121884_page_number_25.txt') &
                     (performance_eval['type']!='gpt3_boros_blank_gpt-3.5-turbo'),:]

# Evaluate the prompts across all models

On the smaller models, Full is worse than instruct on the larger models the reverse. Maybe this is related to ability to 'focus' or hold isntructions in memory?

In [83]:
test = get_metric_error_reduction(performance_eval, raw_dev_ocr_scores)

test.groupby('type').describe().filter(regex = '50|mean').round(2).sort_values(('lev_dist', '50%'))

Unnamed: 0_level_0,WER,WER,CER,CER,lev_dist,lev_dist
Unnamed: 0_level_1,mean,50%,mean,50%,mean,50%
type,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2
full__gemma-7b-it,-5.62,-2.51,-178.23,-20.73,-169.37,-19.76
instruct__gemma-7b-it,-5.04,1.08,-249.75,-20.41,-237.58,-15.52
instruct__mixtral-8x7b-32768,30.89,47.79,-116.89,-6.08,-110.08,-6.95
full__mixtral-8x7b-32768,26.93,41.79,-146.28,-4.27,-138.44,-4.19
full__llama2-70b-4096,-3.91,55.4,-290.94,-4.12,-289.14,-2.54
instruct__llama2-70b-4096,37.29,45.61,-52.95,3.0,-65.63,2.54
full__claude-3-haiku-20240307,52.62,61.15,-24.53,28.05,-20.74,25.43
boros_complex__claude-3-opus-20240229,59.86,70.85,-55.37,36.0,-49.38,33.77
full__gpt-3.5-turbo,54.55,70.75,26.31,39.7,23.73,38.77
instruct__claude-3-haiku-20240307,58.14,69.89,8.32,47.51,9.39,47.09


In [84]:
performance_eval2 = performance_eval.copy()
performance_eval2['type'] = performance_eval2['type'].str.replace("claude-3-haiku-20240307", "haiku").replace("gpt-3.5-turbo", "gpt-3.5")
performance_eval2['model'] = performance_eval2['type'].str.split('_').str[-1]

#The below line allows you to look at an individual model
#performance_eval2 = performance_eval2.loc[performance_eval2['model'].str.contains('gpt-4')]

performance_eval2.drop(columns = 'File Name').groupby(['type', 'model']).describe().filter(regex = '50|mean').round(2).sort_values(('lev_dist', '50%'))

Unnamed: 0_level_0,Unnamed: 1_level_0,WER,WER,CER,CER,lev_dist,lev_dist
Unnamed: 0_level_1,Unnamed: 1_level_1,mean,50%,mean,50%,mean,50%
type,model,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2
full__claude-3-opus-20240229,claude-3-opus-20240229,0.23,0.11,0.13,0.05,150.52,87.0
boros_basic__gpt-4-turbo-preview,gpt-4-turbo-preview,0.26,0.18,0.15,0.06,170.9,91.0
full__gpt-4-turbo-preview,gpt-4-turbo-preview,0.26,0.16,0.15,0.06,164.62,91.0
instruct__gpt-4-turbo-preview,gpt-4-turbo-preview,0.25,0.19,0.15,0.06,186.76,91.0
boros_complex__gpt-4-turbo-preview,gpt-4-turbo-preview,0.24,0.17,0.14,0.07,183.9,98.0
instruct__claude-3-opus-20240229,claude-3-opus-20240229,0.33,0.12,0.22,0.06,380.76,111.0
instruct__haiku,haiku,0.32,0.26,0.21,0.13,323.05,118.0
instruct__gpt-3.5-turbo,gpt-3.5-turbo,0.29,0.16,0.17,0.06,189.38,119.0
boros_complex__claude-3-opus-20240229,claude-3-opus-20240229,0.32,0.3,0.23,0.2,233.48,154.0
full__haiku,haiku,0.35,0.32,0.24,0.16,329.33,156.0


In [85]:
{'Llama 2 70B':'llama2-70b-4096	',
 'Gemma 7B':'gemma-7b-it',
 'Opus':'claude-3-opus-20240229',
 'Haiku':'haiku',
 'GPT-4':'gpt-4-turbo-preview',
 'Mixtral 8c7B':'mixtral-8x7b-32768'}

{'Llama 2 70B': 'llama2-70b-4096\t',
 'Gemma 7B': 'gemma-7b-it',
 'Opus': 'claude-3-opus-20240229',
 'Haiku': 'haiku',
 'GPT-4': 'gpt-4-turbo-preview',
 'Mixtral 8c7B': 'mixtral-8x7b-32768'}

In [86]:
pd.read_csv('data/benchmarks.csv')

Unnamed: 0.1,Unnamed: 0,Llama 2 70B,Gemma 7B,Opus,Haiku,GPT-4,GPT3.5,Mixtral 8x7B
0,MMLU,69.9,64.6,86.8,75.2,86.4,70.0,70.6
1,HellaSwag,87.1,82.2,95.4,85.9,93.3,85.5,86.7
2,ARC-C,85.1,61.9,96.6,89.2,96.3,85.2,85.8
3,WinoGrande,83.2,79.0,88.5,74.2,87.5,81.6,81.2
4,MBPP,49.8,44.4,86.4,80.4,,52.2,60.7
5,DROP,,,83.1,78.9,80.9,64.1,


# Test on the test set
confusing but you know what I mean, This is jsut to make sure it all works as expected

In [87]:
transcribed_files = 'data/transcription_returned_ocr/transcription_files'

corrected_folder = 'data/transcription_returned_ocr/temp_claude'

raw_folder = "/home/jonno/redigitalize/data/transcription_raw_ocr"

performance_eval =  evaluate_correction_performance_folders(corrected_folder, transcribed_files, wer, cer)

performance_eval =  performance_eval.loc[(performance_eval['File Name']!='slug_ar02501_periodical_pc_issue_tec_06121884_page_number_25.txt') &
                     (performance_eval['type']!='gpt3_boros_blank_gpt-3.5-turbo'),:]

In [88]:
performance_eval.describe()

Unnamed: 0,WER,CER,lev_dist
count,18.0,18.0,18.0
mean,0.462007,0.371882,480.388889
std,1.528589,1.360813,1549.942016
min,0.0,0.0,1.0
25%,0.023535,0.006015,35.0
50%,0.095863,0.029196,55.5
75%,0.151927,0.087611,125.25
max,6.57485,5.817863,6646.0


In [89]:
raw_dev_ocr_scores = evaluate_correction_performance(raw_folder, transcribed_files, wer, cer, 'raw_ocr')

error_reduction_df = get_metric_error_reduction(performance_eval, raw_dev_ocr_scores)

error_reduction_df.groupby('type').describe().filter(regex = '50|mean').round(2).sort_values(('lev_dist', '50%'))

Unnamed: 0_level_0,WER,WER,CER,CER,lev_dist,lev_dist
Unnamed: 0_level_1,mean,50%,mean,50%,mean,50%
type,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2
claude_temp_claude-3-opus-20240229,22.89,82.24,-231.82,78.64,-205.53,76.17


In [96]:


transcribed_data_set_df = files_to_df_func(transcribed_files)
raw_data_set_df = files_to_df_func(raw_folder)

raw_data_set_df = raw_data_set_df.loc[raw_data_set_df['file_name'].isin(transcribed_data_set_df['file_name'])]

tokenizer = AutoTokenizer.from_pretrained("dslim/bert-base-NER")

model = AutoModelForTokenClassification.from_pretrained("dslim/bert-base-NER")
def preprocess_text(text):
    # Example preprocessing: substitute multiple whitespaces with a single space
    text = re.sub(r'\s+', ' ', text)
    return text

perform_ner = pipeline("ner", model=model, tokenizer=tokenizer)

perform_ner = pipeline("ner", model=model, tokenizer=tokenizer)


def perform_ner_on_text(text, ner_pipeline):
    # Preprocess the text first
    preprocessed_text = preprocess_text(text)
    # Then pass the preprocessed text to the NER pipeline
    return ner_pipeline(preprocessed_text)

transcribed_data_set_df['ner_results'] = transcribed_data_set_df['content'].apply(perform_ner_on_text, ner_pipeline=perform_ner)

raw_data_set_df['ner_results_raw'] = raw_data_set_df['content'].apply(perform_ner_on_text, ner_pipeline=perform_ner)

Some weights of the model checkpoint at dslim/bert-base-NER were not used when initializing BertForTokenClassification: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight']
- This IS expected if you are initializing BertForTokenClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForTokenClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [97]:
#load on the transcribed data... this should be done within a loop as there are many models to test
LM_corrected_df = files_to_df_func(os.path.join(corrected_folder, 'claude_temp_claude-3-opus-20240229'))
temp_ner = transcribed_data_set_df.copy().merge(LM_corrected_df.loc[:,['content', 'file_name'] ].rename(columns={'content':'content_corrected'}), 
                                                on = 'file_name')

temp_ner['ner_results_corrected'] = temp_ner['content_corrected'].apply(perform_ner_on_text, ner_pipeline=perform_ner)

temp_ner['cosine_sim'] = temp_ner.apply(lambda row: calculate_entity_similarity(row['ner_results_corrected'],row['ner_results']), axis = 1)


print("Median cosine similarity:", temp_ner['cosine_sim'].median().round(2))

Median cosine similarity: 0.85


In [98]:
results, results_by_tag = evaluate_ner(temp_ner, 'ner_results', 'ner_results_corrected')

print(results)
print(results_by_tag)

{'ent_type': {'correct': 170, 'incorrect': 27, 'partial': 0, 'missed': 123, 'spurious': 152, 'possible': 320, 'actual': 349, 'precision': 0.4871060171919771, 'recall': 0.53125, 'f1': 0.5082212257100149}, 'partial': {'correct': 69, 'incorrect': 0, 'partial': 128, 'missed': 123, 'spurious': 152, 'possible': 320, 'actual': 349, 'precision': 0.38108882521489973, 'recall': 0.415625, 'f1': 0.3976083707025411}, 'strict': {'correct': 65, 'incorrect': 132, 'partial': 0, 'missed': 123, 'spurious': 152, 'possible': 320, 'actual': 349, 'precision': 0.18624641833810887, 'recall': 0.203125, 'f1': 0.19431988041853512}, 'exact': {'correct': 69, 'incorrect': 128, 'partial': 0, 'missed': 123, 'spurious': 152, 'possible': 320, 'actual': 349, 'precision': 0.1977077363896848, 'recall': 0.215625, 'f1': 0.2062780269058296}}
{'B-LOC': {'ent_type': {'correct': 25, 'incorrect': 2, 'partial': 0, 'missed': 20, 'spurious': 33, 'possible': 47, 'actual': 60, 'precision': 0.4166666666666667, 'recall': 0.5319148936170

In [99]:
temp_ner = transcribed_data_set_df.copy().merge(raw_data_set_df.loc[:,['ner_results_raw', 'file_name'] ].rename(columns={'content':'content_corrected'}), 
                                                on = 'file_name')

temp_ner['cosine_sim'] = temp_ner.apply(lambda row: calculate_entity_similarity(row['ner_results_raw'],row['ner_results']), axis = 1)

print("Median cosine similarity:", temp_ner['cosine_sim'].median().round(2))

results, results_by_tag = evaluate_ner(temp_ner, 'ner_results', 'ner_results_raw')

print(results)
print(results_by_tag)

Median cosine similarity: 0.76
{'ent_type': {'correct': 62, 'incorrect': 27, 'partial': 0, 'missed': 229, 'spurious': 214, 'possible': 318, 'actual': 303, 'precision': 0.20462046204620463, 'recall': 0.1949685534591195, 'f1': 0.19967793880837362}, 'partial': {'correct': 15, 'incorrect': 0, 'partial': 74, 'missed': 229, 'spurious': 214, 'possible': 318, 'actual': 303, 'precision': 0.1716171617161716, 'recall': 0.16352201257861634, 'f1': 0.16747181964573268}, 'strict': {'correct': 15, 'incorrect': 74, 'partial': 0, 'missed': 229, 'spurious': 214, 'possible': 318, 'actual': 303, 'precision': 0.04950495049504951, 'recall': 0.04716981132075472, 'f1': 0.04830917874396135}, 'exact': {'correct': 15, 'incorrect': 74, 'partial': 0, 'missed': 229, 'spurious': 214, 'possible': 318, 'actual': 303, 'precision': 0.04950495049504951, 'recall': 0.04716981132075472, 'f1': 0.04830917874396135}}
{'B-LOC': {'ent_type': {'correct': 17, 'incorrect': 2, 'partial': 0, 'missed': 28, 'spurious': 26, 'possible': 4

In [100]:
temp_ner['ner_results_raw'][1]

[{'entity': 'B-PER',
  'score': 0.9994386,
  'index': 1,
  'word': 'Robert',
  'start': 0,
  'end': 6},
 {'entity': 'I-PER',
  'score': 0.998922,
  'index': 2,
  'word': 'Alien',
  'start': 7,
  'end': 12},
 {'entity': 'B-ORG',
  'score': 0.42875966,
  'index': 4,
  'word': 'N',
  'start': 13,
  'end': 14},
 {'entity': 'I-ORG',
  'score': 0.4714057,
  'index': 7,
  'word': '##han',
  'start': 17,
  'end': 20},
 {'entity': 'I-LOC',
  'score': 0.4741024,
  'index': 9,
  'word': 'Hu',
  'start': 24,
  'end': 26},
 {'entity': 'B-LOC',
  'score': 0.9920323,
  'index': 35,
  'word': 'A',
  'start': 77,
  'end': 78},
 {'entity': 'B-LOC',
  'score': 0.8442701,
  'index': 36,
  'word': '##yr',
  'start': 78,
  'end': 80},
 {'entity': 'B-LOC',
  'score': 0.9996904,
  'index': 38,
  'word': 'Scotland',
  'start': 83,
  'end': 91},
 {'entity': 'I-ORG',
  'score': 0.4026195,
  'index': 45,
  'word': '##K',
  'start': 103,
  'end': 104},
 {'entity': 'B-PER',
  'score': 0.9894314,
  'index': 65,
  'w