In [1]:
import numpy as np

In [160]:
def get_relation_ids(ridfile):
    with open(ridfile, 'r') as infile:
        lines = infile.read().split('\n')
        rlist = [''] * len(lines)
        for pair in lines:
            p = pair.split('\t')
            rlist[int(p[1])] = p[0]
    return rlist
    
def get_relation_vectors(rvecfile):
    with open(rvecfile, 'r') as infile:
        lines = infile.read().split('\n')
        rlist = [np.array([float(s) for s in vec.split('\t') if len(s) > 0]) for vec in lines]
    return np.array(rlist)
        

def combine_relations(R, thresh):
    """
    R is a (M, N) matrix where M is the number of relations and N is the size of a relation
    Returns a list of K indices and a (K, N) matrix where K is the new number of relations 
    which have been combined based on the cosine similarity threshold
    """
    still_combining = True
    combined = np.copy(R)
    indices = list(range(len(R)))
    newindices = list(range(len(R)))
    while still_combining:
        still_combining = False
        for i in range(len(combined)):
            for j in range(len(combined)-1,i,-1):
                cos_sim = np.dot(combined[i,:], combined[j,:]) / (np.linalg.norm(combined[i,:]) * np.linalg.norm(combined[j,:]))
                if cos_sim > thresh:
                    still_combining = True
                    combined = np.delete(combined, j, 0)
                    newindices[indices.pop(j)] = i
    return newindices, combined

def reduce_relations(ridfile, rvecfile, thresh):
    rids = get_relation_ids(ridfile)
    rvecs = get_relation_vectors(rvecfile)
    newids, rvecs = combine_relations(rvecs, thresh)
    return [rids[i] for i in newids]
    

In [89]:
a = np.array([[1,0,0],[1,0,0],[0,1,0]])
combine_relations(a, 0.5)

([0, 2], array([[1, 0, 0],
        [0, 1, 0]]))

In [162]:
reduce_relations('relation2id.txt', 'relation2vec.csv', 0.1)

['/people/appointed_role/appointment./people/appointment/appointed_by',
 '/location/statistical_region/rent50_2./measurement_unit/dated_money_value/currency',
 '/people/appointed_role/appointment./people/appointment/appointed_by',
 '/people/appointed_role/appointment./people/appointment/appointed_by',
 '/medicine/disease/prevention_factors',
 '/organization/organization_member/member_of./organization/organization_membership/organization',
 '/american_football/football_player/receiving./american_football/player_receiving_statistics/season',
 '/location/statistical_region/rent50_2./measurement_unit/dated_money_value/currency',
 '/sports/sports_team/roster./soccer/football_roster_position/player',
 '/location/statistical_region/rent50_2./measurement_unit/dated_money_value/currency',
 '/people/appointed_role/appointment./people/appointment/appointed_by',
 '/architecture/structure/address./location/mailing_address/citytown',
 '/people/appointed_role/appointment./people/appointment/appointed

In [187]:
def create_relation_entity_files(triplefile, relationfile, entityfile):
    entities = set()
    relations = set()
    with open(triplefile, 'r') as infile:
        for line in infile.read().split('\n'):
            if line == '':
                continue
            triple = line.split('\t')
            entities.add(triple[0])
            entities.add(triple[1])
            relations.add(triple[2])
    entities = list(entities)
    relations = list(relations)
    with open(relationfile, 'w') as outfile:
        outfile.write('\n'.join([relations[i] + '\t' + str(i) for i in range(len(relations))]))
    with open(entityfile, 'w') as outfile:
        outfile.write('\n'.join([entities[i] + "\t" + str(i) for i in range(len(entities))]))

In [188]:
create_relation_entity_files('relation_tuples.txt', 'relation2id.txt', 'entity2id.txt')