In [1]:
%load_ext gvmagic

In [2]:
from problog.engine import DefaultEngine
from problog.formula import LogicFormula
from problog.logic import *
from problog.program import PrologString
from data_processing import clear_data_part
from classifiers import SimpleClassifier
import pandas as pd
import problog
import numpy as np
import json
from scipy.stats import friedmanchisquare

  from ._conv import register_converters as _register_converters


In [3]:
embeddings_enta = pd.read_csv('entity_embeddingsa.csv')
embeddings_entb = pd.read_csv('entity_embeddingsb.csv')
embeddings_rel = pd.read_csv('relation_embeddings.csv')
embeddings_rule = pd.read_csv('rule_embeddings.csv')

In [4]:
embeddings_rule = embeddings_rule.drop(['Unnamed: 0'],axis=1)

In [5]:
with open('dict_neighbors.json') as json_file:
    neighbors = json.load(json_file)

In [6]:
model = SimpleClassifier()

In [7]:
print('opening facts and rules')
with open('rules_and_facts.txt', 'r') as myfile:
    facts_and_rules = myfile.readlines()

facts_and_rules_Prolog = PrologString("\n".join(facts_and_rules))

opening facts and rules


In [8]:
test_data = pd.read_csv('test_data.csv')
test_data = test_data[['entity a','entity b','rel id']]
test_data = test_data.drop_duplicates()

In [9]:
row = test_data.iloc[0]
print(row)
rel_id = row['rel id']
ent_a = row['entity a']
ent_b = row['entity b']
query = Term('q'+str(rel_id),Term('ent'+str(ent_a)),Term('ent'+str(ent_b)))

entity a    2600
entity b    3465
rel id        69
Name: 0, dtype: int64


In [10]:
def get_embedding_action(action,predicates,embeddings_enta,embeddings_entb,embeddings_rel,embeddings_rule,randenta,randentb):
    """
    function that concatenates the embeddings given the action = relation, entity a, entity b and rule
    """
    node = action[1]
    functor = db.get_node(node).functor
    for pred in predicates:
        if pred.functor == functor:
            enta = term2str(pred.args[0])
            entb = term2str(pred.args[1])
            if enta.find('ent')==-1:
                enta = randenta
                enta_emb = embeddings_enta[embeddings_enta['entity a']==enta]
                enta_emb = enta_emb.drop(['entity a'],axis=1)
                enta_emb = enta_emb.values.tolist()[0]
            else:
                enta = int(enta[enta.find('ent')+3:])
                enta_emb = embeddings_enta[embeddings_enta['entity a']==enta]
                enta_emb = enta_emb.drop(['entity a'],axis=1)
                enta_emb = enta_emb.values.tolist()[0]

            if entb.find('ent')==-1:
                entb = randentb
                entb_emb = embeddings_entb[embeddings_entb['entity b']==entb]
                entb_emb = entb_emb.drop(['entity b'],axis=1)
                entb_emb = entb_emb.values.tolist()[0]
            else:
                entb = int(entb[entb.find('ent')+3:])
                entb_emb = embeddings_entb[embeddings_entb['entity b']==entb]
                entb_emb = entb_emb.drop(['entity b'],axis=1)
                entb_emb = entb_emb.values.tolist()[0]


    rule_id = db.get_node(node).args[-1]
    rel_id = int(functor[functor.find('rel')+3:])
    rule_emb = embeddings_rule[(embeddings_rule['rule id']==rule_id) & (embeddings_rule['rel id']==rel_id)]
    rule_emb = rule_emb.drop(['rel id','rule id'],axis=1)
    rule_emb = rule_emb.values.tolist()[0]
    rel_emb = embeddings_rel[embeddings_rel['rel id']==rel_id]
    rel_emb = rel_emb.drop(['rel id'],axis=1)
    rel_emb = rel_emb.values.tolist()[0]


    features = enta_emb + entb_emb + rel_emb + rule_emb
    return [rel_id] + [enta] + [entb] + [rule_id] + features

