In [1]:
import pandas as pd
import openai
import json
import time
import matplotlib.pyplot as plt
from tqdm import tqdm  # Import tqdm for the progress bar
from sklearn.metrics import accuracy_score, classification_report, balanced_accuracy_score
from sklearn.metrics import balanced_accuracy_score
import numpy as np

In [2]:
#pip install openai

In [3]:
from openai import OpenAI

In [4]:
openai.__version__


'1.1.1'

### Read in annotated data

In [13]:
df = pd.read_csv("../data/prodigy/annotated_output/final/full_combined_dataset_1997.csv", index_col=0)

In [14]:
df.head()

Unnamed: 0,pmid,journal_name,title,abstract,accepted_label
0,18379746,Der Radiologe,[Pediatric stroke].,Stroke in childhood has gained increasingly mo...,Non-systematic-review
1,24660674,Journal of consulting and clinical psychology,Treatment engagement and response to CBT among...,"In the current study, we compared measures of ...",Human-RCT-non-drug-intervention
2,20159133,Archives of physical medicine and rehabilitation,Relationship between perceived exertion and ph...,To investigate the strength of the relationshi...,Remaining
3,11781147,Biochimica et biophysica acta,Characterization of a missense mutation at his...,Genetic defects in pyruvate dehydrogenase comp...,Remaining
4,31706919,Epilepsy & behavior : E&B,The role of P-glycoprotein (P-gp) and inwardly...,Sudden unexpected death in epilepsy (SUDEP) is...,Non-systematic-review


In [15]:
# Combine the columns
# Implementing custom tags for the combination of journal name, title, and abstract
df['input_journal_title_abstract'] = '<journal>' + df['journal_name'] + '</journal>' + \
                                     '<title>' + df['title'] + '</title>' + \
                                     '<abstract>' + df['abstract'] + '</abstract>'

# Implementing custom tags for the combination of title and abstract only
df['input_title_abstract'] = '<title>' + df['title'] + '</title>' + \
                             '<abstract>' + df['abstract'] + '</abstract>'

In [16]:
df.head(5)

