In [None]:
from tqdm import tqdm
import random
from random import sample
import networkx as nx
import re, logging
import openai, datetime, os


def load_openai_keys():
    keys = []
    with open('../openai_keys_filter.txt', "r") as f:
        for line in f:
            key = line.strip().split()
            keys.append(key[-1])
    return keys
openai_api_keys = load_openai_keys()
random.shuffle(openai_api_keys)
def update_key():
    curr_key = openai_api_keys[0]
    openai.api_key = curr_key
    openai_api_keys.remove(curr_key)
    openai_api_keys.append(curr_key)

def get_logger(filename, verbosity=1, name=None):
    level_dict = {0: logging.DEBUG, 1: logging.INFO, 2: logging.WARNING}
    formatter = logging.Formatter(
        "[%(asctime)s][%(filename)s][line:%(lineno)d][%(levelname)s] %(message)s"
    )
    logger = logging.getLogger(name)
    logger.setLevel(level_dict[verbosity])

    # Remove any existing handlers
    for handler in logger.handlers:
        logger.removeHandler(handler)
    # Output to file
    fh = logging.FileHandler(filename, "w")
    fh.setFormatter(formatter)
    logger.addHandler(fh)

    # # Output to terminal
    # sh = logging.StreamHandler()
    # sh.setFormatter(formatter)
    # logger.addHandler(sh)

    return logger

# multiple-list save to logging
def list2str(l):
    s = ''
    for i in l:
        s += str(i) + '\r'
    return s
def list_equal(a, answers):
    # if set(a) == set(b):
    #     return True
    # else:
    #     return False
    for b in answers:
        if set(a) == set(b):
            return True
    return False
def dict2str(d):
    s = ''
    for k in d:
        s += str(k) + '\t' + str(d[k]) + '\r'
    return s

## read data

In [None]:
rid = 0
id2rel = dict()
rel2id = dict()
rel2sym = dict()
rel2sym_2 = dict()
relation_txt = ''

infer_rel = list()
with open("../symbolic_tree/1.relations", 'r') as f:
    for line in f:
        _, rel = line.strip().split()

        infer_rel.append(rel)

        relation_txt += rel + ', '
        id2rel[rid] = rel
        rel2id[rel] = rid
        rid += 1

extra_relations = ["greatAuntUncleOf","grandparentOf","greatGrandparentOf","auntUncleOf","siblingOf","secondAuntUncleOf","childOf","grandchildOf","greatGrandchildOf","nieceNephewOf","cousinOf","secondCousinOf","firstCousinOnceRemovedOf", "male", "female"]

for rel in extra_relations:
    id2rel[rid] = rel
    rel2id[rel] = rid
    rid += 1
    
with open("rel2sym_ri.txt","r") as fr:
    for line in fr:
        rel, sym = line.strip().split()
        rel2sym_2[rel] = sym
        rel2sym[rel] = '$' + sym + '$'


In [None]:

rh2rules = dict()
id = 1
rule_text = ''
with open('natural_rules.txt','r') as f:
    for line in f:
        rh2rules[line.strip().split('\t')[0]] = line.strip().split('\t')[-1]
        rule_text += 'L' + str(id) + ": " + line.strip().split('\t')[-1] + '\n'
        id += 1
# print(rule_text)

In [None]:
def read_entity(path, eid, id2ent, ent2id ):
    with open(path, 'r') as f:
        for line in f:
            _, ent = line.strip().split()
            if ent not in ent2id:
                id2ent[eid] = ent
                ent2id[ent] = eid
                eid += 1
                
        return eid
    
def get_related_triplets(h, t, G, entpair2rel):
    input_text = ''
    for path in sorted(nx.all_simple_edge_paths(G, h, t, cutoff=5)):
        for edge in path:
            # print(edge)
            if edge in entpair2rel:
                input_text += entpair2rel[edge] + '(' + edge[0] + ', ' + edge[1] + ')\n'
                # input_text += edge[0] + ' is the ' + entpair2rel[edge] + ' of ' + edge[1] + '. '
            else:
                input_text += entpair2rel[(edge[1],edge[0])] + '(' + edge[1] + ', ' + edge[0] + ')\n'
                # input_text += edge[1] + ' is the ' + entpair2rel[(edge[1],edge[0])] + ' of ' + edge[0] + '. '
    return input_text

