In [1]:
import torch
import torch.optim as optim
import torch.nn as nn
from torch.utils.data import DataLoader
import numpy as np
from scipy.stats import pearsonr, spearmanr
from transformers import logging
from transformers import pipeline
from transformers import BertTokenizer, BertModel
from transformers import AutoTokenizer
from tqdm import tqdm
from PipelineCacheWrapper.PipelineCacheWrapper import PipelineCacheWrapper

In [2]:
from LamaTRExData import LamaTRExData
from SentenceTypologyQueryResults import SentenceTypologyQueryResults
from SentenceComparison.SentenceComparison import SentenceComparison
from ModelHelpers.fill_mask_helpers import get_probability_from_pipeline_for_token
from relation_templates.templates import relations, nominalized_relations, get_length_for_relation, get_templates, get_relation_meta, get_relation_name, get_relation_cardinality

In [3]:
from TypologyQuerier import TypologyQuerier 

In [4]:
logging.set_verbosity_error()

In [5]:
@torch.no_grad()
def metric(sentence: str, token: str):
    prob = get_probability_from_pipeline_for_token(model(sentence), token)
    return prob

In [6]:
MASK = "[MASK]"
KEYS = ["active", "passive", "nominalized"]
TOP_K = 1

In [7]:
#relations = ["P19", "P36", "P101", "P103","P106","P108", "P178", "P1001"]
#relations = ["P19", "P413", "P159", "P103"]
#relations = ["P364"]
#relations = nominalized_relations

In [8]:
TREx = LamaTRExData(relations = relations)
TREx.load()

In [9]:
tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")
my_tokenizer = lambda sentence: tokenizer(sentence)['input_ids']

In [10]:
unmasker = PipelineCacheWrapper('fill-mask', model='bert-base-cased', top_k=TOP_K)

