In [1]:
def proc_direct(direct):
    if 'younger' in direct and 'older' not in direct:
        direct_pred = 0
    elif 'older' in direct and 'younger' not in direct:
        direct_pred = 2
    elif 'same age' in direct:
        direct_pred = 1
    elif 'cannot decide' in direct:
        direct_pred = 3
    else:
        direct_pred = -1
    return direct_pred

def proc_cot(cot):
    if 'final answer:' not in cot:
        # print(cot)
        # print("*****************")
        cot_pred = -1
    else:
        cot = cot.split('final answer:')[-1].strip()
        if 'cannot decide' in cot:
            cot_pred = 3
        elif 'younger' in cot and 'older' not in cot:
            cot_pred = 0
        elif 'older' in cot and 'younger' not in cot:
            cot_pred = 2
        elif 'same age' in cot:
            cot_pred = 1
        else:
            cot_pred = -1
    return cot_pred

def evaluate(pred, gold, correct_inds=False):
    inds = []
    eval_dict = {"invalid": 0, "correct": 0, "wrong": 0, "cannot decide": 0}
    for i in range(len(pred)):
        p,g = pred[i], gold[i]
        if p == -1:
            eval_dict["invalid"] += 1
        elif p == 3:
            eval_dict["cannot decide"] += 1
        else:
            assert p in [0,1,2] and g in [0,1,2]
            if p==g:
                inds.append(i)
                eval_dict['correct'] += 1
            else:
                eval_dict['wrong'] += 1
    for key, val in eval_dict.items():
        eval_dict[key] = round(val/len(pred), 3)
    print(eval_dict)
    if correct_inds:
        return inds

In [2]:
# no retrieval augmentation
print("======no retrieval augmentation======")
model = 'gemini'

direct_pred_l = []
cot_pred_l = []
cot_whole = []
gold_l = []

for k in range(150):
    with open("LLM/{}_directna_{}.txt".format(model, k)) as f:
        direct = f.read().strip()
    with open("LLM/{}_cot_{}.txt".format(model, k)) as f:
        cot = f.read().strip().lower()
    with open("LLM/answer_{}.txt".format(k)) as f:
        ans = f.read().strip()
    direct_pred_l.append(proc_direct(direct))
    cot_whole.append(cot)
    cot_pred_l.append(proc_cot(cot))
    gold_l.append(int(ans))

print("model: {}, direct QA performance:".format(model), end=" ")
evaluate(direct_pred_l, gold_l)
print("model: {}, CoT performance:".format(model), end=" ")
correct_inds = evaluate(cot_pred_l, gold_l, correct_inds=True)

model: gemini, direct QA performance: {'invalid': 0.0, 'correct': 0.287, 'wrong': 0.513, 'cannot decide': 0.2}
model: gemini, CoT performance: {'invalid': 0.0, 'correct': 0.113, 'wrong': 0.18, 'cannot decide': 0.707}


In [3]:
# with retrieval augmentation
print("======retrieval augmentation======")
for model in ['gpt4turbo', 'gemini']:

    direct_pred_l = []
    cot_pred_l = []
    cot_whole = []
    gold_l = []

    for k in range(150):
        with open("LLM/{}_retrieval_directna_{}.txt".format(model, k)) as f:
            direct = f.read().strip()
        with open("LLM/{}_retrieval_cot_{}.txt".format(model, k)) as f:
            cot = f.read().strip().lower()
        with open("LLM/answer_{}.txt".format(k)) as f:
            ans = f.read().strip()
        direct_pred_l.append(proc_direct(direct))
        cot_whole.append(cot)
        cot_pred_l.append(proc_cot(cot))
        gold_l.append(int(ans))

    print("model: {}, direct QA performance:".format(model), end=" ")
    evaluate(direct_pred_l, gold_l)
    print("model: {}, CoT performance:".format(model), end=" ")
    correct_inds = evaluate(cot_pred_l, gold_l, correct_inds=True)

model: gpt4turbo, direct QA performance: {'invalid': 0.0, 'correct': 0.333, 'wrong': 0.667, 'cannot decide': 0.0}
model: gpt4turbo, CoT performance: {'invalid': 0.04, 'correct': 0.313, 'wrong': 0.52, 'cannot decide': 0.127}
model: gemini, direct QA performance: {'invalid': 0.0, 'correct': 0.373, 'wrong': 0.593, 'cannot decide': 0.033}
model: gemini, CoT performance: {'invalid': 0.0, 'correct': 0.12, 'wrong': 0.293, 'cannot decide': 0.587}