In [11]:
def rerank_actions(actions,predicates,embeddings_enta,embeddings_entb,embeddings_rel,embeddings_rule,model,neighbors,rand_dict):
    actions_to_rerank = []
    reranked_actions = []
    pos = 0
    for act in actions:
        node = act[1]
        if (db.get_node(node).functor,db.get_node(node).args[:-1]) in rand_dict.keys():
            randenta = rand_dict[(db.get_node(node).functor,db.get_node(node).args[:-1])][0]
            randentb = rand_dict[(db.get_node(node).functor,db.get_node(node).args[:-1])][-1]
        else:
            used_preds = []
            pred_pos = 0
            for pred in predicates:
                if pred.functor == db.get_node(node).functor and pred_pos not in used_preds:
                    used_preds.append(pred_pos)
                    pred_pos+=1
                    enta = term2str(pred.args[0])
                    entb = term2str(pred.args[1])
                    if enta.find('ent')==-1:
                        entb = int(entb[entb.find('ent')+3:])
                        randentb = entb
                        randenta = neighbors[str(entb)][np.random.permutation(len(neighbors[str(entb)]))[0]]
                        rand_dict[(db.get_node(node).functor,db.get_node(node).args[:-1])]=[randenta,randentb]
                    elif entb.find('ent')==-1:
                        enta = int(enta[enta.find('ent')+3:])
                        randenta = enta
                        randentb = neighbors[str(enta)][np.random.permutation(len(neighbors[str(enta)]))[0]]
                        rand_dict[(db.get_node(node).functor,db.get_node(node).args[:-1])]=[randenta,randentb]
                    else:
                        randenta = np.random.permutation(14000)[0]
                        randentb = np.random.permutation(14000)[0]
                        rand_dict[(db.get_node(node).functor,db.get_node(node).args[:-1])]=[randenta,randentb]

        db_term = term2str(db.get_node(node))
        if db_term[:db_term.find('(')]=='clause':
            data = get_embedding_action(act,predicates,embeddings_enta,embeddings_entb,embeddings_rel,embeddings_rule,randenta,randentb)
            actions_to_rerank.append([pos] + data)
        pos+=1
    Reranked_Meta = pd.DataFrame([])
    if len(actions_to_rerank)>1:
        actions_to_rerank = pd.DataFrame(actions_to_rerank)
        Meta = actions_to_rerank[[0,1,2,3,4]]
        Data = actions_to_rerank.drop([0,1,2,3,4],axis=1)
        predictions = model.predict_proba('./classification_models/model_transe_50/',Data)
        heuristic = predictions[:,-1]
        Meta['heuristic'] = heuristic
        Meta = Meta.rename(columns={0:'position',1:'rel id',2:'ent a',3:'ent b',4:'rule id'})
        uniques = []
        for meta_i in range(1,len(Meta)):
            row = Meta.iloc[meta_i]
            if [row['rel id'],row['ent a'], row['ent b']] not in uniques:
                uniques.append([row['rel id'],row['ent a'], row['ent b']])
        for unique in uniques:
            rel_id = unique[0]
            enta = unique[1]
            entb = unique[2]
            Selected = Meta[(Meta['rel id']==rel_id) & (Meta['ent a']== enta) & (Meta['ent b']==entb)]
            Selected = Selected.sort_values(by=['heuristic'])
            #print(Selected)
            Reranked_Meta = Reranked_Meta.append(Selected)
        act_counter = 0
        for act_i in range(0,len(actions)):
            if act_i not in Reranked_Meta['position']:
                reranked_actions.append(actions[act_i])
            else:
                reranked_actions.append(actions[Reranked_Meta.iloc[act_counter]['position']])
                act_counter+=1

        return reranked_actions, rand_dict
    else:
        return actions,rand_dict

