In [None]:
from tqdm import tqdm
import random
from random import sample
import networkx as nx
import re, logging
import openai, datetime, os

# prompts = generate_few_shot_prompts(predicted_facts, model, lid, fid, rules, basic_facts,labels)
def load_openai_keys():
    keys = []
    with open('../openai_keys_filter.txt', "r") as f:
        for line in f:
            key = line.strip().split()
            keys.append(key[-1])
    return keys
openai_api_keys = load_openai_keys()

def update_key():
    curr_key = openai_api_keys[0]
    openai.api_key = curr_key
    openai_api_keys.remove(curr_key)
    openai_api_keys.append(curr_key)


## read data

In [None]:
rid = 0
id2rel = dict()
rel2id = dict()
rel2sym = dict()
rel2sym_2 = dict()
relation_txt = ''

infer_rel = list()
with open("../symbolic_tree/1.relations", 'r') as f:
    for line in f:
        _, rel = line.strip().split()

        infer_rel.append(rel)

        relation_txt += rel + ', '
        id2rel[rid] = rel
        rel2id[rel] = rid
        rid += 1

extra_relations = ["greatAuntUncleOf","grandparentOf","greatGrandparentOf","auntUncleOf","siblingOf","secondAuntUncleOf","childOf","grandchildOf","greatGrandchildOf","nieceNephewOf","cousinOf","secondCousinOf","firstCousinOnceRemovedOf", "male", "female"]

for rel in extra_relations:
    id2rel[rid] = rel
    rel2id[rel] = rid
    rid += 1
    
with open("rel2sym_ri.txt","r") as fr:
    for line in fr:
        rel, sym = line.strip().split()
        rel2sym_2[rel] = sym
        rel2sym[rel] = '$' + sym + '$'

grounding_truth = dict()

with open('rule_symbolic_first_order.txt','r') as f:
    for line in f:
        rel, rule = line.strip().split('\t')
       
        grounding_truth['$' + rel + '$'] = rule


In [None]:
# get "\forall A,B,C,D: "
def get_prefix(length):
    h = ord('A')
    text = '\\forall '
    for _ in range(length):
        text += chr(h) + ', '
        h += 1
    text += chr(h) + ': '
    return text

# from rule_parent.txt get the length of each rule
rule_length = dict()
with open('rule_tab.txt', 'r') as f:
    for line in f:
        lst = line.strip().split('\t')
        rule_length[rel2sym[lst[0]]] = len(lst[1:]) - 1

In [None]:
def get_logger(filename, verbosity=1, name=None):
    level_dict = {0: logging.DEBUG, 1: logging.INFO, 2: logging.WARNING}
    formatter = logging.Formatter(
        "[%(asctime)s][%(filename)s][line:%(lineno)d][%(levelname)s] %(message)s"
    )
    logger = logging.getLogger(name)
    logger.setLevel(level_dict[verbosity])

    # Remove any existing handlers
    for handler in logger.handlers:
        logger.removeHandler(handler)
    # Output to file
    fh = logging.FileHandler(filename, "w")
    fh.setFormatter(formatter)
    logger.addHandler(fh)

    # # Output to terminal
    # sh = logging.StreamHandler()
    # sh.setFormatter(formatter)
    # logger.addHandler(sh)

    return logger

def read_entity(path, eid, id2ent, ent2id ):
    with open(path, 'r') as f:
        for line in f:
            _, ent = line.strip().split()
            if ent not in ent2id:
                id2ent[eid] = ent
                ent2id[ent] = eid
                eid += 1
                
        return eid
    
# def get_related_triplets(h, t, G, entpair2rel):
#     input_text = ''
#     for path in sorted(nx.all_simple_edge_paths(G, h, t, cutoff=5)):
#         for edge in path:
#             # print(edge)
#             if edge in entpair2rel:
#                 input_text += edge[0] + ' is ' + entpair2rel[edge] + ' of ' + edge[1] + '. '
#             else:
#                 input_text += edge[1] + ' is ' + entpair2rel[(edge[1],edge[0])] + ' of ' + edge[0] + '. '
#     return input_text

