# Test Gemini vs Ensemble for MMLU


to do:

-ignore claude, add gemma, palm

In [3]:
import copy
import random
from pprint import pprint
from datetime import datetime
from datasets import load_dataset

from utils import Annotate
from config import PALM_CONFIG, GEMINI_CONFIG, CLAUDE_CONFIG

In [4]:
seed = 420
now = datetime.now().strftime("%Y%m%d")

In [5]:
dataset = load_dataset("cais/mmlu", "all")

# # take a small sample for dev purposes
dataset = dataset['test'].shuffle(seed=seed).select(range(1000))

# # user provided data description
# DESCRIPTION = """
# This is a massive multitask test consisting of multiple-choice questions from various branches of knowledge.
# The test spans subjects in the humanities, social sciences, hard sciences, and other areas that are important for some people to learn.
# To attain high accuracy on this test, models must possess extensive world knowledge and problem solving ability.
# This covers 57 subjects  across STEM, the humanities, the social sciences, and more. 
# It ranges in difficulty from an elementary level to an advanced professional level, and it tests both world knowledge and problem solving ability. 
# Subjects range from traditional areas, such as mathematics and history, to more specialized areas like law and ethics.
# """


In [6]:
gemini_prompt_template = """
<QUESTION>
{datapoint}
</QUESTION>
------------

<CHOICES>
{labels}
</choices>
------------

INSTRUCTION:
- read the above question carefully.
- you are given 4 choices seperated by comma in <CHOICES>.
- take your time and pick the precise correct answer from <CHOICES> for the given <QUESTION>.
- remember that there is always only one correct answer.
- return the exact correct answer from <CHOICES>. Don't provide explanations.
"""

In [7]:
prompt = [gemini_prompt_template.format(datapoint=x['question'],
                                        labels=x['choices']) for x in dataset]
print(len(prompt))

1000


In [16]:
models = [
    "palm",
    "gemini",
    # "claude"
    ]

PALM_CONFIG["project_config"]["qpm"] = 150

palm_1 =  copy.deepcopy(PALM_CONFIG)
palm_1['config_name'] = "temp_0.4"

palm_2 =  copy.deepcopy(PALM_CONFIG)
palm_2['config_name'] = "temp_0.9"
palm_2["generation_config"]['temperature'] = 0.9

GEMINI_CONFIG["project_config"]["qpm"] = 100

gemini_1 =  copy.deepcopy(GEMINI_CONFIG)
gemini_1['config_name'] = "-1.0-pro-002"

gemini_2 =  copy.deepcopy(GEMINI_CONFIG)
gemini_2['config_name'] = "-1.5-flash-001"
gemini_2['"model"'] = "gemini-1.5-flash-001"

gemini_3 =  copy.deepcopy(GEMINI_CONFIG)
gemini_3['config_name'] = "-1.0-ultra-001"
gemini_3['"model"'] = "gemini-1.0-ultra-001"



model_config = {
    "gemini": [
        gemini_1,
        gemini_2,
        gemini_3
         ],
    "palm": [
        palm_1, 
        palm_2
        ],
    # "claude": [
    #     CLAUDE_CONFIG
    # ]
}

ann = Annotate()


In [17]:
output_dict = await ann.classification(prompt, models, model_config)

Creating tasks: 100%|██████████| 5000/5000 [00:00<00:00, 17846.98it/s]
Gathering palm_temp_0.4 results: 100%|██████████| 1000/1000 [42:25<00:00,  2.55s/it]   
Gathering palm_temp_0.9 results: 100%|██████████| 1000/1000 [00:00<00:00, 211683.86it/s]
Gathering gemini_-1.0-pro-002 results:   0%|          | 0/1000 [00:00<?, ?it/s]2024-05-29 22:07:00,395/Annotate[ERROR]: gemini_-1.0-pro-002 Task 0 failed: Cannot get the response text.
Cannot get the Candidate text.
Response candidate content has no parts (and thus no text). The candidate is likely blocked by the safety filters.
Content:
{}
Candidate:
{
  "finish_reason": "OTHER"
}
Response:
{
  "candidates": [
    {
      "finish_reason": "OTHER"
    }
  ],
  "usage_metadata": {
    "prompt_token_count": 168,
    "total_token_count": 168
  }
}
2024-05-29 22:07:00,397/Annotate[ERROR]: gemini_-1.0-pro-002 Task 99 failed: Cannot get the response text.
Cannot get the Candidate text.
Response candidate content has no parts (and thus no text). The

In [18]:
llm_response = {}

for k in output_dict.keys():
    llm_response[k] = []
    for idx, r in enumerate(output_dict[k]):
        if r is not None:
            stripped_r = r.strip().strip("'")
            if stripped_r in dataset['choices'][idx]:
                llm_response[k].append(dataset['choices'][idx].index(stripped_r))
            else:
                # Handle case where stripped_r is not found in choices
                llm_response[k].append(None)
        else:
            # Handle None values appropriately
            llm_response[k].append(None)

llm_response

