In [19]:
get_in_relations('m.04j5sk8')

{'government.government_office_category.officeholders',
 'government.government_office_or_title.office_holders',
 'government.governmental_jurisdiction.governing_officials',
 'government.politician.government_positions_held',
 'type.type.instance'}

In [1]:
from collections import defaultdict
from typing import List, Tuple
from SPARQLWrapper import SPARQLWrapper, JSON
import json
import urllib
from pathlib import Path
from tqdm import tqdm
FREEBASE_SPARQL_WRAPPER_URL = "http://localhost:3001/sparql"
FREEBASE_ODBC_PORT = 13001
sparql = SPARQLWrapper(FREEBASE_SPARQL_WRAPPER_URL)
sparql.setReturnFormat(JSON)

path = "/media/disk1/chatgpt/zh/Freebase-Setup/virtuoso-opensource"

with open('/media/disk1/chatgpt/zh/ChatKBQA/ontology/fb_roles', 'r') as f:
    contents = f.readlines()

roles = set()
for line in contents:
    fields = line.split()
    roles.add(fields[1])

# connection for freebase
odbc_conn = None
def initialize_odbc_connection():
    global odbc_conn
    odbc_conn = pyodbc.connect(
        f'DRIVER={path}/lib/virtodbc.so;Host=localhost:{FREEBASE_ODBC_PORT};UID=dba;PWD=dba'
    )
    odbc_conn.setdecoding(pyodbc.SQL_CHAR, encoding='utf8')
    odbc_conn.setdecoding(pyodbc.SQL_WCHAR, encoding='utf8')
    odbc_conn.setencoding(encoding='utf8')
    odbc_conn.timeout = 1
    print('Freebase Virtuoso ODBC connected')


def execute_query(query: str) -> List[str]:
    sparql.setQuery(query)
    try:
        results = sparql.query().convert()
    except urllib.error.URLError:
        print(query)
        # exit(0)
    rtn = []
    for result in results['results']['bindings']:
        assert len(result) == 1  # only select one variable
        for var in result:
            rtn.append(result[var]['value'].replace('http://rdf.freebase.com/ns/', '').replace("-08:00", ''))

    return rtn

def execute_query_with_odbc(query:str) -> List[str]:
    global odbc_conn
    if odbc_conn == None:
        initialize_odbc_connection()
    
    # print('successfully connnected to Freebase ODBC')
    result_set = set()
    query2 = "SPARQL "+query
    try:
        with odbc_conn.cursor() as cursor:
            cursor.execute(query2)
            # rows = cursor.fetchall()
            rows = cursor.fetchmany(10000)
    except Exception:
        # print(f"Query Execution Failed:{query2}")
        exit(0)
    
    for row in rows:
        result_set.add(row[0])

    return result_set


def get_types_with_odbc(entity: str)  -> List[str]:

    # build connection
    global odbc_conn
    if odbc_conn == None:
        initialize_odbc_connection()
    
    types = set()

    query = ("""SPARQL
    PREFIX rdf: <http://www.w3.org/1999/02/22-rdf-syntax-ns#>
    PREFIX rdfs: <http://www.w3.org/2000/01/rdf-schema#>
    PREFIX : <http://rdf.freebase.com/ns/> 
    SELECT (?x0 AS ?value) WHERE {
    SELECT DISTINCT ?x0  WHERE {
    """
             ':' + entity + ' :type.object.type ?x0 . '
                            """
    }
    }
    """)

    try:
        with odbc_conn.cursor() as cursor:
            cursor.execute(query)
            # rows = cursor.fetchall()
            rows = cursor.fetchmany(10000)
    except Exception:
        # print(f"Query Execution Failed:{query1}")
        rows=[]
        # exit(0)
    

    for row in rows:
        types.add(row[0].replace('http://rdf.freebase.com/ns/', ''))
    
    if len(types)==0:
        return []
    else:
        return list(types)


def get_in_relations(entity: str):
    in_relations = set()

    query1 = ("""
            PREFIX rdf: <http://www.w3.org/1999/02/22-rdf-syntax-ns#>
            PREFIX rdfs: <http://www.w3.org/2000/01/rdf-schema#>
            PREFIX : <http://rdf.freebase.com/ns/> 
            SELECT (?x0 AS ?value) WHERE {
            SELECT DISTINCT ?x0  WHERE {
            """
              '?x1 ?x0 ' + ':' + entity + '. '
                                          """
     FILTER regex(?x0, "http://rdf.freebase.com/ns/")
     }
     }
     """)
    # print(query1)

    sparql.setQuery(query1)
    try:
        results = sparql.query().convert()
    except urllib.error.URLError:
        print(query1)
        exit(0)
    for result in results['results']['bindings']:
        in_relations.add(result['value']['value'].replace('http://rdf.freebase.com/ns/', ''))

    return in_relations


def get_out_relations(entity: str):
    out_relations = set()

    query2 = ("""
        PREFIX rdf: <http://www.w3.org/1999/02/22-rdf-syntax-ns#>
        PREFIX rdfs: <http://www.w3.org/2000/01/rdf-schema#>
        PREFIX : <http://rdf.freebase.com/ns/> 
        SELECT (?x0 AS ?value) WHERE {
        SELECT DISTINCT ?x0  WHERE {
        """
              ':' + entity + ' ?x0 ?x1 . '
                             """
    FILTER regex(?x0, "http://rdf.freebase.com/ns/")
    }
    }
    """)
    # print(query2)

    sparql.setQuery(query2)
    try:
        results = sparql.query().convert()
    except urllib.error.URLError:
        print(query2)
        exit(0)
    for result in results['results']['bindings']:
        out_relations.add(result['value']['value'].replace('http://rdf.freebase.com/ns/', ''))

    return out_relations
    

def query_two_hop_relations_gmt(entities_path, output_file):
    global odbc_conn
    if odbc_conn == None:
        initialize_odbc_connection()
    res_dict = defaultdict(list)
    entities = load_json(entities_path)
    for entity in tqdm(entities, total=len(entities)):
        query = """
        SPARQL SELECT DISTINCT ?x0 as ?r0 ?y as ?r1 where {{
            {{ ?x1 ?x0 {} . ?x2 ?y ?x1 }}
            UNION
            {{ ?x1 ?x0 {} . ?x1 ?y ?x2 }}
            UNION
            {{ {} ?x0 ?x1 . ?x2 ?y ?x1 }}
            UNION
            {{ {} ?x0 ?x1 . ?x1 ?y ?x2 }}
            FILTER (?x0 != rdf:type && ?x0 != rdfs:label)
            FILTER (?y != rdf:type && ?y != rdfs:label)
            FILTER(?x0 != ns:type.object.type && ?x0 != ns:type.object.instance)
            FILTER(?y != ns:type.object.type && ?y != ns:type.object.instance)
            FILTER( !regex(?x0,"wikipedia","i"))
            FILTER( !regex(?y,"wikipedia","i"))
            FILTER( !regex(?x0,"type.object","i"))
            FILTER( !regex(?y,"type.object","i"))
            FILTER( !regex(?x0,"common.topic","i"))
            FILTER( !regex(?y,"common.topic","i"))
            FILTER( !regex(?x0,"_id","i"))
            FILTER( !regex(?y,"_id","i"))
            FILTER( !regex(?x0,"#type","i"))
            FILTER( !regex(?y,"#type","i"))
            FILTER( !regex(?x0,"#label","i"))
            FILTER( !regex(?y,"#label","i"))
            FILTER( !regex(?x0,"/ns/freebase","i"))
            FILTER( !regex(?y,"/ns/freebase","i"))
            FILTER( !regex(?x0, "ns/common."))
            FILTER( !regex(?y, "ns/common."))
            FILTER( !regex(?x0, "ns/type."))
            FILTER( !regex(?y, "ns/type."))
            FILTER( !regex(?x0, "ns/kg."))
            FILTER( !regex(?y, "ns/kg."))
            FILTER( !regex(?x0, "ns/user."))
            FILTER( !regex(?y, "ns/user."))
            FILTER( !regex(?x0, "ns/base."))
            FILTER( !regex(?y, "ns/base."))
            FILTER( !regex(?x0, "ns/dataworld."))
            FILTER( !regex(?y, "ns/dataworld."))
            FILTER regex(?x0, "http://rdf.freebase.com/ns/")
            FILTER regex(?y, "http://rdf.freebase.com/ns/")
        }} 
        
        LIMIT 300
        """.format('ns:'+entity, 'ns:'+entity, 'ns:'+entity, 'ns:'+entity)
        try:
            with odbc_conn.cursor() as cursor:
                cursor.execute(query)
                # rows = cursor.fetchall()
            rows = cursor.fetchmany(10000)
            res = set()
            for row in rows:
                if row[0].startswith("http://rdf.freebase.com/ns/"):
                    res.add(row[0].replace('http://rdf.freebase.com/ns/', ''))
                if row[1].startswith("http://rdf.freebase.com/ns/"):
                    res.add(row[1].replace('http://rdf.freebase.com/ns/', ''))
            res_dict[entity] = list(res)
            
        except Exception:
            # print(f"Query Execution Failed:{query1}")
            rows=[]
    
    # return list(res)
    dump_json(res_dict, output_file)


