In [1]:
import json

exp_name = 'memory_core_1'
test_data = json.load(open(f'output/{exp_name}/predicted_type.json', 'r'))
template_and_replacements_examples = json.load(open('data/example/template_and_replacements.json', 'r'))

In [2]:
simple_question_type_to_templates = {
    "compare": ["(compare (GROUP_COUNT x1) number)", "(compare (GROUP_COUNT x1) x2)", "(compare (GROUP_SUM (GROUP_COUNT x1) (GROUP_COUNT x2)) (OR x3 x4))", "(compare (GROUP_SUM (GROUP_COUNT x1) (GROUP_COUNT x2)) number)"],
    "compare_and_count": ["(COUNT (compare (GROUP_COUNT x1) number))", "(COUNT (compare (GROUP_COUNT x1) x2))", "(COUNT (compare (GROUP_SUM (GROUP_COUNT x1) (GROUP_COUNT x2)) (OR x3 x4)))", "(COUNT (compare (GROUP_SUM (GROUP_COUNT x1) (GROUP_COUNT x2)) number))"],
    "count": ["(COUNT (DISTINCT (OR x1 x2)))", "(COUNT (DISTINCT x1))", "(COUNT (GROUP_COUNT x1))", "(COUNT (GROUP_SUM (GROUP_COUNT x1) (GROUP_COUNT x2)))", "(COUNT x1)"],
    "optimize": ["(optimize (GROUP_COUNT x1))", "(optimize (GROUP_SUM (GROUP_COUNT x1) (GROUP_COUNT x2)))"],
    "simple": ["(DIFF x1 x2)", "(DISTINCT x1)", "(GROUP_COUNT x1)", "(GROUP_SUM (GROUP_COUNT x1) (GROUP_COUNT x2))", "(OR x1 x2)", "x1"],
    "verify": ["(ALL x1 x2 x3 x4 x5 x6 x7 x8 x9 x10)", "(ALL x1 x2 x3 x4 x5 x6 x7 x8 x9)", "(ALL x1 x2 x3 x4 x5 x6 x7 x8)", "(ALL x1 x2 x3 x4 x5 x6 x7)", "(ALL x1 x2 x3 x4 x5 x6)", "(ALL x1 x2 x3 x4 x5)", "(ALL x1 x2 x3 x4)", "(ALL x1 x2 x3)", "(ALL x1 x2)", "(ALL x1)"]
}

In [3]:
examples = {}
for simple_question_type in template_and_replacements_examples:
    examples[simple_question_type] = []
    for qa in template_and_replacements_examples[simple_question_type]:
        examples[simple_question_type].append({'role': 'user',
        'content': "question: " + qa['coreference_resolved_question']
         + "\ncandidate_templates: " + ','.join(simple_question_type_to_templates[simple_question_type])
         + "\ncandidate_s_expression_cores: " + '\u001F'.join(qa['s_expression_cores_fn'])})
        replacements = {"variables": {var: qa["s_expression_cores_fn"][qa['replacements']['variables'][var]] for var in qa['replacements']['variables']}, "constants": qa['replacements']['constants'], "functions": qa['replacements']['functions']}
        examples[simple_question_type].append({'role': 'assistant', 'content': json.dumps({'template': qa['template'], 'replacements': replacements})})
examples