In [12]:
#no_rerank_results = pd.read_csv('results_no_rerank.csv')
#pd_results = pd.read_csv('pd_results.csv')
#results = list(pd_results['0'])
#no_rerank_results = list(no_rerank_results['no rerank'])
results = []
for test_sample in range(0,len(test_data)):
    try:
        print(test_sample)
        row = test_data.iloc[test_sample]
        rel_id = row['rel id']
        ent_a = row['entity a']
        ent_b = row['entity b']
        query = Term('q'+str(rel_id),Term('ent'+str(ent_a)),Term('ent'+str(ent_b)))
        # Perform incremental grounding.
        # This is a split-up of the engine.execute method.

        # Initialize the engine with options:
        #   - unbuffered: don't buffer results internally in the nodes (mimic depth-first construction of target)
        #   - rc_first: first process 'result' and 'complete' messages (allows stopping on 'evaluation' message)
        #   - label_all: (optional) label all intermediate nodes with their predicate
        engine = problog.engine.DefaultEngine(unbuffered=True, rc_first=True, label_all=True)

        # Target formula
        #   - keep_all: don't collapse non-probabilistic subformula's => only for visualization
        target = problog.formula.LogicFormula(keep_all=True)

        db = engine.prepare(facts_and_rules_Prolog)

        # Start the incremental grounding.
        # The result is a list of 'evaluation' actions.
        actions = list(reversed(engine.ground_step(db, query, gp=target)))

        i = 0
        # Execute until no more 'evaluation' actions can be performed.
        rand_dict={}
        while actions:
            actions = engine.execute_step(actions, steps=1, target=target, name=(False, query, 'query'))

            # HERE YOU CAN DO WHATEVER YOU WANT WITH THE ACTION LIST

            # Below is just generating some output.
            i += 1

            #print('==== STEP %d ====' % i)


            predicates = []
            # Go through the engine's stack and extract predicate evaluation nodes ('EvalDefine')
            for rec in engine.stack:
                if type(rec).__name__ == 'EvalDefine':  # TODO: we should also include 'EvalOr'?
                    nodes = set(b for a, b, in rec.results.results)  # 'target' nodes associated with this evaluation node
                    predicates.append(problog.logic.Term(rec.call[0], *rec.call[1]))
            #print(predicates)
            #actions,rand_dict = rerank_actions(actions,predicates,embeddings_enta,embeddings_entb,embeddings_rel,embeddings_rule,model,neighbors,rand_dict)

            #for act in actions:
            #    print(db.get_node(act[1]))

            if type(engine.stack[0]).__name__=='EvalDefine':
                trigger_nodes = set(b for a,b, in engine.stack[0].results.results)

            #print ('Active predicates:')
            active_nodes = set()  # These are the nodes in 'target' that are still active.
            # Go through the engine's stack and extract predicate evaluation nodes ('EvalDefine')
            for rec in engine.stack:
                if type(rec).__name__ == 'EvalDefine':  # TODO: we should also include 'EvalOr'?
                    nodes = set(b for a, b, in rec.results.results)  # 'target' nodes associated with this evaluation node
                    #print ('\t', problog.logic.Term(rec.call[0], *rec.call[1]), list(nodes))
                    active_nodes |= nodes # union
            #print ('Active nodes:', list(active_nodes))

            # Visualize and print the current logic program.
            #%dotstr target.to_dot(nodeprops={n: 'fillcolor="red"' for n in active_nodes})

            if len(list(trigger_nodes))>0:
                print(str(test_sample)+ ':solution found ('+str(i)+')')
                results.append(i)
                pd_results = pd.DataFrame(results)
                pd_results.to_csv('pd_results.csv',index=False)
                break

            if i>1000:
                print(str(test_sample)+ ':too long')
                results.append(0)
                pd_results = pd.DataFrame(results)
                pd_results.to_csv('pd_results.csv',index=False)
                break

            if len(actions)==0:
                print(str(test_sample)+ ':solution not found')
                results.append(-1)
                pd_results = pd.DataFrame(results)
                pd_results.to_csv('pd_results.csv',index=False)
                break

    except:
        print('cycle')
        results.append(-2)
        pass   

