In [2]:
from datasets import load_dataset, Dataset
from pprint import pprint

from tqdm.auto import tqdm
prompt = """Cuộc trò chuyện giữa con người và trợ lý AI.
[|Con người|] Trả lời câu hỏi sau đây bằng cách lựa chọn 1 trong 4 đáp án A, B, C, D và không đưa ra giải thích thêm.
{question}
{choices}
Đáp án:
[|AI|] """


def load_arc(path):
    def mapper(x):
        choices = [x[t] for t in ['option_a', 'option_b', 'option_c', 'option_d']]
        return {
            'input': prompt.format(question=x['instruction'], choices='\n'.join(f'{i}. {c}' for i, c in zip(['A', 'B', 'C', 'D'], choices))),
            'ref': x['answer'],
            'category': x['id'].split('/')[-1].split('_')[0],
        }
    ds = Dataset.from_json(path)
    ds = ds.map(mapper, remove_columns=ds.column_names, num_proc=4)
    return ds

def load_mmlu(path):
    def mapper(x):
        choices = [x[t] for t in ['option_a', 'option_b', 'option_c', 'option_d']]
        return {
            'input': prompt.format(question=x['instruction'], choices='\n'.join(f'{i}. {c}' for i, c in zip(['A', 'B', 'C', 'D'], choices))),
            'ref': x['answer'],
            'category': x['id'].split('/')[0],
        }
    ds = Dataset.from_json(path)
    ds = ds.map(mapper, remove_columns=ds.column_names, num_proc=4)
    return ds

# ds format: {'input': prompt_str, 'ref': str[A, B, C, D], 'category': str}
temp = [
    {
        'name': 'ARC',
        'path': 'data/ARC-V1-Feb2018-2/ARC-Challenge/ARC-Challenge-Train.jsonl',
        'load_func': load_arc,
    },
    {
        'name': 'MMLU',
        'path': 'data/MMLU/MMLU.jsonl',
        'load_func': load_mmlu,
    }
]

def report(generate_func, ds_path_name_load=temp): # generate_func: input -> pred
    def mapper(x):
        return {
            'input': x['input'],
            'ref': x['ref'],
            'pred': generate_func(x['input']),
            'category': x['category'],
        }
    
    def report_for_category(ds):
        res = {
            x: {'ref': [], 'pred': []} for x in ds.unique('category')
        }
        for row in ds:
            res[row['category']]['ref'].append(row['ref'])
            res[row['category']]['pred'].append(row['pred'])
        for x in res:
            res[x]['acc'] = sum(1 for i, j in zip(res[x]['ref'], res[x]['pred']) if i == j) / len(res[x]['ref'])
        return res
    
    datasets = {
        x['name']: x['load_func'](x['path']).map(mapper) for x in ds_path_name_load
    }

    return {
        x: report_for_category(datasets[x]) for x in datasets
    }    

In [None]:
# check if prompt is string or not
prompt = "hello"
print(isinstance(prompt, str))

In [None]:
def get_answer(prompts):
    try:
        res = chatbot.inferencer.generate(prompt)
        index = res.find('[|AI|]')
        if index == -1:
            raise Exception('No answer found')
    except:
        
    return res[index + 7:index +8]

In [4]:
arc = load_arc('../../data/translated/arc/arc_all.json')
mmlu = load_mmlu('../../data/translated/mmlu/mmlu_all.json')

Found cached dataset json (/Users/phamhoang1408/.cache/huggingface/datasets/json/default-f63ca446b2d21b89/0.0.0)
Loading cached processed dataset at /Users/phamhoang1408/.cache/huggingface/datasets/json/default-f63ca446b2d21b89/0.0.0/cache-27108160d296fd3b_*_of_00004.arrow
Found cached dataset json (/Users/phamhoang1408/.cache/huggingface/datasets/json/default-b67d9addebc8e93d/0.0.0)
Loading cached processed dataset at /Users/phamhoang1408/.cache/huggingface/datasets/json/default-b67d9addebc8e93d/0.0.0/cache-1d94efbccab2c05c_*_of_00004.arrow


In [None]:
report()