In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch

from kb.knowbert_utils import KnowBertBatchifier

WORDNET_ARCHIVE = "https://allennlp.s3-us-west-2.amazonaws.com/knowbert/models/knowbert_wordnet_model.tar.gz"
WIKI_ARCHIVE = "https://allennlp.s3-us-west-2.amazonaws.com/knowbert/models/knowbert_wiki_model.tar.gz"
WORDNET_WIKI_ARCHIVE = "https://allennlp.s3-us-west-2.amazonaws.com/knowbert/models/knowbert_wiki_wordnet_model.tar.gz"

WORDNET_FOLDER = '../knowbert_wordnet_model/'
WORDNET_LINKER_FOLDER = WORDNET_FOLDER + 'entity_linker/'
WORDNET_LINKER_EMBEDDING_FILE = WORDNET_LINKER_FOLDER + 'wordnet_synsets_mask_null_vocab_embeddings_tucker_gensen.hdf5'
WORDNET_LINKER_ENTITY_FILE = WORDNET_LINKER_FOLDER + 'entities.jsonl'
WORDNET_LINKER_VOCAB_FILE = WORDNET_LINKER_FOLDER + 'wordnet_synsets_mask_null_vocab.txt'


WORDNET_MODEL_STATE_DICT_FILE = WORDNET_FOLDER+ 'weights.th'

CharacterTokenizer params None False None None
TokenCharactersIndexer params: entity <allennlp.data.tokenizers.character_tokenizer.CharacterTokenizer object at 0x7fe1197c8b38> None None 0




In [3]:
original_batcher = KnowBertBatchifier(WIKI_ARCHIVE)

WikiCandidateMentionGenerator params: None None True False None
duplicate_mentions_cnt:  6777
end of p_e_m reading. wall time: 1.2307601968447368  minutes
p_e_m_errors:  0
incompatible_ent_ids:  0
TokenCharactersIndexer params: entity <allennlp.data.tokenizers.character_tokenizer.CharacterTokenizer object at 0x7fe1197c8b38> None None 0
BertTokenizerAndCandidateGenerator params
{'wiki': <kb.wiki_linking_util.WikiCandidateMentionGenerator object at 0x7fe39c99a160>}
{'wiki': <allennlp.data.token_indexers.token_characters_indexer.TokenCharactersIndexer object at 0x7fe399172908>}
bert-base-uncased
True
True
512


In [4]:
from kb.custom_tokenizer.custom_tokenizer import CustomKnowBertBatchifier
from kb.custom_tokenizer.bert_tokenizer_and_candidate_generator import BertTokenizerAndCandidateGenerator
from kb.custom_tokenizer.wiki_linking_util import WikiCandidateMentionGenerator
from kb.custom_tokenizer.vocabulary import Vocabulary

from allennlp.common import Params


candidate_generator_params = {
                        "type": "bert_tokenizer_and_candidate_generator",
                        "bert_model_type": "bert-base-uncased",
                        "do_lower_case": True,
                        "entity_candidate_generators": {
                            "wiki": {
                                "type": "wiki"
                            }
                        },
                        "entity_indexers": {
                            "wiki": {
                                "type": "characters_tokenizer",
                                "namespace": "entity",
                                "tokenizer": {
                                    "type": "word",
                                    "word_splitter": {
                                        "type": "just_spaces"
                                    }
                                }
                            }
                        }
                    }

In [5]:
custom_candidate_mention_generator = WikiCandidateMentionGenerator()#use default params: None None True False None
entity_candidate_generators = {'wiki':custom_candidate_mention_generator}

bert_model_type = 'bert-base-uncased'
custom_tokenizer_and_candidate_generator = BertTokenizerAndCandidateGenerator(entity_candidate_generators,bert_model_type,do_lower_case=True,whitespace_tokenize=True,max_word_piece_sequence_length=512) 

