In [None]:
import time
import pandas as pd

from tqdm import tqdm
from google.cloud import aiplatform
from sklearn.metrics import classification_report
from vertexai.language_models import ChatModel, InputOutputTextPair
from utils import html_parsing_ncbi, html_parsing_n2c2, get_classification_report, get_digit, get_macro_average_f1

aiplatform.init(project='xxx-xxx-xxx')
chat_model = ChatModel.from_pretrained("chat-bison@001")

> You will have to setup a project on Google Cloud that enables Vertex AI API, and replace 'xxx-xxx-xxx' with your own project ID. The free trial period of Google Cloud has limited quota for Vertex AI API for PaLM2 model Bison per minute (~60 per minute). If you encounter quota exceeded error, please try again after that minute and continue from where you left off in the for loop.

# 1. NER (Named Entity Recognition)

## 1.1 NCBI-Disease Dataset

### 1.1.1 Inference

In [None]:
ncbi_df = pd.read_csv('data/NER/NCBI-disease/test_200.csv')

In [None]:
def get_ner_ncbi_disease(sentence: str, shot: int = 0) -> str:

    parameters = {
        "temperature": 0.0,
    }

    chat = chat_model.start_chat(
        context="""
                "TASK: the task is to extract disease entities in a sentence."
                "INPUT: the input is a sentence."
                "OUTPUT: the output is an HTML that highlights all the disease entities in the sentence. \
                        The highlighting should only use HTML tags <span style=\"background-color: #FFFF00\"> and </span> and no other tags."
                """,
        examples=[
            InputOutputTextPair(
                input_text="In summary , inactivation of the murine ATP7B gene produces a form of cirrhotic liver disease that resembles Wilson disease in humans and the toxic milk phenotype in the mouse . .",
                output_text='In summary , inactivation of the murine ATP7B gene produces a form of <span style="background-color: #FFFF00">cirrhotic liver disease</span> \
                            that resembles <span style="background-color: #FFFF00">Wilson disease</span> in humans and the toxic milk phenotype in the mouse . .',
            ),
        ] if shot == 1 else []
    )

    time_start = time.time()
    response = chat.send_message(
        sentence, **parameters
    )
    time_end = time.time()

    return response.text, time_end - time_start

In [None]:
for i in tqdm(range(0, len(ncbi_df), 1)):
    if (i + 1) % 20 == 0: # in case of quota limit error per minute
        time.sleep(65)
    ncbi_df.loc[i, 'html_palm2_zero_shot'], ncbi_df.loc[i, 'palm2_zero_shot_time'] = get_ner_ncbi_disease(ncbi_df.loc[i, 'text'], 0)
    ncbi_df.loc[i, 'html_palm2_one_shot'], ncbi_df.loc[i, 'palm2_one_shot_time'] = get_ner_ncbi_disease(ncbi_df.loc[i, 'text'], 1)

### 1.1.2 Evaluation

In [None]:
# Optional: you can just load the llm output from the csv file instead of running the above code
# ncbi_df = pd.read_csv("data/NER/NCBI-disease/test_200_palm2_results.csv")

In [None]:
ncbi_df['gt_labels'], ncbi_df['palm2_zero_shot_labels'] = html_parsing_ncbi(ncbi_df, 'html_palm2_zero_shot')
_, ncbi_df['palm2_one_shot_labels'] = html_parsing_ncbi(ncbi_df, 'html_palm2_one_shot')

In [None]:
get_classification_report(ncbi_df, 'gt_labels', 'palm2_one_shot_labels', 'strict')

In [None]:
get_classification_report(ncbi_df, 'gt_labels', 'palm2_one_shot_labels', 'lenient')

In [None]:
print(f"Average PaLM 2 zero-shot prediction time: {ncbi_df['palm2_zero_shot_time'].mean():.2f} seconds")
print(f"Average PaLM 2 one-shot prediction time: {ncbi_df['palm2_one_shot_time'].mean():.2f} seconds")

In [None]:
# save the inference results
ncbi_df.to_csv('data/NER/NCBI-disease/test_200_palm2_results.csv', index=False)

# 1.2 2018 n2c2 Dataset

### 1.2.1 Inference

In [None]:
n2c2_df = pd.read_csv('data/NER/2018_n2c2/test_200.csv')

