# Neural Theorem Prover using pandas and Pytorch

## 1. Symbolic Unificaiton using pandas DataFrame
- Load Files 
- Define Functions 
- Generate Meta Tables
- Run Symbolic Unification and generate batch 

## 2. NTP Model Training with PyTorch
- Define Model Structure using PyTorch
- Define Foward Function 
- Training Model

## 3. Extract Rules from Trained Embedding Vectors
- Matching Rule templates with Embedding vectors 
- Extract Induced Rules


### import packages

In [1]:
import numpy as np
import pandas as pd
import re
import collections
import random
import copy
from pprint import pprint
# from itertools import permutations
from datetime import datetime, timedelta
from collections.abc import Iterable
# from itertools import combinations
import os

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import itertools

# to print pandas dataframe
from IPython.display import display
pd.set_option('display.max_columns', 50)

## 1. Symbolic Unificaiton using pandas DataFrame
### Load Data Files using pandas
- KG : Knowledge Graph file with triple form
- Query : query with triple form

In [2]:
#data_name options = example_7, kinship, umls, nations
data_name = 'example_7'

In [3]:
pos_per_batch = 1 #Number of positive datas to be included in one batch
neg_per_pos = 2 #The number of negative data to be sampled per positive data
batch_size = pos_per_batch + (pos_per_batch * neg_per_pos)
print('batch_size : ', batch_size)

batch_size :  3


In [4]:
KG = pd.read_csv(f'./data/{data_name}.txt', sep='\t', names=['subj','pred','obj'])
Query = pd.read_csv(f'./data/{data_name}.txt', sep='\t', names=['subj','pred','obj'])

In [5]:
KG = KG[['pred', 'subj', 'obj']]
KG.head()

Unnamed: 0,pred,subj,obj
0,nationality,BART,USA
1,placeOfBirth,BART,NEWYORK
2,locatedIn,NEWYORK,USA
3,hasFather,BART,HOMMER
4,nationality,HOMMER,USA


In [6]:
Query = Query.sample(frac=1).reset_index(drop=True)
Query = Query[['pred', 'subj', 'obj']]
Query.head()

Unnamed: 0,pred,subj,obj
0,nationality,BART,USA
1,hasFather,BART,HOMMER
2,locatedIn,NEWYORK,USA
3,placeOfBirth,BART,NEWYORK
4,placeOfBirth,HOMMER,NEWYORK


In [7]:
entity_list = sorted(set(KG.subj.values).union(set(KG.obj.values)))
len(entity_list)

4

In [8]:
start = datetime.now()

#KG index dictionary initializing
KG_index = {}
for entity in entity_list:
    KG_index[entity] = {'subj':[], 'obj':[]}
    
subj_entities = KG['subj'].tolist()
obj_entities = KG['obj'].tolist()

#KG index dictionary generation
for i in range(len(KG)):
    KG_index[subj_entities[i]]['subj'] = KG_index.get(subj_entities[i]).get('subj')+[i]
    KG_index[obj_entities[i]]['obj'] = KG_index.get(obj_entities[i]).get('obj')+[i]

end = datetime.now() 
print('converting time : ', end-start)

converting time :  0:00:00.001000


In [9]:
KG_index

{'BART': {'subj': [0, 1, 3, 6], 'obj': []},
 'HOMMER': {'subj': [4, 5], 'obj': [3]},
 'NEWYORK': {'subj': [2], 'obj': [1, 5]},
 'USA': {'subj': [], 'obj': [0, 2, 4, 6]}}

### Load Rule template and parsing using regular expression

In [10]:
def trim(string):
    """
    - function: trim whitespaces
    :param string: an input string
    
    :return: the string without trailing whitespaces
    """
    return re.sub("\A\s+|\s+\Z", "", string)

