In [1]:
import json
import os
import re
import numpy as np
import itertools
from utils.execute_query import execute_query
from utils.parse_expr import expression_to_sparql, ParseError

os.environ['CUDA_VISIBLE_DEVICES'] = '3'
from retriever.semantic_retriever import SemanticRetriever

exp_name = 'my_top1'
test_data = json.load(open(f'output/{exp_name}/cores.json', 'r'))

entity_retriever = SemanticRetriever('entity')
relation_retriever = SemanticRetriever('relation')
type_retriever = SemanticRetriever('type')
rela_mid_to_faiss_index = {mid: i for i, mid in enumerate(relation_retriever.mid_list)}
rela_fn_to_faiss_index = {fn: i for i, fn in enumerate(relation_retriever.fn_list)}
type_mid_to_faiss_index = {mid: i for i, mid in enumerate(type_retriever.mid_list)}
type_fn_to_faiss_index = {mid: i for i, mid in enumerate(type_retriever.fn_list)}
wikidata_mid_to_fn = json.load(open('data/wikidata_mid_to_fn.json', 'r'))

In [2]:
# top_1 linking
def sub_fn_to_mid(expression):
    func_list = ['R', 'JOIN', 'AND', 'OR', 'DIFF', 'VALUES', 'DISTINCT', 'COUNT', 'GROUP_COUNT', 'GROUP_SUM', 'LT', 'LE', 'EQ', 'GE', 'GT', 'ARGMIN', 'ARGMAX', 'ALL', 'IS_TRUE']
    seg_list = expression.split()
    last_token, second_last_token = '', ''
    for i in range(len(seg_list)):
        token = seg_list[i].strip(')(')
        if token not in func_list and not token.isdigit():
            if last_token in ['R', 'JOIN'] or second_last_token == 'IS_TRUE':
                retriever = relation_retriever
            elif last_token == 'instance_of':
                retriever = type_retriever
            else:
                retriever = entity_retriever
            mid = retriever.semantic_search(token)[0][1]
            seg_list[i] = seg_list[i].replace(token, mid)
        second_last_token = last_token
        last_token = token
    new_expression = ' '.join(seg_list)
    return new_expression

def sub_mid_to_fn(expression):
    seg_list = expression.split()
    for i in range(len(seg_list)):
        token = seg_list[i].strip(')(')
        if token.startswith('P') or token.startswith('Q'):
            fn = wikidata_mid_to_fn.get(token, "unknown_entity")
            seg_list[i] = seg_list[i].replace(token, fn)
    new_expression = ' '.join(seg_list)
    return new_expression

In [3]:
# top_k linking

num_variants = 1

def is_close(expr):
    stack = 0
    for char in expr:
        if char == '(':
            stack += 1
        elif char == ')':
            if stack == 0:
                return False
            stack -= 1
    return stack == 0

def fix_core(core):
    tokens = [token.replace('(', '').replace(')', '') for token in core.split()]
    index = 0
    def parse_core():
        nonlocal index
        func_list = ['JOIN', 'R', 'AND', 'VALUES', 'IS_TRUE']
        token = tokens[index]
        if token in func_list:
            index += 1
            args = []
            if token == 'IS_TRUE':
                for _ in range(3):
                    args.append(parse_core())
            elif token == 'JOIN':
                for _ in range(2):
                    args.append(parse_core())
            elif token == 'R':
                args.append(parse_core())
            elif token == 'AND':
                while index < len(tokens):
                    args.append(parse_core())
            else:
                while index < len(tokens) and tokens[index] not in func_list:
                    args.append(parse_core())
            return f'({token} {' '.join(args)})'
        else:
            value = token
            index += 1
            return value
    try:
        fixed_core = parse_core()
    except:
        fixed_core = core
    return fixed_core

def get_1hop_relations(entity):
    query = "SELECT DISTINCT ?x0 WHERE { { ?x1 ?x0 wd:" + entity + " . } UNION { wd:" + entity + " ?x0 ?x1 . } }"
    return [rela for rela in execute_query(query) if rela[0] == 'P']

def calculate_rela_similarity(vec1, rela_mid):
    def get_vector_from_mid(rela_mid):
        if rela_mid not in relation_retriever.mid_list:
            rela_mid = relation_retriever.semantic_search(wikidata_mid_to_fn[rela_mid])[0][1]
        return relation_retriever.index.reconstruct(rela_mid_to_faiss_index[rela_mid])

    vec2 = get_vector_from_mid(rela_mid)
    return np.dot(vec1, vec2) / (np.linalg.norm(vec1) * np.linalg.norm(vec2))

def calculate_type_similarity(vec1, type_mid):
    def get_vector_from_mid(type_mid):
        if type_mid not in type_retriever.mid_list:
            type_mid = type_retriever.semantic_search(wikidata_mid_to_fn[type_mid])[0][1]
        return type_retriever.index.reconstruct(type_mid_to_faiss_index[type_mid])

    vec2 = get_vector_from_mid(type_mid)
    return np.dot(vec1, vec2) / (np.linalg.norm(vec1) * np.linalg.norm(vec2))

