In [1]:
import json
import numpy as np
from collections import Counter

In [2]:
kb_json = '../dataset/kb.json'

train_json = '../dataset/train.json'
val_json = '../dataset/val.json'
test_json = '../dataset/test.json'

### Dataset Functions

In [None]:
def string_clean(s: str) -> str:
    s = s.replace(',', ' and ')
    s = ' '.join(s.split())
    return s

def find_name(kb, id):
    try:
        return kb['entities'][id]['name']
    except:
        try:
            return kb['concepts'][id]['name']
        except:
            raise

In [None]:
def get_qualifier_relational_clean_fullname(kb_json, output=False, file_name='kb_q_r_clean_fullname.txt'):
    qualifier = set()
    kb = json.load(open(kb_json))
    for i in kb['entities']:
        fullname = kb['entities'][i]['name']
        for rel_dict in kb['entities'][i]['relations']:
            # First: add fact key, also called triple pairs
            statement = list()
            if rel_dict['direction'] == 'forward':
                statement += [string_clean(fullname), string_clean(rel_dict['predicate']), string_clean(find_name(kb, rel_dict['object']))]
            elif  rel_dict['direction'] == 'backward':
                statement += [string_clean(find_name(kb, rel_dict['object'])), string_clean(rel_dict['predicate']), string_clean(fullname)]

            for qk, qvs in rel_dict['qualifiers'].items():                
                # Second add qk - qv pairs, for qv that have more than one instance, seperate to single qk - qv pairs
                new_qvs = []
                for qv in qvs:
                    if qv['type'] == 'string':
                        new_qvs.append(string_clean(qv['value']))
                        
                if len(new_qvs) != 0:
                    for qv in new_qvs:
                        statement += [string_clean(qk), qv]
        
            # Third: Make sure the statement is qualifier 
            if len(statement) > 3:
                qualifier.add(tuple(statement))

    # qualifier = sorted(qualifier)

    if output:
        str_q = [",".join(q)+'\n' for q in qualifier]
        with open(file_name, 'w') as f:
            f.writelines(str_q)

    return qualifier

def get_relational_clean_fullname(kb_json, output=False, file_name='kb_r_clean_fullname.txt'):
    qualifier = set()
    kb = json.load(open(kb_json))
    for i in kb['entities']:
        fullname = kb['entities'][i]['name']
        
        # For instance of
        for concept_id in kb['entities'][i]['instanceOf']:
            statement = [string_clean(fullname), 'instance of', string_clean(find_name(kb, concept_id))]
            qualifier.add(tuple(statement))

        # For relation
        for rel_dict in kb['entities'][i]['relations']:
            # First: add fact key, also called triple pairs
            statement = list()
            if rel_dict['direction'] == 'forward':
                statement += [string_clean(fullname), string_clean(rel_dict['predicate']), string_clean(find_name(kb, rel_dict['object']))]
            elif  rel_dict['direction'] == 'backward':
                statement += [string_clean(find_name(kb, rel_dict['object'])), string_clean(rel_dict['predicate']), string_clean(fullname)]

            for qk, qvs in rel_dict['qualifiers'].items():                
                # Second add qk - qv pairs, for qv that have more than one instance, seperate to single qk - qv pairs
                new_qvs = []
                for qv in qvs:
                    if qv['type'] == 'string':
                        new_qvs.append(string_clean(qv['value']))
                        
                if len(new_qvs) != 0:
                    for qv in new_qvs:
                        statement += [string_clean(qk), qv]

            # Third: add statement
            qualifier.add(tuple(statement))

    qualifier = sorted(qualifier)

    if output:
        str_q = [",".join(q)+'\n' for q in qualifier]
        with open(file_name, 'w') as f:
            f.writelines(str_q)
    
    return qualifier


In [None]:
def get_qualifier_attributes_clean_fullname(kb_json, output=False, file_name='kb_q_a_clean_fullname.txt'):
    qualifier = set()
    kb = json.load(open(kb_json))
    for i in kb['entities']:
        fullname = kb['entities'][i]['name']

        # For attribute
        for att_dict in kb['entities'][i]['attributes']:
            # First: if it is literal, ignore it
            if att_dict['value']['type'] != 'string':
                continue
            else:
                # Second: add attributes
                statement = list()
                statement += [string_clean(fullname), string_clean(att_dict['key']), string_clean(att_dict['value']['value'])]

                for qk, qvs in att_dict['qualifiers'].items():                
                    # Third: add qk - qv pairs, for qv that have more than one instance, seperate to single qk - qv pairs
                    new_qvs = []
                    for qv in qvs:
                        if qv['type'] == 'string':
                            new_qvs.append(string_clean(qv['value']))
                            
                    if len(new_qvs) != 0:
                        for qv in new_qvs:
                            statement += [string_clean(qk), qv]
            
                # Fourth: Make sure the statement is qualifier 
                if len(statement) > 3:
                    qualifier.add(tuple(statement))

    # qualifier = sorted(qualifier)

    if output:
        str_q = [",".join(q)+'\n' for q in qualifier]
        with open(file_name, 'w') as f:
            f.writelines(str_q)
    
    return qualifier

def get_attributes_clean_fullname(kb_json, output=False, file_name='kb_a_clean_fullname.txt'):
    qualifier = set()
    kb = json.load(open(kb_json))
    for i in kb['entities']:
        fullname = kb['entities'][i]['name']

        # For attribute
        for att_dict in kb['entities'][i]['attributes']:
            # First: if it is literal, ignore it
            if att_dict['value']['type'] != 'string':
                continue
            else:
                # Second: add attributes
                statement = list()
                statement += [string_clean(fullname), string_clean(att_dict['key']), string_clean(att_dict['value']['value'])]

                for qk, qvs in att_dict['qualifiers'].items():                
                    # Third: add qk - qv pairs, for qv that have more than one instance, seperate to single qk - qv pairs
                    new_qvs = []
                    for qv in qvs:
                        if qv['type'] == 'string':
                            new_qvs.append(string_clean(qv['value']))
                            
                    if len(new_qvs) != 0:
                        for qv in new_qvs:
                            statement += [string_clean(qk), qv]
            
                # Fourth: Add statement
                qualifier.add(tuple(statement))

    qualifier = sorted(qualifier)

    if output:
        str_q = [",".join(q)+'\n' for q in qualifier]
        with open(file_name, 'w') as f:
            f.writelines(str_q)
    
    return qualifier