def load_from_file(path, rule_template=False):
    """
    - function: load and parsing file
    :param path: file's location 
    :param rule_template: check rule file
    
    :return : parsed kb or rule template
    """
    with open(path, "r") as f:
        text = f.readlines()
        text = [x for x in text if not x.startswith("%") and x.strip() != ""]
        text = "".join(text)
        rules = [x for x in re.split("\.\n|\.\Z", text) if x != "" and
                 x != "\n" and not x.startswith("%")]
        kb = parse_rules(rules, rule_template=rule_template)
        return kb
    
def parse_rules(rules, rule_template=False):
    """
    - function: read file and parse rules
    input : list of strings (such as 2 #1(X, Y) :- #2(X, Z),#3(Z, W),#4(W, Y))
    output : list of lists (such as [('#1', 'X', 'Y'), ('#2', 'X', 'Z'), ('#3', 'Z', 'W'), ('#4', 'W', 'Y'), 2]) 
    """
    kb = []
    for rule in rules:
        num = rule[:rule.find('\t')]
        rule = re.findall(r'#\d+\(.*?\)', rule)
        listAtoms = [re.split('[(),]', item)[:-1] for item in rule ]
        atoms = [(pred, sub, obj.strip()) for [pred, sub, obj] in listAtoms]
        atoms.append(int(num))
        kb.append(atoms)
       
    return kb

In [11]:
rules = load_from_file(f'./data/{data_name}.nlt', rule_template=True)
rules

[[('#1', 'X', 'Y'), ('#2', 'X', 'Z'), ('#3', 'Z', 'Y'), 2],
 [('#1', 'X', 'Y'), ('#2', 'X', 'Y'), 2]]

In [12]:
rule_structure = pd.DataFrame(list(map(lambda x : [{atom[1]: 'subj', atom[2]: 'obj'} for atom in x[:-1]], rules)))
rule_structure['rule_number'] = [i for i in range(len(rules))]
rule_structure

Unnamed: 0,0,1,2,rule_number
0,"{'X': 'subj', 'Y': 'obj'}","{'X': 'subj', 'Z': 'obj'}","{'Z': 'subj', 'Y': 'obj'}",0
1,"{'X': 'subj', 'Y': 'obj'}","{'X': 'subj', 'Y': 'obj'}",,1


### Generate Dictionary from KG & Query data

In [13]:
KG_predicate_list = sorted(set(KG.pred.values).union(set(Query.pred.values)))

rule_pred_list = []
for i, rule in enumerate(rules):
    # iterate rule components
    for r in rule[:-1]:
        #iterate augmnet number
        for j in range(rule[-1]):
            suffix = '_' + str(i) + '_' + str(j)
            rule_pred_list.append(r[0]+suffix)
            
predicate_list = sorted(set(KG_predicate_list).union(set(rule_pred_list)))
# print('predicates : ',predicate_list)

In [14]:
id2sym_dict = {}
sym2id_dict = {}
sym2id_dict['UNK'] = 0
sym2id_dict['PAD'] = 1
id2sym_dict[0] = 'UNK'
id2sym_dict[1] = 'PAD'


for i, p in enumerate(predicate_list):
    sym2id_dict[p] = i+2
    id2sym_dict[i+2] = p

In [15]:
sym2id_dict

{'UNK': 0,
 'PAD': 1,
 '#1_0_0': 2,
 '#1_0_1': 3,
 '#1_1_0': 4,
 '#1_1_1': 5,
 '#2_0_0': 6,
 '#2_0_1': 7,
 '#2_1_0': 8,
 '#2_1_1': 9,
 '#3_0_0': 10,
 '#3_0_1': 11,
 'bornIn': 12,
 'hasFather': 13,
 'locatedIn': 14,
 'nationality': 15,
 'placeOfBirth': 16}

## Define Functions 

