In [1]:
import os
import json
from tqdm import tqdm
import numpy as np
import random
from copy import deepcopy
import re

In [None]:
branching_factor = 40

with open(f"llm_tests/all_facts_{branching_factor}.json", "r", encoding='utf-8') as f:
    all_facts = json.load(f)
with open(f"llm_tests/X_values_{branching_factor}.json", "r", encoding='utf-8') as f:
    X_values = json.load(f)

print(f"{np.mean([len(all_facts[str(i)]) for i in range(20)])} facts on average.")

In [None]:
P = 10007
core_v_multiplicity = 10
num_core_entities = P * core_v_multiplicity
core_entities = set(["<c_{}>".format(i) for i in range(num_core_entities)])

In [None]:
# generate some random names
import names
from faker import Faker
fake = Faker()

all_names = set()
for _ in tqdm(range(100000)):
    n = fake.name().split()
    if len(n) != 2:
        continue
    all_names.add(n[0].strip())
for _ in tqdm(range(50000)):
    all_names.add(names.get_first_name().strip())

all_names = list(all_names)

In [None]:
def core2val(c):
    return int(c.strip("><c_")) % P

def form_rel_fact(arr, ent2name):
    if np.random.uniform() < 0.5:
        [e1, e2, e3] = arr
    else:
        [e2, e1, e3] = arr
    return f"{ent2name[e1]}'s number plus {ent2name[e2]}'s number is {ent2name[e3]}'s number."

def form_attr_fact(c, ent2name):
    return f"The number of {ent2name[c]} is {core2val(c)}."

prompt_template = \
"""
In the following facts, names correspond to different people as long as their *spellings* are different. For example, 'Jackelyn' and 'Jackeline' represent two *different* people.

Here are the facts:

{}

Different people could have the same number.

What is {}'s number? Your response should end by 'Final Answer: XXX.' 
""".strip()

In [None]:
for example_id in range(len(all_facts)):
    
    facts, gt_answer = all_facts[str(example_id)], X_values[example_id]

    core_ents, other_ents = set(), set()

    for (e1, e2, e3) in facts:
        assert e2 in core_entities
        core_ents.add(e2)
        if e1 in core_entities:
            assert e3 not in core_entities
            core_ents.add(e1)
            other_ents.add(e3)
        elif e3 in core_entities:
            assert e1 not in core_entities
            core_ents.add(e3)
            other_ents.add(e1)
        else:
            other_ents.add(e1)
            other_ents.add(e3)

    target_ent = f'<x_{example_id}>'
    assert target_ent in other_ents

    all_ents = list(core_ents | other_ents)
    assert len(all_ents) == len(core_ents) + len(other_ents)
    assert len(all_names) >= len(all_ents)

    temp_names = deepcopy(all_names[:len(all_ents)])
    random.shuffle(all_ents)
    ent2name = {a:b for a,b in zip(all_ents, temp_names)}

    facts_verbalized = []

    # relational facts
    for fa in facts:
        facts_verbalized.append(form_rel_fact(fa, ent2name))
    # atomic facts of core entities
    for c in core_ents:
        facts_verbalized.append(form_attr_fact(c, ent2name))

    random.shuffle(facts_verbalized)
    facts_verbalized = " ".join(facts_verbalized)

    # final string
    prompt = prompt_template.format(facts_verbalized, ent2name[target_ent])

    with open(f"llm_tests/q_{branching_factor}_{example_id}.txt", "w", encoding='utf-8') as f:
        f.write(prompt)

    with open(f"llm_tests/a_{branching_factor}_{example_id}.txt", "w", encoding='utf-8') as f:
        f.write(gt_answer)

In [2]:
# evaluating o3mini

for branching_factor in [10, 20, 30, 40]:

    with open(f"llm_tests/o3mini_results/gsm_o3mini_{branching_factor}.json", "r", encoding='utf-8') as f:
        all_items = json.load(f)

    correct = 0
    cannot_determine = 0
    for i, item in enumerate(all_items):
        gt_ans = int(item['answer'].strip("><"))
        try:
            temp = item['response'].split("Final Answer:")[-1].strip(". *")
            if gt_ans == int(temp):
                correct += 1
                # print(f"--->\t '{i} correct'")
            else:
                pass
                # print(f"--->\t '{i} wrong'")
        except:
            # print(f"--->\t '{i} undet'")
            cannot_determine += 1
    print(f"branching_factor: {branching_factor}| {correct} correct, {cannot_determine} can't determine, {len(all_items) - correct - cannot_determine} wrong (out of {len(all_items)})")

branching_factor: 10| 17 correct, 0 can't determine, 3 wrong (out of 20)
branching_factor: 20| 4 correct, 1 can't determine, 15 wrong (out of 20)
branching_factor: 30| 0 correct, 2 can't determine, 18 wrong (out of 20)
branching_factor: 40| 0 correct, 4 can't determine, 16 wrong (out of 20)


In [3]:
# eval gemini

import re
def extract_first_number_regex(text):
    # Find the first occurrence of a number
    match = re.search(r'\d+', text)
    return int(match.group()) if match else None

for branching_factor in [10, 20, 30, 40]:

    correct = 0
    cannot_determine = 0
    for example_id in range(20):
        with open(f"llm_tests/gemini_results/{branching_factor}_{example_id}") as f:
            data = json.load(f)
        with open(f"llm_tests/a_{branching_factor}_{example_id}.txt") as f:
            gold_ans = int(f.read().strip(" <>"))

        try:
            pred_ans = data['chunkedPrompt']['chunks'][-1]['text'].split("Final Answer:")[1].strip(". \n")
            pred_ans = int(extract_first_number_regex(pred_ans))
            if pred_ans == gold_ans:
                correct += 1
                # print(example_id, "correct")
                pass
            else:
                # print(example_id, "wrong")
                pass
        except:
            # print(example_id, "no ans")
            cannot_determine += 1

    print(f"branching_factor: {branching_factor}| {correct} correct, {cannot_determine} can't determine, {20 - correct - cannot_determine} wrong (out of {20})")

branching_factor: 10| 9 correct, 0 can't determine, 11 wrong (out of 20)
branching_factor: 20| 9 correct, 0 can't determine, 11 wrong (out of 20)
branching_factor: 30| 7 correct, 0 can't determine, 13 wrong (out of 20)
branching_factor: 40| 5 correct, 2 can't determine, 13 wrong (out of 20)
