# Test Gemini vs Ensemble for MMLU


to do:

-ignore claude, add gemma, palm

In [1]:
import copy
from pprint import pprint
from datetime import datetime
from datasets import load_dataset

from utils import Annotate
from config import PALM_CONFIG, GEMINI_CONFIG

In [2]:
seed = 420
now = datetime.now().strftime("%Y%m%d")

In [3]:
dataset = load_dataset("cais/mmlu", "all")

# # take a small sample for dev purposes
dataset = dataset['test'].shuffle(seed=seed).select(range(5))

# user provided data description
DESCRIPTION = """
This is a massive multitask test consisting of multiple-choice questions from various branches of knowledge.
The test spans subjects in the humanities, social sciences, hard sciences, and other areas that are important for some people to learn.
To attain high accuracy on this test, models must possess extensive world knowledge and problem solving ability.
This covers 57 subjects  across STEM, the humanities, the social sciences, and more. 
It ranges in difficulty from an elementary level to an advanced professional level, and it tests both world knowledge and problem solving ability. 
Subjects range from traditional areas, such as mathematics and history, to more specialized areas like law and ethics.
"""


In [4]:
gemini_prompt_template = """
<QUESTION>
{datapoint}
</QUESTION>
------------

<CHOICES>
{labels}
</choices>
------------

INSTRUCTION:
- read the above question carefully.
- you are given 4 choices seperated by comma in <CHOICES>.
- take your time and pick the precise correct answer from <CHOICES> for the given <QUESTION>.
- remember that there is always only one correct answer.
- return the exact correct answer from <CHOICES>. Don't provide explanations.
"""

In [5]:
prompt = [gemini_prompt_template.format(description= DESCRIPTION,
                                        datapoint=x['question'],
                                        labels=x['choices']) for x in dataset]
print(len(prompt))

5


In [8]:
models = [
    "palm",
    "gemini",
    # "claude"
    ]


palm_1 =  copy.deepcopy(PALM_CONFIG)
palm_1['config_name'] = "palm_1"
palm_2 =  copy.deepcopy(PALM_CONFIG)
palm_2['config_name'] = "palm_2"

gemini_1 =  copy.deepcopy(GEMINI_CONFIG)
gemini_1['config_name'] = "gemini_1"
gemini_2 =  copy.deepcopy(GEMINI_CONFIG)
gemini_2['config_name'] = "gemini_2"



model_config = {
    "gemini": [gemini_1, gemini_2],
    "palm": [palm_1, palm_2]
    # "palm": PALM_CONFIG
}

ann = Annotate()


In [9]:
output_dict = await ann.classification(prompt, models, model_config)

Creating tasks: 100%|██████████| 20/20 [00:00<00:00, 65179.55it/s]
Gathering palm_palm_1 results: 100%|██████████| 5/5 [00:13<00:00,  2.60s/it]
Gathering palm_palm_2 results: 100%|██████████| 5/5 [00:00<00:00, 26990.37it/s]
Gathering gemini_gemini_1 results:   0%|          | 0/5 [00:00<?, ?it/s]2024-05-29 18:54:09,407/Annotate[ERROR]: gemini_gemini_1 Task 0 failed: Cannot get the response text.
Cannot get the Candidate text.
Response candidate content has no parts (and thus no text). The candidate is likely blocked by the safety filters.
Content:
{}
Candidate:
{
  "finish_reason": "OTHER"
}
Response:
{
  "candidates": [
    {
      "finish_reason": "OTHER"
    }
  ],
  "usage_metadata": {
    "prompt_token_count": 168,
    "total_token_count": 168
  }
}
Gathering gemini_gemini_1 results: 100%|██████████| 5/5 [00:00<00:00, 3913.33it/s]
Gathering gemini_gemini_2 results:   0%|          | 0/5 [00:00<?, ?it/s]2024-05-29 18:54:09,410/Annotate[ERROR]: gemini_gemini_2 Task 0 failed: Cannot ge

In [10]:
output_dict

{'palm_palm_1': [' Homo erectus.',
  ' simple',
  ' Perform preliminary analytical procedures to identify accounts that may represent specific risks relevant to the engagement.',
  " 'All of these options.'",
  ' Deep pyro sequencing (NGS)'],
 'palm_palm_2': [' Homo erectus.',
  ' simple',
  ' Perform preliminary analytical procedures to identify accounts that may represent specific risks relevant to the engagement.',
  " 'All of these options.'",
  ' Deep pyro sequencing (NGS)'],
 'gemini_gemini_1': [None,
  "'compound'",
  '"Make inquiries of management concerning the entity\'s procedures used in adjusting and closing the books of account."',
  'All of these options.',
  "'Deep pyro sequencing (NGS)'"],
 'gemini_gemini_2': [None,
  "'compound'",
  '"Make inquiries of management concerning the entity\'s procedures used in adjusting and closing the books of account."',
  '"All of these options."',
  'Deep pyro sequencing (NGS)']}

In [15]:
llm_response = {}

for k in output_dict.keys():
    llm_response[k] = []
    for idx, r in enumerate(output_dict[k]):
        if r is not None:
            stripped_r = r.strip().strip("'")
            if stripped_r in dataset['choices'][idx]:
                llm_response[k].append(dataset['choices'][idx].index(stripped_r))
            else:
                # Handle case where stripped_r is not found in choices
                llm_response[k].append(None)
        else:
            # Handle None values appropriately
            llm_response[k].append(None)

llm_response

{'palm_palm_1': [2, 0, 1, 3, 2],
 'palm_palm_2': [2, 0, 1, 3, 2],
 'gemini_gemini_1': [None, 1, None, 3, 2],
 'gemini_gemini_2': [None, 1, None, None, 2]}

#  GLAD

In [16]:
from utils import glad

In [17]:
num_labels = sum(len(lst) for lst in llm_response.values())
num_labelers =  len(list(llm_response.values())[0])
num_tasks = len(llm_response)
num_classes = 4
z  = 1/num_classes


tc = [num_labels, num_labelers, num_tasks, num_classes]
tc.extend([z] * tc[-1])

tc

[20, 5, 4, 4, 0.25, 0.25, 0.25, 0.25]

In [None]:
# import json

# with open('./data/output/annotation_output__20240515.json', 'r') as file:
#     ddddd = json.load(file)

In [18]:
glad(llm_response, tc)

IndexError: index 4 is out of bounds for axis 0 with size 4