In [1]:
import torch
import torch.optim as optim
import torch.nn as nn
from torch.utils.data import DataLoader
from transformers import logging
from transformers import pipeline
from transformers import BertTokenizer, BertModel
from tqdm import tqdm
from PipelineCacheWrapper.PipelineCacheWrapper import PipelineCacheWrapper

In [2]:
from TRExData.AbstractTRExData import AbstractTRExData
from SentenceTypologyQueryResults import SentenceTypologyQueryResults
from relation_templates.templates import relations, nominalized_relations, get_templates, get_relation_name

In [3]:
from TypologyQuerier import TypologyQuerier 

In [4]:
logging.set_verbosity_error()

In [5]:
#relations = ["P19", "P36", "P101", "P103","P106","P108", "P178", "P1001"]
#relations = ["P19", "P413", "P159", "P103"]
#relations = ["P364"]
#relations = nominalized_relations

In [6]:
MASK = "[MASK]"
KEYS = ["active", "passive", "nominalized"]
TOP_K = 10

In [7]:
def check_triplets(rels, triples):
    count = {"total": 0}
    for relation in rels:
        count[relation] = 0
        for sub, obj in triples[relation]:
            if obj.islower() and sub.islower():
                count[relation] +=1
                count["total"] +=1
                sentence = get_templates(relation, sub, obj, ["simple"])["simple"]
                print(f"{relation}-{get_relation_name(relation)}: {sentence}")
    return count

In [8]:
TREx = AbstractTRExData(relations=relations)
TREx.load()

In [9]:
sum([len(triples) for triples in TREx.data.values()])

1715

In [10]:
list(TREx.data.keys()), len(TREx.data.keys())

(['P138', 'P276', 'P31', 'P937', 'P279', 'P527', 'P101', 'P361'], 8)

In [11]:
relations = list(TREx.data.keys())

for relation, triples in TREx.abstract_data.items() :
    print(f"{relation}-{get_relation_name(relation)} - count: {len(triples)}")

In [12]:
check_triplets(relations, TREx.data)

P138-named after: caffeine is named after coffee.
P138-named after: uraninite is named after uranium.
P138-named after: chlorargyrite is named after silver.
P138-named after: hepatic artery proper is named after liver.
P138-named after: molecule is named after mole.
P138-named after: molybdenum is named after lead.
P138-named after: burgomaster is named after mayor.
P138-named after: patent troll is named after patent.
P138-named after: patriarchate is named after patriarch.
P138-named after: clergy house is named after canon.
P138-named after: rutile is named after red.
P138-named after: zincite is named after zinc.
P138-named after: uranophane-alpha is named after uranium.
P138-named after: munster is named after monastery.
P138-named after: carbon planet is named after carbon.
P138-named after: omphacite is named after grape.
P276-location: parish church is located in parish.
P276-location: shopping cart is located in supermarket.
P276-location: earring is located in ear.
P31-instan

{'total': 1715,
 'P138': 16,
 'P276': 3,
 'P31': 32,
 'P937': 1,
 'P279': 532,
 'P527': 634,
 'P101': 27,
 'P361': 470}

In [13]:
unmasker = PipelineCacheWrapper('fill-mask', model='bert-base-cased', top_k=TOP_K)

In [14]:

querier = TypologyQuerier(unmasker, relations, TOP_K, MASK)
querier.query(TREx.data)


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 16/16 [00:00<00:00, 8097.11it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 4419.71it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████

In [15]:
querier.print_global_result()

total Result
| type             |   correct |   total |   accuracy |
|------------------+-----------+---------+------------|
| simple           |      1020 |    1715 |    59.4752 |
| compound         |       411 |    1715 |    23.965  |
| complex          |       541 |    1715 |    31.5452 |
| compound-complex |       308 |    1715 |    17.9592 |
Total Result for 1:1
| type             |   correct |   total |   accuracy |
|------------------+-----------+---------+------------|
| simple           |         0 |       0 |          0 |
| compound         |         0 |       0 |          0 |
| complex          |         0 |       0 |          0 |
| compound-complex |         0 |       0 |          0 |
Total Result for N:1
| type             |   correct |   total |   accuracy |
|------------------+-----------+---------+------------|
| simple           |       655 |    1053 |    62.2032 |
| compound         |       126 |    1053 |    11.9658 |
| complex          |       290 |    1053 |    27.

In [16]:
querier.print_global_latex()

Total  & 59.48 & 23.97 & 31.55 & 17.96 \\
1:1  & 0.00 & 0.00 & 0.00 & 0.00 \\
N:1  & 62.20 & 11.97 & 27.54 & 10.35 \\
N:M  & 55.14 & 43.05 & 37.92 & 30.06 \\


In [17]:
querier.print_for_latex()

P138 & named after & N:1 & 18.75 & 18.75 & 6.25 & 25.00 \\
P276 & location & N:1 & 0.00 & 33.33 & 33.33 & 0.00 \\
P279 & subclass of & N:1 & 68.98 & 17.48 & 25.38 & 14.29 \\
P31 & instance of & N:1 & 53.12 & 0.00 & 18.75 & 3.12 \\
P361 & part of & N:1 & 57.02 & 6.17 & 31.28 & 5.96 \\
P101 & field of work & N:M & 74.07 & 37.04 & 0.00 & 3.70 \\
P527 & has part or parts & N:M & 54.42 & 43.38 & 39.59 & 31.23 \\
P937 & work location & N:M & 0.00 & 0.00 & 0.00 & 0.00 \\
