In [1]:
import torch
import torch.optim as optim
import torch.nn as nn
from torch.utils.data import DataLoader
from transformers import logging
from transformers import pipeline
from transformers import BertTokenizer, BertModel
from tqdm import tqdm
from PipelineCacheWrapper.PipelineCacheWrapper import PipelineCacheWrapper

In [2]:
from TRExData.HardTRExData import HardKnownTRExData
from SentenceTypologyQueryResults import SentenceTypologyQueryResults
from relation_templates.templates import relations, nominalized_relations, get_templates, get_relation_cardinality
from ModelHelpers.data_helpers import create_known_filter_getter, create_hard_filter_getter

In [3]:
from TypologyQuerier import TypologyQuerier 

In [4]:
logging.set_verbosity_error()

In [5]:
MASK = "[MASK]"
KEYS = ["active", "passive", "nominalized"]

In [6]:
#relations = ["P19", "P36", "P101", "P103","P106","P108", "P178", "P1001"]
#relations = ["P19", "P413", "P159", "P103"]
#relations = ["P138", "P136"]
#relations = ["P413"]
#relations = nominalized_relations
#relations = ['P37', 'P1412', 'P178', 'P103', 'P36', 'P407', 'P1376', 'P1001', 'P140', 'P463', 'P176'] # considered small
#considered_relations = ['P159', 'P37', 'P1412', 'P17', 'P178', 'P103', 'P36', 'P407', 'P276', 'P31', 'P364', 'P1376', 'P127', 'P279', 'P1001', 'P140', 'P463', 'P176'] #big

In [7]:
TOP_K = 1

In [8]:
#unmasker = pipeline('fill-mask', model='bert-large-uncased-whole-word-masking', top_k=100)
unmasker = PipelineCacheWrapper('fill-mask', model='bert-base-cased', top_k=10)

In [9]:
def model_to_tokens(sentence):
    result = unmasker(sentence)
    result = sorted(result, key=lambda entry: entry["score"], reverse=True)
    return [entry["token_str"] for entry in result]

In [10]:
get_hard_filter = create_hard_filter_getter(model_to_tokens, get_templates)
get_known_filter = create_known_filter_getter(model_to_tokens, get_templates, get_relation_cardinality)

In [11]:
TREx = HardKnownTRExData(get_hard_filter, get_known_filter, relations = relations)
TREx.load()

In [19]:
relations = list(TREx.data.keys())

In [20]:
sum([len(val) for _, val in TREx.data.items()])

8046

In [21]:
len(TREx.data.keys())

38

In [22]:
querier = TypologyQuerier(unmasker, relations, TOP_K, MASK)