def get_2hop_relations_with_odbc(entity: str):
    in_relations = set()
    out_relations = set()
    paths = []

    global odbc_conn
    if odbc_conn == None:
        initialize_odbc_connection()


    query1 = ("""SPARQL 
            PREFIX rdf: <http://www.w3.org/1999/02/22-rdf-syntax-ns#>
            PREFIX rdfs: <http://www.w3.org/2000/01/rdf-schema#>
            PREFIX ns: <http://rdf.freebase.com/ns/>
            SELECT distinct ?x0 as ?r0 ?y as ?r1 WHERE {
            """
              '?x1 ?x0 ' + 'ns:' + entity + '. '
                                          """
                ?x2 ?y ?x1 .
                  FILTER (?x0 != rdf:type && ?x0 != rdfs:label)
                  FILTER (?y != rdf:type && ?y != rdfs:label)
                  FILTER(?x0 != ns:type.object.type && ?x0 != ns:type.object.instance)
                  FILTER(?y != ns:type.object.type && ?y != ns:type.object.instance)
                  FILTER( !regex(?x0,"wikipedia","i"))
                  FILTER( !regex(?y,"wikipedia","i"))
                  FILTER( !regex(?x0,"type.object","i"))
                  FILTER( !regex(?y,"type.object","i"))
                  FILTER( !regex(?x0,"common.topic","i"))
                  FILTER( !regex(?y,"common.topic","i"))
                  FILTER( !regex(?x0,"_id","i"))
                  FILTER( !regex(?y,"_id","i"))
                  FILTER( !regex(?x0,"#type","i"))
                  FILTER( !regex(?y,"#type","i"))
                  FILTER( !regex(?x0,"#label","i"))
                  FILTER( !regex(?y,"#label","i"))
                  FILTER( !regex(?x0,"/ns/freebase","i"))
                  FILTER( !regex(?y,"/ns/freebase","i"))
                  FILTER( !regex(?x0, "ns/common."))
                  FILTER( !regex(?y, "ns/common."))
                  FILTER( !regex(?x0, "ns/type."))
                  FILTER( !regex(?y, "ns/type."))
                  FILTER( !regex(?x0, "ns/kg."))
                  FILTER( !regex(?y, "ns/kg."))
                  FILTER( !regex(?x0, "ns/user."))
                  FILTER( !regex(?y, "ns/user."))
                  FILTER( !regex(?x0, "ns/dataworld."))
                  FILTER( !regex(?y, "ns/dataworld."))
                  FILTER regex(?x0, "http://rdf.freebase.com/ns/")
                  FILTER regex(?y, "http://rdf.freebase.com/ns/")
                  }
                  LIMIT 1000
                  """)
    # print(query1)
    try:
        with odbc_conn.cursor() as cursor:
            cursor.execute(query1)
            # rows = cursor.fetchall()
            rows = cursor.fetchmany(10000)
    except Exception:
        # print(f"Query Execution Failed:{query1}")
        rows=[]
        # exit(0)


    for row in rows:
        r0 = row[0].replace('http://rdf.freebase.com/ns/', '')
        r1 = row[1].replace('http://rdf.freebase.com/ns/', '')
        in_relations.add(r0)
        in_relations.add(r1)

        if r0 in roles and r1 in roles:
            paths.append((r0, r1))
        

    query2 = ("""SPARQL 
            PREFIX rdf: <http://www.w3.org/1999/02/22-rdf-syntax-ns#>
            PREFIX rdfs: <http://www.w3.org/2000/01/rdf-schema#>
            PREFIX ns: <http://rdf.freebase.com/ns/> 
            SELECT distinct ?x0 as ?r0 ?y as ?r1 WHERE {
            """
              '?x1 ?x0 ' + 'ns:' + entity + '. '
                                          """
                ?x1 ?y ?x2 .
                """
                  'FILTER (?x2 != ns:'+entity+' )'
                  """
                  FILTER (?x0 != rdf:type && ?x0 != rdfs:label)
                  FILTER (?y != rdf:type && ?y != rdfs:label)
                  FILTER(?x0 != ns:type.object.type && ?x0 != ns:type.object.instance)
                  FILTER(?y != ns:type.object.type && ?y != ns:type.object.instance)
                  FILTER( !regex(?x0,"wikipedia","i"))
                  FILTER( !regex(?y,"wikipedia","i"))
                  FILTER( !regex(?x0,"type.object","i"))
                  FILTER( !regex(?y,"type.object","i"))
                  FILTER( !regex(?x0,"common.topic","i"))
                  FILTER( !regex(?y,"common.topic","i"))
                  FILTER( !regex(?x0,"_id","i"))
                  FILTER( !regex(?y,"_id","i"))
                  FILTER( !regex(?x0,"#type","i"))
                  FILTER( !regex(?y,"#type","i"))
                  FILTER( !regex(?x0,"#label","i"))
                  FILTER( !regex(?y,"#label","i"))
                  FILTER( !regex(?x0,"/ns/freebase","i"))
                  FILTER( !regex(?y,"/ns/freebase","i"))
                  FILTER( !regex(?x0, "ns/common."))
                  FILTER( !regex(?y, "ns/common."))
                  FILTER( !regex(?x0, "ns/type."))
                  FILTER( !regex(?y, "ns/type."))
                  FILTER( !regex(?x0, "ns/kg."))
                  FILTER( !regex(?y, "ns/kg."))
                  FILTER( !regex(?x0, "ns/user."))
                  FILTER( !regex(?y, "ns/user."))
                  FILTER( !regex(?x0, "ns/dataworld."))
                  FILTER( !regex(?y, "ns/dataworld."))
                  FILTER regex(?x0, "http://rdf.freebase.com/ns/")
                  FILTER regex(?y, "http://rdf.freebase.com/ns/")
                  }
                  LIMIT 1000
                  """)

    try:
        with odbc_conn.cursor() as cursor:
            cursor.execute(query2)
            # rows = cursor.fetchall()
            rows = cursor.fetchmany(10000)
    except Exception:
        # print(f"Query Execution Failed:{query2}")
        rows = []
        # exit(0)
    
    for row in rows:
        r0 = row[0].replace('http://rdf.freebase.com/ns/', '')
        r1 = row[1].replace('http://rdf.freebase.com/ns/', '')
        in_relations.add(r0)
        out_relations.add(r1)

        if r0 in roles and r1 in roles:
            paths.append((r0, r1 + '#R'))

    
    query3 = ("""SPARQL 
                PREFIX rdf: <http://www.w3.org/1999/02/22-rdf-syntax-ns#>
                PREFIX rdfs: <http://www.w3.org/2000/01/rdf-schema#>
                PREFIX ns: <http://rdf.freebase.com/ns/>
                SELECT distinct ?x0 as ?r0 ?y as ?r1 WHERE {
                """
              'ns:' + entity + ' ?x0 ?x1 . '
                             """
                ?x2 ?y ?x1 .
                  FILTER (?x0 != rdf:type && ?x0 != rdfs:label)
                  FILTER (?y != rdf:type && ?y != rdfs:label)
                  FILTER(?x0 != ns:type.object.type && ?x0 != ns:type.object.instance)
                  FILTER(?y != ns:type.object.type && ?y != ns:type.object.instance)
                  FILTER( !regex(?x0,"wikipedia","i"))
                  FILTER( !regex(?y,"wikipedia","i"))
                  FILTER( !regex(?x0,"type.object","i"))
                  FILTER( !regex(?y,"type.object","i"))
                  FILTER( !regex(?x0,"common.topic","i"))
                  FILTER( !regex(?y,"common.topic","i"))
                  FILTER( !regex(?x0,"_id","i"))
                  FILTER( !regex(?y,"_id","i"))
                  FILTER( !regex(?x0,"#type","i"))
                  FILTER( !regex(?y,"#type","i"))
                  FILTER( !regex(?x0,"#label","i"))
                  FILTER( !regex(?y,"#label","i"))
                  FILTER( !regex(?x0,"/ns/freebase","i"))
                  FILTER( !regex(?y,"/ns/freebase","i"))
                  FILTER( !regex(?x0, "ns/common."))
                  FILTER( !regex(?y, "ns/common."))
                  FILTER( !regex(?x0, "ns/type."))
                  FILTER( !regex(?y, "ns/type."))
                  FILTER( !regex(?x0, "ns/kg."))
                  FILTER( !regex(?y, "ns/kg."))
                  FILTER( !regex(?x0, "ns/user."))
                  FILTER( !regex(?y, "ns/user."))
                  FILTER( !regex(?x0, "ns/dataworld."))
                  FILTER( !regex(?y, "ns/dataworld."))
                  FILTER regex(?x0, "http://rdf.freebase.com/ns/")
                  FILTER regex(?y, "http://rdf.freebase.com/ns/")
                  }
                  LIMIT 1000
                  """)

    try:
        with odbc_conn.cursor() as cursor:
            cursor.execute(query3)
            # rows = cursor.fetchall()
            rows = cursor.fetchmany(10000)
    except Exception:
        # print(f"Query Execution Failed:{query3}")
        rows = []
        # exit(0)
    
    for row in rows:
        r0 = row[0].replace('http://rdf.freebase.com/ns/', '')
        r1 = row[1].replace('http://rdf.freebase.com/ns/', '')
        out_relations.add(r0)
        in_relations.add(r1)

        if r0 in roles and r1 in roles:
            paths.append((r0 + '#R', r1))


    query4 = ("""SPARQL 
                PREFIX rdf: <http://www.w3.org/1999/02/22-rdf-syntax-ns#>
                PREFIX rdfs: <http://www.w3.org/2000/01/rdf-schema#>
                PREFIX ns: <http://rdf.freebase.com/ns/>
                SELECT distinct ?x0 as ?r0 ?y as ?r1 WHERE {
                """
              'ns:' + entity + ' ?x0 ?x1 . '
                             """
                ?x1 ?y ?x2 .
                """
                  'FILTER (?x2 != ns:'+entity+' )'
                """
                FILTER (?x0 != rdf:type && ?x0 != rdfs:label)
                FILTER (?y != rdf:type && ?y != rdfs:label)
                FILTER(?x0 != ns:type.object.type && ?x0 != ns:type.object.instance)
                FILTER(?y != ns:type.object.type && ?y != ns:type.object.instance)
                FILTER( !regex(?x0,"wikipedia","i"))
                FILTER( !regex(?y,"wikipedia","i"))
                FILTER( !regex(?x0,"type.object","i"))
                FILTER( !regex(?y,"type.object","i"))
                FILTER( !regex(?x0,"common.topic","i"))
                FILTER( !regex(?y,"common.topic","i"))
                FILTER( !regex(?x0,"_id","i"))
                FILTER( !regex(?y,"_id","i"))
                FILTER( !regex(?x0,"#type","i"))
                FILTER( !regex(?y,"#type","i"))
                FILTER( !regex(?x0,"#label","i"))
                FILTER( !regex(?y,"#label","i"))
                FILTER( !regex(?x0,"/ns/freebase","i"))
                FILTER( !regex(?y,"/ns/freebase","i"))
                FILTER( !regex(?x0, "ns/common."))
                FILTER( !regex(?y, "ns/common."))
                FILTER( !regex(?x0, "ns/type."))
                FILTER( !regex(?y, "ns/type."))
                FILTER( !regex(?x0, "ns/kg."))
                FILTER( !regex(?y, "ns/kg."))
                FILTER( !regex(?x0, "ns/user."))
                FILTER( !regex(?y, "ns/user."))
                FILTER( !regex(?x0, "ns/dataworld."))
                FILTER( !regex(?y, "ns/dataworld."))
                FILTER regex(?x0, "http://rdf.freebase.com/ns/")
                FILTER regex(?y, "http://rdf.freebase.com/ns/")
                }
                LIMIT 1000
                """)

    try:
        with odbc_conn.cursor() as cursor:
            cursor.execute(query4)
            # rows = cursor.fetchall()
            rows = cursor.fetchmany(10000)
    except Exception:
        # print(f"Query Execution Failed:{query4}")
        rows = []
        # exit(0)

    for row in rows:
        r0 = row[0].replace('http://rdf.freebase.com/ns/', '')
        r1 = row[1].replace('http://rdf.freebase.com/ns/', '')
        out_relations.add(r0)
        out_relations.add(r1)

        if r0 in roles and r1 in roles:
            paths.append((r0 + '#R', r1 + '#R'))

    return in_relations, out_relations, paths

