In [1]:
%load_ext autoreload
%autoreload 2

In [4]:
print("hello")

hello


In [2]:
import torch
import os

from kb.include_all import ModelArchiveFromParams

from kb.knowbert_utils import KnowBertBatchifier
from allennlp.common import Params

# contains pretrained model, e.g. for Wordnet+Wikipedia
WORDNET_ARCHIVE = "https://allennlp.s3-us-west-2.amazonaws.com/knowbert/models/knowbert_wordnet_model.tar.gz"
WIKI_ARCHIVE = "https://allennlp.s3-us-west-2.amazonaws.com/knowbert/models/knowbert_wiki_model.tar.gz"
WORDNET_WIKI_ARCHIVE = "https://allennlp.s3-us-west-2.amazonaws.com/knowbert/models/knowbert_wiki_wordnet_model.tar.gz"



In [12]:
params = Params({"archive_file": WORDNET_ARCHIVE})#Only contains a dictionnary with a single entry: archive_file:http://...
model = ModelArchiveFromParams.from_params(params=params) #P
batcher = KnowBertBatchifier(WORDNET_ARCHIVE)



In [14]:
# load model and batcher
params = Params({"archive_file": WORDNET_ARCHIVE})
params.params

{'archive_file': 'https://allennlp.s3-us-west-2.amazonaws.com/knowbert/models/knowbert_wordnet_model.tar.gz'}

In [4]:
print(type(model))
print(type(params))
print(type(batcher))

# config = Params.from_file

# config

<class 'kb.knowbert.KnowBert'>
<class 'allennlp.common.params.Params'>
<class 'kb.knowbert_utils.KnowBertBatchifier'>


In [34]:
batcher_config = Params.from_file('../knowbert_wordnet_model/config.json')
batcher_config['dataset_reader'].as_dict()

# candidate_generator_params = _find_key(
#             batcher_config['dataset_reader'].as_dict(), 'tokenizer_and_candidate_generator'
#         )

candidate_generator_params = batcher_config['dataset_reader']['dataset_readers']['language_modeling']['base_reader']['tokenizer_and_candidate_generator'].as_dict()

candidate_generator_params

{'bert_model_type': 'bert-base-uncased',
 'do_lower_case': True,
 'entity_candidate_generators': {'wordnet': {'entity_file': 'https://allennlp.s3-us-west-2.amazonaws.com/knowbert/wordnet/entities.jsonl',
   'type': 'wordnet_mention_generator'}},
 'entity_indexers': {'wordnet': {'namespace': 'entity',
   'tokenizer': {'type': 'word', 'word_splitter': {'type': 'just_spaces'}},
   'type': 'characters_tokenizer'}},
 'type': 'bert_tokenizer_and_candidate_generator'}

In [None]:
from typing import Dict, List, Sequence, Union
import copy

import numpy as np

from allennlp.data.fields import Field, TextField, ListField, SpanField, ArrayField, LabelField
from allennlp.data.token_indexers import TokenIndexer, SingleIdTokenIndexer
from allennlp.data import Token
from pytorch_pretrained_bert.tokenization import BertTokenizer, BasicTokenizer

from kb.dict_field import DictField
from allennlp.common.registrable import Registrable

from kb.common import MentionGenerator

start_token = "[CLS]"
sep_token = "[SEP]"

def get_empty_candidates():
    """
    The mention generators always return at least one candidate, but signal
    it with this special candidate
    """
    return {
        "candidate_spans": [[-1, -1]],
        "candidate_entities": [["@@PADDING@@"]],
        "candidate_entity_priors": [[1.0]]
    }

def truncate_sequence_pair(word_piece_tokens_a, word_piece_tokens_b, max_word_piece_sequence_length):
    length_a = sum([len(x) for x in word_piece_tokens_a])
    length_b = sum([len(x) for x in word_piece_tokens_b])
    while max_word_piece_sequence_length < length_a + length_b:
        if length_a < length_b:
            discarded = word_piece_tokens_b.pop()
            length_b -= len(discarded)
        else:
            discarded = word_piece_tokens_a.pop()
            length_a -= len(discarded)


class TokenizerAndCandidateGenerator(Registrable):
    pass