In [None]:
def get_all_clean_fullname(kb_json, output=False, file_name='kb_all_clean_fullname.txt'):
    qualifier = set()
    kb = json.load(open(kb_json))
    for i in kb['entities']:
        fullname = kb['entities'][i]['name']

        # For instance of
        for concept_id in kb['entities'][i]['instanceOf']:
            statement = [string_clean(fullname), 'instance of', string_clean(find_name(kb, concept_id))]
            qualifier.add(tuple(statement))

        # For attribute
        for att_dict in kb['entities'][i]['attributes']:
            # First: if it is literal, ignore it
            if att_dict['value']['type'] != 'string':
                continue
            else:
                # Second: add attributes
                statement = list()
                statement += [string_clean(fullname), string_clean(att_dict['key']), string_clean(att_dict['value']['value'])]

                for qk, qvs in att_dict['qualifiers'].items():                
                    # Third: add qk - qv pairs, for qv that have more than one instance, seperate to single qk - qv pairs
                    new_qvs = []
                    for qv in qvs:
                        if qv['type'] == 'string':
                            new_qvs.append(string_clean(qv['value']))
                            
                    if len(new_qvs) != 0:
                        for qv in new_qvs:
                            statement += [string_clean(qk), qv]
            
                # Fourth: Add statement
                qualifier.add(tuple(statement))

        # For relation
        for rel_dict in kb['entities'][i]['relations']:
            # First: add fact key, also called triple pairs
            statement = list()
            if rel_dict['direction'] == 'forward':
                statement += [string_clean(fullname), string_clean(rel_dict['predicate']), string_clean(find_name(kb, rel_dict['object']))]
            elif  rel_dict['direction'] == 'backward':
                statement += [string_clean(find_name(kb, rel_dict['object'])), string_clean(rel_dict['predicate']), string_clean(fullname)]

            for qk, qvs in rel_dict['qualifiers'].items():                
                # Second add qk - qv pairs, for qv that have more than one instance, seperate to single qk - qv pairs
                new_qvs = []
                for qv in qvs:
                    if qv['type'] == 'string':
                        new_qvs.append(string_clean(qv['value']))
                        
                if len(new_qvs) != 0:
                    for qv in new_qvs:
                        statement += [string_clean(qk), qv]
        
            # Third: Add statement 
            qualifier.add(tuple(statement))

    qualifier = sorted(qualifier)

    if output:
        str_q = [",".join(q)+'\n' for q in qualifier]
        with open(file_name, 'w') as f:
            f.writelines(str_q)
    
    return qualifier

In [None]:
q = get_all_clean_fullname(kb_json, output=True)

In [None]:
def test_q(kb_json, output=False, file_name='test.txt'):
    qualifier = set()
    kb = json.load(open(kb_json))
    for i in kb['entities']:
        for rel_dict in kb['entities'][i]['relations']:
            # First: add fact key, also called triple pairs
            statement = list()
            if rel_dict['direction'] == 'forward':
                statement += [i, string_clean(rel_dict['predicate']), rel_dict['object']]
            elif  rel_dict['direction'] == 'backward':
                statement += [rel_dict['object'], string_clean(rel_dict['predicate']), i]

            for qk, qvs in rel_dict['qualifiers'].items():                
                # Second add qk - qv pairs, for qv that have more than one instance, seperate to single qk - qv pairs
                new_qvs = []
                for qv in qvs:
                    if qv['type'] == 'string':
                        new_qvs.append(string_clean(qv['value']))
                        
                if len(new_qvs) != 0:
                    for qv in new_qvs:
                        statement += [string_clean(qk), qv]
        
            # Third: Make sure the statement is qualifier 
            if len(statement) > 3:
                qualifier.add(tuple(statement))
    
    qualifier = list(qualifier)
    new_qualifier = []

    for statement in qualifier:
        new_statement = list(statement)
        new_statement[0] = string_clean(find_name(kb, statement[0]))
        new_statement[2] = string_clean(find_name(kb, statement[2]))
        new_qualifier.append(tuple(new_statement))

    new_qualifier = sorted(new_qualifier)

    if output:
        str_q = [",".join(q)+'\n' for q in new_qualifier]
        with open(file_name, 'w') as f:
            f.writelines(str_q)

    return new_qualifier

In [None]:
def random_sampling(s: set, split: list=[0.85, 0.15]):
    str_l = [",".join(q)+'\n' for q in s]
    str_l = np.array(str_l)
    length = len(str_l)
    permutation = np.random.permutation(length).reshape(-1)
    trn_length = np.round(length * split[0]).astype(int)
    # vld_length = np.round(length * split[1])
    tst_length = np.round(length * split[1]).astype(int)
    # assert (trn_length + vld_length + tst_length) == length
    assert (trn_length + tst_length) == length
    trn = str_l[permutation[0:trn_length]]
    # vld = str_l[permutation[trn_length:trn_length+vld_length]]
    # tst = str_l[permutation[trn_length+vld_length:length]]
    tst = str_l[permutation[trn_length:length]]

    with open("train.txt", 'w')as f:
        f.writelines(trn)
    with open("test.txt", 'w')  as f:
        f.writelines(tst)

### New Experiment

In [40]:
train = json.load(open(train_json))
test = json.load(open(val_json))
length_train = len(train)
length_test = len(test)
length_train, length_test

