Skip to content

Commit

Permalink
Bugfix in GPT2 model loading (#227)
Browse files Browse the repository at this point in the history
  • Loading branch information
gpengzhi committed Oct 9, 2019
1 parent c6935c2 commit 484b9b2
Showing 1 changed file with 4 additions and 6 deletions.
10 changes: 4 additions & 6 deletions texar/torch/modules/pretrained/gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,14 +226,12 @@ def _init_from_checkpoint(self, pretrained_model_name: str,
for name, array in zip(names, arrays):
if name in global_tensor_map:
v_name = global_tensor_map[name]

if name == "model/wte":
if load_output_layer:
pointer = self._name_to_variable(
"word_embedder.embedding")
assert pointer.shape == array.shape
pointer.data = torch.from_numpy(array)
pointer = self._name_to_variable("word_embedder.embedding")
assert pointer.shape == array.shape
pointer.data = torch.from_numpy(array)

if load_output_layer:
output_pointer = self._name_to_variable(
"_output_layer.weight")
assert output_pointer.shape == array.shape
Expand Down

0 comments on commit 484b9b2

Please sign in to comment.