@TokenizerAndCandidateGenerator.register("bert_tokenizer_and_candidate_generator")
class BertTokenizerAndCandidateGenerator(Registrable):
    def __init__(self,
                 entity_candidate_generators: Dict[str, MentionGenerator],
                 entity_indexers: Dict[str, TokenIndexer],
                 bert_model_type: str,
                 do_lower_case: bool,
                 whitespace_tokenize: bool = True,
                 max_word_piece_sequence_length: int = 512) -> None:
        """
        Note: the fields need to be used with a pre-generated allennlp vocabulary
        that contains the entity id namespaces and the bert name space.
        entity_indexers = {'wordnet': indexer for wordnet entities,
                          'wiki': indexer for wiki entities}
        """
        # load BertTokenizer from huggingface
        self.candidate_generators = entity_candidate_generators
        self.bert_tokenizer = BertTokenizer.from_pretrained(
            bert_model_type, do_lower_case=do_lower_case
        )
        self.bert_word_tokenizer = BasicTokenizer(do_lower_case=False)
        # Target length should include start and end token
        self.max_word_piece_sequence_length = max_word_piece_sequence_length

        self._entity_indexers = entity_indexers
        # for bert, we'll give an empty token indexer with empty name space
        # and do the indexing directly with the bert vocab to bypass
        # indexing in the indexer
        self._bert_single_id_indexer = {'tokens': SingleIdTokenIndexer('__bert__')}
        self.do_lowercase = do_lower_case
        self.whitespace_tokenize = whitespace_tokenize
        self.dtype = np.float32

    def _word_to_word_pieces(self, word):
        if self.do_lowercase and word not in self.bert_tokenizer.basic_tokenizer.never_split:
            word = word.lower()
        return self.bert_tokenizer.wordpiece_tokenizer.tokenize(word)

    def tokenize_and_generate_candidates(self, text_a: str, text_b: str = None):
        """
        # run BertTokenizer.basic_tokenizer.tokenize on sentence1 and sentence2 to word tokenization
        # generate candidate mentions for each of the generators and for each of sentence1 and 2 from word tokenized text
        # run BertTokenizer.wordpiece_tokenizer on sentence1 and sentence2
        # truncate length, add [CLS] and [SEP] to word pieces
        # compute token offsets
        # combine candidate mention spans from sentence1 and sentence2 and remap to word piece indices

        returns:

        {'tokens': List[str], the word piece strings with [CLS] [SEP]
         'segment_ids': List[int] the same length as 'tokens' with 0/1 for sentence1 vs 2
         'candidates': Dict[str, Dict[str, Any]],
            {'wordnet': {'candidate_spans': List[List[int]],
                         'candidate_entities': List[List[str]],
                         'candidate_entity_prior': List[List[float]],
                         'segment_ids': List[int]},
             'wiki': ...}
        }
        """
        offsets_a, grouped_wp_a, tokens_a = self._tokenize_text(text_a)

        if text_b is not None:
            offsets_b, grouped_wp_b, tokens_b = self._tokenize_text(text_b)
            truncate_sequence_pair(grouped_wp_a, grouped_wp_b, self.max_word_piece_sequence_length - 3)
            offsets_b = offsets_b[:len(grouped_wp_b)]
            tokens_b = tokens_b[:len(grouped_wp_b)]
            instance_b = self._generate_sentence_entity_candidates(tokens_b, offsets_b)
            word_piece_tokens_b = [word_piece for word in grouped_wp_b for word_piece in word]
        else:
            length_a = sum([len(x) for x in grouped_wp_a])
            while self.max_word_piece_sequence_length - 2 < length_a:
                discarded = grouped_wp_a.pop()
                length_a -= len(discarded)

        word_piece_tokens_a = [word_piece for word in grouped_wp_a for word_piece in word]
        offsets_a = offsets_a[:len(grouped_wp_a)]
        tokens_a = tokens_a[:len(grouped_wp_a)]
        instance_a = self._generate_sentence_entity_candidates(tokens_a, offsets_a)

        # If we got 2 sentences.
        if text_b is not None:
            # Target length should include start and two end tokens, and then be divided equally between both sentences
            # Note that this will result in potentially shorter documents than original target length,
            # if one (or both) of the sentences are shorter than half the target length.
            tokens = [start_token] + word_piece_tokens_a + [sep_token] + word_piece_tokens_b + [sep_token]
            segment_ids = (len(word_piece_tokens_a) + 2) * [0] + (len(word_piece_tokens_b) + 1) * [1]
            offsets_a = [x + 1 for x in offsets_a]
            offsets_b = [x + 2 + len(word_piece_tokens_a) for x in offsets_b]
        # Single sentence
        else:
            tokens = [start_token] + word_piece_tokens_a + [sep_token]
            segment_ids = len(tokens) * [0]
            offsets_a = [x + 1 for x in offsets_a]
            offsets_b = None

        for name in instance_a.keys():
            for span in instance_a[name]['candidate_spans']:
                span[0] += 1
                span[1] += 1

        fields: Dict[str, Sequence] = {}

        # concatanating both sentences (for both tokens and ids)
        if text_b is None:
            candidates = instance_a
        else:
            candidates: Dict[str, Field] = {}

            # Merging candidate lists for both sentences.
            for entity_type in instance_b:
                candidate_instance_a = instance_a[entity_type]
                candidate_instance_b = instance_b[entity_type]

                candidates[entity_type] = {}

                for span in candidate_instance_b['candidate_spans']:
                    span[0] += len(word_piece_tokens_a) + 2
                    span[1] += len(word_piece_tokens_a) + 2

                # Merging each of the fields.
                for key in ['candidate_entities', 'candidate_spans', 'candidate_entity_priors']:
                    candidates[entity_type][key] = candidate_instance_a[key] + candidate_instance_b[key]


        for entity_type in candidates.keys():
            # deal with @@PADDING@@
            if len(candidates[entity_type]['candidate_entities']) == 0:
                candidates[entity_type] = get_empty_candidates()
            else:
                padding_indices = []
                has_entity = False
                for cand_i, candidate_list in enumerate(candidates[entity_type]['candidate_entities']):
                    if candidate_list == ["@@PADDING@@"]:
                        padding_indices.append(cand_i)
                        candidates[entity_type]["candidate_spans"][cand_i] = [-1, -1]
                    else:
                        has_entity = True
                indices_to_remove = []
                if has_entity and len(padding_indices) > 0:
                    # remove all the padding entities since have some valid
                    indices_to_remove = padding_indices
                elif len(padding_indices) > 0:
                    assert len(padding_indices) == len(candidates[entity_type]['candidate_entities'])
                    indices_to_remove = padding_indices[1:]
                for ind in reversed(indices_to_remove):
                    del candidates[entity_type]["candidate_spans"][ind]
                    del candidates[entity_type]["candidate_entities"][ind]
                    del candidates[entity_type]["candidate_entity_priors"][ind]

        # get the segment ids for the spans
        for key, cands in candidates.items():
            span_segment_ids = []
            for candidate_span in cands['candidate_spans']:
                span_segment_ids.append(segment_ids[candidate_span[0]])
            candidates[key]['candidate_segment_ids'] = span_segment_ids

        fields['tokens'] = tokens
        fields['segment_ids'] = segment_ids
        fields['candidates'] = candidates
        fields['offsets_a'] = offsets_a
        fields['offsets_b'] = offsets_b
        return fields

    def _tokenize_text(self, text):
        if self.whitespace_tokenize:
            tokens = text.split()
        else:
            tokens = self.bert_word_tokenizer.tokenize(text)

        word_piece_tokens = []
        offsets = [0]
        for token in tokens:
            #NOTE: lowercase if necessary and tokenize
            word_pieces = self._word_to_word_pieces(token)
            offsets.append(offsets[-1] + len(word_pieces))
            word_piece_tokens.append(word_pieces)
        del offsets[0]
        return offsets, word_piece_tokens, tokens

    def _generate_sentence_entity_candidates(self, tokens, offsets):
        """
        Tokenize sentence, trim it to the target length, and generate entity candidates.
        :param sentence
        :param target_length: The length of the output sentence in terms of word pieces.
        :return: Dict[str, Dict[str, Any]],
            {'wordnet': {'candidate_spans': List[List[int]],
                         'candidate_entities': List[List[str]],
                         'candidate_entity_priors': List[List[float]]},
             'wiki': ...}

        """
        assert len(tokens) == len(offsets), f'Length of tokens {len(tokens)} must equal that of offsets {len(offsets)}.'
        entity_instances = {}
        for name, mention_generator in self.candidate_generators.items():
            entity_instances[name] = mention_generator.get_mentions_raw_text(' '.join(tokens), whitespace_tokenize=True)

        for name, entities in entity_instances.items():
            candidate_spans = entities["candidate_spans"]
            adjusted_spans = []
            for start, end in candidate_spans:
                if 0 < start:
                    adjusted_span = [offsets[start - 1], offsets[end] - 1]
                else:
                    adjusted_span = [0, offsets[end] - 1]
                adjusted_spans.append(adjusted_span)
            entities['candidate_spans'] = adjusted_spans
            entity_instances[name] = entities
        return entity_instances

    def convert_tokens_candidates_to_fields(self, tokens_and_candidates):
        """
        tokens_and_candidates is the return from a previous call to
        generate_sentence_entity_candidates.  Converts the dict to
        a dict of fields usable with allennlp.
        """
        fields = {}

        fields['tokens'] = TextField(
                [Token(t, text_id=self.bert_tokenizer.vocab[t])
                    for t in tokens_and_candidates['tokens']],
                token_indexers=self._bert_single_id_indexer
        )

        fields['segment_ids'] = ArrayField(
            np.array(tokens_and_candidates['segment_ids']), dtype=np.int
        )

        all_candidates = {}
        for key, entity_candidates in tokens_and_candidates['candidates'].items():
            # pad the prior to create the array field
            # make a copy to avoid modifying the input
            candidate_entity_prior = copy.deepcopy(
                    entity_candidates['candidate_entity_priors']
            )
            max_cands = max(len(p) for p in candidate_entity_prior)
            for p in candidate_entity_prior:
                if len(p) < max_cands:
                    p.extend([0.0] * (max_cands - len(p)))
            np_prior = np.array(candidate_entity_prior)

            candidate_fields = {
                "candidate_entity_priors": ArrayField(np_prior, dtype=self.dtype),
                "candidate_entities": TextField(
                    [Token(" ".join(candidate_list)) for candidate_list in entity_candidates["candidate_entities"]],
                    token_indexers={'ids': self._entity_indexers[key]}),
                "candidate_spans": ListField(
                    [SpanField(span[0], span[1], fields['tokens']) for span in
                    entity_candidates['candidate_spans']]
                ),
                "candidate_segment_ids": ArrayField(
                    np.array(entity_candidates['candidate_segment_ids']), dtype=np.int
        )
            }
            all_candidates[key] = DictField(candidate_fields)

        fields["candidates"] = DictField(all_candidates)

        return fields