(94376, 11797)

In [None]:
# train = json.load(open(train_json))
# length = len(train)
for i in range(100):
    q = train[i]
    # print(q['program'])
    # print(q['sparql'])
    # print(q['answer'])
    # print(q['choices'])
    # print(q['question'])
    # print()
    if 'COUNT' in q['sparql']:
        print(q['sparql'])
        print(q['choices'])
        print(q['answer'])
        print()

##### The first type of program: QueryName

In [None]:
QueryName = list()
for statement in train:
    if statement['program'][0]['function'] == 'Find' and statement['program'][-1]['function'] == 'What':
        QueryName.append(statement)
len(QueryName)

#### Program to query graph

In [None]:
def program_to_graph(program):

    entities = dict()
    triple = list()

    for idx, block in enumerate(program):
        function = block['function']
        dependencies = block['dependencies']
        inputs = block['inputs']

        if function == 'FindAll':
            '''
            Find all entities in the kb. This is hard to dispose.
            '''
            pass
        elif function == 'Find':
            '''
            Find all entities with the name.
            '''
            entities[idx] = inputs[0]
            
        elif function == 'FilterConcept':
            pass
            entities[idx] = inputs[0]

        elif function in ['FilterStr', 'FilterNum', 'FilterYear', 'FilterDate']:
            pass
        elif function in ['QFilterStr', 'QFilterNum', 'QFilterYear', 'QFilterDate']:
            pass
        elif function in 'Relate':
            pass
        elif function in ['And', 'Or']:
            pass
        elif function == 'What':
            pass
        elif function == 'Count':
            pass
        elif function in ['QueryAttr', 'QueryAttrUnderCondition']:
            pass
        elif function == 'QueryRelation':
            pass
        elif function in ['SelectBetween', 'SelectAmong']:
            pass
        elif function in  ['VerifyStr', 'VerifyNum', 'VerifyYear', 'VerifyDate']:
            pass
        elif function in ['QueryAttrQualifier', 'QueryRelationQualifier']:
            pass

#### SPARQL to query guery graph

In [4]:
import re
import shlex
import copy

def string_clean(s: str) -> str:
    s = s.replace(',', ' and ')
    s = ' '.join(s.split())
    return s

In [32]:
def sparql_to_graph(sparql):
    PRED_INSTANCE = 'pred:instance_of'
    PRED_NAME = 'pred:name'

    PRED_VALUE = 'pred:value'       # link packed value node to its literal value
    PRED_UNIT = 'pred:unit'         # link packed value node to its unit

    PRED_YEAR = 'pred:year'         # link packed value node to its year value, which is an integer
    PRED_DATE = 'pred:date'         # link packed value node to its date value, which is a date

    PRED_FACT_H = 'pred:fact_h'     # link qualifier node to its head
    PRED_FACT_R = 'pred:fact_r'
    PRED_FACT_T = 'pred:fact_t'

    SPECIAL_PREDICATES = (PRED_INSTANCE, PRED_NAME, PRED_VALUE, PRED_UNIT, PRED_YEAR, PRED_DATE, PRED_FACT_H, PRED_FACT_R, PRED_FACT_T)

    target = None

    """
    Some sparql have UNION inside. Ingore them at this stage.
    """

    if sparql.startswith('SELECT DISTINCT ?e'):
        parse_type = 'entity'
        target = '?e'
    elif sparql.startswith('SELECT ?e'):
        parse_type = 'sort'
        target = 'first ?e'
    elif sparql.startswith('SELECT (COUNT(DISTINCT ?e)'):
        parse_type = 'count'
        target = 'count ?e'
    elif sparql.startswith('SELECT DISTINCT ?p '):
        parse_type = 'pred'
        target = '?p'
    elif sparql.startswith('ASK'):
        parse_type = 'bool'
        target = 'bool'
    else:
        parse_type = 'attr'
        tokens = sparql.split()
        target = tokens[2]
        
        """
        Should consider attributes selection here, but it is complex at first glance. Ignore it first and 
        I will implemented it later
        """
        pass

    case = 0
    triples = None

    '''
    Check if sort
    '''
    sort_identity = sparql.split('{', maxsplit=1)[1].rsplit('}', maxsplit=1)[1]
    
    '''
    0 - Normal case
    1 - UNION exist
    2 - "\'" exist
    3 -  BOTH 1 and 2 happens
    4 - 2 happens and "\'" in pred and will cause the shlex throw error
    '''
    if 'UNION' in sparql:
        case = 1
        triples = sparql.split('{', maxsplit=1)[1].rsplit('}', maxsplit=1)[0]
        match = re.fullmatch(r'''(.*?){(.*?)} UNION {(.*?)}(.*)''', triples)
        four_triples = match.groups()
        '''
        Now the four_triples contain four triples and [2:3] are union based
        '''
        triples = []
        for idx, group in enumerate(four_triples):
            _gs = re.split(r'''\.(?=(?:[^"]|"[^"]*")*$)''', group)
            _gs = [_g.strip() for _g in _gs]
            if idx == 0:
                _gs.append('SEPARATER{')
            elif idx == 1:
                _gs.append('}SEPARATER{')
            elif idx == 2:
                _gs.append('}SEPARATER')
            triples += _gs

    if '\'' in sparql:
        case = 2 if case == 0 else 3
        if case == 2:
            triples = sparql.split('{')[1].split('}')[0]
            triples = re.split(r'''\.(?=(?:[^"]|"[^"]*")*$)''', triples)
            triples = [triple.strip() for triple in triples]
        else:
            pass

    if case == 0: # Normal case
        triples = sparql.split('{')[1].split('}')[0]
        triples = re.split(r'''\.(?=(?:[^"]|"[^"]*")*$)''', triples)
        triples = [triple.strip() for triple in triples]
    
    '''
    Match the space: if there are even number of " or ' after the space, use it as the delimilator
    The re is hard to match escape quote
    '''
    seperated_triples = []

    if case == 2 or case == 3:
        '''
        There is a case that r\' in pred and make the whole string quote in double quotes and cannot be recognized by shlex
        A question that index is 3060~3070 in train is a case
        '''
        for triple in triples:
            try:
                new_triple = shlex.split(triple)
            except:
                case = 4
                return case, target, parse_type, None
            if not (len(new_triple) == 0 or new_triple == ''):
                seperated_triples.append(new_triple)
    if case == 0 or case == 1:
        for triple in triples:
            new_triple = shlex.split(triple)
            if not (len(new_triple) == 0 or new_triple == ''):
                seperated_triples.append(new_triple)
    
    '''
    Now make all seperated triple to the format we want
    '''
    disposed_triples = []
    for triple in seperated_triples:
        if len(triple) == 3:
            r = triple[1]
            if r.startswith('<'):
                if PRED_INSTANCE in r:
                    # Ignore pred
                    relation = r[6:-1].replace('_', ' ')
                    new_triple = (string_clean(triple[0]), string_clean(relation), string_clean(triple[2]))
                    disposed_triples.append(new_triple)
                # elif PRED_NAME in r:
                #     pass
                # elif PRED_VALUE in r:
                #     pass
                # elif PRED_UNIT in r:
                #     pass
                # elif PRED_YEAR in r:
                #     pass
                # elif PRED_DATE in r:
                #     pass
                else:
                    # A normal predicate/relation 
                    relation = r[1:-1].replace('_', ' ')
                    new_triple = (string_clean(triple[0]), string_clean(relation), string_clean(triple[2]))
                    disposed_triples.append(new_triple)
            else:
                new_triple = (string_clean(triple[0]), string_clean(triple[1]), string_clean(triple[2]))
                disposed_triples.append(new_triple)

        elif len(triple) != 3:
            if triple[0] == '[':
                pred = triple[10]
                if pred.startswith('<'):
                    pred = pred[1:-1].replace('_', ' ')
                attr = triple[11]
                fact_h = triple[2]
                fact_r = triple[5]
                if fact_r.startswith('<'):
                    fact_r = fact_r[1:-1].replace('_', ' ')
                fact_t = triple[8]
                qualifier_nodes = [string_clean(fact_h), string_clean(fact_r), string_clean(fact_t)]
                new_triple = (*qualifier_nodes, string_clean(pred), string_clean(attr))
                disposed_triples.append(new_triple)
            elif triple[0] == 'FILTER':
                new_triple = " ".join(triple)
                disposed_triples.append(tuple([new_triple]))
            elif 'SEPARATER' in triple[0]:
                disposed_triples.append(tuple(triple))

    '''
    Add order finally
    '''
    if 'ORDER' in sort_identity:
        disposed_triples.append(tuple([string_clean(sort_identity)]))
    '''
    Post process
    '''
    if parse_type == 'name':
        pass
    elif parse_type == 'count':
        pass
    elif parse_type == 'bool':
        pass
    elif parse_type == 'pred':
        pass
    elif parse_type == 'attr': 
        pass

    return case, parse_type, target, disposed_triples

