## Function definitions

In [8]:
from data_loaders import load_dataset, load_inference_data, promptTechList, modelList, accuracy, problem_list
import sys
sys.path.append('../')

def evaluate_problem(problem : str) -> dict[str, dict[str, list[int]]]:

    # formate the problem and its name
    problem_name = problem
    match problem:
        case 'Direct Boolean Computation'|'DirectBooleanComputation':
            from DirectBooleanComputation import response_evaluator
            problem = 'DirectBooleanComputation'
            problem_name = 'Direct Boolean Computation'

        case 'Indirect Boolean Computation'|'IndirectBooleanComputation':
            from IndirectBooleanComputation import response_evaluator
            problem = 'IndirectBooleanComputation'
            problem_name = 'Indirect Boolean Computation'
        
        case 'SAT':
            from SAT import response_evaluator
        
        case 'SAT Count'|'SATCount':
            from SATCount import response_evaluator
            problem = 'SATCount'
            problem_name = 'SAT Count'
        
        case 'TautologyQ':
            from TautologyQ import response_evaluator
        
        case 'EquivalentQ':
            from EquivalentQ import response_evaluator
        
        case 'CNF':
            from CNF import response_evaluator
        
        case 'DNF':
            from DNF import response_evaluator

    # load corresponding dataset
    dataset = load_dataset(problem)

    inference_data = {
        model: {p: load_inference_data(
            problem, p, model) for p in promptTechList}
        for model in modelList
    }
    
    # evaluate all the reponses
    evaluation_result = {
        model : {
            p: [
                response_evaluator(response['response'], dataObject)
                for dataObject, response in zip(dataset, inference_data[model][p])
            ]
        for p in promptTechList}
    for model in modelList
    }

    return evaluation_result

In [16]:
accuracy_result = {(model, p, problem): accuracy(evaluate_problem(problem)[model][p])
                    for p in promptTechList
                    for model in modelList
                    for problem in problem_list
                    }
accuracy_result

{('llama2-13b', 'Zero-shot', 'DirectBooleanComputation'): 0.1520018778083294,
 ('llama2-13b',
  'Zero-shot',
  'IndirectBooleanComputation'): 0.08919597989949749,
 ('llama2-13b', 'Zero-shot', 'CNF'): 0.03752068602377012,
 ('llama2-13b', 'Zero-shot', 'DNF'): 0.03547465021814352,
 ('llama2-13b', 'Zero-shot', 'TautologyQ'): 0.511555215641938,
 ('llama2-13b', 'Zero-shot', 'EquivalentQ'): 0.22043265435022896,
 ('llama2-13b', 'Zero-shot', 'SAT'): 0.04887133545258967,
 ('llama2-13b', 'Zero-shot', 'SAT Count'): 0.05452247698140458,
 ('wizardmath-13b',
  'Zero-shot',
  'DirectBooleanComputation'): 0.08423311649118101,
 ('wizardmath-13b',
  'Zero-shot',
  'IndirectBooleanComputation'): 0.01870064608758076,
 ('wizardmath-13b', 'Zero-shot', 'CNF'): 0.04083044982698962,
 ('wizardmath-13b', 'Zero-shot', 'DNF'): 0.04068000601775237,
 ('wizardmath-13b', 'Zero-shot', 'TautologyQ'): 0.35886010707278954,
 ('wizardmath-13b', 'Zero-shot', 'EquivalentQ'): 0.5013027001421128,
 ('wizardmath-13b', 'Zero-shot',

In [28]:
[
    str([
        f"& {round(accuracy_result[('llama2-13b', p, problem)] * 100, 2)} / {round(accuracy_result[(('wizardmath-13b', p, problem))] * 100, 2)}" 
        for problem in problem_list
        ]).replace('[',' ').replace(']',' ').replace('\'',' ').replace(',',' ')
    for p in promptTechList
]

['  & 15.2 / 8.42    & 8.92 / 1.87    & 3.75 / 4.08    & 3.55 / 4.07    & 51.16 / 35.89    & 22.04 / 50.13    & 4.89 / 10.14    & 5.45 / 0.5  ',
 '  & 55.96 / 49.87    & 52.64 / 17.64    & 3.82 / 3.45    & 1.91 / 2.61    & 50.24 / 45.5    & 58.47 / 55.12    & 63.9 / 27.66    & 19.64 / 18.17  ',
 '  & 15.6 / 19.23    & 17.43 / 6.58    & 4.34 / 4.21    & 4.36 / 4.16    & 33.39 / 8.82    & 7.83 / 5.69    & 8.58 / 3.38    & 9.28 / 0.84  ',
 '  & 20.7 / 19.89    & 21.97 / 6.25    & 4.09 / 4.21    & 4.06 / 4.24    & 30.48 / 8.39    & 8.5 / 4.47    & 6.57 / 2.9    & 10.1 / 0.83  ',
 '  & 49.25 / 40.49    & 42.77 / 23.33    & 9.68 / 7.72    & 13.96 / 16.38    & 47.61 / 18.71    & 33.6 / 20.46    & 35.9 / 21.08    & 7.7 / 2.7  ']