def read_all_triplets(path1, path2, id2ent, text):
    triplets = list()
    entpair2rel = dict()
    with open(path1,'r') as f:
        for line in f:
            flag, h, r, t = line.strip().split()
            triplets.append((id2ent[int(h)], id2rel[int(r)], id2ent[int(t)]))
            # entpair2rel[(id2ent[int(h)], id2ent[int(t)])] = rel2sym[id2rel[int(r)]]
            entpair2rel[(id2ent[int(h)], id2ent[int(t)])] = id2rel[int(r)]
            # text += id2ent[int(h)] + ' is the ' + rel2sym[id2rel[int(r)]] + ' of ' + id2ent[int(t)] + '. '
            text += rel2sym[id2rel[int(r)]] + '(' + id2ent[int(h)] + ', ' + id2ent[int(t)] + ')\n'
            # text += id2ent[int(t)] + ' is the ' + rel2sym['inverse_' + id2rel[int(r)]] + ' of ' + id2ent[int(h)] + '. '
            # text += id2ent[int(h)] + ' is the ' + id2rel[int(r)] + ' of ' + id2ent[int(t)] + '. '



    with open(path2,'r') as f:
        for line in f:
            flag, h, r, t = line.strip().split()
            if flag == '+':
                # triplets.append((id2ent[int(h)], id2rel[int(r)], id2ent[int(t)]))
                # test_triplets.append((id2ent[int(h)], id2rel[int(r)], id2ent[int(t)]))
                entpair2rel[(id2ent[int(h)], id2ent[int(t)])] = id2rel[int(r)]
                # triplets.append((id2ent[int(h)], id2rel[int(r)], id2ent[int(t)]))
                # text += id2ent[int(h)] + ' is the ' + rel2sym[id2rel[int(r)]] + ' of ' + id2ent[int(t)] + '. '
                text += rel2sym[id2rel[int(r)]] + '(' + id2ent[int(h)] + ',' + id2ent[int(t)] + ')\n'
                # text += id2ent[int(t)] + ' is the ' + rel2sym['inverse_' + id2rel[int(r)]] + ' of ' + id2ent[int(h)] + '. '

                # text += id2ent[int(h)] + ' is the ' + id2rel[int(r)] + ' of ' + id2ent[int(t)] + '. '


    return triplets, entpair2rel, text


def read_class(path, cid, ent2class, id2ent, class_text):
    with open(path, 'r') as f:
        for line in f:
            female, male  = line.strip().split()
            if female == '1':
                # ent2class[id2ent[cid]] = rel2sym['female']
                # ent2class[id2ent[cid]] = 'female'

                # class_text += id2ent[cid] + ' is the ' + rel2sym["female"] + '. '
                class_text += rel2sym['female'] + '(' + id2ent[cid] + ')\n'
                # class_text += id2ent[cid] + ' is a ' + "female" + '. '

                # class_text += ('female'+'(' + id2ent[cid] + ')')
            else:
                # ent2class[id2ent[cid]] = rel2sym['male']
                ent2class[id2ent[cid]] = 'male'

                # class_text += id2ent[cid] + ' is the '+ rel2sym['male'] + '. '
                class_text += rel2sym['male'] + '(' + id2ent[cid] + ')\n'
                # class_text += id2ent[cid] + ' is the '+ 'male' + '. '

                # class_text += ('male'+'(' + id2ent[cid] + ')')

            cid += 1
        return cid, class_text