#### Some previous test

In [None]:
with open('sparql_triple.txt', 'w') as f:
    for i in range(length_train):
        test = train[i]['sparql']
        c, t, d = sparql_to_graph(test)
        f.write(f'target:{t}\n')
        if d is not None:
            for a in d:
                str_a = ",".join(a)
                f.writelines(str_a)
                f.write('\n')
            f.write('\n')
        else:
            f.write('None')
            f.write('\n')

In [None]:
def statement_simplification(case, target, statement):
    if case == 4:
        return None

    statement_refine = []
    for triple in statement:
        if len(triple) == 3:
            if triple[1].startswith('pred:') and triple[1] != 'pred:instance of' and triple[1] != 'pred:name':
                pass
            elif triple[1] == 'pred:instance of':
                statement_refine.append((triple[0], 'instance of', triple[2]))
            elif triple[1] == 'pred:name':
                statement_refine.append((triple[0], 'NAME'))
            else:
                if triple[1] != '?p':
                    if 'e' in triple[2]:
                        statement_refine.append((triple[0], 'relational', triple[2]))
                    else:
                        statement_refine.append((triple[0], 'literal', triple[2]))
                else:
                    statement_refine.append(triple)
        elif len(triple) != 3:
            if 'FILTER' in triple[0]:
                statement_refine.append(tuple(['OPERATION']))
            elif 'SEPARATER' in triple[0]:
                if triple[0] == 'SEPARATER{':
                    statement_refine.append(tuple(['[']))
                elif triple[0] == '}SEPARATER{':
                    statement_refine.append(tuple(['UNION']))
                else:
                    statement_refine.append(tuple([']']))
            elif len(triple) == 5: # Must be qualifier
                if 'e' in triple[2]:
                    statement_refine.append((triple[0], 'relational', triple[2], 'qualifier', triple[4]))
                else:
                    statement_refine.append((triple[0], 'literal', triple[2], 'qualifier', triple[4]))
    return statement_refine

In [None]:
test_t = []
with open('test_origin.txt', 'w') as f, open('test_simply.txt', 'w') as fs:
    for i in range(length):
        test = train[i]['sparql']
        c, t, d = sparql_to_graph(test)
        test_t.append((t, d))
        sf = statement_simplification(c, t, d)

        if d is not None:
            f.write(train[i]['question'] + '\n')
            f.write(f'ID: {i}, target:{t}\n---\n')
            for a in d:
                str_a = ",".join(a)
                f.writelines(str_a)
                f.write('\n')
            f.write('---\n')
            f.write('\n')


        if sf is not None:
            fs.write(train[i]['question'] + '\n')
            fs.write(f'ID: {i}, target:{t}\n---\n')
            for a in sf:
                str_a = ",".join(a)
                fs.writelines(str_a)
                fs.write('\n')
            fs.write('---\n')
            fs.write('\n')
        