def read_all_triplets(path1, path2, id2ent, text, fid):
    triplets = list()
    entpair2rel = dict()
    with open(path1,'r') as f:
        for line in f:
            flag, h, r, t = line.strip().split()
            triplets.append((id2ent[int(h)], id2rel[int(r)], id2ent[int(t)]))
            # entpair2rel[(id2ent[int(h)], id2ent[int(t)])] = rel2sym[id2rel[int(r)]]
            entpair2rel[(id2ent[int(h)], id2ent[int(t)])] = id2rel[int(r)]

            text += 'F' + str(fid) + ': ' + rel2sym[id2rel[int(r)]] + '(' + id2ent[int(h)] + ', ' + id2ent[int(t)] + ')\n'
            fid += 1
            text += 'F' + str(fid) + ': ' + rel2sym['inverse_' + id2rel[int(r)]] + '(' + id2ent[int(t)] + ', ' + id2ent[int(h)] + ')\n'
            fid += 1


            # text += id2ent[int(h)] + ' is ' + rel2sym[id2rel[int(r)]] + ' of ' + id2ent[int(t)] + '. '
            # text += id2ent[int(t)] + ' is ' + rel2sym['inverse_' + id2rel[int(r)]] + ' of ' + id2ent[int(h)] + '. '
            # text += id2ent[int(h)] + ' is the ' + id2rel[int(r)] + ' of ' + id2ent[int(t)] + '. '



    with open(path2,'r') as f:
        for line in f:
            flag, h, r, t = line.strip().split()
            if flag == '+':
                # triplets.append((id2ent[int(h)], id2rel[int(r)], id2ent[int(t)]))
                # test_triplets.append((id2ent[int(h)], id2rel[int(r)], id2ent[int(t)]))
                entpair2rel[(id2ent[int(h)], id2ent[int(t)])] = id2rel[int(r)]
                triplets.append((id2ent[int(h)], id2rel[int(r)], id2ent[int(t)]))
                # text += id2ent[int(h)] + ' is ' + rel2sym[id2rel[int(r)]] + ' of ' + id2ent[int(t)] + '. '
                # text += 'F' + str(fid) + ': ' + rel2sym[id2rel[int(r)]] + '(' + id2ent[int(h)] + ', ' + id2ent[int(t)] + ')\n'
                # fid += 1
                # text += id2ent[int(t)] + ' is the ' + rel2sym['inverse_' + id2rel[int(r)]] + ' of ' + id2ent[int(h)] + '. '

                # text += id2ent[int(h)] + ' is the ' + id2rel[int(r)] + ' of ' + id2ent[int(t)] + '. '


    return triplets, entpair2rel, text, fid

    
def read_class(path, cid, ent2class, id2ent, class_text, fid):
    with open(path, 'r') as f:
        for line in f:
            female, male  = line.strip().split()
            if female == '1':
                # ent2class[id2ent[cid]] = rel2sym['female']
                # ent2class[id2ent[cid]] = 'female'

                # class_text += id2ent[cid] + ' is ' + rel2sym["female"] + '. '
                
                class_text += 'F' + str(fid) + ': ' + rel2sym["female"] + '(' + id2ent[cid] + ')\n'
                fid += 1
                # class_text += id2ent[cid] + ' is a ' + "female" + '. '

                # class_text += ('female'+'(' + id2ent[cid] + ')')
            else:
                # ent2class[id2ent[cid]] = rel2sym['male']
                ent2class[id2ent[cid]] = 'male'

                # class_text += id2ent[cid] + ' is '+ rel2sym['male'] + '. '
                class_text += 'F' + str(fid) + ': ' + rel2sym["male"] + '(' + id2ent[cid] + ')\n'
                fid += 1

                # class_text += id2ent[cid] + ' is the '+ 'male' + '. '

                # class_text += ('male'+'(' + id2ent[cid] + ')')

            cid += 1
        return cid, class_text, fid




# def read_rules(path):
#     rules = list()
#     grounding_truth = dict()
#     rel2rules = dict()
#     with open(path, 'r') as f:
#         for line in f:
#             lst = line.strip().split('\t')
#             rules.append(lst)
#             if lst[0] not in rel2rules:
#                 rel2rules[lst[0]] = list() 
#             grounding_truth[lst[0]].append(lst) 
#     return rules, rel2rules

def get_relation_facts(triplets, rel):
    related_triplets_text = ''
    gid = 1
    for tri in triplets:
        if rel2sym[tri[1]] == rel:
            # related_triplets_text += tri[0] + ' is ' + rel2sym[tri[1]] + ' of ' + tri[2] + '. '
            related_triplets_text += 'G' + str(gid) + ': ' + rel2sym[tri[1]] + '(' + tri[0] + ', ' + tri[2] + ')\n'
            gid += 1

    return related_triplets_text, gid


