Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to set the token_type_ids when use the pytorch-transformers Bert #3224

Closed
bestbzw opened this issue Sep 7, 2019 · 6 comments
Closed

How to set the token_type_ids when use the pytorch-transformers Bert #3224

bestbzw opened this issue Sep 7, 2019 · 6 comments

Comments

@bestbzw
Copy link

@bestbzw bestbzw commented Sep 7, 2019

Question
I want to use the Robert model, but i don't know how to set the token_type_ids.
It seems that only feed the token_ids

@TokenEmbedder.register("pretrained_transformer")
class PretrainedTransformerEmbedder(TokenEmbedder):
    """
    Uses a pretrained model from ``pytorch-transformers`` as a ``TokenEmbedder``.
    """
    def __init__(self, model_name: str) -> None:
        super().__init__()
        self.transformer_model = AutoModel.from_pretrained(model_name)
        # I'm not sure if this works for all models; open an issue on github if you find a case
        # where it doesn't work.
        self.output_dim = self.transformer_model.config.hidden_size

    @overrides
    def get_output_dim(self):
        return self.output_dim

    def forward(self, token_ids: torch.LongTensor) -> torch.Tensor:  # type: ignore
        # pylint: disable=arguments-differ
        return self.transformer_model(token_ids)[0]
@matt-gardner

This comment has been minimized.

Copy link
Member

@matt-gardner matt-gardner commented Sep 10, 2019

I was just looking over the huggingface code, and I'm actually not sure how that code passes this. It seems that roberta doesn't use these by default, so it's not surprising I couldn't find much for it in the roberta code. But bert does, and I couldn't find it there, either. Does anyone know where this happens in huggingface code?

It should be pretty easy to make this happen; we have code to detect the token type ids for our previous bert implementation. But it would probably be easier to use their code, if they have some already.

@joelgrus

This comment has been minimized.

Copy link
Collaborator

@joelgrus joelgrus commented Sep 13, 2019

@matt-gardner

This comment has been minimized.

Copy link
Member

@matt-gardner matt-gardner commented Sep 13, 2019

Yeah, I know that, what I don't know is where the token type ids are computed.

@maksym-del

This comment has been minimized.

Copy link
Contributor

@maksym-del maksym-del commented Sep 15, 2019

Hi @matt-gardner,

looks like it happens here: segment_ids.

@mikerossgithub

This comment has been minimized.

Copy link
Contributor

@mikerossgithub mikerossgithub commented Sep 18, 2019

In case it is helpful for anyone else, here is the interim fix I am currently using (for Bert only):

Replace this line:


With this:

result = {index_name: indices}
result.update(_guess_additional_indices(index_name, tokens))
return result

And add this function in the same file:

def _guess_additional_indices(index_name: str, tokens) -> Dict[str, List[str]]:
    if 'bert' in index_name:
        type_ids = []
        type_id = 0
        for token in tokens:
            type_ids.append(type_id)
            if token == '[SEP]':
                type_id += 1
        return {f'{index_name}-type-ids': type_ids}
    else:
        return {}
@matt-gardner

This comment has been minimized.

Copy link
Member

@matt-gardner matt-gardner commented Sep 19, 2019

Thanks @maksym-del, that is in a place I didn't really expect. And it looks like there isn't logic that we can just re-use here, unfortunately. But we do have a method to compute these segment ids ourselves already. It's just too bad that there isn't some API call in the pytorch transformers code, in case there is some difference between the models in how they handle this (or a cap on the number of token type embeddings, etc.).

And thanks for the example code @mikerossgithub. That solution will definitely require some changes in the model, also. But something like that is probably the right solution, using our existing code for determining type ids, and adding a flag to the indexer on whether to use the type ids or not (possibly with a "guess" method if the flag is not explicitly passed).

I don't think we're going to get to adding this ourselves anytime soon, but a PR to add this would definitely be welcome!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Linked pull requests

Successfully merging a pull request may close this issue.

6 participants
You can’t perform that action at this time.