# Test Gemini vs Ensemble for MMLU


to do:

-ignore claude, add gemma, palm

In [8]:
import copy
from pprint import pprint
from datetime import datetime
from datasets import load_dataset

from utils import Annotate
from config import PALM_CONFIG, GEMINI_CONFIG

In [9]:
seed = 42
now = datetime.now().strftime("%Y%m%d")

In [10]:
dataset = load_dataset("cais/mmlu", "all")

# # take a small sample for dev purposes
dataset = dataset['test'].shuffle(seed=seed).select(range(5))

# user provided data description
DESCRIPTION = """
This is a massive multitask test consisting of multiple-choice questions from various branches of knowledge.
The test spans subjects in the humanities, social sciences, hard sciences, and other areas that are important for some people to learn.
To attain high accuracy on this test, models must possess extensive world knowledge and problem solving ability.
This covers 57 subjects  across STEM, the humanities, the social sciences, and more. 
It ranges in difficulty from an elementary level to an advanced professional level, and it tests both world knowledge and problem solving ability. 
Subjects range from traditional areas, such as mathematics and history, to more specialized areas like law and ethics.
"""


In [11]:
gemini_prompt_template = """
<QUESTION>
{datapoint}
</QUESTION>
------------

<CHOICES>
{labels}
</choices>
------------

INSTRUCTION:
- read the above question carefully.
- you are given 4 choices seperated by comma in <CHOICES>.
- take your time and pick the precise correct answer from <CHOICES> for the given <QUESTION>.
- remember that there is always only one correct answer.
- return the exact correct answer from <CHOICES>. Don't provide explanations.
"""

In [12]:
prompt = [gemini_prompt_template.format(description= DESCRIPTION,
                                        datapoint=x['question'],
                                        labels=x['choices']) for x in dataset]
print(len(prompt))

5


In [13]:
models = [
    "palm",
    "gemini",
    # "claude"
    ]


palm_1 =  copy.deepcopy(PALM_CONFIG)
palm_1['name'] = "palm_1"
palm_2 =  copy.deepcopy(PALM_CONFIG)
palm_2['name'] = "palm_2"



model_config = {
    "gemini": GEMINI_CONFIG,
    # "palm": [palm_1, palm_2]
    "palm": PALM_CONFIG
}

ann = Annotate()


In [14]:
output_dict = await ann.classification(prompt, models, model_config)

Creating tasks: 20it [00:00, 30404.52it/s]            
Gathering palm results: 100%|██████████| 5/5 [00:06<00:00,  1.23s/it]
Gathering gemini results:   0%|          | 0/5 [00:00<?, ?it/s]2024-05-29 02:28:24,393/Annotate[ERROR]: gemini Task 0 failed: Cannot get the response text.
Cannot get the Candidate text.
Response candidate content has no parts (and thus no text). The candidate is likely blocked by the safety filters.
Content:
{}
Candidate:
{
  "finish_reason": "OTHER"
}
Response:
{
  "candidates": [
    {
      "finish_reason": "OTHER"
    }
  ],
  "usage_metadata": {
    "prompt_token_count": 217,
    "total_token_count": 217
  }
}
Gathering gemini results: 100%|██████████| 5/5 [00:00<00:00, 2210.32it/s]


In [None]:
output_dict

In [None]:
llm_response = {}

for m in models:   
    llm_response[m] = [dataset['choices'][idx].index(r.strip().strip("'")) for idx, r in enumerate(output_dict[m])]
llm_response

#  GLAD

In [None]:
from utils import glad

In [None]:
num_labels = sum(len(lst) for lst in llm_response.values())
num_labelers =  len(list(llm_response.values())[0])
num_tasks = len(llm_response)
num_classes = 4
z  = 1/num_classes


tc = [num_labels, num_labelers, num_tasks, num_classes]
tc.extend([z] * tc[-1])

tc

In [None]:
# import json

# with open('./data/output/annotation_output__20240515.json', 'r') as file:
#     ddddd = json.load(file)

In [None]:
glad(llm_response, tc)