In [111]:
import torch
import numpy as np
from pathlib import Path

from sklearn.metrics.pairwise import pairwise_distances
from sklearn.metrics import pairwise
import pandas as pd
from openai import OpenAI
import openai

from sklearn.metrics import accuracy_score, classification_report, balanced_accuracy_score
from sklearn.metrics import balanced_accuracy_score

import random


In [30]:
embeddings_data_path = Path("./data/embeddings/")
input_data_path = Path("./data/data_splits_stratified/6-2-2_all_classes/")

## Load Embeddings and Calculate Similarities

### Load 
-> each row represents the text from one sample embedded into a 768-demnsional vector

In [9]:
embeddings_train = np.load(
    embeddings_data_path / "embeddings_microsoft_BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext_train_ds.npy"
)
embeddings_test = np.load(
    embeddings_data_path / "embeddings_microsoft_BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext_test_ds.npy"
)

In [10]:
embeddings_train.shape

(1191, 768)

In [11]:
embeddings_test.shape

(404, 768)

### Compute similarity between test dataset elements to all train elements


In [122]:
dist_matrix = pairwise_distances(embeddings_test, embeddings_train, metric='sqeuclidean')
dist_matrix.shape

(404, 1191)

In [125]:
dist_matrix_cosine = pairwise.cosine_similarity(X=embeddings_test, Y=embeddings_train)
dist_matrix.shape

(404, 1191)

The matrix has the dimensions 404 (test data elements) x 1191 (train data elements). Each of the 1191 values per test row is the similarity score between the test element to the train elements.

In [129]:
dist_matrix_cosine

array([[0.98506499, 0.98830865, 0.98870777, ..., 0.98741564, 0.98696952,
        0.98295794],
       [0.99007619, 0.99129481, 0.9916276 , ..., 0.99085124, 0.99138424,
        0.98798618],
       [0.98900018, 0.99066293, 0.9915621 , ..., 0.98766824, 0.99073964,
        0.98834762],
       ...,
       [0.98882802, 0.9873871 , 0.98857115, ..., 0.9928928 , 0.98869445,
        0.98763254],
       [0.98709908, 0.987791  , 0.9864768 , ..., 0.9930578 , 0.98749872,
        0.98896695],
       [0.98904398, 0.98862152, 0.98700841, ..., 0.9863798 , 0.99105832,
        0.99154065]])

### Find closest neighbours from the train dataset to each test example

In [126]:
values, indices = torch.topk(-torch.from_numpy(dist_matrix_cosine), k=3, dim=-1)

In [130]:
indices.shape

torch.Size([404, 3])

The indices have the dimension 404 (test data elements) x 3 (top k=3 closest train data elements).

In [127]:
indices

tensor([[ 128, 1032,  413],
        [ 128, 1032,  413],
        [ 128,  413, 1032],
        ...,
        [ 128,   11, 1032],
        [  11,  128,  356],
        [ 128,  356, 1032]])

In [None]:
df_train = pd.read_csv(input_data_path/ 'train.csv')
df_test = pd.read_csv(input_data_path/ 'test.csv')


In [103]:
df_test[df_test['accepted_label'] == 'In-vitro-study']