def read_all_facts(path1, path2, path_class, id2ent, ent2class, cid, text):
    f_id = 1
    triplets = list()
    test_triplets = list()
    entpair2rel = dict()
    edges = list()
    tri2number = dict()

    with open(path1,'r') as f:
        for line in f:
            flag, h, r, t = line.strip().split()
            triplets.append((id2ent[int(h)], id2rel[int(r)], id2ent[int(t)]))
            edges.append((id2ent[int(h)], id2ent[int(t)]))
            # edges.append((id2ent[int(t)], id2ent[int(h)]))
            entpair2rel[(id2ent[int(h)], id2ent[int(t)])] = rel2sym[id2rel[int(r)]]
            # entpair2rel[(id2ent[int(t)], id2ent[int(h)])] = rel2sym['inverse_' + id2rel[int(r)]]
            # entpair2rel[(id2ent[int(h)], id2ent[int(t)])] = id2rel[int(r)]
            text += 'F' + str(f_id) + ": " + id2ent[int(h)] + ' is ' + rel2sym[id2rel[int(r)]] + ' of ' + id2ent[int(t)] + '.\n'
            # text += 'F' + str(f_id) + ": " + rel2sym[id2rel[(int(r))]] + '(' + id2ent[int(h)] + ',' + id2ent[int(t)] + ')\n'
            # text += id2ent[int(h)] + ' is the ' + rel2sym[id2rel[int(r)]] + ' of ' + id2ent[int(t)] + '.\n'
            
            tri2number[(id2ent[int(h)], rel2sym[id2rel[int(r)]], id2ent[int(t)])] = 'F' + str(f_id)
            f_id += 1
            # text += 'F' + str(f_id) + ": " + id2ent[int(t)] + ' is the ' + rel2sym['inverse_' + id2rel[int(r)]] + ' of ' + id2ent[int(h)] + '.\n'
            # # text += id2ent[int(t)] + ' is the ' + rel2sym['inverse_' + id2rel[int(r)]] + ' of ' + id2ent[int(h)] + '.\n'
            
            # tri2number[(id2ent[int(t)], rel2sym['inverse_' + id2rel[int(r)]], id2ent[int(h)])] = 'F' + str(f_id)
            
            # f_id += 1
            
            # text += id2ent[int(h)] + ' is the ' + id2rel[int(r)] + ' of ' + id2ent[int(t)] + '. '

    with open(path2,'r') as f:
        for line in f:
            flag, h, r, t = line.strip().split()
            if flag == '+':
                # triplets.append((id2ent[int(h)], id2rel[int(r)], id2ent[int(t)]))
                test_triplets.append((id2ent[int(h)], rel2sym[id2rel[int(r)]], id2ent[int(t)]))
                # entpair2rel[(id2ent[int(h)], id2ent[int(t)])] = id2rel[int(r)]
                
                triplets.append((id2ent[int(h)], id2rel[int(r)], id2ent[int(t)]))
                # text += 'F' + str(f_id) + ": " + id2ent[int(h)] + ' is the ' + rel2sym[id2rel[int(r)]] + ' of ' + id2ent[int(t)] + '. '
                # f_id += 1
                
                # text += id2ent[int(t)] + ' is the ' + rel2sym['inverse_' + id2rel[int(r)]] + ' of ' + id2ent[int(h)] + '. '

                # text += id2ent[int(h)] + ' is the ' + id2rel[int(r)] + ' of ' + id2ent[int(t)] + '. '

    with open(path_class, 'r') as f:
        for line in f:
            female, male  = line.strip().split()
            if female == '1':
                ent2class[id2ent[cid]] = rel2sym['female']
                # ent2class[id2ent[cid]] = 'female'
        
                text += 'F' + str(f_id) + ": " + id2ent[cid] + ' is ' + rel2sym["female"] + '.\n'
                # text += 'F' + str(f_id) + ': ' + rel2sym['female'] + '(' + id2ent[cid] + ')\n'
                # text += id2ent[cid] + ' is the ' + rel2sym["female"] + '.\n'

                tri2number[(id2ent[cid], 'gender', rel2sym['female'])] = 'F' + str(f_id)
                # class_text += id2ent[cid] + ' is a ' + "female" + '. '

                # class_text += ('female'+'(' + id2ent[cid] + ')')
            else:
                ent2class[id2ent[cid]] = rel2sym['male']
                # ent2class[id2ent[cid]] = 'male'

                text += 'F' + str(f_id) + ": " + id2ent[cid] + ' is '+ rel2sym['male'] + '.\n'
                # text += 'F' + str(f_id) + ': ' + rel2sym['male'] + '(' + id2ent[cid] + ')\n'

                # text += id2ent[cid] + ' is the '+ rel2sym['male'] + '.\n'

                tri2number[(id2ent[cid], 'gender', rel2sym['male'])] = 'F' + str(f_id)
                
                # class_text += id2ent[cid] + ' is the '+ 'male' + '. '

                # class_text += ('male'+'(' + id2ent[cid] + ')')
            f_id += 1
            cid += 1
            
    return triplets, test_triplets, entpair2rel, cid, text, edges, tri2number, f_id



    