0
0:solution found (8)
1
1:solution found (8)
2
2:too long
3
3:solution found (10)
4
4:too long
5
5:solution found (7)
6
6:solution found (7)
7
7:solution found (7)
8
8:solution found (7)
9
9:solution found (9)
10
10:solution found (9)
11
11:too long
12
12:too long
13
13:too long
14
14:solution found (11)
15
15:too long
16
16:solution found (10)
17
17:too long
18
18:solution found (7)
19
19:solution found (9)
20
20:solution found (7)
21
21:too long
22
22:solution found (8)
23
23:too long
24
24:solution found (8)
25
25:solution found (9)
26
26:solution found (56)
27
27:solution found (9)
28
28:solution not found
29
29:solution found (9)
30
30:solution found (34)
31
31:too long
32
32:solution found (8)
33
33:solution found (10)
34
34:too long
35
35:too long
36
36:solution found (7)
37
37:solution found (9)
38
38:solution found (8)
39
39:solution found (9)
40
40:solution found (11)
41
41:solution found (8)
42
42:solution found (8)
43
43:too long
44
44:solution found (9)
45
45:solution fou

342:solution found (9)
343
343:solution found (7)
344
344:solution found (8)
345
345:solution found (8)
346
346:solution found (7)
347
347:solution found (14)
348
348:solution found (8)
349
349:solution found (8)
350
350:solution found (11)
351
351:solution found (8)
352
352:solution found (8)
353
353:solution found (33)
354
354:too long
355
355:solution found (54)
356
356:solution found (13)
357
357:solution found (10)
358
358:solution found (9)
359
359:too long
360
360:too long
361
361:too long
362
362:solution found (9)
363
363:solution found (8)
364
364:solution found (10)
365
365:too long
366
366:solution found (7)
367
367:solution found (9)
368
368:solution found (9)
369
369:too long
370
370:too long
371
371:too long
372
372:solution found (8)
373
373:solution found (8)
374
374:solution found (9)
375
375:too long
376
376:solution found (9)
377
377:solution found (32)
378
378:solution found (7)
379
379:solution found (12)
380
380:too long
381
381:solution found (8)
382
382:too lon

685:solution found (10)
686
686:solution found (8)
687
687:solution found (9)
688
688:solution found (12)
689
689:too long
690
690:solution found (8)
691
691:solution found (8)
692
692:too long
693
693:too long
694
694:too long
695
695:solution found (9)
696
696:solution found (9)
697
697:too long
698
698:too long
699
699:solution found (7)
700
700:too long
701
701:solution found (11)
702
702:solution found (7)
703
703:solution found (7)
704
704:solution found (9)
705
705:solution found (10)
706
706:solution found (8)
707
707:solution found (7)
708
708:too long
709
709:solution found (11)
710
710:solution found (9)
711
711:too long
712
712:solution found (7)
713
713:solution found (8)
714
714:solution found (8)
715
715:solution found (11)
716
716:too long
717
717:solution found (7)
718
718:solution found (7)
719
719:too long
720
720:solution found (10)
721
721:solution found (7)
722
722:too long
723
723:too long
724
724:solution found (8)
725
725:solution found (7)
726
726:solution fou

1024:solution found (8)
1025
1025:too long
1026
1026:solution found (12)
1027
1027:solution found (7)
1028
1028:solution found (8)
1029
1029:solution found (8)
1030
1030:too long
1031
1031:solution found (9)
1032
1032:solution found (14)
1033
1033:too long
1034
1034:solution found (10)
1035
1035:too long
1036
1036:solution found (9)
1037
1037:solution found (7)
1038
1038:solution found (8)
1039
1039:solution found (7)
1040
1040:solution found (9)
1041
1041:solution found (8)
1042
1042:solution found (9)
1043
1043:solution found (7)
1044
1044:solution found (8)
1045
1045:solution found (9)
1046
1046:solution found (7)
1047
1047:solution found (7)
1048
1048:too long
1049
1049:too long
1050
1050:solution found (8)
1051
1051:solution found (9)
1052
1052:too long
1053
1053:solution found (9)
1054
1054:solution found (10)
1055
1055:solution found (10)
1056
1056:too long
1057
1057:solution found (8)
1058
1058:solution found (8)
1059
1059:solution found (7)
1060
1060:solution found (7)
1061
10