Unnamed: 0,idx,pmid,journal_name,title,abstract,accepted_label,multi_label,binary_label,input_journal_title_abstract,gpt_predictions_in_context
245,246,27983922,Oncology research,Overexpression of Protease Serine 8 Inhibits G...,"Protease serine 8 (PRSS8), a serine peptidase,...",In-vitro-study,8,0,<journal>Oncology research</journal><title>Ove...,"{\n ""gpt_label"": ""Non-systematic-review""\n}"
246,247,16332401,Neurobiology of aging,Increased cholesterol in Abeta-positive nerve ...,Synapse loss in Alzheimer's disease (AD) is po...,In-vitro-study,8,0,<journal>Neurobiology of aging</journal><title...,"{\n ""gpt_label"": ""Non-systematic-review""\n}"
247,248,14655759,"Brain pathology (Zurich, Switzerland)",TRAIL triggers apoptosis in human malignant gl...,Many malignant glioma cells express death rece...,In-vitro-study,8,0,"<journal>Brain pathology (Zurich, Switzerland)...","{\n ""gpt_label"": ""Non-systematic-review""\n}"
248,249,8297640,APMIS. Supplementum,Fetal antigen 2 (FA2): the aminopropeptide of ...,Fetal antigen 2 (FA2) was purified from second...,In-vitro-study,8,0,<journal>APMIS. Supplementum</journal><title>F...,"{\n ""gpt_label"": ""Remaining""\n}"
249,250,19630956,Molecular brain,Regulation of endosomal motility and degradati...,"Dysfunction of alsin, particularly its putativ...",In-vitro-study,8,0,<journal>Molecular brain</journal><title>Regul...,"{\n ""gpt_label"": ""Remaining""\n}"
250,251,25413246,Cell transplantation,Human umbilical cord blood cells induce neurop...,Human umbilical cord blood (HUCB) cell therapi...,In-vitro-study,8,0,<journal>Cell transplantation</journal><title>...,"{\n ""gpt_label"": ""Human-non-RCT-non-drug-inte..."
251,252,21055392,Biochemical and biophysical research communica...,In yeast redistribution of Sod1 to the mitocho...,The antioxidative enzyme copper-zinc superoxid...,In-vitro-study,8,0,<journal>Biochemical and biophysical research ...,"{\n ""gpt_label"": ""Non-systematic-review""\n}"
252,253,16696567,Chemical research in toxicology,Ebselen induced C6 glioma cell death in oxygen...,Studies have shown that ebselen is an antiinfl...,In-vitro-study,8,0,<journal>Chemical research in toxicology</jour...,"{\n ""gpt_label"": ""Non-systematic-review""\n}"
253,254,22363216,PLoS genetics,GTPase activity and neuronal toxicity of Parki...,Mutations in the leucine-rich repeat kinase 2 ...,In-vitro-study,8,0,<journal>PLoS genetics</journal><title>GTPase ...,"{\n ""gpt_label"": ""Remaining""\n}"
254,255,25960208,Biomedicine & pharmacotherapy = Biomedecine & ...,miR-25 promotes glioma cell proliferation by t...,MicroRNAs (miRNA) have oncogenic or tumor-supp...,In-vitro-study,8,0,<journal>Biomedicine & pharmacotherapy = Biome...,"{\n ""gpt_label"": ""Human-non-RCT-non-drug-inte..."


In [104]:
df_test.iloc[246]

idx                                                                           247
pmid                                                                     16332401
journal_name                                                Neurobiology of aging
title                           Increased cholesterol in Abeta-positive nerve ...
abstract                        Synapse loss in Alzheimer's disease (AD) is po...
accepted_label                                                     In-vitro-study
multi_label                                                                     8
binary_label                                                                    0
input_journal_title_abstract    <journal>Neurobiology of aging</journal><title...
gpt_predictions_in_context           {\n  "gpt_label": "Non-systematic-review"\n}
Name: 246, dtype: object

In [105]:
indices[246]

tensor([ 128,  356, 1032])

In [107]:
df_train.iloc[1032]