def get_2hop_relations_with_odbc_wo_filter(entity: str):
    in_relations = set()
    out_relations = set()
    paths = []

    global odbc_conn
    if odbc_conn == None:
        initialize_odbc_connection()


    query1 = ("""SPARQL 
            PREFIX rdf: <http://www.w3.org/1999/02/22-rdf-syntax-ns#>
            PREFIX rdfs: <http://www.w3.org/2000/01/rdf-schema#>
            PREFIX ns: <http://rdf.freebase.com/ns/>
            SELECT distinct ?x0 as ?r0 ?y as ?r1 WHERE {
            """
              '?x1 ?x0 ' + 'ns:' + entity + '. '
                                          """
                ?x2 ?y ?x1 .
                  FILTER (?x0 != rdf:type && ?x0 != rdfs:label)
                  FILTER (?y != rdf:type && ?y != rdfs:label)
                  FILTER(?x0 != ns:type.object.type && ?x0 != ns:type.object.instance)
                  FILTER(?y != ns:type.object.type && ?y != ns:type.object.instance)
                  FILTER( !regex(?x0,"wikipedia","i"))
                  FILTER( !regex(?y,"wikipedia","i"))
                  FILTER( !regex(?x0,"_id","i"))
                  FILTER( !regex(?y,"_id","i"))
                  FILTER( !regex(?x0,"#type","i"))
                  FILTER( !regex(?y,"#type","i"))
                  FILTER( !regex(?x0,"#label","i"))
                  FILTER( !regex(?y,"#label","i"))
                  FILTER( !regex(?x0,"/ns/freebase","i"))
                  FILTER( !regex(?y,"/ns/freebase","i"))
                  FILTER( !regex(?x0, "ns/kg."))
                  FILTER( !regex(?y, "ns/kg."))
                  FILTER( !regex(?x0, "ns/dataworld."))
                  FILTER( !regex(?y, "ns/dataworld."))
                  FILTER regex(?x0, "http://rdf.freebase.com/ns/")
                  FILTER regex(?y, "http://rdf.freebase.com/ns/")
                  }
                  LIMIT 1000
                  """)
    # print(query1)
    try:
        with odbc_conn.cursor() as cursor:
            cursor.execute(query1)
            # rows = cursor.fetchall()
            rows = cursor.fetchmany(10000)
    except Exception:
        # print(f"Query Execution Failed:{query1}")
        rows=[]
        # exit(0)


    for row in rows:
        r0 = row[0].replace('http://rdf.freebase.com/ns/', '')
        r1 = row[1].replace('http://rdf.freebase.com/ns/', '')
        in_relations.add(r0)
        in_relations.add(r1)

        if r0 in roles and r1 in roles:
            paths.append((r0, r1))
        

    query2 = ("""SPARQL 
            PREFIX rdf: <http://www.w3.org/1999/02/22-rdf-syntax-ns#>
            PREFIX rdfs: <http://www.w3.org/2000/01/rdf-schema#>
            PREFIX ns: <http://rdf.freebase.com/ns/> 
            SELECT distinct ?x0 as ?r0 ?y as ?r1 WHERE {
            """
              '?x1 ?x0 ' + 'ns:' + entity + '. '
                                          """
                ?x1 ?y ?x2 .
                """
                  'FILTER (?x2 != ns:'+entity+' )'
                  """
                  FILTER (?x0 != rdf:type && ?x0 != rdfs:label)
                  FILTER (?y != rdf:type && ?y != rdfs:label)
                  FILTER(?x0 != ns:type.object.type && ?x0 != ns:type.object.instance)
                  FILTER(?y != ns:type.object.type && ?y != ns:type.object.instance)
                  FILTER( !regex(?x0,"wikipedia","i"))
                  FILTER( !regex(?y,"wikipedia","i"))
                  FILTER( !regex(?x0,"_id","i"))
                  FILTER( !regex(?y,"_id","i"))
                  FILTER( !regex(?x0,"#type","i"))
                  FILTER( !regex(?y,"#type","i"))
                  FILTER( !regex(?x0,"#label","i"))
                  FILTER( !regex(?y,"#label","i"))
                  FILTER( !regex(?x0,"/ns/freebase","i"))
                  FILTER( !regex(?y,"/ns/freebase","i"))
                  FILTER( !regex(?x0, "ns/kg."))
                  FILTER( !regex(?y, "ns/kg."))
                  FILTER( !regex(?x0, "ns/dataworld."))
                  FILTER( !regex(?y, "ns/dataworld."))
                  FILTER regex(?x0, "http://rdf.freebase.com/ns/")
                  FILTER regex(?y, "http://rdf.freebase.com/ns/")
                  }
                  LIMIT 1000
                  """)

    try:
        with odbc_conn.cursor() as cursor:
            cursor.execute(query2)
            # rows = cursor.fetchall()
            rows = cursor.fetchmany(10000)
    except Exception:
        # print(f"Query Execution Failed:{query2}")
        rows = []
        # exit(0)
    
    for row in rows:
        r0 = row[0].replace('http://rdf.freebase.com/ns/', '')
        r1 = row[1].replace('http://rdf.freebase.com/ns/', '')
        in_relations.add(r0)
        out_relations.add(r1)

        if r0 in roles and r1 in roles:
            paths.append((r0, r1 + '#R'))

    
    query3 = ("""SPARQL 
                PREFIX rdf: <http://www.w3.org/1999/02/22-rdf-syntax-ns#>
                PREFIX rdfs: <http://www.w3.org/2000/01/rdf-schema#>
                PREFIX ns: <http://rdf.freebase.com/ns/>
                SELECT distinct ?x0 as ?r0 ?y as ?r1 WHERE {
                """
              'ns:' + entity + ' ?x0 ?x1 . '
                             """
                ?x2 ?y ?x1 .
                  FILTER (?x0 != rdf:type && ?x0 != rdfs:label)
                  FILTER (?y != rdf:type && ?y != rdfs:label)
                  FILTER(?x0 != ns:type.object.type && ?x0 != ns:type.object.instance)
                  FILTER(?y != ns:type.object.type && ?y != ns:type.object.instance)
                  FILTER( !regex(?x0,"wikipedia","i"))
                  FILTER( !regex(?y,"wikipedia","i"))
                  FILTER( !regex(?x0,"_id","i"))
                  FILTER( !regex(?y,"_id","i"))
                  FILTER( !regex(?x0,"#type","i"))
                  FILTER( !regex(?y,"#type","i"))
                  FILTER( !regex(?x0,"#label","i"))
                  FILTER( !regex(?y,"#label","i"))
                  FILTER( !regex(?x0,"/ns/freebase","i"))
                  FILTER( !regex(?y,"/ns/freebase","i"))
                  FILTER( !regex(?x0, "ns/kg."))
                  FILTER( !regex(?y, "ns/kg."))
                  FILTER( !regex(?x0, "ns/dataworld."))
                  FILTER( !regex(?y, "ns/dataworld."))
                  FILTER regex(?x0, "http://rdf.freebase.com/ns/")
                  FILTER regex(?y, "http://rdf.freebase.com/ns/")
                  }
                  LIMIT 1000
                  """)

    try:
        with odbc_conn.cursor() as cursor:
            cursor.execute(query3)
            # rows = cursor.fetchall()
            rows = cursor.fetchmany(10000)
    except Exception:
        # print(f"Query Execution Failed:{query3}")
        rows = []
        # exit(0)
    
    for row in rows:
        r0 = row[0].replace('http://rdf.freebase.com/ns/', '')
        r1 = row[1].replace('http://rdf.freebase.com/ns/', '')
        out_relations.add(r0)
        in_relations.add(r1)

        if r0 in roles and r1 in roles:
            paths.append((r0 + '#R', r1))


    query4 = ("""SPARQL 
                PREFIX rdf: <http://www.w3.org/1999/02/22-rdf-syntax-ns#>
                PREFIX rdfs: <http://www.w3.org/2000/01/rdf-schema#>
                PREFIX ns: <http://rdf.freebase.com/ns/>
                SELECT distinct ?x0 as ?r0 ?y as ?r1 WHERE {
                """
              'ns:' + entity + ' ?x0 ?x1 . '
                             """
                ?x1 ?y ?x2 .
                """
                  'FILTER (?x2 != ns:'+entity+' )'
                """
                FILTER (?x0 != rdf:type && ?x0 != rdfs:label)
                FILTER (?y != rdf:type && ?y != rdfs:label)
                FILTER(?x0 != ns:type.object.type && ?x0 != ns:type.object.instance)
                FILTER(?y != ns:type.object.type && ?y != ns:type.object.instance)
                FILTER( !regex(?x0,"wikipedia","i"))
                FILTER( !regex(?y,"wikipedia","i"))
                FILTER( !regex(?x0,"_id","i"))
                FILTER( !regex(?y,"_id","i"))
                FILTER( !regex(?x0,"#type","i"))
                FILTER( !regex(?y,"#type","i"))
                FILTER( !regex(?x0,"#label","i"))
                FILTER( !regex(?y,"#label","i"))
                FILTER( !regex(?x0,"/ns/freebase","i"))
                FILTER( !regex(?y,"/ns/freebase","i"))
                FILTER( !regex(?x0, "ns/kg."))
                FILTER( !regex(?y, "ns/kg."))
                FILTER( !regex(?x0, "ns/dataworld."))
                FILTER( !regex(?y, "ns/dataworld."))
                FILTER regex(?x0, "http://rdf.freebase.com/ns/")
                FILTER regex(?y, "http://rdf.freebase.com/ns/")
                }
                LIMIT 1000
                """)

    try:
        with odbc_conn.cursor() as cursor:
            cursor.execute(query4)
            # rows = cursor.fetchall()
            rows = cursor.fetchmany(10000)
    except Exception:
        # print(f"Query Execution Failed:{query4}")
        rows = []
        # exit(0)

    for row in rows:
        r0 = row[0].replace('http://rdf.freebase.com/ns/', '')
        r1 = row[1].replace('http://rdf.freebase.com/ns/', '')
        out_relations.add(r0)
        out_relations.add(r1)

        if r0 in roles and r1 in roles:
            paths.append((r0 + '#R', r1 + '#R'))

    return in_relations, out_relations, paths


