In [3]:
import torch

print(torch.__version__)

1.2.0


In [6]:
from kb.include_all import ModelArchiveFromParams
from kb.knowbert_utils import KnowBertBatchifier
from allennlp.common import Params

import torch

# a pretrained model, e.g. for Wordnet+Wikipedia

WORDNET_ARCHIVE = "https://allennlp.s3-us-west-2.amazonaws.com/knowbert/models/knowbert_wordnet_model.tar.gz"
WIKI_ARCHIVE = "https://allennlp.s3-us-west-2.amazonaws.com/knowbert/models/knowbert_wiki_model.tar.gz"
WORDNET_WIKI_ARCHIVE = "https://allennlp.s3-us-west-2.amazonaws.com/knowbert/models/knowbert_wiki_wordnet_model.tar.gz"

# load model and batcher
params = Params({"archive_file": WORDNET_ARCHIVE})
model = ModelArchiveFromParams.from_params(params=params)
batcher = KnowBertBatchifier(WORDNET_ARCHIVE)

100%|██████████| 1400916256/1400916256 [01:28<00:00, 15918868.01B/s]
100%|██████████| 563648/563648 [00:00<00:00, 662252.90B/s]


In [67]:
print(type(model))
print(type(params))
print(type(batcher))

<class 'kb.knowbert.KnowBert'>
<class 'allennlp.common.params.Params'>
<class 'kb.knowbert_utils.KnowBertBatchifier'>


In [99]:
sentences = ["Paris is located in France.", "Michael Jackson is a great music singer"]
# batcher takes raw untokenized sentences
# and yields batches of tensors needed to run KnowBert
for i,batch in enumerate(batcher.iter_batches(sentences, verbose=True)):

    print(f"\nInput\n")
    print(f"Batch: {batch.keys()}") #Batch contains {tokens,segment_ids,candidates}
    #tokens: Tensor of tokens indices (used to idx an embedding) => because a batch contains multiple
    #sentences with varying # of tokens, all tokens tensors are padded with zeros 
    #shape: (batch_size (#sentences), max_seq_len)
    #print(batch['tokens'])#dict with only 'tokens'
    print(f"Tokens shape {batch['tokens']['tokens'].shape}")
    #Defines the segments_ids (0 for first segment and 1 for second), can be used for NSP
    #shape: (batch_size,max_seq_len)
    print(f"Segment ids shape: {batch['segment_ids'].shape}")

    #Dict with only wordnet
    #Candidates: stores for multiple knowledge base, the entities detected using this knowledge base
    wordnet_kb = batch['candidates']['wordnet']
    print(f"Wordnet kb: {wordnet_kb.keys()}")

    
    #Stores for each detected entities, a list of candidate KB entities that correspond to it
    #Priors: correctness probabilities estimated by the entity linker (sum to 1 (or 0 if padding) on axis 2)
    #Adds 0 padding to axis 1 when there is less detected entities in the sentence than in the max sentence
    #Adds 0 padding to axis 2 when there is less detected KB entities for an entity in the sentence than in the max candidate KB entities entity
    #shape:(batch_size, max # detected entities, max # KB candidate entities)
    print(f"Candidate entity_priors shape: {wordnet_kb['candidate_entity_priors'].shape}")
    #Ids of the KB candidate entities + 0 padding on axis 1 or 2 if necessary
    #shape: (batch_size, max # detected entities, max # KB candidate entities)
    print(f"Candidate entities ids shape: {wordnet_kb['candidate_entities']['ids'].shape}")
    #Spans of which sequence of tokens correspond to an entity in the sentence, eg: [1,2] for Michael Jackson (both bounds are included)
    
    #Padding with [-1,-1] when no more detected entities
    #shape: (batch_size, max # detected entities, 2)
    print(f"Candidate span shape: {wordnet_kb['candidate_spans'].shape}")

    #For each sentence entity, indicate to which segment ids it corresponds to
    #shape: (batch_size, max # detected entities)
    print(f"Candidate segments_ids shape: {wordnet_kb['candidate_segment_ids'].shape}")

    #model(**batch) <=> model(tokens = batch['tokens'],segment_ids=batch['segment_ids'],candidates=batch['candidates']) 
    model_output = model(**batch)
    
    print(f"\nOutput\n")
    print(f"Model output keys: {model_output.keys()}")
    print(f"Output wordnet keys: {model_output['wordnet'].keys()}")
    #Span attention layers scores for wordnet KB
    #shape: (batch_size,?,max_seq_len,max # detected entities)
    print(f"Output wordnet entity_attention_probs shape: {model_output['wordnet']['entity_attention_probs'].shape}")
    #Entity linker score for each text entity and possible KB entity, -1.0000e+04 padding in case of no score
    #shape: (batch_size, max # detected entities, max # KB candidate entities)
    print(f"Output wordnet linking_scores shape: {model_output['wordnet']['linking_scores'].shape}")
    
    #Scalar indicating loss over this batch (0 if not training?)
    print(f"Output loss: {model_output['loss']}")

    #Final CLS embedding for each sentence of batch
    # shape: (batch_size, hidden_size) 
    print(f"Pooled output shape: {model_output['pooled_output'].shape}")

    #For each tokens, its final embedding ()
    #s
    #Important!!!, still predicts someting for 0 padded tokens => ignore (or 0 padding <=> MASK???)
    print(f"Contextual embeddings: {model_output['contextual_embeddings'].shape}")

    #TODO: see how to add masking => 0 idx tokens embedding?
    #TODO: See how to extract from final embeddings the actual predicted tokens

