# Imports

In [21]:
import numpy as np
import os
import pandas as pd

from kge.model import KgeModel
from kge.util.io import load_checkpoint
from kge.indexing import index_relation_types

# Entities/ Relations <-> Ids

In [12]:
class ERI:

    def __init__(self, dataset_path):
        self.dataset_path = dataset_path
        self.entities = pd.read_csv(os.path.join(dataset_path, 'entity_ids.del'), sep='\t', names=['id', 'entity'])
        self.relations = pd.read_csv(os.path.join(dataset_path, 'relation_ids.del'), sep='\t', names=['id', 'relation'])
    
    def _get_multiple(self, getter, inputs):        
        results = []
        for input in inputs:
            results.append(getter(input))
        return results
    
    def get_entity_by_id(self, id):
        return self.entities.loc[self.entities['id'] == id].iat[0,1]
    
    def get_entities_by_id(self, ids):
        return self._get_multiple(self.get_entity_by_id, ids)
    
    def get_entity_id(self, entity):
        if entity.isdigit():
            entity = int(entity)
        return self.entities.loc[self.entities['entity'] == entity].iat[0,0]
    
    def get_entity_ids(self, entities):
        return self._get_multiple(self.get_entity_id, entities)
    
    def get_relation_by_id(self, id):
        return self.relations.loc[self.relations['id'] == id].iat[0,1]
    
    def get_relations_by_id(self, ids):
        return self._get_multiple(self.get_relation_by_id, ids)
    
    def get_relation_id(self, relation):
        if relation.isdigit():
            relation = int(relation)
        return self.relations.loc[self.relations['relation'] == relation].iat[0,0]
    
    def get_relation_ids(self, relations):
        return self._get_multiple(self.get_relation_id, relations)
    
    def get_all_entities(self):
        return self.entities['entity'].unique()
    
    def get_all_relations(self):
        return self.relations['relation'].unique()
    
    def get_all_entity_ids(self):
        return self.entities['id'].unique()
    
    def get_all_relation_ids(self):
        return self.relations['id'].unique()

# Relation Frequency in Training Data

In [23]:
relation_frequency_map = {}
def get_relation_frequency_in_training_data(dataset):
    if dataset in relation_frequency_map:
        return relation_frequency_map[dataset]
    
    path_to_dataset = os.path.join('experiments', '0_datasets', dataset)
    path_to_training = os.path.join(path_to_dataset, 'train.txt')

    training_df = pd.read_csv(path_to_training, sep='\t', header=None)
    training_df.columns = ['h', 'r', 't']

    relations = training_df['r'].unique()
    relation_counts = training_df['r'].value_counts()

    df = pd.DataFrame(columns=['r_id','freq'])
    eri = ERI(path_to_dataset)
    for relation in relations:
        df.loc[df.shape[0]] = [eri.get_relation_id(relation), relation_counts[relation]]
    
    df['norm_freq'] = (df['freq'] - df['freq'].min())/(df['freq'].max()-df['freq'].min())
    df = df.reset_index()
    
    relation_frequency_map[dataset] = df    
    return df

# Relation Class

In [24]:
"""
def get_relation_classes(dataset_name, threshold=0.85):
    path_to_dataset = os.path.join('experiments', '0_datasets', dataset_name)
    path_to_training = os.path.join(path_to_dataset, 'train.txt')
    path_to_valid = os.path.join(path_to_dataset, 'valid.txt')
    path_to_test = os.path.join(path_to_dataset, 'test.txt')

    training_df = pd.read_csv(path_to_training, sep='\t', header=None)
    valid_df = pd.read_csv(path_to_valid, sep='\t', header=None)
    test_df = pd.read_csv(path_to_test, sep='\t', header=None)

    eri = ERI(path_to_dataset)
    relations = eri.get_all_relations()

    data = pd.concat([training_df, valid_df, test_df])
    data.columns= ['h','r','t']

    df=pd.DataFrame(columns=['r_id', 'relationClass'])

    for relation in relations:
        data_for_h_r = data[data['r']==relation].groupby(['h', 'r']).agg(set)
        data_for_t_r = data[data['r']==relation].groupby(['t', 'r']).agg(set)

        OneTo = 0
        Nto = 0
        for i in range(data_for_t_r.size):
            h_count  = len(data_for_t_r['h'].iloc[i])    
            if h_count > 1:
                Nto += 1
            elif h_count == 1:
                OneTo += 1 

        toOne = 0
        toM = 0
        for i in range(data_for_h_r.size):
            t_count  = len(data_for_h_r['t'].iloc[i])
            if t_count > 1:
                toM += 1
            elif t_count == 1:
                toOne += 1          

        xTo = ''
        if OneTo/(OneTo+Nto) > threshold:
            xTo = '1'
        else:
            xTo = 'N'

        toX = ''
        if toOne/(toOne+toM) > threshold:
            toX = '1'
        else:
            toX = 'M'

        df.loc[df.shape[0]] = [eri.get_relation_id(relation), f'{xTo}to{toX}']
    
    return df 
"""

In [32]:
relation_classes_map = {}
def get_relation_classes(checkpoint_path):
    if checkpoint_path in relation_classes_map:
        return relation_classes_map[checkpoint_path]
    
    checkpoint = load_checkpoint(checkpoint_path)
    model = KgeModel.create_from(checkpoint)

    eri = ERI(model.dataset.folder)

    relation_strings = model.dataset.relation_strings()
    relation_ids = [eri.get_relation_id(relation_string) for relation_string in relation_strings]
    relation_types = index_relation_types(model.dataset)
    
    df = pd.DataFrame(data={'r_id': relation_ids, 'relationClass': relation_types})
    
    relation_classes_map[checkpoint_path] = df
    return df

Loaded 11 keys from map relation_strings
  62547 distinct sp pairs in train
  40962 distinct po pairs in train


Unnamed: 0,r_id,relationClass
0,0,M-1
1,1,M-N
2,2,M-1
3,3,M-N
4,4,1-N
5,5,M-1
6,6,1-N
7,7,1-N
8,8,1-N
9,9,1-1


# Formatted Data Name

In [25]:
def get_formatted_data_name(dataset_name, symbolic_name, subsymbolic_name):
    return f'{dataset_name}_{symbolic_name}_{subsymbolic_name}.txt'