@TokenizerAndCandidateGenerator.register("pretokenized")
class PretokenizedTokenizerAndCandidateGenerator(BertTokenizerAndCandidateGenerator):
    """
    Simple modification to the ``BertTokenizerAndCandidateGenerator``. We assume data comes
    pre-tokenized, so only wordpiece splitting is performed.

    # TODO: mypy is not going to like us calling ``tokenize_and_generate_candidates()`` on lists
    # instead of strings. Maybe update type annotations in ``BertTokenizerAndCandidateGenerator``?
    """
    def _tokenize_text(self, tokens: List[str]):
        word_piece_tokens = []
        offsets = [0]
        for token in tokens:
            # Stupid hack
            if token in ['[SEP]', '[MASK]']:
                word_pieces = [token]
            else:
                word_pieces = self._word_to_word_pieces(token)
            offsets.append(offsets[-1] + len(word_pieces))
            word_piece_tokens.append(word_pieces)
        del offsets[0]
        return offsets, word_piece_tokens, tokens


In [31]:

from typing import Union, List

from allennlp.common import Params
from allennlp.data import Instance, DataIterator, Vocabulary
from allennlp.common.file_utils import cached_path


from kb.include_all import TokenizerAndCandidateGenerator
from kb.bert_pretraining_reader import replace_candidates_with_mask_entity



