In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
import os

from kb.include_all import ModelArchiveFromParams

from kb.knowbert_utils import KnowBertBatchifier
from allennlp.common import Params

# contains pretrained model, e.g. for Wordnet+Wikipedia
WORDNET_ARCHIVE = "https://allennlp.s3-us-west-2.amazonaws.com/knowbert/models/knowbert_wordnet_model.tar.gz"
WIKI_ARCHIVE = "https://allennlp.s3-us-west-2.amazonaws.com/knowbert/models/knowbert_wiki_model.tar.gz"
WORDNET_WIKI_ARCHIVE = "https://allennlp.s3-us-west-2.amazonaws.com/knowbert/models/knowbert_wiki_wordnet_model.tar.gz"

WORDNET_FOLDER = '../knowbert_wordnet_model/'
WORDNET_LINKER_FOLDER = WORDNET_FOLDER + 'entity_linker/'
WORDNET_LINKER_EMBEDDING_FILE = WORDNET_LINKER_FOLDER + 'wordnet_synsets_mask_null_vocab_embeddings_tucker_gensen.hdf5'
WORDNET_LINKER_ENTITY_FILE = WORDNET_LINKER_FOLDER + 'entities.jsonl'
WORDNET_LINKER_VOCAB_FILE = WORDNET_LINKER_FOLDER + 'wordnet_synsets_mask_null_vocab.txt'


WORDNET_MODEL_STATE_DICT_FILE = WORDNET_FOLDER+ 'weights.th'

CharacterTokenizer params None False None None
TokenCharactersIndexer params: entity <allennlp.data.tokenizers.character_tokenizer.CharacterTokenizer object at 0x7f63c1e79be0> None None 0




In [40]:
def create_test_set(original_model,batcher):

    test_sentences = [["Paris is located in [MASK].", "Michael [MASK] is a great music singer"],
                ["The Louvre contains the Mona Lisa", "The Amazon river is in Brazil"],
                "Donald Duck is a cartoon character",
                ["Hayao Miyazaki is the co-founder of Studio Ghibli and a renowned anime film maker",
                "The Alpine ibex is one of Switzerland's most famous animal along its grazing cows"]]
                
    original_model.eval()
    test_set = []
    for batch in batcher.iter_batches(test_sentences, verbose=False):
        test_case = {}
        test_case['input']=batch
        
        with torch.no_grad():
            test_case['expected']=original_model(**batch)
        test_set.append(test_case)

    return test_set

def model_correct(custom_model,test_set):
    custom_model.eval()
    for test_case in test_set:
        with torch.no_grad():
            custom_output = custom_model(**test_case["input"])

        expected_output = test_case["expected"]

        for key in test_case["expected"].keys():
            if(key in ['wiki','wordnet']):
                print(f"{key} entity_attention_probs are equal: {torch.equal(expected_output[key]['entity_attention_probs'],custom_output[key]['entity_attention_probs'])}")

                print(f"{key} output linking scores are equal: {torch.equal(expected_output[key]['linking_scores'],custom_output[key]['linking_scores'])}")
            else:
                if(key=='loss'):
                    print(f"{key} are equal : {expected_output[key]==custom_output[key]}")
                else:
                    print(f"{key} are equal : {torch.equal(expected_output[key],custom_output[key])}")

                    print(f"{key} are equal : {torch.equal(expected_output[key],custom_output[key])}")



### Wordnet Model

In [35]:
wordnet_batcher = KnowBertBatchifier(WORDNET_ARCHIVE)
params = Params({"archive_file": WORDNET_ARCHIVE})#Only contains a dictionnary with a single entry: archive_file:http://...
wordnet_original_model = ModelArchiveFromParams.from_params(params=params)



