In [None]:
from tqdm import tqdm
import random
from random import sample
import networkx as nx
import re, logging
import openai, datetime, os
openai.api_key = ""
def load_openai_keys():
    keys = []
    with open('../openai_keys.txt', "r") as f:
        for line in f:
            key = line.strip().split('----')
            keys.append(key[-1])
    return keys
openai_api_keys = load_openai_keys()

def update_key():
    curr_key = openai_api_keys[0]
    openai.api_key = curr_key
    openai_api_keys.remove(curr_key)
    openai_api_keys.append(curr_key)

In [None]:
rid = 0
id2rel = dict()
rel2id = dict()
rel2sym = dict()

with open("../symbolic_tree/1.relations", 'r') as f:
    for line in f:
        _, rel = line.strip().split()
        id2rel[rid] = rel
        rel2id[rel] = rid
        rid += 1

with open("rel2sym.txt","r") as fr:
    for line in fr:
        rel, sym = line.strip().split()
        if 'Of' in sym:

            rel2sym[rel] = sym[:-2]
        else:
            rel2sym[rel] = sym
        # rel2sym[rel] = sym 

In [None]:
def get_logger(filename, verbosity=1, name=None):
    level_dict = {0: logging.DEBUG, 1: logging.INFO, 2: logging.WARNING}
    formatter = logging.Formatter(
        "[%(asctime)s][%(filename)s][line:%(lineno)d][%(levelname)s] %(message)s"
    )
    logger = logging.getLogger(name)
    logger.setLevel(level_dict[verbosity])

    # Remove any existing handlers
    for handler in logger.handlers:
        logger.removeHandler(handler)
    # Output to file
    fh = logging.FileHandler(filename, "w")
    fh.setFormatter(formatter)
    logger.addHandler(fh)

    # # Output to terminal
    # sh = logging.StreamHandler()
    # sh.setFormatter(formatter)
    # logger.addHandler(sh)

    return logger


def read_entity(path, eid, id2ent, ent2id ):
    with open(path, 'r') as f:
        for line in f:
            _, ent = line.strip().split()
            if ent not in ent2id:
                id2ent[eid] = ent
                ent2id[ent] = eid
                eid += 1
                
        return eid
    
def read_class(path, cid, ent2class, id2ent, class_text, fid):
    with open(path, 'r') as f:
        for line in f:
            female, male  = line.strip().split()
            if female == '1':
                ent2class[id2ent[cid]] = rel2sym['female']
                # ent2class[id2ent[cid]] = 'female'

                # class_text += 'F' + str(fid) + ': ' + rel2sym['female'] + '(' + id2ent[cid] + ')\n'
                class_text += 'F' + str(fid) + ': ' + id2ent[cid] + ' is ' + rel2sym['female'] + '.\n'
                
                fid += 1
                # class_text += id2ent[cid] + ' is a ' + "female" + '. '

                # class_text += ('female'+'(' + id2ent[cid] + ')')
            else:
                ent2class[id2ent[cid]] = rel2sym['male']
                # ent2class[id2ent[cid]] = 'male'

                # class_text += 'F' + str(fid) + ': ' + rel2sym['male'] + '(' + id2ent[cid] + ')\n'
                class_text += 'F' + str(fid) + ': ' + id2ent[cid] + ' is ' + rel2sym['male'] + '.\n'
                fid += 1
                # class_text += id2ent[cid] + ' is a '+ 'male' + '. '

                # class_text += ('male'+'(' + id2ent[cid] + ')')

            cid += 1
        return cid, class_text, fid
    
# def get_related_triplets(h, t, G, entpair2rel):
#     input_text = ''
#     for path in sorted(nx.all_simple_edge_paths(G, h, t, cutoff=5)):
#         for edge in path:
#             # print(edge)
#             if edge in entpair2rel:
#                 input_text += edge[0] + ' is ' + entpair2rel[edge] + ' of ' + edge[1] + '. '
#             else:
#                 input_text += edge[1] + ' is ' + entpair2rel[(edge[1],edge[0])] + ' of ' + edge[0] + '. '
#     return input_text
# def get_rule_text(rule_heads, rule2text):
#     rule_text = ''
#     for rule_head in rule_heads:
#         rule_text += rule2text[rule_head]
#     return rule_text