def _extract_config_from_archive(model_archive):
    import tarfile
    import tempfile
    import os
    with tempfile.TemporaryDirectory() as tmp:
        with tarfile.open(model_archive, 'r:gz') as archive:
            archive.extract('config.json', path=tmp)
            config = Params.from_file(os.path.join(tmp, 'config.json'))
    return config

#NOTE: Recursively traverse dictionary to find the value corresponding to a key
def _find_key(d, key):
    val = None
    stack = [d.items()]
    while len(stack) > 0 and val is None:
        s = stack.pop()
        for k, v in s:
            if k == key:
                val = v
                break
            elif isinstance(v, dict):
                stack.append(v.items())
    return val

class KnowBertBatchifierCustom:
    """
    Takes a list of sentence strings and returns a tensor dict usable with
    a KnowBert model
    """
    def __init__(self, candidate_generator_params,vocab_params, batch_size=32,
                       masking_strategy=None,
                       wordnet_entity_file=None):

        # get bert_tokenizer_and_candidate_generator
        #config = _extract_config_from_archive(cached_path(model_archive))

        # look for the bert_tokenizers and candidate_generator
        #NOTE: Contains config info for tokenizer and candidate_generator
        # candidate_generator_params = _find_key(
        #     config['dataset_reader'].as_dict(), 'tokenizer_and_candidate_generator'
        # )

        if wordnet_entity_file is not None:
            candidate_generator_params['entity_candidate_generators']['wordnet']['entity_file'] = wordnet_entity_file

        self.tokenizer_and_candidate_generator = TokenizerAndCandidateGenerator.\
                from_params(Params(candidate_generator_params))
                
        self.tokenizer_and_candidate_generator.whitespace_tokenize = False

        assert masking_strategy is None or masking_strategy == 'full_mask'

        self.masking_strategy = masking_strategy

        # need bert_tokenizer_and_candidate_generator
        # if vocab_dir is not None:
        #     vocab_params = Params({"directory_path": vocab_dir})
        # else:
        #     vocab_params = config['vocabulary']

        self.vocab = Vocabulary.from_params(vocab_params)

        self.iterator = DataIterator.from_params(
            Params({"type": "basic", "batch_size": batch_size})
        )
        self.iterator.index_with(self.vocab)

    def _replace_mask(self, s):
        return s.replace('[MASK]', ' [MASK] ')

    def iter_batches(self, sentences_or_sentence_pairs: Union[List[str], List[List[str]]], verbose=True):
        # create instances
        instances = []
        for sentence_or_sentence_pair in sentences_or_sentence_pairs:
            if isinstance(sentence_or_sentence_pair, list):
                assert len(sentence_or_sentence_pair) == 2
                tokens_candidates = self.tokenizer_and_candidate_generator.\
                        tokenize_and_generate_candidates(
                                self._replace_mask(sentence_or_sentence_pair[0]),
                                self._replace_mask(sentence_or_sentence_pair[1]))
            else:
                tokens_candidates = self.tokenizer_and_candidate_generator.\
                        tokenize_and_generate_candidates(self._replace_mask(sentence_or_sentence_pair))

            if verbose:
                print(self._replace_mask(sentence_or_sentence_pair))
                print(tokens_candidates['tokens'])

            # now modify the masking if needed
            if self.masking_strategy == 'full_mask':
                # replace the mask span with a @@mask@@ span
                masked_indices = [index for index, token in enumerate(tokens_candidates['tokens'])
                      if token == '[MASK]']

                spans_to_mask = set([(i, i) for i in masked_indices])
                replace_candidates_with_mask_entity(
                        tokens_candidates['candidates'], spans_to_mask
                )

                # now make sure the spans are actually masked
                for key in tokens_candidates['candidates'].keys():
                    for span_to_mask in spans_to_mask:
                        found = False
                        for span in tokens_candidates['candidates'][key]['candidate_spans']:
                            if tuple(span) == tuple(span_to_mask):
                                found = True
                        if not found:
                            tokens_candidates['candidates'][key]['candidate_spans'].append(list(span_to_mask))
                            tokens_candidates['candidates'][key]['candidate_entities'].append(['@@MASK@@'])
                            tokens_candidates['candidates'][key]['candidate_entity_priors'].append([1.0])
                            tokens_candidates['candidates'][key]['candidate_segment_ids'].append(0)
                            # hack, assume only one sentence
                            assert not isinstance(sentence_or_sentence_pair, list)


            fields = self.tokenizer_and_candidate_generator.\
                convert_tokens_candidates_to_fields(tokens_candidates)

            instances.append(Instance(fields))


        for batch in self.iterator(instances, num_epochs=1, shuffle=False):
            yield batch

