Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use KnowBert to predict missing words #11

Closed
jzbjyb opened this issue Dec 18, 2019 · 2 comments
Closed

Use KnowBert to predict missing words #11

jzbjyb opened this issue Dec 18, 2019 · 2 comments

Comments

@jzbjyb
Copy link

jzbjyb commented Dec 18, 2019

Hi Matthew,

Thanks a bunch for the documentation on embedding sentences programmatically. It saves me a lot of time! I did a little bit of modification so that I can use KnowBert to predict the missing word (i.e., [MASK]) in a sentence, but found the results are unexpected. I am not sure if my implementation is correct, here is code snippet:

from kb.include_all import ModelArchiveFromParams
from kb.knowbert_utils import KnowBertBatchifier
from allennlp.common import Params
import torch
import torch.nn.functional as F

archive_file = 'https://allennlp.s3-us-west-2.amazonaws.com/knowbert/models/knowbert_wiki_wordnet_model.tar.gz'

# load model and batcher
params = Params({'archive_file': archive_file})
model = ModelArchiveFromParams.from_params(params=params)
model.eval()
batcher = KnowBertBatchifier(archive_file)

# get bert vocab
vocab = list(batcher.tokenizer_and_candidate_generator.bert_tokenizer.ids_to_tokens.values())

sentences = ['Paris is located in [MASK].']
mask_ind = 5

for batch in batcher.iter_batches(sentences, verbose=False):
    model_output = model(**batch)
    # the tokenized sentence, where the 6-th token is [MASK]
    print([vocab[w] for w in batch['tokens']['tokens'][0].numpy()])
    logits, _ = model.pretraining_heads(model_output['contextual_embeddings'], model_output['pooled_output'])
    log_probs = F.log_softmax(logits, dim=-1)
    topk = torch.topk(log_probs[0, mask_ind], 10, 0)[1]
    # print the top 10 predictions
    print([vocab[t.item()] for t in topk])

The top 10 predictions are [UNK], the, itself, its, and, marne, to, them, first, lissa, while the top 10 predictions of BERT-uncased-base is france, paris, europe, italy, belgium, algeria, germany, russia, haiti, canada, which seems a little bit wired. Is my implementation correct or any suggestions on this? Thanks in advance!

@matt-peters
Copy link
Contributor

If you want to fill in [MASK] tokens then it's necessary initialize batcher = KnowBertBatchifier(archive_file, masking_strategy='full_mask'). This creates batches in the same was as during pretraining. After doing so, I get ['france', 'germany', 'belgium', 'europe', 'canada', 'italy', 'paris', 'spain', 'russia', 'algeria'] as the top 10 predictions for 'Paris is located in [MASK].'

@jzbjyb
Copy link
Author

jzbjyb commented Dec 19, 2019

Thanks for your quick reply! I just noticed that full_mask is not the default and using it can make correct predictions!

@jzbjyb jzbjyb closed this as completed Dec 19, 2019
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants