In [None]:
import os
import time
import openai
openai.api_key = os.getenv("OPENAI_API_KEY")

import pandas as pd
from tqdm import tqdm
from typing import List
from utils import html_parsing_ncbi, html_parsing_n2c2, get_classification_report

# 1. NER (Named Entity Recognition)

## 1.1 NCBI-Disease Dataset

### 1.1.1 Inference

In [None]:
ncbi_df = pd.read_csv('data/NER/NCBI-disease/test_200.csv')

In [None]:
def get_ner_ncbi_disease(sentence: str, gpt4: bool = False, shot: int = 0) -> str:
    """
        Get NER zero/one-shot prediction from GPT-3.5 or GPT-4 given a sentence in NCBI-disease dataset.
        Input:
            sentence: a string of sentence
            gpt4: whether to use GPT-4 or GPT-3.5
            shot: zero-shot (0) or one-shot (1)
        Output:
            a HTML string that highlights all the disease entities in the sentence
    """

    prompt = [
        {
            "role": "system", 
            "content": "TASK: the task is to extract disease entities in a sentence."
                       "INPUT: the input is a sentence."
                       "OUTPUT: the output is an HTML that highlights all the disease entities in the sentence. \
                                The highlighting should only use HTML tags <span style=\"background-color: #FFFF00\"> and </span> and no other tags."
        }
    ]
    if shot == 1: # the example is from NCBI-disease dataset's training split
        prompt.append(
            {
                "role": "user", 
                "content": "In summary , inactivation of the murine ATP7B gene produces a form of cirrhotic liver disease that resembles Wilson disease in humans and the toxic milk phenotype in the mouse . ."
            }
        )
        prompt.append(
            {
                "role": "assistant",
                "content": 'In summary , inactivation of the murine ATP7B gene produces a form of <span style="background-color: #FFFF00">cirrhotic liver disease</span> \
                            that resembles <span style="background-color: #FFFF00">Wilson disease</span> in humans and the toxic milk phenotype in the mouse . .'
            }
        )
    prompt.append(
        {
            "role": "user", 
            "content": sentence
        }
    )

    gpt = "gpt-4-0613" if gpt4 else "gpt-3.5-turbo-0613"

    retries = 10 # retry at most 10 times until it succeeds
    while retries > 0:
        try:
            time_start = time.time()
            response = openai.ChatCompletion.create(
                model = gpt,
                messages = prompt,
                temperature = 0.0, # deterministic
            )
            time_end = time.time()
            return response['choices'][0]['message']['content'], time_end - time_start
        except:
            print(f"Retrying... {retries} retries left")
            retries -= 1
            time.sleep(5)
            continue

    raise SystemExit("Max retries exceeded, exiting program")

In [None]:
for i in tqdm(range(0, len(ncbi_df), 1)):
    ncbi_df.loc[i, 'html_gpt3.5_zero_shot'], ncbi_df.loc[i, 'gpt3.5_zero_shot_time'] = get_ner_ncbi_disease(ncbi_df.iloc[i]['text'], gpt4=False, shot=0)
    ncbi_df.loc[i, 'html_gpt4_zero_shot'], ncbi_df.loc[i, 'gpt4_zero_shot_time'] = get_ner_ncbi_disease(ncbi_df.iloc[i]['text'], gpt4=True, shot=0)
    ncbi_df.loc[i, 'html_gpt3.5_one_shot'], ncbi_df.loc[i, 'gpt3.5_one_shot_time'] = get_ner_ncbi_disease(ncbi_df.iloc[i]['text'], gpt4=False, shot=1)
    ncbi_df.loc[i, 'html_gpt4_one_shot'], ncbi_df.loc[i, 'gpt4_one_shot_time'] = get_ner_ncbi_disease(ncbi_df.iloc[i]['text'], gpt4=True, shot=1)

### 1.2.2 Evaluation

In [None]:
ncbi_df['gt_labels'], ncbi_df['gpt3.5_zero_shot_labels'] = html_parsing_ncbi(ncbi_df, 'html_gpt3.5_zero_shot')
_, ncbi_df['gpt4_zero_shot_labels'] = html_parsing_ncbi(ncbi_df, 'html_gpt4_zero_shot')
_, ncbi_df['gpt3.5_one_shot_labels'] = html_parsing_ncbi(ncbi_df, 'html_gpt3.5_one_shot')
_, ncbi_df['gpt4_one_shot_labels'] = html_parsing_ncbi(ncbi_df, 'html_gpt4_one_shot')

In [None]:
get_classification_report(ncbi_df, 'gt_labels', 'gpt4_one_shot_labels', 'strict')

In [None]:
get_classification_report(ncbi_df, 'gt_labels', 'gpt4_one_shot_labels', 'lenient')

In [None]:
print(f"Average GPT-3.5 zero-shot prediction time: {ncbi_df['gpt3.5_zero_shot_time'].mean():.2f} seconds")
print(f"Average GPT-4 zero-shot prediction time: {ncbi_df['gpt4_zero_shot_time'].mean():.2f} seconds")
print(f"Average GPT-3.5 one-shot prediction time: {ncbi_df['gpt3.5_one_shot_time'].mean():.2f} seconds")
print(f"Average GPT-4 one-shot prediction time: {ncbi_df['gpt4_one_shot_time'].mean():.2f} seconds")

In [None]:
# save the inference results
ncbi_df.to_csv('data/NER/NCBI-disease/test_200_results.csv', index=False)

# 1.2 2018 n2c2 Dataset

