This repository has been archived by the owner on Dec 16, 2022. It is now read-only.
/
gpt2.py
38 lines (31 loc) · 1.5 KB
/
gpt2.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
from overrides import overrides
from transformers.models.gpt2.modeling_gpt2 import GPT2Config, GPT2LMHeadModel
import torch
from .language_model_head import LanguageModelHead
@LanguageModelHead.register("gpt2")
class Gpt2LanguageModelHead(LanguageModelHead):
"""
Loads just the LM head from `transformers.GPT2LMHeadModel`. It was easiest to load
the entire model before only pulling out the head, so this is a bit slower than it could be,
but for practical use in a model, the few seconds of extra loading time is probably not a big
deal.
"""
def __init__(self, model_name: str) -> None:
super().__init__()
config = GPT2Config.from_pretrained(model_name)
self.input_dim = config.hidden_size
self.output_dim = config.vocab_size
# TODO(mattg): It's possible that we could use some kind of cache like we have in
# allennlp.modules.token_embedders.bert_token_embedder.PretrainedBertModel. That way, we
# would only load the GPT2 weights once. Though, it's not clear how to do that here, as we
# need to load `GPT2LMHeadModel`, not just `GPT2Model`...
gpt2_model = GPT2LMHeadModel.from_pretrained(model_name)
self.gpt2_lm_head = gpt2_model.lm_head
@overrides
def get_input_dim(self) -> int:
return self.input_dim
@overrides
def get_output_dim(self) -> int:
return self.output_dim
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
return self.gpt2_lm_head(hidden_states)