In [3]:
import pandas as pd
import openai
import json
import time
import matplotlib.pyplot as plt
from tqdm import tqdm  # Import tqdm for the progress bar
from sklearn.metrics import accuracy_score, classification_report, balanced_accuracy_score
from sklearn.metrics import balanced_accuracy_score
import numpy as np

In [4]:
#pip install openai

In [5]:
from openai import OpenAI

In [6]:
openai.__version__


'1.14.0'

### Read in data

In [73]:
# Initialize an empty list to store the parsed data
data_list = []

# Read the JSONL file line by line
with open('../data/prodigy/annotated_output/pilot_500_pubmed_abstracts_shirin_correct_id.jsonl', 'r') as file:
    for line in file:
        # Parse each line as a JSON object
        data = json.loads(line.strip())
        
        # Check if _view_id is "choice"// should be review if this was coming from a review prodigy session
        if data.get("_view_id") == "choice":
            # Split text into journal_name, title, and abstract based on ^\n symbol
            text = data.get("text", "")
            
            # Splitting text based on ^\n for journal_name, title, and abstract
            parts = text.split("^\n", 2)  # Split into three parts based on first two ^\n occurrences
            
            journal_name = parts[0].strip() if len(parts) > 0 else ""
            title = parts[1].strip() if len(parts) > 1 else ""
            abstract = parts[2].strip() if len(parts) > 2 else ""
            
            # Extract other required fields
            pmid = data.get("pmid", "")
            accept = data.get("accept", [])
            
            # Append to the list
            data_list.append({
                "pmid": pmid,
                "journal_name": journal_name,
                "title": title,
                "abstract": abstract,
                "accepted_label": accept
            })

# Convert the list of dictionaries to a pandas DataFrame
df = pd.DataFrame(data_list)
df = df.explode('accepted_label')

# Combine the columns
# Implementing custom tags for the combination of journal name, title, and abstract
df['input_journal_title_abstract'] = '<journal>' + df['journal_name'] + '</journal>' + \
                                     '<title>' + df['title'] + '</title>' + \
                                     '<abstract>' + df['abstract'] + '</abstract>'

# Implementing custom tags for the combination of title and abstract only
df['input_title_abstract'] = '<title>' + df['title'] + '</title>' + \
                             '<abstract>' + df['abstract'] + '</abstract>'

In [74]:
df.head(5)

Unnamed: 0,pmid,journal_name,title,abstract,accepted_label,input_journal_title_abstract,input_title_abstract
0,37550718,Trials,Can dexamethasone improve postoperative sleep ...,Perioperative sleep disorders (PSD) are an ind...,Human-RCT-drug-intervention,<journal>Trials</journal><title>Can dexamethas...,<title>Can dexamethasone improve postoperative...
1,2500373,Developmental medicine and child neurology,Effects of puberty on seizure frequency.,"Seizure frequency was documented before, durin...",Remaining,<journal>Developmental medicine and child neur...,<title>Effects of puberty on seizure frequency...
2,36189588,Journal of Alzheimer's disease : JAD,Characterization of Mild Cognitive Impairment ...,"Despite tremendous advancements in the field, ...",Remaining,<journal>Journal of Alzheimer's disease : JAD<...,<title>Characterization of Mild Cognitive Impa...
3,36314672,Journal of vector ecology : journal of the Soc...,Effects of woody plant encroachment by eastern...,Woody plant encroachment into grasslands is oc...,Remaining,<journal>Journal of vector ecology : journal o...,<title>Effects of woody plant encroachment by ...
4,29172241,Depression and anxiety,The impact of resilience and subsequent stress...,There remains a dearth of research examining t...,Remaining,<journal>Depression and anxiety</journal><titl...,<title>The impact of resilience and subsequent...


### Load key for the OpenAI API 

In [57]:
def load_pass(file_path, key_to_find):
    with open(file_path, 'r') as file:
        for line in file:
            parts = line.strip().split("=")
            if len(parts) == 2 and parts[0] == key_to_find:
                found_password = parts[1]
                break
    if found_password:
        print("Found password.")
        return found_password
    else:
        print("Password not found for key:", key_to_find)