{'palm_temp_0.4': [2,
  0,
  1,
  3,
  2,
  2,
  1,
  2,
  2,
  0,
  2,
  2,
  None,
  2,
  0,
  1,
  0,
  1,
  2,
  1,
  3,
  2,
  1,
  2,
  0,
  1,
  1,
  3,
  2,
  1,
  3,
  2,
  0,
  1,
  1,
  1,
  0,
  1,
  1,
  0,
  3,
  2,
  1,
  2,
  None,
  1,
  1,
  1,
  3,
  3,
  2,
  1,
  2,
  3,
  1,
  1,
  0,
  None,
  3,
  2,
  3,
  3,
  1,
  2,
  2,
  3,
  None,
  1,
  3,
  1,
  1,
  None,
  3,
  3,
  1,
  3,
  1,
  3,
  1,
  1,
  2,
  2,
  0,
  1,
  0,
  2,
  2,
  0,
  0,
  1,
  0,
  0,
  2,
  0,
  1,
  0,
  1,
  2,
  2,
  2,
  2,
  1,
  0,
  3,
  1,
  2,
  1,
  2,
  2,
  0,
  2,
  3,
  2,
  2,
  3,
  0,
  3,
  1,
  0,
  2,
  1,
  3,
  3,
  0,
  3,
  2,
  3,
  0,
  2,
  3,
  3,
  0,
  None,
  0,
  None,
  2,
  1,
  2,
  3,
  3,
  3,
  0,
  3,
  1,
  3,
  1,
  3,
  1,
  2,
  3,
  3,
  2,
  1,
  2,
  0,
  1,
  2,
  1,
  0,
  None,
  3,
  1,
  2,
  0,
  2,
  2,
  2,
  0,
  2,
  2,
  0,
  2,
  1,
  1,
  2,
  3,
  3,
  3,
  2,
  0,
  2,
  1,
  2,
  0,
  0,
  3,
  1,
  0,
  1,
  1,
  1,
  0,

In [19]:
def replace_none_with_random(data, n):
    for key, lst in data.items():
        for i, value in enumerate(lst):
            if value is None:
                lst[i] = random.randint(0, n - 1)



def convert_dict_to_indexed_list(data_dict):
    number_map = {key: index for index, key in enumerate(data_dict.keys())}
    max_len = len(next(iter(data_dict.values())))

    result = []
    for index in range(max_len):
        for key, value_list in data_dict.items():
            value = value_list[index]
            # Convert to 0 if not an integer
            converted_value = 0 if not isinstance(value, int) else value 
            result.append([index, number_map[key], converted_value])
    return result
    

def generate_task_config(response_dict, num_classes):

    num_labels = sum(len(lst) for lst in response_dict.values())
    num_tasks =  len(list(response_dict.values())[0])
    num_labelers = len(response_dict)
    z  = 1/num_classes


    tc = [num_labels, num_labelers, num_tasks, num_classes]
    tc.extend([z] * tc[-1])

    return tc

n_class = 4
replace_none_with_random(llm_response, n_class)

In [20]:
task_conf = generate_task_config(llm_response, n_class)
llm_result_list = convert_dict_to_indexed_list(llm_response)
llm_result_list.insert(0, task_conf)

In [21]:
task_conf

[5000, 5, 1000, 4, 0.25, 0.25, 0.25, 0.25]

In [22]:
filename = f"./data/llm_response__{now}.txt"

with open(filename, "w") as file:
    for sublist in llm_result_list:
        line = " ".join(str(num) for num in sublist)  # Convert to string, join with spaces
        file.write(line + "\n")  # Write line and add newline
filename

'./data/llm_response__20240529.txt'

#  GLAD

In [23]:
from utils import glad

In [24]:
file_name = "./data/llm_response__20240529.txt"

In [25]:
glad_output = glad(file_name)

In [26]:
glad_output.keys()

dict_keys(['alpha', 'beta', 'probZ', 'labels'])

In [27]:
glad_output['labels']

{0: 2,
 1: 0,
 2: 2,
 3: 2,
 4: 2,
 5: 2,
 6: 1,
 7: 2,
 8: 2,
 9: 2,
 10: 1,
 11: 2,
 12: 0,
 13: 2,
 14: 1,
 15: 1,
 16: 1,
 17: 1,
 18: 2,
 19: 1,
 20: 3,
 21: 2,
 22: 1,
 23: 2,
 24: 3,
 25: 1,
 26: 1,
 27: 1,
 28: 2,
 29: 1,
 30: 3,
 31: 2,
 32: 1,
 33: 1,
 34: 1,
 35: 1,
 36: 0,
 37: 2,
 38: 1,
 39: 0,
 40: 3,
 41: 2,
 42: 1,
 43: 2,
 44: 1,
 45: 3,
 46: 3,
 47: 1,
 48: 3,
 49: 3,
 50: 2,
 51: 1,
 52: 0,
 53: 3,
 54: 1,
 55: 1,
 56: 0,
 57: 2,
 58: 3,
 59: 2,
 60: 3,
 61: 3,
 62: 1,
 63: 2,
 64: 2,
 65: 3,
 66: 1,
 67: 2,
 68: 3,
 69: 1,
 70: 1,
 71: 1,
 72: 3,
 73: 3,
 74: 1,
 75: 3,
 76: 1,
 77: 2,
 78: 1,
 79: 1,
 80: 2,
 81: 2,
 82: 0,
 83: 1,
 84: 0,
 85: 2,
 86: 2,
 87: 0,
 88: 0,
 89: 1,
 90: 2,
 91: 0,
 92: 2,
 93: 0,
 94: 1,
 95: 0,
 96: 1,
 97: 2,
 98: 1,
 99: 2,
 100: 2,
 101: 1,
 102: 0,
 103: 3,
 104: 0,
 105: 2,
 106: 1,
 107: 2,
 108: 3,
 109: 0,
 110: 2,
 111: 3,
 112: 2,
 113: 2,
 114: 0,
 115: 0,
 116: 0,
 117: 1,
 118: 0,
 119: 2,
 120: 1,
 121: 3,
 122: 3,
 12