In [None]:
from tqdm import tqdm
import random
from random import sample
import networkx as nx
import re, logging
import openai, datetime, os

def load_openai_keys():
    keys = []
    
    with open('../openai_keys_filter.txt', "r") as f:
        for line in f:
            key = line.strip().split()
            keys.append(key[-1])
    return keys
openai_api_keys = load_openai_keys()

random.shuffle(openai_api_keys)
def update_key():
    curr_key = openai_api_keys[0]
    openai.api_key = curr_key
    openai_api_keys.remove(curr_key)
    openai_api_keys.append(curr_key)

In [None]:
rid = 0
id2rel = dict()
rel2id = dict()
rel2sym = dict()

with open("../symbolic_tree/1.relations", 'r') as f:
    for line in f:
        _, rel = line.strip().split()
        id2rel[rid] = rel
        rel2id[rel] = rid
        rid += 1

with open("rel2sym.txt","r") as fr:
    for line in fr:
        rel, sym = line.strip().split()
        if 'Of' in sym:

            rel2sym[rel] = sym[:-2]
        else:
            rel2sym[rel] = sym
        # rel2sym[rel] = sym 

In [None]:
def get_logger(filename, verbosity=1, name=None):
    level_dict = {0: logging.DEBUG, 1: logging.INFO, 2: logging.WARNING}
    formatter = logging.Formatter(
        "[%(asctime)s][%(filename)s][line:%(lineno)d][%(levelname)s] %(message)s"
    )
    logger = logging.getLogger(name)
    logger.setLevel(level_dict[verbosity])

    # Remove any existing handlers
    for handler in logger.handlers:
        logger.removeHandler(handler)
    # Output to file
    fh = logging.FileHandler(filename, "w")
    fh.setFormatter(formatter)
    logger.addHandler(fh)

    # # Output to terminal
    # sh = logging.StreamHandler()
    # sh.setFormatter(formatter)
    # logger.addHandler(sh)

    return logger


def read_entity(path, eid, id2ent, ent2id ):
    with open(path, 'r') as f:
        for line in f:
            _, ent = line.strip().split()
            if ent not in ent2id:
                id2ent[eid] = ent
                ent2id[ent] = eid
                eid += 1
                
        return eid
    
def read_class(path, cid, ent2class, id2ent, class_text, fid):
    with open(path, 'r') as f:
        for line in f:
            female, male  = line.strip().split()
            if female == '1':
                ent2class[id2ent[cid]] = rel2sym['female']
                # ent2class[id2ent[cid]] = 'female'

                # class_text += 'F' + str(fid) + ': ' + rel2sym['female'] + '(' + id2ent[cid] + ')\n'
                class_text += 'F' + str(fid) + ': ' + id2ent[cid] + ' is ' + rel2sym['female'] + '.\n'
                
                fid += 1
                # class_text += id2ent[cid] + ' is a ' + "female" + '. '

                # class_text += ('female'+'(' + id2ent[cid] + ')')
            else:
                ent2class[id2ent[cid]] = rel2sym['male']
                # ent2class[id2ent[cid]] = 'male'

                # class_text += 'F' + str(fid) + ': ' + rel2sym['male'] + '(' + id2ent[cid] + ')\n'
                class_text += 'F' + str(fid) + ': ' + id2ent[cid] + ' is ' + rel2sym['male'] + '.\n'
                
                fid += 1
                # class_text += id2ent[cid] + ' is a '+ 'male' + '. '

                # class_text += ('male'+'(' + id2ent[cid] + ')')

            cid += 1
        return cid, class_text, fid


In [None]:
rules = ''
rel2rule = dict()
lid = 1
with open("natural_rules.txt", 'r') as f1:
    for line in f1:
        rel, rule = line.strip().split('\t')
        rel2rule[rel] = rule
        rules += 'L' + str(lid) + ': ' + rule + '\n'
        lid += 1

def dict2str(d):
    s = ''
    for k in d:
        s += str(k) + '\t' + str(d[k]) + '\r'
    return s