def get_explain_grounding_truth(test_triplets, edges, entpair2rel, ent2class, rel2rules,  tri2number, rule2number):

    # Define the logical rules that the paths should match
    def logical_rules(entpair2rel, path, rule):
        path_number = list()
        for i in range(len(path)-1):
            if (path[i], path[i+1]) in entpair2rel:
                if entpair2rel[(path[i], path[i+1])] == rule[i]:
                    path_number.append(tri2number[(path[i], entpair2rel[(path[i], path[i+1])], path[i+1])])
            elif '$r45$' == rule[i]:
                path_number.append(tri2number[(path[i+1], entpair2rel[(path[i+1], path[i])], path[i])])

            else:
                
                return None
        return path_number



    # Define your knowledge graph using the NetworkX library
    G = nx.Graph()
    G.add_edges_from(edges)

    fact2explain = dict()
    fact2rule = dict()
    for tri in test_triplets:
        h = tri[0]
        r = tri[1]
        t = tri[2]
        rule = rel2rules[r]
        length = len(rule)
        
        all_paths = list()
        # Find all paths that match the logical rules using NetworkX's all_simple_paths() function
        for path in nx.all_simple_paths(G, source=h, target=t, cutoff=length):
            # print("path", path)
            path_number = logical_rules(entpair2rel, path, rule)
            if path_number:
                if ent2class[h] == rule[-1]:
                    path_number.append(tri2number[(h,'gender',ent2class[h])])
                    
                    path_number.append(rule2number[r])
                    all_paths.append(path_number)
        fact2explain[(h, r, t)] = all_paths
        # fact2explain[(h,r,t)] = [rule2number[r]]
        # fact2rule[(h,r,t)] = rule2number[r]

    return fact2explain


def read_rules(path, rel2sym):
    rel2rules = dict()
    rule2number = dict()
    l_id = 1
    with open(path, 'r') as f:
        for line in f:
            lst = line.strip().split('\t')
            # replace symbol 
            new_lst = list()
            for l in lst:
                new_lst.append(rel2sym[l])
            
            rel2rules[new_lst[0]] = new_lst[1:]
            rule2number[new_lst[0]] = 'L' + str(l_id)
            l_id += 1
    return rel2rules, rule2number, l_id




In [None]:
# read data
nowTime = datetime.datetime.now().strftime('%Y-%m-%d-%H-%M-%S')
dir = 'logs/natural_cot_f+r'
if not os.path.exists(dir):
        os.makedirs(dir)
logger = get_logger(os.path.join(dir, nowTime + '.log'), verbosity=1)
# model = 'gpt-4'
model = "gpt-3.5-turbo"

logging.info('model: ' + model)