### unification
- goal: query (e.g. nationality BART USA)
- rule: rule template (e.g. #1(X,Y) :- #2(X,Z), #3(Z,Y))

- 1. 주어진 rule template의 conclusion과 query를 unify
    - unify된 트리플은 rule component substitution에 key를 rule component(e.g. #1(X,Y))로  
        value를 unified triples(dataframe)으로 저장   
    
        #1(X, Y) :
    
            |     pred    | subj | obj |
            |-------------|------|-----|
            | nationality | BART | USA |
    
    - conclusion의 X,Y와 같은 variable에 대하여 unify된 트리플을 참조하여 variable substitution에  
    X : [BART], Y: [USA] 와 같이 binding




- 2. 앞서 binding된 variable을 참조하여 각 rule body에 맞는 트리플을 unify
    - #1(X,Y)를 통해 binding된 X에 대한 variable substitution을 참조하여 #2(X,Z)와 같은 body에 트리플을 unify하는 작업을 수행
        - 위 경우에는 variable substitution을 참조하여 X가 subject인 트리플을 찾아 unify   
    - unify된 트리플은 rule component substitution에 key를 rule component(e.g. #2(X,Y))로  
        value를 unified triples(dataframe)으로 저장   
           
       #2(X, Z) :

            |     pred     | subj | obj     |
            |--------------|------|---------|
            | placeOfBirth | BART | NEWYORK |
            | hasFather    | BART | HOMMER  |    
        
    - 규칙 body의 X,Z와 같은 variable에 대하여 unify된 트리플을 참조하여 variable substitution에  
    Z : [NEWYORK, HOMMER] 와 같이 binding
    
### proof path completion
- rule component substitution : key가 rule compnent (e.g. #1(X,Y)) value가 각 rule component에 unify된 트리플(dataframe)인 dictionary 
- rule: rule template (e.g. #1(X,Y) :- #2(X,Z), #3(Z,Y))
- 1. rule template을 분석하여 인접한 rule component간의 common variable 도출 
- 2. common variable을 기준으로 unified triple을 join하여 proof path를 생성

In [16]:
#stackoverflow Check multi-value duplication in pandas
def check_duplicate(proof_path):
    
    triple_size = 3
    column_len = len(proof_path.columns)
    num_unique_triples = int(column_len/triple_size) 

    proof_path['n_unique_triples'] = \
        proof_path.apply(lambda row: len(set([tuple(row[ i*triple_size : (i+1)*triple_size ]) 
                                                    for i in range(num_unique_triples)])), axis=1)

    proof_path = proof_path[proof_path.n_unique_triples == num_unique_triples]
    proof_path
    
    return proof_path

#stackoverflow How to insert a dropped join key column from Dataframe join in order
def merge_from_common_key(current_rComp, next_rComp):

    key = current_rComp.columns.intersection(next_rComp.columns, sort=False)
    new_key = key + '_'
    d = dict(zip(key, new_key))

    proof_path = pd.merge(current_rComp, next_rComp.rename(columns=d), how='inner', 
                          left_on=key.tolist(), right_on=new_key.tolist())

    return proof_path

def proof_path_completion(rComp_substitution, rule):
    
    for depth in range(len(rule)-2):    
        if depth == 0:
            #Load dataframes
            current_rComp = rComp_substitution[rule[depth]] 
            next_rComp = rComp_substitution[rule[depth+1]]
            #merge from common key
            partial_proof_path = merge_from_common_key(current_rComp, next_rComp)

        else:
            #Load dataframes
            current_rComp = partial_proof_path
            next_rComp = rComp_substitution[rule[depth+1]]
            #merge from common key
            partial_proof_path = merge_from_common_key(current_rComp, next_rComp)

    #remove duplicate  Adjacent triple
    proof_path = partial_proof_path
    proof_path = check_duplicate(proof_path)
    
    #stackoverflow How to get that result without for loop (python)
    unified_rel_path = proof_path[[r[0] for r in rule[:-1]]].drop_duplicates()\
           .apply(lambda x: list(zip([r[0] for r in rule[:-1]], x)), axis=1).values.tolist()
    
    return unified_rel_path

def unification(goal, rule, KG, KG_index, rule_num, depth = 0, 
                variable_substitution={}, rComp_substitution={}, rule_structure=None):
    '''
    - function: 
        1. Unify Variables and store information in substitution dictionary
        2. Check Common Variable from Rules and Join each Triples
        
    :param goal: a query triple (e.g. [nationality BART USA])
    :parma rule: a given rule template (e.g. [#1(X,Y) :- #2(Y,X), 2])
    :param depth: an integer indicates rule depth
    :param variable_substitution: a dictionary which has information of unified variables
        - key: variable / value : unified entity (list)
    :param rComp_substitution : a dictionary which has information of unified rule components
        - key: rule compoenet / value : unified triples (dataframe)
    :return: proof paths generated by Symbolic Unification 
    '''

    if depth == 0 :
        rComp_substitution[rule[depth]] = pd.DataFrame(goal, index=[rule[depth][0],rule[depth][1],rule[depth][2]]).transpose()
        # subject variable binding
        if rule[depth][1] not in variable_substitution.keys():
            variable_substitution[rule[depth][1]] = [goal[1]]
        # object variable binding
        if rule[depth][2] not in variable_substitution.keys():
            variable_substitution[rule[depth][2]] = [goal[2]]   
        depth += 1
        
    if depth == len(rule)-1:

        return proof_path_completion(rComp_substitution, rule)

    else :
        common_variable = []
        current_body = rule_structure[rule_structure['rule_number'] == rule_num].iloc[0,depth]
        current_variable = list(current_body.keys())
        common_variable = [variable for variable in current_variable if variable in variable_substitution]
        cVar_position = list(map(current_body.get, common_variable)) 
        unified_cVar= list(map(variable_substitution.get, common_variable))
        
        if len(common_variable) == 1:
            cVar_index = set(itertools.chain.from_iterable(
                             list(map(lambda x : KG_index.get(x).get(cVar_position[0]), unified_cVar[0]))))

            sub_goal = KG.loc[cVar_index]
            sub_goal.columns = [rule[depth][0], rule[depth][1], rule[depth][2]]
            rComp_substitution[rule[depth]] = sub_goal

            # subject variable binding
            if rule[depth][1] not in variable_substitution.keys():
                variable_substitution[rule[depth][1]] = list(set(sub_goal[rule[depth][1]].values))
            # object variable binding
            if rule[depth][2] not in variable_substitution.keys():
                variable_substitution[rule[depth][2]] = list(set(sub_goal[rule[depth][2]].values))

        else : 
            subj_cVar_index = set(itertools.chain.from_iterable(
                                  list(map(lambda x : KG_index.get(x).get(cVar_position[0]), unified_cVar[0]))))
            obj_cVar_index = set(itertools.chain.from_iterable(
                                  list(map(lambda x : KG_index.get(x).get(cVar_position[1]), unified_cVar[1]))))
            sub_goal = KG.loc[subj_cVar_index &obj_cVar_index]
            sub_goal.columns = [rule[depth][0], rule[depth][1], rule[depth][2]]
            rComp_substitution[rule[depth]] = sub_goal      
        depth += 1
       
        return unification(goal, rule, KG, KG_index, rule_num, depth, 
                           variable_substitution, rComp_substitution,rule_structure)

## Run Symbolic Unification 


In [17]:
def create_unify_dict(unified_rel_path):
    # stakcoverflow Easy way to map values from list of list to dictionary
    unify_dict = collections.defaultdict(set)
    unified_rel_path = list(itertools.chain.from_iterable(unified_rel_path))
    for key, value in unified_rel_path:
        unify_dict[key].add(value)
    return unify_dict

def negative_sampling(unified_rel_path, unify_dict, KG):
    
    neg_unified_rel_path = copy.deepcopy(unified_rel_path)
    for path_idx, path in enumerate(neg_unified_rel_path):
        for group_idx, group in enumerate(path):
            rule_pred = group[0]
            pos_pred = unify_dict[rule_pred]
            KG_pred = set(KG['pred'])
            neg_pred = KG_pred - pos_pred

            if len(neg_pred) == 0:
                neg_unified_rel_path[path_idx][group_idx] = (rule_pred, 'UNK')
            else:
                neg_sym = random.choice(list(neg_pred))
                neg_unified_rel_path[path_idx][group_idx] = (rule_pred, neg_sym)
                
    return neg_unified_rel_path

In [18]:
def generate_batches(query, KG, rules):
    number = 0
    relation_path = []
    rule_temp_path = []
    max_path = 0
    for row in Query.itertuples(index=False):
        number += 1
        if number%100 == 0:
            print('generate proof paths : '+str(number)+'/'+str(len(Query)))
        elif number == len(Query):
            print('complete generating proof paths! : '+str(number)+'/'+str(len(Query)))
        goal = list(row)
        aug_rel_path_list = []
        aug_rule_temp_path_list = []
        for rule_num, rule in enumerate(rules):     
            unified_rel_path = unification(goal, rule, KG, KG_index,
                                                rule_num = rule_num, depth = 0, variable_substitution={}, 
                                                rComp_substitution={}, rule_structure = rule_structure)
            if len(unified_rel_path) > 0 :
                if max_path < len(unified_rel_path):
                    max_path = len(unified_rel_path)
                
                aug_rel_path = [[list(map(lambda x : sym2id_dict[x[1]], path))]*rule[-1] 
                                 for path in unified_rel_path]
                aug_rule_temp_path = [[list(map(lambda x : sym2id_dict[f'{x[0]}_{str(rule_num)}_{str(aug_num)}'], path)) 
                                      for aug_num in range(rule[-1])] for path in unified_rel_path]
                
                #negative sampling
                unify_dict = create_unify_dict(unified_rel_path)
                for i in range(neg_per_pos):
                    neg_unified_rel_path = negative_sampling(unified_rel_path, unify_dict, KG)
                    neg_aug_rel_path = [[list(map(lambda x : sym2id_dict[x[1]], path))]*rule[-1] 
                                      for path in neg_unified_rel_path]
                    aug_rel_path += neg_aug_rel_path
                aug_rule_temp_path += aug_rule_temp_path * neg_per_pos

                aug_rel_path_list.append(aug_rel_path)
                aug_rule_temp_path_list.append(aug_rule_temp_path)
            else :
                aug_rel_path_list.append([])
                aug_rule_temp_path_list.append([])
        relation_path.append(tuple(aug_rel_path_list))
        rule_temp_path.append(tuple(aug_rule_temp_path_list))

    return relation_path, rule_temp_path, max_path

relation_path, rule_temp_path, max_path = generate_batches(Query, KG, rules)

complete generating proof paths! : 7/7


In [19]:
#for debuging
# relation_path

In [20]:
#for debuging
# rule_temp_path

### data filtering
    - proof path가 없는 데이터 제거

In [21]:
def flatten(iter_object):
    for element in iter_object:
        if isinstance(element, Iterable):
            yield from flatten(element)
        else:
            yield element
            
def data_filter(path_to_query):
    path_existence = True
    if len(list(flatten(path_to_query))) == 0:
        path_existence = False
    return path_existence

relation_path = list(filter(data_filter, relation_path))
rule_temp_path = list(filter(data_filter, rule_temp_path))


## padding

In [22]:
def padding(relation_path, rule_temp_path, rules, max_path):
    
    comp_each_template = []
    for rule in rules:
        comp_each_template.append(len(rule)-1)

    single_temp_size = 1 + (1 * neg_per_pos)
    for query_idx, (rel_path_to_query, rule_temp_path_to_query) in \
        enumerate(zip(relation_path, rule_temp_path)):
        rel_path_to_query = list(rel_path_to_query)
        rule_temp_path_to_query = list(rule_temp_path_to_query)
        
        for template_idx, (rel_path_to_template, rule_temp_path_to_template) in \
            enumerate(zip(rel_path_to_query, rule_temp_path_to_query)):
            if len(rel_path_to_template) == 0:
                padding = np.ones((max_path*single_temp_size, rules[template_idx][-1], comp_each_template[template_idx]), 
                                  dtype=int).tolist()
                rel_path_to_query[template_idx] = padding
                rule_temp_path_to_query[template_idx] = padding

            elif len(rel_path_to_template) != 0:
                pad_rel_path_to_template = []
                pad_rule_temp_path_to_template = []
                num_pos_path = int(len(rel_path_to_template)/single_temp_size)
                for i in range(0, len(rel_path_to_template), num_pos_path):
                    pad_rel_path_to_template += rel_path_to_template[i:i+num_pos_path]
                    pad_rule_temp_path_to_template += rule_temp_path_to_template[i:i+num_pos_path]
                    padding = np.ones((max_path-num_pos_path, rules[template_idx][-1],comp_each_template[template_idx]),
                                      dtype=int).tolist()
                    pad_rel_path_to_template += padding
                    pad_rule_temp_path_to_template += padding
                rel_path_to_query[template_idx] = pad_rel_path_to_template
                rule_temp_path_to_query[template_idx] = pad_rule_temp_path_to_template
        relation_path[query_idx] = tuple(rel_path_to_query)
        rule_temp_path[query_idx] = tuple(rule_temp_path_to_query)
        
    return relation_path, rule_temp_path

relation_path, rule_temp_path = padding(relation_path, rule_temp_path, rules, max_path)

In [23]:
#for debuging
#rule_temp_path

In [24]:
#for debuging
#relation_path

## Train Relation Embedding

In [25]:
def l2_sim(embed_aug_rule_temp_path, embed_aug_rel_path):
    eps = 1e-6
    #stackoverflow To calculate euclidean distance between vectors in a torch tensor with multiple dimensions
    dist = torch.sqrt((embed_aug_rel_path - embed_aug_rule_temp_path).pow(2).sum(3)+eps)
    sim = torch.exp(-dist)
    return sim

In [26]:
class NTP(nn.Module):
    
    def __init__(self, vocab_size, embedding_size, batch_size, num_templates):
        super(NTP, self).__init__()
        self.vocab_size = vocab_size
        self.embedding_size = embedding_size
        self.embedding_matrix = nn.Embedding(self.vocab_size, self.embedding_size)
        self.loss = torch.nn.BCELoss()
        self.batch_size = batch_size
        self.template_size = num_templates
    
    def calculate_sim_avg(self, aug_rule_temp_path, aug_rel_path):
        sims_list = []

        for i in range(self.template_size):
            if len(aug_rule_temp_path[i]) == 0:
                continue
            if len(aug_rel_path[i]) == 0:
                continue
            lookup_tensor_aug_rule_temp_path = torch.tensor(aug_rule_temp_path[i], dtype=torch.long)
            lookup_tensor_aug_rel_path = torch.tensor(aug_rel_path[i], dtype=torch.long)
            embed_aug_rule_temp_path = self.embedding_matrix(lookup_tensor_aug_rule_temp_path)
            embed_aug_rel_path = self.embedding_matrix(lookup_tensor_aug_rel_path)
            sims=l2_sim(embed_aug_rule_temp_path, embed_aug_rel_path)
            avg_sims = torch.mean(sims, 2)
            sims_list.append(avg_sims)
            
        avg_sims_ = torch.cat(sims_list, dim=1)

        return avg_sims_
        
        
    def forward(self, aug_rule_temp_path, aug_rel_path):
        avg_sims = self.calculate_sim_avg(aug_rule_temp_path, aug_rel_path)
        x = torch.chunk(avg_sims, self.batch_size, dim=0)
        x = list(x)
        for i, t in enumerate(x):
            x[i] = torch.cat(torch.chunk(t, chunks =self.template_size ,dim=1), dim=0)#template
        sims = torch.cat(x, dim=0)
        max_sims = torch.max(sims, axis=1)[0]
        max_sims = max_sims.reshape(self.batch_size, -1)
        min_sims = torch.min(max_sims, axis=1)[0]
        
        return min_sims

In [27]:
num_templates = len(rules)
vocab_size = len(sym2id_dict)
embedding_size = 100
ntp = NTP(vocab_size, embedding_size, batch_size, num_templates)

In [28]:
answer = []
for i in range(pos_per_batch):
    answer += [1]
    for j in range(neg_per_pos):
        answer += [0]
answer = torch.tensor(answer, dtype=torch.float32)
answer = answer
answer

tensor([1., 0., 0.])

In [29]:
epochs = 100
report_interver_epoch = 10
optimizer = torch.optim.Adam(ntp.parameters(), lr = 0.08, weight_decay = 0.00001)
data_size = len(relation_path)
ntp.train()
time1 = datetime.now()
for epoch in range(1, epochs+1):
    for i in range(0, data_size, pos_per_batch):
        optimizer.zero_grad()
        aug_rel_path = []
        aug_rule_temp_path = []
        r1 = rule_temp_path[i:i+pos_per_batch]
        r2 = relation_path[i:i+pos_per_batch]
        if len(r1)<pos_per_batch:
            continue
        #stackoverflow An easy way to create a torch tensor from multiple elements of tuple through concatenate
        aug_rule_temp_path = [torch.Tensor([ atom for element in x for atom in element ]) for x in zip(*r1)]
        aug_rel_path = [torch.Tensor([ atom for element in x for atom in element ]) for x in zip(*r2)]
        y_hat = ntp.forward(aug_rule_temp_path, aug_rel_path)
        answer = answer
        loss = ntp.loss(y_hat, answer)
        loss.backward()
        optimizer.step()
    if epoch%report_interver_epoch == 0:
        print('Epoch: ',epoch, '      Loss : ', loss.item())

time2 = datetime.now()
print('\ntraining time : ', time2-time1)



Epoch:  10       Loss :  0.6187537908554077
Epoch:  20       Loss :  0.6109741926193237
Epoch:  30       Loss :  0.577885627746582
Epoch:  40       Loss :  0.8521082997322083
Epoch:  50       Loss :  0.48950350284576416
Epoch:  60       Loss :  0.4862099885940552
Epoch:  70       Loss :  0.453716903924942
Epoch:  80       Loss :  0.5597609877586365
Epoch:  90       Loss :  0.5289940237998962
Epoch:  100       Loss :  0.4599451720714569

training time :  0:00:03.649883


## write rule file

In [30]:
def representation_match(x, emb):
    dist = torch.torch.nn.functional.pairwise_distance(x, emb)
    sim = torch.exp(-dist)
    return sim

In [31]:
#get trained embedding matrix
for i in enumerate(ntp.parameters()):
    print(i[1])
    embeddings = i[1]

Parameter containing:
tensor([[ 1.6943e-10, -1.1226e-10, -1.0804e-09,  ...,  9.3667e-11,
         -9.7855e-11,  6.3358e-10],
        [ 9.4268e-10,  2.4035e-10,  1.9959e-09,  ...,  2.1840e-09,
         -5.1814e-11,  4.1839e-10],
        [-1.7726e-01,  1.4352e-01,  3.1908e-01,  ...,  3.9006e-01,
         -3.7725e-01,  1.7340e-01],
        ...,
        [-6.9433e-03,  1.7421e-01,  3.9733e-02,  ...,  5.4037e-01,
         -2.0568e-01,  8.2744e-01],
        [-1.9272e-01, -4.5514e-03,  3.3609e-01,  ...,  2.5940e-01,
         -3.5859e-01,  1.1728e-01],
        [-8.4288e-01, -1.6064e-01, -1.8080e-01,  ..., -2.9404e-01,
         -6.1306e-01,  3.3202e-01]], requires_grad=True)


In [32]:
#get parameterized rule template
rule_templates = {}
ids_rule_templates = {}
for rule_number, template in enumerate(rules):
    result_template_key = []
    ids_result_template_value = []
    ids_result_template_values = []
    for i in range(len(template)-1):
        rule_element=('p'+ str(int(template[i][0][1])-1), template[i][1], template[i][2])       
        result_template_key.append(rule_element)
        rule_element = ()

    for aug in range(template[-1]):
        for j in range(len(template)-1):
            ids_result_template_value.append([sym2id_dict[template[j][0]+'_'+str(rule_number)+'_'+
                                                           str(aug)], template[j][1], template[j][2]])
        ids_result_template_values.append(ids_result_template_value)
        ids_result_template_value = []
    ids_rule_templates[tuple(result_template_key)] = ids_result_template_values

ids_rule_templates

{(('p0', 'X', 'Y'),
  ('p1', 'X', 'Z'),
  ('p2', 'Z', 'Y')): [[[2, 'X', 'Y'],
   [6, 'X', 'Z'],
   [10, 'Z', 'Y']], [[3, 'X', 'Y'], [7, 'X', 'Z'], [11, 'Z', 'Y']]],
 (('p0', 'X', 'Y'), ('p1', 'X', 'Y')): [[[4, 'X', 'Y'], [8, 'X', 'Y']],
  [[5, 'X', 'Y'], [9, 'X', 'Y']]]}

In [33]:
#get rule instance & write rule file
masking_index = []
for key, value in ids_rule_templates.items():
    for rule in value:
        for element in rule:
            masking_index.append(element[0])
        
masking_index

total_reuslt = []
with open(data_name+'_rule.nl', 'w') as f:
    for key, value in ids_rule_templates.items():
        f.write(str(key)+'\n')
        for rule in value:
            result = []
            confidence_score = []
            rule_result = []
            for element in rule:
                masking_index = masking_index+[element[0]]+[0, 1]
                x = ntp.embedding_matrix(torch.tensor([element[0]]))
                match = representation_match(x, embeddings)
                match[masking_index] = 0
                top_k = torch.topk(match, 1)
                rule_result.append(id2sym_dict[top_k.indices.item()]+'('+element[1]+','+element[2]+')')
                confidence_score.append(match[top_k.indices])
            f.write(str(min(confidence_score).item())+'\t')
            head = rule_result[0]
            body = rule_result[1:]
            f.write(head + ' :- ' +", ".join(body)+'\n')  
            result.append((key, min(confidence_score).item(), rule_result))
            total_reuslt.append(result)
        f.write('\n')
total_reuslt

[[((('p0', 'X', 'Y'), ('p1', 'X', 'Z'), ('p2', 'Z', 'Y')),
   0.00034720447729341686,
   ['nationality(X,Y)', 'hasFather(X,Z)', 'nationality(Z,Y)'])],
 [((('p0', 'X', 'Y'), ('p1', 'X', 'Z'), ('p2', 'Z', 'Y')),
   0.39551132917404175,
   ['nationality(X,Y)', 'placeOfBirth(X,Z)', 'locatedIn(Z,Y)'])],
 [((('p0', 'X', 'Y'), ('p1', 'X', 'Y')),
   0.035376399755477905,
   ['bornIn(X,Y)', 'nationality(X,Y)'])],
 [((('p0', 'X', 'Y'), ('p1', 'X', 'Y')),
   0.10968772321939468,
   ['nationality(X,Y)', 'bornIn(X,Y)'])]]

In [34]:
with open(data_name+'_rule_batch'+str(batch_size)+'_epoch'+str(epochs)+
          '_sorted_'+str(time2)[11:13]+str(time2)[14:16]+'.nl', 'w') as file:
    with open(data_name+'_rule.nl', 'r') as f:
        augment = rules[0][-1]
        scores = []
        total_scores = []
        rule = []
        total_rules = []
        count = 0
        for line in f:

            if '.' not in line.split('\t')[0]:
                file.write(line.split('\t')[0])
            if '.' in line.split('\t')[0]:
                count+=1

                scores.append(round(float(line.split('\t')[0]), 8))
                rule.append(line.split('\t')[-1])
                if count % augment == 0:
                    count = 0
                    total_scores.append(scores)
                    total_rules.append(rule)
                    s = torch.sort(torch.tensor(scores), descending=True).values
                    r = torch.sort(torch.tensor(scores), descending=True).indices

                    for i in range(augment):
                        file.write(str(round(s[i].item(), 8))+'\t')
                        file.write(rule[r[i].item()])
                    scores = []
                    rule = []