1337:solution found (9)
1338
1338:too long
1339
1339:solution found (8)
1340
1340:solution found (7)
1341
1341:solution found (7)
1342
1342:too long
1343
1343:solution found (8)
1344
1344:too long
1345
1345:solution found (7)
1346
1346:solution found (10)
1347
1347:solution found (8)
1348
1348:too long
1349
1349:too long
1350
1350:solution found (8)
1351
1351:solution found (9)
1352
1352:solution found (10)
1353
1353:solution found (7)
1354
1354:solution found (8)
1355
1355:solution not found
1356
1356:solution found (9)
1357
1357:too long
1358
1358:too long
1359
1359:too long
1360
1360:too long
1361
1361:solution found (13)
1362
1362:solution found (7)
1363
1363:too long
1364
1364:too long
1365
1365:solution not found
1366
1366:solution found (10)
1367
1367:solution found (8)
1368
1368:solution found (7)
1369
1369:solution found (7)
1370
1370:solution found (11)
1371
1371:solution found (29)
1372
1372:solution found (8)
1373
1373:solution found (7)
1374
1374:solution found (7)
1375
13

1652:solution found (7)
1653
1653:solution found (9)
1654
1654:too long
1655
1655:solution found (8)
1656
1656:solution found (10)
1657
1657:solution found (9)
1658
1658:solution found (8)
1659
1659:too long
1660
1660:too long
1661
1661:solution found (8)
1662
1662:solution found (7)
1663
1663:solution found (10)
1664
1664:too long
1665
1665:solution found (9)
1666
1666:solution found (10)
1667
1667:solution found (9)
1668
1668:solution found (9)
1669
1669:solution found (7)
1670
1670:too long
1671
1671:solution found (14)
1672
1672:solution found (15)
1673
1673:too long
1674
1674:solution found (8)
1675
1675:solution found (7)
1676
1676:solution found (9)
1677
1677:solution found (8)
1678
1678:solution found (9)
1679
1679:solution found (9)
1680
1680:solution found (29)
1681
1681:too long
1682
1682:too long
1683
1683:too long
1684
1684:solution found (9)
1685
1685:too long
1686
1686:solution found (7)
1687
1687:solution found (7)
1688
1688:solution found (9)
1689
1689:solution found (

1969:solution found (9)
1970
1970:solution found (7)
1971
1971:solution found (9)
1972
1972:solution found (7)
1973
1973:too long
1974
1974:too long
1975
1975:too long
1976
1976:solution not found
1977
1977:solution found (12)
1978
1978:solution found (9)
1979
1979:solution found (30)
1980
1980:solution found (11)
1981
1981:solution found (9)
1982
1982:too long
1983
1983:solution found (11)
1984
1984:too long
1985
1985:solution found (8)
1986
1986:solution found (7)
1987
1987:solution found (9)
1988
1988:solution found (7)
1989
1989:too long
1990
1990:solution found (8)
1991
1991:too long
1992
1992:solution found (8)
1993
1993:solution found (12)
1994
1994:solution found (9)
1995
1995:too long
1996
1996:solution found (14)
1997
1997:solution found (9)
1998
1998:too long
1999
1999:solution found (7)
2000
2000:solution found (11)
2001
2001:solution found (8)
2002
2002:solution found (10)
2003
2003:solution found (7)
2004
2004:too long
2005
2005:solution found (9)
2006
2006:solution found

In [14]:
test_data['no rerank'] = results

In [15]:
test_data.to_csv('results_no_rerank.csv',index=False)

In [None]:
no_rerank_results = pd.read_csv('results_no_rerank.csv')
results = list(pd_results['0'])
#results = []
no_rerank_results = list(no_rerank_results['no rerank'])
for test_sample in range(0,len(test_data)):
    if no_rerank_results[test_sample]>30:
        try:
            print(test_sample)
            row = test_data.iloc[test_sample]
            rel_id = row['rel id']
            ent_a = row['entity a']
            ent_b = row['entity b']
            query = Term('q'+str(rel_id),Term('ent'+str(ent_a)),Term('ent'+str(ent_b)))
            # Perform incremental grounding.
            # This is a split-up of the engine.execute method.

            # Initialize the engine with options:
            #   - unbuffered: don't buffer results internally in the nodes (mimic depth-first construction of target)
            #   - rc_first: first process 'result' and 'complete' messages (allows stopping on 'evaluation' message)
            #   - label_all: (optional) label all intermediate nodes with their predicate
            engine = problog.engine.DefaultEngine(unbuffered=True, rc_first=True, label_all=True)

            # Target formula
            #   - keep_all: don't collapse non-probabilistic subformula's => only for visualization
            target = problog.formula.LogicFormula(keep_all=True)

            db = engine.prepare(facts_and_rules_Prolog)

            # Start the incremental grounding.
            # The result is a list of 'evaluation' actions.
            actions = list(reversed(engine.ground_step(db, query, gp=target)))

            i = 0
            # Execute until no more 'evaluation' actions can be performed.
            rand_dict={}
            while actions:
                actions = engine.execute_step(actions, steps=1, target=target, name=(False, query, 'query'))

                # HERE YOU CAN DO WHATEVER YOU WANT WITH THE ACTION LIST

                # Below is just generating some output.
                i += 1

                #print('==== STEP %d ====' % i)


                predicates = []
                # Go through the engine's stack and extract predicate evaluation nodes ('EvalDefine')
                for rec in engine.stack:
                    if type(rec).__name__ == 'EvalDefine':  # TODO: we should also include 'EvalOr'?
                        nodes = set(b for a, b, in rec.results.results)  # 'target' nodes associated with this evaluation node
                        predicates.append(problog.logic.Term(rec.call[0], *rec.call[1]))
                #print(predicates)
                actions,rand_dict = rerank_actions(actions,predicates,embeddings_enta,embeddings_entb,embeddings_rel,embeddings_rule,model,neighbors,rand_dict)

                #for act in actions:
                #    print(db.get_node(act[1]))

                if type(engine.stack[0]).__name__=='EvalDefine':
                    trigger_nodes = set(b for a,b, in engine.stack[0].results.results)

                #print ('Active predicates:')
                active_nodes = set()  # These are the nodes in 'target' that are still active.
                # Go through the engine's stack and extract predicate evaluation nodes ('EvalDefine')
                for rec in engine.stack:
                    if type(rec).__name__ == 'EvalDefine':  # TODO: we should also include 'EvalOr'?
                        nodes = set(b for a, b, in rec.results.results)  # 'target' nodes associated with this evaluation node
                        #print ('\t', problog.logic.Term(rec.call[0], *rec.call[1]), list(nodes))
                        active_nodes |= nodes # union
                #print ('Active nodes:', list(active_nodes))

                # Visualize and print the current logic program.
                #%dotstr target.to_dot(nodeprops={n: 'fillcolor="red"' for n in active_nodes})

                if len(list(trigger_nodes))>0:
                    print(str(test_sample)+ ':solution found ('+str(i)+')')
                    results.append(i)
                    pd_results = pd.DataFrame(results)
                    pd_results.to_csv('pd_results.csv',index=False)
                    break

                if i>100:
                    print(str(test_sample)+ ':too long')
                    results.append(0)
                    pd_results = pd.DataFrame(results)
                    pd_results.to_csv('pd_results.csv',index=False)
                    break

                if len(actions)==0:
                    print(str(test_sample)+ ':solution not found')
                    results.append(-1)
                    pd_results = pd.DataFrame(results)
                    pd_results.to_csv('pd_results.csv',index=False)
                    break

        except:
            print('cycle')
            results.append(-2)
            pd_results = pd.DataFrame(results)
            pd_results.to_csv('pd_results.csv',index=False)
            pass   

26


A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: http://pandas.pydata.org/pandas-docs/stable/indexing.html#indexing-view-versus-copy


cycle
30
cycle
103
103:solution found (15)
120
cycle
138
138:too long
157
cycle
162
cycle
165
cycle
170
cycle
244
cycle
249
249:solution found (93)
353
cycle
355
355:solution found (18)
377
cycle
433
433:too long
464
cycle
577
cycle
590
