In [1]:
from glob import glob
import json
import re
import os
from tqdm import tqdm
from collections import Counter

def aggregate_mcq_predictions(preds, gold, method="majority"):
    if method == "majority":
        guess = Counter(preds).most_common(1)[0][0]

    elif method == "any":
        return int(gold in preds)

    else:
        raise ValueError("method must be 'majority' or 'any'")

    return int(guess == gold)
    
def parse(text):

    m = re.search(r"[A-F]", text)
    if m:
        return m.group()

In [2]:
with open('MalayMMLU_0shot.json') as fopen:
    malaymmlu = json.load(fopen)

In [3]:
malaymmlu[0]

{'id': 12822941,
 'prompt': 'Pasangan algoritma yang digunakan untuk melakukan penyulitan dan nyahsulit dikenali sebagai\nA. kunci (keys)\nB. Sifer (cipher)\nC. Teks sifer (ciphertext)',
 'answer': 'B. Sifer (cipher)',
 'year': 'Tingkatan 3',
 'subject': 'Sains Komputer',
 'subject_eng': 'Computer Science',
 'category': 'STEM',
 'level': 'Secondary',
 'options': ['A. kunci (keys)',
  'B. Sifer (cipher)',
  'C. Teks sifer (ciphertext)'],
 'num_options': 3,
 'key': 'B'}

In [4]:
folders = glob('malaymmlu-20b-*')
folders = [f for f in folders if '.zip' not in f and 'baseline' not in f]
folders = sorted(folders, key = lambda x: int(x.split('-r')[1]))
folders

['malaymmlu-20b-r16',
 'malaymmlu-20b-r32',
 'malaymmlu-20b-r64',
 'malaymmlu-20b-r128',
 'malaymmlu-20b-r256',
 'malaymmlu-20b-r512']

In [5]:
for f in folders:
    total_k1 = 0
    total_k5 = 0
    total = 0
    wrong = []
    for i in range(len(malaymmlu)):
        try:
            results = []
            for k in range(5):
                filename = os.path.join(f, f'{i}-{k}.json')
                with open(filename) as fopen:
                    d = json.load(fopen)
                p = parse(d)
                if p:
                    results.append(p)
            s = aggregate_mcq_predictions(results, malaymmlu[i]['key'])
            total_k1 += aggregate_mcq_predictions(results[:1], malaymmlu[i]['key'])
            total_k5 += s
            total += 1

            if s == 0:
                wrong.append((results, malaymmlu[i]))
        except:
            pass

    print(f.split('-r')[1], total_k1 / total, total_k5 / total)

16 0.6103171980835949 0.6688831984140096
32 0.6082934082273252 0.6673963323971585
64 0.6182629166150415 0.6735224879197126
128 0.6151082108045597 0.6728068726251446
256 0.6211126254491389 0.6804196093007888
512 0.6250361376120266 0.6845496221038285