vocabulary = Vocabulary.from_files("https://allennlp.s3-us-west-2.amazonaws.com/knowbert/wiki_entity_linking/vocabulary_wiki.tar.gz")
#vocabulary = Vocabulary.from_params(Params({"directory_path": "https://allennlp.s3-us-west-2.amazonaws.com/knowbert/wiki_entity_linking/vocabulary_wiki.tar.gz"}))

duplicate_mentions_cnt:  6777
end of p_e_m reading. wall time: 1.3197708050409953  minutes
p_e_m_errors:  0
incompatible_ent_ids:  0


In [6]:
vocabulary.get_token_index("Paris_Hilton")
vocabulary.get_token_from_index(1,'entity')
vocabulary.get_token_from_index(156993,'entity')
# vocabulary._index_to_token['entity']

'Paris_Hilton'

In [7]:
custom_batcher = CustomKnowBertBatchifier(custom_tokenizer_and_candidate_generator,vocabulary)

In [13]:
#Create test set
sentences = [["Paris is located in [MASK].", "Michael [MASK] is a great music singer"],
            ["The Louvre contains the Mona Lisa", "The Amazon river is in Brazil"],
            "Donald Duck is a cartoon character",
            ["Hayao Miyazaki is the co-founder of Studio Ghibli and a renowned anime film maker",
            "The Alpine ibex is one of Switzerland's most famous animal along its grazing cows"]]
            
def batchifier_equal(original_batcher,custom_batcher,test_sentences):
    for original_batch,custom_batch in zip(original_batcher.iter_batches(test_sentences,verbose=False),custom_batcher.iter_batches(test_sentences,verbose=False)):
        
        print(f"token ids are equal: {torch.equal(original_batch['tokens']['tokens'], custom_batch['tokens']['tokens'])}")
        #Defines the segments_ids (0 for first segment and 1 for second), can be used for NSP
        #shape: (batch_size,max_seq_len)
        print(f"Segment ids are equal: {torch.equal(original_batch['segment_ids'],custom_batch['segment_ids'])}")

        original_wiki_kb = original_batch['candidates']['wiki']
        custom_wiki_kb = custom_batch['candidates']['wiki']

        print(f"Candidate entity_priors are equal: {torch.equal(original_wiki_kb['candidate_entity_priors'],custom_wiki_kb['candidate_entity_priors'])}")
        print(f"Candidate entities ids are equal: {torch.equal(original_wiki_kb['candidate_entities']['ids'],custom_wiki_kb['candidate_entities']['ids'])}")
        print(f"Candidate span are equal: {torch.equal(original_wiki_kb['candidate_spans'],custom_wiki_kb['candidate_spans'])}")

        #For each sentence entity, indicate to which segment ids it corresponds to
        print(f"Candidate segments_ids are equal: {torch.equal(original_wiki_kb['candidate_segment_ids'],custom_wiki_kb['candidate_segment_ids'])}")
    

batchifier_equal(original_batcher,custom_batcher,sentences)

Batch tokens shape (4, 42)
Batch segment_ids shape (4, 42)
token ids are equal: True
Segment ids are equal: True
Candidate entity_priors are equal: True
Candidate entities ids are equal: True
Candidate span are equal: True
Candidate segments_ids are equal: True


In [21]:
#Generate a test set for batchifier
test_sentences = [["Paris is located in [MASK].", "Michael [MASK] is a great music singer"],
            ["The Louvre contains the Mona Lisa", "The Amazon river is in Brazil"],
            "Donald Duck is a cartoon character",
            ["Hayao Miyazaki is the co-founder of Studio Ghibli and a renowned anime film maker",
            "The Alpine ibex is one of Switzerland's most famous animal along its grazing cows"]]

test = {}
test['input']=test_sentences
test_expected =[]
for original_batch in original_batcher.iter_batches(test['input'],verbose=False):
    test_expected.append(original_batch)

test['expected'] = test_expected

torch.save(test,"tokenizer_test")

