In [1]:
from transformers import logging
from transformers import BertModel, BertTokenizer
from transformers import pipeline
import torch
import torch.nn as nn
from torch.nn.functional import normalize, log_softmax

In [2]:
from TRExData.LamaTRExData import LamaTRExData
from ResultData.ScorePersistor import ScorePersistor
from relation_templates.templates import get_templates, get_relation_meta, relations, relation_names, get_relations_with_last_digit, nominalized_relations
from ModelHelpers.fill_mask_helpers import get_probability_from_pipeline_for_token
from tqdm import tqdm

In [3]:
logging.set_verbosity_error()

In [4]:
SIMPLE = 'simple'
COMPOUND = 'compound'
COMPLEX = 'complex'
COMCOM = "compound-complex"

#KEYS = [SIMPLE, COMPOUND, COMPLEX, COMCOM]
KEYS = ["active", "passive", "nominalized"]

In [5]:
#model='bert-base-cased'
model='bert-base-uncased'
#model='bert-large-cased'
#model='bert-large-uncased'
#model='bert-base-multilingual-cased'
#model='bert-base-multilingual-uncased'
VOCABULARY_SIZE = 28996
MASK = "[MASK]"

In [6]:
relations = nominalized_relations[2:]
len(relations), relations

(6, ['P127', 'P136', 'P176', 'P178', 'P413', 'P1303'])

In [7]:
unmasker = pipeline('fill-mask', model="bert-base-cased", top_k=VOCABULARY_SIZE)

In [8]:
TREx = LamaTRExData(relations = relations)
TREx.load()

In [9]:
persister = ScorePersistor(model=model)

In [10]:
@torch.no_grad()
def metric(sentence: str, token: str):
    prob = get_probability_from_pipeline_for_token(unmasker(sentence), token)
    return prob

In [11]:
for relation in relations:
    print(relation)
    for sub, obj in tqdm(TREx.data[relation]):
        sentences = get_templates(relation, sub, MASK, KEYS)
        for key in KEYS:
            score = metric(sentences[key], obj)
            try:
                persister.persist_score_row(key, relation, sub, obj , score)
            except Exception as e: 
                print(e)
                print(relation, sub, obj)        

P127


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 687/687 [1:09:39<00:00,  6.08s/it]


P136


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 931/931 [33:21<00:00,  2.15s/it]


P176


  0%|▎                                                                                                                                                                          | 2/982 [00:04<39:30,  2.42s/it]


KeyboardInterrupt: 

In [12]:
persister.close()