TokenCharactersIndexer params: entity <allennlp.data.tokenizers.character_tokenizer.CharacterTokenizer object at 0x7f63c1e79be0> None None 0
BertTokenizerAndCandidateGenerator params
{'wordnet': <kb.wordnet.WordNetCandidateMentionGenerator object at 0x7f62e7358b70>}
{'wordnet': <allennlp.data.token_indexers.token_characters_indexer.TokenCharactersIndexer object at 0x7f62e7358c50>}
bert-base-uncased
True
True
512
Vocab: <class 'allennlp.data.vocabulary.Vocabulary'>
Soldered kg: {'wordnet': SolderedKG(
  (entity_linker): EntityLinkingWithCandidateMentions(
    (loss): NLLLoss()
    (_log_softmax): LogSoftmax()
    (disambiguator): EntityDisambiguator(
      (span_extractor): SelfAttentiveSpanExtractor(
        (_global_attention): TimeDistributed(
          (_module): Linear(in_features=200, out_features=1, bias=True)
        )
      )
      (dropout): Dropout(p=0.1, inplace=False)
      (bert_to_kg_projector): Linear(in_features=768, out_features=200, bias=True)
      (projected_span_la

In [36]:
wordnet_original_model.soldered_kgs['wordnet'].entity_linker.disambiguator.entity_embeddings

WordNetAllEmbedding(
  (pos_embeddings): Embedding(117663, 25)
  (entity_embeddings): Embedding(117663, 2248, padding_idx=0)
  (proj_feed_forward): Linear(in_features=2273, out_features=200, bias=True)
  (dropout): Dropout(p=0.1, inplace=False)
)

In [37]:
from kb.custom_knowbert import CustomKnowBert
from kb.soldered_kg import CustomSolderedKG, CustomEntityLinkingWithCandidateMentions
from kb.custom_knowledge import CustomWordNetAllEmbedding

span_attention_config = {'hidden_size': 200, 'intermediate_size': 1024, 'num_attention_heads': 4, 'num_hidden_layers': 1}
span_encoder_config = {'hidden_size': 200, 'intermediate_size': 1024, 'num_attention_heads': 4, 'num_hidden_layers': 1}

#117662
#null_entity_id = model.vocab.get_token_index('@@NULL@@', "entity")
#117662
null_entity_id = 117662
entity_dim = 200

model_entity_embedder = CustomWordNetAllEmbedding(
                 embedding_file = WORDNET_LINKER_EMBEDDING_FILE,
                 entity_dim = entity_dim,
                 entity_file = WORDNET_LINKER_ENTITY_FILE,
                 vocab_file= WORDNET_LINKER_VOCAB_FILE,
                 entity_h5_key = "tucker_gensen",
                 dropout = 0.1,
                 pos_embedding_dim = 25,
                 include_null_embedding = False)

entity_embeddings = model_entity_embedder.entity_embeddings
null_embedding = torch.zeros(entity_dim) #From wordnet code

custom_entity_linker = CustomEntityLinkingWithCandidateMentions(
                 null_entity_id=null_entity_id,
                 entity_embedding = model_entity_embedder,
                 contextual_embedding_dim =768,
                 span_encoder_config = span_encoder_config,
                 margin = 0.2,
                 decode_threshold = 0.0,
                 loss_type = 'softmax',
                 max_sequence_length = 512,
                 dropout = 0.1,
                 output_feed_forward_hidden_dim = 100,
                 initializer_range = 0.02)

custom_wordnet_kg = CustomSolderedKG(entity_linker = custom_entity_linker, 
                            span_attention_config = span_attention_config,
                            should_init_kg_to_bert_inverse = False,
                            freeze = False)

custom_soldered_kgs = {'wordnet':custom_wordnet_kg}

span_extractor_global_attention_old_name = "wordnet_soldered_kg.entity_linker.disambiguator.span_extractor._global_attention._module.weight"
span_extractor_global_attention_bias_old_name = "wordnet_soldered_kg.entity_linker.disambiguator.span_extractor._global_attention._module.bias"
state_dict_map = {span_extractor_global_attention_old_name:span_extractor_global_attention_old_name.replace("._module",""),
                span_extractor_global_attention_bias_old_name: span_extractor_global_attention_bias_old_name.replace("._module","")}

custom_wordnet_model = CustomKnowBert(soldered_kgs = custom_soldered_kgs,
                                soldered_layers ={"wordnet": 9},
                                bert_model_name = "bert-base-uncased",
                                mode=None,state_dict_file=WORDNET_MODEL_STATE_DICT_FILE,
                                strict_load_archive=True,
                                remap_segment_embeddings = None,
                                state_dict_map = state_dict_map)


In [42]:
test_set = create_test_set(wordnet_original_model,wordnet_batcher)
torch.save(test_set,'knowbert_wordnet_model_test')
test_set = torch.load("knowbert_wordnet_model_test")

model_correct(custom_wordnet_model,test_set)

wordnet entity_attention_probs are equal: True
wordnet output linking scores are equal: True
loss are equal : True
pooled_output are equal : True
pooled_output are equal : True
contextual_embeddings are equal : True
contextual_embeddings are equal : True




### Wiki Model

In [8]:
wiki_batcher = KnowBertBatchifier(WIKI_ARCHIVE)
params = Params({"archive_file": WIKI_ARCHIVE})
wiki_original_model = ModelArchiveFromParams.from_params(params=params)

WikiCandidateMentionGenerator params: None None True False None
duplicate_mentions_cnt:  6777
end of p_e_m reading. wall time: 1.357266374429067  minutes
p_e_m_errors:  0
incompatible_ent_ids:  0
TokenCharactersIndexer params: entity <allennlp.data.tokenizers.character_tokenizer.CharacterTokenizer object at 0x7f63c1e79be0> None None 0
BertTokenizerAndCandidateGenerator params
{'wiki': <kb.wiki_linking_util.WikiCandidateMentionGenerator object at 0x7f62c9743fd0>}
{'wiki': <allennlp.data.token_indexers.token_characters_indexer.TokenCharactersIndexer object at 0x7f62c9743eb8>}
bert-base-uncased
True
True
512
Vocab: <class 'allennlp.data.vocabulary.Vocabulary'>
Soldered kg: {'wiki': SolderedKG(
  (entity_linker): EntityLinkingWithCandidateMentions(
    (loss): MarginRankingLoss()
    (disambiguator): EntityDisambiguator(
      (span_extractor): SelfAttentiveSpanExtractor(
        (_global_attention): TimeDistributed(
          (_module): Linear(in_features=300, out_features=1, bias=True)
 

In [45]:
type(wiki_original_model.soldered_kgs['wiki'].entity_linker.disambiguator.entity_embeddings)

allennlp.modules.token_embedders.embedding.Embedding

In [46]:
from allennlp.modules.token_embedders.embedding import Embedding

In [64]:
wiki_original_model.soldered_kgs['wiki'].entity_linker.disambiguator.entity_embeddings

Embedding()

In [80]:
from kb.custom_tokenizer.vocabulary import Vocabulary
vocabulary = Vocabulary.from_files("https://allennlp.s3-us-west-2.amazonaws.com/knowbert/wiki_entity_linking/vocabulary_wiki.tar.gz")

In [84]:
line = "Hello I am good"

line.split(" ",1)[0]

'Hello'

In [85]:
from tqdm import tqdm
import numpy as np

def read_embeddings_from_text_file(gzip_filename: str,
                                    embedding_dim: int,
                                    vocab: Vocabulary,
                                    namespace: str = "tokens") -> torch.FloatTensor:
    """
    Read pre-trained word vectors from an eventually compressed text file, possibly contained
    inside an archive with multiple files. The text file is assumed to be utf-8 encoded with
    space-separated fields: [word] [dim 1] [dim 2] ...

    Lines that contain more numerical tokens than ``embedding_dim`` raise a warning and are skipped.

    The remainder of the docstring is identical to ``_read_pretrained_embeddings_file``.
    """
    tokens_to_keep = set(vocab.get_index_to_token_vocabulary(namespace).values())
    vocab_size = vocab.get_vocab_size(namespace)
    embeddings = {}

    # First we read the embeddings from the file, only keeping vectors for the words we need.
    print("Reading pretrained embeddings from file")

    with open(gzip_filename) as embeddings_file:
        for line in tqdm(embeddings_file):
            print(type(line))
            print(line)
            token = line.split(' ', 1)[0]
            if token in tokens_to_keep:
                fields = line.rstrip().split(' ')
                if len(fields) - 1 != embedding_dim:
                    # Sometimes there are funny unicode parsing problems that lead to different
                    # fields lengths (e.g., a word with a unicode space character that splits
                    # into more than one column).  We skip those lines.  Note that if you have
                    # some kind of long header, this could result in all of your lines getting
                    # skipped.  It's hard to check for that here; you just have to look in the
                    # embedding_misses_file and at the model summary to make sure things look
                    # like they are supposed to.
                    print("Found line with wrong number of dimensions (expected: %d; actual: %d): %s",
                                   embedding_dim, len(fields) - 1, line)
                    continue

                vector = np.asarray(fields[1:], dtype='float32')
                embeddings[token] = vector

    if not embeddings:
        raise Exception("No embeddings of correct dimension found; you probably "
                                 "misspecified your embedding_dim parameter, or didn't "
                                 "pre-populate your Vocabulary")

    all_embeddings = np.asarray(list(embeddings.values()))
    embeddings_mean = float(np.mean(all_embeddings))
    embeddings_std = float(np.std(all_embeddings))
    # Now we initialize the weight matrix for an embedding layer, starting with random vectors,
    # then filling in the word vectors we just read.
    print("Initializing pre-trained embedding layer")
    embedding_matrix = torch.FloatTensor(vocab_size, embedding_dim).normal_(embeddings_mean,
                                                                            embeddings_std)
    num_tokens_found = 0
    index_to_token = vocab.get_index_to_token_vocabulary(namespace)
    for i in range(vocab_size):
        token = index_to_token[i]

        # If we don't have a pre-trained vector for this word, we'll just leave this row alone,
        # so the word has a random initialization.
        if token in embeddings:
            embedding_matrix[i] = torch.FloatTensor(embeddings[token])
            num_tokens_found += 1
        else:
            print("Token %s was not found in the embedding file. Initialising randomly.", token)

    print("Pretrained embeddings were found for %d out of %d tokens",
                num_tokens_found, vocab_size)

    return embedding_matrix


import zipfile
import gzip
import shutil
compressed_embedding_file = cached_path("https://allennlp.s3-us-west-2.amazonaws.com/knowbert/wiki_entity_linking/entities_glove_format.gz")
embedding_file = compressed_embedding_file+'+unzipped'
with gzip.open(compressed_embedding_file, 'rb') as f_in:
    with open(embedding_file, 'wb') as f_out:
        shutil.copyfileobj(f_in, f_out)

with open(embedding_file) as f:
    print(f.readline())
# print(embedding_file)
# tempdir = embedding_file+"_extracted"

# print(zipfile.is_zipfile(embedding_file))
# print(tarfile.is_tarfile(embedding_file))

# with gzip.open(embedding_file) as f:
#     print(f.readline())

# os.makedirs(tempdir,exist_ok=True)
# print("Extracting model archives")
# with tarfile.open(embedding_file, 'r:gz') as archive:
#     archive.extractall(tempdir)

#Remove temporary directory
# shutil.rmtree(tempdir)

read_embeddings_from_text_file(embedding_file,embedding_dim=300,vocab=vocabulary)
    

Reading pretrained embeddings from file


0it [00:00, ?it/s]

<class 'bytes'>
b'A_Forest 0.043461049344163 -0.083031912273826 0.015892639209858 -0.019632200966898 -0.030101784401457 0.037633704590477 -0.072133985662249 -0.10074446004005 0.069966410675443 -0.049433579831659 -0.039654540397471 -0.03242789842469 0.018648329481155 -0.052262736550407 -0.061408946042189 -0.070215945512216 0.052235180579016 -0.0016973801537844 -0.05759637604839 -0.015419615833859 0.03305185472591 0.11210366698168 -0.10238323753775 -0.021602179112486 -0.016465255424595 0.079950300090811 -0.093411107130336 -0.0092436798464461 0.054513717058838 0.055219987371663 -0.030644296546327 0.072715689722356 -0.0385469781022 -0.068647345030748 -0.022262602279184 0.080115352797101 -0.021307244164405 -0.022041506307571 0.014280509644071 0.041370047696809 0.026127087916798 -0.023551068390419 -0.051852362310526 0.054481135671007 0.07898001101303 0.012865536403939 -0.013461836150934 -0.062267074083521 0.023079185884535 0.049899390114566 0.062010714515155 0.0089464231112669 0.013683554246




TypeError: a bytes-like object is required, not 'str'

In [65]:
from kb.custom_knowbert import CustomKnowBert
from kb.soldered_kg import CustomSolderedKG, CustomEntityLinkingWithCandidateMentions
from kb.custom_knowledge import CustomWordNetAllEmbedding
from allennlp.common.file_utils import cached_path
import torch
import os
import tarfile
import shutil

wiki_state_dict_file = cached_path(WIKI_ARCHIVE)

# span_attention_config = {'hidden_size': 200, 'intermediate_size': 1024, 'num_attention_heads': 4, 'num_hidden_layers': 1}


#null_entity_id = model.vocab.get_token_index('@@NULL@@', "entity")

entity_dim = 300


custom_embedding = wiki_original_model.soldered_kgs['wiki'].entity_linker.disambiguator.entity_embeddings
null_entity_id = vocabulary.get_token_index('@@NULL@@', "entity")
span_encoder_config = {'hidden_size': 300, 'intermediate_size': 1024, 'num_attention_heads': 4, 'num_hidden_layers': 1}

custom_entity_linker = CustomEntityLinkingWithCandidateMentions(
                 null_entity_id=null_entity_id,
                 entity_embedding = custom_embedding,
                 contextual_embedding_dim =768,
                 span_encoder_config = span_encoder_config,
                 margin = 0.2,
                 decode_threshold = 0.0,
                 loss_type = 'softmax',
                 max_sequence_length = 512,
                 dropout = 0.1,
                 output_feed_forward_hidden_dim = 100,
                 initializer_range = 0.02)

custom_entity_linker = wiki_original_model.soldered_kgs['wiki'].entity_linker
span_attention_config = {'hidden_size': 300, 'intermediate_size': 1024, 'num_attention_heads': 4, 'num_hidden_layers': 1}

custom_wiki_kg = CustomSolderedKG(entity_linker = custom_entity_linker, 
                            span_attention_config = span_attention_config,
                            should_init_kg_to_bert_inverse = False,
                            freeze = False)

custom_soldered_kgs = {'wiki':custom_wiki_kg}
#custom_soldered_kgs = {'wiki':wiki_original_model.soldered_kgs['wiki']}

span_extractor_global_attention_old_name = "wiki_soldered_kg.entity_linker.disambiguator.span_extractor._global_attention._module.weight"
span_extractor_global_attention_bias_old_name = "wiki_soldered_kg.entity_linker.disambiguator.span_extractor._global_attention._module.bias"
state_dict_map = {span_extractor_global_attention_old_name:span_extractor_global_attention_old_name.replace("._module",""),
                span_extractor_global_attention_bias_old_name: span_extractor_global_attention_bias_old_name.replace("._module","")}

#Temporary, for using custom_soldered_k
state_dict_map = None

tempdir = wiki_state_dict_file+"_extracted"
os.makedirs(tempdir,exist_ok=True)

print("Extracting model archives")
with tarfile.open(wiki_state_dict_file, 'r:gz') as archive:
    archive.extractall(tempdir)

weights_file = tempdir+'/weights.th'


custom_wiki_model = CustomKnowBert(soldered_kgs = custom_soldered_kgs,
                                soldered_layers ={"wiki": 9},
                                bert_model_name = "bert-base-uncased",
                                mode=None,state_dict_file=weights_file,
                                strict_load_archive=True,
                                remap_segment_embeddings = None,
                                state_dict_map = state_dict_map)


Extracting model archives


In [66]:
test_set = create_test_set(wiki_original_model,wiki_batcher)
torch.save(test_set,'knowbert_wiki_model_test')
test_set = torch.load("knowbert_wiki_model_test")

custom_model = custom_wiki_model
model_correct(custom_model,test_set)

wiki entity_attention_probs are equal: True
wiki output linking scores are equal: True
loss are equal : True
pooled_output are equal : True
pooled_output are equal : True
contextual_embeddings are equal : True
contextual_embeddings are equal : True




### Check equality

In [13]:
def model_equal(original_model,custom_model,batches):

    for (name_model,param_model), (name_custom_model,param_custom_model) in zip(original_model.soldered_kgs['wordnet'].named_parameters(),custom_model.soldered_kgs['wordnet'].named_parameters()):
        # if(name_model!=name_custom_model):
        #     print(f"model_name:{name_model} is not equal to custom_model_name:{name_custom_model}")
        if(not torch.equal(param_model,param_custom_model)):
            print(f"Tensor values are not equal for custom_param:{name_custom_model} and model_param{name_model}")

    original_model.eval()
    custom_model.eval()
    # batcher takes raw untokenized sentences
    # and yields batches of tensors needed to run KnowBert
    for batch in batches:
        with torch.no_grad():
            original_output = original_model(**batch)
            custom_output = custom_model(**batch)

        for key in original_output.keys():
            if(key in ['wiki','wordnet']):
                print(f"wordnet entity_attention_probs are equal: {torch.equal(original_output[key]['entity_attention_probs'],custom_output[key]['entity_attention_probs'])}")

                print(f"Output linking scores are equal: {torch.equal(original_output[key]['linking_scores'],custom_output[key]['linking_scores'])}")
            else:
                print(f"Loss are equal : {original_output['loss']==custom_output['loss']}")

                print(f"Pooled outputs are equal : {torch.equal(original_output['pooled_output'],custom_output['pooled_output'])}")

                print(f"Contextual embeddings are equal : {torch.equal(original_output['contextual_embeddings'],custom_output['contextual_embeddings'])}\n")

In [26]:
batch = torch.load("test_batch")
custom_output = custom_model(**batch)
expected_output = torch.load("expected_output")


print(f"wordnet entity_attention_probs are equal: {torch.equal(expected_output['wordnet']['entity_attention_probs'],custom_output['wordnet']['entity_attention_probs'])}")

print(f"Output linking scores are equal: {torch.equal(expected_output['wordnet']['linking_scores'],custom_output['wordnet']['linking_scores'])}")

print(f"Loss are equal : {expected_output['loss']==custom_output['loss']}")

print(f"Pooled outputs are equal : {torch.equal(expected_output['pooled_output'],custom_output['pooled_output'])}")

print(f"Contextual embeddings are equal : {torch.equal(expected_output['contextual_embeddings'],custom_output['contextual_embeddings'])}")



FileNotFoundError: [Errno 2] No such file or directory: 'test_batch'

In [None]:
sentences = ["Paris is located in France.", "Michael Jackson is a great music singer"]
# batcher takes raw untokenized sentences
# and yields batches of tensors needed to run KnowBert
for i,batch in enumerate(batcher.iter_batches(sentences, verbose=True)):

    print(f"\nInput\n")
    print(f"Batch: {batch.keys()}") #Batch contains {tokens,segment_ids,candidates}
    #tokens: Tensor of tokens indices (used to idx an embedding) => because a batch contains multiple
    #sentences with varying # of tokens, all tokens tensors are padded with zeros 
    #shape: (batch_size (#sentences), max_seq_len)
    #print(batch['tokens'])#dict with only 'tokens'
    print(f"Tokens shape {batch['tokens']['tokens'].shape}")
    #Defines the segments_ids (0 for first segment and 1 for second), can be used for NSP
    #shape: (batch_size,max_seq_len)
    print(f"Segment ids shape: {batch['segment_ids'].shape}")

    #Dict with only wordnet
    #Candidates: stores for multiple knowledge base, the entities detected using this knowledge base
    wordnet_kb = batch['candidates']['wordnet']
    print(f"Wordnet kb: {wordnet_kb.keys()}")
    
    #Stores for each detected entities, a list of candidate KB entities that correspond to it
    #Priors: correctness probabilities estimated by the entity linker (sum to 1 (or 0 if padding) on axis 2)
    #Adds 0 padding to axis 1 when there is less detected entities in the sentence than in the max sentence
    #Adds 0 padding to axis 2 when there is less detected KB entities for an entity in the sentence than in the max candidate KB entities entity
    #shape:(batch_size, max # detected entities, max # KB candidate entities)
    print(f"Candidate entity_priors shape: {wordnet_kb['candidate_entity_priors'].shape}")
    #Ids of the KB candidate entities + 0 padding on axis 1 or 2 if necessary
    #shape: (batch_size, max # detected entities, max # KB candidate entities)
    print(f"Candidate entities ids shape: {wordnet_kb['candidate_entities']['ids'].shape}")
    #Spans of which sequence of tokens correspond to an entity in the sentence, eg: [1,2] for Michael Jackson (both bounds are included)
    #Padding with [-1,-1] when no more detected entities
    #shape: (batch_size, max # detected entities, 2)
    print(f"Candidate span shape: {wordnet_kb['candidate_spans'].shape}")

    #For each sentence entity, indicate to which segment ids it corresponds to
    #shape: (batch_size, max # detected entities)
    print(f"Candidate segments_ids shape: {wordnet_kb['candidate_segment_ids'].shape}")

    #model(**batch) <=> model(tokens = batch['tokens'],segment_ids=batch['segment_ids'],candidates=batch['candidates']) 
    model_output = model(**batch)
    
    print(f"\nOutput\n")
    print(f"Model output keys: {model_output.keys()}")
    print(f"Output wordnet keys: {model_output['wordnet'].keys()}")
    #Span attention layers scores for wordnet KB
    #shape: (batch_size,?,max_seq_len,max # detected entities)
    print(f"Output wordnet entity_attention_probs shape: {model_output['wordnet']['entity_attention_probs'].shape}")
    #Entity linker score for each text entity and possible KB entity, -1.0000e+04 padding in case of no score
    #shape: (batch_size, max # detected entities, max # KB candidate entities)
    print(f"Output wordnet linking_scores shape: {model_output['wordnet']['linking_scores'].shape}")
    
    #Scalar indicating loss over this batch (0 if not training?)
    print(f"Output loss: {model_output['loss']}")

    #Final CLS embedding for each sentence of batch
    # shape: (batch_size, hidden_size) 
    print(f"Pooled output shape: {model_output['pooled_output'].shape}")

    #For each tokens, its final embeddings
    #Important!!!, still predicts something for 0 padded tokens => ignore (or 0 padding <=> MASK???)
    print(f"Contextual embeddings: {model_output['contextual_embeddings'].shape}")

Paris is located in France.
['[CLS]', 'paris', 'is', 'located', 'in', 'france', '.', '[SEP]']
Michael Jackson is a great music singer
['[CLS]', 'michael', 'jackson', 'is', 'a', 'great', 'music', 'singer', '[SEP]']

Input

Batch: dict_keys(['tokens', 'segment_ids', 'candidates'])
Tokens shape torch.Size([2, 9])
Segment ids shape: torch.Size([2, 9])
Wordnet kb: dict_keys(['candidate_entity_priors', 'candidate_entities', 'candidate_spans', 'candidate_segment_ids'])
Candidate entity_priors shape: torch.Size([2, 8, 14])
Candidate entities ids shape: torch.Size([2, 8, 14])
Candidate span shape: torch.Size([2, 8, 2])
Candidate segments_ids shape: torch.Size([2, 8])

Output

Model output keys: dict_keys(['wordnet', 'loss', 'pooled_output', 'contextual_embeddings'])
Output wordnet keys: dict_keys(['entity_attention_probs', 'linking_scores'])
Output wordnet entity_attention_probs shape: torch.Size([2, 4, 9, 8])
Output wordnet linking_scores shape: torch.Size([2, 8, 14])
Output loss: 0.0
Pooled o



In [None]:
    #TODO: see how to add masking => 0 idx tokens embedding?
    #TODO: See how to extract from final embeddings the actual predicted tokens
    #TODO: copy paste all allennlp dependencies in an allennlp.py file that contains all classes => get rid of dependency

In [None]:
for name, param in model.named_parameters():
    print(f"{name}:{param.shape}")

pretrained_bert.bert.embeddings.word_embeddings.weight:torch.Size([30522, 768])
pretrained_bert.bert.embeddings.position_embeddings.weight:torch.Size([512, 768])
pretrained_bert.bert.embeddings.token_type_embeddings.weight:torch.Size([2, 768])
pretrained_bert.bert.embeddings.LayerNorm.weight:torch.Size([768])
pretrained_bert.bert.embeddings.LayerNorm.bias:torch.Size([768])
pretrained_bert.bert.encoder.layer.0.attention.self.query.weight:torch.Size([768, 768])
pretrained_bert.bert.encoder.layer.0.attention.self.query.bias:torch.Size([768])
pretrained_bert.bert.encoder.layer.0.attention.self.key.weight:torch.Size([768, 768])
pretrained_bert.bert.encoder.layer.0.attention.self.key.bias:torch.Size([768])
pretrained_bert.bert.encoder.layer.0.attention.self.value.weight:torch.Size([768, 768])
pretrained_bert.bert.encoder.layer.0.attention.self.value.bias:torch.Size([768])
pretrained_bert.bert.encoder.layer.0.attention.output.dense.weight:torch.Size([768, 768])
pretrained_bert.bert.encoder.la