test = torch.load("tokenizer_test")

for custom_batch,expected_batch in zip(custom_batcher.iter_batches(test['input'],verbose=False),test['expected']):
    print(f"token ids are equal: {torch.equal(expected_batch['tokens']['tokens'], custom_batch['tokens']['tokens'])}")
    #Defines the segments_ids (0 for first segment and 1 for second), can be used for NSP
    #shape: (batch_size,max_seq_len)
    print(f"Segment ids are equal: {torch.equal(expected_batch['segment_ids'],custom_batch['segment_ids'])}")

    expected_wiki_kb = original_batch['candidates']['wiki']
    custom_wiki_kb = custom_batch['candidates']['wiki']

    print(f"Candidate entity_priors are equal: {torch.equal(expected_wiki_kb['candidate_entity_priors'],custom_wiki_kb['candidate_entity_priors'])}")
    print(f"Candidate entities ids are equal: {torch.equal(expected_wiki_kb['candidate_entities']['ids'],custom_wiki_kb['candidate_entities']['ids'])}")
    print(f"Candidate span are equal: {torch.equal(expected_wiki_kb['candidate_spans'],custom_wiki_kb['candidate_spans'])}")

    #For each sentence entity, indicate to which segment ids it corresponds to
    print(f"Candidate segments_ids are equal: {torch.equal(expected_wiki_kb['candidate_segment_ids'],custom_wiki_kb['candidate_segment_ids'])}")


token ids are equal: True
Segment ids are equal: True
Candidate entity_priors are equal: True
Candidate entities ids are equal: True
Candidate span are equal: True
Candidate segments_ids are equal: True


In [10]:
for batch in original_batcher.iter_batches(sentences, verbose=False):
    print(f"Batch: {batch.keys()}") #Batch contains {tokens,segment_ids,candidates}
    #tokens: Tensor of tokens indices (used to idx an embedding) => because a batch contains multiple
    #sentences with varying # of tokens, all tokens tensors are padded with zeros 
    #shape: (batch_size (#full_sentences(if two sentences => link them with NSP)), max_seq_len)
    #print(batch['tokens'])#dict with only 'tokens'
    print(f"Tokens shape {batch['tokens']['tokens'].shape}")
    print(f"Tokens type {batch['tokens']['tokens'].dtype}")
    #Defines the segments_ids (0 for first segment and 1 for second), can be used for NSP
    #shape: (batch_size,max_seq_len)
    print(f"Segment ids shape: {batch['segment_ids'].shape}")
    print(f"Segment ids type {batch['segment_ids'].dtype}")

    #Dict with only wordnet
    #Candidates: stores for multiple knowledge base, the entities detected using this knowledge base
    wiki_kb = batch['candidates']['wiki']
    # print(f"Wordnet kb: {wordnet_kb.keys()}")

    #Stores for each detected entities, a list of candidate KB entities that correspond to it
    #Priors: correctness probabilities estimated by the entity linker (sum to 1 (or 0 if padding) on axis 2)
    #Adds 0 padding to axis 1 when there is less detected entities in the sentence than in the max sentence
    #Adds 0 padding to axis 2 when there is less detected KB entities for an entity in the sentence than in the max candidate KB entities entity
    #shape:(batch_size, max # detected entities, max # KB candidate entities)
    print(f"Candidate entity_priors shape: {wiki_kb['candidate_entity_priors'].shape}")
    print(f"Candidate entity_priors type: {wiki_kb['candidate_entity_priors'].dtype}")
    #Ids of the KB candidate entities + 0 padding on axis 1 or 2 if necessary
    #shape: (batch_size, max # detected entities, max # KB candidate entities)
    print(f"Candidate entities ids shape: {wiki_kb['candidate_entities']['ids'].shape}")
    print(f"Candidate entities ids type: {wiki_kb['candidate_entities']['ids'].dtype}")
    #Spans of which sequence of tokens correspond to an entity in the sentence, eg: [1,2] for Michael Jackson (both bounds are included)
    #Padding with [-1,-1] when no more detected entities
    #shape: (batch_size, max # detected entities, 2)
    print(f"Candidate span shape: {wiki_kb['candidate_spans'].shape}")
    print(f"Candidate span type: {wiki_kb['candidate_spans'].dtype}")

    #For each sentence entity, indicate to which segment ids it corresponds to
    print(f"Candidate segments_ids shape: {wiki_kb['candidate_segment_ids'].shape}")
    print(f"Candidate segments_ids type: {wiki_kb['candidate_segment_ids'].dtype}")
    #break



    # 
    #shape: (batch_size, max # detected entities)
    # model(**batch)