In [None]:
def get_ner_2018_n2c2(sentence: str, shot: int = 0) -> str:

    parameters = {
        "temperature": 0.0,
    }

    chat = chat_model.start_chat(
        context="""
                "TASK: the task is to extract disease entities in a sentence. The entity type includes Form, Route, Frequency, Dosage, Strength, Duration, Reason, Ade, Drug."
                "INPUT: the input is a sentence."
                "OUTPUT: the output is an HTML that highlights all the disease entities in the sentence in different colors: Form(#FF0000), Route(#FFA500), Frequency(#FFFF00), Dosage(#00FF00), Strength(#0000FF), Duration(#800080), Reason(#FFC0CB), Ade(#964B00), Drug(#808080) in hex code. \
                        The highlighting should only use HTML tags <span style=\"background-color: #XXXXXX\"> and </span> and no other tags."
                """,
        examples=[
            InputOutputTextPair(
                input_text="Vitamin D 400 unit Tablet Sig : Two ( 2 ) Tablet PO once a day .",
                output_text='<span style="background-color: #808080">Vitamin D</span> <span style="background-color: #0000FF">400 unit</span> <span style="background-color: #FF0000">Tablet</span> Sig : <span style="background-color: #00FF00">Two ( 2 )</span> <span style="background-color: #FF0000">Tablet</span> <span style="background-color: #FFA500">PO</span> <span style="background-color: #FFFF00">once a day</span> .',
            ),
        ] if shot == 1 else []
    )

    time_start = time.time()
    response = chat.send_message(
        sentence, **parameters
    )
    time_end = time.time()

    return response.text, time_end - time_start

In [None]:
for i in tqdm(range(0, len(n2c2_df), 1)):
    if (i + 1) % 20 == 0: # in case of quota limit error per minute
        time.sleep(65)
    n2c2_df.loc[i, 'html_palm2_zero_shot'], n2c2_df.loc[i, 'palm2_zero_shot_time'] = get_ner_2018_n2c2(n2c2_df.loc[i, 'text'], 0)
    n2c2_df.loc[i, 'html_palm2_one_shot'], n2c2_df.loc[i, 'palm2_one_shot_time'] = get_ner_2018_n2c2(n2c2_df.loc[i, 'text'], 1)

### 1.2.2 Evaluation

In [None]:
# Optional: you can just load the llm output from the csv file instead of running the above code
# n2c2_df = pd.read_csv("data/NER/2018_n2c2/test_200_palm2_results.csv")

In [None]:
n2c2_df['gt_labels'], n2c2_df['palm2_zero_shot_labels'] = html_parsing_n2c2(n2c2_df, 'html_palm2_zero_shot')
_, n2c2_df['palm2_one_shot_labels'] = html_parsing_n2c2(n2c2_df, 'html_palm2_one_shot')

In [None]:
get_classification_report(n2c2_df, 'gt_labels', 'palm2_one_shot_labels', 'strict')

In [None]:
get_macro_average_f1(get_classification_report(n2c2_df, 'gt_labels', 'palm2_one_shot_labels', 'strict'))

In [None]:
get_classification_report(n2c2_df, 'gt_labels', 'palm2_one_shot_labels', 'lenient')

In [None]:
get_macro_average_f1(get_classification_report(n2c2_df, 'gt_labels', 'palm2_one_shot_labels', 'lenient'))

In [None]:
print(f"Average PaLM 2 zero-shot prediction time: {n2c2_df['palm2_zero_shot_time'].mean():.2f} seconds")
print(f"Average PaLM 2 one-shot prediction time: {n2c2_df['palm2_one_shot_time'].mean():.2f} seconds")

In [None]:
# save the inference results
n2c2_df.to_csv('data/NER/2018_n2c2/test_200_palm2_results.csv', index=False)

# 2. RE (Relation Extraction)

## 2.1 2018 n2c2 Dataset

### 2.1.1 Infernece

In [None]:
n2c2_df = pd.read_csv('data/ER/2018_n2c2/test_200.csv')

In [None]:
def get_re_2018_n2c2(sentence: str, shot: int = 0) -> str:

    parameters = {
        "temperature": 0.0,
    }

    chat = chat_model.start_chat(
        context="""
                "TASK: the task is to classify relations for a sentence."
                "INPUT: the input is a sentence where the entities are labeled within [E${X}] and [E${X}/] in a sentence, where X is an integer representing an unique entity."
                "OUTPUT: your task is to select one out of the nine types of relations ('STRENGTH-DRUG', 'ROUTE-DRUG', 'FREQUENCY-DRUG', 'FORM-DRUG', 'DOSAGE-DRUG', \
                        'REASON-DRUG', 'DURATION-DRUG', 'ADE-DRUG', and 'No relation')."
                """,
        examples=[
            InputOutputTextPair(
                input_text="[E2] Docusate/Sodium [E2/] ( Liquid ) 100/mg PO BID/:/PRN [E1] constipation [E1/] 4 .",
                output_text='REASON-DRUG',
            ),
        ] if shot == 1 else []
    )

    time_start = time.time()
    response = chat.send_message(
        sentence, **parameters
    )
    time_end = time.time()

    return response.text, time_end - time_start

