In [1]:
import numpy as np
from itertools import permutations
from collections import defaultdict
import random

# load and parse dataset

In [2]:
!file kinships -I

kinships: text/plain; charset=us-ascii


In [3]:
raw_data = []

entities = set()
with open('kinships', 'r') as to_read:
    for i, line in enumerate(to_read.readlines()):
        s, p, o = line.strip().split(' ')
        entities.add(s)
        entities.add(o)
        raw_data += [(s,p,o)]

In [4]:
A_implies_A_rules = [
    ('term18',),
]

# term5(x1, x2) => term15(x2, x1)
A_implies_B_rules = [
    ('term5', 'term15'),
]

A_B_implies_C_rules = [
    ('term2', 'term22', 'term15'),
]

#java -jar /Users/simon/Office/Dokumente/Uni/Data\ Science\ and\ Machine\ Learning\ Master/Masters\ Project/Libraries/amie-dev.jar -d " " -minc 0.8 -mins 450 -maxad 2 kinships

#java -jar /Users/simon/Office/Dokumente/Uni/Data\ Science\ and\ Machine\ Learning\ Master/Masters\ Project/Libraries/amie-dev.jar -d " " -minc 0.8 -mins 30 -maxad 3 kinships | grep '?h  ?h'

In [5]:
train = set()
valid = set()
test = set()
entities = set()

A_implies_A_rule_examples = defaultdict(lambda: [])
A_implies_B_rule_examples = defaultdict(lambda: [])
A_B_implies_C_rule_examples = defaultdict(lambda: [])

counter_A_implies_A_rules = defaultdict(lambda: 0)
counter_A_implies_B_rules = defaultdict(lambda: 0)
counter_A_B_implies_C_rules = defaultdict(lambda: 0)

for s,p,o in raw_data:
    entities.add(s)
    entities.add(o)

for x1, x2 in permutations(entities, 2):
    for (A,) in A_implies_A_rules:
        if (x1, A, x2) in raw_data and (x2, A, x1) in raw_data:
            valid.add((x2, A, x1))
            A_implies_A_rule_examples[(A,)] += [(x1, x2)]
            counter_A_implies_A_rules[(A,)] += 1

for x1, x2 in permutations(entities, 2):
    for (A, B) in A_implies_B_rules:
        if (x1, A, x2) in raw_data and (x2, B, x1) in raw_data:
            valid.add((x2, B, x1))
            A_implies_B_rule_examples[(A, B)] += [(x1, x2)]
            counter_A_implies_B_rules[(A, B)] += 1

for x1, x2, x3 in permutations(entities, 3):
    for (A, B, C) in A_B_implies_C_rules:    
        if (x1, A, x2) in raw_data and (x2, B, x3) in raw_data and (x1, C, x3) in raw_data:
            valid.add((x1, C, x3))
            A_B_implies_C_rule_examples[(A, B, C)] += [(x1, x2, x3)]
            counter_A_B_implies_C_rules[(A, B, C)] += 1

for s,p,o in raw_data:
    if (s,p,o) not in valid:
        train.add((s,p,o))

In [6]:
train = list(train)
valid = list(valid)
random.Random(42).shuffle(valid)
valid, test = valid[:len(valid) // 2], valid[len(valid) // 2:]

In [7]:
print(len(train))
print(len(valid))
print(len(test))

9582
552
552


# check that splits are mutually exclusive

In [8]:
for triple in train:
    if triple in valid:
        print("valid", triple)
    if triple in test:
        print("valid", triple)

In [9]:
for triple in valid:
    if triple in train:
        print("train", triple)
    if triple in test:
        print("test", triple)

In [10]:
for triple in test:
    if triple in train:
        print("train", triple)
    if triple in valid:
        print("valid", triple)

# save splits as .tsv

In [11]:
with open("train.tsv", "w", encoding='utf-8') as f:
    for triple in train:
        f.write("{}\t{}\t{}\n".format(*triple))

In [12]:
with open("valid.tsv", "w", encoding='utf-8') as f:
    for triple in valid:
        f.write("{}\t{}\t{}\n".format(*triple))

In [13]:
with open("test.tsv", "w", encoding='utf-8') as f:
    for triple in test:
        f.write("{}\t{}\t{}\n".format(*triple))

In [14]:
for (A,), examples in A_implies_A_rule_examples.items():
    with open("{}=>{}.tsv".format(A,A), "w", encoding='utf-8') as f:
        for (x1, x2) in examples:
            f.write("{}\t{}\t{}\n".format(x2, A, x1))

for (A, B), examples in A_implies_B_rule_examples.items():
    with open("{}=>{}.tsv".format(A,B), "w", encoding='utf-8') as f:
        for (x1, x2) in examples:
            f.write("{}\t{}\t{}\n".format(x2, B, x1))

for (A, B, C), examples in A_B_implies_C_rule_examples.items():
    with open("{},{}=>{}.tsv".format(A,B,C), "w", encoding='utf-8') as f:
        for (x1, x2, x3) in examples:
            f.write("{}\t{}\t{}\n".format(x1, C, x3))