def get_negative_samples(triplets, id2ent, id2rel, labels):
    neg_samples = []
    # random sample head or tail
    
    for i in range(len(triplets)):
        
        while 1:
            if random.random() < 0.5:
                # sample head
                h = random.randint(0, len(id2ent) - 1)
                t = triplets[i][2]
                r = triplets[i][1]
                if (id2ent[int(h)], r, t) not in triplets:
                    neg_samples.append((id2ent[int(h)], r, t))
                    labels[(id2ent[int(h)], r, t)] = 0
                    break
            else:
                h = triplets[i][0]
                t = random.randint(0, len(id2ent) - 1)
                r = triplets[i][1]
                if (h, r, id2ent[int(t)]) not in triplets:
                    neg_samples.append((h, r, id2ent[int(t)]))
                    labels[(h, r, id2ent[int(t)])]=0
                    break

    return neg_samples

In [None]:
# compute F beta[0,0.1:1] score
def compute_f_beta_score(precision, recall, logger):
    for beta in range(0, 11):
        beta = beta / 10
        if precision == 0 and recall == 0:
            logger.info('beta: '+ str(beta) + '\rF score: 0')
        else:
            score = (1 + beta * beta) * precision * recall / (beta * beta * precision + recall)
            logger.info('beta: '+ str(beta) + '\tF score: ' + str(score))

In [None]:
import tiktoken
# enc = tiktoken.encoding_for_model("gpt-3.5-turbo")
enc = tiktoken.encoding_for_model("gpt-4")