Unnamed: 0,pmid,journal_name,title,abstract,accepted_label,input_journal_title_abstract,input_title_abstract
0,18379746,Der Radiologe,[Pediatric stroke].,Stroke in childhood has gained increasingly mo...,Non-systematic-review,<journal>Der Radiologe</journal><title>[Pediat...,<title>[Pediatric stroke].</title><abstract>St...
1,24660674,Journal of consulting and clinical psychology,Treatment engagement and response to CBT among...,"In the current study, we compared measures of ...",Human-RCT-non-drug-intervention,<journal>Journal of consulting and clinical ps...,<title>Treatment engagement and response to CB...
2,20159133,Archives of physical medicine and rehabilitation,Relationship between perceived exertion and ph...,To investigate the strength of the relationshi...,Remaining,<journal>Archives of physical medicine and reh...,<title>Relationship between perceived exertion...
3,11781147,Biochimica et biophysica acta,Characterization of a missense mutation at his...,Genetic defects in pyruvate dehydrogenase comp...,Remaining,<journal>Biochimica et biophysica acta</journa...,<title>Characterization of a missense mutation...
4,31706919,Epilepsy & behavior : E&B,The role of P-glycoprotein (P-gp) and inwardly...,Sudden unexpected death in epilepsy (SUDEP) is...,Non-systematic-review,<journal>Epilepsy & behavior : E&B</journal><t...,<title>The role of P-glycoprotein (P-gp) and i...


### Load key for the OpenAI API 

In [17]:
def load_pass(file_path, key_to_find):
    with open(file_path, 'r') as file:
        for line in file:
            parts = line.strip().split("=")
            if len(parts) == 2 and parts[0] == key_to_find:
                found_password = parts[1]
                break
    if found_password:
        print("Found password.")
        return found_password
    else:
        print("Password not found for key:", key_to_find)

Note: You need to create a credentials.txt file with the following content:  
OPENAI=sk-77QXXXXXXXXXXXXXXXXXXXXXXXXXXX  
replace the value after the = sign with your API key.  
Make sure the credentials.txt is added to .gitignore, you don't want to put your password on Git!

In [18]:
openai.api_key = load_pass("./credentials.txt", "OPENAI")


Found password.


In [19]:
client = OpenAI(api_key=openai.api_key)

### Query GPT

To change the task the model is solving, you need to change the text of the prompt and the content text of the system role.  
To change the GPT model used, you need to change the text of the model name when initiating the openai API.  
The function gets as input the input_raw_text, that will be text for information extraction or classification.

In [20]:
import time

DEFAULT_TEMPERATURE = 0
DEFAULT_MAX_TOKENS = 500
DEFAULT_MODEL = "gpt-3.5-turbo"

def query_gpt(input_raw_text, prompt_text, gpt_model="gpt-3.5-turbo", temperature=0, max_retries=5, retry_delay=3):
    # CHANGE gpt_model to the desired model name, see https://platform.openai.com/docs/models (gpt-3.5-turbo and gpt-4-turbo-preview)
    
    # Add a delay at the beginning of the function to avoid overloading the API if there are multiple calls
    # time.sleep(10)  

    system_msg = f"""
    You are an expert assistant specialized in text classification of PubMed abstracts. """

    retries = 0
    while retries < max_retries:
        print("Trying to call OpenAI API...")
        try:
            completion = client.chat.completions.create(
                model=gpt_model,  
                response_format={"type": "json_object"},
                temperature=temperature,
                #max_tokens=2000,
                messages=[
                    {"role": "system", "content": system_msg},
                    {"role": "user", "content": prompt_text + input_raw_text}
                ]
            )
            return completion.choices[0].message.content
        except Exception as e:
            # Handle API error, e.g., retry or log
            print(f"OpenAI API returned an error: {e}")
            time.sleep(retry_delay)  # Wait before retrying
            retries += 1

    raise RuntimeError("Max retries reached. Unable to complete the API call.")


In [21]:
# Define a function to apply GPT queries with a progress bar
def apply_gpt_with_progress(data_series, prompt_text, model="gpt-3.5-turbo"):
    results = []
    total_items = len(data_series)
    # Create a tqdm progress bar
    with tqdm(total=total_items, desc=f"Processing dataset") as pbar:
        for text in data_series:
            result = query_gpt(text, prompt_text, model)
            results.append(result)
            pbar.update(1)  # Update the progress bar

    return results

# Read prompts from file and query GPT

In [22]:
json_file_path = "./prompt_strategies.json"
# Load the JSON file
with open(json_file_path, 'r') as file:
    prompts_data = json.load(file)

In [24]:
# Sample 50 elements randomly in a reproducible way
sampled_df = df.sample(n=50, random_state=1)
sampled_df.head() #show first few lines of sympampled_df

Unnamed: 0,pmid,journal_name,title,abstract,accepted_label,input_journal_title_abstract,input_title_abstract
608,18540779,Experimental and clinical psychopharmacology,Triacetyluridine (TAU) decreases depressive sy...,Eleven patients with bipolar depression were g...,Human-non-RCT-drug-intervention,<journal>Experimental and clinical psychopharm...,<title>Triacetyluridine (TAU) decreases depres...
1695,29052307,Developmental science,Instrumental learning and cognitive flexibilit...,Children who experience severe early life stre...,Remaining,<journal>Developmental science</journal><title...,<title>Instrumental learning and cognitive fle...
790,23337006,Pediatric neurology,Progressive intracranial fusiform aneurysms an...,"In the pediatric population, intracranial fusi...",Remaining,<journal>Pediatric neurology</journal><title>P...,<title>Progressive intracranial fusiform aneur...
650,24619358,Human molecular genetics,Overexpression of the calpain-specific inhibit...,"Lewy bodies, a pathological hallmark of Parkin...",Animal-drug-intervention,<journal>Human molecular genetics</journal><ti...,<title>Overexpression of the calpain-specific ...
1634,36563534,Journal of behavior therapy and experimental p...,Does fear reduction predict treatment response...,Fear activation and reduction have traditional...,Human-non-RCT-non-drug-intervention,<journal>Journal of behavior therapy and exper...,<title>Does fear reduction predict treatment r...


## Run different prompts over the data

In [26]:
# Add the IDs of Prompts that you want to test
prompt_ids_to_test = ["P1", "P2"] #, "P2", "P3", "P4"
model = "gpt-3.5-turbo"

for prompt in prompts_data["prompts"]:
    prompt_id = prompt["id"]
    prompt_text = prompt["text"]

    if prompt_id in prompt_ids_to_test:
        # Apply GPT predictions
        sampled_df[f'gpt_predictions_{prompt_id}_raw'] = apply_gpt_with_progress(sampled_df['input_journal_title_abstract'], prompt_text, model)
        ## the below includes error handling in case the json formatting did not work as expected
        sampled_df[f'gpt_predictions_{prompt_id}'] = sampled_df[f'gpt_predictions_{prompt_id}_raw'].apply(
            lambda x: json.loads(x)['gpt_label'] if isinstance(x, str) and 'gpt_label' in json.loads(x) else x
        )
        sampled_df.to_csv(f"predictions/{model}_outputs_{'_'.join(prompt_ids_to_test)}.csv") # saving after each prompt strategy has been run, in order not to loose information in case a later strategy fails
    else:
        print(f"Skipping prompt {prompt_id}")
        

Processing dataset:   0%|                          | 0/50 [00:00<?, ?it/s]

Trying to call OpenAI API...


Processing dataset:   2%|▎                 | 1/50 [00:01<01:11,  1.47s/it]

Trying to call OpenAI API...


Processing dataset:   4%|▋                 | 2/50 [00:02<01:03,  1.33s/it]

Trying to call OpenAI API...


Processing dataset:   6%|█                 | 3/50 [00:03<00:53,  1.14s/it]

Trying to call OpenAI API...


Processing dataset:   8%|█▍                | 4/50 [00:04<00:53,  1.17s/it]

Trying to call OpenAI API...


Processing dataset:  10%|█▊                | 5/50 [00:06<01:01,  1.36s/it]

Trying to call OpenAI API...


Processing dataset:  12%|██▏               | 6/50 [00:07<00:55,  1.27s/it]

Trying to call OpenAI API...


Processing dataset:  14%|██▌               | 7/50 [00:09<01:02,  1.46s/it]

Trying to call OpenAI API...


Processing dataset:  16%|██▉               | 8/50 [00:10<00:51,  1.23s/it]

Trying to call OpenAI API...


Processing dataset:  18%|███▏              | 9/50 [00:11<00:52,  1.28s/it]

Trying to call OpenAI API...


Processing dataset:  20%|███▍             | 10/50 [00:14<01:06,  1.66s/it]

Trying to call OpenAI API...


Processing dataset:  22%|███▋             | 11/50 [00:14<00:53,  1.38s/it]

Trying to call OpenAI API...


Processing dataset:  24%|████             | 12/50 [00:15<00:48,  1.29s/it]

Trying to call OpenAI API...


Processing dataset:  26%|████▍            | 13/50 [00:16<00:43,  1.17s/it]

Trying to call OpenAI API...


Processing dataset:  28%|████▊            | 14/50 [00:18<00:44,  1.24s/it]

Trying to call OpenAI API...


Processing dataset:  30%|█████            | 15/50 [00:18<00:38,  1.09s/it]

Trying to call OpenAI API...


Processing dataset:  32%|█████▍           | 16/50 [00:21<00:52,  1.53s/it]

Trying to call OpenAI API...


Processing dataset:  34%|█████▊           | 17/50 [00:22<00:46,  1.41s/it]

Trying to call OpenAI API...


Processing dataset:  36%|██████           | 18/50 [00:23<00:40,  1.25s/it]

Trying to call OpenAI API...


Processing dataset:  38%|██████▍          | 19/50 [00:24<00:36,  1.16s/it]

Trying to call OpenAI API...


Processing dataset:  40%|██████▊          | 20/50 [00:25<00:32,  1.09s/it]

Trying to call OpenAI API...


Processing dataset:  42%|███████▏         | 21/50 [00:26<00:30,  1.04s/it]

Trying to call OpenAI API...


Processing dataset:  44%|███████▍         | 22/50 [00:27<00:28,  1.00s/it]

Trying to call OpenAI API...


Processing dataset:  46%|███████▊         | 23/50 [00:28<00:26,  1.02it/s]

Trying to call OpenAI API...


Processing dataset:  48%|████████▏        | 24/50 [00:29<00:27,  1.05s/it]

Trying to call OpenAI API...


Processing dataset:  50%|████████▌        | 25/50 [00:30<00:23,  1.08it/s]

Trying to call OpenAI API...


Processing dataset:  52%|████████▊        | 26/50 [00:30<00:22,  1.08it/s]

Trying to call OpenAI API...


Processing dataset:  54%|█████████▏       | 27/50 [00:32<00:23,  1.01s/it]

Trying to call OpenAI API...


Processing dataset:  56%|█████████▌       | 28/50 [00:33<00:21,  1.01it/s]

Trying to call OpenAI API...


Processing dataset:  58%|█████████▊       | 29/50 [00:34<00:22,  1.06s/it]

Trying to call OpenAI API...


Processing dataset:  60%|██████████▏      | 30/50 [00:36<00:29,  1.48s/it]

Trying to call OpenAI API...


Processing dataset:  62%|██████████▌      | 31/50 [00:37<00:23,  1.26s/it]

Trying to call OpenAI API...


Processing dataset:  64%|██████████▉      | 32/50 [00:38<00:21,  1.21s/it]

Trying to call OpenAI API...


Processing dataset:  66%|███████████▏     | 33/50 [00:39<00:19,  1.12s/it]

Trying to call OpenAI API...


Processing dataset:  68%|███████████▌     | 34/50 [00:40<00:17,  1.06s/it]

Trying to call OpenAI API...


Processing dataset:  70%|███████████▉     | 35/50 [00:41<00:16,  1.11s/it]

Trying to call OpenAI API...


Processing dataset:  72%|████████████▏    | 36/50 [00:42<00:14,  1.06s/it]

Trying to call OpenAI API...


Processing dataset:  74%|████████████▌    | 37/50 [00:43<00:13,  1.02s/it]

Trying to call OpenAI API...


Processing dataset:  76%|████████████▉    | 38/50 [00:44<00:10,  1.12it/s]

Trying to call OpenAI API...


Processing dataset:  78%|█████████████▎   | 39/50 [00:45<00:10,  1.03it/s]

Trying to call OpenAI API...


Processing dataset:  80%|█████████████▌   | 40/50 [00:46<00:09,  1.02it/s]

Trying to call OpenAI API...


Processing dataset:  82%|█████████████▉   | 41/50 [00:47<00:10,  1.15s/it]

Trying to call OpenAI API...


Processing dataset:  84%|██████████████▎  | 42/50 [00:49<00:09,  1.17s/it]

Trying to call OpenAI API...


Processing dataset:  86%|██████████████▌  | 43/50 [00:50<00:07,  1.10s/it]

Trying to call OpenAI API...


Processing dataset:  88%|██████████████▉  | 44/50 [00:51<00:06,  1.14s/it]

Trying to call OpenAI API...


Processing dataset:  90%|███████████████▎ | 45/50 [00:52<00:05,  1.16s/it]

Trying to call OpenAI API...


Processing dataset:  92%|███████████████▋ | 46/50 [00:53<00:04,  1.09s/it]

Trying to call OpenAI API...


Processing dataset:  94%|███████████████▉ | 47/50 [00:55<00:03,  1.32s/it]

Trying to call OpenAI API...


Processing dataset:  96%|████████████████▎| 48/50 [00:56<00:02,  1.29s/it]

Trying to call OpenAI API...


Processing dataset:  98%|████████████████▋| 49/50 [00:58<00:01,  1.46s/it]

Trying to call OpenAI API...


Processing dataset: 100%|█████████████████| 50/50 [00:59<00:00,  1.19s/it]
Processing dataset:   0%|                          | 0/50 [00:00<?, ?it/s]

Trying to call OpenAI API...


Processing dataset:   2%|▎                 | 1/50 [00:01<00:49,  1.01s/it]

Trying to call OpenAI API...


Processing dataset:   4%|▋                 | 2/50 [00:01<00:45,  1.04it/s]

Trying to call OpenAI API...


Processing dataset:   6%|█                 | 3/50 [00:02<00:39,  1.20it/s]

Trying to call OpenAI API...


Processing dataset:   8%|█▍                | 4/50 [00:03<00:44,  1.04it/s]

Trying to call OpenAI API...


Processing dataset:  10%|█▊                | 5/50 [00:04<00:37,  1.19it/s]

Trying to call OpenAI API...


Processing dataset:  12%|██▏               | 6/50 [00:05<00:33,  1.32it/s]

Trying to call OpenAI API...


Processing dataset:  14%|██▌               | 7/50 [00:05<00:34,  1.23it/s]

Trying to call OpenAI API...


Processing dataset:  16%|██▉               | 8/50 [00:06<00:35,  1.18it/s]

Trying to call OpenAI API...


Processing dataset:  18%|███▏              | 9/50 [00:07<00:35,  1.15it/s]

Trying to call OpenAI API...


Processing dataset:  20%|███▍             | 10/50 [00:08<00:31,  1.26it/s]

Trying to call OpenAI API...


Processing dataset:  22%|███▋             | 11/50 [00:09<00:29,  1.32it/s]

Trying to call OpenAI API...


Processing dataset:  24%|████             | 12/50 [00:10<00:33,  1.13it/s]

Trying to call OpenAI API...


Processing dataset:  26%|████▍            | 13/50 [00:11<00:39,  1.08s/it]

Trying to call OpenAI API...


Processing dataset:  28%|████▊            | 14/50 [00:13<00:43,  1.22s/it]

Trying to call OpenAI API...


Processing dataset:  30%|█████            | 15/50 [00:14<00:41,  1.18s/it]

Trying to call OpenAI API...


Processing dataset:  32%|█████▍           | 16/50 [00:15<00:35,  1.04s/it]

Trying to call OpenAI API...


Processing dataset:  34%|█████▊           | 17/50 [00:15<00:29,  1.11it/s]

Trying to call OpenAI API...


Processing dataset:  36%|██████           | 18/50 [00:16<00:26,  1.20it/s]

Trying to call OpenAI API...


Processing dataset:  38%|██████▍          | 19/50 [00:17<00:29,  1.05it/s]

Trying to call OpenAI API...


Processing dataset:  40%|██████▊          | 20/50 [00:18<00:29,  1.02it/s]

Trying to call OpenAI API...


Processing dataset:  42%|███████▏         | 21/50 [00:19<00:25,  1.12it/s]

Trying to call OpenAI API...


Processing dataset:  44%|███████▍         | 22/50 [00:20<00:26,  1.07it/s]

Trying to call OpenAI API...


Processing dataset:  46%|███████▊         | 23/50 [00:21<00:27,  1.02s/it]

Trying to call OpenAI API...


Processing dataset:  48%|████████▏        | 24/50 [00:23<00:30,  1.18s/it]

Trying to call OpenAI API...


Processing dataset:  50%|████████▌        | 25/50 [00:24<00:29,  1.19s/it]

Trying to call OpenAI API...


Processing dataset:  52%|████████▊        | 26/50 [00:25<00:26,  1.08s/it]

Trying to call OpenAI API...


Processing dataset:  54%|█████████▏       | 27/50 [00:25<00:21,  1.09it/s]

Trying to call OpenAI API...


Processing dataset:  56%|█████████▌       | 28/50 [00:26<00:21,  1.03it/s]

Trying to call OpenAI API...


Processing dataset:  58%|█████████▊       | 29/50 [00:29<00:29,  1.42s/it]

Trying to call OpenAI API...


Processing dataset:  60%|██████████▏      | 30/50 [00:31<00:34,  1.73s/it]

Trying to call OpenAI API...


Processing dataset:  62%|██████████▌      | 31/50 [00:32<00:28,  1.49s/it]

Trying to call OpenAI API...


Processing dataset:  64%|██████████▉      | 32/50 [00:33<00:23,  1.29s/it]

Trying to call OpenAI API...


Processing dataset:  66%|███████████▏     | 33/50 [00:34<00:20,  1.21s/it]

Trying to call OpenAI API...


Processing dataset:  68%|███████████▌     | 34/50 [00:35<00:16,  1.03s/it]

Trying to call OpenAI API...


Processing dataset:  70%|███████████▉     | 35/50 [00:36<00:17,  1.18s/it]

Trying to call OpenAI API...


Processing dataset:  72%|████████████▏    | 36/50 [00:37<00:15,  1.10s/it]

Trying to call OpenAI API...


Processing dataset:  74%|████████████▌    | 37/50 [00:38<00:13,  1.05s/it]

Trying to call OpenAI API...


Processing dataset:  76%|████████████▉    | 38/50 [00:39<00:12,  1.01s/it]

Trying to call OpenAI API...


Processing dataset:  78%|█████████████▎   | 39/50 [00:40<00:10,  1.02it/s]

Trying to call OpenAI API...


Processing dataset:  80%|█████████████▌   | 40/50 [00:41<00:09,  1.07it/s]

Trying to call OpenAI API...


Processing dataset:  82%|█████████████▉   | 41/50 [00:42<00:08,  1.04it/s]

Trying to call OpenAI API...


Processing dataset:  84%|██████████████▎  | 42/50 [00:42<00:06,  1.17it/s]

Trying to call OpenAI API...


Processing dataset:  86%|██████████████▌  | 43/50 [00:44<00:06,  1.03it/s]

Trying to call OpenAI API...


Processing dataset:  88%|██████████████▉  | 44/50 [00:44<00:05,  1.05it/s]

Trying to call OpenAI API...


Processing dataset:  90%|███████████████▎ | 45/50 [00:45<00:04,  1.11it/s]

Trying to call OpenAI API...


Processing dataset:  92%|███████████████▋ | 46/50 [00:46<00:03,  1.19it/s]

Trying to call OpenAI API...


Processing dataset:  94%|███████████████▉ | 47/50 [00:47<00:02,  1.13it/s]

Trying to call OpenAI API...


Processing dataset:  96%|████████████████▎| 48/50 [00:48<00:01,  1.08it/s]

Trying to call OpenAI API...


Processing dataset:  98%|████████████████▋| 49/50 [00:49<00:00,  1.12it/s]

Trying to call OpenAI API...


Processing dataset: 100%|█████████████████| 50/50 [00:50<00:00,  1.01s/it]

Skipping prompt P3
Skipping prompt P4
Skipping prompt P5





In [27]:
sampled_df.head()

Unnamed: 0,pmid,journal_name,title,abstract,accepted_label,input_journal_title_abstract,input_title_abstract,gpt_predictions_P1_raw,gpt_predictions_P1,gpt_predictions_P2_raw,gpt_predictions_P2
608,18540779,Experimental and clinical psychopharmacology,Triacetyluridine (TAU) decreases depressive sy...,Eleven patients with bipolar depression were g...,Human-non-RCT-drug-intervention,<journal>Experimental and clinical psychopharm...,<title>Triacetyluridine (TAU) decreases depres...,"{\n ""gpt_label"": ""Human-case-report""\n}",Human-case-report,"{\n ""gpt_label"": ""Human-case-report""\n}",Human-case-report
1695,29052307,Developmental science,Instrumental learning and cognitive flexibilit...,Children who experience severe early life stre...,Remaining,<journal>Developmental science</journal><title...,<title>Instrumental learning and cognitive fle...,"{\n ""gpt_label"": ""Human-non-RCT-non-drug-in...",Human-non-RCT-non-drug-intervention,"{\n ""gpt_label"": ""Human-case-report""\n}",Human-case-report
790,23337006,Pediatric neurology,Progressive intracranial fusiform aneurysms an...,"In the pediatric population, intracranial fusi...",Remaining,<journal>Pediatric neurology</journal><title>P...,<title>Progressive intracranial fusiform aneur...,"{\n ""gpt_label"": ""Human-case-report""\n}",Human-case-report,"{\n ""gpt_label"": ""Human-case-report""\n}",Human-case-report
650,24619358,Human molecular genetics,Overexpression of the calpain-specific inhibit...,"Lewy bodies, a pathological hallmark of Parkin...",Animal-drug-intervention,<journal>Human molecular genetics</journal><ti...,<title>Overexpression of the calpain-specific ...,"{\n ""gpt_label"": ""Animal-drug-intervention""\n}",Animal-drug-intervention,"{\n ""gpt_label"": ""Animal-drug-intervention""\n}",Animal-drug-intervention
1634,36563534,Journal of behavior therapy and experimental p...,Does fear reduction predict treatment response...,Fear activation and reduction have traditional...,Human-non-RCT-non-drug-intervention,<journal>Journal of behavior therapy and exper...,<title>Does fear reduction predict treatment r...,"{\n ""gpt_label"": ""Human-RCT-drug-interventi...",Human-RCT-drug-intervention,"{\n ""gpt_label"": ""Human-RCT-drug-interventi...",Human-RCT-drug-intervention


## Evaluate each prompt

In [28]:
labels = ["Human-systematic-review", "Human-RCT-drug-intervention", "Human-RCT-non-drug-intervention", "Human-RCT-non-intervention", "Human-case-report", "Human-non-RCT-drug-intervention", "Human-non-RCT-non-drug-intervention", "Animal-systematic-review", "Animal-drug-intervention", "Animal-non-drug-intervention", "Animal-other", "Non-systematic-review", "In-vitro-study", "Clinical-study-protocol", "Remaining"]

label_to_numerical = {label: i for i, label in enumerate(labels)}
label_to_numerical["label missing"] = -1

In [29]:
def map_label_to_numerical(label):
    # Check if label is a dictionary
    if isinstance(label, dict):
        # Extract the label with the highest score/probability
        highest_label = max(label, key=label.get)
        return label_to_numerical.get(highest_label, -1)
    else:
        # Directly map string labels to numerical IDs
        return label_to_numerical.get(label, -1)
        
# Convert accepted labels to numerical
sampled_df['accepted_label_numerical'] = sampled_df['accepted_label'].apply(lambda x: label_to_numerical.get(x, -1))

# Initialize a list to hold DataFrame for each report and summary statistics
report_dfs = []
summary_stats = []

# Iterate over each GPT prediction column
for prompt_id in prompt_ids_to_test:
    print("Evaluating ", prompt_id)
    prediction_col = f'gpt_predictions_{prompt_id}'
    
    # Map GPT predictions to numerical values
    sampled_df[f'{prediction_col}_numerical'] = sampled_df[prediction_col].apply(map_label_to_numerical)
    
    # Extract arrays for evaluation
    y_true = sampled_df['accepted_label_numerical'].values
    y_pred = sampled_df[f'{prediction_col}_numerical'].values
    
    # Calculate metrics
    accuracy = accuracy_score(y_true, y_pred)
    accuracy_balanced = balanced_accuracy_score(y_true, y_pred)
    report = classification_report(y_true, y_pred, output_dict=True, zero_division=0, labels=range(len(labels)), target_names=labels)
    
    # Create DataFrame from report
    report_df = pd.DataFrame(report).transpose()
    report_df['Prompt ID'] = prompt_id  # Add column to indicate the prompt ID
    report_dfs.append(report_df)
    
    # Extract summary statistics (average precision, recall, F1)
    summary = report_df.loc['weighted avg', ['precision', 'recall', 'f1-score']].to_dict()
    summary['Prompt ID'] = prompt_id
    summary_stats.append(summary)

# Combine all report DataFrames
all_reports_df = pd.concat(report_dfs)

# Create a summary table for average precision, recall, and F1-score
summary_df = pd.DataFrame(summary_stats)

Evaluating  P1
Evaluating  P2




In [30]:
summary_df

Unnamed: 0,precision,recall,f1-score,Prompt ID
0,0.508889,0.32,0.327273,P1
1,0.584487,0.26,0.258061,P2


In [None]:
all_reports_df.to_csv(f"evaluations/{model}_per_class_{'_'.join(prompt_ids_to_test)}.csv") # saving after each prompt strategy has been run, in order not to loose information in case a later strategy fails
summary_df.to_csv(f"evaluations/{model}_summary_{'_'.join(prompt_ids_to_test)}.csv") # saving after each prompt strategy has been run, in order not to loose information in case a later strategy fails


In [31]:
all_reports_df

Unnamed: 0,precision,recall,f1-score,support,Prompt ID
Human-systematic-review,0.0,0.0,0.0,0.0,P1
Human-RCT-drug-intervention,0.666667,1.0,0.8,4.0,P1
Human-RCT-non-drug-intervention,0.0,0.0,0.0,1.0,P1
Human-RCT-non-intervention,0.0,0.0,0.0,0.0,P1
Human-case-report,0.222222,1.0,0.363636,2.0,P1
Human-non-RCT-drug-intervention,1.0,0.333333,0.5,3.0,P1
Human-non-RCT-non-drug-intervention,0.0,0.0,0.0,4.0,P1
Animal-systematic-review,0.0,0.0,0.0,0.0,P1
Animal-drug-intervention,0.666667,0.8,0.727273,5.0,P1
Animal-non-drug-intervention,0.0,0.0,0.0,0.0,P1