idx                                                                          1033
pmid                                                                     28764897
journal_name                                                    Atencion primaria
title                           [Clinical characteristics of patients with atr...
abstract                        To analyse the clinical characteristics and ma...
accepted_label                                    Human-non-RCT-drug-intervention
multi_label                                                                     3
binary_label                                                                    0
input_journal_title_abstract    <journal>Atencion primaria</journal><title>[Cl...
Name: 1032, dtype: object

In [131]:
df_test.groupby('accepted_label').size()

accepted_label
Animal-drug-intervention                20
Animal-non-drug-intervention             6
Animal-other                            21
Clinical-study-protocol                  2
Human-RCT-drug-intervention              5
Human-RCT-non-drug-intervention          8
Human-RCT-non-intervention               1
Human-case-report                       23
Human-non-RCT-drug-intervention         26
Human-non-RCT-non-drug-intervention     32
Human-systematic-review                 12
In-vitro-study                          11
Non-systematic-review                   65
Remaining                              172
dtype: int64

## Init OpenAI API

In [None]:
def load_pass(file_path, key_to_find):
    with open(file_path, 'r') as file:
        for line in file:
            parts = line.strip().split("=")
            if len(parts) == 2 and parts[0] == key_to_find:
                found_password = parts[1]
                break
    if found_password:
        print("Found password.")
        return found_password
    else:
        print("Password not found for key:", key_to_find)

In [None]:
openai.api_key = load_pass("./credentials.txt", "OPENAI")
client = OpenAI(api_key=openai.api_key)

## Create Prompts

In [48]:
df_train['input_journal_title_abstract'] = '<journal>' + df_train['journal_name'] + '</journal>' + \
                                         '<title>' + df_train['title'] + '</title>' + \
                                         '<abstract>' + df_train['abstract'] + '</abstract>'
df_test['input_journal_title_abstract'] = '<journal>' + df_test['journal_name'] + '</journal>' + \
                                         '<title>' + df_test['title'] + '</title>' + \
                                         '<abstract>' + df_test['abstract'] + '</abstract>'

In [50]:
df_test

Unnamed: 0,idx,pmid,journal_name,title,abstract,accepted_label,multi_label,binary_label,input_journal_title_abstract
0,1,15055442,Journal of neuropathology and experimental neu...,Emerging tumor entities and variants of CNS ne...,Since the appearance in 2000 of the World Heal...,Non-systematic-review,1,0,<journal>Journal of neuropathology and experim...
1,2,11172874,Journal of affective disorders,Serotonergic gene expression and depression: i...,The development and configuration of several n...,Non-systematic-review,1,0,<journal>Journal of affective disorders</journ...
2,3,19961324,Annual review of entomology,Ekbom syndrome: the challenge of 'invisible bu...,Ekbom Syndrome is synonymous with delusory par...,Non-systematic-review,1,0,<journal>Annual review of entomology</journal>...
3,4,11077858,Rozhledy v chirurgii : mesicnik Ceskoslovenske...,[Brain injuries].,The author presents an account of contemporary...,Non-systematic-review,1,0,<journal>Rozhledy v chirurgii : mesicnik Cesko...
4,5,20362421,Archives de pediatrie : organe officiel de la ...,[Treatment of childhood dystonia].,"Dystonia is not uncommon in childhood, but is ...",Non-systematic-review,1,0,<journal>Archives de pediatrie : organe offici...
...,...,...,...,...,...,...,...,...,...
399,400,37550718,Trials,Can dexamethasone improve postoperative sleep ...,Perioperative sleep disorders (PSD) are an ind...,Human-RCT-drug-intervention,11,0,<journal>Trials</journal><title>Can dexamethas...
400,401,11279969,Headache,Intranasal lidocaine for migraine: a randomize...,To study the efficacy of intranasal lidocaine ...,Human-RCT-drug-intervention,11,0,<journal>Headache</journal><title>Intranasal l...
401,402,33393402,Developmental neurorehabilitation,Lower Limb Sensorimotor Training (LoSenseT) fo...,Motor disorders in cerebral palsy (CP) are oft...,Clinical-study-protocol,12,0,<journal>Developmental neurorehabilitation</jo...
402,403,32541457,Medicine,Association between non-alcoholic fatty liver ...,This study will systematically synthesize the ...,Clinical-study-protocol,12,0,<journal>Medicine</journal><title>Association ...


In [55]:
def create_prompt(df_train, df_test, test_index, example_indices):
    # Start the prompt with a task description (optional)
    prompt = "Classify this text, choosing one of these labels: Clinical-study-protocol, Human-systematic-review, Non-systematic-review, Human-RCT-non-drug-intervention, Human-RCT-drug-intervention, Human-RCT-non-intervention, Human-case-report, Human-non-RCT-non-drug-intervention, Human-non-RCT-drug-intervention, Animal-systematic-review, Animal-drug-intervention, Animal-non-drug-intervention, Animal-other, In-vitro-study, Remaining. Respond in json format with the key: gpt_label.\n\n"
    
    # Add examples from df_train
    for idx in example_indices:
        example_text = df_train.loc[idx, 'input_journal_title_abstract']
        example_label = df_train.loc[idx, 'accepted_label']
        prompt += f"Text: \"{example_text}\"\nCategory: {example_label}\n\n"
    
    # Add the test text needing classification
    test_text = df_test.loc[test_index, 'input_journal_title_abstract']
    prompt += f"Text: \"{test_text}\"\nCategory: "
    
    return prompt

In [78]:
example_i = 0
example_prompt = create_prompt(df_train, df_test, example_i, indices[example_i].tolist())
#example_prompt

In [70]:
len(example_prompt.split())

1483

In [75]:
query_gpt(example_prompt)

Trying to call OpenAI API...


'{\n    "gpt_label": "Non-systematic-review"\n}'

In [109]:
import time
from tqdm.auto import tqdm

DEFAULT_TEMPERATURE = 0

def create_prompt(df_train, example_indices, input_raw_text):
    prompt = "Classify this text, choosing one of these labels: Clinical-study-protocol, Human-systematic-review, Non-systematic-review, Human-RCT-non-drug-intervention, Human-RCT-drug-intervention, Human-RCT-non-intervention, Human-case-report, Human-non-RCT-non-drug-intervention, Human-non-RCT-drug-intervention, Animal-systematic-review, Animal-drug-intervention, Animal-non-drug-intervention, Animal-other, In-vitro-study, Remaining. Respond in json format with the key: gpt_label.\n\n"
    for idx in example_indices:
        example_text = df_train.loc[idx, 'input_journal_title_abstract']
        example_label = df_train.loc[idx, 'accepted_label']
        prompt += f"Text: \"{example_text}\"\nCategory: {example_label}\n\n"
    prompt += f"Text: \"{input_raw_text}\"\nCategory: "
    return prompt

def query_gpt(df_train, input_raw_text, example_indices, gpt_model="gpt-3.5-turbo", temperature=DEFAULT_TEMPERATURE, max_retries=5, retry_delay=3):
    prompt_text = create_prompt(df_train, example_indices, input_raw_text)
    system_msg = f"You are an expert assistant specialized in text classification of PubMed abstracts."

    retries = 0
    while retries < max_retries:
        print("Trying to call OpenAI API...")
        try:
            completion = client.chat.completions.create(
                model=gpt_model,  
                response_format={"type": "json_object"},
                temperature=temperature,
                messages=[
                    {"role": "system", "content": system_msg},
                    {"role": "user", "content": prompt_text}
                ]
            )
            return completion.choices[0].message.content
        except Exception as e:
            print(f"OpenAI API returned an error: {e}")
            time.sleep(retry_delay)
            retries += 1

    raise RuntimeError("Max retries reached. Unable to complete the API call.")

def apply_gpt_with_progress(df_train, test_data_series, example_indices_tensor=None, num_samples=3, use_random=False, model="gpt-3.5-turbo"):
    results = []
    total_items = len(test_data_series)
    with tqdm(total=total_items, desc="Processing dataset") as pbar:
        for i, text in enumerate(test_data_series):
            if use_random:
                example_indices = random.sample(range(len(df_train)), num_samples)
            else:
                example_indices = example_indices_tensor[i].tolist()
            print("Retrieved in-context learning examples with idx: ", example_indices)
            result = query_gpt(df_train, text, example_indices, model)
            results.append(result)
            pbar.update(1)
    return results


In [84]:
# Example usage:
df_test[f'gpt_predictions_in_context'] = apply_gpt_with_progress(df_train, df_test['input_journal_title_abstract'], indices)

Processing dataset:   0%|          | 0/404 [00:00<?, ?it/s]

Retrieved in-context learning examples with idx:  [128, 1032, 413]
Trying to call OpenAI API...
Retrieved in-context learning examples with idx:  [128, 1032, 413]
Trying to call OpenAI API...
Retrieved in-context learning examples with idx:  [128, 413, 1032]
Trying to call OpenAI API...
Retrieved in-context learning examples with idx:  [128, 1032, 413]
Trying to call OpenAI API...
Retrieved in-context learning examples with idx:  [413, 356, 128]
Trying to call OpenAI API...
Retrieved in-context learning examples with idx:  [128, 413, 11]
Trying to call OpenAI API...
Retrieved in-context learning examples with idx:  [128, 11, 413]
Trying to call OpenAI API...
Retrieved in-context learning examples with idx:  [128, 413, 356]
Trying to call OpenAI API...
Retrieved in-context learning examples with idx:  [128, 413, 356]
Trying to call OpenAI API...
Retrieved in-context learning examples with idx:  [128, 1032, 413]
Trying to call OpenAI API...
Retrieved in-context learning examples with idx

In [112]:
df_test[f'gpt_predictions_in_context_random'] = apply_gpt_with_progress(df_train, df_test['input_journal_title_abstract'], use_random=True)

Processing dataset:   0%|          | 0/404 [00:00<?, ?it/s]

Retrieved in-context learning examples with idx:  [631, 607, 853]
Trying to call OpenAI API...
Retrieved in-context learning examples with idx:  [41, 721, 695]
Trying to call OpenAI API...
Retrieved in-context learning examples with idx:  [808, 578, 556]
Trying to call OpenAI API...
Retrieved in-context learning examples with idx:  [107, 999, 1011]
Trying to call OpenAI API...
Retrieved in-context learning examples with idx:  [728, 40, 738]
Trying to call OpenAI API...
Retrieved in-context learning examples with idx:  [724, 627, 614]
Trying to call OpenAI API...
Retrieved in-context learning examples with idx:  [284, 442, 333]
Trying to call OpenAI API...
Retrieved in-context learning examples with idx:  [1043, 420, 682]
Trying to call OpenAI API...
Retrieved in-context learning examples with idx:  [652, 418, 1035]
Trying to call OpenAI API...
Retrieved in-context learning examples with idx:  [519, 617, 1027]
Trying to call OpenAI API...
Retrieved in-context learning examples with idx:

## Evaluate

In [113]:
df_test_to_eval = df_test.copy()
df_test_to_eval.head()

Unnamed: 0,idx,pmid,journal_name,title,abstract,accepted_label,multi_label,binary_label,input_journal_title_abstract,gpt_predictions_in_context,gpt_predictions_in_context_random
0,1,15055442,Journal of neuropathology and experimental neu...,Emerging tumor entities and variants of CNS ne...,Since the appearance in 2000 of the World Heal...,Non-systematic-review,1,0,<journal>Journal of neuropathology and experim...,"{\n ""gpt_label"": ""Non-systematic-review""\n}","{\n ""gpt_label"": ""Remaining""\n}"
1,2,11172874,Journal of affective disorders,Serotonergic gene expression and depression: i...,The development and configuration of several n...,Non-systematic-review,1,0,<journal>Journal of affective disorders</journ...,"{\n ""gpt_label"": ""Non-systematic-review""\n}","{\n ""gpt_label"": ""Non-systematic-review""\n}"
2,3,19961324,Annual review of entomology,Ekbom syndrome: the challenge of 'invisible bu...,Ekbom Syndrome is synonymous with delusory par...,Non-systematic-review,1,0,<journal>Annual review of entomology</journal>...,"{\n ""gpt_label"": ""Non-systematic-review""\n}","{\n ""gpt_label"": ""Human-non-RCT-non-drug-in..."
3,4,11077858,Rozhledy v chirurgii : mesicnik Ceskoslovenske...,[Brain injuries].,The author presents an account of contemporary...,Non-systematic-review,1,0,<journal>Rozhledy v chirurgii : mesicnik Cesko...,"{\n ""gpt_label"": ""Non-systematic-review""\n}","{\n ""gpt_label"": ""Non-systematic-review""\n}"
4,5,20362421,Archives de pediatrie : organe officiel de la ...,[Treatment of childhood dystonia].,"Dystonia is not uncommon in childhood, but is ...",Non-systematic-review,1,0,<journal>Archives de pediatrie : organe offici...,"{\n ""gpt_label"": ""Human-non-RCT-non-drug-inte...","{\n ""gpt_label"": ""Remaining""\n}"


In [116]:
prompt_ids_to_test = ["in_context", "in_context_random"]

In [115]:
labels = ["Human-systematic-review", "Human-RCT-drug-intervention", "Human-RCT-non-drug-intervention", "Human-RCT-non-intervention", "Human-case-report", "Human-non-RCT-drug-intervention", "Human-non-RCT-non-drug-intervention", "Animal-systematic-review", "Animal-drug-intervention", "Animal-non-drug-intervention", "Animal-other", "Non-systematic-review", "In-vitro-study", "Clinical-study-protocol", "Remaining"]

label_to_numerical = {label: i for i, label in enumerate(labels)}
label_to_numerical["label missing"] = -1

In [117]:
def map_label_to_numerical(label):
    # Check if label is a dictionary
    if isinstance(label, dict):
        # Extract the label with the highest score/probability
        highest_label = max(label, key=label.get)
        return label_to_numerical.get(highest_label, -1)
    else:
        # Directly map string labels to numerical IDs
        return label_to_numerical.get(label, -1)
        
# Convert accepted labels to numerical
df_test_to_eval['accepted_label_numerical'] = df_test_to_eval['accepted_label'].apply(lambda x: label_to_numerical.get(x, -1))


# Initialize a list to hold DataFrame for each report and summary statistics
report_dfs = []
summary_stats = []

# Iterate over each GPT prediction column
for prompt_id in prompt_ids_to_test:
    print("Evaluating ", prompt_id)
    prediction_col = f'gpt_predictions_{prompt_id}_clean'

    df_test_to_eval[prediction_col] = df_test_to_eval[f'gpt_predictions_{prompt_id}'].apply(
            lambda x: json.loads(x)['gpt_label'] if isinstance(x, str) and 'gpt_label' in json.loads(x) else x
        )
    
    # Map GPT predictions to numerical values
    df_test_to_eval[f'{prediction_col}_numerical'] = df_test_to_eval[prediction_col].apply(map_label_to_numerical)
    
    # Extract arrays for evaluation
    y_true = df_test_to_eval['accepted_label_numerical'].values
    y_pred = df_test_to_eval[f'{prediction_col}_numerical'].values
    
    # Calculate metrics
    accuracy = accuracy_score(y_true, y_pred)
    accuracy_balanced = balanced_accuracy_score(y_true, y_pred)
    report = classification_report(y_true, y_pred, output_dict=True, zero_division=0, labels=range(len(labels)), target_names=labels)
    
    # Create DataFrame from report
    report_df = pd.DataFrame(report).transpose()
    report_df['Prompt ID'] = prompt_id  # Add column to indicate the prompt ID
    report_dfs.append(report_df)
    
    # Extract summary statistics (average precision, recall, F1)
    summary = report_df.loc['weighted avg', ['precision', 'recall', 'f1-score']].to_dict()
    summary['Prompt ID'] = prompt_id
    summary_stats.append(summary)

# Combine all report DataFrames
all_reports_df = pd.concat(report_dfs)

# Create a summary table for average precision, recall, and F1-score
summary_df = pd.DataFrame(summary_stats)


Evaluating  in_context
Evaluating  in_context_random




In [119]:
all_reports_df

Unnamed: 0,precision,recall,f1-score,support,Prompt ID
Human-systematic-review,0.75,0.25,0.375,12.0,in_context
Human-RCT-drug-intervention,1.0,0.2,0.333333,5.0,in_context
Human-RCT-non-drug-intervention,0.0,0.0,0.0,8.0,in_context
Human-RCT-non-intervention,0.0,0.0,0.0,1.0,in_context
Human-case-report,0.46875,0.652174,0.545455,23.0,in_context
Human-non-RCT-drug-intervention,0.0,0.0,0.0,26.0,in_context
Human-non-RCT-non-drug-intervention,0.142857,0.125,0.133333,32.0,in_context
Animal-systematic-review,0.0,0.0,0.0,0.0,in_context
Animal-drug-intervention,0.4,0.1,0.16,20.0,in_context
Animal-non-drug-intervention,1.0,0.166667,0.285714,6.0,in_context


In [120]:
summary_df

Unnamed: 0,precision,recall,f1-score,Prompt ID
0,0.278529,0.19802,0.154899,in_context
1,0.313046,0.324257,0.313494,in_context_random


In [128]:
print(summary_df)

   precision    recall  f1-score          Prompt ID
0   0.278529  0.198020  0.154899         in_context
1   0.313046  0.324257  0.313494  in_context_random