In [36]:

print("Hello")

Hello


In [35]:
def test_batcher(original_batcher,custom_batcher):
    sentences = ["Paris is located in France.", "Michael Jackson is a great music singer"]
    # batcher takes raw untokenized sentences
    # and yields batches of tensors needed to run KnowBert
    for original_batch,custom_batch in zip(original_batcher.iter_batches(sentences, verbose=False),custom_batcher.iter_batches(sentences, verbose=False)):

        print(f"Tokens are equal: {torch.equal(original_batch['tokens']['tokens'],custom_batch['tokens']['tokens'])}")
        print(f"Segment ids are equal: {torch.equal(original_batch['segment_ids'],custom_batch['segment_ids'])}")
        original_wordnet_kb = original_batch['candidates']['wordnet']
        custom_wordnet_kb = custom_batch['candidates']['wordnet']
        print(f"Candidate entity priors are equal: {torch.equal(original_wordnet_kb['candidate_entity_priors'],custom_wordnet_kb['candidate_entity_priors'])}")
        print(f"Candidate entity ids are equal: {torch.equal(original_wordnet_kb['candidate_entities']['ids'],custom_wordnet_kb['candidate_entities']['ids'])}")
        print(f"Candidate spans are equal: {torch.equal(original_wordnet_kb['candidate_spans'],custom_wordnet_kb['candidate_spans'])}")
        print(f"Candidate segment ids are equal: {torch.equal(original_wordnet_kb['candidate_segment_ids'],custom_wordnet_kb['candidate_segment_ids'])}")