In [23]:
querier.query(TREx.data)

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 258/258 [00:00<00:00, 30137.03it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 111/111 [00:00<00:00, 19719.93it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████

In [24]:
%%time
unmasker.save_to_cache()

CPU times: user 4 µs, sys: 0 ns, total: 4 µs
Wall time: 5.25 µs


In [25]:
querier.print_global_result()

total Result
| type             |   correct |   total |   accuracy |
|------------------+-----------+---------+------------|
| simple           |      5713 |    8046 |    71.0042 |
| compound         |      2204 |    8046 |    27.3925 |
| complex          |      2019 |    8046 |    25.0932 |
| compound-complex |      2602 |    8046 |    32.3391 |
Total Result for 1:1
| type             |   correct |   total |   accuracy |
|------------------+-----------+---------+------------|
| simple           |       543 |     543 |   100      |
| compound         |       517 |     543 |    95.2118 |
| complex          |       506 |     543 |    93.186  |
| compound-complex |       391 |     543 |    72.0074 |
Total Result for N:1
| type             |   correct |   total |   accuracy |
|------------------+-----------+---------+------------|
| simple           |      3616 |    3616 |   100      |
| compound         |      1354 |    3616 |    37.4447 |
| complex          |      1199 |    3616 |    33.

In [26]:
querier.print_result()

P159: headquarters location N:1
| type             |   correct |   total |   accuracy |
|------------------+-----------+---------+------------|
| simple           |       258 |     258 |   100      |
| compound         |       232 |     258 |    89.9225 |
| complex          |       224 |     258 |    86.8217 |
| compound-complex |       220 |     258 |    85.2713 |
P37: official language N:1
| type             |   correct |   total |   accuracy |
|------------------+-----------+---------+------------|
| simple           |       111 |     111 |   100      |
| compound         |        87 |     111 |    78.3784 |
| complex          |        59 |     111 |    53.1532 |
| compound-complex |        98 |     111 |    88.2883 |
P1412: languages spoken, written or signed N:M
| type             |   correct |   total |   accuracy |
|------------------+-----------+---------+------------|
| simple           |       395 |     609 |  64.8604   |
| compound         |        22 |     609 |   3.61248  

In [27]:
querier.print_for_latex()

P1376 & capital of & 1:1 & 100.00 & 95.45 & 96.75 & 95.45 \\
P36 & capital & 1:1 & 100.00 & 95.12 & 91.77 & 62.72 \\
P103 & native language & N:1 & 100.00 & 88.64 & 81.82 & 56.82 \\
P127 & owned by & N:1 & 100.00 & 0.00 & 2.70 & 49.55 \\
P131 & located in the administrative territorial entity & N:1 & 100.00 & 0.00 & 0.00 & 0.50 \\
P138 & named after & N:1 & 100.00 & 52.00 & 24.00 & 72.00 \\
P159 & headquarters location & N:1 & 100.00 & 89.92 & 86.82 & 85.27 \\
P17 & country & N:1 & 100.00 & 9.33 & 82.00 & 26.67 \\
P176 & manufacturer & N:1 & 100.00 & 75.75 & 36.99 & 91.46 \\
P19 & place of birth & N:1 & 100.00 & 0.00 & 0.00 & 0.00 \\
P20 & place of death & N:1 & 100.00 & 0.00 & 0.00 & 0.00 \\
P264 & record label & N:1 & 100.00 & 0.00 & 0.00 & 0.00 \\
P276 & location & N:1 & 100.00 & 84.21 & 72.37 & 89.47 \\
P279 & subclass of & N:1 & 100.00 & 10.51 & 27.80 & 6.10 \\
P30 & continent & N:1 & 100.00 & 2.02 & 4.84 & 1.21 \\
P31 & instance of & N:1 & 100.00 & 0.00 & 2.10 & 0.00 \\
P361 & pa

In [28]:
querier.print_result()

P159: headquarters location N:1
| type             |   correct |   total |   accuracy |
|------------------+-----------+---------+------------|
| simple           |       258 |     258 |   100      |
| compound         |       232 |     258 |    89.9225 |
| complex          |       224 |     258 |    86.8217 |
| compound-complex |       220 |     258 |    85.2713 |
P37: official language N:1
| type             |   correct |   total |   accuracy |
|------------------+-----------+---------+------------|
| simple           |       111 |     111 |   100      |
| compound         |        87 |     111 |    78.3784 |
| complex          |        59 |     111 |    53.1532 |
| compound-complex |        98 |     111 |    88.2883 |
P1412: languages spoken, written or signed N:M
| type             |   correct |   total |   accuracy |
|------------------+-----------+---------+------------|
| simple           |       395 |     609 |  64.8604   |
| compound         |        22 |     609 |   3.61248  

In [29]:
TOP_K = 10

In [30]:
unmasker = PipelineCacheWrapper('fill-mask', model='bert-base-cased', top_k=TOP_K)

In [31]:
querier.reset()
querier.set_top_k(TOP_K)
querier.set_model(unmasker)

In [32]:
querier.query(TREx.data)

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 258/258 [00:00<00:00, 25080.09it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 111/111 [00:00<00:00, 19180.48it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████

In [33]:
%%time
unmasker.save_to_cache()

CPU times: user 2 µs, sys: 0 ns, total: 2 µs
Wall time: 3.81 µs


In [34]:
querier.print_global_result()

total Result
| type             |   correct |   total |   accuracy |
|------------------+-----------+---------+------------|
| simple           |      8046 |    8046 |   100      |
| compound         |      5427 |    8046 |    67.4497 |
| complex          |      5278 |    8046 |    65.5978 |
| compound-complex |      5611 |    8046 |    69.7365 |
Total Result for 1:1
| type             |   correct |   total |   accuracy |
|------------------+-----------+---------+------------|
| simple           |       543 |     543 |   100      |
| compound         |       542 |     543 |    99.8158 |
| complex          |       542 |     543 |    99.8158 |
| compound-complex |       538 |     543 |    99.0792 |
Total Result for N:1
| type             |   correct |   total |   accuracy |
|------------------+-----------+---------+------------|
| simple           |      3616 |    3616 |   100      |
| compound         |      2397 |    3616 |    66.2887 |
| complex          |      2477 |    3616 |    68.

In [35]:
querier.print_result()

P159: headquarters location N:1
| type             |   correct |   total |   accuracy |
|------------------+-----------+---------+------------|
| simple           |       258 |     258 |   100      |
| compound         |       255 |     258 |    98.8372 |
| complex          |       252 |     258 |    97.6744 |
| compound-complex |       254 |     258 |    98.4496 |
P37: official language N:1
| type             |   correct |   total |   accuracy |
|------------------+-----------+---------+------------|
| simple           |       111 |     111 |   100      |
| compound         |       111 |     111 |   100      |
| complex          |       111 |     111 |   100      |
| compound-complex |       110 |     111 |    99.0991 |
P1412: languages spoken, written or signed N:M
| type             |   correct |   total |   accuracy |
|------------------+-----------+---------+------------|
| simple           |       609 |     609 |   100      |
| compound         |       565 |     609 |    92.775  

In [36]:
querier.print_for_latex()

P1376 & capital of & 1:1 & 100.00 & 100.00 & 100.00 & 100.00 \\
P36 & capital & 1:1 & 100.00 & 99.74 & 99.74 & 98.71 \\
P103 & native language & N:1 & 100.00 & 100.00 & 100.00 & 100.00 \\
P127 & owned by & N:1 & 100.00 & 28.38 & 43.24 & 94.14 \\
P131 & located in the administrative territorial entity & N:1 & 100.00 & 51.98 & 64.85 & 37.13 \\
P138 & named after & N:1 & 100.00 & 100.00 & 76.00 & 100.00 \\
P159 & headquarters location & N:1 & 100.00 & 98.84 & 97.67 & 98.45 \\
P17 & country & N:1 & 100.00 & 88.67 & 98.67 & 98.67 \\
P176 & manufacturer & N:1 & 100.00 & 98.51 & 74.53 & 98.37 \\
P19 & place of birth & N:1 & 100.00 & 4.00 & 10.29 & 7.43 \\
P20 & place of death & N:1 & 100.00 & 6.78 & 1.69 & 3.39 \\
P264 & record label & N:1 & 100.00 & 100.00 & 100.00 & 100.00 \\
P276 & location & N:1 & 100.00 & 99.21 & 97.89 & 98.68 \\
P279 & subclass of & N:1 & 100.00 & 49.83 & 54.24 & 46.10 \\
P30 & continent & N:1 & 100.00 & 90.73 & 99.60 & 71.37 \\
P31 & instance of & N:1 & 100.00 & 0.00 &

In [37]:
TOP_K = 100

In [38]:
unmasker = PipelineCacheWrapper('fill-mask', model='bert-base-cased', top_k=TOP_K)

In [39]:
querier.reset()
querier.set_top_k(TOP_K)
querier.set_model(unmasker)

In [40]:

querier.query(TREx.data)

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 258/258 [00:00<00:00, 12208.98it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 111/111 [00:00<00:00, 9646.27it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████

In [41]:
%%time
unmasker.save_to_cache()

CPU times: user 3 µs, sys: 0 ns, total: 3 µs
Wall time: 5.01 µs


In [42]:
querier.print_global_result()

total Result
| type             |   correct |   total |   accuracy |
|------------------+-----------+---------+------------|
| simple           |      8046 |    8046 |   100      |
| compound         |      6789 |    8046 |    84.3773 |
| complex          |      6966 |    8046 |    86.5772 |
| compound-complex |      6759 |    8046 |    84.0045 |
Total Result for 1:1
| type             |   correct |   total |   accuracy |
|------------------+-----------+---------+------------|
| simple           |       543 |     543 |   100      |
| compound         |       542 |     543 |    99.8158 |
| complex          |       542 |     543 |    99.8158 |
| compound-complex |       543 |     543 |   100      |
Total Result for N:1
| type             |   correct |   total |   accuracy |
|------------------+-----------+---------+------------|
| simple           |      3616 |    3616 |   100      |
| compound         |      2971 |    3616 |    82.1626 |
| complex          |      3154 |    3616 |    87.

In [43]:
querier.print_result()

P159: headquarters location N:1
| type             |   correct |   total |   accuracy |
|------------------+-----------+---------+------------|
| simple           |       258 |     258 |        100 |
| compound         |       258 |     258 |        100 |
| complex          |       258 |     258 |        100 |
| compound-complex |       258 |     258 |        100 |
P37: official language N:1
| type             |   correct |   total |   accuracy |
|------------------+-----------+---------+------------|
| simple           |       111 |     111 |        100 |
| compound         |       111 |     111 |        100 |
| complex          |       111 |     111 |        100 |
| compound-complex |       111 |     111 |        100 |
P1412: languages spoken, written or signed N:M
| type             |   correct |   total |   accuracy |
|------------------+-----------+---------+------------|
| simple           |       609 |     609 |   100      |
| compound         |       608 |     609 |    99.8358 

In [44]:
querier.print_for_latex()

P1376 & capital of & 1:1 & 100.00 & 100.00 & 100.00 & 100.00 \\
P36 & capital & 1:1 & 100.00 & 99.74 & 99.74 & 100.00 \\
P103 & native language & N:1 & 100.00 & 100.00 & 100.00 & 100.00 \\
P127 & owned by & N:1 & 100.00 & 56.31 & 77.03 & 99.10 \\
P131 & located in the administrative territorial entity & N:1 & 100.00 & 95.05 & 99.01 & 89.11 \\
P138 & named after & N:1 & 100.00 & 100.00 & 100.00 & 100.00 \\
P159 & headquarters location & N:1 & 100.00 & 100.00 & 100.00 & 100.00 \\
P17 & country & N:1 & 100.00 & 100.00 & 100.00 & 100.00 \\
P176 & manufacturer & N:1 & 100.00 & 99.73 & 88.35 & 99.73 \\
P19 & place of birth & N:1 & 100.00 & 76.57 & 99.43 & 56.57 \\
P20 & place of death & N:1 & 100.00 & 84.75 & 79.66 & 81.36 \\
P264 & record label & N:1 & 100.00 & 100.00 & 100.00 & 100.00 \\
P276 & location & N:1 & 100.00 & 100.00 & 100.00 & 100.00 \\
P279 & subclass of & N:1 & 100.00 & 85.08 & 76.61 & 82.71 \\
P30 & continent & N:1 & 100.00 & 100.00 & 100.00 & 100.00 \\
P31 & instance of & N: