In [2]:
WN18RR_DATASET_PATH = "../data/WN18RR"
FB15K_DATASET_PATH = "../data/FB15K-237"
ENTITIES_FILENAME = "entity2id.txt"
RELATIONS_FILENAME = "relation2id.txt"
TRAINING_DATASET_FILENAME = "train.txt"
VALIDATION_DATASET_FILENAME = "valid.txt"
TEST_DATASET_FILENAME = "test.txt"

In [3]:
def get_entities_filename(dataset_path):
    return os.path.join(dataset_path, ENTITIES_FILENAME)


def get_relations_filename(dataset_path):
    return os.path.join(dataset_path, RELATIONS_FILENAME)


def get_training_samples_path(dataset_path):
    return os.path.join(dataset_path, TRAINING_DATASET_FILENAME)


def get_validation_samples_path(dataset_path):
    return os.path.join(dataset_path, VALIDATION_DATASET_FILENAME)


def get_test_samples_path(dataset_path):
    return os.path.join(dataset_path, TEST_DATASET_FILENAME)

In [4]:
def load_entities_of_dataset(dataset_path):
    entities_filename = get_entities_filename(dataset_path)
    with open(entities_filename, mode="r") as file_stream:
        return {
            line.strip().split("\t")[0]
            for line in file_stream.readlines()
        }

    
def load_entities_of_samples(samples_path):
    with open(samples_path, mode="r") as file_stream:
        lines = list(file_stream.readlines())
        head_entities = {line.strip().split("\t")[0] for line in lines}
        tail_entities = {line.strip().split("\t")[2] for line in lines}
        return head_entities.union(tail_entities)


def load_matching_samples(samples_path, entities=None):
    matching_samples = set()
    with open(samples_path, mode="r") as file_stream:
        for line in file_stream.readlines():
            sample = tuple(line.strip().split("\t"))
            if entities is None:
                matching_samples.add(sample)
            elif head_entity in entities or tail_entity in entities:
                matching_samples.add(sample)
    return matching_samples

In [5]:
def load_relations_of_dataset(dataset_path):
    relations_filename = os.path.join(dataset_path, RELATIONS_FILENAME)
    with open(relations_filename, mode="r") as file_stream:
        return {
            line.strip().split("\t")[0]
            for line in file_stream.readlines()
        }


def load_relations_of_samples(samples_path):
    with open(samples_path, mode="r") as file_stream:
        return {
            line.strip().split("\t")[1]
            for line in file_stream.readlines()
        }


def explore_relations_counts(dataset_path):
    print(f"Dataset path: '{dataset_path}'")    
    relations = load_relations_of_dataset(dataset_path)
    print(f"Total relations count: {len(relations)}")    
    training_samples_path = get_training_samples_path(dataset_path)
    training_relations = load_relations_of_samples(training_samples_path)
    print(f"Training relations count: {len(training_relations)}")
    validation_samples_path = get_validation_samples_path(dataset_path)
    validation_relations = load_relations_of_samples(validation_samples_path)
    print(f"Validation relations count: {len(validation_relations)}")    
    test_samples_path = get_test_samples_path(dataset_path)
    test_relations = load_relations_of_samples(test_samples_path)
    print(f"Test relations count: {len(test_relations)}")