### 1.2.1 Inference

In [None]:
n2c2_df = pd.read_csv('data/NER/2018_n2c2/test_200.csv')

In [None]:
def get_ner_2018_n2c2(sentence: str, gpt4: bool = False, shot: int = 0):
    """
        Get NER zero/one-shot prediction from GPT-3.5 or GPT-4 given a sentence in 2018 n2c2 dataset.
        Input:
            sentence: a string of sentence
            gpt4: whether to use GPT-4 or GPT-3.5
            shot: zero-shot (0) or one-shot (1)
        Output:
            a HTML string that highlights all the disease entities in the sentence in different colors
    """
    
    prompt = [
        {
            "role": "system", 
            "content": "TASK: the task is to extract disease entities in a sentence. The entity type includes Form, Route, Frequency, Dosage, Strength, Duration, Reason, Ade, Drug."
                       "INPUT: the input is a sentence."
                       "OUTPUT: the output is an HTML that highlights all the disease entities in the sentence in different colors: Form(#FF0000), Route(#FFA500), Frequency(#FFFF00), Dosage(#00FF00), Strength(#0000FF), Duration(#800080), Reason(#FFC0CB), Ade(#964B00), Drug(#808080) in hex code. \
                                The highlighting should only use HTML tags <span style=\"background-color: #XXXXXX\"> and </span> and no other tags."
        }
    ]
    if shot == 1:
        prompt.append(
            {
                "role": "user", 
                "content": "Vitamin D 400 unit Tablet Sig : Two ( 2 ) Tablet PO once a day ."
            }
        )
        prompt.append(
            {
                "role": "assistant",
                "content": '<span style="background-color: #808080">Vitamin D</span> <span style="background-color: #0000FF">400 unit</span> <span style="background-color: #FF0000">Tablet</span> Sig : <span style="background-color: #00FF00">Two ( 2 )</span> <span style="background-color: #FF0000">Tablet</span> <span style="background-color: #FFA500">PO</span> <span style="background-color: #FFFF00">once a day</span> .'
            }
        )
    prompt.append(
        {
            "role": "user", 
            "content": sentence
        }
    )
    
    gpt = "gpt-4-0613" if gpt4 else "gpt-3.5-turbo-0613"
    
    
    retries = 10 # retry at most 10 times until it succeeds
    while retries > 0:
        try:
            time_start = time.time()
            response = openai.ChatCompletion.create(
                model = gpt,
                messages = prompt,
                temperature = 0.0, # deterministic
            )
            time_end = time.time()
            return response['choices'][0]['message']['content'], time_end - time_start
        except:
            print(f"Retrying... {retries} retries left")
            retries -= 1
            time.sleep(5)
            continue

    raise SystemExit("Max retries exceeded, exiting program")

In [None]:
for i in tqdm(range(0, len(n2c2_df), 1)):
    n2c2_df.loc[i, 'html_gpt3.5_zero_shot'], n2c2_df.loc[i, 'gpt3.5_zero_shot_time'] = get_ner_2018_n2c2(n2c2_df.iloc[i]['text'], gpt4=False, shot=0)
    n2c2_df.loc[i, 'html_gpt4_zero_shot'], n2c2_df.loc[i, 'gpt4_zero_shot_time'] = get_ner_2018_n2c2(n2c2_df.iloc[i]['text'], gpt4=True, shot=0)
    n2c2_df.loc[i, 'html_gpt3.5_one_shot'], n2c2_df.loc[i, 'gpt3.5_one_shot_time'] = get_ner_2018_n2c2(n2c2_df.iloc[i]['text'], gpt4=False, shot=1)
    n2c2_df.loc[i, 'html_gpt4_one_shot'], n2c2_df.loc[i, 'gpt4_one_shot_time'] = get_ner_2018_n2c2(n2c2_df.iloc[i]['text'], gpt4=True, shot=1)

### 1.2.2 Evaluation

In [None]:
n2c2_df['gt_labels'], n2c2_df['gpt3.5_zero_shot_labels'] = html_parsing_n2c2(n2c2_df, 'html_gpt3.5_zero_shot')
_, n2c2_df['gpt4_zero_shot_labels'] = html_parsing_n2c2(n2c2_df, 'html_gpt4_zero_shot')
_, n2c2_df['gpt3.5_one_shot_labels'] = html_parsing_n2c2(n2c2_df, 'html_gpt3.5_one_shot')
_, n2c2_df['gpt4_one_shot_labels'] = html_parsing_n2c2(n2c2_df, 'html_gpt4_one_shot')

In [None]:
get_classification_report(n2c2_df, 'gt_labels', 'gpt4_one_shot_labels', 'strict')

In [None]:
get_classification_report(n2c2_df, 'gt_labels', 'gpt4_one_shot_labels', 'lenient')

In [None]:
print(f"Average GPT-3.5 zero-shot prediction time: {n2c2_df['gpt3.5_zero_shot_time'].mean():.2f} seconds")
print(f"Average GPT-4 zero-shot prediction time: {n2c2_df['gpt4_zero_shot_time'].mean():.2f} seconds")
print(f"Average GPT-3.5 one-shot prediction time: {n2c2_df['gpt3.5_one_shot_time'].mean():.2f} seconds")
print(f"Average GPT-4 one-shot prediction time: {n2c2_df['gpt4_one_shot_time'].mean():.2f} seconds")

In [None]:
n2c2_df.to_csv('data/NER/2018_n2c2/test_200_results.csv', index=False)