In [1]:
import numpy as np
import pandas as pd 
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
import torch
import re
import pickle
import time 
import os


In [None]:
pip install -U bitsandbytes

In [2]:
splits = {'test': 'college_mathematics/test-00000-of-00001.parquet', 'validation': 'college_mathematics/validation-00000-of-00001.parquet', 'dev': 'college_mathematics/dev-00000-of-00001.parquet'}
df = pd.read_parquet("hf://datasets/cais/mmlu/" + splits["test"])

In [3]:
df.head()

Unnamed: 0,question,subject,choices,answer
0,Let k be the number of real solutions of the e...,college_mathematics,"[k = 0 and n = 1, k = 1 and n = 0, k = n = 1, ...",1
1,"Up to isomorphism, how many additive abelian g...",college_mathematics,"[0, 1, 2, 3]",3
2,Suppose P is the set of polynomials with coeff...,college_mathematics,"[n = 1 and r = 6, n = 1 and r = 7, n = 2 and r...",3
3,The shortest distance from the curve xy = 8 to...,college_mathematics,"[4, 8, 16, 2sqrt(2)]",0
4,"There are 25 suitcases, 5 of which are damaged...",college_mathematics,"[2/69, 1/30, 2/23, 12/125]",2


In [4]:
def create_prompts_zs(df):
    prompts = []
    answers = []
    for index, row in df.iterrows():
        #
        prompt = "Choose the answer of the given question from the below options.\n"+row['question']+ "\n"+"{1:"+row['choices'][0]+"}" + "\n"+"{2:"+row['choices'][1]+"}" + "\n"+"{3:"+row['choices'][2]+"}"+ "\n"+"{4:"+row['choices'][3]+"}"+"\n" + "answer: {"
        prompts.append(prompt)
        answers.append(str(row['answer']))
    return prompts, answers

In [5]:
prompts_zs, answers = create_prompts_zs(df)

In [35]:
for i in range(len(answers)):
    answers[i] = int(answers[i])

In [6]:
hf_token = "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"

In [14]:
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b-it",use_auth_token = hf_token)
model = AutoModelForCausalLM.from_pretrained(
    "google/gemma-2b-it",
    device_map="auto",
    torch_dtype=torch.bfloat16,
    use_auth_token = hf_token
)




Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [15]:
model = model.to('cuda')

In [8]:
prompts_zs[8]

'Choose the answer of the given question from the below options.\nLet V be a finite-dimensional real vector space and let P be a linear transformation of V such that P^2 = P. Which of the following must be true?\nI. P is invertible.\nII. P is diagonalizable.\nIII. P is either the identity transformation or the zero transformation.\n{1:None}\n{2:I only}\n{3:II only}\n{4:III only}\nanswer: {'

In [24]:
predictions_zeroshot_gemma = []
begin = time.time()
for i in range(len(prompts_zs)):
    print(i)
    input_text = prompts_zs[i]
    input_ids = tokenizer(input_text, return_tensors="pt").to("cuda")
    outputs = model.generate(**input_ids,max_length = 200)
    response = tokenizer.decode(outputs[0])
    match = re.search(r"answer: \{(\d+)", response)
    extracted_character = match.group(1) if match else '1'
    if(extracted_character in ['1','2','3','4']):
        extracted_number = int(extracted_character)-1
    else:
        extracted_number = 0
    predictions_zeroshot_gemma.append(extracted_number)
end = time.time()

0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99


In [25]:
def calculate_accuracy(predictions, answers):
    valid_pairs = [(pred, ans) for pred, ans in zip(predictions, answers) if pred is not None]
    correct_predictions = sum(1 for pred, ans in valid_pairs if pred == ans)
    accuracy = correct_predictions / len(valid_pairs) if valid_pairs else 0
    return accuracy

In [31]:
filename = 'predictions_zeroshot_gemma.pkl'
with open(filename, 'wb') as file:
    pickle.dump(predictions_zeroshot_gemma, file)
with open(filename, 'rb') as file:
    predictions_zeroshot_gemma = pickle.load(file)

In [37]:
print("Zero shot prompting on Gemma")
print("Total inference time:",end-begin,"seconds")
print("Accuracy score:",100*calculate_accuracy(predictions_zeroshot_gemma,answers),"%")

Zero shot prompting on Gemma
Total inference time: 215.90406322479248 seconds
Accuracy score: 28.999999999999996 %