In [None]:
x = 264
print(train[x]['sparql'])
print(train[x]['program'])
c, t, d = sparql_to_graph(train[x]['sparql'])
c, t, d

#### filter out relational query and simplify

In [33]:
def graph_simplifier_rough_no_literal(case, parse_type, target, query_graph):

    PRED_VALUE = 'pred:value'       # link packed value node to its literal value
    PRED_UNIT = 'pred:unit'         # link packed value node to its unit

    PRED_YEAR = 'pred:year'         # link packed value node to its year value, which is an integer
    PRED_DATE = 'pred:date'  

    if (case == 1) or (case == 3) or (case == 4):
        return
    for statement in query_graph:
        if len(statement) > 1:
            if (PRED_UNIT in statement[1]) or (PRED_YEAR in statement[1]) or (PRED_DATE in statement[1]):
                return

    substitution_name_dict = dict()
    substitution_value_dict = dict()
    for statement in query_graph:
        if len(statement) == 3: # Normal case, no qualifier, no filter
            if statement[1] == 'pred:name':
                substitution_name_dict[statement[0]] = statement[2]
            if statement[1] == 'pred:value':
                substitution_value_dict[statement[0]] = statement[2]

    output_graph = []
    name_keys = list(substitution_name_dict.keys())
    value_keys = list(substitution_value_dict.keys())
    for statement in query_graph:
        if len(statement) == 3 or len(statement) == 5: # Normal case, no qualifier, no filter
            if statement[1] == 'pred:name' or statement[1] == 'pred:value':
                pass
            else:
                ### Refine the statement
                if  len(statement) == 3:
                    new_statement_0 = statement[0]
                    new_statement_2 = statement[2]
                    if statement[0] in name_keys:
                        new_statement_0 = substitution_name_dict[statement[0]]
                    if statement[2] in name_keys:
                        new_statement_2 = substitution_name_dict[statement[2]]

                    if statement[0] in value_keys:
                        new_statement_0 = substitution_value_dict[statement[0]]
                    if statement[2] in value_keys:
                        new_statement_2 = substitution_value_dict[statement[2]]
                    
                    output_graph.append((new_statement_0, statement[1], new_statement_2))
                elif len(statement) == 5:
                    new_statement_0 = statement[0]
                    new_statement_2 = statement[2]
                    new_statement_4 = statement[4]
                    if statement[0] in name_keys:
                        new_statement_0 = substitution_name_dict[statement[0]]
                    if statement[2] in name_keys:
                        new_statement_2 = substitution_name_dict[statement[2]]
                    if statement[4] in name_keys:
                        new_statement_4 = substitution_name_dict[statement[4]]
                    
                    if statement[0] in value_keys:
                        new_statement_0 = substitution_value_dict[statement[0]]
                    if statement[2] in value_keys:
                        new_statement_2 = substitution_value_dict[statement[2]]
                    if statement[4] in value_keys:
                        new_statement_4 = substitution_value_dict[statement[4]]
                    
                    output_graph.append((new_statement_0, statement[1], new_statement_2, statement[3], new_statement_4))
        else:
            # print("Special graph")
            output_graph.append(statement)
    
    # Check redundant qualifier and triple
    statement_need_remove = []
    for statement in output_graph:
        if len(statement) == 5:
            for st in output_graph:
                if len(st) == 3 and (st[0] == statement[0]) and (st[1] == statement[1]) and (st[0] == statement[0]):
                    statement_need_remove.append(st)
    for st in statement_need_remove:
        output_graph.remove(st)

    return output_graph

In [34]:
relational = 0
with open('test_no_literal_modified.txt', 'w', encoding='utf-8') as f:
    for i in range(length_train):
        test = train[i]['sparql']
        c, p, t, d = sparql_to_graph(test)
        d = graph_simplifier_rough_no_literal(c, p, t, d)

        if d is not None:
            f.write(train[i]['question'] + '\n')
            f.write(f'ID: {i}, target:{t}\n---\n')
            for a in d:
                str_a = ",".join(a)
                f.writelines(str_a)
                f.write('\n')
            f.write('---\n') 
            f.write(train[i]['answer'] + '\n')
            f.write('\n')       
            relational += 1
relational 

52547

#### Further filter out as multihop, qualifier or verify

In [50]:
'''
The graph to this step: have processed so that no "UNION", no "FILTER", no "<pred:...>"
'''

def retrieve_multihop(case, parse_type, target, query_graph):
    entities = set()
    if parse_type == 'count':
        return None
    if query_graph is None:
        return None

    for statement in query_graph:
        if '?e' in statement[0]:
            entities.add(statement[0])
        if '?e' in statement[2]:
            entities.add(statement[2])
    
    if len(entities) > 1:
        return query_graph
    return None

def retrieve_qualifier_qpv(case, parse_type, target, query_graph):
    if parse_type == 'attr' and target == '?qpv':
        return query_graph
    return None

def retrieve_qualifier_other(case, parse_type, target, query_graph):
    if parse_type == 'attr' and target == '?qpv':
        return None
    if query_graph is None:
        return None
    for statement in query_graph:
        if len(statement) == 5:
            return query_graph
    return None

def retrieve_verify(case, parse_type, target, query_graph):
    if parse_type == 'bool' and target == 'bool':
        return query_graph
    return None

In [36]:
def retrieve_entity(case, parse_type, target, query_graph):

    if target == '?e':
        return query_graph
    return None

In [48]:
def qpv_no_literal(choice_0):
    try:
        c = float(choice_0)
    except:
        c = re.fullmatch('''\d{4}-\d{2}-\d{2}''', choice_0)
        if c is None:
            return True
    return False
    