for i in range(0, 1):
        
        eid = 0
        cid = 0
        id2ent = dict()
        ent2id = dict()
        ent2class = dict()
        ent2triplets = dict()
        class_text = ''
        text = ''

        path_ent = "../symbolic_tree/" + str(i) + ".individuals"
        path_rel1 = "../symbolic_tree/"+str(i)+".relations.data"
        path_rel2 = "../symbolic_tree/"+str(i)+".relations.data.inf"
        path_class = "../symbolic_tree/" + str(i) + ".classes.data"

        path_rule = 'rule_tab.txt'

        eid = read_entity(path_ent,eid, id2ent,ent2id)
        # cid, class_text = read_class(path_class, cid, ent2class, id2ent, class_text)
        
        # triplets, entpair2rel, text = read_all_triplets(path_rel1, path_rel2, id2ent, text)

        
        triplets, test_triplets, entpair2rel, cid, text, edges, tri2number, fid = read_all_facts(path_rel1, path_rel2, path_class, id2ent, ent2class, cid, text)
        
        rel2rules, rule2number, lid = read_rules(path_rule, rel2sym)
        fact2explain = get_explain_grounding_truth(test_triplets, edges, entpair2rel, ent2class, rel2rules, tri2number, rule2number)
        # test_questions = random.sample(triplets, int(len(triplets) * 0.2))
        # train_questions = triplets.copy()
        # for t in test_questions:
        #     train_questions.remove(t)
        true_num = 0
        false_num = 0
        num = 0


        record_flag = False

        for triple in tqdm(test_triplets):
                h, r, t = triple
                # text_pred = r + '(' + h + ', ' + t + ')'
                text_pred = h + ' is ' + r + ' of ' + t + '.'
                # print("We also have some facts. " + class_text + input_text + text_pred + ' If yes, please answer only with 1 else 0')
                
                # message = {
                #             'system': "You are a helpful assistant. I will give you some logical rules and facts. Please select one single rule and a few facts to explain the following statement. ",
                #             # 'user': "I will give you some logical rules, facts and a statement.\nThe logical rules are:\n" + rule_text + "\nThe facts are:\n" + text + "\nThe statement is: " + text_pred + "\nPlease first select one logical rule that can infer the statement and then select multiple facts to match the logical rule. The selected logical rule and facts can entail the statement.\nPlease output the numbers of logical rule and facts.\nThe selected logical rule and facts are:",
                #             'user': "I will provide a set of logical rules L1 to L" + str(lid - 1) + " and facts F1 to F" + str(fid - 1) + ". Please select one single logical rule from L1 to L" + str(lid - 1) + " and a few facts from F1 to F" + str(fid - 1) + " to explain the following statement. " +
                #              "\nRules:\n" + rule_text + "\nFacts:\n" + text + "\nStatement: " + text_pred + "\nThe selected logical rule and facts are: ",
                                
                #             }
                # message_1 = {
                #         'system': "You are a helpful assistant. I will give you some logical rules, facts and a statement. Please select one logical rule and multiple facts to explain the statement. ",
                #         'user': "I will give you some logical rules, facts and a statetement. Please select one logical rule and multiple facts to explain the statement. \nThe logical rules are:\n" + rule_text 
                #         + "\nThe facts are:\n" + text + "\nThe statement is: " + text_pred + "\nAnswer with the numbers of the selected rule and facts. Which rule and facts should be selected? Let's think step by step. ",
                # }
                message_1 = {
                        'system': "You are a helpful assistant with abductive reasoning abilities. Please select one single logical rule and a few facts to explain the following statement. ",
                        'user': "I will provide a set of logical rules L1 to L" + str(lid - 1) + " and facts F1 to F" + str(fid - 1) + ". Please select one single logical rule from L1 to L" + str(lid - 1) + " and a few facts from F1 to F" + str(fid - 1) + " to explain the following statement. \nRules:\n" + rule_text 
                        + "\nFacts:\n" + text + "\nStatement: " + text_pred + "\nAnswer with the numbers of the selected rule and facts. The selected rule and facts are: Let's think step by step.",
        
                }
                # message_1 = {
                #      'system': "You are a helpful assistant. I will provide a set of rules and facts. Please select one single logical rule and a few facts to explain the following statement. ",
                #      'user': "I will provide a set of logical rules, facts. Please select one single logical rule and a few facts to explain the following statement. \nRules:\n" + rule_text 
                #      + "\nFacts:\n" + text + "\nStatement: " + text_pred + "\nAnswer with the numbers of the selected rule and facts. The selected rule and facts are: Let's think step by step. ",
                # }


                server_error_cnt = 0
                
                while server_error_cnt<10:
                        try:
                        
                                update_key()
                                
                                response = openai.ChatCompletion.create(
                                        # model="gpt-4",
                                        model=model,


                                        messages=[
                                                {"role": "system", "content": message_1['system']},
                                                {"role": "user", "content": message_1['user']},
                                                ],
                                        temperature=0,

                                )
                                results_1 = response['choices'][0]['message']['content']
                                
                                message_2 = {
                                        # 'system': "You are a helpful assistant. I will give you some logical rules and facts. Please output the numbers (e.g. F1) of selected logical rule and facts. ",

                                        'system': "You are a helpful assistant. Please output the numbers (e.g. L1, F1, F2) of the selected logical rule and facts. ",
                                        'user' : message_1['user'] + '\n' + results_1 + "\nTherefore, the selected logical rule and facts are: "
                                        
                                        }
                                update_key()
                                response = openai.ChatCompletion.create(
                                        model=model,

                                        messages=[
                                                
                                                {"role": "system", "content": message_2['system']},
                                                {"role": "user", "content": message_2['user']},
                                        ],
                                        temperature=0,
                                )
                                results_2 = response['choices'][0]['message']['content']
                                break
                        except Exception as e:
                                server_error_cnt += 1
                                print(e)

                if record_flag == False:

                        logger.info('message_1: \n' + dict2str(message_1))
                        logger.info('message_2: \n' + dict2str(message_2)) 

                        record_flag = True


                
                number_list = re.findall(r'[FL]\d+', results_2)
                # print(results)
                # print(last_line)
                num += 1
                
                logger.info("statement: " + text_pred )
                logger.info("results_1: %s", results_1)
                logger.info("results_2: %s", results_2)
                logger.info('LLM: %s', number_list)
                logger.info("grounding_truth: %s", list2str(fact2explain[(h,r,t)]))
                
                if list_equal(number_list, fact2explain[(h,r,t)]):
                        logger.info("correct")
                        true_num += 1
                else:
                        false_num += 1
                

        logger.info("accuracy: " + str( true_num / num ))