In [1]:
from transformers import logging
from transformers import BertModel, BertTokenizer
from transformers import pipeline
import torch
import torch.nn as nn
from torch.nn.functional import normalize, log_softmax

In [2]:
from TRExData.LamaTRExData import LamaTRExData
from ResultData.ScorePersistor import ScorePersistor
from relation_templates.templates import get_templates, get_relation_meta, relations, relation_names, get_relations_with_last_digit
from ModelHelpers.fill_mask_helpers import get_probability_from_pipeline_for_token
from tqdm import tqdm

In [3]:
logging.set_verbosity_error()

In [4]:
SIMPLE = 'simple'
COMPOUND = 'compound'
COMPLEX = 'complex'
COMCOM = "compound-complex"

KEYS = [SIMPLE, COMPOUND, COMPLEX, COMCOM]

In [5]:
#model='bert-base-uncased'
model='bert-large-cased'
#model='bert-base-multilingual-cased'
#model='bert-base-multilingual-uncased'
VOCABULARY_SIZE = 28996
MASK = "[MASK]"

In [6]:
relations = get_relations_with_last_digit(9)[2:]
len(relations)

3

In [7]:
unmasker = pipeline('fill-mask', model="bert-base-cased", top_k=VOCABULARY_SIZE)

In [8]:
TREx = LamaTRExData(relations = relations)
TREx.load()

In [9]:
persister = ScorePersistor(model=model)

In [10]:
@torch.no_grad()
def metric(sentence: str, token: str):
    prob = get_probability_from_pipeline_for_token(unmasker(sentence), token)
    return prob

In [11]:
for relation in relations:
    for sub, obj in tqdm(TREx.data[relation]):
        sentences = get_templates(relation, sub, MASK)
        for key in KEYS:
            score = metric(sentences[key], obj)
            try:
                persister.persist_score_row(key, relation, sub, obj , score)
            except Exception as e: 
                print(e)
                print(relation, sub, obj)        

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 964/964 [47:12<00:00,  2.94s/it]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 944/944 [45:55<00:00,  2.92s/it]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 892/892 [43:09<00:00,  2.90s/it]


In [12]:
persister.close()