In [37]:
def add_id(target, query_graph, id):
    
    if ('?e' in target) or ('?pv' in target) or ('?qpv' in target) or ('?p' in target):
        id_target = target + '_' + str(id)
    else:
        id_target = target

    id_graph = []
    for statement in query_graph:
        if len(statement) == 3:
            s, r, o = statement
            if ('?e' in s) or ('?pv' in s):
                s = s  + '_' + str(id)
            if ('?e' in o) or ('?pv' in o):
                o = o  + '_' + str(id)
            if '?p' in r:
                r = r + '_' + str(id)
            id_graph.append((s, r, o))
            
        if len(statement) == 5:
            s, r, o, k, v = statement
            if ('?e' in s) or ('?pv' in s):
                s = s  + '_' + str(id)
            if ('?e' in o) or ('?pv' in o):
                o = o  + '_' + str(id)
            if '?p' in r:
                r = r + '_' + str(id)
            if '?qpv' in v:
                v = v + '_' + str(id)
            id_graph.append((s, r, o, k, v))
            
    return id_target, id_graph

#### Generate training/test

In [14]:
count = 0
with open('./multihop/train.txt', 'w', encoding='utf-8') as f:
    for i in range(length_train):
        qry = train[i]['sparql']
        c, p, t, d = sparql_to_graph(qry)
        d = graph_simplifier_rough_no_literal(c, p, t, d)
        d = retrieve_multihop(c, p, t, d)

        if d is not None:
            choices_list = [string_clean(s) for s in train[i]['choices']]
            choices = ",".join(choices_list)

            f.write(t + '\t')
            for a in d:
                str_a = ",".join(a)
                f.writelines(str_a)
                f.write('\t')
            f.write(choices + '\t')
            f.write(string_clean(train[i]['answer']))
            f.write('\n')       
            count += 1
count 

1875

In [15]:
count = 0
with open('./multihop/test.txt', 'w', encoding='utf-8') as f:
    for i in range(length_test):
        qry = test[i]['sparql']
        c, p, t, d = sparql_to_graph(qry)
        d = graph_simplifier_rough_no_literal(c, p, t, d)
        d = retrieve_multihop(c, p, t, d)

        if d is not None:
            choices_list = [string_clean(s) for s in test[i]['choices']]
            choices = ",".join(choices_list)

            f.write(t + '\t')
            for a in d:
                str_a = ",".join(a)
                f.writelines(str_a)
                f.write('\t')
            f.write(choices + '\t')
            f.write(string_clean(test[i]['answer']))
            f.write('\n')       
            count += 1
count 

260

In [23]:
count = 0
with open('./qualifier/train.txt', 'w', encoding='utf-8') as f:
    for i in range(length_train):
        qry = train[i]['sparql']
        c, p, t, d = sparql_to_graph(qry)
        d = graph_simplifier_rough_no_literal(c, p, t, d)
        d = retrieve_qualifier_qpv(c, p, t, d)

        if d is not None:
            choices_list = [string_clean(s) for s in train[i]['choices']]
            choices = ",".join(choices_list)

            f.write(t + '\t')
            for a in d:
                str_a = ",".join(a)
                f.writelines(str_a)
                f.write('\t')
            f.write(choices + '\t')
            f.write(string_clean(train[i]['answer']))
            f.write('\n')       
            count += 1
count 

10555

In [24]:
count = 0
with open('./qualifier/test.txt', 'w', encoding='utf-8') as f:
    for i in range(length_test):
        qry = test[i]['sparql']
        c, p, t, d = sparql_to_graph(qry)
        d = graph_simplifier_rough_no_literal(c, p, t, d)
        d = retrieve_qualifier_qpv(c, p, t, d)

        if d is not None:
            choices_list = [string_clean(s) for s in test[i]['choices']]
            choices = ",".join(choices_list)

            f.write(t + '\t')
            for a in d:
                str_a = ",".join(a)
                f.writelines(str_a)
                f.write('\t')
            f.write(choices + '\t')
            f.write(string_clean(test[i]['answer']))
            f.write('\n')       
            count += 1
count 

1316

In [25]:
count = 0
with open('./qualifier/train_2.txt', 'w', encoding='utf-8') as f:
    for i in range(length_train):
        qry = train[i]['sparql']
        c, p, t, d = sparql_to_graph(qry)
        d = graph_simplifier_rough_no_literal(c, p, t, d)
        d = retrieve_qualifier_other(c, p, t, d)

        if d is not None:
            choices_list = [string_clean(s) for s in train[i]['choices']]
            choices = ",".join(choices_list)

            f.write(t + '\t')
            for a in d:
                str_a = ",".join(a)
                f.writelines(str_a)
                f.write('\t')
            f.write(choices + '\t')
            f.write(string_clean(train[i]['answer']))
            f.write('\n')       
            count += 1
count 

1533

In [26]:
count = 0
with open('./qualifier/test_2.txt', 'w', encoding='utf-8') as f:
    for i in range(length_test):
        qry = test[i]['sparql']
        c, p, t, d = sparql_to_graph(qry)
        d = graph_simplifier_rough_no_literal(c, p, t, d)
        d = retrieve_qualifier_other(c, p, t, d)

        if d is not None:
            choices_list = [string_clean(s) for s in test[i]['choices']]
            choices = ",".join(choices_list)

            f.write(t + '\t')
            for a in d:
                str_a = ",".join(a)
                f.writelines(str_a)
                f.write('\t')
            f.write(choices + '\t')
            f.write(string_clean(test[i]['answer']))
            f.write('\n')       
            count += 1
count 

196

In [12]:
count = 0
with open('./verify/train.txt', 'w', encoding='utf-8') as f:
    for i in range(length_train):
        qry = train[i]['sparql']
        c, p, t, d = sparql_to_graph(qry)
        d = graph_simplifier_rough_no_literal(c, p, t, d)
        d = retrieve_verify(c, p, t, d)

        if d is not None:
            for a in d:
                str_a = ",".join(a)
                f.writelines(str_a)
                f.write('\t')
            f.write(train[i]['answer'])
            f.write('\n')       
            count += 1