def get_label(entity: str) -> str:
    """Get the label of an entity in Freebase"""
    query = ("""
        PREFIX rdf: <http://www.w3.org/1999/02/22-rdf-syntax-ns#>
        PREFIX rdfs: <http://www.w3.org/2000/01/rdf-schema#>
        PREFIX : <http://rdf.freebase.com/ns/> 
        SELECT (?x0 AS ?label) WHERE {
        SELECT DISTINCT ?x0  WHERE {
        """
             ':' + entity + ' rdfs:label ?x0 . '
                            """
                            FILTER (langMatches( lang(?x0), "EN" ) )
                             }
                             }
                             """)
    # # print(query)
    sparql.setQuery(query)
    try:
        results = sparql.query().convert()
    except urllib.error.URLError:
        print(query)
        exit(0)
    rtn = []
    for result in results['results']['bindings']:
        label = result['label']['value']
        rtn.append(label)
    if len(rtn) != 0:
        return rtn[0]
    else:
        return None


import pyodbc
def pyodbc_test():
    conn = pyodbc.connect(f'DRIVER={path}/lib/virtodbc.so;Host=localhost:{FREEBASE_ODBC_PORT};UID=dba;PWD=dba')
    print(conn)
    conn.setdecoding(pyodbc.SQL_CHAR, encoding='utf8')
    conn.setdecoding(pyodbc.SQL_WCHAR, encoding='utf8')
    conn.setencoding(encoding='utf8')
    
    with conn.cursor() as cursor:
        cursor.execute("SPARQL SELECT ?subject ?object WHERE { ?subject rdfs:subClassOf ?object }")
        # rows = cursor.fetchall()
        rows = cursor.fetchmany(10000)
    
    for row in rows:
        row = str(row)
        print(row)
    conn.commit()
    conn.close()



def get_label_with_odbc(entity: str) -> str:
    """Get the label of an entity in Freebase"""

    # build connection
    global odbc_conn
    if odbc_conn == None:
        initialize_odbc_connection()
        
    query = ("""SPARQL
        PREFIX rdf: <http://www.w3.org/1999/02/22-rdf-syntax-ns#>
        PREFIX rdfs: <http://www.w3.org/2000/01/rdf-schema#>
        PREFIX ns: <http://rdf.freebase.com/ns/> 
        SELECT (?x0 AS ?label) WHERE {
        SELECT DISTINCT ?x0  WHERE {
        """
             'ns:' + entity + ' rdfs:label ?x0 . '
                            """
                            FILTER (langMatches( lang(?x0), "EN" ) )
                             }
                             }
                             """)

    # query = query.replace("\n"," ")
    # print(query)
    try:
        with odbc_conn.cursor() as cursor:
            cursor.execute(query)
            # rows = cursor.fetchall()
            rows = cursor.fetchmany(10000)
    except Exception:
        # print(f"Query Execution Failed:{query}")
        exit(0)
    
    
    rtn = []
    for row in rows:
        # print(type(row))
        rtn.append(row[0])
    
    if len(rtn) != 0:
        return rtn[0]
    else:
        return None


def get_in_relations_with_odbc(entity: str) -> str:
    # build connection
    global odbc_conn
    if odbc_conn == None:
        initialize_odbc_connection()
    
    in_relations = set()

    query1 = ("""SPARQL
            PREFIX rdf: <http://www.w3.org/1999/02/22-rdf-syntax-ns#>
            PREFIX rdfs: <http://www.w3.org/2000/01/rdf-schema#>
            PREFIX : <http://rdf.freebase.com/ns/> 
            SELECT (?x0 AS ?value) WHERE {
            SELECT DISTINCT ?x0  WHERE {
            """
              '?x1 ?x0 ' + ':' + entity + '. '
                                          """
     FILTER regex(?x0, "http://rdf.freebase.com/ns/")
     }
     }
     """)
    # print(query1)


    try:
        with odbc_conn.cursor() as cursor:
            cursor.execute(query1)
            # rows = cursor.fetchall()
            rows = cursor.fetchmany(10000)
    except Exception:
        # print(f"Query Execution Failed:{query1}")
        exit(0)
    

    for row in rows:
        in_relations.add(row[0].replace('http://rdf.freebase.com/ns/', ''))

    return in_relations
    

def get_out_relations_with_odbc(entity: str) -> str:
    # build connection
    global odbc_conn
    if odbc_conn == None:
        initialize_odbc_connection()
    
    out_relations = set()

    query2 = ("""SPARQL
        PREFIX rdf: <http://www.w3.org/1999/02/22-rdf-syntax-ns#>
        PREFIX rdfs: <http://www.w3.org/2000/01/rdf-schema#>
        PREFIX : <http://rdf.freebase.com/ns/> 
        SELECT (?x0 AS ?value) WHERE {
        SELECT DISTINCT ?x0  WHERE {
        """
              ':' + entity + ' ?x0 ?x1 . '
                             """
    FILTER regex(?x0, "http://rdf.freebase.com/ns/")
    }
    }
    """)
    # print(query2)
    

    try:
        with odbc_conn.cursor() as cursor:
            cursor.execute(query2)
            # rows = cursor.fetchall()
            rows = cursor.fetchmany(10000)
            for row in rows:
                out_relations.add(row[0].replace('http://rdf.freebase.com/ns/', ''))
    except Exception:
        # print(f"Query Execution Failed:{query2}")
        exit(0)
    

    

    return out_relations


