In [96]:
import spacy
from spacy import displacy
import networkx as nx

In [19]:
from relation_templates.templates import get_all_templates, relations, get_templates

In [28]:
nlp = spacy.load("en_core_web_sm")

In [41]:
SUBJECT_LABEL = "subject"
OBJECT_LABEL = "object"

In [15]:
#list(filter(lambda template: template[0] != "P",get_all_templates()))

In [45]:
def find_token_index(tree, correct_token):
    for index, token in enumerate(tree):
        if str(token)==correct_token:
            return index
    return -1
        

In [56]:
def getDependencies(tree, token):
    token_index = find_token_index(tree, SUBJECT_LABEL)
    return list(tree[token_index].subtree)

In [67]:
def getIsDependency(tree, target_token):
    dependencies = list()
    for token in tree:
        for sub_token in list(token.subtree):
            if str(target_token) == str(sub_token):
                dependencies.append(token)
                break
    return dependencies
            

In [71]:
def tokenDependencyScore(tree, token):
        dependencies = getDependencies(tree, token)
        is_dependencies = getIsDependency(tree, token)
        return (len(dependencies)+len(is_dependencies))/len(tree)

In [93]:
def newTokenDependencyScore(tree, token):
    label_set = {OBJECT_LABEL, SUBJECT_LABEL}
    score = 0
    for token in tree:
        isBoth = label_set <= set(token.subtree)
        score += isBoth
    return score

In [112]:
def getGraph(tree):
    edges = []
    for token in tree:
        for child in token.children:
            edges.append(('{0}'.format(token.lower_),
                          '{0}'.format(child.lower_)))
    return nx.Graph(edges)


def shortestPathLength(tree, src, target):
    graph = getGraph(tree)
    try:
        return nx.shortest_path_length(graph, src, target)
    except:
        return 0
    
    
def spTokenDependencyScore(tree, token):
    return shortestPathLength(tree, SUBJECT_LABEL, OBJECT_LABEL)/len(tree)
    

In [113]:
for relation in relations:
    #print(relation)
    for sentence_type, sentence in get_templates(relation, SUBJECT_LABEL, OBJECT_LABEL).items():
        tree = nlp(sentence)
        sub_dependencies = getDependencies(tree, SUBJECT_LABEL)
        sub_is_dependencies = getIsDependency(tree, SUBJECT_LABEL)
        
        obj_dependencies = getDependencies(tree, OBJECT_LABEL)
        obj_is_dependencies = getIsDependency(tree, OBJECT_LABEL)
        
        
        #print(f"{sentence_type[:12]}: \t\t{SUBJECT_LABEL}: {len(sub_dependencies)}, \t\t{OBJECT_LABEL}: {len(obj_dependencies)} \t\t\t\t {SUBJECT_LABEL} is: {len(sub_is_dependencies)} \t\t {OBJECT_LABEL} is: {len(obj_is_dependencies)}")
        

In [114]:
def createTable(relations, token, score_function=tokenDependencyScore):
    relations.sort()
    table = [['relations', 'simple', 'compound', 'complex', 'compound-complex']]
    for relation in relations:
        row = [relation]
        for sentence in get_templates(relation, SUBJECT_LABEL, OBJECT_LABEL).values():
            tree = nlp(sentence)
            score = score_function(tree, OBJECT_LABEL)
            row.append(score)
        table.append(row)
    return table
    

In [116]:
relations.sort()
for relation in relations:
    #print(relation)
    for sentence_type, sentence in get_templates(relation, SUBJECT_LABEL, OBJECT_LABEL).items():
        tree = nlp(sentence)
        #print(f"{sentence_type[:12]}: {spTokenDependencyScore(tree, OBJECT_LABEL)}")

In [117]:
createTable(relations, OBJECT_LABEL, spTokenDependencyScore)

[['relations', 'simple', 'compound', 'complex', 'compound-complex'],
 ['P1001', 0.625, 0.3, 0.25, 0.11764705882352941],
 ['P101', 0.625, 0.2, 0.2, 0.15],
 ['P103', 0.5, 0.18181818181818182, 0.4444444444444444, 0.13333333333333333],
 ['P106', 0.2857142857142857, 0.2, 0.4444444444444444, 0.125],
 ['P108', 0.6, 0.2727272727272727, 0.4444444444444444, 0.26666666666666666],
 ['P127', 0.5, 0.3, 0.5, 0.1875],
 ['P1303', 0.5, 0.4444444444444444, 0.36363636363636365, 0.26666666666666666],
 ['P131', 0.5, 0.36363636363636365, 0.3, 0.13333333333333333],
 ['P136', 0.6, 0.3, 0.16666666666666666, 0.1],
 ['P1376', 0.5714285714285714, 0.3076923076923077, 0.36363636363636365, 0.25],
 ['P138', 0.5, 0.07142857142857142, 0.3076923076923077, 0.05263157894736842],
 ['P140', 0.5555555555555556, 0.2857142857142857, 0.23076923076923078, 0.25],
 ['P1412', 0.5714285714285714, 0.25, 0.36363636363636365, 0.29411764705882354],
 ['P159', 0.625, 0.23076923076923078, 0.25, 0.23529411764705882],
 ['P17', 0.5, 0.4, 0.272