count 

4961

In [13]:
count = 0
with open('./verify/test.txt', 'w', encoding='utf-8') as f:
    for i in range(length_test):
        qry = test[i]['sparql']
        c, p, t, d = sparql_to_graph(qry)
        d = graph_simplifier_rough_no_literal(c, p, t, d)
        d = retrieve_verify(c, p, t, d)

        if d is not None:
            for a in d:
                str_a = ",".join(a)
                f.writelines(str_a)
                f.write('\t')
            f.write(test[i]['answer'])
            f.write('\n')       
            count += 1
count 

632

#### Generate ID training/test

In [19]:
count = 0
with open('./multihop/train_id.txt', 'w', encoding='utf-8') as f:
    for i in range(length_train):
        qry = train[i]['sparql']
        c, p, t, d = sparql_to_graph(qry)
        d = graph_simplifier_rough_no_literal(c, p, t, d)
        d = retrieve_multihop(c, p, t, d)

        if d is not None:
            t, d = add_id(t, d, i)
            choices_list = [string_clean(s) for s in train[i]['choices']]
            choices = ",".join(choices_list)

            f.write(t + '\t')
            for a in d:
                str_a = ",".join(a)
                f.writelines(str_a)
                f.write('\t')
            f.write(choices + '\t')
            f.write(string_clean(train[i]['answer']))
            f.write('\n')       
            count += 1
count 

1875

In [20]:
count = 0
with open('./multihop/test_id.txt', 'w', encoding='utf-8') as f:
    for i in range(length_test):
        qry = test[i]['sparql']
        c, p, t, d = sparql_to_graph(qry)
        d = graph_simplifier_rough_no_literal(c, p, t, d)
        d = retrieve_multihop(c, p, t, d)

        if d is not None:
            t, d = add_id(t, d, i)
            choices_list = [string_clean(s) for s in test[i]['choices']]
            choices = ",".join(choices_list)

            f.write(t + '\t')
            for a in d:
                str_a = ",".join(a)
                f.writelines(str_a)
                f.write('\t')
            f.write(choices + '\t')
            f.write(string_clean(test[i]['answer']))
            f.write('\n')       
            count += 1
count 

260

In [21]:
count = 0
with open('./qualifier/train_id.txt', 'w', encoding='utf-8') as f:
    for i in range(length_train):
        qry = train[i]['sparql']
        c, p, t, d = sparql_to_graph(qry)
        d = graph_simplifier_rough_no_literal(c, p, t, d)
        d = retrieve_qualifier_qpv(c, p, t, d)

        if d is not None:
            t, d = add_id(t, d, i)
            choices_list = [string_clean(s) for s in train[i]['choices']]
            choices = ",".join(choices_list)

            f.write(t + '\t')
            for a in d:
                str_a = ",".join(a)
                f.writelines(str_a)
                f.write('\t')
            f.write(choices + '\t')
            f.write(string_clean(train[i]['answer']))
            f.write('\n')       
            count += 1
count 

10555

In [22]:
count = 0
with open('./qualifier/test_id.txt', 'w', encoding='utf-8') as f:
    for i in range(length_test):
        qry = test[i]['sparql']
        c, p, t, d = sparql_to_graph(qry)
        d = graph_simplifier_rough_no_literal(c, p, t, d)
        d = retrieve_qualifier_qpv(c, p, t, d)

        if d is not None:
            t, d = add_id(t, d, i)
            choices_list = [string_clean(s) for s in test[i]['choices']]
            choices = ",".join(choices_list)

            f.write(t + '\t')
            for a in d:
                str_a = ",".join(a)
                f.writelines(str_a)
                f.write('\t')
            f.write(choices + '\t')
            f.write(string_clean(test[i]['answer']))
            f.write('\n')       
            count += 1
count 

1316

In [23]:
count = 0
with open('./verify/train_id.txt', 'w', encoding='utf-8') as f:
    for i in range(length_train):
        qry = train[i]['sparql']
        c, p, t, d = sparql_to_graph(qry)
        d = graph_simplifier_rough_no_literal(c, p, t, d)
        d = retrieve_verify(c, p, t, d)

        if d is not None:
            for a in d:
                t, d = add_id(t, d, i)
                str_a = ",".join(a)
                f.writelines(str_a)
                f.write('\t')
            f.write(train[i]['answer'])
            f.write('\n')       
            count += 1
count 

4961

In [24]:
count = 0
with open('./verify/test_id.txt', 'w', encoding='utf-8') as f:
    for i in range(length_test):
        qry = test[i]['sparql']
        c, p, t, d = sparql_to_graph(qry)
        d = graph_simplifier_rough_no_literal(c, p, t, d)
        d = retrieve_verify(c, p, t, d)

        if d is not None:
            t, d = add_id(t, d, i)
            for a in d:
                str_a = ",".join(a)
                f.writelines(str_a)
                f.write('\t')
            f.write(test[i]['answer'])
            f.write('\n')       
            count += 1
count 

632

#### Generate dataset with pure type

In [44]:
count = 0
with open('./entity/train.txt', 'w', encoding='utf-8') as f:
    for i in range(length_train):
        qry = train[i]['sparql']
        c, p, t, d = sparql_to_graph(qry)
        d = graph_simplifier_rough_no_literal(c, p, t, d)
        d = retrieve_entity(c, p, t, d)

        if d is not None:
            choices_list = [string_clean(s) for s in train[i]['choices']]
            choices = ",".join(choices_list)

            f.write(t + '\t')
            for a in d:
                str_a = ",".join(a)
                f.writelines(str_a)
                f.write('\t')
            f.write(choices + '\t')
            f.write(string_clean(train[i]['answer']))
            f.write('\n')       
            count += 1
