Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

add BERT token embedder #2067

Merged
merged 19 commits into from
Nov 26, 2018
Merged

add BERT token embedder #2067

merged 19 commits into from
Nov 26, 2018

Conversation

joelgrus
Copy link
Contributor

@joelgrus joelgrus commented Nov 16, 2018

this is ready for review. in addition to the included unit tests, I trained two NER models using these embeddings (unfortunately, I realized this morning, I used the uncased BERT model, which seems like a bad idea for NER)

(1) only BERT embeddings: https://beaker-internal.allenai.org/ex/ex_rnk3mcplnpjz/tasks
(2) BERT embeddings + character embeddings: https://beaker-internal.allenai.org/ex/ex_nrq8d5vw5cb2/tasks

(apologies to non-AI2 people for the beaker-internal links)

as discussed offline, because of the positional encodings the BERT embedding has a max sequence length and will crash if you feed it longer sequences. this implementation simply truncates longer sequences and logs a warning. I left a TODO to come up with something better.

@SparkJiao
Copy link

Very appreciate your work about bert. I'm considering using bert on my task so could I use your implementation now?

@joelgrus
Copy link
Contributor Author

I wouldn't recommend it using it for anything important, the code hasn't been thoroughly tested yet.

(you can try it and tell me how well it works though. 😀)

@SparkJiao
Copy link

SparkJiao commented Nov 17, 2018 via email

@hzeng-otterai
Copy link
Contributor

Thanks for all the work!
I prefer that the BERT code being included in the allennlp library instead of pip install it. One reason is that people may want to experiment with the code using allennlp framework.
Also, ELMo and BERT in the same place is fun.

@SparkJiao
Copy link

Unfortunately I'd say that I have find a bug.
While sorting in the batches using the sorting keys where has a dict {"passage", "num_tokens"} it is reported that the key num_tokens didn't exist. The bug is because the bert token indexer has made an array of tokens with different lengths with single_id token_indexer and character_tokenizer and the key "num_tokens" will not be added into the dict "padding_lengths".
The config file is dialog_qa.jsonnet and I will try to modified the sorting keys temporarily

@SparkJiao
Copy link

So are there some problems while using GloVe and singId token characters and bert at the same time? Because bert will make sequences with different lengths from the others? And we may not joint the word embedding ? Sorry to bother you

@thomwolf
Copy link
Contributor

Hi @joelgrus, I've released our implementation on pip (see https://github.com/huggingface/pytorch-pretrained-BERT). Sorry for the delay! Tell me if I can change anything to make it easier for you to integrate in AllenNLP!

@joelgrus
Copy link
Contributor Author

the timing is perfect, thanks so much!

@joelgrus
Copy link
Contributor Author