def generate_few_shot_prompts(predicted_facts, model, lid, fid, rules, basic_facts, labels):
    # randomly select 5 triplet from predicated facts
    number = 0
    false_number = 0
    true_number = 0
    prompts = list()
    prompts_plus = list()
    relation_list = list()
    extra_tokens = 1200
    while number < 6 :
        fact = random.sample(predicted_facts, 1)[0]
        
        h, r, t = fact
        
        if r not in relation_list:

            

            # statement = h + ' has a relationship ' + rel2sym[r] + ' with ' + t + '.\n'
            # statement = rel2sym[r] + '(' + h + ',' + t + ')' 
            statement = h + ' is ' + rel2sym[r] + ' of ' + t + '.'
            message_1 = {
                        'system': "Please select a few facts to predict True/False of the unknown fact.",
                        'user': "I will provide a set of logical rules L1 to L" + str(lid - 1) + " and facts F1 to F" + str(fid - 1) + ". Please select one single logical rule from L1 to L" + str(lid - 1) + " and a few facts from F1 to F" + str(fid - 1) + " to predict True/False of the following statement.\nLogical rules:\n" + rules + "\nFacts:\n" + basic_facts + "\nUnknown fact: " + statement + "\nAnswer with True or False? Let's think step by step."
                    
                    }
            server_error_cnt = 0
            
            while server_error_cnt < 10:
                try:
                    # update_key()
                    response = openai.ChatCompletion.create(
                    model= model,
                    messages=[
                            {"role": "system", "content": message_1['system']},
                            {"role": "user", "content": message_1['user']},
                    ],
                    temperature = 0,
                    # max_tokens = 2096,
                    )
                    
                    break
                except:
                    server_error_cnt += 1
                    print('server error')
            
            results = response['choices'][0]['message']['content']
            results = re.sub(r'\n+', '\n', results)
            if len(enc.encode(results)) <= extra_tokens / (6 - number):
            # if response['usage']['completion_tokens'] < extra_tokens / (5 - number):
                
                # 多换行符变成一个换行符
                
                message_2 = {
                    'system': "Please predict True/False of the unknown fact.",
                    'user': message_1['user'] + '\n' + results + '\nTherefore, the answer (True or False) is: '
                }
                response = openai.ChatCompletion.create(
                model= model,
                messages=[
                        {"role": "system", "content": message_2['system']},
                        {"role": "user", "content": message_2['user']},
                ],
                temperature=0,
                # max_tokens = 2096,
                )
                results_2 = response['choices'][0]['message']['content']
                if len(results_2.split('.')) >= 1:
                    last_sentence = results_2.split('.')[0]
                elif len(results_2.split('\n')) >= 1:
                    last_sentence = results_2.split('\n')[0]
                else:
                    last_sentence = results_2
                    print('output: ' + results_2)
                false_words = [ 'False', 'false', 'Unknown', 'unknown']

                if labels[(h, r, t)] == 1:
                    if any(word in last_sentence for word in false_words):
                        # false_num += 1
                        # pos_false += 1
                        print('correctness: ' + 'Incorrect')
                        
                    elif 'True' in last_sentence or 'true' in last_sentence:
                        if true_number == 3:
                            if false_number == 3:
                                break
                            else:
                                continue
                        print('correctness: ' + 'Correct')

                        d = {}
                        d['Statement'] = "Unknown fact: " + statement
                        d['Answer'] = "Answer: " + results
                        d_plus = {}
                        d_plus['Statement'] = "Unknown fact: " + statement
                        d_plus['Answer'] = "Answer: Let's think step by step. " + results
                        prompts.append(d)
                        prompts_plus.append(d_plus)
                        number += 1
                        true_number += 1
                        extra_tokens -= len(enc.encode(results))
                        print('true label')
                        relation_list.append(r)
                        
                    else:
                        # false_num += 1
                        # pos_false += 1
                        print('correctness: ' + 'Incorrect')

                        print(last_sentence)
            
            
                else:
                    if any(word in last_sentence for word in false_words):
                        # true_num += 1
                        # neg_true += 1
                        if false_number == 3:
                            if true_number == 3:
                                break
                            continue
                        d = {}
                        d['Statement'] = "Unknown fact: " + statement
                        d['Answer'] = "Answer: " + results
                        d_plus = {}
                        d_plus['Statement'] = "Unknown fact: " + statement
                        d_plus['Answer'] = "Answer: Let's think step by step. " + results
                        prompts.append(d)
                        prompts_plus.append(d_plus)
                        number += 1
                        false_number += 1
                        extra_tokens -= len(enc.encode(results))
                        relation_list.append(r)

                        print('correctness: ' + 'Correct')
                        print('false label')
                    elif 'True' in last_sentence or 'true' in last_sentence:
                        # false_num += 1
                        # neg_false += 1
                        print('correctness: ' + 'Incorrect')
                        
                    else:
                        # if false_number == 3:
                        #     if true_number == 3:
                        #         break
                        #     else:
                        #         continue
                        # d = {}
                        # d['Statement'] = "Statement: " + statement
                        # d['Answer'] = "Answer: " + results
                        # d_plus = {}
                        # d_plus['Statement'] = "Statement: " + statement
                        # d_plus['Answer'] = "Answer: Let's think step by step. " + results
                        # prompts.append(d)
                        # prompts_plus.append(d_plus)
                        # number += 1
                        # false_number += 1
                        # extra_tokens -= len(enc.encode(results))
                        # print('false label')
                        # relation_list.append(r)
                    
                        print('correctness: ' + 'Unknown')
        
    return prompts, prompts_plus
    

In [None]:
nowTime = datetime.datetime.now().strftime('%Y-%m-%d-%H-%M-%S')
dir = 'logs/natural_few_shot_cot_auto'
if not os.path.exists(dir):
    os.makedirs(dir)
logger = get_logger(os.path.join(dir, nowTime + '.log'), verbosity=1)

# model = "gpt-3.5-turbo"
model = "gpt-4"
logging.info('model: ' + model)