def get_1hop_relations_with_odbc(entity):
    # build connection
    global odbc_conn
    if odbc_conn == None:
        initialize_odbc_connection()
    
    relations = set()

    query = ("""SPARQL
            PREFIX rdf: <http://www.w3.org/1999/02/22-rdf-syntax-ns#>
            PREFIX rdfs: <http://www.w3.org/2000/01/rdf-schema#>
            PREFIX : <http://rdf.freebase.com/ns/> 
            SELECT (?x0 AS ?value) WHERE {
            SELECT DISTINCT ?x0  WHERE {
            """
              '{ ?x1 ?x0 ' + ':' + entity + ' }'
              + ' UNION '
              + '{ :' + entity + ' ?x0 ?x1 ' + '}'
                                          """
     FILTER regex(?x0, "http://rdf.freebase.com/ns/")
     }
     }
     """)


    try:
        with odbc_conn.cursor() as cursor:
            cursor.execute(query)
            # rows = cursor.fetchall()
            rows = cursor.fetchmany(10000)
    except Exception:
        # print(f"Query Execution Failed:{query}")
        exit(0)
    

    for row in rows:
        relations.add(row[0].replace('http://rdf.freebase.com/ns/', ''))

    return relations


def get_freebase_mid_from_wikiID(wikiID: int):
    # build connection
    global odbc_conn
    if odbc_conn == None:
        initialize_odbc_connection()
    
    mid = set()

    query2 = ("""SPARQL
        PREFIX rdf: <http://www.w3.org/1999/02/22-rdf-syntax-ns#>
        PREFIX rdfs: <http://www.w3.org/2000/01/rdf-schema#>
        PREFIX : <http://rdf.freebase.com/ns/> 
        SELECT (?x0 AS ?value) WHERE {
        SELECT DISTINCT ?x0  WHERE {
        """
              '?x0 <http://rdf.freebase.com/key/wikipedia.en_id> ' + f'"{wikiID}"'
                             """
    FILTER regex(?x0, "http://rdf.freebase.com/ns/")
    }
    }
    """)
    # print(query2)
    

    try:
        with odbc_conn.cursor() as cursor:
            cursor.execute(query2)
            # rows = cursor.fetchall()
            rows = cursor.fetchmany(10000)
    except Exception:
        # print(f"Query Execution Failed:{query2}")
        exit(0)
    

    for row in rows:
        mid.add(row[0].replace('http://rdf.freebase.com/ns/', ''))
    
    if len(mid)==0:
        return ''
    else:
        return list(mid)[0]


def load_json(fname, mode="r", encoding="utf8"):
    if "b" in mode:
        encoding = None
    with open(fname, mode=mode, encoding=encoding) as f:
        return json.load(f)


def dump_json(obj, fname, indent=4, mode='w' ,encoding="utf8", ensure_ascii=False):
    if "b" in mode:
        encoding = None
    with open(fname, "w", encoding=encoding) as f:
        return json.dump(obj, f, indent=indent, ensure_ascii=ensure_ascii)


def get_entity_labels(src_path, tgt_path):
    entities_list = load_json(src_path)
    res = dict()
    # for entity in entities_list:
    for entity in tqdm(entities_list, total=len(entities_list),desc=f'Querying entity labels'):
        label = get_label_with_odbc(entity)
        res[entity] = label
    dump_json(res, tgt_path)


def query_relation_domain_range_label_odbc(input_path, output_path):
    # build connection
    global odbc_conn
    if odbc_conn == None:
        initialize_odbc_connection()
    relations = load_json(input_path)
    
    res_dict = dict()
    for relation in tqdm(relations):
        query = """
        SPARQL DESCRIBE {}
        """.format('ns:' + relation)
        
        try:
            with odbc_conn.cursor() as cursor:
                cursor.execute(query)
                # rows = cursor.fetchall()
            rows = cursor.fetchmany(10000)
        except Exception:
            # print(f"Query Execution Failed:{query}")
            exit(0)
        
        res_dict[relation] = dict()
        for row in rows:
            if '#domain' in row[1]:
                res_dict[relation]["domain"] = row[2].replace('http://rdf.freebase.com/ns/', '')
            elif '#range' in row[1]:
                res_dict[relation]["range"] = row[2].replace('http://rdf.freebase.com/ns/', '')
            elif '#label' in row[1]:
                res_dict[relation]["label"] = row[2].replace('http://rdf.freebase.com/ns/', '')
    
    dump_json(res_dict, output_path)

def freebase_query_entity_type_with_odbc(entities_path, output_path):
    # build connection
    global odbc_conn
    if odbc_conn == None:
        initialize_odbc_connection()
    
    res_dict = defaultdict(list)
    entities = load_json(entities_path)
    count = 0
    for entity in entities:
        query = """
        SPARQL DESCRIBE {}
        """.format('ns:' + entity)
        print('count: {}'.format(count))
        count += 1
        
        try:
            with odbc_conn.cursor() as cursor:
                cursor.execute(query)
                # rows = cursor.fetchall()
            rows = cursor.fetchmany(10000)
            for row in rows:
                if row[1] == 'http://rdf.freebase.com/ns/kg.object_profile.prominent_type':
                    if row[2].startswith('http://rdf.freebase.com/ns/'):
                        # res_dict[entity].append(row[2])
                        res_dict[entity].append(row[2].replace('http://rdf.freebase.com/ns/', ''))
        except Exception:
            # print(f"Query Execution Failed:{query1}")
            rows=[]
            # exit(0)
    
    dump_json(output_path, res_dict)

"""
copied from `relation_retrieval/sparql_executor.py`
"""

def get_freebase_relations_with_odbc(data_path, limit=100):
    """Get all relations of Freebase"""
    # build connection
    global odbc_conn
    if odbc_conn == None:
        initialize_odbc_connection()
    # {{ }}: to escape
    if limit > 0:
        query = """
        SPARQL SELECT DISTINCT ?p (COUNT(?p) as ?freq) WHERE {{
            ?subject ?p ?object
        }}
        LIMIT {}
        """.format(limit)
    else:
        query = """
        SPARQL SELECT DISTINCT ?p (COUNT(?p) as ?freq) WHERE {{
            ?subject ?p ?object
        }}
        """
    print('query: {}'.format(query))
    
    try:
        with odbc_conn.cursor() as cursor:
            cursor.execute(query)
            # rows = cursor.fetchall()
            rows = cursor.fetchmany(10000)
    except Exception:
        # print(f"Query Execution Failed:{query1}")
        rows=[]
        exit(0)
    
    rtn = []
    for row in rows:
        rtn.append([row[0], int(row[1])])
    
    if len(rtn) != 0:
        dump_json(rtn, data_path)

def freebase_relations_post_process(input_path, output_path):
    input_data = load_json(input_path)
    print(f'input length: {len(input_data)}')
    output_data = [item[0] for item in input_data]
    output_data = [item for item in output_data if item.startswith("http://rdf.freebase.com/ns/")]
    output_data = [item.replace('http://rdf.freebase.com/ns/', '') for item in output_data]
    output_data = list(set(output_data))
    print(f'output length: {len(output_data)}')
    dump_json(output_data, output_path)


    
   

In [4]:
from typing import Dict, Tuple, List
from collections import defaultdict
import json
domain_info = defaultdict(lambda: 'base')
with open('/media/disk1/chatgpt/zh/ChatKBQA/rng-kbqa/framework/ontology/domain_info', 'r') as f:
    # domain_info = json.load(f)
    domain_info.update(json.load(f))

with open('/media/disk1/chatgpt/zh/ChatKBQA/rng-kbqa/framework/ontology/fb_roles', 'r') as f:
    contents = f.readlines()

with open('/media/disk1/chatgpt/zh/ChatKBQA/rng-kbqa/framework/ontology/fb_types', 'r') as f:
    type_infos = f.readlines()

subclasses = defaultdict(lambda: set())
for line in type_infos:
    fields = line.split()
    subclasses[fields[2]].add(fields[0])
    subclasses[fields[2]].add(fields[2])

# subclasses = {k: v for k, v in sorted(subclasses.items(), key=lambda x: len(x[1]), reverse=True)}

domain_dict_relations = defaultdict(lambda: set())
domain_dict_types = defaultdict(lambda: set())

relations_info: Dict[str, Tuple] = {}  # stores domain and range information for all relations
date_relations = set()
numerical_relations = set()

for line in contents:
    fields = line.split()
    domain_dict_relations[domain_info[fields[1]]].add(fields[1])
    domain_dict_types[domain_info[fields[0]]].add(fields[0])
    domain_dict_types[domain_info[fields[2]]].add(fields[2])
    relations_info[fields[1]] = (fields[0], fields[2])
    if fields[2] in ['type.int', 'type.float']:
        numerical_relations.add(fields[1])
    elif fields[2] == 'type.datetime':
        date_relations.add(fields[1])

