Skip to content

Commit

Permalink
Resolve Comments in #120 (#128)
Browse files Browse the repository at this point in the history
* Resolve comments
  • Loading branch information
gpengzhi committed Jul 27, 2019
1 parent c9b82dc commit 5018cef
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 40 deletions.
4 changes: 4 additions & 0 deletions texar/modules/classifiers/xlnet_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@ def __init__(self,
cache_dir=cache_dir,
hparams=encoder_hparams)

# TODO: The logic here is very similar to that in XLNetRegressor.
# We need to reduce the code redundancy.
if self._hparams.use_projection:
if self._hparams.clas_strategy == 'all_time':
self.projection = nn.Linear(
Expand Down Expand Up @@ -208,6 +210,8 @@ def param_groups(self,
The parameter groups, used as the first argument for optimizers.
"""

# TODO: Same logic in XLNetRegressor. Reduce code redundancy.

if lr_layer_scale != 1.0:
if lr is None:
raise ValueError(
Expand Down
43 changes: 3 additions & 40 deletions texar/modules/encoders/xlnet_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,46 +424,9 @@ def _forward(self,
-> Tuple[torch.Tensor, Optional[List[torch.Tensor]]]:
r"""Compute XLNet representations for the input. This layer exists
because :class:`XLNetDecoder` compute embeddings in the decoder helper.
Args:
word_embed: Shape `[batch_size, seq_len, word_embed_dim]`.
segment_ids: Shape `[batch_size, seq_len]`.
input_mask: Float tensor of shape `[batch_size, seq_len]`. Note that
positions with value 1 are masked out.
memory: Memory from previous batches. A list of length `num_layers`,
each tensor of shape `[batch_size, mem_len, hidden_dim]`.
permute_mask: The permutation mask. Float tensor of shape
`[batch_size, seq_len, seq_len]`.
A value of 0 for ``permute_mask[i, j, k]`` indicates that
position `i` attends to position `j` in batch `k`.
target_mapping: The target token mapping. Float tensor of shape
`[batch_size, num_targets, seq_len]`.
A value of 1 for ``target_mapping[i, j, k]`` indicates that
the `i`-th target token (in order of permutation) in batch `k`
is the token at position `j`.
Each row ``target_mapping[i, :, k]`` can have no more than one
value of 1.
bi_data (bool): Whether to use bidirectional data input pipeline.
clamp_len (int): Clamp all relative distances larger than
:attr:`clamp_len`. A value of -1 means no clamping.
cache_len (int): Length of memory (number of tokens) to cache.
same_length (bool): Whether to use the same attention length for
each token.
attn_type (str): Attention type. Supported values are `"uni"`
and `"bi"`.
two_stream (bool): Whether to use two-stream attention. Only set to
`True` when pre-training or generating text. Defaults to
`False`.
:returns: A tuple of `(output, new_memory)`:
- **`output`**: The final layer output representations. Shape
`[batch_size, seq_len, hidden_dim]`.
- **`new_memory`**: The memory of the current batch.
If `cache_len` is 0, then `new_memory` is `None`. Otherwise, it is
a list of length `num_layers`, each tensor of shape
`[batch_size, cache_len, hidden_dim]`.
This can be used as the :attr:`memory` argument in the next batch.
`word_embed` has shape `[batch_size, seq_len, word_embed_dim]`.
Please refer to :meth:`forward` for the detailed information of other
arguments.
"""
# word_embed: [seq_len, batch_size, word_embed_dim]
word_embed = word_embed.permute(1, 0, 2)
Expand Down
4 changes: 4 additions & 0 deletions texar/modules/regressors/xlnet_regressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@ def __init__(self,
cache_dir=cache_dir,
hparams=encoder_hparams)

# TODO: The logic here is very similar to that in XLNetClassifier.
# We need to reduce the code redundancy.
if self._hparams.use_projection:
if self._hparams.regr_strategy == 'all_time':
self.projection = nn.Linear(
Expand Down Expand Up @@ -188,6 +190,8 @@ def param_groups(self,
The parameter groups, used as the first argument for optimizers.
"""

# TODO: Same logic in XLNetClassifier. Reduce code redundancy.

if lr_layer_scale != 1.0:
if lr is None:
raise ValueError(
Expand Down

0 comments on commit 5018cef

Please sign in to comment.