In [None]:
rules = ''
rel2rule = dict()
id = 1
with open("natural_rules.txt", 'r') as f1:
    for line in f1:
        rel, rule = line.strip().split('\t')
        rel2rule[rel] = rule
        rules += 'L' + str(id) + ': ' + rule + '\n'
        id += 1

def dict2str(d):
    s = ''
    for k in d:
        s += str(k) + '\t' + str(d[k]) + '\r'
    return s
def get_negative_samples(triplets, id2ent, id2rel, labels):
    neg_samples = []
    # random sample head or tail
    
    
    for i in range(len(triplets)):
        
        while 1:
            if random.random() < 0.5:
                # sample head
                h = random.randint(0, len(id2ent) - 1)
                t = triplets[i][2]
                r = triplets[i][1]
                if (id2ent[int(h)], r, t) not in triplets:
                    neg_samples.append((id2ent[int(h)], r, t))
                    labels[(id2ent[int(h)], r, t)] = 0
                    break
            else:
                h = triplets[i][0]
                t = random.randint(0, len(id2ent) - 1)
                r = triplets[i][1]
                if (h, r, id2ent[int(t)]) not in triplets:
                    neg_samples.append((h, r, id2ent[int(t)]))
                    labels[(h, r, id2ent[int(t)])]=0
                    break

    return neg_samples

In [None]:
# compute F beta[0,0.1:1] score
def compute_f_beta_score(precision, recall, logger):
    for beta in range(0, 11):
        beta = beta / 10
        if precision == 0 and recall == 0:
            logger.info('beta: '+ str(beta) + '\rF score: 0')
            
        else:
            score = (1 + beta * beta) * precision * recall / (beta * beta * precision + recall)
            logger.info('beta: '+ str(beta) + '\tF score: ' + str(score))
def output_relation_acc(relation_true_num, relation_num, logger):
    for rel in relation_true_num:
        logger.info(rel + '\t' + str(relation_true_num[rel] / relation_num[rel]))


In [None]:
nowTime = datetime.datetime.now().strftime('%Y-%m-%d-%H-%M-%S')
dir = 'logs/natural_standard_facts'
if not os.path.exists(dir):
        os.makedirs(dir)
logger = get_logger(os.path.join(dir, nowTime + '.log'), verbosity=1)

model = "gpt-3.5-turbo"
# model = "gpt-4"
logging.info('model: ' + model)

record_flag = False

