# Test Gemini vs Ensemble for MMLU


to do:

-ignore claude, add gemma, palm

In [None]:
import copy
from pprint import pprint
from datetime import datetime
from datasets import load_dataset

from utils import Annotate
from config import PALM_CONFIG, GEMINI_CONFIG

In [None]:
seed = 420
now = datetime.now().strftime("%Y%m%d")

In [None]:
dataset = load_dataset("cais/mmlu", "all")

# # take a small sample for dev purposes
dataset = dataset['test'].shuffle(seed=seed).select(range(5))

# user provided data description
DESCRIPTION = """
This is a massive multitask test consisting of multiple-choice questions from various branches of knowledge.
The test spans subjects in the humanities, social sciences, hard sciences, and other areas that are important for some people to learn.
To attain high accuracy on this test, models must possess extensive world knowledge and problem solving ability.
This covers 57 subjects  across STEM, the humanities, the social sciences, and more. 
It ranges in difficulty from an elementary level to an advanced professional level, and it tests both world knowledge and problem solving ability. 
Subjects range from traditional areas, such as mathematics and history, to more specialized areas like law and ethics.
"""


In [None]:
gemini_prompt_template = """
<QUESTION>
{datapoint}
</QUESTION>
------------

<CHOICES>
{labels}
</choices>
------------

INSTRUCTION:
- read the above question carefully.
- you are given 4 choices seperated by comma in <CHOICES>.
- take your time and pick the precise correct answer from <CHOICES> for the given <QUESTION>.
- remember that there is always only one correct answer.
- return the exact correct answer from <CHOICES>. Don't provide explanations.
"""

In [None]:
prompt = [gemini_prompt_template.format(description= DESCRIPTION,
                                        datapoint=x['question'],
                                        labels=x['choices']) for x in dataset]
print(len(prompt))

In [None]:
models = [
    "palm",
    "gemini",
    # "claude"
    ]


palm_1 =  copy.deepcopy(PALM_CONFIG)
palm_1['config_name'] = "palm_1"
palm_2 =  copy.deepcopy(PALM_CONFIG)
palm_2['config_name'] = "palm_2"
palm_2["generation_config"]['temperature'] = 0.9


gemini_1 =  copy.deepcopy(GEMINI_CONFIG)
gemini_1['config_name'] = "gemini_1"
gemini_2 =  copy.deepcopy(GEMINI_CONFIG)
gemini_2['config_name'] = "gemini_2"
gemini_2['"model"'] = "gemini-1.5-flash-001"


model_config = {
    "gemini": [
        gemini_1,
        gemini_2
         ],
    "palm": [
        palm_1, 
        palm_2
        ]
}

ann = Annotate()


In [None]:
output_dict = await ann.classification(prompt, models, model_config)

In [None]:
output_dict

In [None]:
llm_response = {}

for k in output_dict.keys():
    llm_response[k] = []
    for idx, r in enumerate(output_dict[k]):
        if r is not None:
            stripped_r = r.strip().strip("'")
            if stripped_r in dataset['choices'][idx]:
                llm_response[k].append(dataset['choices'][idx].index(stripped_r))
            else:
                # Handle case where stripped_r is not found in choices
                llm_response[k].append(None)
        else:
            # Handle None values appropriately
            llm_response[k].append(None)

llm_response

#  GLAD

In [None]:
from utils import glad

In [None]:
def generate_task_config(response_dict, num_classes):
    num_labels = sum(len(lst) for lst in response_dict.values())
    num_labelers =  len(list(response_dict.values())[0])
    num_tasks = len(response_dict)
    z  = 1/num_classes


    tc = [num_labels, num_labelers, num_tasks, num_classes]
    tc.extend([z] * tc[-1])

    return tc

In [None]:
task_conf = generate_task_config(llm_response, 4)
task_conf

In [None]:
# import json

# with open('./data/output/annotation_output__20240515.json', 'r') as file:
#     sample_data = json.load(file)


# task_conf = generate_task_config(sample_data, 2)
# sample_data, task_conf

In [None]:
glad(llm_response, task_conf)