In [None]:
# based on rule_length to output the template like $\_(a,b) \land \_(b,c) \land \_(a)$ 


template = dict()

for key in rule_length:
    
    template[key] = '$'
    template[key] += get_prefix(rule_length[key])
    ent_h = ord('A')
    for i in range(rule_length[key]):
        ent_h = ent_h
        ent_t = ent_h + 1
        
        template[key] += '##(' + chr(ent_h) + ', ' + chr(ent_t) + ') \land '
        ent_h = ent_t 
    template[key] += '++(A) \\rightarrow ' + key[1:-1] + '(A, ' + chr(ent_t) + ')$'

In [None]:

def dict2str(d):
    s = ''
    for k in d:
        s += str(k) + '\t' + str(d[k]) + '\r'
    return s

In [None]:
# # based on rule_length to output the template like $\_(a,b) \land \_(b,c) \land \_(a)$ 
# template = dict()

# for key in rule_length:
#     template[key] = 'If '
#     ent_h = ord('A')
#     for i in range(rule_length[key]):
#         ent_h = ent_h
#         ent_t = ent_h + 1
#         template[key] += chr(ent_h) + ' is ## of ' + chr(ent_t) + ' and '
#         # template[key] += '##(' + chr(ent_h) + ',' + chr(ent_t) + ') \land '
#         ent_h = ent_t 
#     template[key] += 'A is ++, then A is ' + key + ' of ' + chr(ent_t) + '.'

In [None]:
# # read rule_latext.txt and get the latex format of each rule
# rule_latex = dict()
# id = 1
# with open('../rule_latex.txt', 'r') as f:
#     for line in f:
        
#         rule_latex[rel2sym[infer_rel[id]]] = line.strip()
#         id += 1

In [None]:
# read data
nowTime = datetime.datetime.now().strftime('%Y-%m-%d-%H-%M-%S')
dir = 'logs/first_order_zero_shot_cot'
if not os.path.exists(dir):
        os.makedirs(dir)
logger = get_logger(os.path.join(dir, nowTime + '.log'), verbosity=1)

