Skip to content

Commit

Permalink
Refactor GPT2 module interface (#238)
Browse files Browse the repository at this point in the history
  • Loading branch information
gpengzhi committed Oct 15, 2019
1 parent bd4c514 commit 0fe6d47
Show file tree
Hide file tree
Showing 6 changed files with 228 additions and 204 deletions.
2 changes: 1 addition & 1 deletion texar/torch/modules/classifiers/gpt2_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def __init__(self,

self.is_binary = (self.num_classes == 1) or \
(self.num_classes <= 0 and
self._hparams.dim == 1)
self._hparams.encoder.dim == 1)

@staticmethod
def default_hparams():
Expand Down
156 changes: 80 additions & 76 deletions texar/torch/modules/decoders/gpt2_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
]


class GPT2Decoder(TransformerDecoder, PretrainedGPT2Mixin):
class GPT2Decoder(PretrainedGPT2Mixin):
r"""Raw GPT2 Transformer for decoding sequences. Please see
:class:`~texar.torch.modules.PretrainedGPT2Mixin` for a brief description
of GPT2.
Expand Down Expand Up @@ -58,39 +58,43 @@ class GPT2Decoder(TransformerDecoder, PretrainedGPT2Mixin):
:meth:`default_hparams` for the hyperparameter structure
and default values.
"""
_IS_DECODE = True

def __init__(self,
pretrained_model_name: Optional[str] = None,
cache_dir: Optional[str] = None,
hparams=None):
self.load_pretrained_config(pretrained_model_name, cache_dir, hparams)
super().__init__(hparams=hparams)

self.load_pretrained_config(pretrained_model_name, cache_dir)