In [2]:
def webqsp_legal_relation(r, num_entity_envolved):
    if r.endswith('#R'):
        r = r[:-2]
    # if dataset == 'webqsp':
    if r not in relations_info or r.startswith('common.') or r.startswith('type.') or r.startswith('kg.') or r.startswith('user.'):
        return False
    if num_entity_envolved == 2 and r.startswith('base.'):
        return False
    return True

In [16]:
in_count = 0
out_count = 0
for r in get_in_relations_with_odbc('m.03_r3'):
    if webqsp_legal_relation(r, 1):
        in_count += 1
for r in get_out_relations_with_odbc('m.03_r3'):
    if webqsp_legal_relation(r, 1):
        out_count += 1
print(len(get_in_relations_with_odbc('m.03_r3')))
print(len(get_out_relations_with_odbc('m.03_r3')))

88
122


In [12]:
entity = 'm.03_r3'
query4 = ("""

                PREFIX rdf: <http://www.w3.org/1999/02/22-rdf-syntax-ns#>
                PREFIX rdfs: <http://www.w3.org/2000/01/rdf-schema#>
                PREFIX : <http://rdf.freebase.com/ns/>
                SELECT distinct ?x1 WHERE {
                """
              ':' + entity + ' ?x0 ?x1 . '
                             """
                ?x1 ?y ?x2 .
                  FILTER regex(?x0, "http://rdf.freebase.com/ns/")
                  FILTER regex(?y, "http://rdf.freebase.com/ns/")
                  }   
                  """)  

sparql.setQuery(query4)

results = sparql.query().convert()
len(results['results']['bindings'])

3049

In [3]:
def get_another_entity(entity: str, relation: str, return_label=True):
        query1 = """PREFIX rdf: <http://www.w3.org/1999/02/22-rdf-syntax-ns#>
                PREFIX rdfs: <http://www.w3.org/2000/01/rdf-schema#>
                PREFIX : <http://rdf.freebase.com/ns/> 
                SELECT (?x0 AS ?value) WHERE {
                SELECT DISTINCT ?x0  WHERE {
                """ + '?x0' + ' :' + relation + ' :' + entity + '. ' + """
        }
        }"""
        query2 = """PREFIX rdf: <http://www.w3.org/1999/02/22-rdf-syntax-ns#>
                PREFIX rdfs: <http://www.w3.org/2000/01/rdf-schema#>
                PREFIX : <http://rdf.freebase.com/ns/> 
                SELECT (?x0 AS ?value) WHERE {
                SELECT DISTINCT ?x0  WHERE {
                """ + ':' + entity + ' :' + relation +  '?x0' + '. ' + """
        }  }"""
        if return_label:
                return [get_label_with_odbc(e) for e in execute_query(query1) + execute_query(query2)]

        return execute_query(query1) + execute_query(query2)


In [31]:
m_count_set  = set()
for r in results['results']['bindings']:
    if r['x1']['value'].replace('http://rdf.freebase.com/ns/', '').startswith('m.'):
        m_count_set.add(r['x1']['value'].replace('http://rdf.freebase.com/ns/', ''))
#统计entity个数，与上面的代码结果吻合
all_count = 0
m_count_set2 = set()
for r in get_out_relations_with_odbc('m.03_r3'):
    two_hop_re = set()
    for e in get_another_entity('m.03_r3', r, False):
        if e.startswith('m.'):
            m_count_set2.add(e)
            two_hop_re  = two_hop_re | get_out_relations_with_odbc(e)
    all_count += len(two_hop_re)

In [60]:
# get entity from entitty id and relation name
entity = 'm.01_2n'
query2 = """PREFIX rdf: <http://www.w3.org/1999/02/22-rdf-syntax-ns#>
        PREFIX rdfs: <http://www.w3.org/2000/01/rdf-schema#>
        PREFIX : <http://rdf.freebase.com/ns/> 
        
        SELECT (?x1 AS ?value) WHERE {
        SELECT DISTINCT ?x1  WHERE {
        """ + ':' + entity + ' :tv.tv_program.regular_cast ?x1 . ' + """
    }
    }"""

#filter的作用是筛选包含ns的relation
sparql.setQuery(query2)
results = sparql.query().convert()

In [68]:
from src.sparql_utils import get_friendly_name, get_label
get_friendly_name('m.0w2tqh3')

'null'

In [65]:
get_another_entity('m.01_2n', 'tv.tv_program.regular_cast', False)

