Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 33 additions & 8 deletions docs/code/modules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -59,16 +59,22 @@ Encoders
.. autoclass:: texar.modules.TransformerEncoder
:members:

:hidden:`BertEncoder`
:hidden:`BERTEncoder`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: texar.modules.BertEncoder
.. autoclass:: texar.modules.BERTEncoder
:members:

:hidden:`GPT2Encoder`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: texar.modules.GPT2Encoder
:members:

:hidden:`XLNetEncoder`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: texar.modules.XLNetEncoder
:members:
:exclude-members: _forward

:hidden:`Conv1DEncoder`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: texar.modules.Conv1DEncoder
Expand Down Expand Up @@ -129,6 +135,17 @@ Decoders
.. autoclass:: texar.modules.GPT2Decoder
:members:

:hidden:`XLNetDecoder`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: texar.modules.XLNetDecoder
:members:
:exclude-members: initialize,step,finalize,_create_input

:hidden:`XLNetDecoderOutput`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: texar.modules.XLNetDecoderOutput
:members:

:hidden:`TransformerDecoder`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: texar.modules.TransformerDecoder
Expand Down Expand Up @@ -197,9 +214,9 @@ Decoders
Classifiers
============

:hidden:`BertClassifier`
:hidden:`BERTClassifier`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: texar.modules.BertClassifier
.. autoclass:: texar.modules.BERTClassifier
:members:

:hidden:`GPT2Classifier`
Expand All @@ -212,6 +229,11 @@ Classifiers
.. autoclass:: texar.modules.Conv1DClassifier
:members:

:hidden:`XLNetClassifier`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: texar.modules.XLNetClassifier
:members:

Networks
========

Expand All @@ -236,14 +258,17 @@ Pre-trained
.. spelling::
pooler

:hidden:`BertBase`
:hidden:`PretrainedBase`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: texar.modules.BertBase
.. autoclass:: texar.modules.PretrainedBase
:members:

:hidden:`GPT2Base`
Regressor
==========

:hidden:`XLNetRegressor`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: texar.modules.GPT2Base
.. autoclass:: texar.modules.XLNetRegressor
:members:

