In [1]:
import torch
import torch.optim as optim
import torch.nn as nn
from torch.utils.data import DataLoader
from transformers import logging
from transformers import pipeline
from transformers import BertTokenizer, BertModel
from tqdm import tqdm
from PipelineCacheWrapper.PipelineCacheWrapper import PipelineCacheWrapper

In [2]:
from LamaTRExData import LamaTRExData
from SentenceTypologyQueryResults import SentenceTypologyQueryResults
from relation_templates.templates import relations, nominalized_relations

In [3]:
from TypologyDiscriminator import TypologyDiscriminator 

In [4]:
logging.set_verbosity_error()

In [5]:
MASK = "[MASK]"
KEYS = ["active", "passive", "nominalized"]

In [6]:
#relations = ["P19", "P36", "P101", "P103","P106","P108", "P178", "P1001"]
#relations = ["P19", "P413", "P159", "P103"]
#relations = ["P138", "P136"]
#relations = ["P1412"]
#relations = nominalized_relations
#relations = ['P37', 'P1412', 'P178', 'P103', 'P36', 'P407', 'P1376', 'P1001', 'P140', 'P463', 'P176'] # considered small
#considered_relations = ['P159', 'P37', 'P1412', 'P17', 'P178', 'P103', 'P36', 'P407', 'P276', 'P31', 'P364', 'P1376', 'P127', 'P279', 'P1001', 'P140', 'P463', 'P176'] #big

In [7]:
TOP_K = 10

In [8]:
#unmasker = pipeline('fill-mask', model='bert-large-uncased-whole-word-masking', top_k=100)
unmasker = PipelineCacheWrapper('fill-mask', model='bert-base-cased', top_k=TOP_K)

In [9]:
TREx = LamaTRExData(relations = relations)
TREx.load()

In [10]:
querier = TypologyDiscriminator(unmasker, relations, TOP_K, MASK, "unk")

In [11]:
querier.token

'unk'

