# NER using GPT-3.5

### Project name: Honos
Date: 24th May 2024

Author: Milindi Kodikara | Supervisor: Professor Karin Verspoor


Before running this notebook:
1. [Install Jupyter notebook](https://jupyter.org/install) 


2. [Setting up Azure OpenAI model](https://learn.microsoft.com/en-us/azure/ai-services/openai/how-to/working-with-models?tabs=powershell#model-updates)


3. [Setting up connection to GPT-3.5 using Azure OpenAI service](https://learn.microsoft.com/en-us/azure/ai-services/openai/quickstart?tabs=command-line%2Cpython-new&pivots=programming-language-python)
        - In the Environment variables section, instead of doing what is outlined in the link, add the `API_KEY`, `API-VERSION`, `ENDPOINT` and `DEPLOYMENT-NAME` into a `.env` file in the root folder.
        
4. Add the correct filename paths for `data` in Step 1 and gold annotated data filename for the `evaluate()` function in Step 4. 


In [None]:
import pandas as pd
import re
from datetime import datetime
import os
from openai import AzureOpenAI

from dotenv import load_dotenv
load_dotenv() 


### Step 1: Load and pre-process data and prompt library 


#### Step 1.1: Load datasets

In [None]:
# train_text.tsv
# pmid\tfilename\ttext

# TODO: Replace filepath for related data file, Tom comment out this whole cell
original_data = pd.read_csv("../data/ner/genovardis_train_dev_test/test_text.tsv", sep='\t', header=0)

# this type will be appended to the final result file
data_type = 'test'

# If output needed in the BRAT format
generate_brat_format = False

# gold ann file to be bratified
gold_annotation_filepath = '../data/ner/genovardis_train_dev_test/dev_annotation.tsv'

In [None]:
original_data.head(5)

In [None]:
len(original_data)

In [None]:
# TODO: remove this after testing
# original_data = original_data.head(2)
# 
# original_data

In [None]:
# original data has the pmid instances in the text
data = original_data.copy(deep=True)

In [None]:
original_data.sample()

In [None]:
# clean up text by removing the appended pmid and title abstract tags at the start of each section

pattern = '(?:[\d]{1,10}\|t\|)(?P<title>[\w\W]+)(?:\\n[\d]{1,20}\|a\|)(?P<abstract>[\w\W]+)'

def clean_text(text):
    matches = re.search(pattern, text)
    reformatted_text = f'{matches.group("title")}\n{matches.group("abstract")}'
    return reformatted_text

data['text'] = [clean_text(text) for text in data['text']]

In [None]:
data.sample()

In [None]:
len(data)


#### Step 1.2: Load prompt library

Prompt id structure:
`p_<index>_<task>_<language>_<output>`

TODO: Figure out `<guideline>_<paradigm>`

In [None]:
prompt_library = pd.read_json('prompts.json')

prompt_library

In [None]:
# TODO: remove this after testing, checking each prompt using the index
prompt_library = prompt_library.loc[[7]]

prompt_library


#### Step 1.3: Create data+prompt dataset

In [None]:
# Load in 10 examples from training data
text_training_data = pd.read_csv("../data/ner/genovardis_train_dev_test/train_text.tsv", sep='\t', header=0)
annotated_training_data = pd.read_csv('../data/ner/genovardis_train_dev_test/train_annotation.tsv', sep='\t', header=0)

text_training_data['text'] = [clean_text(text) for text in text_training_data['text']]

print(f"Text data len: {len(text_training_data)}, Ann len: {len(annotated_training_data)}")

In [None]:
text_training_data.head()

In [None]:
annotated_training_data.head()

In [None]:
# combine label, span into string
annotated_training_data['label-span'] = [f"{row['label']}\t{row['span']}" for _, row in annotated_training_data.iterrows()]

annotated_training_data

In [None]:
split_annotated_training_data = annotated_training_data.loc[:, ['pmid', 'label-span']]

len(split_annotated_training_data)

In [None]:
split_annotated_training_data = split_annotated_training_data.groupby('pmid')['label-span'].apply('\n'.join)

len(split_annotated_training_data)

In [None]:
split_annotated_training_data

In [None]:
merged_training_data = pd.merge(text_training_data, split_annotated_training_data, on="pmid")
merged_training_data['text-label-span'] = [f"Given an example text \"{row['text']}\", the output is: \"\nlabel\tspan\n{row['label-span']}\"" for _, row in merged_training_data.iterrows()]

examples_df = merged_training_data.loc[:, ['pmid', 'text-label-span']]

In [None]:
examples_df

In [None]:
examples_df.iloc[0]['text-label-span']

In [None]:
# pmid prompt_id embedded_prompt
def embed_data_in_prompts(row_data):
    prompts = []
    pmid = row_data['pmid']
    data_text = row_data['text']
    
    for index, row_prompt in prompt_library.iterrows():
        instruction = row_prompt['instruction']
        guideline = row_prompt['guideline']
        no_of_examples = row_prompt['examples']
        examples = '\n'.join(examples_df.iloc[: no_of_examples]['text-label-span'])
        expected_output = row_prompt['expected_output']
        prompt_text = row_prompt['text'].format(data_text)
        
        prompt_structure = [guideline, examples, instruction, expected_output, prompt_text]
        concatenated_prompt = '\n\n'.join(prompt_structure)
        
        
        prompt = {'prompt_id': row_prompt['prompt_id'], 'prompt': concatenated_prompt}
        prompts.append(prompt)
    
    return {'pmid': pmid, 'prompts': prompts}


In [None]:
embedded_prompt_data_list = [embed_data_in_prompts(row_data) for index, row_data in data.iterrows()]

In [None]:
print(embedded_prompt_data_list[0]['prompts'][0]['prompt'])


### Step 2: Setting up GPT-3.5

In [None]:

client = AzureOpenAI(
    api_key=os.environ["API-KEY"],  
    api_version=os.environ["API-VERSION"],
    azure_endpoint=os.environ["ENDPOINT"]
    )
    
deployment_name=os.environ["DEPLOYMENT-NAME"]


In [None]:
# Testing the connection
test_response = client.chat.completions.create(model=deployment_name, messages=[{"role": "user", "content": "Hello, World!"}])
print(test_response.choices[0].message.content)

In [None]:
# TODO: Ask Karin whether we should run again and again to see what gpt generates - yes! later!
results_list = []
def generate_results(prompt_items):
    
    pmid = prompt_items['pmid']
    
    for prompt_item in prompt_items['prompts']:
    
        prompt_id = prompt_item['prompt_id']
        prompt = prompt_item['prompt']
        
        # TODO: Look into hyper params like temp 
        response = client.chat.completions.create(model=deployment_name, messages=[{"role": "user", "content": prompt}])
        
        response_result = response.choices[0].message.content
        
        results_list.append({'pmid': pmid, 'prompt_id': prompt_id, 'result': response_result})
    
        # print(f'Prompt:\n{prompt}\n\nResponse:\n{response_result} \n----------\n')
        print(f'Prompt_id:\n{prompt_id}\n\npmid:\n{pmid}\n----------\n')
    
    return results_list
    

In [None]:
for embedded_prompt_data in embedded_prompt_data_list:
    generate_results(embedded_prompt_data)

In [None]:
results_list

In [None]:
len(results_list)

### Step 3: Post-processing

In [None]:
# create df from results list and data df
# columns = pmid, prompt_id, filename, label, offset1, offset2, span
extracted_entity_results = pd.DataFrame(columns=['pmid','prompt_id','filename','label', 'offset_checked', 'offset1','offset2','span'])

In [None]:
len(extracted_entity_results)

In [None]:
label_entity_pattern = '^(?P<label>DNAMutation|SNP|DNAAllele|NucleotideChange-BaseChange|OtherMutation|Gene|Disease|Transcript)\s+(?P<span>[\w\W]+)$'

def extract_tuple(tuple_string):
    stripped_tuple_string = tuple_string.strip()
    matches = re.search(label_entity_pattern, stripped_tuple_string)
    
    if not matches:
        return
    
    label = matches.group("label").strip()
    span = matches.group("span").strip()
    
    return {'label': label, 'span': span}

In [None]:
# extract each entity from the combined result string from gpt-3.5
# add each extracted tuple as a new row in extracted_entity_results df
def extract_ner_results(pmid, prompt_id, result_string):
    if result_string:
        extracted_list = result_string.splitlines()
        extracted_tuple_list = [ extract_tuple(result_string) for result_string in extracted_list]
        
        for extracted_tuple in extracted_tuple_list:
            if extracted_tuple:
                filename = data.loc[data['pmid'] == pmid, 'filename'].iloc[0]
                filename_ann = filename.replace('txt', 'ann')
                df_row = {
                        "pmid": pmid,
                        "prompt_id": prompt_id,
                        "filename" : filename_ann,
                        "label": extracted_tuple['label'],
                        "offset_checked": False,
                        "offset1": '',
                        "offset2": '',
                        "span": extracted_tuple['span']
                    }
            
                extracted_entity_results.loc[len(extracted_entity_results)] = df_row
    

In [None]:
# extract the concatenated results strings into a new line for each tuple 
for result_dict in results_list:
    extract_ner_results(result_dict['pmid'], result_dict['prompt_id'], result_dict['result'])


In [None]:
extracted_entity_results

In [None]:
len(extracted_entity_results)

In [None]:
# Find offsets 

# loop df, find each span, calculate the word length, find the indexes of each occurance 
for _, row in extracted_entity_results.iterrows():
    pmid = row['pmid']
    prompt_id = row['prompt_id']
    # find the text from the original_data with the pmid
    text = original_data.loc[original_data['pmid'] == pmid, 'text'].iloc[0]
    
    if not row['offset_checked'] and row['offset1'] == '':
        span = row['span']
        span_length = len(span)
        span_start_indexes = [m.start() for m in re.finditer(re.escape(span), text)]
        span_count = 0
        
        matching_spans = extracted_entity_results[(extracted_entity_results['pmid']==pmid) & (extracted_entity_results['prompt_id']==prompt_id) & (extracted_entity_results['span']==span) & (extracted_entity_results['offset1']=='') & (extracted_entity_results['offset_checked']==False)]
        
        for index, matched_span in matching_spans.iterrows(): 
            if span_start_indexes and span_count < len(span_start_indexes):
                extracted_entity_results.loc[index, 'offset1'] = str(span_start_indexes[span_count])
                extracted_entity_results.loc[index, 'offset2'] = str(span_start_indexes[span_count] + span_length)
                
                span_count = span_count + 1
            else: 
                # Add -1 to extra or missing ones 
                extracted_entity_results.loc[index, 'offset1'] = '-1'
                extracted_entity_results.loc[index, 'offset2'] = '-1'
                
            extracted_entity_results.loc[index, 'offset_checked'] = True
            
        # testing code
        # test_matching_spans = extracted_entity_results[(extracted_entity_results['pmid']==pmid) & (extracted_entity_results['prompt_id']==prompt_id) & (extracted_entity_results['span']==span)]
        # 
        # print(test_matching_spans)

In [None]:
extracted_entity_results

In [None]:
total_results = extracted_entity_results
len(extracted_entity_results)

In [None]:
# extract the hallucinations
hallucinated_results = extracted_entity_results[(extracted_entity_results['offset1'] == '-1') & (extracted_entity_results['offset2'] == '-1')]

In [None]:
# remove hallucinations
extracted_entity_results = extracted_entity_results[(extracted_entity_results['offset1'] != '-1') & (extracted_entity_results['offset2'] != '-1')]

In [None]:
extracted_entity_results

In [None]:
len(extracted_entity_results)


### Step 4: Evaluation

Evaluation log is found in `eval_log.tsv`

In [None]:
evaluation_log_filepath = "eval_log.tsv"
date = datetime.today().strftime('%Y-%m-%d %H:%M:%S')

In [None]:
if os.path.isfile(evaluation_log_filepath):
    eval_log_df=pd.read_csv(evaluation_log_filepath, sep='\t', header=0)
else:
    # TODO: Keep track of the variations between the runs eg: hyperparams (fixed), prompt that worked best etc. to add the metrics for result 
    # TODO: Keep track of the hallucinations like the # of entities found that got chopped off coz they can tbe mapped to the text, confabulate
    eval_log_df = pd.DataFrame(columns=['prompt_id', 'data_type', 'true_positive', 'false_positive', 'false_negative', 'precision', 'recall', 'f1', 'hallucination_count', 'total_result_count', 'date', 'notes'])

def update_eval_log(eval_prompt_id, hallucination_count, total_result_count):
    eval_log_df.loc[len(eval_log_df.index)] = [eval_prompt_id, data_type, 0, 0, 0, 0, 0, 0, hallucination_count, total_result_count, date, 'gpt-3.5-turbo-16k']
        


### Step 5: Saving output

Save output files in the following forms:
1. `.tsv` file and `.zip` compressed folder containing the extracted entities in the following format: 
  `pmid   filename   label   offset1   offset2   span`.
2. `.tsv` file containing the gold standard annotations in the following BRAT format: 
  `mark   label offset1 offset2   span`.
3. `.tsv` file containing the extracted entities in the following BRAT format: 
  `mark   label offset1 offset2   span`.



In [None]:
# train_annotations.tsv
# pmid\tfilename\tmark\tlabel\toffset1\toffset2\tspan

# Read and find what other people have done 

# brat format for NER
# TODO: Ask Karin about the BRAT format correctedness
# <unique_id>   <label> <offset1> <offset2>   <span> 
def save_brat_output(brat, df_to_save=None, filename="./results/temp.tsv"):
    
    if brat:    
        df_to_save["label-offsets"] = df_to_save.apply(
        lambda df_row: f"{df_row['label']} {df_row['offset1']} {df_row['offset2']}",axis=1)
        
        if 'mark' not in df_to_save.columns:
            df_to_save["mark"] = df_to_save.apply(lambda df_row: f"T{df_row.name+1}",axis=1)

        formatted_df_to_save = df_to_save.loc[:, ['mark', 'label-offsets', 'span']]
        formatted_df_to_save.to_csv(filename, sep ='\t', index=False, header=False)
        
        print(f'----BRAT----\nOriginal data len: {len(df_to_save)}, Reformatted len: {len(formatted_df_to_save)}\n')
        print(f"Reformatted data:\n--------------------\n\n{formatted_df_to_save.head(5)}\n--------------------\n\n")
        
    if not brat:
        formatted_df_to_save = df_to_save.loc[:, ['pmid', 'filename', 'label', 'offset1', 'offset2', 'span']]
        formatted_df_to_save.to_csv(f"{filename}.tsv", sep ='\t', index=False, header=True)
        
        print(f'Original data len: {len(df_to_save)}, Reformatted len: {len(formatted_df_to_save)}\n')
        print(f"Reformatted data:\n--------------------\n\n{formatted_df_to_save.head(5)}\n--------------------\n\n")
    

In [None]:
# Save gold annotations in BRAT format
if gold_annotation_filepath != "" and generate_brat_format:
    gold_annotations_df = pd.read_csv(gold_annotation_filepath, sep='\t', header=0)
    
    gold_annotations_df = gold_annotations_df.drop(['mark'], axis=1)
    
    for _, prompt in prompt_library.iterrows():
        prompt_id = prompt['prompt_id']
        gold_annotations_filename = f'results/temp/gold/{prompt_id}_{data_type}.ann'
        save_brat_output(True, gold_annotations_df, gold_annotations_filename)

In [None]:
for _, prompt in prompt_library.iterrows():
    prompt_id = prompt['prompt_id']
    results_subset = extracted_entity_results[(extracted_entity_results['prompt_id']==prompt_id)]
    
    # Save results in BRAT format
    if generate_brat_format:
        results_brat_filename = f'results/temp/eval/{prompt_id}_{data_type}.ann'
        save_brat_output(True, results_subset, results_brat_filename)
        
    # Update eval log
    total_result_count = len(total_results[(total_results['prompt_id']==prompt_id)])
    hallucination_count = len(hallucinated_results[(hallucinated_results['prompt_id']==prompt_id)])
    update_eval_log(prompt_id, hallucination_count, total_result_count)
    
    # Save whole result output
    results_filename = f'results/{prompt_id}_{data_type}'
    save_brat_output(False, results_subset, results_filename)
    

In [None]:
# Save eval_log file

# /Users/milindi/Documents/Honours/Projects/honos/results/brateval/gold
# /Users/milindi/Documents/Honours/Projects/honos/results/brateval/eval
eval_log_df.to_csv('eval_log.tsv', sep ='\t', index=False, header=True)

`mvn exec:java -Dexec.mainClass=au.com.nicta.csp.brateval.CompareEntities -Dexec.args="-e /Users/milindi/Documents/Honours/Projects/honos/results/brateval/eval -g /Users/milindi/Documents/Honours/Projects/honos/results/brateval/gold -s exact" -X`