Connectors
Expand Down
2 changes: 2 additions & 0 deletions docs/spelling_wordlist.txt
Original file line number Diff line number Diff line change
Expand Up @@ -58,3 +58,5 @@ fastly
CUDA
precompute
Tokenize
Regressor
regressor
2 changes: 1 addition & 1 deletion examples/xlnet/xlnet/model/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ def forward(self, # type: ignore
if not recompute_memory and start_tokens.size(0) > 1:
_, memory = self._forward(
memory=memory, cache_len=cache_len,
**self.create_input(
**self._create_input(
self._state_previous_inputs, initial=True))
start_tokens = start_tokens[-1]

Expand Down
1 change: 1 addition & 0 deletions texar/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,4 @@
from texar.modules.encoders import *
from texar.modules.networks import *
from texar.modules.pretrained import *
from texar.modules.regressors import *
3 changes: 2 additions & 1 deletion texar/modules/classifiers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
Modules of Texar library classifiers.
"""

from texar.modules.classifiers.bert_classifiers import *
from texar.modules.classifiers.bert_classifier import *
from texar.modules.classifiers.classifier_base import *
from texar.modules.classifiers.conv_classifiers import *
from texar.modules.classifiers.gpt2_classifier import *
from texar.modules.classifiers.xlnet_classifier import *
Original file line number Diff line number Diff line change
Expand Up @@ -22,22 +22,24 @@

from texar.hyperparams import HParams
from texar.modules.classifiers.classifier_base import ClassifierBase
from texar.modules.encoders.bert_encoders import BertEncoder
from texar.utils import utils
from texar.modules.encoders.bert_encoders import BERTEncoder
from texar.utils.utils import dict_fetch

__all__ = ["BertClassifier"]
__all__ = [
"BERTClassifier"
]


class BertClassifier(ClassifierBase):
class BERTClassifier(ClassifierBase):
r"""Classifier based on BERT modules.

This is a combination of the
:class:`~texar.modules.BertEncoder` with a classification
:class:`~texar.modules.BERTEncoder` with a classification
layer. Both step-wise classification and sequence-level classification
are supported, specified in :attr:`hparams`.

Arguments are the same as in
:class:`~texar.modules.BertEncoder`.
:class:`~texar.modules.BERTEncoder`.

Args:
pretrained_model_name (optional): a str with the name
Expand All @@ -62,15 +64,12 @@ def __init__(self,
cache_dir: Optional[str] = None,
hparams=None):

super().__init__(hparams)
super().__init__(hparams=hparams)

# Create the underlying encoder
encoder_hparams = utils.dict_fetch(hparams,
BertEncoder.default_hparams())
if encoder_hparams is not None:
encoder_hparams['name'] = None
encoder_hparams = dict_fetch(hparams, BERTEncoder.default_hparams())

self._encoder = BertEncoder(pretrained_model_name=pretrained_model_name,
self._encoder = BERTEncoder(pretrained_model_name=pretrained_model_name,
cache_dir=cache_dir,
hparams=encoder_hparams)

Expand All @@ -93,13 +92,14 @@ def __init__(self,

if self._hparams.clas_strategy == 'all_time':
self._logits_layer = nn.Linear(
self._hparams.hidden_size * self._hparams.max_seq_length,
self._encoder.output_size *
self._hparams.max_seq_length,
self.num_classes,
**logit_kwargs)
else:
self._logits_layer = nn.Linear(self._hparams.hidden_size,
self.num_classes,
**logit_kwargs)
self._logits_layer = nn.Linear(
self._encoder.output_size, self.num_classes,
**logit_kwargs)

self.is_binary = (self.num_classes == 1) or \
(self.num_classes <= 0 and
Expand All @@ -126,13 +126,13 @@ def default_hparams():
Here:

1. Same hyperparameters as in
:class:`~texar.modules.BertEncoder`.
See the :meth:`~texar.modules.BertEncoder.default_hparams`.
An instance of BertEncoder is created for feature extraction.
:class:`~texar.modules.BERTEncoder`.
See the :meth:`~texar.modules.BERTEncoder.default_hparams`.
An instance of BERTEncoder is created for feature extraction.

2. Additional hyperparameters:

`num_classes`: int
`"num_classes"`: int
Number of classes:

- If **> 0**, an additional `Linear`
Expand All @@ -142,12 +142,12 @@ def default_hparams():
classes is assumed to be the final dense layer size of the
encoder.

`logit_layer_kwargs`: dict
`"logit_layer_kwargs"`: dict
Keyword arguments for the logit Dense layer constructor,
except for argument "units" which is set to `num_classes`.
Ignored if no extra logit layer is appended.

`clas_strategy`: str
`"clas_strategy"`: str
The classification strategy, one of:

- **cls_time**: Sequence-level classification based on the
Expand All @@ -158,18 +158,18 @@ def default_hparams():
- **time_wise**: Step-wise classification, i.e., make
classification for each time step based on its output.

`max_seq_length`: int, optional
`"max_seq_length"`: int, optional
Maximum possible length of input sequences. Required if
`clas_strategy` is `all_time`.

`dropout`: float
`"dropout"`: float
The dropout rate of the BERT encoder output.

`name`: str
`"name"`: str
Name of the classifier.
"""

hparams = BertEncoder.default_hparams()
hparams = BERTEncoder.default_hparams()
hparams.update({
"num_classes": 2,
"logit_layer_kwargs": None,
Expand All @@ -187,7 +187,7 @@ def forward(self, # type: ignore
-> Tuple[torch.Tensor, torch.LongTensor]:
r"""Feeds the inputs through the network and makes classification.

The arguments are the same as in :class:`~texar.modules.BertEncoder`.
The arguments are the same as in :class:`~texar.modules.BERTEncoder`.

Args:
inputs: A 2D Tensor of shape `[batch_size, max_time]`,
Expand Down Expand Up @@ -233,7 +233,7 @@ def forward(self, # type: ignore
# Pad `enc_outputs` to have max_seq_length before flatten
length_diff = self._hparams.max_seq_length - inputs.shape[1]
logit_input = F.pad(enc_outputs, [0, 0, 0, length_diff, 0, 0])
logit_input_dim = (self._hparams.hidden_size *
logit_input_dim = (self._encoder.output_size *
self._hparams.max_seq_length)
logits = logit_input.view(-1, logit_input_dim)
else:
Expand Down
Loading