def add_reverse(org_exp):
    final_candi = [org_exp]
    total_join = 0
    list_seg = org_exp.split(" ")
    for seg in list_seg:
        if "JOIN" in seg:
            total_join += 1
    for i in range(total_join):
        final_candi = final_candi + add_reverse_index(final_candi, i + 1)
    return final_candi

def add_reverse_index(list_of_e, join_id):
    added_list = []
    list_of_e_copy = list_of_e.copy()
    for exp in list_of_e_copy:
        list_seg = exp.split(" ")
        count = 0
        for i, seg in enumerate(list_seg):
            if "JOIN" in seg and list_seg[i + 1] != "(R":
                count += 1
                if count != join_id:
                    continue
                if list_seg[i + 1] == 'P31':
                    break
                list_seg[i + 1] = "(R " + list_seg[i + 1] + ")"
                added_list.append(" ".join(list_seg))
                break
            if "JOIN" in seg and list_seg[i + 1] == "(R":
                count += 1
                if count != join_id:
                    continue
                list_seg[i + 1] = ""
                list_seg[i + 2] = list_seg[i + 2][:-1]
                added_list.append(" ".join(" ".join(list_seg).split()))
                break
    return added_list

def bound_to_existed(s_expression):
    results = []
    query_count = 0
    expression_segment = s_expression.split(" ")
    expression_segment_copy = expression_segment.copy()

    type_count = 0
    type_replace_dict = {}
    for i, seg in enumerate(expression_segment):
        processed_seg = seg.strip(')')
        if i > 0 and expression_segment[i - 1] == 'instance_of':
            type_count += 1
            expression_segment_copy[i - 1] = 'P31'
            type_replace_dict[i] = f'?t{type_count}'
            expression_segment_copy[i] = type_replace_dict[i] + ')' * (len(seg) - len(processed_seg))

    enti_replace_dict = {}
    for i, seg in enumerate(expression_segment):
        processed_seg = seg.strip(')')
        if processed_seg[0] != '(' and not processed_seg.isdigit() and not (i > 0 and expression_segment[i - 1] in ['(R', '(JOIN'] or i > 1 and expression_segment[i - 2] == '(IS_TRUE') and not expression_segment[i - 1] == 'instance_of':
            enti_replace_dict[i] = [mid for score, mid, fn in entity_retriever.semantic_search(processed_seg)]
    if len(enti_replace_dict) > 4:
        top_k = 1
    elif len(enti_replace_dict) > 2:
        top_k = 1
    else:
        top_k = 1

    for i in enti_replace_dict:
        enti_replace_dict[i] = enti_replace_dict[i][:top_k]
    
    # print({id: [wikidata_mid_to_fn[mid] for mid in mids] for id, mids in enti_replace_dict.items()})
    
    combinations = list(enti_replace_dict.values())
    all_iters = list(itertools.product(*combinations)) # 所有可能的实体替换方案
    enti_index = list(enti_replace_dict.keys()) # 待替换实体在expression_segment中的index
    for iters in all_iters:
        for i in range(len(iters)):
            cur_enti = expression_segment[enti_index[i]]
            suffix = ')' * (len(cur_enti) - len(cur_enti.strip(')')))
            expression_segment_copy[enti_index[i]] = iters[i] + suffix

        rela_replace_dict = {}
        for j, seg in enumerate(expression_segment):
            processed_seg = seg.strip(')')
            if processed_seg[0] != '(' and not processed_seg.isdigit() and (j > 0 and expression_segment[j - 1] in ['(R', '(JOIN'] or j > 1 and expression_segment[j - 2] == '(IS_TRUE') and not processed_seg == 'instance_of':
                if expression_segment[j + 1] in ['(JOIN', '(AND']:
                    rela_replace_dict[j] = [mid for score, mid, fn in relation_retriever.semantic_search(processed_seg)]
                else:
                    possible_rela = []
                    if expression_segment[j + 1] == '(VALUES':
                        index = j + 2
                        while index < len(expression_segment) and expression_segment[index][0] != '(':
                            possible_rela += get_1hop_relations(expression_segment_copy[index].strip(')'), )
                            query_count += 1
                            index += 1
                    else:
                        possible_rela += get_1hop_relations(expression_segment_copy[j + 1].strip(')'))
                        query_count += 1
                    possible_rela = set(possible_rela)
                    seg_fn = relation_retriever.semantic_search(processed_seg)[0][2] if processed_seg not in relation_retriever.fn_list else processed_seg
                    seg_vector = relation_retriever.index.reconstruct(rela_fn_to_faiss_index[seg_fn])
                    rela_similarity = [(rela, calculate_rela_similarity(seg_vector, rela)) for rela in possible_rela if rela in wikidata_mid_to_fn]
                    rela_similarity.sort(key=lambda x: x[1], reverse=True)
                    rela_replace_dict[j] = [rela for rela, score in rela_similarity]
        if len(rela_replace_dict) > 4:
            top_k = 1
        elif len(rela_replace_dict) > 2:
            top_k = 1
        else:
            top_k = 1

        for j in rela_replace_dict:
            rela_replace_dict[j] = rela_replace_dict[j][:top_k]
        
        # print({id: [wikidata_mid_to_fn[mid] for mid in mids] for id, mids in rela_replace_dict.items()})

        combinations_rela = list(rela_replace_dict.values())
        all_iters_rela = list(itertools.product(*combinations_rela))
        rela_index = list(rela_replace_dict.keys())
        for iter_rela in all_iters_rela:
            for k in range(len(iter_rela)):
                cur_rela = expression_segment[rela_index[k]]
                suffix = ')' * (len(cur_rela) - len(cur_rela.strip(')')))
                expression_segment_copy[rela_index[k]] = iter_rela[k] + suffix
            final = " ".join(expression_segment_copy)
            added = add_reverse(final) # 反转关系生成变体

            # 遍历added，首次能够查询到结果时，返回查询结果
            for exp in added:
                # print(exp)
                if type_replace_dict:
                    try:
                        sparql = expression_to_sparql(exp)
                    except ParseError:
                        return '', query_count
                    sparql = sparql.replace('SELECT ?x', f'SELECT DISTINCT {' '.join(type_replace_dict.values())}', 1)
                    possible_types = execute_query(sparql, multi_var=True)
                    # print([[wikidata_mid_to_fn[mid] for mid in binding] for binding in possible_types])
                    if isinstance(possible_types, list) and not possible_types or not possible_types[0]:
                        continue
                    query_count += 1
                    indexes = list(type_replace_dict.keys())
                    for type_id in range(len(type_replace_dict)):
                        index = indexes[type_id]
                        seg = expression_segment[index]
                        processed_seg = seg.strip(')')
                        possible_type = set([binding[type_id] for binding in possible_types])
                        seg_fn = type_retriever.semantic_search(processed_seg)[0][2] if processed_seg not in type_retriever.fn_list else processed_seg
                        seg_vector = type_retriever.index.reconstruct(type_fn_to_faiss_index[seg_fn])
                        type_similarity = [(typ, calculate_type_similarity(seg_vector, typ)) for typ in possible_type if typ in wikidata_mid_to_fn]
                        if not type_similarity:
                            exp = ''
                            break
                        # print([(wikidata_mid_to_fn[mid], score) for mid, score in type_similarity])
                        most_similar_type, highest_similarity = max(type_similarity, key=lambda x: x[1])
                        possible_types = [binding for binding in possible_types if binding[type_id] == most_similar_type]
                        exp = exp.replace(f'?t{type_id + 1}', most_similar_type)
                sparql = expression_to_sparql(exp)
                answer = execute_query(sparql)
                query_count += 1
                if isinstance(answer, bool) or answer:
                    results.append(exp)
                    if len(results) == num_variants:
                        return results, query_count
    # top_k linking failed
    if not results:
        results = [sub_fn_to_mid(s_expression)]
    return results, query_count