# model = "gpt-3.5-turbo"
model = "gpt-4"
logger.info('model: ' + model)
for i in range(0, 10):
    
    eid = 0
    cid = 0
    fid = 1
    id2ent = dict()
    ent2id = dict()
    ent2class = dict()
    ent2triplets = dict()
    class_text = ''
    text = ''

    path_ent = "../symbolic_tree/" + str(i) + ".individuals"
    path_rel1 = "../symbolic_tree/"+str(i)+".relations.data"
    path_rel2 = "../symbolic_tree/"+str(i)+".relations.data.inf"
    path_class = "../symbolic_tree/" + str(i) + ".classes.data"

    # path_rule = '../rules_tab.txt'

    eid = read_entity(path_ent, eid, id2ent,ent2id)
    cid, class_text, fid = read_class(path_class, cid, ent2class, id2ent, class_text, fid)
    
    triplets, entpair2rel, text, fid = read_all_triplets(path_rel1, path_rel2, id2ent, text, fid)

    

    record_flag = False
    id = 0
    scores = 0
    for rel_origin in tqdm(infer_rel[1:]):
            # print(rel)
            rel = rel2sym[rel_origin]
            relation_specific_text, gid = get_relation_facts(triplets, rel)
            server_flag = 0
            server_error_cnt = 0
            while server_error_cnt < 10:
                try:    
                        # update_key()
                        # message_1 = {
                        #         'system': "You are a helpful assistant with inductive reasoning abilities. I will give you a set of facts F1 to F99, facts G1 to G"+ str(gid-1) +" and a template for a logical rule. Please fill in the template so that the generated rule can logically entail the facts G1 to G" + str(gid-1) + " based on facts F1 to F99. ",
                        #         'user': "I will give you a set of facts F1 to F99, facts G1 to G"+ str(gid-1) +" and a template for a logical rule. Please fill in the template so that the generated rule can logically entail the facts G1 to G" + str(gid-1) + " based on facts F1 to F99.\nFacts: " + class_text + text + relation_specific_text + '\nTemplate: ' + template[rel]
                        #         + "\nNote that the symbol '##' in the template should be filled with either 'r1' or 'r45', while the symbol '++' should be filled with either 'r43' or 'r44'.\nAfter filling in the template, what is the predict rule? Let's think step by step. ",
                        #         }
                        # message_1 = {
                        #         'system': "You are a helpful assistant with inductive reasoning abilities. I will give you a set of facts and a template for a rule. Please generate a rule based on given facts and the template. ",
                        #         'user': "I will give you a set of facts F1 to F" + str(fid - 1) + ", facts G1 to G"+ str(gid-1) +" and a template for a rule. Please generate a rule based on given facts and the template. fill in the template so that the generated rule can logically entail the facts G1 to G" + str(gid-1) + " based on facts F1 to F" + str(fid - 1) + ".\nFacts:\n" + class_text + text + relation_specific_text + '\nTemplate: ' + template[rel]
                        #         + "\nNote that the symbol '##' in the template should be filled with either 'r1' or 'r45', while the symbol '++' should be filled with either 'r43' or 'r44'.\nAfter filling in the template, the predicted rule is: Let's think step by step. ",
                        #         }
                        # message_1 = {
                        #        'system': "You are a helpful assistant with inductive reasoning abilities.",
                        #        'user': "I will provide a set of facts and a template for a logical rule. Please fill in the rule template such that the generated rule can entail these facts.\nHere are some facts: " + class_text +  text + relation_specific_text + '\nHere is the template: ' + template[rel] + "\nNote that the symbol '##' in the template should be filled with either 'r1' or 'r45', while the symbol '++' should be filled with either 'r43' or 'r44'.\nAfter filling in the template, what is the predict rule? Let's think step by step.",
                        # }
                        message_1 = {
                        #        'system': "You are a helpful assistant with inductive reasoning abilities.",
                               
                               'system': "You are a helpful assistant with inductive reasoning abilities. Please generate one single rule to match the template and logically entail the facts. Note that the symbol '##' in the template should be filled with either 'r1' or 'r45', while the symbol '++' should be filled with either 'r43' or 'r44'.",
                        #        'system': "You are a helpful assistant with inductive reasoning abilities. I will provide a set of facts and a template for a rule. Please generate a rule based on given facts and the template. ",
                        #        'user': "Given a set of facts and a template for rule, please induce a rule that logically entails the facts B and matches the template.\nThe facts A are: " + class_text + text + "\nThe facts B are: " + relation_specific_text + "Please find the paths (including relation) connecting facts B to induce a rule.\nTemplate: " + template[rel] + "\nNote that 1) '##' should be replaced with 'r1' or 'r45' while '++' should be replaced with 'r43' or 'r44'.\n2) 'r1', 'r45', 'r43' and 'r44' are different relations.\nAfter replacing the special '##' and '++', the logical rule is: ",
                                'user': "I will give you a set of facts F1 to F" + str(fid - 1) + ", facts G1 to G"+ str(gid-1) +" and a template for a logical rule. Please fill in the template so that the generated rule can logically entail the facts G1 to G" + str(gid-1) + " based on facts F1 to F" + str(fid - 1) + ".\nFacts:\n" + class_text + text + relation_specific_text + '\nTemplate: ' + template[rel]
                                + "\nNote that the symbol '##' in the template should be filled with either 'r1' or 'r45', while the symbol '++' should be filled with either 'r43' or 'r44'.\nAfter filling in the template, the generated rule is: Let's think step by step. ",
                                
                        
                        }
                        response = openai.ChatCompletion.create(
                                model= model,
                                messages=[
                                        {"role": "system", "content": message_1['system']},
                                        {"role": "user", "content": message_1['user']},
                        ],
                        temperature=0,
                        )
                        results = response['choices'][0]['message']['content']
                        
                        message_2 = {
                                'system': "Please fill in the template and output the generated rule.",
                                'user': message_1['user'] + '\n' + results + '\nThe generated rule is: '}       
                        
                        response = openai.ChatCompletion.create(
                                model= model,
                                messages=[
                                        {"role": "system", "content": message_2['system']},
                                        {"role": "user", "content": message_2['user']},
                                ],
                                temperature=0,
                                )
                        # results = response['choices'][0]['message']['content']
                        if record_flag == False:

                                logger.info('message_1: \n' + dict2str(message_1))
                                logger.info('message_2: \n' + dict2str(message_2))

                                record_flag = True
                
                        logger.info("template: " + template[rel])

                        logger.info('prediction' + ": "+ results)
                        logger.info("grounding_truth: " + grounding_truth[rel])
                        
                        if grounding_truth[rel] in results:
                                logger.info("correct")
                                scores += 1
                        logger.info("============================================================")
                        id += 1
                        break
                
                except Exception as e:
                        server_error_cnt += 1
                        print(e)


    logger.info("accuracy: " + str(scores/id))
    