@SparkJiao the indexer produces a bert-offsets field that contains the indices of the last wordpiece for each word. if you pass this to the bert token embedder, it will only return the embedding of the last wordpiece per original token, which gives it the right size. (whether this is the optimal way to accomplish this I'm less sure of)

you can do this from your config file by adding something like the following to your text_field_embedder:

            "allow_unmatched_keys": true,
            "embedder_to_indexer_map": {
                "tokens": ["tokens"],  // etc
                "bert": ["bert", "bert-offsets"]
            },

so that the model knows to pass the offsets into the bert token embedder.

@SparkJiao
Copy link

@joelgrus Well, sorry to reply lately and very appreciate to your help. I recently met some other problems and I can't solve them so I may wait for your official tutorial:(
Thanks a lot at last!

@susht3
Copy link

susht3 commented Nov 19, 2018

This is mostly implemented, with two caveats:

(1) the end to end test for it is failing (and is itself not fully written), but for reasons that I'm pretty sure are unrelated to BERT or the new code (and that I'm pretty sure are related to me doing something really dumb somewhere that I haven't figured out)

(2) huggingface said they're going to pip release their implementation, but afaik they haven't, so for now I copied over the relevant files, but I'm using them as if they were a library, so when the library gets released I should (in theory) just have to change a handful of import statements and everything should work.

now they have pip release their implementation, could you please add bert?

@joelgrus joelgrus changed the title WIP: add BERT token embedder add BERT token embedder Nov 21, 2018
Copy link
Contributor

@matt-gardner matt-gardner left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code largely looks good. The big question to me is how to use BERT embeddings if you're not fine tuning (or making sure that it works as expected if you are fine tuning).

@@ -182,7 +182,7 @@ expected-line-ending-format=
[BASIC]

# Good variable names which should always be accepted, separated by a comma
good-names=i,j,k,ex,Run,_
good-names=i,j,k,ex,Run,_,logger
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, I didn't realize you could do this. Nice find.

@@ -351,13 +351,14 @@ def _check_is_dict(self, new_history, value):
return value

@staticmethod
def from_file(params_file: str, params_overrides: str = "") -> 'Params':
def from_file(params_file: str, params_overrides: str = "", ext_vars: dict = {}) -> 'Params':
# pylint: disable=dangerous-default-value
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I agree with pylint here - why not just have this be None, and add a line to the logic below? Much less error-prone in future maintenance of this code.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It'd also be nice to document what this does.


@overrides
def count_vocab_items(self, token: Token, counter: Dict[str, Dict[str, int]]):
# If we only use pretrained models, we don't need to do anything here.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the docstring above, it looked like you were recommending a different class for pretrained models. Are you talking about pretrained WordpieceTokenizers here?


def _add_encoding_to_vocabulary(self, vocabulary: Vocabulary) -> None:
# pylint: disable=protected-access
for word, idx in self.vocab.items():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this takes a while, you might consider putting a logging statement (or even a tqdm) in here.

logger = logging.getLogger(__name__)


class BertIndexer(TokenIndexer[int]):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe call this a WordpieceIndexer? This is more general than BERT. The class below is the BERT-specific one.

bert_model: ``BertModel``
The BERT model being wrapped.
"""
def __init__(self,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All on one line? Also, what's the motivation for splitting this class up into two? So that people can either instantiate the BertModel themselves or use a string to reference it? Do we really need both of these?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, looking at the config fixture, looks like this will let you train your own bert model if you want? Ok, yeah, that's definitely sufficient motivation to split this up.

Parameters
----------
input_ids: ``torch.LongTensor``
The wordpiece ids for each input sentence.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sentence? I'm not sure that's correct. It'd probably clear up what you meant here if you gave an expected shape.

If an input consists of two sentences (as in the BERT paper),
tokens from the first sentence should have type 0 and tokens from
the second sentence should have type 1. If you don't provide this
(the default BertIndexer doesn't) then it's assumed to be all 0s.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a way to modify the indexer to provide this? It's fine to do it in another PR, but I'd at least open an issue to track adding that. Seems pretty important if you want to use this for SQuAD.

passage1 = "There were four major HDTV systems tested by SMPTE in the late 1970s, and in 1979 an SMPTE study group released A Study of High Definition Television Systems:"
question1 = "Who released A Study of High Definition Television Systems?"

passage2 = """Broca, being what today would be called a neurosurgeon, had taken an interest in the pathology of speech. He wanted to localize the difference between man and the other animals, which appeared to reside in speech. He discovered the speech center of the human brain, today called Broca's area after him. His interest was mainly in Biological anthropology, but a German philosopher specializing in psychology, Theodor Waitz, took up the theme of general and social anthropology in his six-volume work, entitled Die Anthropologie der Naturvölker, 1859–1864. The title was soon translated as "The Anthropology of Primitive Peoples". The last two volumes were published posthumously."""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're using triple quotes already; can you make this into multiple shorter lines instead of one super long one? Same with the passage above.

input_mask = (input_ids != 0).long()

all_encoder_layers, _ = self.bert_model(input_ids, input_mask, token_type_ids)
sequence_output = all_encoder_layers[-1]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The top layer is a particularly bad layer to use for transfer (talk to Nelson about why, or wait for his final internship presentation). If we're fine-tuning, this is ok, but if we're not, we probably need some kind of scalar mix or something here, instead of just taking the last layer. This is probably why your performance is so poor for your NER experiment.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Table 2 of the original elmo paper suggests otherwise (unless this is specific to transformers)? Perhaps this accounts for ~1 F1 difference, not 20+.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Layer 1 is pretty consistently better than high layers across tasks, in Nelson's experiments (with 2 and 4 layer ELMo). The top layer is specialized for language modeling, and only a few tasks are close enough to benefit from that specialization (turns out NER is one of those). For transformers, the story is a bit different, but middle layers are still better than the top layer.

Table 2 of the ELMo paper didn't try just using layer 1.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wasn't saying there isn't a difference - just that even with the top layer, it should be there or thereabouts. So that isn't the problem/difference to look into here for the NER model that Joel trained.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I missed that slack convo. Nevermind.

@matthew-z
Copy link

matthew-z commented Dec 2, 2018

@joelgrus Thank you for the response.

Yes, I tried to remove bert-offsets from embedder_to_indexer_map , and then the bert encoded input became (batch_size, wordpiece_sequence_length) instead of (batch_size, tokens_sequence_length), then this code will not work because #tokens != #wordpieces:

mask = util.get_text_field_mask(inputs)
encoded_inputs = self.text_field_embedder(inputs)
logits = self.seq2vec(encoded_inputs, mask)

In other words, I want a mask in this shape: (batch_size, wordpiece_sequence_length).
I used a simple workaround to use bert tokenzer to let token = wordpiece:

@WordSplitter.register("wordPiece")
class WordPiece(WordSplitter):

    def __init__(self, pretrained_model, do_lowercase) -> None:
        super().__init__()
        self.tokenizer = BertTokenizer.from_pretrained(pretrained_model, do_lower_case=do_lowercase)

    @overrides
    def split_words(self, sentence: str) -> List[Token]:
        return [Token(t) for t in self.tokenizer.tokenize(sentence) if t]

config looks like:

  "dataset_reader": {
    "type": "mydataset",
    "tokenizer": {
         "type":"word",  
         "word_splitter":{
              "type":"wordPiece", 
              "pretrained_model": "bert-base-uncased",
              "do_lowercase": true
         }
    },
    "token_indexers": {
      "bert": {
          "type": "bert-pretrained",
          "pretrained_model": "bert-base-uncased",
          "do_lowercase": true,
      },
    },
  },

Then mask will become (batch_size, wordpiece_sequence_length).

@joelgrus
Copy link
Contributor Author

joelgrus commented Dec 2, 2018

the problem here is that util.get_text_field_mask sees the "mask" key in your inputs and uses that as the mask even though here it's not what you want.

I think a simpler solution here is just not to use util.get_text_field_mask, you can do something like

mask = inputs["bert"] != 0

which is what the token embedder does internally:

https://github.com/allenai/allennlp/blob/master/allennlp/modules/token_embedders/bert_token_embedder.py#L85

@matthew-z
Copy link

This solution is indeed much simpler.

Thank you!

@bheinzerling
Copy link

bheinzerling commented Dec 4, 2018

Did anybody manage to get a CoNLL'03 dev score close to the reported 96.4 F1 with BERT_base? Best I've managed to get is 94.6 using the last transformer layer for classification, as described in the paper (huggingface/transformers#64 (comment)).

@matt-gardner
Copy link
Contributor

@bheinzerling, this is from a slack conversation a couple of weeks ago with @matt-peters:

I've run a bunch of combinations, trying to initially reproduce their results in table 7.

They left out some important details in the paper for the NER task, namely that they used document context for each word since it's available in the raw data. So I was never able to reproduce their results since I was using sentence context.

It also makes their results not directly comparable to previous work, since your standard glove + biLSTM would also presumably improve with document context...

FYI: with my implementation, I got dev F1 95.09 +/- 0.07 for 2x200 dim LSTM + CRF for second-to-last layer (they reported 95.6 in table 7 without CRF). This uses sentence context to compute the BERT activations.

(This was done outside of AllenNLP.)

The last layer is slightly worse than the second-to-last layer, so your number seems to agree with @matt-peters' results.

@matt-peters
Copy link
Contributor

FWIW, adding document context improves F1 a little to reproduce the results in Table 7 (+/- noise from random seeds). Note that these numbers are from extracting features from BERT, not fine-tuning.

@qiuwei
Copy link
Contributor

qiuwei commented Dec 5, 2018

@joelgrus I had a look at the config for NER you provided. It seems that you didn't pad the sentence with the special token [CLS] at the start of each sentence.

However, I found this quite crucial in my local experiments. Did I miss anything?

@joelgrus
Copy link
Contributor Author

joelgrus commented Dec 5, 2018

let me take a look

@qiuwei
Copy link
Contributor

qiuwei commented Dec 5, 2018

@matt-peters Hi matt, could you illustrate a bit more about adding document context?
Did you use the whole document(I believe that would often exceed the max length allowed by bert?) or just add a few sentences around the target sentence?

When a larger context is used, did your model predict the NER labels for the context as well?

@joelgrus
Copy link
Contributor Author

joelgrus commented Dec 5, 2018

@qiuwei it looks like you are right and it's a "bug" in the token indexer, I'll open an issue for it, thanks for finding this

@matt-peters
Copy link
Contributor

@qiuwei - I took the easy / simple implementation approach and just chunked the document into non-overlapping segments if it exceeded the maximum length usable by BERT. The only wrinkle is ensuring not to chunk the document in the middle of an entity span. This way each token still has an annotation, and the NER model still predicts labels for every token.

@bikestra
Copy link

bikestra commented Mar 6, 2019

@matt-peters Were you able to reproduce BERT paper's results once you introduced the document context? I was able to get dev F1 95.3 using sentence context but this is still 1.1% point behind what authors claim, and I didn't see much boost using document context.

This was work done outside of allennlp, but I thought independently reproducing their result using any tool would help all of us to progress; I am finding trouble anyone who was able to successfully reproduce their NER results. Sorry if this disturbed allennlp contributors.

@matt-peters
Copy link
Contributor

I haven't tried to reproduce the fine tuning result (96.4 dev F1, BERT base, Table 3).

@kamalkraj
Copy link

https://github.com/kamalkraj/BERT-NER
Replicated results from BERT paper

@pasinit
Copy link

pasinit commented Oct 27, 2019

Did anyone try to compute the average of the word piece instead of using the first workpiece of each word? For example if I have the token "longtoken" which is split in "long" "token", for my understanding now for the whole token one can easily take the embedding for long. How easy would it be instead to take the average of "long" and "token"?

@wangxinyu0922
Copy link

So does anyone successfully reproduce the score reported in the paper? I want to reproduce the score but I find it very hard to do that (only 91.5~ with document context + finetuning).

@dsindex
Copy link

dsindex commented Apr 3, 2021

@wangxinyu0922

with document context, i reached 92.35% (bert-base-cased)

dsindex/ntagger#4 (comment)

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.