Paris is located in France.
['[CLS]', 'paris', 'is', 'located', 'in', 'france', '.', '[SEP]']
Michael Jackson is a great music singer
['[CLS]', 'michael', 'jackson', 'is', 'a', 'great', 'music', 'singer', '[SEP]']

Input

Batch: dict_keys(['tokens', 'segment_ids', 'candidates'])
Tokens shape torch.Size([2, 9])
Segment ids shape: torch.Size([2, 9])
Wordnet kb: dict_keys(['candidate_entity_priors', 'candidate_entities', 'candidate_spans', 'candidate_segment_ids'])
Candidate entity_priors shape: torch.Size([2, 8, 14])
Candidate entities ids shape: torch.Size([2, 8, 14])
Candidate span shape: torch.Size([2, 8, 2])
Candidate segments_ids shape: torch.Size([2, 8])

Output

Model output keys: dict_keys(['wordnet', 'loss', 'pooled_output', 'contextual_embeddings'])
Output wordnet keys: dict_keys(['entity_attention_probs', 'linking_scores'])
Output wordnet entity_attention_probs shape: torch.Size([2, 4, 9, 8])
Output wordnet linking_scores shape: torch.Size([2, 8, 14])
Output loss: 0.0
Pooled o



In [71]:
for name, param in model.named_parameters():
    print(f"{name}:{param.shape}")

pretrained_bert.bert.embeddings.word_embeddings.weight:torch.Size([30522, 768])
pretrained_bert.bert.embeddings.position_embeddings.weight:torch.Size([512, 768])
pretrained_bert.bert.embeddings.token_type_embeddings.weight:torch.Size([2, 768])
pretrained_bert.bert.embeddings.LayerNorm.weight:torch.Size([768])
pretrained_bert.bert.embeddings.LayerNorm.bias:torch.Size([768])
pretrained_bert.bert.encoder.layer.0.attention.self.query.weight:torch.Size([768, 768])
pretrained_bert.bert.encoder.layer.0.attention.self.query.bias:torch.Size([768])
pretrained_bert.bert.encoder.layer.0.attention.self.key.weight:torch.Size([768, 768])
pretrained_bert.bert.encoder.layer.0.attention.self.key.bias:torch.Size([768])
pretrained_bert.bert.encoder.layer.0.attention.self.value.weight:torch.Size([768, 768])
pretrained_bert.bert.encoder.layer.0.attention.self.value.bias:torch.Size([768])
pretrained_bert.bert.encoder.layer.0.attention.output.dense.weight:torch.Size([768, 768])
pretrained_bert.bert.encoder.la