In [14]:
import numpy as np
import os
import sys


def get_relations_by_type(data_dir):
    with open(os.path.join(data_dir, 'raw.kb')) as f:
        triples = list(f.readlines())
    with open(os.path.join(data_dir, 'train.triples')) as f:
        triples += list(f.readlines())
    triples = list(set(triples))

    query_answers = dict()

    theta_1_to_M = 1.5

    for triple_str in triples:
        e1, e2, r = triple_str.strip().split('\t')
        if not r in query_answers:
            query_answers[r] = dict()
        if not e1 in query_answers[r]:
            query_answers[r][e1] = set()
        query_answers[r][e1].add(e2)

    to_M_rels = set()
    to_1_rels = set()

    dev_rels = set()
    with open(os.path.join(data_dir, 'dev.triples')) as f:
        for line in f:
            e1, e2, r = line.strip().split('\t')
            dev_rels.add(r)

    num_rels = len(dev_rels)
    print('{} relations in dev dataset in total'.format(num_rels))
    for r in dev_rels:
        ratio = np.mean([len(x) for x in query_answers[r].values()])
        if ratio > theta_1_to_M:
            to_M_rels.add(r)
        else:
            to_1_rels.add(r)

    num_to_M = len(to_M_rels) + 0.0
    num_to_1 = len(to_1_rels) + 0.0
    print('to-M: {}/{} ({})'.format(num_to_M, num_rels, num_to_M / num_rels))
    print('to-1: {}/{} ({})'.format(num_to_1, num_rels, num_to_1 / num_rels))
    return to_M_rels, to_1_rels

    
def main():
    dataset = 'FB15K-237-10'
    data_dir = os.path.join('/home/yhz/miniconda3/PAAR-main/data', dataset)
    get_relations_by_type(data_dir)
    triples=[]
    with open(os.path.join(data_dir, 'dev.triples')) as f:
        triples += list(f.readlines())
    with open(os.path.join(data_dir, 'test.triples')) as f:
        triples += list(f.readlines())
    print('{} unique facts'.format(len(set(triples))))
    
main()

216 relations in dev dataset in total
to-M: 55.0/216 (0.25462962962962965)
to-1: 161.0/216 (0.7453703703703703)
33774 unique facts