In [12]:
querier.discriminate(TREx.data)

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 967/967 [00:00<00:00, 29152.86it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 966/966 [00:00<00:00, 30726.34it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████

In [13]:
querier.query_easy()

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 139/139 [00:00<00:00, 21079.95it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 694/694 [00:00<00:00, 27288.34it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████

In [14]:
%%time
unmasker.save_to_cache()

CPU times: user 2 µs, sys: 0 ns, total: 2 µs
Wall time: 3.1 µs


In [15]:
querier.print_global_result()

total Result
| type             |   correct |   total |   accuracy |
|------------------+-----------+---------+------------|
| simple           |      7491 |    8863 |    84.5199 |
| compound         |      4429 |    8863 |    49.9718 |
| complex          |      5305 |    8863 |    59.8556 |
| compound-complex |      4614 |    8863 |    52.0591 |
Total Result for 1:1
| type             |   correct |   total |   accuracy |
|------------------+-----------+---------+------------|
| simple           |        71 |      74 |    95.9459 |
| compound         |        70 |      74 |    94.5946 |
| complex          |        71 |      74 |    95.9459 |
| compound-complex |        69 |      74 |    93.2432 |
Total Result for N:1
| type             |   correct |   total |   accuracy |
|------------------+-----------+---------+------------|
| simple           |      5862 |    6598 |    88.8451 |
| compound         |      3661 |    6598 |    55.4865 |
| complex          |      4573 |    6598 |    69.

In [16]:
querier.print_result()

P159: headquarters location N:1
| type             |   correct |   total |   accuracy |
|------------------+-----------+---------+------------|
| simple           |        86 |     139 |    61.8705 |
| compound         |       102 |     139 |    73.3813 |
| complex          |       102 |     139 |    73.3813 |
| compound-complex |        94 |     139 |    67.6259 |
P37: official language N:1
| type             |   correct |   total |   accuracy |
|------------------+-----------+---------+------------|
| simple           |       637 |     694 |    91.7867 |
| compound         |       608 |     694 |    87.6081 |
| complex          |       595 |     694 |    85.7349 |
| compound-complex |       574 |     694 |    82.7089 |
P1412: languages spoken, written or signed N:M
| type             |   correct |   total |   accuracy |
|------------------+-----------+---------+------------|
| simple           |       263 |     264 |    99.6212 |
| compound         |       258 |     264 |    97.7273 

In [17]:
querier.print_for_latex()

P1376 & capital of & 1:1 & 95.24 & 95.24 & 95.24 & 95.24 \\
P36 & capital & 1:1 & 96.23 & 94.34 & 96.23 & 92.45 \\
P103 & native language & N:1 & 96.51 & 98.20 & 98.43 & 98.20 \\
P127 & owned by & N:1 & 89.47 & 6.58 & 14.47 & 77.63 \\
P131 & located in the administrative territorial entity & N:1 & 75.00 & 25.00 & 25.00 & 0.00 \\
P136 & genre & N:1 & 72.11 & 2.84 & 71.05 & 1.95 \\
P138 & named after & N:1 & 50.00 & 50.00 & 50.00 & 50.00 \\
P140 & religion or worldview & N:1 & 99.78 & 98.71 & 99.14 & 87.10 \\
P159 & headquarters location & N:1 & 61.87 & 73.38 & 73.38 & 67.63 \\
P17 & country & N:1 & 78.69 & 59.02 & 80.60 & 80.87 \\
P176 & manufacturer & N:1 & 97.32 & 91.96 & 89.29 & 91.96 \\
P19 & place of birth & N:1 & 69.70 & 0.00 & 4.55 & 1.52 \\
P20 & place of death & N:1 & 95.20 & 3.00 & 1.20 & 0.90 \\
P264 & record label & N:1 & 100.00 & 75.82 & 4.40 & 92.31 \\
P276 & location & N:1 & 75.00 & 77.50 & 77.50 & 75.00 \\
P279 & subclass of & N:1 & 0.00 & 0.00 & 0.00 & 0.00 \\
P30 & con

In [18]:
querier.print_result()

P159: headquarters location N:1
| type             |   correct |   total |   accuracy |
|------------------+-----------+---------+------------|
| simple           |        86 |     139 |    61.8705 |
| compound         |       102 |     139 |    73.3813 |
| complex          |       102 |     139 |    73.3813 |
| compound-complex |        94 |     139 |    67.6259 |
P37: official language N:1
| type             |   correct |   total |   accuracy |
|------------------+-----------+---------+------------|
| simple           |       637 |     694 |    91.7867 |
| compound         |       608 |     694 |    87.6081 |
| complex          |       595 |     694 |    85.7349 |
| compound-complex |       574 |     694 |    82.7089 |
P1412: languages spoken, written or signed N:M
| type             |   correct |   total |   accuracy |
|------------------+-----------+---------+------------|
| simple           |       263 |     264 |    99.6212 |
| compound         |       258 |     264 |    97.7273 

In [19]:
querier.query_hard()

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 828/828 [00:00<00:00, 15735.13it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 272/272 [00:00<00:00, 25994.00it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████

In [20]:
%%time
unmasker.save_to_cache()

CPU times: user 2 µs, sys: 0 ns, total: 2 µs
Wall time: 3.1 µs


In [21]:
querier.print_global_result()

total Result
| type             |   correct |   total |   accuracy |
|------------------+-----------+---------+------------|
| simple           |     12162 |   25176 |    48.3079 |
| compound         |      8274 |   25176 |    32.8646 |
| complex          |      8289 |   25176 |    32.9242 |
| compound-complex |      8648 |   25176 |    34.3502 |
Total Result for 1:1
| type             |   correct |   total |   accuracy |
|------------------+-----------+---------+------------|
| simple           |       696 |     863 |    80.6489 |
| compound         |       692 |     863 |    80.1854 |
| complex          |       686 |     863 |    79.4902 |
| compound-complex |       658 |     863 |    76.2457 |
Total Result for N:1
| type             |   correct |   total |   accuracy |
|------------------+-----------+---------+------------|
| simple           |      7579 |   14330 |    52.889  |
| compound         |      4354 |   14330 |    30.3838 |
| complex          |      4728 |   14330 |    32.

In [22]:
querier.print_result()

P159: headquarters location N:1
| type             |   correct |   total |   accuracy |
|------------------+-----------+---------+------------|
| simple           |       413 |     828 |    49.8792 |
| compound         |       467 |     828 |    56.401  |
| complex          |       455 |     828 |    54.9517 |
| compound-complex |       456 |     828 |    55.0725 |
P37: official language N:1
| type             |   correct |   total |   accuracy |
|------------------+-----------+---------+------------|
| simple           |       242 |     272 |    88.9706 |
| compound         |       233 |     272 |    85.6618 |
| complex          |       228 |     272 |    83.8235 |
| compound-complex |       233 |     272 |    85.6618 |
P1412: languages spoken, written or signed N:M
| type             |   correct |   total |   accuracy |
|------------------+-----------+---------+------------|
| simple           |       609 |     705 |    86.383  |
| compound         |       579 |     705 |    82.1277 

In [23]:
querier.print_for_latex()

P1376 & capital of & 1:1 & 91.08 & 91.08 & 90.61 & 90.61 \\
P36 & capital & 1:1 & 77.23 & 76.62 & 75.85 & 71.54 \\
P103 & native language & N:1 & 95.45 & 94.32 & 86.36 & 90.91 \\
P127 & owned by & N:1 & 58.43 & 13.26 & 19.48 & 52.05 \\
P131 & located in the administrative territorial entity & N:1 & 54.39 & 16.08 & 21.55 & 10.95 \\
P136 & genre & N:1 & 1.36 & 2.17 & 16.58 & 0.82 \\
P138 & named after & N:1 & 69.52 & 64.39 & 22.86 & 68.43 \\
P140 & religion or worldview & N:1 & 37.50 & 0.00 & 0.00 & 25.00 \\
P159 & headquarters location & N:1 & 49.88 & 56.40 & 54.95 & 55.07 \\
P17 & country & N:1 & 60.28 & 48.58 & 68.97 & 71.81 \\
P176 & manufacturer & N:1 & 93.91 & 90.11 & 68.28 & 90.23 \\
P19 & place of birth & N:1 & 46.01 & 1.37 & 2.51 & 2.51 \\
P20 & place of death & N:1 & 42.42 & 3.23 & 2.10 & 1.94 \\
P264 & record label & N:1 & 13.02 & 6.21 & 1.78 & 1.78 \\
P276 & location & N:1 & 65.18 & 60.83 & 58.76 & 62.79 \\
P279 & subclass of & N:1 & 68.98 & 16.80 & 22.93 & 16.29 \\
P30 & con

In [24]:
counts, rels = querier.get_easy_count()

In [25]:
counts

{'total': 8863,
 'P159': 139,
 'P37': 694,
 'P1412': 264,
 'P138': 2,
 'P495': 402,
 'P17': 366,
 'P103': 889,
 'P20': 333,
 'P36': 53,
 'P407': 664,
 'P190': 19,
 'P276': 40,
 'P31': 7,
 'P364': 526,
 'P27': 146,
 'P47': 80,
 'P1376': 21,
 'P106': 64,
 'P413': 421,
 'P937': 386,
 'P449': 694,
 'P131': 4,
 'P1303': 459,
 'P127': 76,
 'P740': 41,
 'P19': 66,
 'P527': 10,
 'P136': 563,
 'P264': 91,
 'P39': 45,
 'P101': 285,
 'P140': 465,
 'P361': 3,
 'P530': 271,
 'P463': 162,
 'P176': 112}

In [26]:
counts, rels = querier.get_hard_count()

In [27]:
counts

{'total': 25176,
 'P159': 828,
 'P37': 272,
 'P1412': 705,
 'P138': 643,
 'P495': 507,
 'P17': 564,
 'P178': 592,
 'P103': 88,
 'P20': 620,
 'P36': 650,
 'P407': 213,
 'P190': 976,
 'P276': 919,
 'P108': 383,
 'P31': 915,
 'P364': 330,
 'P27': 820,
 'P30': 975,
 'P47': 842,
 'P1376': 213,
 'P106': 894,
 'P413': 531,
 'P937': 568,
 'P449': 187,
 'P131': 877,
 'P1303': 490,
 'P127': 611,
 'P279': 964,
 'P740': 895,
 'P19': 878,
 'P527': 966,
 'P136': 368,
 'P264': 338,
 'P1001': 701,
 'P39': 847,
 'P101': 411,
 'P140': 8,
 'P361': 929,
 'P530': 725,
 'P463': 63,
 'P176': 870}

In [28]:
querier.print_count_for_latex("easy")

Total &  -  &  -  & 8863 \\
P1376 & capital of & 1:1 & 21 \\
P36 & capital & 1:1 & 53 \\
P103 & native language & N:1 & 889 \\
P127 & owned by & N:1 & 76 \\
P131 & located in the administrative territorial entity & N:1 & 4 \\
P136 & genre & N:1 & 563 \\
P138 & named after & N:1 & 2 \\
P140 & religion or worldview & N:1 & 465 \\
P159 & headquarters location & N:1 & 139 \\
P17 & country & N:1 & 366 \\
P176 & manufacturer & N:1 & 112 \\
P19 & place of birth & N:1 & 66 \\
P20 & place of death & N:1 & 333 \\
P264 & record label & N:1 & 91 \\
P276 & location & N:1 & 40 \\
P31 & instance of & N:1 & 7 \\
P361 & part of & N:1 & 3 \\
P364 & original language of film or TV show & N:1 & 526 \\
P37 & official language & N:1 & 694 \\
P407 & language of work or name & N:1 & 664 \\
P413 & position played on team / speciality & N:1 & 421 \\
P449 & original broadcaster & N:1 & 694 \\
P495 & country of origin & N:1 & 402 \\
P740 & location of formation & N:1 & 41 \\
P101 & field of work & N:M & 285 \\
P1

In [29]:
querier.print_count_for_latex("hard")

Total &  -  &  -  & 25176 \\
P1376 & capital of & 1:1 & 213 \\
P36 & capital & 1:1 & 650 \\
P103 & native language & N:1 & 88 \\
P127 & owned by & N:1 & 611 \\
P131 & located in the administrative territorial entity & N:1 & 877 \\
P136 & genre & N:1 & 368 \\
P138 & named after & N:1 & 643 \\
P140 & religion or worldview & N:1 & 8 \\
P159 & headquarters location & N:1 & 828 \\
P17 & country & N:1 & 564 \\
P176 & manufacturer & N:1 & 870 \\
P19 & place of birth & N:1 & 878 \\
P20 & place of death & N:1 & 620 \\
P264 & record label & N:1 & 338 \\
P276 & location & N:1 & 919 \\
P279 & subclass of & N:1 & 964 \\
P30 & continent & N:1 & 975 \\
P31 & instance of & N:1 & 915 \\
P361 & part of & N:1 & 929 \\
P364 & original language of film or TV show & N:1 & 330 \\
P37 & official language & N:1 & 272 \\
P407 & language of work or name & N:1 & 213 \\
P413 & position played on team / speciality & N:1 & 531 \\
P449 & original broadcaster & N:1 & 187 \\
P495 & country of origin & N:1 & 507 \\
P740