Note: You need to create a credentials.txt file with the following content:  
OPENAI=sk-77QXXXXXXXXXXXXXXXXXXXXXXXXXXX  
replace the value after the = sign with your API key.  
Make sure the credentials.txt is added to .gitignore, you don't want to put your password on Git!

In [58]:
openai.api_key = load_pass("../credentials.txt", "OPENAI")


Found password.


In [59]:
client = OpenAI(api_key=openai.api_key)

### Query GPT

To change the task the model is solving, you need to change the text of the prompt and the content text of the system role.  
To change the GPT model used, you need to change the text of the model name when initiating the openai API.  
The function gets as input the input_raw_text, that will be text for information extraction or classification.

In [75]:
import time

DEFAULT_TEMPERATURE = 0
DEFAULT_MAX_TOKENS = 500
DEFAULT_MODEL = "gpt-3.5-turbo"

def query_gpt(input_raw_text, prompt_text, gpt_model="gpt-3.5-turbo", temperature=0, max_retries=5, retry_delay=3):
    # CHANGE gpt_model to the desired model name, see https://platform.openai.com/docs/models (gpt-3.5-turbo and gpt-4-turbo-preview)
    
    # Add a delay at the beginning of the function to avoid overloading the API if there are multiple calls
    # time.sleep(10)  

    system_msg = f"""
    You are an expert assistant specialized in text classification of PubMed abstracts. """

    retries = 0
    while retries < max_retries:
        print("Trying to call OpenAI API...")
        try:
            completion = client.chat.completions.create(
                model=gpt_model,  
                response_format={"type": "json_object"},
                temperature=temperature,
                #max_tokens=2000,
                messages=[
                    {"role": "system", "content": system_msg},
                    {"role": "user", "content": prompt_text + input_raw_text}
                ]
            )
            return completion.choices[0].message.content
        except Exception as e:
            # Handle API error, e.g., retry or log
            print(f"OpenAI API returned an error: {e}")
            time.sleep(retry_delay)  # Wait before retrying
            retries += 1

    raise RuntimeError("Max retries reached. Unable to complete the API call.")


In [76]:
# Define a function to apply GPT queries with a progress bar
def apply_gpt_with_progress(data_series, prompt_text, model="gpt-3.5-turbo"):
    results = []
    total_items = len(data_series)
    # Create a tqdm progress bar
    with tqdm(total=total_items, desc=f"Processing dataset") as pbar:
        for text in data_series:
            result = query_gpt(text, prompt_text, model)
            results.append(result)
            pbar.update(1)  # Update the progress bar

    return results

# Read prompts from file and query GPT

In [77]:
json_file_path = "./prompt_strategies.json"
# Load the JSON file
with open(json_file_path, 'r') as file:
    prompts_data = json.load(file)

In [78]:
sampled_df = df.sample(n=10, random_state=1)
sampled_df.head()