for i in range(0, 1):
    id2ent = dict()
    ent2id = dict()
    ent2class = dict()
    ent2triplets = dict()

    eid = 0

    cid = 0
    fid = 1
    class_text = ''
    triplets = []
    test_triplets = []
    labels = dict()
    entpair2rel = dict() 
    basic_facts = ''
    statement = ''
    path_ent = "../symbolic_tree/" + str(i) + ".individuals"
    path_class = "../symbolic_tree/" + str(i) + ".classes.data"
    eid = read_entity(path_ent,eid, id2ent,ent2id)
    cid, class_text, fid = read_class(path_class, cid, ent2class,id2ent, class_text, fid)
    # print(i)
    path = "../symbolic_tree/"+str(i)+".relations.data"
    basic_facts += class_text

    relation_num = dict()
    relation_true_num = dict()

    with open(path,'r') as f:
        for line in f:
            flag, h, r, t = line.strip().split()
            triplets.append((id2ent[int(h)], id2rel[int(r)], id2ent[int(t)]))
            entpair2rel[(id2ent[int(h)], id2ent[int(t)])] = rel2sym[id2rel[int(r)]]
            
            basic_facts += 'F' + str(fid) + ': ' + id2ent[int(h)] + ' is ' + rel2sym[id2rel[int(r)]] + ' of ' + id2ent[int(t)] + '.\n'
            fid += 1
            # text += id2ent[int(h)] + ' is the ' + id2rel[int(r)] + ' of ' + id2ent[int(t)] + '. '

            # text += (id2rel[int(r)] + '(' + id2ent[int(h)], id2ent[int(t)] +')')
    path = "../symbolic_tree/"+str(i)+".relations.data.inf"

    with open(path,'r') as f:
        for line in tqdm(f):
            flag, h, r, t = line.strip().split()
            if flag == '+':
                test_triplets.append((id2ent[int(h)], id2rel[int(r)], id2ent[int(t)]))
                labels[(id2ent[int(h)], id2rel[int(r)], id2ent[int(t)])] = 1
    negative_samples = get_negative_samples(test_triplets, id2ent, id2rel, labels)
            # if flag == '-':
            #     test_triplets.append((id2ent[int(h)], id2rel[int(r)], id2ent[int(t)]))
            #     labels.append(0)

    num = 0
    true_num = 0
    false_num = 0

    pos_true = 0
    pos_false = 0

    neg_true = 0
    neg_false = 0

    predicted_facts = test_triplets + negative_samples
    # random order in predicted_facts
    random.shuffle(predicted_facts)
    for triple in tqdm(predicted_facts):
        h, r, t = triple
        # statement = rel2sym[r] + '(' + h + ', ' + t + ')'
        statement = h + ' is ' + rel2sym[r] + ' of ' + t + '.' 
        if r not in relation_num:
            relation_num[r] = 0
            relation_true_num[r] = 0
        relation_num[r] += 1
        # message = {
        #             'systerm': "You are a helpful assistant.",
        #             'user': "I will provide you with logical rules and facts. Please identify all paths connecting " + h + "with " + t + ". Then, predict the correctness of the following statement using deductive reasoning. \nLogical rules: " + rel2rule[rel2sym[r]] + "\nFacts: " + basic_facts + "\nStatement: " + statement + "\nOutput True or False?"
        #         }
        # message = {
        #             'systerm': "You are a helpful assistant with deductive reasoning abilities. You can first identify the logical rule relevant to the relation " + rel2sym[r] + " and then find all paths connecting " + h + " with " + t + ". Based on this information, predict the correctness of the following statement using deductive reasoning.",
        #             'user': "I will provide a set of logical rules and facts. Please identify the logical rules relevant to the relation " + rel2sym[r] + " and find all paths connecting " + h + " with " + t + ". Based on this information, predict the correctness of the following statement using deductive reasoning.\nLogical rules: " + rules + "\nFacts: " + basic_facts + "\nStatement: " + statement + "\nPlease answer with only True, False or Unknown. The answer is: "
        #         }
        
        # message = {
        #             'systerm': "You are a helpful assistant with deductive reasoning abilities. You can first identify the logical rule relevant to the relation " + rel2sym[r] + " and then find all paths connecting " + h + " with " + t + ". Based on this information, predict the correctness of the following statement using deductive reasoning.",
        #             'user': "I will provide a set of logical rules and facts. Please identify the logical rules relevant to the relation " + rel2sym[r] + " and find all paths connecting " + h + " with " + t + ". Based on this information, determine whether the following statement can be inferred.\nLogical rules: " + rules + "\nFacts: " + basic_facts + "\nStatement: " + statement + "\nPlease answer with only Yes, No or Unknown. The answer is: "
        #         }
        # message = {
        #             'system': "You are a helpful assistant.",
        #             'user': "I will provide a set of logical rules L1 to L28 and facts F1 to F63. Please predict True/False of the following statement using deductive reasoning.\nLogical rules:\n" + rules + "\nFacts:\n" + basic_facts + "\nStatement: " + statement + "\nThe answer (True or False) is: "
                
        #         }
        # message = {
        #             'system': "You are a helpful assistant with deductive reasoning abilities. ",
        #             'user': "I will provide a set of logical rules and facts. Please select one single logical rule from L1 to L28 and a few facts from F1 to F63 to predict True/False of the unknown fact using deductive reasoning.\nLogical rules:\n" + rules + "\nFacts:\n" + basic_facts + "\nUnknown fact: " + statement + "\nThe answer (True or False) is: "
        #         }
        # message = {
        #             'system': "You are a helpful assistant with deductive reasoning abilities. ",
        #             'user': "Given a set of rules and facts, you have to reason whether a statement is True or False.\nHere are some rules:\n" + rules + "\nHere are some facts:\n" + basic_facts + "\nDoes it imply that the statement \""+ statement + "\" is True?\nThe answer (YES or NO) is: "
        #         }
        # message = {
        #             'system': "Please answer the question only with True or False. ",
        #             'user': "I will provide a set of facts. Please predict True/False of the unknown fact based on given facts.\nFacts:\n" + basic_facts + "\nUnknown fact: " + statement + "\nThe answer (True or False) is: "
        #         }
        message = {
                    'system': "Please select a few facts to predict True/False of the unknown fact. ",
                    'user': "I will provide a set of facts F1 to F" + str(fid - 1) + ". Please select a few facts from F1 to F" + str(fid - 1) + " to predict True/False of the unknown fact.\nFacts:\n" + basic_facts + "\nUnknown fact: " + statement + "\nThe answer (True or False) is: "
                }
        # message = {
        #             'system': "You are a helpful assistant.",
        #             'user': "I will provide a set of logical rules and facts. Please predict True/False of the following statement using deductive reasoning.\nLogical rules:\n" + rules + "\nFacts:\n" + basic_facts + "\nStatement: " + statement + "\nThe answer (True or False) is: "
        #         }
        server_error_cnt = 0
        while server_error_cnt<10:
            try:
                update_key()
                response = openai.ChatCompletion.create(
                model= model,
                messages=[
                        {"role": "system", "content": message['system']},
                        {"role": "user", "content": message['user']},
                ],
                temperature=0,
                )

                if record_flag == False:
                    logger.info('message: \n' + dict2str(message))
                    record_flag = True

                results = response['choices'][0]['message']['content']
                num += 1
                
                ans = results.split('.')[0]
                if labels[(h, r, t)] == 1:
                    if 'True' in ans:
                        true_num += 1
                        pos_true += 1
                        relation_true_num[r] += 1
                        logger.info('correctness: ' + 'Correct')
                    elif 'False' in ans:
                        false_num += 1
                        pos_false += 1
                        logger.info('correctness: ' + 'Incorrect')
                    elif 'Unknown' in ans:
                        false_num += 1
                        pos_false += 1
                        logger.info('correctness: ' + 'Incorrect')
                        print(results)
                else:
                    if 'True' in ans :
                        false_num += 1
                        neg_false += 1
                        logger.info('correctness: ' + 'Incorrect')
                    elif 'False' in ans:
                        true_num += 1
                        neg_true += 1
                        relation_true_num[r] += 1
                        logger.info('correctness: ' + 'Correct')
                    elif 'Unknown' in ans:
                        true_num += 1
                        neg_true += 1
                        relation_true_num[r] += 1
                        logger.info('correctness: ' + 'Correct')
                        print(results)
                
                logger.info('triplet: ' + statement + '\tgrounding truth: ' + str(labels[(h, r, t)]) + '\tprediction: ' + results )

                break

            except Exception as e:
                server_error_cnt += 1
                logger.info(e)
    logger.info(str(i) + ': ' + str(true_num / num))
    logger.info('pos_acc: ' + str(pos_true / (pos_true + pos_false)))
    logger.info('neg_acc: ' + str(neg_true / (neg_true + neg_false)))
    TP = pos_true
    FN = pos_false
    FP = neg_false
    TN = neg_true
    logger.info('precision: ' + str(TP / (TP + FP)))
    logger.info('recall: ' + str(TP / (TP + FN)))
    compute_f_beta_score(TP / (TP + FP), TP / (TP + FN), logger)
    output_relation_acc(relation_true_num, relation_num, logger)