# Test Gemini vs Ensemble for MMLU


to do:

-ignore claude, add gemma, palm

In [1]:
import sys
sys.path.append('/usr/local/google/home/amirimani/Desktop/projects/llm_annotator/')


import ast
import os
import copy
import json
import random
import pandas as pd
from pprint import pprint
from datetime import datetime
from datasets import load_dataset

from utils import Annotate, Evaluate
from config import PALM_CONFIG, GEMINI_CONFIG, CLAUDE_CONFIG

In [2]:
seed = 42
now = datetime.now().strftime("%Y%m%d")

In [3]:
dataset = load_dataset("bigbio/bc5cdr", trust_remote_code=True)

# Access the different splits (train, validation, test)
train_data = dataset["train"]
validation_data = dataset["validation"]
test_data = dataset["test"]

In [4]:
full_abstracts_and_entities = []

for example in dataset['train']:
    passages = example['passages']
    abstract_text = ""
    entities = []

    # Extract the abstract text and entities
    for passage in passages:
        if passage['type'] == 'abstract':
            abstract_text = passage['text']
        entities.extend(passage['entities'])

    full_abstracts_and_entities.append({
        'abstract': abstract_text,
        'entities': entities
    })



full_abstracts_and_entities[0]

{'abstract': 'In unanesthetized, spontaneously hypertensive rats the decrease in blood pressure and heart rate produced by intravenous clonidine, 5 to 20 micrograms/kg, was inhibited or reversed by nalozone, 0.2 to 2 mg/kg. The hypotensive effect of 100 mg/kg alpha-methyldopa was also partially reversed by naloxone. Naloxone alone did not affect either blood pressure or heart rate. In brain membranes from spontaneously hypertensive rats clonidine, 10(-8) to 10(-5) M, did not influence stereoselective binding of [3H]-naloxone (8 nM), and naloxone, 10(-8) to 10(-4) M, did not influence clonidine-suppressible binding of [3H]-dihydroergocryptine (1 nM). These findings indicate that in spontaneously hypertensive rats the effects of central alpha-adrenoceptor stimulation involve activation of opiate receptors. As naloxone and clonidine do not appear to interact with the same receptor site, the observed functional antagonism suggests the release of an endogenous opiate by clonidine or alpha-m

In [5]:
gemini_prompt_template = """
    You are an accurate and context-aware biomedical information extractor.

    BACKGROUND:
    The BC5CDR (BioCreative V Chemical Disease Relation) dataset is a collection of PubMed articles annotated with mentions of chemicals, diseases, and their relationships.

    TASK:
    You are given an abstract from the BC5CDR dataset. Extract all chemical and disease entities and format them as a JSON array of dictionaries. Each dictionary must have these keys:

    * offsets: A list of lists. Each inner list has two integers: the start and end character positions of the entity in the text (the first character is at position 0).
    * text: A list containing the string of the entity.
    * type: A string, either "Chemical" or "Disease".

    EXAMPLES:

    Input: "Doxorubicin is used to treat breast cancer in patients."
    Output:
    ```json
    [
      {{
        "offsets": [[0, 11]],
        "text": ["Doxorubicin"],
        "type": "Chemical"
      }},
      {{
        "offsets": [[28, 42]],
        "text": ["breast cancer"],
        "type": "Disease"
      }},
      {{
        "offsets": [[28, 49]],
        "text": ["breast cancer patients"],
        "type": "Disease"
      }}
    ]
    ```

    Input: "Alcohol consumption can lead to liver disease."
    Output:
    ```json
    [
      {{
        "offsets": [[0, 7]],
        "text": ["Alcohol"],
        "type": "Chemical"
      }},
      {{
        "offsets": [[31, 44]],
        "text": ["liver disease"],
        "type": "Disease"
      }}
    ]
    ```

    Input: "The patient showed no evidence of pneumonia."
    Output:
    ```json
    [
      {{
        "offsets": [[33, 42]],
        "text": ["pneumonia"],
        "type": "Disease"
      }}
    ]
    ```

    Input: "This study investigates the efficacy of a new surgical procedure."
    Output: []

    NOTES:

    * If a span of text can be annotated as both a chemical and a disease, create separate entries for each type.
    * Do not normalize chemical names or expand abbreviations/acronyms.
    * Include negated mentions as regular entities.

-----
    Abstract:
    {abstract}

"""


In [6]:
prompt = [gemini_prompt_template.format(abstract=x['abstract']) for x in full_abstracts_and_entities[:10]]
print(len(prompt))

10


In [7]:
ann = Annotate()

In [8]:
GEMINI_CONFIG["project_config"]["qpm"] = 100
GEMINI_CONFIG['generation_config']['response_mime_type'] = "application/json"

GEMINI_CONFIG


{'config_name': 'default',
 'model': 'gemini-1.5-pro',
 'project_config': {'qpm': 100,
  'project': 'amir-genai-bb',
  'location': 'us-central1'},
 'generation_config': {'max_output_tokens': 2048,
  'temperature': 0.4,
  'top_p': 1,
  'response_mime_type': 'application/json'}}

In [9]:
output_dict = await ann.classification(prompt, ['gemini'], {'gemini': [GEMINI_CONFIG]})

predicted_entities = [ast.literal_eval(x) for x in output_dict['gemini_default']]

Creating tasks: 100%|██████████| 10/10 [00:00<00:00, 18808.54it/s]
Gathering gemini_default results: 100%|██████████| 10/10 [01:23<00:00,  8.40s/it]


In [10]:
def extract_entity_details(data):
    extracted_entities = []

    for entry in data:
        entities = entry.get('entities', [])
        row_entities = []

        for entity in entities:
            entity_details = {
                'offsets': entity.get('offsets', []),
                'text': entity.get('text', []),
                'type': entity.get('type', '')
            }
            row_entities.append(entity_details)

        extracted_entities.append(row_entities)
    return extracted_entities

In [11]:
true_entities = extract_entity_details(full_abstracts_and_entities[:10])

In [12]:
assert len(predicted_entities) == len(true_entities)

In [13]:
evl = Evaluate()

In [14]:
evl.bootstrap_evaluation(predicted_entities, true_entities)

{'Precision CI': array([0.35     , 0.5737013]),
 'Recall CI': array([0.46341463, 0.64102564]),
 'F1 Score CI': array([0.41162971, 0.57123823]),
 'Exact Match Ratio CI': array([0.46341463, 0.64102564])}