Unnamed: 0,pmid,journal_name,title,abstract,accepted_label,input_journal_title_abstract,input_title_abstract
304,12905582,Fa yi xue za zhi,[The research of the heroin and its metabolite...,Heroin can be metabolized easily in body and t...,Remaining,<journal>Fa yi xue za zhi</journal><title>[The...,<title>[The research of the heroin and its met...
340,28645717,Vaccine,Surveillance of pneumococcal colonization and ...,Following the introduction of pneumococcal con...,Human-non-RCT-drug-intervention,<journal>Vaccine</journal><title>Surveillance ...,<title>Surveillance of pneumococcal colonizati...
47,11482695,Acta neurochirurgica,Carotid endarterectomy: a new technique replac...,Carotid endarterectomy has been reported to in...,Human-non-RCT-non-drug-intervention,<journal>Acta neurochirurgica</journal><title>...,<title>Carotid endarterectomy: a new technique...
67,15065953,Journal of consulting and clinical psychology,Traditional versus integrative behavioral coup...,A randomized clinical trial compared the effec...,Human-RCT-non-drug-intervention,<journal>Journal of consulting and clinical ps...,<title>Traditional versus integrative behavior...
479,9578881,The Journal of laryngology and otology,Vocal fold abductor paralysis as a solitary an...,A patient is presented who had bilateral abduc...,Remaining,<journal>The Journal of laryngology and otolog...,<title>Vocal fold abductor paralysis as a soli...


## Run different prompts over the data

In [81]:
# Add the IDs of Prompts that you want to test
prompt_ids_to_test = ["P1", "P2"] #, "P2", "P3", "P4"
model = "gpt-3.5-turbo"

for prompt in prompts_data["prompts"]:
    prompt_id = prompt["id"]
    prompt_text = prompt["text"]

    if prompt_id in prompt_ids_to_test:
        # Apply GPT predictions
        sampled_df[f'gpt_predictions_{prompt_id}_raw'] = apply_gpt_with_progress(sampled_df['input_journal_title_abstract'], prompt_text, model)
        ## the below includes error handling in case the json formatting did not work as expected
        sampled_df[f'gpt_predictions_{prompt_id}'] = sampled_df[f'gpt_predictions_{prompt_id}_raw'].apply(
            lambda x: json.loads(x)['gpt_label'] if isinstance(x, str) and 'gpt_label' in json.loads(x) else x
        )
        sampled_df.to_csv(f"predictions/{model}_outputs_{'_'.join(prompt_ids_to_test)}.csv") # saving after each prompt strategy has been run, in order not to loose information in case a later strategy fails
    else:
        print(f"Skipping prompt {prompt_id}")
        

Processing dataset:   0%|                                                                                                                                                                    | 0/10 [00:00<?, ?it/s]

Trying to call OpenAI API...


Processing dataset:  10%|███████████████▌                                                                                                                                            | 1/10 [00:01<00:11,  1.27s/it]

Trying to call OpenAI API...


Processing dataset:  20%|███████████████████████████████▏                                                                                                                            | 2/10 [00:01<00:07,  1.07it/s]

Trying to call OpenAI API...


Processing dataset:  30%|██████████████████████████████████████████████▊                                                                                                             | 3/10 [00:02<00:06,  1.03it/s]

Trying to call OpenAI API...


Processing dataset:  40%|██████████████████████████████████████████████████████████████▍                                                                                             | 4/10 [00:04<00:05,  1.01it/s]

Trying to call OpenAI API...


Processing dataset:  50%|██████████████████████████████████████████████████████████████████████████████                                                                              | 5/10 [00:04<00:04,  1.07it/s]

Trying to call OpenAI API...


Processing dataset:  60%|█████████████████████████████████████████████████████████████████████████████████████████████▌                                                              | 6/10 [00:05<00:03,  1.16it/s]

Trying to call OpenAI API...


Processing dataset:  70%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████▏                                              | 7/10 [00:06<00:02,  1.29it/s]

Trying to call OpenAI API...


Processing dataset:  80%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▊                               | 8/10 [00:06<00:01,  1.26it/s]

Trying to call OpenAI API...


Processing dataset:  90%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▍               | 9/10 [00:07<00:00,  1.30it/s]

Trying to call OpenAI API...


Processing dataset: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [00:08<00:00,  1.17it/s]
Processing dataset:   0%|                                                                                                                                                                    | 0/10 [00:00<?, ?it/s]

Trying to call OpenAI API...


Processing dataset:  10%|███████████████▌                                                                                                                                            | 1/10 [00:01<00:10,  1.12s/it]

Trying to call OpenAI API...


Processing dataset:  20%|███████████████████████████████▏                                                                                                                            | 2/10 [00:01<00:07,  1.10it/s]

Trying to call OpenAI API...


Processing dataset:  30%|██████████████████████████████████████████████▊                                                                                                             | 3/10 [00:02<00:06,  1.11it/s]

Trying to call OpenAI API...


Processing dataset:  40%|██████████████████████████████████████████████████████████████▍                                                                                             | 4/10 [00:03<00:05,  1.17it/s]

Trying to call OpenAI API...


Processing dataset:  50%|██████████████████████████████████████████████████████████████████████████████                                                                              | 5/10 [00:04<00:04,  1.04it/s]

Trying to call OpenAI API...


Processing dataset:  60%|█████████████████████████████████████████████████████████████████████████████████████████████▌                                                              | 6/10 [00:05<00:04,  1.02s/it]

Trying to call OpenAI API...


Processing dataset:  70%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████▏                                              | 7/10 [00:06<00:02,  1.05it/s]

Trying to call OpenAI API...


Processing dataset:  80%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▊                               | 8/10 [00:07<00:01,  1.11it/s]

Trying to call OpenAI API...


Processing dataset:  90%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▍               | 9/10 [00:08<00:00,  1.02it/s]

Trying to call OpenAI API...


Processing dataset: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [00:09<00:00,  1.09it/s]

Skipping prompt P3
Skipping prompt P4
Skipping prompt P5





In [82]:
sampled_df.head()

Unnamed: 0,pmid,journal_name,title,abstract,accepted_label,input_journal_title_abstract,input_title_abstract,gpt_predictions_P1,gpt_predictions_P2,gpt_predictions_P1_raw,gpt_predictions_P2_raw
304,12905582,Fa yi xue za zhi,[The research of the heroin and its metabolite...,Heroin can be metabolized easily in body and t...,Remaining,<journal>Fa yi xue za zhi</journal><title>[The...,<title>[The research of the heroin and its met...,Human-non-RCT-non-drug-intervention,Human-case-report,"{\n ""gpt_label"": ""Human-non-RCT-non-drug-in...","{\n ""gpt_label"": ""Human-case-report""\n}"
340,28645717,Vaccine,Surveillance of pneumococcal colonization and ...,Following the introduction of pneumococcal con...,Human-non-RCT-drug-intervention,<journal>Vaccine</journal><title>Surveillance ...,<title>Surveillance of pneumococcal colonizati...,Human-systematic-review,Human-systematic-review,"{\n ""gpt_label"": ""Human-systematic-review""\n}","{\n ""gpt_label"": ""Human-systematic-review""\n}"
47,11482695,Acta neurochirurgica,Carotid endarterectomy: a new technique replac...,Carotid endarterectomy has been reported to in...,Human-non-RCT-non-drug-intervention,<journal>Acta neurochirurgica</journal><title>...,<title>Carotid endarterectomy: a new technique...,Human-case-report,Human-case-report,"{\n ""gpt_label"": ""Human-case-report""\n}","{\n ""gpt_label"": ""Human-case-report""\n}"
67,15065953,Journal of consulting and clinical psychology,Traditional versus integrative behavioral coup...,A randomized clinical trial compared the effec...,Human-RCT-non-drug-intervention,<journal>Journal of consulting and clinical ps...,<title>Traditional versus integrative behavior...,Human-RCT-non-drug-intervention,Human-RCT-non-drug-intervention,"{\n ""gpt_label"": ""Human-RCT-non-drug-interv...","{\n ""gpt_label"": ""Human-RCT-non-drug-interv..."
479,9578881,The Journal of laryngology and otology,Vocal fold abductor paralysis as a solitary an...,A patient is presented who had bilateral abduc...,Remaining,<journal>The Journal of laryngology and otolog...,<title>Vocal fold abductor paralysis as a soli...,Human-case-report,Human-case-report,"{\n ""gpt_label"": ""Human-case-report""\n}","{\n ""gpt_label"": ""Human-case-report""\n}"


## Evaluate each prompt

In [83]:
labels = ["Human-systematic-review", "Human-RCT-drug-intervention", "Human-RCT-non-drug-intervention", "Human-RCT-non-intervention", "Human-case-report", "Human-non-RCT-drug-intervention", "Human-non-RCT-non-drug-intervention", "Animal-systematic-review", "Animal-drug-intervention", "Animal-non-drug-intervention", "Animal-other", "Non-systematic-review", "In-vitro-study", "Clinical-study-protocol", "Remaining"]

label_to_numerical = {label: i for i, label in enumerate(labels)}
label_to_numerical["label missing"] = -1

In [84]:
def map_label_to_numerical(label):
    # Check if label is a dictionary
    if isinstance(label, dict):
        # Extract the label with the highest score/probability
        highest_label = max(label, key=label.get)
        return label_to_numerical.get(highest_label, -1)
    else:
        # Directly map string labels to numerical IDs
        return label_to_numerical.get(label, -1)
        
# Convert accepted labels to numerical
sampled_df['accepted_label_numerical'] = sampled_df['accepted_label'].apply(lambda x: label_to_numerical.get(x, -1))

# Initialize a list to hold DataFrame for each report and summary statistics
report_dfs = []
summary_stats = []

# Iterate over each GPT prediction column
for prompt_id in prompt_ids_to_test:
    print("Evaluating ", prompt_id)
    prediction_col = f'gpt_predictions_{prompt_id}'
    
    # Map GPT predictions to numerical values
    sampled_df[f'{prediction_col}_numerical'] = sampled_df[prediction_col].apply(map_label_to_numerical)
    
    # Extract arrays for evaluation
    y_true = sampled_df['accepted_label_numerical'].values
    y_pred = sampled_df[f'{prediction_col}_numerical'].values
    
    # Calculate metrics
    accuracy = accuracy_score(y_true, y_pred)
    accuracy_balanced = balanced_accuracy_score(y_true, y_pred)
    report = classification_report(y_true, y_pred, output_dict=True, zero_division=0, labels=range(len(labels)), target_names=labels)
    
    # Create DataFrame from report
    report_df = pd.DataFrame(report).transpose()
    report_df['Prompt ID'] = prompt_id  # Add column to indicate the prompt ID
    report_dfs.append(report_df)
    
    # Extract summary statistics (average precision, recall, F1)
    summary = report_df.loc['weighted avg', ['precision', 'recall', 'f1-score']].to_dict()
    summary['Prompt ID'] = prompt_id
    summary_stats.append(summary)

# Combine all report DataFrames
all_reports_df = pd.concat(report_dfs)

# Create a summary table for average precision, recall, and F1-score
summary_df = pd.DataFrame(summary_stats)

Evaluating  P1
Evaluating  P2




In [85]:
summary_df

Unnamed: 0,precision,recall,f1-score,Prompt ID
0,0.1,0.1,0.1,P1
1,0.2,0.2,0.2,P2


In [86]:
all_reports_df.to_csv(f"evaluations/{model}_per_class_{'_'.join(prompt_ids_to_test)}.csv") # saving after each prompt strategy has been run, in order not to loose information in case a later strategy fails
summary_df.to_csv(f"evaluations/{model}_summary_{'_'.join(prompt_ids_to_test)}.csv") # saving after each prompt strategy has been run, in order not to loose information in case a later strategy fails


In [87]:
all_reports_df

Unnamed: 0,precision,recall,f1-score,support,Prompt ID
Human-systematic-review,0.0,0.0,0.0,0.0,P1
Human-RCT-drug-intervention,0.0,0.0,0.0,0.0,P1
Human-RCT-non-drug-intervention,1.0,1.0,1.0,1.0,P1
Human-RCT-non-intervention,0.0,0.0,0.0,0.0,P1
Human-case-report,0.0,0.0,0.0,0.0,P1
Human-non-RCT-drug-intervention,0.0,0.0,0.0,1.0,P1
Human-non-RCT-non-drug-intervention,0.0,0.0,0.0,1.0,P1
Animal-systematic-review,0.0,0.0,0.0,0.0,P1
Animal-drug-intervention,0.0,0.0,0.0,0.0,P1
Animal-non-drug-intervention,0.0,0.0,0.0,0.0,P1