In [None]:
for i in tqdm(range(0, len(n2c2_df), 1)):
    if (i + 1) % 20 == 0: # in case of quota limit error per minute
        time.sleep(65)
    n2c2_df.loc[i, 'palm2_zero_shot'], n2c2_df.loc[i, 'palm2_zero_shot_time'] = get_re_2018_n2c2(n2c2_df.iloc[i]['text'], 0)
    n2c2_df.loc[i, 'palm2_one_shot'], n2c2_df.loc[i, 'palm2_one_shot_time'] = get_re_2018_n2c2(n2c2_df.iloc[i]['text'], 1)

### 2.1.2 Evaluation

In [None]:
# get rid of ' ' if any
n2c2_df['palm2_zero_shot'] = n2c2_df['palm2_zero_shot'].apply(lambda x: x[1:-1] if "'" in x else x)
n2c2_df['palm2_one_shot'] = n2c2_df['palm2_one_shot'].apply(lambda x: x[1:-1] if "'" in x else x)

In [None]:
# get digit label while considering failed LLM outputs as 'No relation'
n2c2_df['labels'] = n2c2_df['labels'].apply(get_digit)
n2c2_df['palm2_zero_shot_labels'] = n2c2_df['palm2_zero_shot'].apply(get_digit)
n2c2_df['palm2_one_shot_labels'] = n2c2_df['palm2_one_shot'].apply(get_digit)

In [None]:
# Optional: you can just load the llm output from the csv file instead of running the above code
# n2c2_df = pd.read_csv("data/ER/2018_n2c2/test_200_palm2_results.csv")

In [None]:
y_true = n2c2_df['labels'].tolist()
y_pred = n2c2_df['palm2_one_shot_labels'].tolist()
print(classification_report(y_true, y_pred, digits=4))

In [None]:
print(f"Average PaLM 2 zero-shot prediction time: {n2c2_df['palm2_zero_shot_time'].mean():.2f} seconds")
print(f"Average PaLM 2 one-shot prediction time: {n2c2_df['palm2_one_shot_time'].mean():.2f} seconds")

In [None]:
# save the inference results
n2c2_df.to_csv('data/ER/2018_n2c2/test_200_palm2_results.csv', index=False)

## 2.2 GAD

### 2.2.1 Inference

In [None]:
gad_df = pd.read_csv('data/ER/GAD/test_200.csv')

In [None]:
def get_re_gad(sentence: str, shot: int = 0) -> str:

    parameters = {
        "temperature": 0.0,
    }

    chat = chat_model.start_chat(
        context="""
                "TASK: the task is to classify relations between a disease and a gene for a sentence."
                "INPUT: the input is a sentence where the disease is labeled as @DISEASE$ and the gene is labeled as @GENE$ accordingly in a sentence. "
                "OUTPUT: your task is to select one out of the two types of relations (0 and 1) for the gene and disease without any explanation or other characters: \n \
                        0, no relations \n \
                        1, has relations"
                """,
        examples=[
            InputOutputTextPair(
                input_text="We found evidence for association between @GENE$ and COGA @DISEASE$, history of blackouts, age at first drunkenness, and level of response to alcohol.",
                output_text='1',
            ),
        ] if shot == 1 else []
    )

    time_start = time.time()
    response = chat.send_message(
        sentence, **parameters
    )
    time_end = time.time()

    return response.text, time_end - time_start

In [None]:
for i in tqdm(range(0, len(gad_df), 1)):
    if (i + 1) % 20 == 0: # in case of quota limit error per minute
        time.sleep(65)
    gad_df.loc[i, 'palm2_zero_shot'], gad_df.loc[i, 'palm2_zero_shot_time'] = get_re_gad(gad_df.iloc[i]['text'], 0)
    gad_df.loc[i, 'palm2_one_shot'], gad_df.loc[i, 'palm2_one_shot_time'] = get_re_gad(gad_df.iloc[i]['text'], 1)

### 2.2.2 Evaluation

In [None]:
# convert some strings to int while considering failed LLM outputs as 'No relation (0)'
gad_df['palm2_zero_shot'] = gad_df['palm2_zero_shot'].apply(lambda x: int(x) if x.isdigit() else 0)
gad_df['palm2_one_shot'] = gad_df['palm2_one_shot'].apply(lambda x: int(x) if x.isdigit() else 0)

In [None]:
# Optional: you can just load the llm output from the csv file instead of running the above code
# gad_df = pd.read_csv("data/ER/GAD/test_200_palm2_results.csv")

In [None]:
y_true = gad_df['labels'].tolist()
y_pred = gad_df['palm2_one_shot'].tolist()
print(classification_report(y_true, y_pred, digits=4))

In [None]:
print(f"Average PaLM 2 zero-shot prediction time: {gad_df['palm2_zero_shot_time'].mean():.2f} seconds")
print(f"Average PaLM 2 one-shot prediction time: {gad_df['palm2_one_shot_time'].mean():.2f} seconds")

In [None]:
# save the inference results
gad_df.to_csv('data/ER/GAD/test_200_palm2_results.csv', index=False)