batcher_config = Params.from_file('../knowbert_wordnet_model/config.json')
candidate_generator_params = batcher_config['dataset_reader']['dataset_readers']['language_modeling']['base_reader']['tokenizer_and_candidate_generator'].as_dict()
vocab_param = batcher_config['vocabulary']
custom_batcher = KnowBertBatchifierCustom(candidate_generator_params,vocab_param)
test_batcher(batcher,custom_batcher)



Tokens are equal: True
Segment ids are equal: True
Candidate entity priors are equal: True
Candidate entity ids are equal: True
Candidate spans are equal: True
Candidate segment ids are equal: True


In [10]:
sentences = ["Paris is located in France.", "Michael Jackson is a great music singer"]
# batcher takes raw untokenized sentences
# and yields batches of tensors needed to run KnowBert
for i,batch in enumerate(batcher.iter_batches(sentences, verbose=True)):

    print(f"\nInput\n")
    print(f"Batch: {batch.keys()}") #Batch contains {tokens,segment_ids,candidates}
    #tokens: Tensor of tokens indices (used to idx an embedding) => because a batch contains multiple
    #sentences with varying # of tokens, all tokens tensors are padded with zeros 
    #shape: (batch_size (#sentences), max_seq_len)
    #print(batch['tokens'])#dict with only 'tokens'
    print(f"Tokens shape {batch['tokens']['tokens'].shape}")
    #Defines the segments_ids (0 for first segment and 1 for second), can be used for NSP
    #shape: (batch_size,max_seq_len)
    print(f"Segment ids shape: {batch['segment_ids'].shape}")

    #Dict with only wordnet
    #Candidates: stores for multiple knowledge base, the entities detected using this knowledge base
    wordnet_kb = batch['candidates']['wordnet']
    print(f"Wordnet kb: {wordnet_kb.keys()}")

    
    #Stores for each detected entities, a list of candidate KB entities that correspond to it
    #Priors: correctness probabilities estimated by the entity linker (sum to 1 (or 0 if padding) on axis 2)
    #Adds 0 padding to axis 1 when there is less detected entities in the sentence than in the max sentence
    #Adds 0 padding to axis 2 when there is less detected KB entities for an entity in the sentence than in the max candidate KB entities entity
    #shape:(batch_size, max # detected entities, max # KB candidate entities)
    print(f"Candidate entity_priors shape: {wordnet_kb['candidate_entity_priors'].shape}")
    #Ids of the KB candidate entities + 0 padding on axis 1 or 2 if necessary
    #shape: (batch_size, max # detected entities, max # KB candidate entities)
    print(f"Candidate entities ids shape: {wordnet_kb['candidate_entities']['ids'].shape}")
    #Spans of which sequence of tokens correspond to an entity in the sentence, eg: [1,2] for Michael Jackson (both bounds are included)
    
    #Padding with [-1,-1] when no more detected entities
    #shape: (batch_size, max # detected entities, 2)
    print(f"Candidate span shape: {wordnet_kb['candidate_spans'].shape}")

    #For each sentence entity, indicate to which segment ids it corresponds to
    #shape: (batch_size, max # detected entities)
    print(f"Candidate segments_ids shape: {wordnet_kb['candidate_segment_ids'].shape}")

    #model(**batch) <=> model(tokens = batch['tokens'],segment_ids=batch['segment_ids'],candidates=batch['candidates']) 
    model_output = model(**batch)
    
    print(f"\nOutput\n")
    print(f"Model output keys: {model_output.keys()}")
    print(f"Output wordnet keys: {model_output['wordnet'].keys()}")
    #Span attention layers scores for wordnet KB
    #shape: (batch_size,?,max_seq_len,max # detected entities)
    print(f"Output wordnet entity_attention_probs shape: {model_output['wordnet']['entity_attention_probs'].shape}")
    #Entity linker score for each text entity and possible KB entity, -1.0000e+04 padding in case of no score
    #shape: (batch_size, max # detected entities, max # KB candidate entities)
    print(f"Output wordnet linking_scores shape: {model_output['wordnet']['linking_scores'].shape}")
    
    #Scalar indicating loss over this batch (0 if not training?)
    print(f"Output loss: {model_output['loss']}")

    #Final CLS embedding for each sentence of batch
    # shape: (batch_size, hidden_size) 
    print(f"Pooled output shape: {model_output['pooled_output'].shape}")

    #For each tokens, its final embeddings
    #Important!!!, still predicts something for 0 padded tokens => ignore (or 0 padding <=> MASK???)
    print(f"Contextual embeddings: {model_output['contextual_embeddings'].shape}")

    #TODO: see how to add masking => 0 idx tokens embedding?
    #TODO: See how to extract from final embeddings the actual predicted tokens
    #TODO: copy paste all allennlp dependencies in an allennlp.py file that contains all classes => get rid of dependency

