In [1]:
from transformers import logging
from transformers import BertModel, BertTokenizer
from transformers import pipeline
import torch
import torch.nn as nn
from torch.nn.functional import normalize, log_softmax

In [2]:
logging.set_verbosity_error()
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [3]:
from LamaTRExData import LamaTRExData
from DistributionComparison.DistributionComparisonResult import DistributionComparisonResult 
from DistributionComparison.SentenceTypologyEmbeddingDistributionComparison import SentenceTypologyEmbeddingDistributionComparison 
from ModelHelpers.fill_mask_helpers import get_probability_from_pipeline_for_token
from relation_templates.templates import get_templates, get_relation_meta, relations, relation_names
from PipelineCacheWrapper.PipelineCacheWrapper import PipelineCacheWrapper

In [4]:
USELESS_WORD = "Erftwerk"
UNKNOWN_TOKEN = "[UNK]"
MASK = "[MASK]"
OBJ_LABEL = "obj-label"
SUBJ_LABEL = "sub-label"
VOCABULARY_SIZE = 28996
words = [] 

In [5]:
#relations = ["P140", "P127", "P36", "P159"]
#relations = ["P1303"]
relations = ["P19", "P413", "P159", "P103"]
len(relations)

4

In [6]:
%%time
model = PipelineCacheWrapper('fill-mask', model="bert-base-cased", top_k=VOCABULARY_SIZE)

CPU times: user 1min 4s, sys: 1min 15s, total: 2min 19s
Wall time: 4min 8s


In [7]:
TREx = LamaTRExData(relations = relations)
TREx.load()

In [8]:
@torch.no_grad()
def metric(left_sentence: str, right_sentence: str, token: str):
    left_prob = get_probability_from_pipeline_for_token(model(left_sentence), token)
    right_prob = get_probability_from_pipeline_for_token(model(right_sentence), token)
    return left_prob-right_prob

In [9]:
Comparer = SentenceTypologyEmbeddingDistributionComparison(relations, get_templates, words, metric, MASK, get_relation_meta, "probability difference")

In [10]:
nr_of_compared_triplets = 1000

In [None]:
Comparer.compare(TREx.data, nr_of_compared_triplets)

 77%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▋                                                                                         | 724/944 [41:29<06:58,  1.90s/it]

In [None]:
%%time
model.save_to_cache()

In [None]:
pairs_to_compare = [("simple", "compound-complex"), ("simple", "complex"),  ("simple", "compound")]
#pairs_to_compare = [("simple", "compound")]

In [None]:
Comparer.print_comparison("P19", SUBJ_LABEL, pairs_to_compare)

In [None]:
Comparer.print_comparison("P413", SUBJ_LABEL, pairs_to_compare)

In [None]:
Comparer.print_comparison("P159", SUBJ_LABEL, pairs_to_compare)

In [None]:
Comparer.print_comparison("P103", SUBJ_LABEL, pairs_to_compare)

In [None]:
Comparer.print_global_comparison()

In [None]:
Comparer.print_global_for_latex()

In [None]:
import numpy as np

In [None]:
np.mean(Comparer.results.results["P131"]['sub-label']["simple"]["compound"])

In [None]:
np.mean(Comparer.results.results["P131"]['sub-label']["simple"]["compound-complex"])