record_flag = False
for i in range(1, 10):
    id2ent = dict()
    ent2id = dict()
    ent2class = dict()
    ent2triplets = dict()

    eid = 0

    cid = 0
    fid = 1
    class_text = ''
    triplets = []
    test_triplets = []
    labels = dict()
    entpair2rel = dict() 
    basic_facts = ''
    statement = ''
    path_ent = "../symbolic_tree/" + str(i) + ".individuals"
    path_class = "../symbolic_tree/" + str(i) + ".classes.data"
    eid = read_entity(path_ent,eid, id2ent,ent2id)
    cid, class_text, fid = read_class(path_class, cid, ent2class,id2ent, class_text, fid)
    # print(i)
    path = "../symbolic_tree/"+str(i)+".relations.data"
    basic_facts += class_text
    with open(path,'r') as f:
        for line in f:
            flag, h, r, t = line.strip().split()
            triplets.append((id2ent[int(h)], id2rel[int(r)], id2ent[int(t)]))
            entpair2rel[(id2ent[int(h)], id2ent[int(t)])] = rel2sym[id2rel[int(r)]]
            # basic_facts += 'F' + str(fid) + ': ' + rel2sym[id2rel[int(r)]] + '(' + id2ent[int(h)] + ', ' + id2ent[int(t)] + ')\n'
            basic_facts += 'F' + str(fid) + ': ' + id2ent[int(h)] + ' is ' + rel2sym[id2rel[int(r)]] + ' of ' + id2ent[int(t)] + '.\n'
            
            fid += 1
            # text += id2ent[int(h)] + ' is the ' + id2rel[int(r)] + ' of ' + id2ent[int(t)] + '. '

            # text +=(id2rel[int(r)] + '(' + id2ent[int(h)], id2ent[int(t)] +')')
    path = "../symbolic_tree/"+str(i)+".relations.data.inf"

    with open(path,'r') as f:
        for line in tqdm(f):
            flag, h, r, t = line.strip().split()
            if flag == '+':
                test_triplets.append((id2ent[int(h)], id2rel[int(r)], id2ent[int(t)]))
                labels[(id2ent[int(h)], id2rel[int(r)], id2ent[int(t)])] = 1
    negative_samples = get_negative_samples(test_triplets, id2ent, id2rel, labels)
            # if flag == '-':
            #     test_triplets.append((id2ent[int(h)], id2rel[int(r)], id2ent[int(t)]))
            #     labels.append(0)

    num = 0
    true_num = 0
    false_num = 0

    pos_true = 0
    pos_false = 0

    neg_true = 0
    neg_false = 0

    predicted_facts = test_triplets + negative_samples
    # random order in predicted_facts
    random.shuffle(predicted_facts)
    prompts, prompts_plus = generate_few_shot_prompts(predicted_facts, model, lid, fid, rules, basic_facts,labels)
    
    for triple in tqdm(predicted_facts):
        h, r, t = triple
        # statement = rel2sym[r] + '(' + h + ', ' + t + ')'
        statement = h + ' is ' + rel2sym[r] + ' of ' + t + '.'
        
        server_error_cnt = 0
        while server_error_cnt<10:
            try:
                                
                # update_key()
                message_1 = {
                    # 'system': "You are a helpful assistant with deductive reasoning abilities.",
                    'system': "You are a helpful assistant with deductive reasoning abilities. Please select one single logical rule and a few facts to predict True/False of the following statement.",
                    'user': "I will provide a set of logical rules L1 to L" + str(lid - 1) + " and facts F1 to F" + str(fid - 1) + ".\nLogical rules:\n" + rules + "\nFacts:\n" + basic_facts + "\nPlease select one single logical rule from L1 to L" + str(lid - 1) + " and a few facts from F1 to F" + str(fid - 1) + " to predict True/False of the following statement using deductive reasoning.\n",
                    'Q1': prompts[0]['Statement'] ,
                    'A1': prompts[0]['Answer'],
                    'Q2': prompts[1]['Statement'],
                    'A2': prompts[1]['Answer'],
                    'Q3': prompts[2]['Statement'],
                    'A3': prompts[2]['Answer'],
                    'Q4': prompts[3]['Statement'],
                    'A4': prompts[3]['Answer'],
                    'Q5': prompts[4]['Statement'],
                    'A5': prompts[4]['Answer'],
                    'Q6': prompts[5]['Statement'],
                    'A6': prompts[5]['Answer'],
                    'Q7': "Statement: " + statement + '\nAnswer: ',
                }
                response = openai.ChatCompletion.create(
                    model= model,
                    messages=[
                            {"role": "system", "content": message_1['system']},
                            {"role": "user", "content": message_1['user']},
                            {"role": "user", "content": message_1['Q1']},
                            {"role": "assistant", "content": message_1['A1']},
                            {"role": "user", "content": message_1['Q2']},
                            {"role": "assistant", "content": message_1['A2']},
                            {"role": "user", "content": message_1['Q3']},
                            {"role": "assistant", "content": message_1['A3']},
                            {"role": "user", "content": message_1['Q4']},
                            {"role": "assistant", "content": message_1['A4']},
                            {"role": "user", "content": message_1['Q5']},
                            {"role": "assistant", "content": message_1['A5']},
                            {"role": "user", "content": message_1['Q6']},
                            {"role": "assistant", "content": message_1['A6']},
                            {"role": "user", "content": message_1['Q7']},

                    ],
                    temperature=0,
                    # max_tokens = 900,
                )
                # update_key()
                results = response['choices'][0]['message']['content']
                last_text = message_1['user'] + message_1['Q1'] + message_1['A1'] + message_1['Q2'] + message_1['A2'] + message_1['Q3'] + message_1['A3'] + message_1['Q4'] + message_1['A4'] + message_1['Q5'] + message_1['A5'] + message_1['Q6'] + message_1['A6'] + message_1['Q7']

                message_2 = {
                    'system': "Please predict True/False of the following statement.",
                    'user': last_text + '\n' + results + '\nTherefore, the answer (True or False) is: '
                }
                response = openai.ChatCompletion.create(
                    model= model,
                    messages=[
                            {"role": "system", "content": message_2['system']},
                            {"role": "user", "content": message_2['user']},
                    ],
                    temperature=0,
                )
                results = response['choices'][0]['message']['content']
                num += 1

                if record_flag == False:
                    logger.info('message: \n' + dict2str(message_1))

                    record_flag = True

                # get the last sentence
                # last_line = results.strip().split('\n')[-1]
                # last_sentence = re.findall(r"\b\S[^.!?]*[.!?]", last_line)[-1]
                if len(results.split('.')) >= 1:
                    last_sentence = results.split('.')[0]
                elif len(results.split('\n')) >= 1:
                    last_sentence = results.split('\n')[0]
                else:
                    last_sentence = results
                    logger.info('output: ' + results)

                false_words = ['indeterminate', 'Indeterminate', 'FALSE', 'Unknown', 'unknown', 'not', 'False', ' no ', "inconclusive", "undefined", "invalid", 'false']

                # if last sentence contain one of false words
                
                if labels[(h, r, t)] == 1:
                    if any(word in last_sentence for word in false_words):
                        false_num += 1
                        pos_false += 1
                        logger.info('correctness: ' + 'Incorrect')
                    elif 'True' in last_sentence or 'true' in last_sentence or 'TRUE' in last_sentence:
                        true_num += 1
                        pos_true += 1
                        logger.info('correctness: ' + 'Correct')

                    else:
                        false_num += 1
                        pos_false += 1
                        logger.info('correctness: ' + 'Incorrect')

                        print(last_sentence)
                else:
                    if any(word in last_sentence for word in false_words):
                        true_num += 1
                        neg_true += 1
                        logger.info('correctness: ' + 'Correct')
                        
                    elif 'True' in last_sentence or 'true' in last_sentence or 'TRUE' in last_sentence:
                        false_num += 1
                        neg_false += 1
                        logger.info('correctness: ' + 'Incorrect')

                    else:
                        true_num += 1
                        neg_true += 1
                        logger.info('correctness: ' + 'Correct')

                        print(last_sentence)
                
                
                logger.info('statement: ' + statement + '\tgrounding truth: ' + str(labels[(h, r, t)]) + '\tprediction: ' + results )

                break

            except Exception as e:
                server_error_cnt += 1
                logger.info(e)
    
    
    logger.info(str(i) + ': ' + str(true_num / num))
    logger.info('pos_acc: ' + str(pos_true /(pos_true + pos_false)))
    logger.info('neg_acc: ' + str(neg_true /(neg_true + neg_false)))
    logger.info('pos_true:' + str(pos_true))
    logger.info('neg_true:' + str(neg_true))
    TP = pos_true
    FN = pos_false
    FP = neg_false
    TN = neg_true
    logger.info('precision: ' + str(TP / (TP + FP)))
    logger.info('recall: ' + str(TP / (TP + FN)))
    compute_f_beta_score(TP / (TP + FP), TP / (TP + FN), logger)