count 


8525

In [45]:
count = 0
with open('./entity/test.txt', 'w', encoding='utf-8') as f:
    for i in range(length_test):
        qry = test[i]['sparql']
        c, p, t, d = sparql_to_graph(qry)
        d = graph_simplifier_rough_no_literal(c, p, t, d)
        d = retrieve_entity(c, p, t, d)

        if d is not None:
            choices_list = [string_clean(s) for s in test[i]['choices']]
            choices = ",".join(choices_list)

            f.write(t + '\t')
            for a in d:
                str_a = ",".join(a)
                f.writelines(str_a)
                f.write('\t')
            f.write(choices + '\t')
            f.write(string_clean(test[i]['answer']))
            f.write('\n')       
            count += 1
count 

1142

In [38]:
count = 0
with open('./entity/train_id.txt', 'w', encoding='utf-8') as f:
    for i in range(length_train):
        qry = train[i]['sparql']
        c, p, t, d = sparql_to_graph(qry)
        d = graph_simplifier_rough_no_literal(c, p, t, d)
        d = retrieve_entity(c, p, t, d)

        if d is not None:
            t, d = add_id(t, d, i)
            choices_list = [string_clean(s) for s in train[i]['choices']]
            choices = ",".join(choices_list)

            f.write(t + '\t')
            for a in d:
                str_a = ",".join(a)
                f.writelines(str_a)
                f.write('\t')
            f.write(choices + '\t')
            f.write(string_clean(train[i]['answer']))
            f.write('\n')       
            count += 1
count 

8525

In [41]:
count = 0
with open('./entity/test_id.txt', 'w', encoding='utf-8') as f:
    for i in range(length_test):
        qry = test[i]['sparql']
        c, p, t, d = sparql_to_graph(qry)
        d = graph_simplifier_rough_no_literal(c, p, t, d)
        d = retrieve_entity(c, p, t, d)

        if d is not None:
            t, d = add_id(t, d, i)
            choices_list = [string_clean(s) for s in test[i]['choices']]
            choices = ",".join(choices_list)

            f.write(t + '\t')
            for a in d:
                str_a = ",".join(a)
                f.writelines(str_a)
                f.write('\t')
            f.write(choices + '\t')
            f.write(string_clean(test[i]['answer']))
            f.write('\n')       
            count += 1
count 

1142

In [51]:
count = 0
with open('./qualifier/train_id_clean.txt', 'w', encoding='utf-8') as f:
    for i in range(length_train):
        qry = train[i]['sparql']
        c, p, t, d = sparql_to_graph(qry)
        d = graph_simplifier_rough_no_literal(c, p, t, d)
        d = retrieve_qualifier_qpv(c, p, t, d)

        if d is not None:
            t, d = add_id(t, d, i)
            choices_list = [string_clean(s) for s in train[i]['choices']]

            flag = qpv_no_literal(choices_list[0])
            if not flag: continue
            choices = ",".join(choices_list)

            f.write(t + '\t')
            for a in d:
                str_a = ",".join(a)
                f.writelines(str_a)
                f.write('\t')
            f.write(choices + '\t')
            f.write(string_clean(train[i]['answer']))
            f.write('\n')       
            count += 1
count 

5700

In [52]:
count = 0
with open('./qualifier/test_id_clean.txt', 'w', encoding='utf-8') as f:
    for i in range(length_test):
        qry = test[i]['sparql']
        c, p, t, d = sparql_to_graph(qry)
        d = graph_simplifier_rough_no_literal(c, p, t, d)
        d = retrieve_qualifier_qpv(c, p, t, d)

        if d is not None:
            t, d = add_id(t, d, i)
            choices_list = [string_clean(s) for s in test[i]['choices']]
            flag = qpv_no_literal(choices_list[0])
            if not flag: continue
            
            choices = ",".join(choices_list)

            f.write(t + '\t')
            for a in d:
                str_a = ",".join(a)
                f.writelines(str_a)
                f.write('\t')
            f.write(choices + '\t')
            f.write(string_clean(test[i]['answer']))
            f.write('\n')       
            count += 1
count 

705

#### Find some cases

In [None]:
def find_all_target_entity_without_filter(case, parse_type, target, query_graph):
    if case != 4 and target == '?e' and parse_type == 'entity':
        return query_graph
    return None


In [None]:
with open('entity_qry.txt', 'w') as f:
    for i in range(length_train):
        qry = train[i]['sparql']
        c, p, t, d = sparql_to_graph(qry)
        d = find_all_target_entity(c, p, t, d)
        
        if d is not None:
            choices_list = [string_clean(s) for s in train[i]['choices']]
            choices = ",".join(choices_list)

            f.write(t + '\t')
            for a in d:
                str_a = ",".join(a)
                f.writelines(str_a)
                f.write('\t')
            f.write(choices + '\t')
            f.write(string_clean(train[i]['answer']))
            f.write('\n')       
            count += 1
print(count)


In [None]:
def find_all_compare_entity(case, parse_type, target, query_graph):
    if case != 4 and target == '?e' and parse_type == 'sort':
        return query_graph
    return None


In [None]:
count = 0
with open('entity_compare.txt', 'w') as f:
    for i in range(length_train):
        qry = train[i]['sparql']
        c, p, t, d = sparql_to_graph(qry)
        d = find_all_compare_entity(c, p, t, d)
        
        if d is not None:
            t, d = add_id(t, d, i)
            choices_list = [string_clean(s) for s in train[i]['choices']]
            choices = ",".join(choices_list)

            f.write(t + '\t')
            for a in d:
                str_a = ",".join(a)
                f.writelines(str_a)
                f.write('\t')
            f.write(choices + '\t')
            f.write(string_clean(train[i]['answer']))
            f.write('\n')       
            count += 1
print(count)