Skip to content

Commit

Permalink
nn.Identity bugfix (#124)
Browse files Browse the repository at this point in the history
  • Loading branch information
gpengzhi authored and huzecong committed Jul 26, 2019
1 parent 95d6d12 commit c9b82dc
Show file tree
Hide file tree
Showing 4 changed files with 7 additions and 98 deletions.
4 changes: 2 additions & 2 deletions texar/modules/decoders/decoder_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import torch
from torch import nn

from texar.core.layers import identity
from texar.core.layers import identity, Identity
from texar.module_base import ModuleBase
from texar.modules.decoders import decoder_helpers
from texar.modules.decoders.decoder_helpers import Embedding, Helper
Expand Down Expand Up @@ -71,7 +71,7 @@ def _make_output_layer(layer: Optional[Union[nn.Module, torch.Tensor]],
layer = nn.Parameter(layer, requires_grad=False)
output_layer.weight = layer
elif layer is identity:
output_layer = identity # type: ignore
output_layer = Identity()
else:
raise ValueError(
f"output_layer should be an instance of `nn.Module`, a tensor,"
Expand Down
5 changes: 2 additions & 3 deletions texar/modules/encoders/gpt2_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from typing import Optional

import torch
import torch.nn as nn

from texar.core import layers
from texar.hyperparams import HParams
Expand Down Expand Up @@ -86,9 +85,9 @@ def __init__(self,
hparams=self._hparams.position_embed)

# The GPT2 encoder (a TransformerDecoder)
self.decoder = TransformerDecoder(
self.decoder = TransformerDecoder( # type: ignore
vocab_size=self._hparams.vocab_size,
output_layer=nn.Identity(),
output_layer=layers.identity,
hparams=self._hparams.decoder)

if self.pretrained_model_dir:
Expand Down
91 changes: 0 additions & 91 deletions texar/modules/pretrained/gpt2_base.py

This file was deleted.

5 changes: 3 additions & 2 deletions texar/modules/pretrained/gpt2_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import torch
import torch.nn as nn

from texar.core import layers
from texar.data.data_utils import maybe_download
from texar.modules.pretrained.pretrained_utils import default_download_dir

Expand Down Expand Up @@ -111,7 +112,7 @@ def init_gpt2_checkpoint(model: nn.Module, cache_dir: str):

output_pointer = name_to_variable(
model.decoder, "_output_layer.weight")
if not isinstance(output_pointer, nn.Identity):
if not isinstance(output_pointer, layers.Identity):
assert output_pointer.shape == array.shape
output_pointer.data = torch.from_numpy(array)
elif name == "model/wpe":
Expand Down Expand Up @@ -195,7 +196,7 @@ def name_to_variable(model: nn.Module, name: str) -> nn.Module:
num = int(m_name)
pointer = pointer[num] # type: ignore
else:
if not isinstance(pointer, nn.Identity):
if not isinstance(pointer, layers.Identity):
pointer = getattr(pointer, m_name)
return pointer

Expand Down

0 comments on commit c9b82dc

Please sign in to comment.