{'verify': [{'role': 'user',
   'content': 'question: Output: Do Didrik Ficks gränd have Västerlånggatan and Stora Nygatan as terminus?\ncandidate_templates: (ALL x1 x2 x3 x4 x5 x6 x7 x8 x9 x10),(ALL x1 x2 x3 x4 x5 x6 x7 x8 x9),(ALL x1 x2 x3 x4 x5 x6 x7 x8),(ALL x1 x2 x3 x4 x5 x6 x7),(ALL x1 x2 x3 x4 x5 x6),(ALL x1 x2 x3 x4 x5),(ALL x1 x2 x3 x4),(ALL x1 x2 x3),(ALL x1 x2),(ALL x1)\ncandidate_s_expression_cores: (IS_TRUE Didrik_Ficks_gränd terminus Västerlånggatan)\x1f(IS_TRUE Didrik_Ficks_gränd terminus Stora_Nygatan)'},
  {'role': 'assistant',
   'content': '{"template": "(ALL x1 x2)", "replacements": {"variables": {"x2": "(IS_TRUE Didrik_Ficks_gr\\u00e4nd terminus Stora_Nygatan)", "x1": "(IS_TRUE Didrik_Ficks_gr\\u00e4nd terminus V\\u00e4sterl\\u00e5nggatan)"}, "constants": {}, "functions": {}}}'},
  {'role': 'user',
   'content': 'question: Are Canada next to the border of Prince Edward Island and Greenland?\ncandidate_templates: (ALL x1 x2 x3 x4 x5 x6 x7 x8 x9 x10),(ALL x1 x2 x3 x4

In [4]:
def fix_json(s):
    return s + '}' * (s.count('{') - s.count('}'))

def fix_template(template):
    return template + ')' * (template.count('(') - template.count(')'))

def is_close(expr):
    stack = 0
    for char in expr:
        if char == '(':
            stack += 1
        elif char == ')':
            if stack == 0:
                return False
            stack -= 1
    return stack == 0

def fix_core(core):
    tokens = [token.replace('(', '').replace(')', '') for token in core.split()]
    index = 0
    def parse_core():
        nonlocal index
        func_list = ['JOIN', 'R', 'AND', 'VALUES', 'IS_TRUE']
        token = tokens[index]
        if token in func_list:
            index += 1
            args = []
            if token == 'IS_TRUE':
                for _ in range(3):
                    args.append(parse_core())
            elif token == 'JOIN':
                for _ in range(2):
                    args.append(parse_core())
            elif token == 'R':
                args.append(parse_core())
            elif token == 'AND':
                while index < len(tokens):
                    args.append(parse_core())
            else:
                while index < len(tokens) and tokens[index] not in func_list:
                    args.append(parse_core())
            return f'({token} {' '.join(args)})'
        else:
            value = token
            index += 1
            return value
    try:
        fixed_core = parse_core()
    except:
        fixed_core = core
    return fixed_core

In [None]:
import re
from utils.llm_call import LLM_Call

def concat_input(qa):
    return "question:" + qa['coreference_resolved_question'] \
        + "\ncandidate_templates:" + ','.join(simple_question_type_to_templates[qa['predicted_simple_question_type']]) \
        + "\ncandidate_s_expression_cores:" + '\u001F'.join(qa['calibrated_cores_fn'])

API_KEYS = ['sk-54964e5c3b8c4998a74f7d3e35b618ac', 'sk-785362add9834d1da5907f4621dc90a8']
API_URL = 'https://api.deepseek.com'
API_MODEL = 'deepseek-chat'

api_pool = [(k, API_URL, API_MODEL) for k in API_KEYS]

LLM = LLM_Call(api_pool=api_pool, openai_params = {'temperature': 0, 'max_tokens': 512})

batch_size = 20
batches = [test_data[i:i + batch_size] for i in range(0, len(test_data), batch_size)]

prompt = """
    # S-expression Generation Task

    ## Core Task
    Select an appropriate template from candidate_templates and replace placeholders in the template with numbers from the question, specified functions, or expressions from candidate_s_expression_cores to form a final expression that matches the question's semantics. Output the replacements made.

    ## Key Principles
    1. Semantics First: Replacements must accurately express the question's logic
    2. Strict Replacement: Only replace placeholders while preserving all other template structures
    3. Complete Coverage: Ensure all placeholders are replaced

    ## Operator Semantics
    | Operator  | Description                     | Example                      |
    |-----------|--------------------------------|-----------------------------|
    | OR        | Set union (A ∪ B)              | (OR a b) → union of a and b |
    | DIFF      | Set difference (A - B)         | (DIFF a b) → elements in a but not b |
    | DISTINCT  | Deduplication                  | (DISTINCT a) → deduplicated a |
    | GROUP_COUNT| Group counting statistics     | (GROUP_COUNT a) → count of a |
    | GROUP_SUM | Sum of grouped statistics     | (GROUP_SUM a b) → sum of a and b's stats |
    | ALL       | Universal quantifier          | (ALL a b) → a∧b            |

    ## Replacement Rules
    1. Variable placeholder replacement:
    - Must select from candidate_s_expression_cores
    - Selected core text must remain unchanged
    - Determine replacement order based on question semantics
    - Different x should use different cores

    2. Constant placeholder replacement:
    - number → Explicit number from the question

    3. Function placeholder replacement:
    - compare → Select based on question semantics:
        - LT(less than) / LE(less than or equal to) / EQ(equal to) / GE(greater than or equal to) / GT(greater than)
    - optimize → Select based on extreme value:
        - ARGMAX(maximize) / ARGMIN(minimize)

    ## Output Specification
    - Output in single-line JSON format without any explanatory text
    - Must contain three fields (variables/constants/functions)
    - Empty fields represented with empty objects (e.g. "constants": {})
    """

for i, batch in enumerate(batches):
    print(f"batch: {i + 1}/{len(batches)}")
    messages_list = []
    for qa in batch:
        if qa['predicted_simple_question_type'] not in examples:
            qa['predicted_simple_question_type'] = 'simple'
        selected_examples = examples[qa['predicted_simple_question_type']]
        messages_list.append([{'role': 'system', 'content': prompt}] + selected_examples+ [{'role': 'user', 'content': concat_input(qa)}])
        
    resp_list = await LLM._batch_generate_async(messages_list)
    
    for j, qa in enumerate(batch):
        resp = resp_list[j]
        content = resp.choices[0].message.content
        try:
            content = re.search(r'\{[\s\S]*\}', content).group(0)
            template_and_replacements = json.loads(fix_json(content))
            template = fix_template(template_and_replacements['template'])
            replacements = template_and_replacements['replacements']
            qa['predicted_template'] = template
            qa['predicted_replacements'] = replacements
        except:
            qa['predicted_template'] = ''
            qa['predicted_replacements'] = ''

    if (i + 1) % 5 == 0:
        json.dump(test_data, open(f'output/{exp_name}/prediction.json', 'w'), indent=2)


[('sk-54964e5c3b8c4998a74f7d3e35b618ac', 'https://api.deepseek.com', 'deepseek-chat'), ('sk-785362add9834d1da5907f4621dc90a8', 'https://api.deepseek.com', 'deepseek-chat')]
batch: 1/190


100%|██████████| 20/20 [00:17<00:00,  1.14it/s]


batch: 2/190


100%|██████████| 20/20 [00:22<00:00,  1.14s/it]


batch: 3/190


100%|██████████| 20/20 [00:18<00:00,  1.09it/s]


batch: 4/190


100%|██████████| 20/20 [00:28<00:00,  1.41s/it]


batch: 5/190


100%|██████████| 20/20 [00:32<00:00,  1.62s/it]


batch: 6/190


100%|██████████| 20/20 [00:21<00:00,  1.09s/it]


batch: 7/190


100%|██████████| 20/20 [00:44<00:00,  2.24s/it]


batch: 8/190


100%|██████████| 20/20 [00:35<00:00,  1.79s/it]


batch: 9/190


100%|██████████| 20/20 [00:15<00:00,  1.32it/s]


batch: 10/190


100%|██████████| 20/20 [00:16<00:00,  1.24it/s]


batch: 11/190


100%|██████████| 20/20 [00:22<00:00,  1.12s/it]


batch: 12/190


100%|██████████| 20/20 [00:23<00:00,  1.18s/it]


batch: 13/190


100%|██████████| 20/20 [00:19<00:00,  1.00it/s]


batch: 14/190


100%|██████████| 20/20 [00:20<00:00,  1.04s/it]


batch: 15/190


100%|██████████| 20/20 [00:25<00:00,  1.29s/it]


batch: 16/190


100%|██████████| 20/20 [00:21<00:00,  1.06s/it]


batch: 17/190


100%|██████████| 20/20 [00:47<00:00,  2.38s/it]


batch: 18/190


100%|██████████| 20/20 [00:29<00:00,  1.48s/it]


batch: 19/190


100%|██████████| 20/20 [00:32<00:00,  1.63s/it]


batch: 20/190


100%|██████████| 20/20 [01:19<00:00,  3.97s/it]


batch: 21/190


100%|██████████| 20/20 [00:26<00:00,  1.30s/it]


batch: 22/190


100%|██████████| 20/20 [00:50<00:00,  2.54s/it]


batch: 23/190


100%|██████████| 20/20 [00:15<00:00,  1.27it/s]


batch: 24/190


100%|██████████| 20/20 [00:20<00:00,  1.00s/it]


batch: 25/190


100%|██████████| 20/20 [00:39<00:00,  1.95s/it]


batch: 26/190


100%|██████████| 20/20 [00:33<00:00,  1.66s/it]


batch: 27/190


100%|██████████| 20/20 [00:22<00:00,  1.14s/it]


batch: 28/190


100%|██████████| 20/20 [00:23<00:00,  1.16s/it]


batch: 29/190


100%|██████████| 20/20 [00:19<00:00,  1.03it/s]


batch: 30/190


100%|██████████| 20/20 [00:13<00:00,  1.44it/s]


batch: 31/190


100%|██████████| 20/20 [00:14<00:00,  1.39it/s]


batch: 32/190


100%|██████████| 20/20 [00:15<00:00,  1.33it/s]


batch: 33/190


100%|██████████| 20/20 [00:20<00:00,  1.04s/it]


batch: 34/190


100%|██████████| 20/20 [00:16<00:00,  1.20it/s]


batch: 35/190


100%|██████████| 20/20 [00:17<00:00,  1.14it/s]


batch: 36/190


100%|██████████| 20/20 [00:11<00:00,  1.72it/s]


batch: 37/190


100%|██████████| 20/20 [00:24<00:00,  1.22s/it]


batch: 38/190


100%|██████████| 20/20 [00:17<00:00,  1.16it/s]


batch: 39/190


100%|██████████| 20/20 [00:13<00:00,  1.50it/s]


batch: 40/190


100%|██████████| 20/20 [00:14<00:00,  1.40it/s]


batch: 41/190


100%|██████████| 20/20 [00:16<00:00,  1.21it/s]


batch: 42/190


100%|██████████| 20/20 [00:27<00:00,  1.39s/it]


batch: 43/190


100%|██████████| 20/20 [00:43<00:00,  2.18s/it]


batch: 44/190


100%|██████████| 20/20 [00:16<00:00,  1.19it/s]


batch: 45/190


100%|██████████| 20/20 [00:16<00:00,  1.20it/s]


batch: 46/190


100%|██████████| 20/20 [00:14<00:00,  1.34it/s]


batch: 47/190


100%|██████████| 20/20 [00:19<00:00,  1.00it/s]


batch: 48/190


100%|██████████| 20/20 [00:15<00:00,  1.29it/s]


batch: 49/190


100%|██████████| 20/20 [00:14<00:00,  1.43it/s]


batch: 50/190


100%|██████████| 20/20 [00:51<00:00,  2.59s/it]


batch: 51/190


100%|██████████| 20/20 [00:48<00:00,  2.45s/it]


batch: 52/190


100%|██████████| 20/20 [00:17<00:00,  1.15it/s]


batch: 53/190


100%|██████████| 20/20 [01:23<00:00,  4.16s/it]


batch: 54/190


100%|██████████| 20/20 [00:49<00:00,  2.46s/it]


batch: 55/190


100%|██████████| 20/20 [00:17<00:00,  1.16it/s]


batch: 56/190


100%|██████████| 20/20 [00:48<00:00,  2.41s/it]


batch: 57/190


100%|██████████| 20/20 [00:55<00:00,  2.76s/it]


batch: 58/190


100%|██████████| 20/20 [01:07<00:00,  3.36s/it]


batch: 59/190


100%|██████████| 20/20 [00:23<00:00,  1.18s/it]


batch: 60/190


100%|██████████| 20/20 [00:34<00:00,  1.72s/it]


batch: 61/190


100%|██████████| 20/20 [00:32<00:00,  1.62s/it]


batch: 62/190


100%|██████████| 20/20 [00:21<00:00,  1.10s/it]


batch: 63/190


100%|██████████| 20/20 [00:18<00:00,  1.08it/s]


batch: 64/190


100%|██████████| 20/20 [00:52<00:00,  2.61s/it]


batch: 65/190


100%|██████████| 20/20 [00:26<00:00,  1.34s/it]


batch: 66/190


100%|██████████| 20/20 [00:24<00:00,  1.25s/it]


batch: 67/190


100%|██████████| 20/20 [00:49<00:00,  2.47s/it]


batch: 68/190


100%|██████████| 20/20 [00:18<00:00,  1.06it/s]


batch: 69/190


100%|██████████| 20/20 [00:19<00:00,  1.03it/s]


batch: 70/190


100%|██████████| 20/20 [00:30<00:00,  1.54s/it]


batch: 71/190


100%|██████████| 20/20 [00:25<00:00,  1.28s/it]


batch: 72/190


100%|██████████| 20/20 [00:17<00:00,  1.15it/s]


batch: 73/190


100%|██████████| 20/20 [00:30<00:00,  1.55s/it]


batch: 74/190


100%|██████████| 20/20 [00:16<00:00,  1.18it/s]


batch: 75/190


100%|██████████| 20/20 [00:23<00:00,  1.19s/it]


batch: 76/190


100%|██████████| 20/20 [00:17<00:00,  1.12it/s]


batch: 77/190


100%|██████████| 20/20 [00:20<00:00,  1.04s/it]


batch: 78/190


100%|██████████| 20/20 [00:21<00:00,  1.08s/it]


batch: 79/190


100%|██████████| 20/20 [00:42<00:00,  2.12s/it]


batch: 80/190


100%|██████████| 20/20 [00:18<00:00,  1.06it/s]


batch: 81/190


100%|██████████| 20/20 [00:19<00:00,  1.03it/s]


batch: 82/190


100%|██████████| 20/20 [00:18<00:00,  1.06it/s]


batch: 83/190


100%|██████████| 20/20 [00:29<00:00,  1.49s/it]


batch: 84/190


100%|██████████| 20/20 [00:18<00:00,  1.08it/s]


batch: 85/190


100%|██████████| 20/20 [00:31<00:00,  1.56s/it]


batch: 86/190


100%|██████████| 20/20 [00:41<00:00,  2.07s/it]


batch: 87/190


100%|██████████| 20/20 [00:55<00:00,  2.77s/it]


batch: 88/190


100%|██████████| 20/20 [00:41<00:00,  2.07s/it]


batch: 89/190


100%|██████████| 20/20 [00:19<00:00,  1.03it/s]


batch: 90/190


100%|██████████| 20/20 [00:40<00:00,  2.00s/it]


batch: 91/190


100%|██████████| 20/20 [00:19<00:00,  1.03it/s]


batch: 92/190


100%|██████████| 20/20 [00:25<00:00,  1.27s/it]


batch: 93/190


100%|██████████| 20/20 [00:20<00:00,  1.04s/it]


batch: 94/190


100%|██████████| 20/20 [00:22<00:00,  1.13s/it]


batch: 95/190


100%|██████████| 20/20 [00:26<00:00,  1.32s/it]


batch: 96/190


100%|██████████| 20/20 [00:51<00:00,  2.56s/it]


batch: 97/190


100%|██████████| 20/20 [00:38<00:00,  1.94s/it]


batch: 98/190


100%|██████████| 20/20 [00:18<00:00,  1.06it/s]


batch: 99/190


100%|██████████| 20/20 [00:49<00:00,  2.46s/it]


batch: 100/190


100%|██████████| 20/20 [00:39<00:00,  1.97s/it]


batch: 101/190


100%|██████████| 20/20 [00:16<00:00,  1.19it/s]


batch: 102/190


100%|██████████| 20/20 [00:32<00:00,  1.60s/it]


batch: 103/190


100%|██████████| 20/20 [00:26<00:00,  1.33s/it]


batch: 104/190


100%|██████████| 20/20 [00:39<00:00,  1.97s/it]


batch: 105/190


100%|██████████| 20/20 [00:43<00:00,  2.17s/it]


batch: 106/190


100%|██████████| 20/20 [00:52<00:00,  2.62s/it]


batch: 107/190


100%|██████████| 20/20 [00:36<00:00,  1.83s/it]


batch: 108/190


100%|██████████| 20/20 [00:22<00:00,  1.14s/it]


batch: 109/190


100%|██████████| 20/20 [01:02<00:00,  3.10s/it]


batch: 110/190


100%|██████████| 20/20 [01:00<00:00,  3.02s/it]


batch: 111/190


100%|██████████| 20/20 [00:33<00:00,  1.67s/it]


batch: 112/190


100%|██████████| 20/20 [00:19<00:00,  1.01it/s]


batch: 113/190


100%|██████████| 20/20 [00:40<00:00,  2.02s/it]


batch: 114/190


100%|██████████| 20/20 [00:37<00:00,  1.87s/it]


batch: 115/190


100%|██████████| 20/20 [00:16<00:00,  1.22it/s]


batch: 116/190


100%|██████████| 20/20 [00:16<00:00,  1.19it/s]


batch: 117/190


100%|██████████| 20/20 [01:10<00:00,  3.53s/it]


batch: 118/190


100%|██████████| 20/20 [00:28<00:00,  1.44s/it]


batch: 119/190


100%|██████████| 20/20 [00:20<00:00,  1.04s/it]


batch: 120/190


100%|██████████| 20/20 [00:16<00:00,  1.22it/s]


batch: 121/190


100%|██████████| 20/20 [00:17<00:00,  1.17it/s]


batch: 122/190


100%|██████████| 20/20 [00:23<00:00,  1.18s/it]


batch: 123/190


100%|██████████| 20/20 [00:17<00:00,  1.15it/s]


batch: 124/190


100%|██████████| 20/20 [00:50<00:00,  2.53s/it]


batch: 125/190


100%|██████████| 20/20 [00:26<00:00,  1.33s/it]


batch: 126/190


100%|██████████| 20/20 [00:35<00:00,  1.78s/it]


batch: 127/190


100%|██████████| 20/20 [00:23<00:00,  1.16s/it]


batch: 128/190


100%|██████████| 20/20 [00:22<00:00,  1.12s/it]


batch: 129/190


100%|██████████| 20/20 [00:29<00:00,  1.45s/it]


batch: 130/190


100%|██████████| 20/20 [00:19<00:00,  1.05it/s]


batch: 131/190


100%|██████████| 20/20 [00:34<00:00,  1.73s/it]


batch: 132/190


100%|██████████| 20/20 [00:50<00:00,  2.52s/it]


batch: 133/190


100%|██████████| 20/20 [00:40<00:00,  2.05s/it]


batch: 134/190


100%|██████████| 20/20 [00:55<00:00,  2.79s/it]


batch: 135/190


100%|██████████| 20/20 [00:21<00:00,  1.08s/it]


batch: 136/190


100%|██████████| 20/20 [00:23<00:00,  1.20s/it]


batch: 137/190


100%|██████████| 20/20 [00:37<00:00,  1.88s/it]


batch: 138/190


100%|██████████| 20/20 [00:17<00:00,  1.16it/s]


batch: 139/190


100%|██████████| 20/20 [00:11<00:00,  1.71it/s]


batch: 140/190


100%|██████████| 20/20 [00:27<00:00,  1.39s/it]


batch: 141/190


100%|██████████| 20/20 [00:53<00:00,  2.67s/it]


batch: 142/190


100%|██████████| 20/20 [01:04<00:00,  3.20s/it]


batch: 143/190


100%|██████████| 20/20 [00:24<00:00,  1.22s/it]


batch: 144/190


100%|██████████| 20/20 [00:36<00:00,  1.85s/it]


batch: 145/190


100%|██████████| 20/20 [00:15<00:00,  1.26it/s]


batch: 146/190


100%|██████████| 20/20 [00:14<00:00,  1.38it/s]


batch: 147/190


100%|██████████| 20/20 [00:17<00:00,  1.13it/s]


batch: 148/190


100%|██████████| 20/20 [00:16<00:00,  1.24it/s]


batch: 149/190


100%|██████████| 20/20 [00:18<00:00,  1.09it/s]


batch: 150/190


100%|██████████| 20/20 [00:47<00:00,  2.35s/it]


batch: 151/190


100%|██████████| 20/20 [00:49<00:00,  2.48s/it]


batch: 152/190


100%|██████████| 20/20 [00:14<00:00,  1.34it/s]


batch: 153/190


100%|██████████| 20/20 [00:17<00:00,  1.12it/s]


batch: 154/190


100%|██████████| 20/20 [00:15<00:00,  1.27it/s]


batch: 155/190


100%|██████████| 20/20 [00:40<00:00,  2.04s/it]


batch: 156/190


100%|██████████| 20/20 [00:26<00:00,  1.31s/it]


batch: 157/190


100%|██████████| 20/20 [00:14<00:00,  1.37it/s]


batch: 158/190


100%|██████████| 20/20 [00:19<00:00,  1.02it/s]


batch: 159/190


100%|██████████| 20/20 [00:27<00:00,  1.35s/it]


batch: 160/190


100%|██████████| 20/20 [00:21<00:00,  1.07s/it]


batch: 161/190


 90%|█████████ | 18/20 [00:19<00:01,  1.30it/s]

In [None]:
from utils.parse_expr import expression_to_sparql
from retriever.semantic_retriever import SemanticRetriever

entity_retriever = SemanticRetriever('entity')
relation_retriever = SemanticRetriever('relation')
type_retriever = SemanticRetriever('type')

def sub_fn_to_mid(expression):
    func_list = ['R', 'JOIN', 'AND', 'OR', 'DIFF', 'VALUES', 'DISTINCT', 'COUNT', 'GROUP_COUNT', 'GROUP_SUM', 'LT', 'LE', 'EQ', 'GE', 'GT', 'ARGMIN', 'ARGMAX', 'ALL', 'IS_TRUE']
    seg_list = expression.split()
    last_token, second_last_token = '', ''
    for i in range(len(seg_list)):
        token = seg_list[i].strip(')(')
        if token not in func_list and not token.isdigit():
            if last_token in ['R', 'JOIN'] or second_last_token == 'IS_TRUE':
                retriever = relation_retriever
            elif last_token == 'P31':
                retriever = type_retriever
            else:
                retriever = entity_retriever
            mid = retriever.semantic_search(token)[0][1]
            seg_list[i] = seg_list[i].replace(token, mid)
        second_last_token = last_token
        last_token = token
    new_expression = ' '.join(seg_list)
    return new_expression

In [None]:
# import json
# exp_name = 'dev_each_type_50'
# test_data = test_data = json.load(open(f'output/{exp_name}/prediction.json', 'r'))

for qa in test_data:
    template = qa['predicted_template']
    replacements = qa['predicted_replacements']
    try:
        assert(template)
        for function in replacements['functions']:
            template = template.replace(function, replacements['functions'][function])
        if replacements['constants']:
            template = template.replace('number', str(replacements['constants']['number']))

        s_expression = template
        s_expression_fn = template
        for var in sorted(replacements['variables'], reverse=True):
            core_fn = replacements['variables'][var]
            if not is_close(core_fn):
                core_fn = fix_core(core_fn)
            if core_fn in qa['calibrated_cores_fn']:
                index = qa['calibrated_cores_fn'].index(core_fn)
                core = qa['calibrated_cores'][index]
            else:
                core = sub_fn_to_mid(core_fn)
            s_expression = s_expression.replace(var, core, 1)
            s_expression_fn = s_expression_fn.replace(var, core_fn, 1)
        qa['predicted_s_expression'] = s_expression
        qa['predicted_s_expression_fn'] = s_expression_fn
    except:
        qa['predicted_s_expression_fn'] = ''
        qa['predicted_s_expression'] = ''

In [None]:
for qa in test_data:
    qa['actions'] = expression_to_sparql(qa['predicted_s_expression'])

json.dump(test_data, open(f'output/{exp_name}/prediction.json', 'w'), indent=2)