['m.0wk9sqk',
 'm.0x12w15',
 'm.0bnjt4f',
 'm.0bvz0wn',
 'm.0mtj7y1',
 'm.0wkph1b',
 'm.0z3vtml',
 'm.0q54wmv',
 'm.0x00v8j',
 'm.0bmmtrm',
 'm.0r5_vr1',
 'm.0nh9b29',
 'm.0tm618z',
 'm.0tm61pg',
 'm.0bvyb5k',
 'm.0x03fqc',
 'm.0x01l7p',
 'm.0w_xg_v',
 'm.0khhm3s',
 'm.0nh9b1y',
 'm.0x0g7b2',
 'm.0krnwbm',
 'm.0x18h4t',
 'm.0wk6_z0',
 'm.0bw1lcr',
 'm.0x0s1qc',
 'm.0wkd82f',
 'm.0bvxqfy',
 'm.0x0x0zk',
 'm.010vy8pb',
 'm.0w_hmd9',
 'm.0w2tqh3',
 'm.0dk_sc3',
 'm.0dk_w2b',
 'm.0wlpq02',
 'm.0mtjdyy',
 'm.0w_hlz4',
 'm.0wl1lm0',
 'm.0wlvlyc',
 'm.010dqk3_',
 'm.0q54wn8',
 'm.0wlt9sz',
 'm.0bngkjc',
 'm.0nh99s3',
 'm.011201wm',
 'm.0whxg6v',
 'm.0z4425r',
 'm.0wysr6t',
 'm.0wj60mn',
 'm.0z43zrc',
 'm.0bmnfyw',
 'm.0bngy_q',
 'm.0dl1d63',
 'm.0y_v4cv',
 'm.0wkdl6m',
 'm.0_hq072',
 'm.0w_l_ym',
 'm.0wlt3vh',
 'm.0bnj82f',
 'm.0x0m4k4',
 'm.05lgy3b',
 'm.0wk93zz',
 'm.0wlxzxz',
 'm.0khhm4t',
 'm.0whz2c8',
 'm.0bvy1td',
 'm.0bvv6xf',
 'm.0bvv2dt',
 'm.0w_gn9p',
 'm.0bvvdsv',
 'm.0q54wmd',
 'm

In [29]:
candidate_relations = set({'fictional_universe.fictional_character.married_to',
 'fictional_universe.marriage_of_fictional_characters.spouses',
 'people.marriage.from',
 'people.marriage.spouse',
 'people.person.spouse_s'})
for can in candidate_relations:
    print(get_another_entity('m.06c97', can))


[None]
[None]
[]
[]
[]


In [17]:
get_another_entity('m.0yq3p6k','http://www.w3.org/2000/01/rdf-schema#label', False)

QueryBadFormed: QueryBadFormed: a bad request has been sent to the endpoint, probably the sparql query is bad formed. 

Response:
b"Virtuoso 37000 Error SP030: SPARQL compiler, line 8: syntax error at '/' before '/'\n\nSPARQL query:\ndefine sql:big-data-const 0 \n#output-format:application/sparql-results+json\nPREFIX rdf: <http://www.w3.org/1999/02/22-rdf-syntax-ns#>\n                PREFIX rdfs: <http://www.w3.org/2000/01/rdf-schema#>\n                PREFIX : <http://rdf.freebase.com/ns/> \n                SELECT (?x0 AS ?value) WHERE {\n                SELECT DISTINCT ?x0  WHERE {\n                ?x0 :http://www.w3.org/2000/01/rdf-schema#label :m.0yq3p6k. \n        }\n        }"

In [4]:
get_label('m.02h98gq')

In [3]:
from src.sparql_utils import get_friendly_name, get_label
get_friendly_name('m.02h98gq')

'null'

In [13]:
names = ['m.0cgngk0', 'm.0cg1k5z', 'm.0cgn2z1', 'm.0cgntj8', 'm.0cs8f63', 'm.0gz6dzv', 'm.0gz7hn1']
for n in names: 
    print(get_label_with_odbc(n))

None
None
None
None
None
None
None


In [4]:
entity = 'm.01ps2h8'
initialize_odbc_connection()

query = ("""SPARQL
            PREFIX rdf: <http://www.w3.org/1999/02/22-rdf-syntax-ns#>
            PREFIX rdfs: <http://www.w3.org/2000/01/rdf-schema#>
            PREFIX : <http://rdf.freebase.com/ns/> 
            SELECT (?x0 AS ?value) WHERE {
            SELECT DISTINCT ?x0  WHERE {
            """
              '{ ?x1 ?x0 ' + ':' + entity + ' }'
              + ' UNION '
              + '{ :' + entity + ' ?x0 ?x1 ' + '}'
                                          """
     FILTER regex(?x0, "http://rdf.freebase.com/ns/")
     }
     }
     """)
with odbc_conn.cursor() as cursor:
  cursor.execute(query)
  rows = cursor.fetchmany(10000)
#     # print(query1)

# sparql.setQuery(query)

# sparql.query().convert()

Freebase Virtuoso ODBC connected


In [6]:
relations = set()
for row in rows:
    relations.add(row[0].replace('http://rdf.freebase.com/ns/', ''))

### 存储relation embedding

In [1]:
from langchain_community.embeddings import HuggingFaceBgeEmbeddings

In [2]:
from langchain_openai import AzureOpenAIEmbeddings
import os
os.environ["AZURE_OPENAI_API_KEY"] = "2b219db0d2984f9dae28b651ab8ab3d9"
os.environ["AZURE_OPENAI_ENDPOINT"] = "https://smsh.openai.azure.com/"
os.environ["AZURE_OPENAI_API_VERSION"] = "2024-03-01-preview"
embeddings = AzureOpenAIEmbeddings(
    model="text-embedding-3-small",
)

In [4]:
from langchain.storage import LocalFileStore, RedisStore
from langchain.embeddings import CacheBackedEmbeddings
from langchain_community.vectorstores import FAISS
store = RedisStore(redis_url="redis://localhost:6379")
cached_embedder = CacheBackedEmbeddings.from_bytes_store(
embeddings, store, namespace="openai"
)
row_string = []
with open('./data/relations', 'r') as f:
    data = f.readlines()
db = FAISS.from_texts(data, cached_embedder)

In [7]:
import pickle
# 两跳内的所有relation
def pickle_load(file_path: str, default=None):
    if not os.path.exists(file_path):
        return default
    with open(file_path, 'rb') as file:
        obj = pickle.load(file)
    return obj
pick_data = pickle_load('./data/webqsp_2hop_relations.pkl')

In [None]:
count = 0
for i in range(len(webqsp_data)):
    gold_relations = [item for d in webqsp_data[i]['Parses']  if d['InferentialChain'] is not None for item in d['InferentialChain']]
    pred = [re.search(r'<fb:(.*?)>', p.page_content).group(1) for p in db.similarity_search(webqsp_data[i]['RawQuestion'], k=10)]
    intersection_count = len(set(gold_relations) & set(pred))
    if intersection_count != 0:
        count += 1
    else:
        print(i)


### 测试entity检索的回归率和召回率


In [94]:
def compute_recall(answer: List, candidate: List):
    golden_set = set(answer)
    pred_set = set(candidate)
    return len(golden_set.intersection(pred_set)) / len(golden_set)

def compute_precision(answer: List, candidate: List):
    golden_set = set(answer)
    pred_set = set(candidate)
    return len(golden_set.intersection(pred_set)) / len(pred_set)

In [84]:
import json
gnnrag_data = []
with open('./data/test.info', 'r') as f:
    info_data = f.readlines()
    for l in info_data:
        gnnrag_data.append(json.loads(l))


In [12]:
with open('/media/disk1/chatgpt/zh/ChatKBQA/data/dataset/WebQSP/entity_linking_results/webqsp_test_oracle_mid.json', 'r') as f:
    tiara_data = json.load(f)

In [6]:
import json
rog_data = []
with open('/media/disk1/chatgpt/zh/reasoning-on-graphs/results/gen_rule_path/RoG-webqsp/RoG/test/predictions_3_False.jsonl', 'r') as f:
    for line in f.readlines():
        rog_data.append(json.loads(line))
print(len(rog_data))

1628


In [7]:
with  open('./data/WebQSP.test.expr.json', 'r') as f:
    webqsp_data = json.load(f)

In [119]:
tiara_recall = []
for i in range(len(webqsp_data)):
    gold_entities = [d['TopicEntityMid'] for d in webqsp_data[i]['Parses']]
    cand_entities = elq_data[i]['freebase_ids']
    tiara_recall.append(compute_recall(gold_entities, cand_entities))


In [96]:
for i in range(len(tiara_data)):
    tiara_recall = compute_recall(tiara_data[i]['freebase_ids'], [c[0] for c in gnnrag_data[i]['cand']])
    gnn_recall = gnnrag_data[i]['recall']
    print(tiara_recall)
    print(gnn_recall)
    break

0.0
1.0


### Few-shot


In [8]:
from langchain.prompts.prompt import PromptTemplate
from langchain.prompts.few_shot import FewShotPromptTemplate
from langchain.chains import LLMChain
from src.sparql_utils import *
examples_prompt = PromptTemplate(input_variables=["query", "new_query"], template=
"""Question: {query}
Relation Paths: 
{relations}""")
examples_dict = [{"query": "what is the name of justin bieber brother?",
                    "relations": "[people.person.gender], [people.sibling_relationship.sibling]"},
                {"query": "what character did natalie portman play in star wars?",
                    "relations": "[film.actor.film, film.film_character.portrayed_in_films], [film.actor.film, film.performance.character]"},
                {"query": "what country is the grand bahama island in?",
                    "relations": "[location.location.contains]"},]
hypo_prompt = FewShotPromptTemplate(
    examples=examples_dict,
    example_prompt=examples_prompt,
    prefix="""Please generate a variety of possible valid relation paths to address the related questions. Each path contained in a list. Below are some examples.""",
    suffix=
    """Question: {query}
Relation Paths: 
    """,
    input_variables=["query"],
)

In [9]:
from langchain_openai import AzureChatOpenAI
import os


os.environ["AZURE_OPENAI_API_KEY"] = "2b219db0d2984f9dae28b651ab8ab3d9"
os.environ["AZURE_OPENAI_ENDPOINT"] = "https://smsh.openai.azure.com/"
os.environ["AZURE_OPENAI_API_VERSION"] = "2024-02-01"
os.environ["AZURE_OPENAI_CHAT_DEPLOYMENT_NAME"] = "gpt-35-turbo"

# os.environ["AZURE_OPENAI_API_KEY"] = "0c75de50975e4f278b882fe90da47f2f"
# os.environ["AZURE_OPENAI_ENDPOINT"] = "https://ces.openai.azure.com"
# os.environ["AZURE_OPENAI_API_VERSION"] = "2024-02-01"
# os.environ["AZURE_OPENAI_CHAT_DEPLOYMENT_NAME"] = "gpt-35-turbo"

# os.environ["AZURE_OPENAI_API_KEY"] = "796da5bbe9994bd2b3eff9cbd91e35e5"
# os.environ["AZURE_OPENAI_ENDPOINT"] = "https://bxcl-prod.openai.azure.com/"
# os.environ["AZURE_OPENAI_API_VERSION"] = "2024-02-01"
# os.environ["AZURE_OPENAI_CHAT_DEPLOYMENT_NAME"] = "gpt-4o"
model = AzureChatOpenAI(
    openai_api_version=os.environ["AZURE_OPENAI_API_VERSION"],
    azure_deployment=os.environ["AZURE_OPENAI_CHAT_DEPLOYMENT_NAME"],
    temperature=0.3,
    n = 3,
    max_retries=5, request_timeout=600
)


In [10]:
from langchain.prompts.prompt import PromptTemplate
from langchain.prompts.few_shot import FewShotPromptTemplate
from langchain.chains import LLMChain
reasoning_prompt = PromptTemplate(input_variables=["path", "query"], template=
"""Based on the reasoning paths and answer the question. Directly return all the possible answers. Only return the string instead of other format information.
Reasoning Paths:
{path}
Question:
{query}""")


In [11]:
import os
import json

def incrementally_add_to_json(file_path, input_dict):
    # 确保input_dict包含'id'键
    if 'id' not in input_dict:
        raise ValueError("Input dictionary must contain an 'id' key.")
    
    # 检查文件是否存在
    if os.path.exists(file_path):
        # 文件存在，读取现有内容
        with open(file_path, 'r', encoding='utf-8') as file:
            existing_data = json.load(file)
            # 确保existing_data是一个列表
            if not isinstance(existing_data, list):
                raise ValueError("JSON file must contain a list of dictionaries.")
            # 检查'id'是否已存在
            if any(item['id'] == input_dict['id'] for item in existing_data):
                # 如果'id'已存在，则不添加input_dict
                print('Exists already')
                return
            # 将输入字典添加到列表中
            existing_data.append(input_dict)
    else:
        # 文件不存在，创建新列表并添加输入字典
        existing_data = [input_dict]
    
    # 将更新后的列表写回文件
    with open(file_path, 'w', encoding='utf-8') as file:
        json.dump(existing_data, file, ensure_ascii=False, indent=4)

# 示例用法
# incrementally_add_to_json_list_if_unique('path/to/your/file.json', {'id': 'unique_id', 'key': 'value'})


In [13]:
import re


with open('/media/disk1/chatgpt/zh/ChatKBQA/rng-kbqa/WebQSP/misc/webqsp_test_elq-5_mid.json', 'r') as f:
    linking_data = json.load(f)
verbose = False
llm_chain = LLMChain(llm=model, prompt=hypo_prompt, verbose=verbose)
reasoning_chain = LLMChain(llm=model, prompt=reasoning_prompt, verbose=verbose)

for i in range(1000, len(linking_data)):
    try:
        res = llm_chain.batch([{"query": linking_data[i]['text']}], return_only_outputs=True)
    except ValueError as e:
            print(f'******************Value Error {i}****************************')
            continue
    possible_start = set()
    possible_answer = set()
    useful_paths = []
    pattern = r'\[(.*?)\]'
    for r in re.findall(pattern, res[0]['text']):
        possible_start.update(set(r.split(',')))
    for pos in possible_start:
        possible_relations = set([re.search(r'<fb:(.*?)>', p.page_content).group(1) for p in db.similarity_search(pos, k=5)])
        # print(possible_relations)
        for can in possible_relations:
            for en in get_another_entity(linking_data[i]['freebase_ids'][0], can):
                if en:
                    possible_answer.add(en)
                    useful_paths.append(f"Relation: {can}, Entity: {en}")
    # Add token allocation
    try: 

        final_answer = reasoning_chain.batch([{"query": linking_data[i]['text'], "path": '\n'.join(useful_paths)[:16000]}], return_only_outputs=True)[0]['text']
    except ValueError as e:
            print(f'******************Value Error {i}****************************')
            continue
    incrementally_add_to_json('./output/test.json', dict({"id":linking_data[i]['id'], "path": useful_paths, "predict": final_answer, "label": webqsp_data[i]['Parses'][0]['Answers']}))
    # print('---------Answers--------------')
    # print('Predict: {}, Label: {}'.format(final_answer, webqsp_data[i]['Parses'][0]['Answers']))
    print(f"Process {linking_data[i]['id']}")



Process WebQTest-1254
Process WebQTest-1255
Process WebQTest-1256
Process WebQTest-1257
Process WebQTest-1259
Process WebQTest-1260
Process WebQTest-1261
Process WebQTest-1262
Process WebQTest-1263
Process WebQTest-1264
Process WebQTest-1265
Process WebQTest-1266
Process WebQTest-1267
Process WebQTest-1268
Process WebQTest-1269
Process WebQTest-1270
Process WebQTest-1271
Process WebQTest-1272
Process WebQTest-1274
Process WebQTest-1275
Process WebQTest-1276
Process WebQTest-1277
******************Value Error 1022****************************
Process WebQTest-1279
Process WebQTest-1280
Process WebQTest-1281
Process WebQTest-1282
Process WebQTest-1284
Process WebQTest-1285
Process WebQTest-1287
Process WebQTest-1289
Process WebQTest-1291
Process WebQTest-1292
Process WebQTest-1293
Process WebQTest-1294
Process WebQTest-1295
Process WebQTest-1296
Process WebQTest-1297
Process WebQTest-1298
Process WebQTest-1300
Process WebQTest-1301
Process WebQTest-1303
Process WebQTest-1304
Process WebQT

In [59]:
get_another_entity('m.01_2n', ' tv.tv_program.regular_cast', False)

QueryBadFormed: QueryBadFormed: a bad request has been sent to the endpoint, probably the sparql query is bad formed. 

Response:
b"Virtuoso 37000 Error SP030: SPARQL compiler, line 8: syntax error at 'tv.tv_program.regular_cast' before ':m.01_2n'\n\nSPARQL query:\ndefine sql:big-data-const 0 \n#output-format:application/sparql-results+json\nPREFIX rdf: <http://www.w3.org/1999/02/22-rdf-syntax-ns#>\n                PREFIX rdfs: <http://www.w3.org/2000/01/rdf-schema#>\n                PREFIX : <http://rdf.freebase.com/ns/> \n                SELECT (?x0 AS ?value) WHERE {\n                SELECT DISTINCT ?x0  WHERE {\n                ?x0 : tv.tv_program.regular_cast :m.01_2n. \n        }\n        }"

In [17]:
print(webqsp_data[0]['Parses'][0]['Answers'])

[{'AnswerType': 'Entity', 'AnswerArgument': 'm.01428y', 'EntityName': 'Jamaican English'}, {'AnswerType': 'Entity', 'AnswerArgument': 'm.04ygk0', 'EntityName': 'Jamaican Creole English Language'}]


In [50]:

test_query = "PREFIX ns: <http://rdf.freebase.com/ns/>\nSELECT DISTINCT ?x\nWHERE {\n  {\n    SELECT ?pFrom \n    WHERE {\n      ns:m.042f1 ns:government.politician.government_positions_held ?y . # James K. Polk\n      ?y ns:government.government_position_held.office_position_or_title ?x ; \n         ns:government.government_position_held.basic_title ns:m.060c4 ; # President\n         ns:government.government_position_held.from ?pFrom .\n    }\n  }\n  ns:m.042f1 ns:government.politician.government_positions_held ?y . # James K. Polk\n  ?y ns:government.government_position_held.office_position_or_title ?x ; \n     ns:government.government_position_held.from ?from .\n  \n  FILTER(xsd:dateTime(?pFrom) - xsd:dateTime(?from) > 0)\n}"
execute_query_with_odbc(test_query)

Freebase Virtuoso ODBC connected


{'http://rdf.freebase.com/ns/m.02_bcst',
 'http://rdf.freebase.com/ns/m.04x_n9q',
 'http://rdf.freebase.com/ns/m.0cgqx'}

In [70]:
rog_data[2]

{'id': 'WebQTest-3',
 'question': 'who plays ken barlow in coronation street',
 'prediction': [['tv.tv_program.country_of_origin',
   'people.person.nationality'],
  ['tv.regular_tv_appearance.series', 'tv.regular_tv_appearance.actor'],
  ['tv.regular_tv_appearance.series', 'tv.tv_actor.starring_roles']],
 'ground_paths': [],
 'input': '[INST] <<SYS>>\n<</SYS>>\nPlease generate a valid relation path that can be helpful for answering the following question: who plays ken barlow in coronation street [/INST]',
 'raw_output': {'paths': ['<PATH> tv.tv_program.country_of_origin <SEP> people.person.nationality </PATH>',
   '<PATH> tv.regular_tv_appearance.series <SEP> tv.regular_tv_appearance.actor </PATH>',
   '<PATH> tv.regular_tv_appearance.series <SEP> tv.tv_actor.starring_roles </PATH>'],
  'scores': [-0.017409520223736763,
   -0.026553694158792496,
   -0.03162452578544617],
  'norm_scores': [0.335933119058609, 0.3328752815723419, 0.3311915993690491]}}

### Evaluate

In [94]:
with open('./output/test.json', 'r') as f:
    predict_data = json.load(f)

In [97]:
predict_data[0]['id']

'WebQTest-0'

In [96]:
from src.utils import extract_topk_prediction, eval_acc, eval_f1, eval_hit

In [102]:
acc_list = []
hit_list = []
f1_list = []
precission_list = []
recall_list = []
cal_f1 =True
for data in predict_data:
    id = data['id']
    prediction = data['predict']
    answer = [l['EntityName'] if l['AnswerType'] == "Entity" else l['AnswerArgument'] for l in data['label'] ]
    if cal_f1:
        if not isinstance(prediction, list):
            prediction = prediction.split(",")
        else:
            prediction = extract_topk_prediction(prediction, 2)
        f1_score, precision_score, recall_score = eval_f1(prediction, answer)
        f1_list.append(f1_score)
        precission_list.append(precision_score)
        recall_list.append(recall_score)
        prediction_str = ' '.join(prediction)
        acc = eval_acc(prediction_str, answer)
        hit = eval_hit(prediction_str, answer)
        acc_list.append(acc)
        hit_list.append(hit)
        incrementally_add_to_json('./output/eval.json', 
                                  dict({'id': id, 'prediction': prediction, 'ground_truth': answer, 'acc': acc, 'hit': hit, 'f1': f1_score, 'precission': precision_score, 'recall': recall_score}))
    else:
        acc = eval_acc(prediction, answer)
        hit = eval_hit(prediction, answer)
        acc_list.append(acc)
        hit_list.append(hit)
        incrementally_add_to_json('./output/eval.json', dict({'id': id, 'prediction': prediction, 'ground_truth': answer, 'acc': acc, 'hit': hit}))

if len(f1_list) > 0:
    result_str = "Accuracy: " + str(sum(acc_list) * 100 / len(acc_list)) + " Hit: " + str(sum(hit_list) * 100 / len(hit_list)) + " F1: " + str(sum(f1_list) * 100 / len(f1_list)) + " Precision: " + str(sum(precission_list) * 100 / len(precission_list)) + " Recall: " + str(sum(recall_list) * 100 / len(recall_list))
else:
    result_str = "Accuracy: " + str(sum(acc_list) * 100 / len(acc_list)) + " Hit: " + str(sum(hit_list) * 100 / len(hit_list))
print(result_str)


Exists already
Exists already
Exists already
Exists already
Exists already
Exists already
Exists already
Exists already
Exists already
Exists already
Exists already
Exists already
Exists already
Exists already
Exists already
Exists already
Exists already
Exists already
Exists already
Accuracy: 64.4120707596254 Hit: 77.41935483870968 F1: 67.29646697388633 Precision: 75.60483870967742 Recall: 64.4120707596254


In [None]:
##### grounding by graph
#grounding by relation path

def process_data(data, remove_duplicate=False):
    question = data['question']
    graph  =  utils.build_graph(data['graph'])
    paths = utils.get_truth_paths(data['q_entity'], data['a_entity'], graph)
    result = []
    # Split each Q-P pair into a single data
    rel_paths = []
    for path in paths:
        rel_path = [p[1] for p in path] # extract relation path
        if remove_duplicate:
            if tuple(rel_path) in rel_paths:
                continue
        rel_paths.append(tuple(rel_path))
    for rel_path in rel_paths:
        result.append({"question": question, "path": rel_path})
    return result