In [11]:
querier = TypologyQuerier(unmasker, relations, TOP_K, MASK)
querier.query(TREx.data)

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 967/967 [00:00<00:00, 31868.91it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 966/966 [00:00<00:00, 33166.32it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████

In [12]:
k10querier = TypologyQuerier(unmasker, relations, 10, MASK)
k10querier.query(TREx.data)

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 967/967 [00:00<00:00, 34677.30it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 966/966 [00:00<00:00, 34014.44it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████

In [13]:
unmasker = PipelineCacheWrapper('fill-mask', model='bert-base-cased', top_k=500)

In [14]:
@torch.no_grad()
def metric(sentence: str, token: str):
    prob = get_probability_from_pipeline_for_token(unmasker(sentence), token)
    return prob

In [15]:
Comparer = SentenceComparison(relations, get_templates, metric, MASK, get_relation_meta)
Comparer.compare(TREx.data)

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 967/967 [00:00<00:00, 5370.31it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 966/966 [00:00<00:00, 5534.27it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████

In [16]:
lower_limit, upper_limit = 25.0, 100.0 
def filter_for_relevancek1(rel):
    if get_relation_cardinality(rel) == "N:M":
        return lower_limit < k10querier.results.relation_results[rel].get()["simple"]["accuracy"] < upper_limit
    else:
        return lower_limit < querier.results.relation_results[rel].get()["simple"]["accuracy"] < upper_limit


In [17]:
considered_relations = list(filter(filter_for_relevancek1, relations))

In [18]:
considered_relations

['P159',
 'P37',
 'P1412',
 'P17',
 'P178',
 'P103',
 'P20',
 'P36',
 'P407',
 'P276',
 'P31',
 'P364',
 'P30',
 'P1376',
 'P937',
 'P127',
 'P279',
 'P1001',
 'P140',
 'P463',
 'P176']

In [19]:
good_relations = ['P37', 'P1412', 'P178', 'P103', 'P36', 'P407', 'P1376', 'P1001', 'P140', 'P463', 'P176']
middle_relations = ['P159', 'P17', 'P20', 'P276', 'P31', 'P364', 'P30', 'P937', 'P127', 'P279']

In [20]:
accuracies = querier.results_for_calculation()
np_accuracies = np.array(accuracies)[1:,1:].flatten().astype(float)

In [21]:
accuracies = querier.results_for_calculation()
accuracies = list(filter( lambda row: row[0] in considered_relations, accuracies))
np_accuracies = np.array(accuracies)[1:,1:].flatten().astype(float)

In [22]:
k10accuracies = k10querier.results_for_calculation()
k10accuracies = list(filter( lambda row: row[0] in considered_relations, k10accuracies))
np_k10accuracies = np.array(k10accuracies)[1:,1:].flatten().astype(float)

In [23]:
probabilities = Comparer.make_global_comparison()
probabilities = list(filter( lambda row: row[0] in considered_relations, probabilities))
np_probabilities = np.array(probabilities)[:,1:].flatten().astype(float)

In [24]:
simple_dom = lambda row: max(row[1:]) == row[1]
simple_not_zero = lambda row: row[1] > 0.0
percentage_of_simple = lambda row: [row[0], get_relation_name(row[0]) ,row[1], (sum(row[2:])/3), (sum(row[2:])/3)/row[1]*100]

In [25]:
new_accuracies = list(filter(simple_not_zero, accuracies[1:]))
new_accuracies = list(map(percentage_of_simple, new_accuracies))
new_accuracies = sorted(new_accuracies, key=lambda row: row[4]) 

In [26]:
for row in new_accuracies:
    print(" & ".join(map(lambda entry: entry if type(entry)==str else f"{entry:.1f}", row))+"\\% \\\\")

P937 & work location & 28.0 & 0.0 & 0.0\% \\
P20 & place of death & 27.9 & 0.0 & 0.1\% \\
P31 & instance of & 36.7 & 0.3 & 0.9\% \\
P1412 & languages spoken, written or signed & 65.0 & 1.3 & 2.0\% \\
P30 & continent & 25.4 & 0.7 & 2.8\% \\
P463 & member of  & 67.1 & 3.1 & 4.6\% \\
P279 & subclass of & 30.6 & 5.2 & 16.8\% \\
P178 & developer & 58.6 & 11.1 & 18.9\% \\
P127 & owned by & 34.8 & 9.1 & 26.1\% \\
P407 & language of work or name & 59.2 & 26.6 & 45.0\% \\
P17 & country & 31.3 & 18.3 & 58.4\% \\
P37 & official language & 54.6 & 35.2 & 64.6\% \\
P364 & original language of film or TV show & 44.5 & 30.3 & 68.0\% \\
P140 & religion or worldview & 75.1 & 53.3 & 71.1\% \\
P176 & manufacturer & 85.6 & 62.4 & 72.8\% \\
P36 & capital & 62.2 & 53.8 & 86.6\% \\
P276 & location & 41.5 & 38.5 & 92.7\% \\
P1376 & capital of & 73.9 & 73.4 & 99.2\% \\
P103 & native language & 72.2 & 74.1 & 102.7\% \\
P159 & headquarters location & 32.4 & 34.2 & 105.8\% \\


In [27]:
map_to_float = lambda row: [row[0]] + list(map(float, row[1:]))

In [28]:
new_probs = list(map(map_to_float, probabilities[1:]))
new_probs = list(filter(simple_not_zero, new_probs))
new_probs = list(map(percentage_of_simple, new_probs))
new_probs = sorted(new_probs, key=lambda row: row[4]) 

In [29]:
for row in new_probs:
    print(" & ".join(map(lambda entry: entry if type(entry)==str else f"{entry:.1f}", row))+" \\\\")

P937 & work location & 0.0 & 0.0 & 0.0 \\
P31 & instance of & 0.2 & 0.0 & 0.4 \\
P20 & place of death & 0.1 & 0.0 & 1.5 \\
P279 & subclass of & 0.2 & 0.0 & 9.5 \\
P30 & continent & 0.1 & 0.0 & 12.4 \\
P463 & member of  & 0.5 & 0.1 & 12.4 \\
P127 & owned by & 0.2 & 0.0 & 27.5 \\
P178 & developer & 0.3 & 0.1 & 34.4 \\
P176 & manufacturer & 0.7 & 0.3 & 43.3 \\
P140 & religion or worldview & 0.6 & 0.3 & 45.0 \\
P1001 & applies to jurisdiction & 0.2 & 0.1 & 49.9 \\
P364 & original language of film or TV show & 0.2 & 0.1 & 52.2 \\
P37 & official language & 0.4 & 0.2 & 53.0 \\
P1412 & languages spoken, written or signed & 0.1 & 0.1 & 53.2 \\
P17 & country & 0.2 & 0.1 & 80.0 \\
P36 & capital & 0.5 & 0.4 & 83.1 \\
P1376 & capital of & 0.6 & 0.6 & 95.9 \\
P407 & language of work or name & 0.1 & 0.2 & 104.0 \\
P103 & native language & 0.5 & 0.5 & 106.3 \\
P276 & location & 0.2 & 0.3 & 109.6 \\


In [30]:
class1 = []
class2 = []
class3 = []
class4 = []
class5 = []
class6 = []
class7 = []

for acc_row in new_accuracies:
    if acc_row[4] < 0.01:
        class1.append(acc_row[0])
    elif acc_row[4] < 1.0:
        class2.append(acc_row[0])
    elif acc_row[4] < 10.0:
        class3.append(acc_row[0])
    elif acc_row[4] < 50.0:
        class4.append(acc_row[0])
    elif acc_row[4] < 80.0:
        class5.append(acc_row[0])
    elif acc_row[4] < 100.0:
        class6.append(acc_row[0])
    else:
        class7.append(acc_row[0])

In [31]:
def print_class(clss):
    for rel in clss: 
        print("\n", rel, get_relation_meta(rel))
        for key, val in get_templates(rel, "[SUBJ]", MASK).items():
            print(key[0:12],": \t", val)

In [33]:
 print_class(middle_relations)


 P159 headquarters location N:1
simple : 	 The headquarter of [SUBJ] is in [MASK].
compound : 	 [SUBJ] is located in [MASK] and that is where they have their headquarters
complex : 	 [SUBJ] is located in [MASK] because they have their headquarters there.
compound-com : 	 [SUBJ] has a location in [MASK] and this location is important because it is their headquarters.

 P17 country N:1
simple : 	 [SUBJ] is located in [MASK].
compound : 	 [MASK] is a country and [SUBJ] is located there.
complex : 	 [SUBJ] is located in [MASK], which is a country.
compound-com : 	 [MASK] is a country and it is comprised of many locations, for example [SUBJ].

 P20 place of death N:1
simple : 	 [SUBJ] died in [MASK].
compound : 	 [MASK] is the name of a place and [SUBJ] died there.
complex : 	 [MASK] is the name of the place, where [SUBJ] died.
compound-com : 	 [MASK] is the name of the place, where some people died, for example [SUBJ].

 P276 location N:1
simple : 	 [SUBJ] is located in [MASK].
compound :