Batch: dict_keys(['tokens', 'segment_ids', 'candidates'])
Tokens shape torch.Size([4, 42])
Tokens type torch.int64
Segment ids shape: torch.Size([4, 42])
Segment ids type torch.int64
Candidate entity_priors shape: torch.Size([4, 26, 30])
Candidate entity_priors type: torch.float32
Candidate entities ids shape: torch.Size([4, 26, 30])
Candidate entities ids type: torch.int64
Candidate span shape: torch.Size([4, 26, 2])
Candidate span type: torch.int64
Candidate segments_ids shape: torch.Size([4, 26])
Candidate segments_ids type: torch.int64


In [11]:
print(f"\nInput\n")
print(f"Batch: {batch.keys()}") #Batch contains {tokens,segment_ids,candidates}
#tokens: Tensor of tokens indices (used to idx an embedding) => because a batch contains multiple
#sentences with varying # of tokens, all tokens tensors are padded with zeros 
#shape: (batch_size (#sentences), max_seq_len)
#print(batch['tokens'])#dict with only 'tokens'
print(f"Tokens shape {batch['tokens']['tokens'].shape}")
#Defines the segments_ids (0 for first segment and 1 for second), can be used for NSP
#shape: (batch_size,max_seq_len), 0 padding
print(f"Segment ids shape: {batch['segment_ids'].shape}")

#Dict with only wordnet
#Candidates: stores for multiple knowledge base, the entities detected using this knowledge base
wordnet_kb = batch['candidates']['wordnet']
print(f"Wordnet kb: {wordnet_kb.keys()}")

#Stores for each detected entities, a list of candidate KB entities that correspond to it
#Priors: correctness probabilities estimated by the entity linker (sum to 1 (or 0 if padding) on axis 2)
#Adds 0 padding to axis 1 when there is less detected entities in the sentence than in the max sentence
#Adds 0 padding to axis 2 when there is less detected KB entities for an entity in the sentence than in the max candidate KB entities entity
#shape:(batch_size, max # detected entities, max # KB candidate entities)
print(f"Candidate entity_priors shape: {wordnet_kb['candidate_entity_priors'].shape}")
#Ids of the KB candidate entities + 0 padding on axis 1 or 2 if necessary
#shape: (batch_size, max # detected entities, max # KB candidate entities)
print(f"Candidate entities ids shape: {wordnet_kb['candidate_entities']['ids'].shape}")
#Spans of which sequence of tokens correspond to an entity in the sentence, eg: [1,2] for Michael Jackson (both bounds are included)
#Padding with [-1,-1] when no more detected entities
#shape: (batch_size, max # detected entities, 2)
print(f"Candidate span shape: {wordnet_kb['candidate_spans'].shape}")

#For each sentence entity, indicate to which segment ids it corresponds to
#shape: (batch_size, max # detected entities)
print(f"Candidate segments_ids shape: {wordnet_kb['candidate_segment_ids'].shape}")


Input

Batch: dict_keys(['tokens', 'segment_ids', 'candidates'])
Tokens shape torch.Size([4, 42])
Segment ids shape: torch.Size([4, 42])


KeyError: 'wordnet'