Paris is located in France.
['[CLS]', 'paris', 'is', 'located', 'in', 'france', '.', '[SEP]']
Michael Jackson is a great music singer
['[CLS]', 'michael', 'jackson', 'is', 'a', 'great', 'music', 'singer', '[SEP]']

Input

Batch: dict_keys(['tokens', 'segment_ids', 'candidates'])
Tokens shape torch.Size([2, 9])
Segment ids shape: torch.Size([2, 9])
Wordnet kb: dict_keys(['candidate_entity_priors', 'candidate_entities', 'candidate_spans', 'candidate_segment_ids'])
Candidate entity_priors shape: torch.Size([2, 8, 14])
Candidate entities ids shape: torch.Size([2, 8, 14])
Candidate span shape: torch.Size([2, 8, 2])
Candidate segments_ids shape: torch.Size([2, 8])

Output

Model output keys: dict_keys(['wordnet', 'loss', 'pooled_output', 'contextual_embeddings'])
Output wordnet keys: dict_keys(['entity_attention_probs', 'linking_scores'])
Output wordnet entity_attention_probs shape: torch.Size([2, 4, 9, 8])
Output wordnet linking_scores shape: torch.Size([2, 8, 14])
Output loss: 0.0
Pooled o



In [None]:
    #TODO: see how to add masking => 0 idx tokens embedding?
    #TODO: See how to extract from final embeddings the actual predicted tokens
    #TODO: copy paste all allennlp dependencies in an allennlp.py file that contains all classes => get rid of dependency

In [6]:
for name, param in model.named_parameters():
    print(f"{name}:{param.shape}")

pretrained_bert.bert.embeddings.word_embeddings.weight:torch.Size([30522, 768])
pretrained_bert.bert.embeddings.position_embeddings.weight:torch.Size([512, 768])
pretrained_bert.bert.embeddings.token_type_embeddings.weight:torch.Size([2, 768])
pretrained_bert.bert.embeddings.LayerNorm.weight:torch.Size([768])
pretrained_bert.bert.embeddings.LayerNorm.bias:torch.Size([768])
pretrained_bert.bert.encoder.layer.0.attention.self.query.weight:torch.Size([768, 768])
pretrained_bert.bert.encoder.layer.0.attention.self.query.bias:torch.Size([768])
pretrained_bert.bert.encoder.layer.0.attention.self.key.weight:torch.Size([768, 768])
pretrained_bert.bert.encoder.layer.0.attention.self.key.bias:torch.Size([768])
pretrained_bert.bert.encoder.layer.0.attention.self.value.weight:torch.Size([768, 768])
pretrained_bert.bert.encoder.layer.0.attention.self.value.bias:torch.Size([768])
pretrained_bert.bert.encoder.layer.0.attention.output.dense.weight:torch.Size([768, 768])
pretrained_bert.bert.encoder.la