# Word embedding
word_embedder = WordEmbedder(
self.word_embedder = WordEmbedder(
vocab_size=self._hparams.vocab_size,
hparams=self._hparams.embed)

# Position embedding
position_embedder = PositionEmbedder(
self.position_embedder = PositionEmbedder(
position_size=self._hparams.position_size,
hparams=self._hparams.position_embed)

# The GPT2 decoder (a TransformerDecoder)
super().__init__(vocab_size=self._hparams.vocab_size,
output_layer=word_embedder.embedding,
hparams=None)
def func(tokens, positions):
word_embeds = self.word_embedder(tokens)
pos_embeds = self.position_embedder(positions)
return word_embeds + pos_embeds

# Register modules after `__init__` is called.
self.word_embedder = word_embedder
self.position_embedder = position_embedder
class GPT2TransformerDecoder(TransformerDecoder):
def embed_tokens(self, tokens: torch.LongTensor,
positions: torch.LongTensor) -> torch.Tensor:
return func(tokens, positions)

self.init_pretrained_weights()
self.decoder = GPT2TransformerDecoder(
vocab_size=self._hparams.vocab_size,
output_layer=self.word_embedder.embedding,
hparams=self._hparams.decoder)

def embed_tokens(self, tokens: torch.LongTensor,
positions: torch.LongTensor) -> torch.Tensor:
word_embeds = self.word_embedder(tokens)
pos_embeds = self.position_embedder(positions)
return word_embeds + pos_embeds
self.init_pretrained_weights()

@staticmethod
def default_hparams():
Expand Down Expand Up @@ -197,53 +201,53 @@ def default_hparams():
Name of the module.
"""
return {
**TransformerDecoder.default_hparams(),
'dim': 768,
'num_blocks': 12,
'use_gpt_config': True,
'embedding_dropout': 0,
'residual_dropout': 0,
'multihead_attention': {
'use_bias': True,
'num_units': 768,
'num_heads': 12,
"dropout_rate": 0.0,
'output_dim': 768
},
'initializer': {
'type': 'variance_scaling_initializer',
'kwargs': {
'factor': 1.0,
'mode': 'FAN_AVG',
'uniform': True
}
},
'poswise_feedforward': {
'layers': [
{
'type': 'Linear',
'kwargs': {
'in_features': 768,
'out_features': 3072,
'bias': True
}
},
{
'type': 'GPTGELU',
'kwargs': {}
},
{
'type': 'Linear',
'kwargs': {
'in_features': 3072,
'out_features': 768,
'bias': True
}
'decoder': {
'dim': 768,
'num_blocks': 12,
'use_gpt_config': True,
'embedding_dropout': 0,
'residual_dropout': 0,
'multihead_attention': {
'use_bias': True,
'num_units': 768,
'num_heads': 12,
"dropout_rate": 0.0,
'output_dim': 768
},
'initializer': {
'type': 'variance_scaling_initializer',
'kwargs': {
'factor': 1.0,
'mode': 'FAN_AVG',
'uniform': True
}
],
'name': 'ffn'
},
'poswise_feedforward': {
'layers': [
{
'type': 'Linear',
'kwargs': {
'in_features': 768,
'out_features': 3072,
'bias': True
}
},
{
'type': 'GPTGELU',
'kwargs': {}
},
{
'type': 'Linear',
'kwargs': {
'in_features': 3072,
'out_features': 768,
'bias': True
}
}
],
'name': 'ffn'
},
},

'pretrained_model_name': 'gpt2-small',
'vocab_size': 50257,
'context_size': 1024,
Expand Down Expand Up @@ -286,18 +290,18 @@ def forward(self, # type: ignore
:meth:`texar.torch.modules.TransformerDecoder.forward`. Please refer to
it for the detailed usage.
"""
return super().forward(inputs=inputs,
sequence_length=sequence_length,
memory=memory,
memory_sequence_length=memory_sequence_length,
memory_attention_bias=memory_attention_bias,
context=context,
context_sequence_length=context_sequence_length,
helper=helper,
decoding_strategy=decoding_strategy,
max_decoding_length=max_decoding_length,
impute_finished=impute_finished,
infer_mode=infer_mode,
beam_width=beam_width,
length_penalty=length_penalty,
**kwargs)
return self.decoder(inputs=inputs,
sequence_length=sequence_length,
memory=memory,
memory_sequence_length=memory_sequence_length,
memory_attention_bias=memory_attention_bias,
context=context,
context_sequence_length=context_sequence_length,
helper=helper,
decoding_strategy=decoding_strategy,
max_decoding_length=max_decoding_length,
impute_finished=impute_finished,
infer_mode=infer_mode,
beam_width=beam_width,
length_penalty=length_penalty,
**kwargs)
20 changes: 13 additions & 7 deletions texar/torch/modules/decoders/gpt2_decoder_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,30 +34,34 @@ def test_hparams(self):
}
decoder = GPT2Decoder(pretrained_model_name="gpt2-small",
hparams=hparams)
self.assertEqual(decoder.hparams.num_blocks, 12)
self.assertEqual(decoder.hparams.decoder.num_blocks, 12)
_ = decoder(self.inputs)

# case 2: set "pretrained_mode_name" by hparams
hparams = {
"pretrained_model_name": "gpt2-small",
"num_blocks": 6,
"decoder": {
"num_blocks": 6,
},
}
decoder = GPT2Decoder(hparams=hparams)
self.assertEqual(decoder.hparams.num_blocks, 12)
self.assertEqual(decoder.hparams.decoder.num_blocks, 12)
_ = decoder(self.inputs)

# case 3: set to None in both hparams and constructor argument
hparams = {
"pretrained_model_name": None,
"num_blocks": 6,
"decoder": {
"num_blocks": 6,
},
}
decoder = GPT2Decoder(hparams=hparams)
self.assertEqual(decoder.hparams.num_blocks, 6)
self.assertEqual(decoder.hparams.decoder.num_blocks, 6)
_ = decoder(self.inputs)

# case 4: using default hparams
decoder = GPT2Decoder()
self.assertEqual(decoder.hparams.num_blocks, 12)
self.assertEqual(decoder.hparams.decoder.num_blocks, 12)
_ = decoder(self.inputs)

@pretrained_test
Expand Down Expand Up @@ -92,7 +96,9 @@ def get_variable_num(n_layers: int) -> int:
# case 3: self-designed GPT2
hparams = {
"pretrained_model_name": None,
"num_blocks": 6,
"decoder": {
"num_blocks": 6,
},
}
decoder = GPT2Decoder(hparams=hparams)
self.assertEqual(len(decoder.trainable_variables), get_variable_num(6))
Expand Down

0 comments on commit 0fe6d47

Please sign in to comment.