In [4]:
for i, qa in enumerate(test_data):
    print('gold cores:')
    for j, core in enumerate(qa['s_expression_cores']):
        print(qa['s_expression_cores_fn'][j])
    print()

    print('predicted_cores:')
    for core in qa['predicted_cores']:
        print(core)
    print()

    print('calibrated cores:') 
    qa['calibrated_cores'] = []
    qa['calibrated_cores_fn'] = []
    qa['sparql_attempt_count'] = 0
    for core in qa['predicted_cores']:
        if not is_close(core):
            core = fix_core(core)
        candidates, query_count = bound_to_existed(core)
        for calibrated_core in candidates:
            calibrated_core_core_fn = sub_mid_to_fn(calibrated_core)
            print(calibrated_core_core_fn)
            qa['calibrated_cores'].append(calibrated_core)
            qa['calibrated_cores_fn'].append(calibrated_core_core_fn)
    qa['sparql_attempt_count'] += query_count
    print()
    print('--------------------------------')
    print()

json.dump(test_data, open(f'output/{exp_name}/calibrated_cores.json', 'w'), indent=2)
        

gold cores:
(AND (JOIN (R father) Ludovico_II,_Marquess_of_Saluzzo) (JOIN instance_of common_name))

predicted_cores:
(AND (JOIN (R father) Ludovico_II,_Marquess_of_Saluzzo) (JOIN sex_or_gender Male))

calibrated cores:
(AND (JOIN (R father) Ludovico_II,_Marquess_of_Saluzzo) (JOIN sex_or_gender Male))

--------------------------------

gold cores:
(AND (JOIN (R main_subject) Lessons_of_a_Dream) (JOIN instance_of sport))

predicted_cores:
(AND (JOIN main_subject Lessons_of_a_Dream) (JOIN instance_of concept))

calibrated cores:
(AND (JOIN (R main_subject) Lessons_of_a_Dream) (JOIN instance_of sport))

--------------------------------

gold cores:
(AND (JOIN (R member_of) Chris_Broderick) (JOIN instance_of musical_ensemble))

predicted_cores:
(AND (JOIN (R member_of) Chris_Broderick) (JOIN instance_of musical_ensemble))

calibrated cores:
(AND (JOIN (R member_of) Chris_Broderick) (JOIN instance_of musical_ensemble))

--------------------------------

gold cores:
(AND (JOIN (R place_of_de