diff --git a/docs/code/modules.rst b/docs/code/modules.rst index 8005fd088..766c9855a 100644 --- a/docs/code/modules.rst +++ b/docs/code/modules.rst @@ -59,9 +59,9 @@ Encoders .. autoclass:: texar.modules.TransformerEncoder :members: -:hidden:`BertEncoder` +:hidden:`BERTEncoder` ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -.. autoclass:: texar.modules.BertEncoder +.. autoclass:: texar.modules.BERTEncoder :members: :hidden:`GPT2Encoder` @@ -69,6 +69,12 @@ Encoders .. autoclass:: texar.modules.GPT2Encoder :members: +:hidden:`XLNetEncoder` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +.. autoclass:: texar.modules.XLNetEncoder + :members: + :exclude-members: _forward + :hidden:`Conv1DEncoder` ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. autoclass:: texar.modules.Conv1DEncoder @@ -129,6 +135,17 @@ Decoders .. autoclass:: texar.modules.GPT2Decoder :members: +:hidden:`XLNetDecoder` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +.. autoclass:: texar.modules.XLNetDecoder + :members: + :exclude-members: initialize,step,finalize,_create_input + +:hidden:`XLNetDecoderOutput` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +.. autoclass:: texar.modules.XLNetDecoderOutput + :members: + :hidden:`TransformerDecoder` ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. autoclass:: texar.modules.TransformerDecoder @@ -197,9 +214,9 @@ Decoders Classifiers ============ -:hidden:`BertClassifier` +:hidden:`BERTClassifier` ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -.. autoclass:: texar.modules.BertClassifier +.. autoclass:: texar.modules.BERTClassifier :members: :hidden:`GPT2Classifier` @@ -212,6 +229,11 @@ Classifiers .. autoclass:: texar.modules.Conv1DClassifier :members: +:hidden:`XLNetClassifier` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +.. autoclass:: texar.modules.XLNetClassifier + :members: + Networks ======== @@ -236,14 +258,17 @@ Pre-trained .. spelling:: pooler -:hidden:`BertBase` +:hidden:`PretrainedBase` ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -.. autoclass:: texar.modules.BertBase +.. autoclass:: texar.modules.PretrainedBase :members: -:hidden:`GPT2Base` +Regressor +========== + +:hidden:`XLNetRegressor` ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -.. autoclass:: texar.modules.GPT2Base +.. autoclass:: texar.modules.XLNetRegressor :members: Connectors diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt index 8f6173ca3..470788086 100644 --- a/docs/spelling_wordlist.txt +++ b/docs/spelling_wordlist.txt @@ -58,3 +58,5 @@ fastly CUDA precompute Tokenize +Regressor +regressor diff --git a/examples/xlnet/xlnet/model/decoder.py b/examples/xlnet/xlnet/model/decoder.py index 691e73026..4f8d57f63 100644 --- a/examples/xlnet/xlnet/model/decoder.py +++ b/examples/xlnet/xlnet/model/decoder.py @@ -180,7 +180,7 @@ def forward(self, # type: ignore if not recompute_memory and start_tokens.size(0) > 1: _, memory = self._forward( memory=memory, cache_len=cache_len, - **self.create_input( + **self._create_input( self._state_previous_inputs, initial=True)) start_tokens = start_tokens[-1] diff --git a/texar/modules/__init__.py b/texar/modules/__init__.py index 19683e02f..1f6b33103 100644 --- a/texar/modules/__init__.py +++ b/texar/modules/__init__.py @@ -22,3 +22,4 @@ from texar.modules.encoders import * from texar.modules.networks import * from texar.modules.pretrained import * +from texar.modules.regressors import * diff --git a/texar/modules/classifiers/__init__.py b/texar/modules/classifiers/__init__.py index d573aa088..b2e2f90f0 100644 --- a/texar/modules/classifiers/__init__.py +++ b/texar/modules/classifiers/__init__.py @@ -15,7 +15,8 @@ Modules of Texar library classifiers. """ -from texar.modules.classifiers.bert_classifiers import * +from texar.modules.classifiers.bert_classifier import * from texar.modules.classifiers.classifier_base import * from texar.modules.classifiers.conv_classifiers import * from texar.modules.classifiers.gpt2_classifier import * +from texar.modules.classifiers.xlnet_classifier import * diff --git a/texar/modules/classifiers/bert_classifiers.py b/texar/modules/classifiers/bert_classifier.py similarity index 88% rename from texar/modules/classifiers/bert_classifiers.py rename to texar/modules/classifiers/bert_classifier.py index 0fee957f0..8922ca514 100644 --- a/texar/modules/classifiers/bert_classifiers.py +++ b/texar/modules/classifiers/bert_classifier.py @@ -22,22 +22,24 @@ from texar.hyperparams import HParams from texar.modules.classifiers.classifier_base import ClassifierBase -from texar.modules.encoders.bert_encoders import BertEncoder -from texar.utils import utils +from texar.modules.encoders.bert_encoders import BERTEncoder +from texar.utils.utils import dict_fetch -__all__ = ["BertClassifier"] +__all__ = [ + "BERTClassifier" +] -class BertClassifier(ClassifierBase): +class BERTClassifier(ClassifierBase): r"""Classifier based on BERT modules. This is a combination of the - :class:`~texar.modules.BertEncoder` with a classification + :class:`~texar.modules.BERTEncoder` with a classification layer. Both step-wise classification and sequence-level classification are supported, specified in :attr:`hparams`. Arguments are the same as in - :class:`~texar.modules.BertEncoder`. + :class:`~texar.modules.BERTEncoder`. Args: pretrained_model_name (optional): a str with the name @@ -62,15 +64,12 @@ def __init__(self, cache_dir: Optional[str] = None, hparams=None): - super().__init__(hparams) + super().__init__(hparams=hparams) # Create the underlying encoder - encoder_hparams = utils.dict_fetch(hparams, - BertEncoder.default_hparams()) - if encoder_hparams is not None: - encoder_hparams['name'] = None + encoder_hparams = dict_fetch(hparams, BERTEncoder.default_hparams()) - self._encoder = BertEncoder(pretrained_model_name=pretrained_model_name, + self._encoder = BERTEncoder(pretrained_model_name=pretrained_model_name, cache_dir=cache_dir, hparams=encoder_hparams) @@ -93,13 +92,14 @@ def __init__(self, if self._hparams.clas_strategy == 'all_time': self._logits_layer = nn.Linear( - self._hparams.hidden_size * self._hparams.max_seq_length, + self._encoder.output_size * + self._hparams.max_seq_length, self.num_classes, **logit_kwargs) else: - self._logits_layer = nn.Linear(self._hparams.hidden_size, - self.num_classes, - **logit_kwargs) + self._logits_layer = nn.Linear( + self._encoder.output_size, self.num_classes, + **logit_kwargs) self.is_binary = (self.num_classes == 1) or \ (self.num_classes <= 0 and @@ -126,13 +126,13 @@ def default_hparams(): Here: 1. Same hyperparameters as in - :class:`~texar.modules.BertEncoder`. - See the :meth:`~texar.modules.BertEncoder.default_hparams`. - An instance of BertEncoder is created for feature extraction. + :class:`~texar.modules.BERTEncoder`. + See the :meth:`~texar.modules.BERTEncoder.default_hparams`. + An instance of BERTEncoder is created for feature extraction. 2. Additional hyperparameters: - `num_classes`: int + `"num_classes"`: int Number of classes: - If **> 0**, an additional `Linear` @@ -142,12 +142,12 @@ def default_hparams(): classes is assumed to be the final dense layer size of the encoder. - `logit_layer_kwargs`: dict + `"logit_layer_kwargs"`: dict Keyword arguments for the logit Dense layer constructor, except for argument "units" which is set to `num_classes`. Ignored if no extra logit layer is appended. - `clas_strategy`: str + `"clas_strategy"`: str The classification strategy, one of: - **cls_time**: Sequence-level classification based on the @@ -158,18 +158,18 @@ def default_hparams(): - **time_wise**: Step-wise classification, i.e., make classification for each time step based on its output. - `max_seq_length`: int, optional + `"max_seq_length"`: int, optional Maximum possible length of input sequences. Required if `clas_strategy` is `all_time`. - `dropout`: float + `"dropout"`: float The dropout rate of the BERT encoder output. - `name`: str + `"name"`: str Name of the classifier. """ - hparams = BertEncoder.default_hparams() + hparams = BERTEncoder.default_hparams() hparams.update({ "num_classes": 2, "logit_layer_kwargs": None, @@ -187,7 +187,7 @@ def forward(self, # type: ignore -> Tuple[torch.Tensor, torch.LongTensor]: r"""Feeds the inputs through the network and makes classification. - The arguments are the same as in :class:`~texar.modules.BertEncoder`. + The arguments are the same as in :class:`~texar.modules.BERTEncoder`. Args: inputs: A 2D Tensor of shape `[batch_size, max_time]`, @@ -233,7 +233,7 @@ def forward(self, # type: ignore # Pad `enc_outputs` to have max_seq_length before flatten length_diff = self._hparams.max_seq_length - inputs.shape[1] logit_input = F.pad(enc_outputs, [0, 0, 0, length_diff, 0, 0]) - logit_input_dim = (self._hparams.hidden_size * + logit_input_dim = (self._encoder.output_size * self._hparams.max_seq_length) logits = logit_input.view(-1, logit_input_dim) else: diff --git a/texar/modules/classifiers/bert_classifiers_test.py b/texar/modules/classifiers/bert_classifier_test.py similarity index 60% rename from texar/modules/classifiers/bert_classifiers_test.py rename to texar/modules/classifiers/bert_classifier_test.py index d12f4d6c0..c41cc256b 100644 --- a/texar/modules/classifiers/bert_classifiers_test.py +++ b/texar/modules/classifiers/bert_classifier_test.py @@ -6,14 +6,48 @@ import torch -from texar.modules.classifiers.bert_classifiers import * +from texar.modules.classifiers.bert_classifier import * -@unittest.skip("Manual test only") -class BertClassifierTest(unittest.TestCase): - r"""Tests :class:`~texar.modules.BertClassifier` class. +class BERTClassifierTest(unittest.TestCase): + r"""Tests :class:`~texar.modules.BERTClassifier` class. """ + @unittest.skip("Manual test only") + def test_model_loading(self): + r"""Tests model loading functionality.""" + inputs = torch.zeros(32, 16, dtype=torch.int64) + + # case 1 + classifier = BERTClassifier(pretrained_model_name="bert-base-uncased") + _, _ = classifier(inputs) + + # case 2 + classifier = BERTClassifier(pretrained_model_name="bert-large-uncased") + _, _ = classifier(inputs) + + # case 3 + classifier = BERTClassifier(pretrained_model_name="bert-base-cased") + _, _ = classifier(inputs) + + # case 4 + classifier = BERTClassifier(pretrained_model_name="bert-large-cased") + _, _ = classifier(inputs) + + # case 5 + classifier = BERTClassifier( + pretrained_model_name="bert-base-multilingual-uncased") + _, _ = classifier(inputs) + + # case 6 + classifier = BERTClassifier( + pretrained_model_name="bert-base-multilingual-cased") + _, _ = classifier(inputs) + + # case 7 + classifier = BERTClassifier(pretrained_model_name="bert-base-chinese") + _, _ = classifier(inputs) + def test_trainable_variables(self): r"""Tests the functionality of automatically collecting trainable variables. @@ -21,36 +55,44 @@ def test_trainable_variables(self): inputs = torch.zeros(32, 16, dtype=torch.int64) # case 1 - classifier = BertClassifier() + hparams = { + "pretrained_model_name": None, + } + classifier = BERTClassifier(hparams=hparams) _, _ = classifier(inputs) self.assertEqual(len(classifier.trainable_variables), 199 + 2) # case 2 hparams = { + "pretrained_model_name": None, "clas_strategy": "all_time", "max_seq_length": 8, } - classifier = BertClassifier(hparams=hparams) + classifier = BERTClassifier(hparams=hparams) _, _ = classifier(inputs) self.assertEqual(len(classifier.trainable_variables), 199 + 2) # case 3 hparams = { + "pretrained_model_name": None, "clas_strategy": "time_wise", } - classifier = BertClassifier(hparams=hparams) + classifier = BERTClassifier(hparams=hparams) _, _ = classifier(inputs) self.assertEqual(len(classifier.trainable_variables), 199 + 2) - def test_encode(self): - r"""Tests encoding. + def test_classification(self): + r"""Tests classification. """ max_time = 8 batch_size = 16 inputs = torch.randint(30521, (batch_size, max_time), dtype=torch.int64) # case 1 - classifier = BertClassifier() + hparams = { + "pretrained_model_name": None, + } + classifier = BERTClassifier(hparams=hparams) logits, preds = classifier(inputs) self.assertEqual(logits.shape, torch.Size( @@ -59,10 +101,11 @@ def test_encode(self): # case 2 hparams = { + "pretrained_model_name": None, "num_classes": 10, "clas_strategy": "time_wise" } - classifier = BertClassifier(hparams=hparams) + classifier = BERTClassifier(hparams=hparams) logits, preds = classifier(inputs) self.assertEqual(logits.shape, torch.Size( @@ -71,10 +114,11 @@ def test_encode(self): # case 3 hparams = { + "pretrained_model_name": None, "num_classes": 0, "clas_strategy": "time_wise" } - classifier = BertClassifier(hparams=hparams) + classifier = BERTClassifier(hparams=hparams) logits, preds = classifier(inputs) self.assertEqual(logits.shape, torch.Size( @@ -83,19 +127,20 @@ def test_encode(self): # case 4 hparams = { + "pretrained_model_name": None, "num_classes": 10, "clas_strategy": "all_time", "max_seq_length": max_time } inputs = torch.randint(30521, (batch_size, 6), dtype=torch.int64) - classifier = BertClassifier(hparams=hparams) + classifier = BERTClassifier(hparams=hparams) logits, preds = classifier(inputs) self.assertEqual(logits.shape, torch.Size( [batch_size, classifier.hparams.num_classes])) self.assertEqual(preds.shape, torch.Size([batch_size])) - def test_binary(self): + def _test_binary(self): r"""Tests binary classification. """ max_time = 8 @@ -104,10 +149,11 @@ def test_binary(self): # case 1 hparams = { + "pretrained_model_name": None, "num_classes": 1, "clas_strategy": "time_wise" } - classifier = BertClassifier(hparams=hparams) + classifier = BERTClassifier(hparams=hparams) logits, preds = classifier(inputs) self.assertEqual(logits.shape, torch.Size([batch_size, max_time])) @@ -115,12 +161,13 @@ def test_binary(self): # case 2 hparams = { + "pretrained_model_name": None, "num_classes": 1, "clas_strategy": "cls_time", "max_seq_length": max_time } inputs = torch.randint(30521, (batch_size, 6), dtype=torch.int64) - classifier = BertClassifier(hparams=hparams) + classifier = BERTClassifier(hparams=hparams) logits, preds = classifier(inputs) self.assertEqual(logits.shape, torch.Size([batch_size])) @@ -128,12 +175,13 @@ def test_binary(self): # case 3 hparams = { + "pretrained_model_name": None, "num_classes": 1, "clas_strategy": "all_time", "max_seq_length": max_time } inputs = torch.randint(30521, (batch_size, 6), dtype=torch.int64) - classifier = BertClassifier(hparams=hparams) + classifier = BERTClassifier(hparams=hparams) logits, preds = classifier(inputs) self.assertEqual(logits.shape, torch.Size([batch_size])) diff --git a/texar/modules/classifiers/classifier_base.py b/texar/modules/classifiers/classifier_base.py index 7606c4388..7986e70c1 100644 --- a/texar/modules/classifiers/classifier_base.py +++ b/texar/modules/classifiers/classifier_base.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """ -Base class for encoders. +Base class for classifiers. """ from abc import ABC from typing import Any, Dict diff --git a/texar/modules/classifiers/gpt2_classifier.py b/texar/modules/classifiers/gpt2_classifier.py index 0f774f9e3..a39cc0612 100644 --- a/texar/modules/classifiers/gpt2_classifier.py +++ b/texar/modules/classifiers/gpt2_classifier.py @@ -14,7 +14,6 @@ """ GPT2 classifiers. """ - from typing import Optional, Tuple import torch @@ -24,9 +23,11 @@ from texar.hyperparams import HParams from texar.modules.classifiers.classifier_base import ClassifierBase from texar.modules.encoders.gpt2_encoder import GPT2Encoder -from texar.utils import utils +from texar.utils.utils import dict_fetch -__all__ = ["GPT2Classifier"] +__all__ = [ + "GPT2Classifier" +] class GPT2Classifier(ClassifierBase): @@ -61,14 +62,10 @@ def __init__(self, cache_dir: Optional[str] = None, hparams=None): - super().__init__(hparams) + super().__init__(hparams=hparams) # Create the underlying encoder - encoder_hparams = utils.dict_fetch(hparams, - GPT2Encoder.default_hparams()) - - if encoder_hparams is not None: - encoder_hparams['name'] = None + encoder_hparams = dict_fetch(hparams, GPT2Encoder.default_hparams()) self._encoder = GPT2Encoder(pretrained_model_name=pretrained_model_name, cache_dir=cache_dir, @@ -93,13 +90,14 @@ def __init__(self, if self._hparams.clas_strategy == 'all_time': self._logits_layer = nn.Linear( - self._hparams.decoder.dim * self._hparams.max_seq_length, + self._encoder.output_size * + self._hparams.max_seq_length, self.num_classes, **logit_kwargs) else: - self._logits_layer = nn.Linear(self._hparams.decoder.dim, - self.num_classes, - **logit_kwargs) + self._logits_layer = nn.Linear( + self._encoder.output_size, self.num_classes, + **logit_kwargs) self.is_binary = (self.num_classes == 1) or \ (self.num_classes <= 0 and @@ -132,7 +130,7 @@ def default_hparams(): 2. Additional hyperparameters: - `num_classes`: int + `"num_classes"`: int Number of classes: - If **> 0**, an additional `Linear` @@ -142,12 +140,12 @@ def default_hparams(): classes is assumed to be the final dense layer size of the encoder. - `logit_layer_kwargs`: dict + `"logit_layer_kwargs"`: dict Keyword arguments for the logit Dense layer constructor, except for argument "units" which is set to `num_classes`. Ignored if no extra logit layer is appended. - `clas_strategy`: str + `"clas_strategy"`: str The classification strategy, one of: - **cls_time**: Sequence-level classification based on the @@ -157,14 +155,14 @@ def default_hparams(): - **time_wise**: Step-wise classification, i.e., make classification for each time step based on its output. - `max_seq_length`: int, optional + `"max_seq_length"`: int, optional Maximum possible length of input sequences. Required if `clas_strategy` is `all_time`. - `dropout`: float + `"dropout"`: float The dropout rate of the GPT2 encoder output. - `name`: str + `"name"`: str Name of the classifier. """ @@ -231,7 +229,7 @@ def forward(self, # type: ignore # Pad `enc_outputs` to have max_seq_length before flatten length_diff = self._hparams.max_seq_length - inputs.shape[1] logit_input = F.pad(enc_outputs, [0, 0, 0, length_diff, 0, 0]) - logit_input_dim = (self._hparams.decoder.dim * + logit_input_dim = (self._encoder.output_size * self._hparams.max_seq_length) logits = logit_input.view(-1, logit_input_dim) else: diff --git a/texar/modules/classifiers/gpt2_classifier_test.py b/texar/modules/classifiers/gpt2_classifier_test.py index c466f7f20..ddbcb81b5 100644 --- a/texar/modules/classifiers/gpt2_classifier_test.py +++ b/texar/modules/classifiers/gpt2_classifier_test.py @@ -9,11 +9,23 @@ from texar.modules.classifiers.gpt2_classifier import * -@unittest.skip("Manual test only") class GPT2ClassifierTest(unittest.TestCase): r"""Tests :class:`~texar.modules.GPT2Classifier` class. """ + @unittest.skip("Manual test only") + def test_model_loading(self): + r"""Tests model loading functionality.""" + inputs = torch.zeros(32, 16, dtype=torch.int64) + + # case 1 + classifier = GPT2Classifier(pretrained_model_name="117M") + _, _ = classifier(inputs) + + # case 2 + classifier = GPT2Classifier(pretrained_model_name="345M") + _, _ = classifier(inputs) + def test_trainable_variables(self): r"""Tests the functionality of automatically collecting trainable variables. @@ -21,12 +33,16 @@ def test_trainable_variables(self): inputs = torch.zeros(32, 16, dtype=torch.int64) # case 1 - classifier = GPT2Classifier() + hparams = { + "pretrained_model_name": None, + } + classifier = GPT2Classifier(hparams=hparams) _, _ = classifier(inputs) self.assertEqual(len(classifier.trainable_variables), 318) # case 2 hparams = { + "pretrained_model_name": None, "clas_strategy": "all_time", "max_seq_length": 8, } @@ -36,21 +52,25 @@ def test_trainable_variables(self): # case 3 hparams = { + "pretrained_model_name": None, "clas_strategy": "time_wise", } classifier = GPT2Classifier(hparams=hparams) _, _ = classifier(inputs) self.assertEqual(len(classifier.trainable_variables), 318) - def test_encode(self): - r"""Tests encoding. + def test_classification(self): + r"""Tests classificaiton. """ max_time = 8 batch_size = 16 inputs = torch.randint(30521, (batch_size, max_time), dtype=torch.int64) # case 1 - classifier = GPT2Classifier() + hparams = { + "pretrained_model_name": None, + } + classifier = GPT2Classifier(hparams=hparams) logits, preds = classifier(inputs) self.assertEqual(logits.shape, torch.Size( @@ -59,6 +79,7 @@ def test_encode(self): # case 2 hparams = { + "pretrained_model_name": None, "num_classes": 10, "clas_strategy": "time_wise" } @@ -71,6 +92,7 @@ def test_encode(self): # case 3 hparams = { + "pretrained_model_name": None, "num_classes": 0, "clas_strategy": "time_wise" } @@ -83,6 +105,7 @@ def test_encode(self): # case 4 hparams = { + "pretrained_model_name": None, "num_classes": 10, "clas_strategy": "all_time", "max_seq_length": max_time @@ -104,6 +127,7 @@ def test_binary(self): # case 1 hparams = { + "pretrained_model_name": None, "num_classes": 1, "clas_strategy": "time_wise" } @@ -115,6 +139,7 @@ def test_binary(self): # case 2 hparams = { + "pretrained_model_name": None, "num_classes": 1, "clas_strategy": "cls_time", "max_seq_length": max_time @@ -128,6 +153,7 @@ def test_binary(self): # case 3 hparams = { + "pretrained_model_name": None, "num_classes": 1, "clas_strategy": "all_time", "max_seq_length": max_time diff --git a/texar/modules/classifiers/xlnet_classifier.py b/texar/modules/classifiers/xlnet_classifier.py new file mode 100644 index 000000000..08c0ff303 --- /dev/null +++ b/texar/modules/classifiers/xlnet_classifier.py @@ -0,0 +1,288 @@ +# Copyright 2019 The Texar Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +XLNet Classifier. +""" + +from typing import Any, Dict, Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from texar.hyperparams import HParams +from texar.modules.classifiers.classifier_base import ClassifierBase +from texar.modules.encoders.xlnet_encoder import XLNetEncoder +from texar.modules.pretrained.xlnet_utils import params_except_in +from texar.utils.utils import dict_fetch + + +__all__ = [ + "XLNetClassifier", +] + + +class XLNetClassifier(ClassifierBase): + r"""Classifier based on XLNet modules. + + Arguments are the same as in + :class:`~texar.modules.XLNetEncoder`. + + Args: + pretrained_model_name (optional): a str with the name + of a pre-trained model to load selected in the list of: + `xlnet-base-cased`, `xlnet-large-cased`. + If `None`, will use the model name in :attr:`hparams`. + cache_dir (optional): the path to a folder in which the + pre-trained models will be cached. If `None` (default), + a default directory will be used. + hparams (dict or HParams, optional): Hyperparameters. Missing + hyperparameters will be set to default values. See + :meth:`default_hparams` for the hyperparameter structure + and default values. + """ + + def __init__(self, + pretrained_model_name: Optional[str] = None, + cache_dir: Optional[str] = None, + hparams=None): + + super().__init__(hparams=hparams) + + # Create the underlying encoder + encoder_hparams = dict_fetch(hparams, XLNetEncoder.default_hparams()) + + self._encoder = XLNetEncoder( + pretrained_model_name=pretrained_model_name, + cache_dir=cache_dir, + hparams=encoder_hparams) + + if self._hparams.use_projection: + if self._hparams.clas_strategy == 'all_time': + self.projection = nn.Linear( + self._encoder.output_size * self._hparams.max_seq_length, + self._encoder.output_size * self._hparams.max_seq_length) + else: + self.projection = nn.Linear(self._encoder.output_size, + self._encoder.output_size) + self.dropout = nn.Dropout(self._hparams.dropout) + + # Create an additional classification layer if needed + self.num_classes = self._hparams.num_classes + if self.num_classes <= 0: + self.hidden_to_logits = None + else: + logit_kwargs = self._hparams.logit_layer_kwargs + if logit_kwargs is None: + logit_kwargs = {} + elif not isinstance(logit_kwargs, HParams): + raise ValueError("hparams['logit_layer_kwargs'] " + "must be a dict.") + else: + logit_kwargs = logit_kwargs.todict() + + if self._hparams.clas_strategy == 'all_time': + self.hidden_to_logits = nn.Linear( + self._encoder.output_size * self._hparams.max_seq_length, + self.num_classes, + **logit_kwargs) + else: + self.hidden_to_logits = nn.Linear( + self._encoder.output_size, self.num_classes, + **logit_kwargs) + + self.is_binary = (self.num_classes == 1) or \ + (self.num_classes <= 0 and + self._hparams.hidden_dim == 1) + + @staticmethod + def default_hparams() -> Dict[str, Any]: + r"""Returns a dictionary of hyperparameters with default values. + + .. code-block:: python + + { + # (1) Same hyperparameters as in XLNetEncoder + ... + # (2) Additional hyperparameters + "clas_strategy": "cls_time", + "use_projection": True, + "num_classes": 2, + "name": "xlnet_classifier", + } + + Here: + + 1. Same hyperparameters as in + :class:`~texar.modules.XLNetEncoder`. + See the :meth:`~texar.modules.XLNetEncoder.default_hparams`. + An instance of XLNetEncoder is created for feature extraction. + + 2. Additional hyperparameters: + + `"clas_strategy"`: str + The classification strategy, one of: + + - **cls_time**: Sequence-level classification based on the + output of the last time step (which is the `CLS` token). + Each sequence has a class. + - **all_time**: Sequence-level classification based on + the output of all time steps. Each sequence has a class. + - **time_wise**: Step-wise classification, i.e., make + classification for each time step based on its output. + + `"use_projection"`: bool + If `True`, an additional `Linear` layer is added after the + summary step. + + `"num_classes"`: int + Number of classes: + + - If **> 0**, an additional :torch_nn:`Linear` + layer is appended to the encoder to compute the logits over + classes. + - If **<= 0**, no dense layer is appended. The number of + classes is assumed to be the final dense layer size of the + encoder. + + `"name"`: str + Name of the classifier. + """ + + hparams = XLNetEncoder.default_hparams() + hparams.update({ + "clas_strategy": "cls_time", + "use_projection": True, + "num_classes": 2, + "logit_layer_kwargs": None, + "name": "xlnet_classifier", + }) + return hparams + + def param_groups(self, + lr: Optional[float] = None, + lr_layer_scale: float = 1.0, + decay_base_params: bool = False): + r"""Create parameter groups for optimizers. When + :attr:`lr_layer_decay_rate` is not 1.0, parameters from each layer form + separate groups with different base learning rates. + + Args: + lr (float): The learning rate. Can be omitted if + :attr:`lr_layer_decay_rate` is 1.0. + lr_layer_scale (float): Per-layer LR scaling rate. The `i`-th layer + will be scaled by `lr_layer_scale ^ (num_layers - i - 1)`. + decay_base_params (bool): If `True`, treat non-layer parameters + (e.g. embeddings) as if they're in layer 0. If `False`, these + parameters are not scaled. + + Returns: + The parameter groups, used as the first argument for optimizers. + """ + + if lr_layer_scale != 1.0: + if lr is None: + raise ValueError( + "lr must be specified when lr_layer_decay_rate is not 1.0") + + fine_tune_group = { + "params": params_except_in(self, ["_encoder"]), + "lr": lr + } + param_groups = [fine_tune_group] + param_group = self._encoder.param_groups(lr, lr_layer_scale, + decay_base_params) + param_groups.extend(param_group) + else: + param_groups = self.parameters() + return param_groups + + def forward(self, # type: ignore + token_ids: torch.LongTensor, + segment_ids: Optional[torch.LongTensor] = None, + input_mask: Optional[torch.Tensor] = None) \ + -> Tuple[torch.Tensor, torch.LongTensor]: + r"""Feeds the inputs through the network and makes classification. + + Args: + token_ids: Shape `[batch_size, max_time]`. + segment_ids: Shape `[batch_size, max_time]`. + input_mask: Float tensor of shape `[batch_size, max_time]`. Note + that positions with value 1 are masked out. + + Returns: + A tuple `(logits, preds)`, containing the logits over classes and + the predictions, respectively. + + - If ``clas_strategy`` is ``cls_time`` or ``all_time``: + + - If ``num_classes`` == 1, ``logits`` and ``pred`` are both of + shape ``[batch_size]``. + - If ``num_classes`` > 1, ``logits`` is of shape + ``[batch_size, num_classes]`` and ``pred`` is of shape + ``[batch_size]``. + + - If ``clas_strategy`` is ``time_wise``: + + - ``num_classes`` == 1, ``logits`` and ``pred`` are both of + shape ``[batch_size, max_time]``. + - If ``num_classes`` > 1, ``logits`` is of shape + ``[batch_size, max_time, num_classes]`` and ``pred`` is of + shape ``[batch_size, max_time]``. + """ + # output: [batch_size, seq_len, hidden_dim] + output, _ = self._encoder(token_ids=token_ids, + segment_ids=segment_ids, + input_mask=input_mask) + + strategy = self._hparams.clas_strategy + if strategy == 'time_wise': + summary = output + elif strategy == 'cls_time': + summary = output[:, -1] + elif strategy == 'all_time': + length_diff = self._hparams.max_seq_length - token_ids.shape[1] + summary_input = F.pad(output, [0, 0, 0, length_diff, 0, 0]) + summary_input_dim = (self._encoder.output_size * + self._hparams.max_seq_length) + + summary = summary_input.contiguous().view(-1, summary_input_dim) + else: + raise ValueError(f"Unknown classification strategy: {strategy}.") + + if self._hparams.use_projection: + summary = torch.tanh(self.projection(summary)) + + if self.hidden_to_logits is not None: + summary = self.dropout(summary) + logits = self.hidden_to_logits(summary) + else: + logits = summary + + # Compute predictions + if strategy == "time_wise": + if self.is_binary: + logits = torch.squeeze(logits, -1) + preds = (logits > 0).long() + else: + preds = torch.argmax(logits, dim=-1) + else: + if self.is_binary: + preds = (logits > 0).long() + logits = torch.flatten(logits) + else: + preds = torch.argmax(logits, dim=-1) + preds = torch.flatten(preds) + + return logits, preds diff --git a/texar/modules/classifiers/xlnet_classifier_test.py b/texar/modules/classifiers/xlnet_classifier_test.py new file mode 100644 index 000000000..693f336b7 --- /dev/null +++ b/texar/modules/classifiers/xlnet_classifier_test.py @@ -0,0 +1,179 @@ +""" +Unit tests for XLNet classifiers. +""" + +import unittest + +import torch + +from texar.modules.classifiers.xlnet_classifier import * + + +class XLNetClassifierTest(unittest.TestCase): + r"""Tests :class:`~texar.modules.XLNetClassifier` class. + """ + + @unittest.skip("Manual test only") + def test_model_loading(self): + r"""Tests model loading functionality.""" + inputs = torch.zeros(32, 16, dtype=torch.int64) + + # case 1 + classifier = XLNetClassifier(pretrained_model_name="xlnet-base-cased") + _, _ = classifier(inputs) + + # case 2 + classifier = XLNetClassifier(pretrained_model_name="xlnet-large-cased") + _, _ = classifier(inputs) + + def test_trainable_variables(self): + r"""Tests the functionality of automatically collecting trainable + variables. + """ + inputs = torch.zeros(32, 16, dtype=torch.int64) + + # case 1 + hparams = { + "pretrained_model_name": None, + } + classifier = XLNetClassifier(hparams=hparams) + _, _ = classifier(inputs) + self.assertEqual(len(classifier.trainable_variables), 182 + 4) + + # case 2 + hparams = { + "pretrained_model_name": None, + "use_projection": False + } + classifier = XLNetClassifier(hparams=hparams) + _, _ = classifier(inputs) + self.assertEqual(len(classifier.trainable_variables), 182 + 2) + + # case 3 + hparams = { + "pretrained_model_name": None, + "clas_strategy": "all_time", + "max_seq_length": 8 + } + classifier = XLNetClassifier(hparams=hparams) + _, _ = classifier(inputs) + self.assertEqual(len(classifier.trainable_variables), 182 + 4) + + # case 4 + hparams = { + "pretrained_model_name": None, + "clas_strategy": "time_wise" + } + classifier = XLNetClassifier(hparams=hparams) + _, _ = classifier(inputs) + self.assertEqual(len(classifier.trainable_variables), 182 + 4) + + def test_classification(self): + r"""Tests classification. + """ + max_time = 8 + batch_size = 16 + inputs = torch.randint(32000, (batch_size, max_time), dtype=torch.int64) + + # case 1 + hparams = { + "pretrained_model_name": None, + } + classifier = XLNetClassifier(hparams=hparams) + logits, preds = classifier(inputs) + + self.assertEqual(logits.shape, torch.Size( + [batch_size, classifier.hparams.num_classes])) + self.assertEqual(preds.shape, torch.Size([batch_size])) + + # case 2 + hparams = { + "pretrained_model_name": None, + "num_classes": 10, + "clas_strategy": "time_wise" + } + classifier = XLNetClassifier(hparams=hparams) + logits, preds = classifier(inputs) + + self.assertEqual(logits.shape, torch.Size( + [batch_size, max_time, classifier.hparams.num_classes])) + self.assertEqual(preds.shape, torch.Size([batch_size, max_time])) + + # case 3 + hparams = { + "pretrained_model_name": None, + "num_classes": 0, + "clas_strategy": "time_wise" + } + classifier = XLNetClassifier(hparams=hparams) + logits, preds = classifier(inputs) + + self.assertEqual(logits.shape, torch.Size( + [batch_size, max_time, classifier.hparams.hidden_dim])) + self.assertEqual(preds.shape, torch.Size([batch_size, max_time])) + + # case 4 + hparams = { + "pretrained_model_name": None, + "num_classes": 10, + "clas_strategy": "all_time", + "max_seq_length": max_time + } + inputs = torch.randint(30521, (batch_size, 6), dtype=torch.int64) + classifier = XLNetClassifier(hparams=hparams) + logits, preds = classifier(inputs) + + self.assertEqual(logits.shape, torch.Size( + [batch_size, classifier.hparams.num_classes])) + self.assertEqual(preds.shape, torch.Size([batch_size])) + + def test_binary(self): + r"""Tests binary classification. + """ + max_time = 8 + batch_size = 16 + inputs = torch.randint(32000, (batch_size, max_time), dtype=torch.int64) + + # case 1 + hparams = { + "pretrained_model_name": None, + "num_classes": 1, + "clas_strategy": "time_wise" + } + classifier = XLNetClassifier(hparams=hparams) + logits, preds = classifier(inputs) + + self.assertEqual(logits.shape, torch.Size([batch_size, max_time])) + self.assertEqual(preds.shape, torch.Size([batch_size, max_time])) + + # case 2 + hparams = { + "pretrained_model_name": None, + "num_classes": 1, + "clas_strategy": "cls_time", + "max_seq_length": max_time + } + inputs = torch.randint(32000, (batch_size, 6), dtype=torch.int64) + classifier = XLNetClassifier(hparams=hparams) + logits, preds = classifier(inputs) + + self.assertEqual(logits.shape, torch.Size([batch_size])) + self.assertEqual(preds.shape, torch.Size([batch_size])) + + # case 3 + hparams = { + "pretrained_model_name": None, + "num_classes": 1, + "clas_strategy": "all_time", + "max_seq_length": max_time + } + inputs = torch.randint(32000, (batch_size, 6), dtype=torch.int64) + classifier = XLNetClassifier(hparams=hparams) + logits, preds = classifier(inputs) + + self.assertEqual(logits.shape, torch.Size([batch_size])) + self.assertEqual(preds.shape, torch.Size([batch_size])) + + +if __name__ == "__main__": + unittest.main() diff --git a/texar/modules/decoders/__init__.py b/texar/modules/decoders/__init__.py index 09ee338be..68a6d07da 100644 --- a/texar/modules/decoders/__init__.py +++ b/texar/modules/decoders/__init__.py @@ -21,3 +21,4 @@ from texar.modules.decoders.gpt2_decoder import * from texar.modules.decoders.rnn_decoders import * from texar.modules.decoders.transformer_decoders import * +from texar.modules.decoders.xlnet_decoder import * diff --git a/texar/modules/decoders/decoder_base.py b/texar/modules/decoders/decoder_base.py index 0507c75a4..7694aba3f 100644 --- a/texar/modules/decoders/decoder_base.py +++ b/texar/modules/decoders/decoder_base.py @@ -88,7 +88,6 @@ class DecoderBase(ModuleBase, Generic[State, Output], ABC): """ def __init__(self, - input_size: int, vocab_size: Optional[int] = None, input_time_major: bool = False, output_time_major: bool = False, @@ -100,7 +99,6 @@ def __init__(self, self._input_time_major = input_time_major self._output_time_major = output_time_major - self._input_size = input_size self._vocab_size = vocab_size def create_helper(self, *, diff --git a/texar/modules/decoders/gpt2_decoder.py b/texar/modules/decoders/gpt2_decoder.py index 0e789ea09..c97bcf094 100644 --- a/texar/modules/decoders/gpt2_decoder.py +++ b/texar/modules/decoders/gpt2_decoder.py @@ -21,8 +21,10 @@ from texar.core import layers from texar.hyperparams import HParams -from texar.modules.pretrained import GPT2Base, gpt2_utils -from texar.modules.embedders import PositionEmbedder, WordEmbedder +from texar.modules.pretrained.gpt2_utils import init_gpt2_checkpoint +from texar.modules.pretrained.pretrained_base import PretrainedBase +from texar.modules.embedders.embedders import WordEmbedder +from texar.modules.embedders.position_embedders import PositionEmbedder from texar.modules.decoders.decoder_helpers import Helper from texar.modules.decoders.transformer_decoders import ( TransformerDecoder, TransformerDecoderOutput) @@ -32,7 +34,7 @@ ] -class GPT2Decoder(GPT2Base): +class GPT2Decoder(PretrainedBase): r"""Raw GPT2 Transformer for decoding sequences. This module basically stacks @@ -62,6 +64,8 @@ def __init__(self, cache_dir: Optional[str] = None, hparams=None): + self.model_name = "GPT2" + super().__init__(pretrained_model_name=pretrained_model_name, cache_dir=cache_dir, hparams=hparams) @@ -87,7 +91,7 @@ def __init__(self, hparams=self._hparams.decoder) if self.pretrained_model_dir: - gpt2_utils.init_gpt2_checkpoint(self, self.pretrained_model_dir) + init_gpt2_checkpoint(self, self.pretrained_model_dir) elif self._hparams.initializer: initialize = layers.get_initializer(self._hparams.initializer) assert initialize is not None @@ -181,33 +185,33 @@ def default_hparams(): The default parameters are values for 117M GPT2 model. - `pretrained_model_name`: str or None + `"pretrained_model_name"`: str or None The name of the pre-trained GPT2 model. If None, the model will be randomly initialized. - `embed`: dict + `"embed"`: dict Hyperparameters for word embedding layer. - `vocab_size`: int + `"vocab_size"`: int The vocabulary size of `inputs` in `GPT2Model`. - `position_embed`: dict + `"position_embed"`: dict Hyperparameters for position embedding layer. - `position_size`: int + `"position_size"`: int The maximum sequence length that this model might ever be used with. - `decoder`: dict + `"decoder"`: dict Hyperparameters for the TransformerDecoder. See :func:`~texar.modules.TransformerDecoder.default_harams` for details. - `initializer`: dict, optional + `"initializer"`: dict, optional Hyperparameters of the default initializer that initializes variables created in this module. See :func:`~texar.core.get_initializer` for details. - `name`: str + `"name"`: str Name of the module. """ return { diff --git a/texar/modules/decoders/gpt2_decoder_test.py b/texar/modules/decoders/gpt2_decoder_test.py index 6def98bf3..143dd1ce5 100644 --- a/texar/modules/decoders/gpt2_decoder_test.py +++ b/texar/modules/decoders/gpt2_decoder_test.py @@ -10,11 +10,11 @@ from texar.modules.decoders.transformer_decoders import TransformerDecoderOutput -@unittest.skip("Manual test only") class GPT2DecoderTest(unittest.TestCase): r"""Tests :class:`~texar.modules.GPT2Decoder` """ + @unittest.skip("Manual test only") def test_hparams(self): r"""Tests the priority of the decoer arch parameter. """ @@ -24,10 +24,10 @@ def test_hparams(self): hparams = { "pretrained_model_name": "345M", } - encoder = GPT2Decoder(pretrained_model_name="117M", + decoder = GPT2Decoder(pretrained_model_name="117M", hparams=hparams) - _ = encoder(inputs) - self.assertEqual(encoder.hparams.decoder.num_blocks, 12) + _ = decoder(inputs) + self.assertEqual(decoder.hparams.decoder.num_blocks, 12) # case 2: set "pretrained_mode_name" by hparams hparams = { @@ -36,9 +36,9 @@ def test_hparams(self): "num_blocks": 6 } } - encoder = GPT2Decoder(hparams=hparams) - _ = encoder(inputs) - self.assertEqual(encoder.hparams.decoder.num_blocks, 12) + decoder = GPT2Decoder(hparams=hparams) + _ = decoder(inputs) + self.assertEqual(decoder.hparams.decoder.num_blocks, 12) # case 3: set to None in both hparams and constructor argument hparams = { @@ -47,15 +47,16 @@ def test_hparams(self): "num_blocks": 6 }, } - encoder = GPT2Decoder(hparams=hparams) - _ = encoder(inputs) - self.assertEqual(encoder.hparams.decoder.num_blocks, 6) + decoder = GPT2Decoder(hparams=hparams) + _ = decoder(inputs) + self.assertEqual(decoder.hparams.decoder.num_blocks, 6) # case 4: using default hparams - encoder = GPT2Decoder() - _ = encoder(inputs) - self.assertEqual(encoder.hparams.decoder.num_blocks, 12) + decoder = GPT2Decoder() + _ = decoder(inputs) + self.assertEqual(decoder.hparams.decoder.num_blocks, 12) + @unittest.skip("Manual test only") def test_trainable_variables(self): r"""Tests the functionality of automatically collecting trainable variables. @@ -63,17 +64,17 @@ def test_trainable_variables(self): inputs = torch.zeros(32, 16, dtype=torch.int64) # case 1: GPT2 117M - encoder = GPT2Decoder() - _ = encoder(inputs) - self.assertEqual(len(encoder.trainable_variables), 1 + 1 + 12 * 26 + 2) + decoder = GPT2Decoder() + _ = decoder(inputs) + self.assertEqual(len(decoder.trainable_variables), 1 + 1 + 12 * 26 + 2) # case 2: GPT2 345M hparams = { "pretrained_model_name": "345M" } - encoder = GPT2Decoder(hparams=hparams) - _ = encoder(inputs) - self.assertEqual(len(encoder.trainable_variables), 1 + 1 + 24 * 26 + 2) + decoder = GPT2Decoder(hparams=hparams) + _ = decoder(inputs) + self.assertEqual(len(decoder.trainable_variables), 1 + 1 + 24 * 26 + 2) # case 3: self-designed GPT2 hparams = { @@ -82,14 +83,17 @@ def test_trainable_variables(self): }, "pretrained_model_name": None } - encoder = GPT2Decoder(hparams=hparams) - _ = encoder(inputs) - self.assertEqual(len(encoder.trainable_variables), 1 + 1 + 6 * 26 + 2) + decoder = GPT2Decoder(hparams=hparams) + _ = decoder(inputs) + self.assertEqual(len(decoder.trainable_variables), 1 + 1 + 6 * 26 + 2) def test_decode_train(self): r"""Tests train_greedy. """ - decoder = GPT2Decoder() + hparams = { + "pretrained_model_name": None + } + decoder = GPT2Decoder(hparams=hparams) decoder.train() max_time = 8 @@ -106,7 +110,10 @@ def test_decode_train(self): def test_decode_infer_greedy(self): r"""Tests train_greedy """ - decoder = GPT2Decoder() + hparams = { + "pretrained_model_name": None + } + decoder = GPT2Decoder(hparams=hparams) decoder.eval() start_tokens = torch.full((16,), 1, dtype=torch.int64) @@ -128,7 +135,10 @@ def test_decode_infer_greedy(self): def test_decode_infer_sample(self): r"""Tests infer_sample """ - decoder = GPT2Decoder() + hparams = { + "pretrained_model_name": None + } + decoder = GPT2Decoder(hparams=hparams) decoder.eval() start_tokens = torch.full((16,), 1, dtype=torch.int64) @@ -150,7 +160,10 @@ def test_decode_infer_sample(self): def test_beam_search(self): r"""Tests beam_search """ - decoder = GPT2Decoder() + hparams = { + "pretrained_model_name": None + } + decoder = GPT2Decoder(hparams=hparams) decoder.eval() start_tokens = torch.full((16,), 1, dtype=torch.int64) @@ -175,7 +188,10 @@ def test_beam_search(self): def test_greedy_embedding_helper(self): r"""Tests with tf.contrib.seq2seq.GreedyEmbeddingHelper """ - decoder = GPT2Decoder() + hparams = { + "pretrained_model_name": None + } + decoder = GPT2Decoder(hparams=hparams) decoder.eval() start_tokens = torch.full((16,), 1, dtype=torch.int64) @@ -197,7 +213,10 @@ def test_greedy_embedding_helper(self): def test_topk_embedding_helper(self): r"""Tests TopKSampleEmbeddingHelper """ - decoder = GPT2Decoder() + hparams = { + "pretrained_model_name": None + } + decoder = GPT2Decoder(hparams=hparams) decoder.eval() start_tokens = torch.full((16,), 1, dtype=torch.int64) @@ -219,3 +238,7 @@ def test_topk_embedding_helper(self): helper=helper) self.assertIsInstance(outputs, TransformerDecoderOutput) + + +if __name__ == "__main__": + unittest.main() diff --git a/texar/modules/decoders/rnn_decoder_base.py b/texar/modules/decoders/rnn_decoder_base.py index e71228915..8af88b066 100644 --- a/texar/modules/decoders/rnn_decoder_base.py +++ b/texar/modules/decoders/rnn_decoder_base.py @@ -50,7 +50,7 @@ def __init__(self, input_time_major: bool = False, output_time_major: bool = False, hparams=None): - super().__init__(input_size, vocab_size, input_time_major, + super().__init__(vocab_size, input_time_major, output_time_major, hparams) # Make RNN cell diff --git a/texar/modules/decoders/transformer_decoders.py b/texar/modules/decoders/transformer_decoders.py index 5a18719d0..7ce931de0 100644 --- a/texar/modules/decoders/transformer_decoders.py +++ b/texar/modules/decoders/transformer_decoders.py @@ -92,7 +92,7 @@ def __init__(self, vocab_size: Optional[int] = None, output_layer: Optional[Union[nn.Module, torch.Tensor]] = None, hparams=None): - super().__init__(0, vocab_size, # dummy value for input_size + super().__init__(vocab_size, input_time_major=False, output_time_major=False, hparams=hparams) self._input_size = self._hparams.dim diff --git a/texar/modules/decoders/xlnet_decoder.py b/texar/modules/decoders/xlnet_decoder.py new file mode 100644 index 000000000..d2c0f01d0 --- /dev/null +++ b/texar/modules/decoders/xlnet_decoder.py @@ -0,0 +1,386 @@ +# Copyright 2019 The Texar Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Dict, List, NamedTuple, Optional, Tuple, Type, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from texar.core import layers +from texar.modules.decoders.decoder_base import DecoderBase +from texar.modules.decoders.decoder_helpers import Helper, SampleEmbeddingHelper +from texar.modules.encoders.xlnet_encoder import XLNetEncoder +from texar.modules.pretrained import xlnet_utils +from texar.utils import get_instance + +__all__ = [ + 'XLNetDecoderOutput', + 'XLNetDecoder', +] + + +class XLNetDecoderOutput(NamedTuple): + r"""The output of :class:`XLNetDecoder`. + """ + logits: torch.Tensor + r"""A :tensor:`Tensor` of shape ``[batch_size, max_time, vocab_size]`` + containing the logits.""" + sample_id: torch.LongTensor + r"""A :tensor:`LongTensor` of shape ``[batch_size, max_time]`` containing + the sampled token indices.""" + + +Output = XLNetDecoderOutput +State = List[torch.Tensor] + + +class XLNetDecoder(XLNetEncoder, DecoderBase[Optional[State], Output]): + r"""Raw XLNet module for decoding sequences. + + This module supports the architecture first proposed + in `(Yang et al.)` XLNet. + + Args: + pretrained_model_name (optional): a str with the name + of a pre-trained model to load selected in the list of: + `xlnet-base-cased`, `xlnet-large-cased`. + If `None`, will use the model name in :attr:`hparams`. + cache_dir (optional): the path to a folder in which the + pre-trained models will be cached. If `None` (default), + a default directory will be used. + hparams (dict or HParams, optional): Hyperparameters. Missing + hyperparameter will be set to default values. See + :meth:`default_hparams` for the hyperparameter structure + and default values. + """ + + def __init__(self, + pretrained_model_name: Optional[str] = None, + cache_dir: Optional[str] = None, + hparams=None): + + super().__init__(pretrained_model_name=pretrained_model_name, + cache_dir=cache_dir, + hparams=hparams, + init=False) + + self.lm_bias = nn.Parameter(torch.zeros(self._hparams.vocab_size)) + + if self.pretrained_model_dir: + xlnet_utils.init_xlnet_checkpoint(self, + self.pretrained_model_dir) + elif self._hparams.initializer: + initialize = layers.get_initializer(self._hparams.initializer) + assert initialize is not None + # Do not re-initialize LayerNorm modules. + for name, param in self.named_parameters(): + if name.split('.')[-1] == 'weight' \ + and 'layer_norm' not in name: + initialize(param) + else: + self.reset_parameters() + + def reset_parameters(self): + if not self._hparams.untie_r: + nn.init.normal_(self.r_w_bias, 0.0, 0.02) + nn.init.normal_(self.r_r_bias, 0.0, 0.02) + if self._hparams.use_segments: + nn.init.normal_(self.r_s_bias, 0.0, 0.02) + + # Variables persistent during decoding. + _state_cache_len: int + _state_recompute_memory: bool + # required for recomputing memory + _state_previous_inputs: List[torch.Tensor] + + @staticmethod + def default_hparams() -> Dict[str, Any]: + r"""Returns a dictionary of hyperparameters with default values. + + * The decoder arch is determined by the constructor argument + :attr:`pretrained_model_name` if it's specified. In this case, + `hparams` are ignored. + * Otherwise, the decoder arch is determined by + `hparams['pretrained_model_name']` if it's specified. All other + configurations in `hparams` are ignored. + * If the above two are `None`, the decoder arch is defined by the + configurations in `hparams` and weights are randomly initialized. + + .. code-block:: python + + { + "pretrained_model_name": "xlnet-base-cased", + "untie_r": True, + "num_layers": 12, + "mem_len": 0, + "reuse_len": 0, + "num_heads": 12, + "hidden_dim": 768, + "head_dim": 64, + "dropout": 0.1, + "attention_dropout": 0.1, + "use_segments": True, + "ffn_inner_dim": 3072, + "activation": 'gelu', + "vocab_size": 32000, + "max_seq_length": 512, + "initializer": None, + "name": "xlnet_decoder", + } + + Here: + + The default parameters are values for cased XLNet-Base model. + + `"pretrained_model_name"`: str or None + The name of the pre-trained XLNet model. If None, the model + will be randomly initialized. + + `"untie_r"`: bool + Whether to untie the biases in attention. + + `"num_layers"`: int + The number of stacked layers. + + `"mem_len"`: int + The number of tokens to cache. + + `"reuse_len"`: int + The number of tokens in the current batch to be cached and reused + in the future. + + `"num_heads"`: int + The number of attention heads. + + `"hidden_dim"`: int + The hidden size. + + `"head_dim"`: int + The dimension size of each attention head. + + `"dropout"`: float + Dropout rate. + + `"attention_dropout"`: float + Dropout rate on attention probabilities. + + `"use_segments"`: bool + Whether to use segment embedding. + + `"ffn_inner_dim"`: int + The hidden size in feed-forward layers. + + `"activation"`: str + `relu` or `gelu`. + + `"vocab_size"`: int + The vocabulary size. + + `"max_seq_length"`: int + The maximum sequence length for `RelativePositionalEncoding`. + + `"initializer"`: dict, optional + Hyperparameters of the default initializer that initializes + variables created in this module. + See :func:`~texar.core.get_initializer` for details. + + `"name"`: str + Name of the module. + """ + return { + 'pretrained_model_name': 'xlnet-base-cased', + 'untie_r': True, + 'num_layers': 12, + 'mem_len': 0, + 'reuse_len': 0, + # layer + 'num_heads': 12, + 'hidden_dim': 768, + 'head_dim': 64, + 'dropout': 0.1, + 'attention_dropout': 0.1, + 'use_segments': True, + # ffn + 'ffn_inner_dim': 3072, + 'activation': 'gelu', + # embedding + 'vocab_size': 32000, + 'max_seq_length': 512, + 'initializer': None, + 'name': "xlnet_decoder", + '@no_typecheck': ['pretrained_model_name'], + } + + @staticmethod + def _create_input(inputs: List[torch.Tensor], + initial: bool = False) \ + -> Dict[str, torch.Tensor]: + r"""Create input tensors given the list of prompt tokens. + """ + word_embed = torch.stack(inputs, dim=0) + seq_len, batch_size, embed_dim = word_embed.size() + if not initial: + # Add a dummy token at the end that stands for the token + # to predict. + word_embed = torch.cat([ + word_embed, + word_embed.new_zeros(1, batch_size, embed_dim) + ], dim=0) + seq_len += 1 + segment_ids = word_embed.new_zeros( + seq_len, batch_size, dtype=torch.long) + return_dict = { + "word_embed": word_embed.permute(1, 0, 2), + "segment_ids": segment_ids.permute(1, 0), + } + + if not initial: + # Only the dummy token is considered target. + target_mapping = torch.cat([ + torch.zeros(1, seq_len - 1, batch_size), + torch.ones(1, 1, batch_size) + ], dim=1).to(device=word_embed.device) + # Dummy token attends to nothing; actual tokens attend to all. + permute_mask = torch.cat([ + torch.zeros(seq_len, seq_len - 1, batch_size), + torch.ones(seq_len, 1, batch_size), + ], dim=1).to(device=word_embed.device) + return_dict.update({ + "target_mapping": target_mapping.permute(2, 0, 1), + "permute_mask": permute_mask.permute(2, 0, 1), + }) + + return return_dict + + def initialize(self, # pylint: disable=no-self-use + helper: Helper, + inputs: Optional[torch.Tensor], + sequence_length: Optional[torch.LongTensor], + initial_state: Optional[State]) \ + -> Tuple[torch.ByteTensor, torch.Tensor, Optional[State]]: + initial_finished, initial_inputs = helper.initialize( + inputs, sequence_length) + return initial_finished, initial_inputs, initial_state + + def step(self, + helper: Helper, + time: int, + inputs: torch.Tensor, + state: Optional[State]) \ + -> Tuple[Output, Optional[State], torch.Tensor, torch.ByteTensor]: + self._state_previous_inputs.append(inputs) + if self._state_recompute_memory: + net_output, memory = self._forward( + two_stream=True, + **self._create_input( + self._state_previous_inputs[-self._state_cache_len:])) + else: + assert state is not None + net_output, memory = self._forward( + memory=state, cache_len=self._state_cache_len, two_stream=True, + **self._create_input(self._state_previous_inputs[-1:])) + assert memory is not None + # Omit memory for the dummy token. + memory = [mem[:, :-1] for mem in memory] + + logits = F.linear(net_output, self.word_embed.weight, self.lm_bias) + logits = logits[:, -1] + sample_ids = helper.sample(time=time, outputs=logits) + (finished, next_inputs) = helper.next_inputs( + time=time, + outputs=logits, + sample_ids=sample_ids) + outputs = XLNetDecoderOutput(logits=logits, sample_id=sample_ids) + return outputs, memory, next_inputs, finished + + def finalize(self, outputs, final_state, sequence_lengths): + del self._state_cache_len + del self._state_recompute_memory + del self._state_previous_inputs + return super().finalize(outputs, final_state, sequence_lengths) + + def forward(self, # type: ignore + start_tokens: torch.LongTensor, + memory: Optional[State] = None, + cache_len: int = 512, + max_decoding_length: Optional[int] = 500, + recompute_memory: bool = True, + print_steps: bool = False, + helper_type: Optional[Union[str, Type[Helper]]] = None, + **helper_kwargs) \ + -> Tuple[Output, Optional[State]]: + r"""Perform autoregressive decoding using XLNet. The algorithm is + largely inspired by: https://github.com/rusiaaman/XLNet-gen. + + Args: + start_tokens: A LongTensor of shape `[batch_size, prompt_len]`, + representing the tokenized initial prompt. + memory (optional): The initial memory. + cache_len: Length of memory (number of tokens) to cache. + max_decoding_length (int): Maximum number of tokens to decode. + recompute_memory (bool): If `True`, the entire memory is recomputed + for each token to generate. This leads to better performance + because it enables every generated token to attend to each + other, compared to reusing previous memory which is equivalent + to using a causal attention mask. However, it is computationally + more expensive. Defaults to `True`. + print_steps (bool): If `True`, will print decoding progress. + helper: Type (or name of the type) of any sub-class of + :class:`~texar.modules.decoders.Helper`. + helper_kwargs: The keyword arguments to pass to constructor of + the specific helper type. + + :returns: A tuple of `(output, new_memory)`: + - **`output`**: The sampled tokens as a list of integers. + - **`new_memory`**: The memory of the sampled tokens. + """ + + start_tokens = start_tokens.t() + self._state_recompute_memory = recompute_memory + self._state_cache_len = cache_len + self._state_previous_inputs = list( + self.word_embed(start_tokens).unbind(dim=0))[:-1] + + if helper_type is None: + helper_type = SampleEmbeddingHelper + + if not recompute_memory and start_tokens.size(0) > 1: + _, memory = self._forward( + memory=memory, cache_len=cache_len, + **self._create_input( + self._state_previous_inputs, initial=True)) + start_tokens = start_tokens[-1] + + helper_kwargs.update( + embedding=self.word_embed.weight, start_tokens=start_tokens) + + if helper_kwargs.get("end_token") is None: + raise ValueError("'end_token' must be specified.") + + helper = get_instance( + helper_type, helper_kwargs, + module_paths=['texar.modules.decoders.decoder_helpers']) + + step_hook = None + if print_steps: + step_hook = lambda step: print( + f"\033[2K\rDecoding step: {step}", end='') + output, new_memory, _ = self.dynamic_decode( + helper, inputs=None, sequence_length=None, initial_state=memory, + max_decoding_length=max_decoding_length, step_hook=step_hook) + if print_steps: + print("\033[2K\r", end='') + + return output, new_memory diff --git a/texar/modules/decoders/xlnet_decoder_test.py b/texar/modules/decoders/xlnet_decoder_test.py new file mode 100644 index 000000000..65fbf1e3d --- /dev/null +++ b/texar/modules/decoders/xlnet_decoder_test.py @@ -0,0 +1,119 @@ +""" +Unit tests for XLNet decoder. +""" +import unittest + +import torch + +from texar.modules.decoders.xlnet_decoder import * + + +class XLNetDecoderTest(unittest.TestCase): + r"""Tests :class:`~texar.modules.XLNetDecoder` + """ + + @unittest.skip("Manual test only") + def test_hparams(self): + r"""Tests the priority of the decoer arch parameter. + """ + # case 1: set "pretrained_mode_name" by constructor argument + hparams = { + "pretrained_model_name": "xlnet-large-cased", + } + decoder = XLNetDecoder(pretrained_model_name="xlnet-base-cased", + hparams=hparams) + + _, _ = decoder(start_tokens=torch.zeros(16, 8, dtype=torch.int64), + end_token=1, + max_decoding_length=8) + + self.assertEqual(decoder.hparams.num_layers, 12) + + # case 2: set "pretrained_mode_name" by hparams + hparams = { + "pretrained_model_name": "xlnet-large-cased", + "num_layers": 6 + } + decoder = XLNetDecoder(hparams=hparams) + + _, _ = decoder(start_tokens=torch.zeros(16, 8, dtype=torch.int64), + end_token=1, + max_decoding_length=8) + + self.assertEqual(decoder.hparams.num_layers, 24) + + # case 3: set to None in both hparams and constructor argument + hparams = { + "pretrained_model_name": None, + "num_layers": 6 + } + decoder = XLNetDecoder(hparams=hparams) + + _, _ = decoder(start_tokens=torch.zeros(16, 8, dtype=torch.int64), + end_token=1, + max_decoding_length=8) + self.assertEqual(decoder.hparams.num_layers, 6) + + # case 4: using default hparams + decoder = XLNetDecoder() + + _, _ = decoder(start_tokens=torch.zeros(16, 8, dtype=torch.int64), + end_token=1, + max_decoding_length=8) + self.assertEqual(decoder.hparams.num_layers, 12) + + @unittest.skip("Manual test only") + def test_trainable_variables(self): + r"""Tests the functionality of automatically collecting trainable + variables. + """ + # case 1 + decoder = XLNetDecoder() + + _, _ = decoder(start_tokens=torch.zeros(16, 8, dtype=torch.int64), + end_token=1, + max_decoding_length=8) + self.assertEqual(len(decoder.trainable_variables), 182 + 1) + + # case 2 + hparams = { + "pretrained_model_name": "xlnet-large-cased", + } + decoder = XLNetDecoder(hparams=hparams) + + _, _ = decoder(start_tokens=torch.zeros(16, 8, dtype=torch.int64), + end_token=1, + max_decoding_length=8) + self.assertEqual(len(decoder.trainable_variables), 362 + 1) + + # case 3 + hparams = { + "pretrained_model_name": None, + "num_layers": 6 + } + decoder = XLNetDecoder(hparams=hparams) + + _, _ = decoder(start_tokens=torch.zeros(16, 8, dtype=torch.int64), + end_token=1, + max_decoding_length=8) + self.assertEqual(len(decoder.trainable_variables), 92 + 1) + + @unittest.skip("Manual test only") + def test_decode_infer_sample(self): + r"""Tests train_greedy.""" + hparams = { + "pretrained_model_name": None + } + decoder = XLNetDecoder(hparams=hparams) + decoder.train() + + max_time = 8 + batch_size = 16 + inputs = torch.randint(32000, (batch_size, max_time), dtype=torch.int64) + outputs, _ = decoder(inputs, max_decoding_length=10, end_token=2) + + self.assertIsInstance(outputs, XLNetDecoderOutput) + + +if __name__ == "__main__": + unittest.main() diff --git a/texar/modules/encoders/__init__.py b/texar/modules/encoders/__init__.py index d83508180..730e33c52 100644 --- a/texar/modules/encoders/__init__.py +++ b/texar/modules/encoders/__init__.py @@ -22,3 +22,4 @@ from texar.modules.encoders.multihead_attention import * from texar.modules.encoders.rnn_encoders import * from texar.modules.encoders.transformer_encoder import * +from texar.modules.encoders.xlnet_encoder import * diff --git a/texar/modules/encoders/bert_encoders.py b/texar/modules/encoders/bert_encoders.py index adbf6f20f..4033c7167 100644 --- a/texar/modules/encoders/bert_encoders.py +++ b/texar/modules/encoders/bert_encoders.py @@ -22,17 +22,20 @@ from texar.core import layers from texar.hyperparams import HParams -from texar.modules.pretrained import BertBase, bert_utils -from texar.modules.embedders import PositionEmbedder, WordEmbedder +from texar.modules.pretrained.bert_utils import init_bert_checkpoint +from texar.modules.pretrained.pretrained_base import PretrainedBase +from texar.modules.embedders.embedders import WordEmbedder +from texar.modules.embedders.position_embedders import PositionEmbedder +from texar.modules.encoders.encoder_base import EncoderBase from texar.modules.encoders.transformer_encoder import TransformerEncoder __all__ = [ - "BertEncoder", + "BERTEncoder", ] -class BertEncoder(BertBase): +class BERTEncoder(PretrainedBase, EncoderBase): r"""Raw BERT Transformer for encoding sequences. This module basically stacks @@ -59,6 +62,8 @@ class BertEncoder(BertBase): and default values. """ + model_name = "BERT" + def __init__(self, pretrained_model_name: Optional[str] = None, cache_dir: Optional[str] = None, @@ -95,7 +100,7 @@ def __init__(self, nn.Tanh()) if self.pretrained_model_dir: - bert_utils.init_bert_checkpoint(self, self.pretrained_model_dir) + init_bert_checkpoint(self, self.pretrained_model_dir) elif self._hparams.initializer: initialize = layers.get_initializer(self._hparams.initializer) assert initialize is not None @@ -183,42 +188,42 @@ def default_hparams(): The default parameters are values for uncased BERT-Base model. - `pretrained_model_name`: str or None + `"pretrained_model_name"`: str or None The name of the pre-trained BERT model. If None, the model will be randomly initialized. - `embed`: dict + `"embed"`: dict Hyperparameters for word embedding layer. - `vocab_size`: int + `"vocab_size"`: int The vocabulary size of `inputs` in `BertModel`. - `segment_embed`: dict + `"segment_embed"`: dict Hyperparameters for segment embedding layer. - `type_vocab_size`: int + `"type_vocab_size"`: int The vocabulary size of the `segment_ids` passed into `BertModel`. - `position_embed`: dict + `"position_embed"`: dict Hyperparameters for position embedding layer. - `position_size`: int + `"position_size"`: int The maximum sequence length that this model might ever be used with. - `encoder`: dict + `"encoder"`: dict Hyperparameters for the TransformerEncoder. See :func:`~texar.modules.TransformerEncoder.default_harams` for details. - `hidden_size`: int + `"hidden_size"`: int Size of the pooler dense layer. - `initializer`: dict, optional + `"initializer"`: dict, optional Hyperparameters of the default initializer that initializes variables created in this module. See :func:`~texar.core.get_initializer` for details. - `name`: str + `"name"`: str Name of the module. """ @@ -339,3 +344,7 @@ def forward(self, # type: ignore pooled_output = self.pooler(first_token_tensor) return output, pooled_output + + @property + def output_size(self): + return self._hparams.hidden_size diff --git a/texar/modules/encoders/bert_encoders_test.py b/texar/modules/encoders/bert_encoders_test.py index 4f802dc55..147c7fe94 100644 --- a/texar/modules/encoders/bert_encoders_test.py +++ b/texar/modules/encoders/bert_encoders_test.py @@ -1,19 +1,54 @@ """ -Unit tests for Bert encoders. +Unit tests for BERT encoders. """ import unittest import torch -from texar.modules.encoders.bert_encoders import BertEncoder +from texar.modules.encoders.bert_encoders import BERTEncoder -@unittest.skip("Manual test only") -class BertEncoderTest(unittest.TestCase): - r"""Tests :class:`~texar.modules.BertEncoder` class. +class BERTEncoderTest(unittest.TestCase): + r"""Tests :class:`~texar.modules.BERTEncoder` class. """ + @unittest.skip("Manual test only") + def test_model_loading(self): + r"""Tests model loading functionality.""" + inputs = torch.zeros(32, 16, dtype=torch.int64) + + # case 1 + encoder = BERTEncoder(pretrained_model_name="bert-base-uncased") + _, _ = encoder(inputs) + + # case 2 + encoder = BERTEncoder(pretrained_model_name="bert-large-uncased") + _, _ = encoder(inputs) + + # case 3 + encoder = BERTEncoder(pretrained_model_name="bert-base-cased") + _, _ = encoder(inputs) + + # case 4 + encoder = BERTEncoder(pretrained_model_name="bert-large-cased") + _, _ = encoder(inputs) + + # case 5 + encoder = BERTEncoder( + pretrained_model_name="bert-base-multilingual-uncased") + _, _ = encoder(inputs) + + # case 6 + encoder = BERTEncoder( + pretrained_model_name="bert-base-multilingual-cased") + _, _ = encoder(inputs) + + # case 7 + encoder = BERTEncoder(pretrained_model_name="bert-base-chinese") + _, _ = encoder(inputs) + + @unittest.skip("Manual test only") def test_hparams(self): r"""Tests the priority of the encoder arch parameter. """ @@ -23,7 +58,7 @@ def test_hparams(self): hparams = { "pretrained_model_name": "bert-large-uncased", } - encoder = BertEncoder(pretrained_model_name="bert-base-uncased", + encoder = BERTEncoder(pretrained_model_name="bert-base-uncased", hparams=hparams) _, _ = encoder(inputs) self.assertEqual(encoder.hparams.encoder.num_blocks, 12) @@ -35,7 +70,7 @@ def test_hparams(self): "num_blocks": 6 } } - encoder = BertEncoder(hparams=hparams) + encoder = BERTEncoder(hparams=hparams) _, _ = encoder(inputs) self.assertEqual(encoder.hparams.encoder.num_blocks, 24) @@ -46,15 +81,16 @@ def test_hparams(self): "num_blocks": 6 }, } - encoder = BertEncoder(hparams=hparams) + encoder = BERTEncoder(hparams=hparams) _, _ = encoder(inputs) self.assertEqual(encoder.hparams.encoder.num_blocks, 6) # case 4: using default hparams - encoder = BertEncoder() + encoder = BERTEncoder() _, _ = encoder(inputs) self.assertEqual(encoder.hparams.encoder.num_blocks, 12) + @unittest.skip("Manual test only") def test_trainable_variables(self): r"""Tests the functionality of automatically collecting trainable variables. @@ -62,7 +98,7 @@ def test_trainable_variables(self): inputs = torch.zeros(32, 16, dtype=torch.int64) # case 1: bert base - encoder = BertEncoder() + encoder = BERTEncoder() _, _ = encoder(inputs) self.assertEqual(len(encoder.trainable_variables), 3 + 2 + 12 * 16 + 2) @@ -70,7 +106,7 @@ def test_trainable_variables(self): hparams = { "pretrained_model_name": "bert-large-uncased" } - encoder = BertEncoder(hparams=hparams) + encoder = BERTEncoder(hparams=hparams) _, _ = encoder(inputs) self.assertEqual(len(encoder.trainable_variables), 3 + 2 + 24 * 16 + 2) @@ -81,7 +117,7 @@ def test_trainable_variables(self): }, "pretrained_model_name": None } - encoder = BertEncoder(hparams=hparams) + encoder = BERTEncoder(hparams=hparams) _, _ = encoder(inputs) self.assertEqual(len(encoder.trainable_variables), 3 + 2 + 6 * 16 + 2) @@ -89,7 +125,10 @@ def test_encode(self): r"""Tests encoding. """ # case 1: bert base - encoder = BertEncoder() + hparams = { + "pretrained_model_name": None + } + encoder = BERTEncoder(hparams=hparams) max_time = 8 batch_size = 16 @@ -148,7 +187,7 @@ def test_encode(self): }, 'hidden_size': 96 } - encoder = BertEncoder(hparams=hparams) + encoder = BERTEncoder(hparams=hparams) max_time = 8 batch_size = 16 diff --git a/texar/modules/encoders/gpt2_encoder.py b/texar/modules/encoders/gpt2_encoder.py index f2d31bc03..35ab7c16b 100644 --- a/texar/modules/encoders/gpt2_encoder.py +++ b/texar/modules/encoders/gpt2_encoder.py @@ -22,8 +22,11 @@ from texar.core import layers from texar.hyperparams import HParams -from texar.modules.pretrained import GPT2Base, gpt2_utils -from texar.modules.embedders import PositionEmbedder, WordEmbedder +from texar.modules.pretrained.gpt2_utils import init_gpt2_checkpoint +from texar.modules.pretrained.pretrained_base import PretrainedBase +from texar.modules.embedders.embedders import WordEmbedder +from texar.modules.embedders.position_embedders import PositionEmbedder +from texar.modules.encoders.encoder_base import EncoderBase from texar.modules.decoders.transformer_decoders import TransformerDecoder @@ -32,7 +35,7 @@ ] -class GPT2Encoder(GPT2Base): +class GPT2Encoder(PretrainedBase, EncoderBase): r"""Raw GPT2 Transformer for encoding sequences. This module basically stacks @@ -57,6 +60,8 @@ class GPT2Encoder(GPT2Base): and default values. """ + model_name = "GPT2" + def __init__(self, pretrained_model_name: Optional[str] = None, cache_dir: Optional[str] = None, @@ -87,7 +92,7 @@ def __init__(self, hparams=self._hparams.decoder) if self.pretrained_model_dir: - gpt2_utils.init_gpt2_checkpoint(self, self.pretrained_model_dir) + init_gpt2_checkpoint(self, self.pretrained_model_dir) elif self._hparams.initializer: initialize = layers.get_initializer(self._hparams.initializer) assert initialize is not None @@ -180,33 +185,33 @@ def default_hparams(): The default parameters are values for 117M GPT2 model. - `pretrained_model_name`: str or None + `"pretrained_model_name"`: str or None The name of the pre-trained GPT2 model. If None, the model will be randomly initialized. - `embed`: dict + `"embed"`: dict Hyperparameters for word embedding layer. - `vocab_size`: int + `"vocab_size"`: int The vocabulary size of `inputs` in `GPT2Model`. - `position_embed`: dict + `"position_embed"`: dict Hyperparameters for position embedding layer. - `position_size`: int + `"position_size"`: int The maximum sequence length that this model might ever be used with. - `decoder`: dict + `"decoder"`: dict Hyperparameters for the TransformerDecoder. See :func:`~texar.modules.TransformerDecoder.default_harams` for details. - `initializer`: dict, optional + `"initializer"`: dict, optional Hyperparameters of the default initializer that initializes variables created in this module. See :func:`~texar.core.get_initializer` for details. - `name`: str + `"name"`: str Name of the module. """ return { @@ -311,3 +316,7 @@ def forward(self, # type: ignore sequence_length=sequence_length) return output.logits + + @property + def output_size(self): + return self._hparams.decoder.dim diff --git a/texar/modules/encoders/gpt2_encoder_test.py b/texar/modules/encoders/gpt2_encoder_test.py index 4bc1a5d81..4540450fe 100644 --- a/texar/modules/encoders/gpt2_encoder_test.py +++ b/texar/modules/encoders/gpt2_encoder_test.py @@ -9,11 +9,24 @@ from texar.modules.encoders.gpt2_encoder import GPT2Encoder -@unittest.skip("Manual test only") class GPT2EncoderTest(unittest.TestCase): r"""Tests :class:`~texar.modules.GPT2Encoder` class. """ + @unittest.skip("Manual test only") + def test_model_loading(self): + r"""Tests model loading functionality.""" + inputs = torch.zeros(32, 16, dtype=torch.int64) + + # case 1 + encoder = GPT2Encoder(pretrained_model_name="117M") + _ = encoder(inputs) + + # case 2 + encoder = GPT2Encoder(pretrained_model_name="345M") + _ = encoder(inputs) + + @unittest.skip("Manual test only") def test_hparams(self): r"""Tests the priority of the encoder arch parameter. """ @@ -55,6 +68,7 @@ def test_hparams(self): _ = encoder(inputs) self.assertEqual(encoder.hparams.decoder.num_blocks, 12) + @unittest.skip("Manual test only") def test_trainable_variables(self): r"""Tests the functionality of automatically collecting trainable variables. @@ -89,7 +103,10 @@ def test_encode(self): r"""Tests encoding. """ # case 1: GPT2 117M - encoder = GPT2Encoder() + hparams = { + "pretrained_model_name": None + } + encoder = GPT2Encoder(hparams=hparams) max_time = 8 batch_size = 16 diff --git a/texar/modules/encoders/xlnet_encoder.py b/texar/modules/encoders/xlnet_encoder.py new file mode 100644 index 000000000..c13ffa4a8 --- /dev/null +++ b/texar/modules/encoders/xlnet_encoder.py @@ -0,0 +1,582 @@ +# Copyright 2019 The Texar Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +XLNet encoder. +""" + +from typing import Any, Dict, List, Optional, Tuple + + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from texar.core import layers +from texar.hyperparams import HParams +from texar.modules.encoders.encoder_base import EncoderBase +from texar.modules.pretrained.pretrained_base import PretrainedBase +from texar.modules.pretrained.xlnet_utils import (init_xlnet_checkpoint, + params_except_in) +from texar.modules.pretrained.xlnet_model_utils import ( + PositionWiseFF, RelativePositionalEncoding, RelativeMultiheadAttention) +from texar.utils.utils import dict_fetch, sum_tensors + + +__all__ = [ + "XLNetEncoder", +] + + +class XLNetEncoder(PretrainedBase, EncoderBase): + r"""Raw XLNet module for encoding sequences. + + This module supports the architecture first proposed + in `(Yang et al.)` XLNet. + + Args: + pretrained_model_name (optional): a str with the name + of a pre-trained model to load selected in the list of: + `xlnet-base-cased`, `xlnet-large-cased`. + If `None`, will use the model name in :attr:`hparams`. + cache_dir (optional): the path to a folder in which the + pre-trained models will be cached. If `None` (default), + a default directory will be used. + hparams (dict or HParams, optional): Hyperparameters. Missing + hyperparameter will be set to default values. See + :meth:`default_hparams` for the hyperparameter structure + and default values. + init (optional): whether to initialize `XLNetEncoder`. + """ + + model_name = "XLNet" + + def __init__(self, + pretrained_model_name: Optional[str] = None, + cache_dir: Optional[str] = None, + hparams=None, + init=True): + + super().__init__(pretrained_model_name=pretrained_model_name, + cache_dir=cache_dir, + hparams=hparams) + + if self.pretrained_model_dir: + self._hparams = HParams(self.pretrained_model_hparams, + self._hparams.todict()) + + num_layers = self._hparams.num_layers + num_heads = self._hparams.num_heads + head_dim = self._hparams.head_dim + + self.word_embed = nn.Embedding(self._hparams.vocab_size, + self._hparams.hidden_dim) + self.pos_embed = RelativePositionalEncoding( + hparams={ + "dim": self._hparams.hidden_dim, + "max_seq_len": self._hparams.max_seq_length, + }) + self.dropout = nn.Dropout(self._hparams.dropout) + + self.r_r_bias = None + self.r_w_bias = None + self.r_s_bias = None + + if not self._hparams.untie_r: + self.r_r_bias = nn.Parameter(torch.Tensor(num_heads, head_dim)) + self.r_w_bias = nn.Parameter(torch.Tensor(num_heads, head_dim)) + self.r_s_bias = (nn.Parameter(torch.Tensor(num_heads, head_dim)) + if self._hparams.use_segments else None) + + self.attn_layers = nn.ModuleList() + self.ff_layers = nn.ModuleList() + rel_attn_hparams = dict_fetch( + self._hparams, RelativeMultiheadAttention.default_hparams()) + ff_hparams = dict_fetch( + self._hparams, PositionWiseFF.default_hparams()) + for _ in range(num_layers): + self.attn_layers.append(RelativeMultiheadAttention( + self.r_r_bias, self.r_w_bias, self.r_s_bias, + hparams=rel_attn_hparams)) + self.ff_layers.append(PositionWiseFF(hparams=ff_hparams)) + + self.mask_emb = nn.Parameter( + torch.Tensor(1, 1, self._hparams.hidden_dim)) + + if init: + if self.pretrained_model_dir: + init_xlnet_checkpoint(self, self.pretrained_model_dir) + elif self._hparams.initializer: + initialize = layers.get_initializer(self._hparams.initializer) + assert initialize is not None + # Do not re-initialize LayerNorm modules. + for name, param in self.named_parameters(): + if name.split('.')[-1] == 'weight' \ + and 'layer_norm' not in name: + initialize(param) + else: + self.reset_parameters() + + def reset_parameters(self): + if not self._hparams.untie_r: + nn.init.normal_(self.r_w_bias, 0.0, 0.02) + nn.init.normal_(self.r_r_bias, 0.0, 0.02) + if self._hparams.use_segments: + nn.init.normal_(self.r_s_bias, 0.0, 0.02) + + @staticmethod + def default_hparams() -> Dict[str, Any]: + r"""Returns a dictionary of hyperparameters with default values. + + * The encoder arch is determined by the constructor argument + :attr:`pretrained_model_name` if it's specified. In this case, + `hparams` are ignored. + * Otherwise, the encoder arch is determined by + `hparams['pretrained_model_name']` if it's specified. All other + configurations in `hparams` are ignored. + * If the above two are `None`, the encoder arch is defined by the + configurations in `hparams` and weights are randomly initialized. + + .. code-block:: python + + { + "pretrained_model_name": "xlnet-base-cased", + "untie_r": True, + "num_layers": 12, + "mem_len": 0, + "reuse_len": 0, + "num_heads": 12, + "hidden_dim": 768, + "head_dim": 64, + "dropout": 0.1, + "attention_dropout": 0.1, + "use_segments": True, + "ffn_inner_dim": 3072, + "activation": 'gelu', + "vocab_size": 32000, + "max_seq_length": 512, + "initializer": None, + "name": "xlnet_encoder", + } + + Here: + + The default parameters are values for cased XLNet-Base model. + + `"pretrained_model_name"`: str or None + The name of the pre-trained XLNet model. If None, the model + will be randomly initialized. + + `"untie_r"`: bool + Whether to untie the biases in attention. + + `"num_layers"`: int + The number of stacked layers. + + `"mem_len"`: int + The number of tokens to cache. + + `"reuse_len"`: int + The number of tokens in the current batch to be cached and reused + in the future. + + `"num_heads"`: int + The number of attention heads. + + `"hidden_dim"`: int + The hidden size. + + `"head_dim"`: int + The dimension size of each attention head. + + `"dropout"`: float + Dropout rate. + + `"attention_dropout"`: float + Dropout rate on attention probabilities. + + `"use_segments"`: bool + Whether to use segment embedding. + + `"ffn_inner_dim"`: int + The hidden size in feed-forward layers. + + `"activation"`: str + `relu` or `gelu`. + + `"vocab_size"`: int + The vocabulary size. + + `"max_seq_length"`: int + The maximum sequence length for `RelativePositionalEncoding`. + + `"initializer"`: dict, optional + Hyperparameters of the default initializer that initializes + variables created in this module. + See :func:`~texar.core.get_initializer` for details. + + `"name"`: str + Name of the module. + """ + + return { + 'pretrained_model_name': 'xlnet-base-cased', + 'untie_r': True, + 'num_layers': 12, + 'mem_len': 0, + 'reuse_len': 0, + # layer + 'num_heads': 12, + 'hidden_dim': 768, + 'head_dim': 64, + 'dropout': 0.1, + 'attention_dropout': 0.1, + 'use_segments': True, + # ffn + 'ffn_inner_dim': 3072, + 'activation': 'gelu', + # embedding + 'vocab_size': 32000, + 'max_seq_length': 512, + 'initializer': None, + 'name': "xlnet_encoder", + '@no_typecheck': ['pretrained_model_name'], + } + + def param_groups(self, + lr: Optional[float] = None, + lr_layer_scale: float = 1.0, + decay_base_params: bool = False): + r"""Create parameter groups for optimizers. When + :attr:`lr_layer_decay_rate` is not 1.0, parameters from each layer form + separate groups with different base learning rates. + + Args: + lr (float): The learning rate. Can be omitted if + :attr:`lr_layer_decay_rate` is 1.0. + lr_layer_scale (float): Per-layer LR scaling rate. The `i`-th layer + will be scaled by `lr_layer_scale ^ (num_layers - i - 1)`. + decay_base_params (bool): If `True`, treat non-layer parameters + (e.g. embeddings) as if they're in layer 0. If `False`, these + parameters are not scaled. + + Returns: + The parameter groups, used as the first argument for optimizers. + """ + + if lr_layer_scale != 1.0: + if lr is None: + raise ValueError( + "lr must be specified when lr_layer_decay_rate is not 1.0") + + num_layers = self._hparams.num_layers + base_group = { + "params": params_except_in( + self, ['attn_layers', 'ff_layers']), + "lr": lr * (lr_layer_scale ** num_layers + if decay_base_params else 1.0) + } + param_groups = [base_group] + for idx in range(num_layers): + decay_rate = lr_layer_scale ** (num_layers - idx - 1) + param_group = { + "params": [*self.attn_layers[idx].parameters(), + *self.ff_layers[idx].parameters()], + "lr": lr * decay_rate, + } + param_groups.append(param_group) + else: + param_groups = self.parameters() + return param_groups + + @property + def output_size(self): + return self._hparams.hidden_dim + + @staticmethod + def _cache_mem(output: torch.Tensor, + prev_mem: Optional[torch.Tensor], + mem_len: int, + reuse_len: int = 0) -> torch.Tensor: + r"""Cache hidden states into memory.""" + assert mem_len > 0 + + if reuse_len is not None and reuse_len > 0: + output = output[:reuse_len] + if prev_mem is None: + new_mem = output[-mem_len:] + else: + new_mem = torch.cat([prev_mem, output], dim=0)[-mem_len:] + return new_mem.detach() + + def _create_causal_attn_mask(self, + seq_len: int, + mem_len: int, + same_length: bool = False) -> torch.Tensor: + r"""Create causal attention mask of shape + `(seq_len, mem_len + seq_len)`. + """ + assert self.r_w_bias is not None + device = self.r_w_bias.device + attn_mask = torch.ones(seq_len, seq_len, device=device) + mask_u = torch.triu(attn_mask, diagonal=1) + attn_mask_pad = torch.zeros(seq_len, mem_len, device=device) + ret = torch.cat([attn_mask_pad, mask_u], dim=1) + if same_length: + mask_l = torch.tril(attn_mask, diagonal=-1) + ret = torch.cat([ret[:, :seq_len] + mask_l, ret[:, seq_len:]], 1) + return ret + + def forward(self, # type: ignore + token_ids: torch.LongTensor, + segment_ids: Optional[torch.LongTensor] = None, + input_mask: Optional[torch.Tensor] = None, + memory: Optional[List[torch.Tensor]] = None, + permute_mask: Optional[torch.Tensor] = None, + target_mapping: Optional[torch.Tensor] = None, + bi_data: bool = False, + clamp_len: Optional[int] = None, + cache_len: int = 0, + same_length: bool = False, + attn_type: str = 'bi', + two_stream: bool = False) \ + -> Tuple[torch.Tensor, Optional[List[torch.Tensor]]]: + r"""Compute XLNet representations for the input. + + Args: + token_ids: Shape `[batch_size, seq_len]`. + segment_ids: Shape `[batch_size, seq_len]`. + input_mask: Float tensor of shape `[batch_size, seq_len]`. Note that + positions with value 1 are masked out. + memory: Memory from previous batches. A list of length `num_layers`, + each tensor of shape `[batch_size, mem_len, hidden_dim]`. + permute_mask: The permutation mask. Float tensor of shape + `[batch_size, seq_len, seq_len]`. + A value of 0 for ``permute_mask[i, j, k]`` indicates that + position `i` attends to position `j` in batch `k`. + target_mapping: The target token mapping. Float tensor of shape + `[batch_size, num_targets, seq_len]`. + A value of 1 for ``target_mapping[i, j, k]`` indicates that + the `i`-th target token (in order of permutation) in batch `k` + is the token at position `j`. + Each row ``target_mapping[i, :, k]`` can have no more than one + value of 1. + bi_data (bool): Whether to use bidirectional data input pipeline. + clamp_len (int): Clamp all relative distances larger than + :attr:`clamp_len`. A value of -1 means no clamping. + cache_len (int): Length of memory (number of tokens) to cache. + same_length (bool): Whether to use the same attention length for + each token. + attn_type (str): Attention type. Supported values are `"uni"` + and `"bi"`. + two_stream (bool): Whether to use two-stream attention. Only set to + `True` when pre-training or generating text. Defaults to + `False`. + + :returns: A tuple of `(output, new_memory)`: + + - **`output`**: The final layer output representations. Shape + `[batch_size, seq_len, hidden_dim]`. + - **`new_memory`**: The memory of the current batch. + If `cache_len` is 0, then `new_memory` is `None`. Otherwise, it is + a list of length `num_layers`, each tensor of shape + `[batch_size, cache_len, hidden_dim]`. + This can be used as the :attr:`memory` argument in the next batch. + """ + return self._forward(self.word_embed(token_ids), + segment_ids=segment_ids, + input_mask=input_mask, + memory=memory, + permute_mask=permute_mask, + target_mapping=target_mapping, + bi_data=bi_data, + clamp_len=clamp_len, + cache_len=cache_len, + same_length=same_length, + attn_type=attn_type, + two_stream=two_stream) + + def _forward(self, + word_embed: torch.Tensor, + segment_ids: Optional[torch.LongTensor] = None, + input_mask: Optional[torch.Tensor] = None, + memory: Optional[List[torch.Tensor]] = None, + permute_mask: Optional[torch.Tensor] = None, + target_mapping: Optional[torch.Tensor] = None, + bi_data: bool = False, + clamp_len: Optional[int] = None, + cache_len: int = 0, + same_length: bool = False, + attn_type: str = 'bi', + two_stream: bool = False) \ + -> Tuple[torch.Tensor, Optional[List[torch.Tensor]]]: + r"""Compute XLNet representations for the input. This layer exists + because :class:`XLNetDecoder` compute embeddings in the decoder helper. + + Args: + word_embed: Shape `[batch_size, seq_len, word_embed_dim]`. + segment_ids: Shape `[batch_size, seq_len]`. + input_mask: Float tensor of shape `[batch_size, seq_len]`. Note that + positions with value 1 are masked out. + memory: Memory from previous batches. A list of length `num_layers`, + each tensor of shape `[batch_size, mem_len, hidden_dim]`. + permute_mask: The permutation mask. Float tensor of shape + `[batch_size, seq_len, seq_len]`. + A value of 0 for ``permute_mask[i, j, k]`` indicates that + position `i` attends to position `j` in batch `k`. + target_mapping: The target token mapping. Float tensor of shape + `[batch_size, num_targets, seq_len]`. + A value of 1 for ``target_mapping[i, j, k]`` indicates that + the `i`-th target token (in order of permutation) in batch `k` + is the token at position `j`. + Each row ``target_mapping[i, :, k]`` can have no more than one + value of 1. + bi_data (bool): Whether to use bidirectional data input pipeline. + clamp_len (int): Clamp all relative distances larger than + :attr:`clamp_len`. A value of -1 means no clamping. + cache_len (int): Length of memory (number of tokens) to cache. + same_length (bool): Whether to use the same attention length for + each token. + attn_type (str): Attention type. Supported values are `"uni"` + and `"bi"`. + two_stream (bool): Whether to use two-stream attention. Only set to + `True` when pre-training or generating text. Defaults to + `False`. + + :returns: A tuple of `(output, new_memory)`: + + - **`output`**: The final layer output representations. Shape + `[batch_size, seq_len, hidden_dim]`. + - **`new_memory`**: The memory of the current batch. + If `cache_len` is 0, then `new_memory` is `None`. Otherwise, it is + a list of length `num_layers`, each tensor of shape + `[batch_size, cache_len, hidden_dim]`. + This can be used as the :attr:`memory` argument in the next batch. + """ + # word_embed: [seq_len, batch_size, word_embed_dim] + word_embed = word_embed.permute(1, 0, 2) + # segment_ids: [seq_len, batch_size] + if segment_ids is not None: + segment_ids = segment_ids.permute(1, 0) + # input_mask: [seq_len, batch_size] + if input_mask is not None: + input_mask = input_mask.permute(1, 0) + # memory: A list of length num_layers + # each tensor of shape [mem_len, batch_size, hidden_dim] + if memory is not None: + memory = [m.permute(1, 0, 2) for m in memory] + # permute_mask: [seq_len, seq_len, batch_size] + if permute_mask is not None: + permute_mask = permute_mask.permute(1, 2, 0) + # target_mapping: [num_targets, seq_len, batch_size] + if target_mapping is not None: + target_mapping = target_mapping.permute(1, 2, 0) + + seq_len, batch_size = word_embed.size()[:2] + mem_len = memory[0].size(0) if memory is not None else 0 + tot_len = seq_len + mem_len + reuse_len = self._hparams.reuse_len + + # Construct masks. + masks: List[Optional[torch.Tensor]] = [] + + # Causal attention mask. + if attn_type == 'uni': + causal_mask = self._create_causal_attn_mask( + seq_len, mem_len, same_length) + # attn_mask: (seq_len, tot_len, 1, 1) + causal_mask = causal_mask.unsqueeze(2).unsqueeze(3) + masks.append(causal_mask) + elif attn_type == 'bi': + pass + else: + raise ValueError(f"Unsupported attention type: {attn_type}") + + # Data mask: input mask & permutation mask. + if input_mask is not None: + input_mask = input_mask.expand(seq_len, -1, -1) + data_mask = sum_tensors([input_mask, permute_mask]) + if data_mask is not None: + # All positions in memory can be attended to. + memory_mask = data_mask.new_zeros(seq_len, mem_len, batch_size) + # data_mask: (seq_len, tot_len, batch_size, 1) + data_mask = torch.cat([memory_mask, data_mask], dim=1).unsqueeze(3) + masks.append(data_mask) + + # Exclude the main diagonal (target tokens) from the mask. + attn_mask = sum_tensors(masks) + if attn_mask is None: + final_mask = None + else: + attn_mask = (attn_mask > 0) + final_mask = -torch.eye(seq_len, device=attn_mask.device) + final_mask = torch.cat([ + final_mask.new_zeros(seq_len, mem_len), final_mask], dim=-1) + final_mask = final_mask.unsqueeze(2).unsqueeze(3) + # final_mask: (seq_len, tot_len, batch_size, 1) + final_mask = ((attn_mask.float() + final_mask) > 0) + + # Construct segment embedding. + if segment_ids is not None: + concat_segment_ids = torch.cat([ + segment_ids.new_zeros(mem_len, batch_size), segment_ids]) + segment_matrix = (segment_ids.unsqueeze(1) != + concat_segment_ids.unsqueeze(0)).long() + segment_matrix = F.one_hot(segment_matrix, num_classes=2).float() + else: + segment_matrix = None + + pos_embed = self.pos_embed( + batch_size, seq_len, tot_len, clamp_len, attn_type, bi_data) + pos_embed = self.dropout(pos_embed) + + states_h = self.dropout(word_embed) + if two_stream: + if target_mapping is not None: + word_embed_q = self.mask_emb.expand( + target_mapping.size(0), batch_size, -1) + else: + word_embed_q = word_embed + states_g = self.dropout(word_embed_q) + else: + states_g = None + new_memory = [] + + for idx in range(self._hparams.num_layers): + cur_memory = memory[idx] if memory is not None else None + if cache_len > 0: + new_memory.append(self._cache_mem( + states_h, cur_memory, cache_len, reuse_len)) + attn_layer: RelativeMultiheadAttention = self.attn_layers[idx] + states_h, states_g = attn_layer( + states_h=states_h, states_g=states_g, + pos_embed=pos_embed, segment_mat=segment_matrix, + attn_mask_h=final_mask, attn_mask_g=attn_mask, + target_mapping=target_mapping, memory=cur_memory) + states_h = self.ff_layers[idx](states_h) + if states_g is not None: + states_g = self.ff_layers[idx](states_g) + + output = self.dropout(states_h if states_g is None else states_g) + + # Now output: [seq_len, batch_size, hidden_dim] + # new_memory: None or A list of length num_layers, + # each tensor of shape [cache_len, batch_size, hidden_dim] + output = output.permute(1, 0, 2) + if new_memory is not None: + new_memory = [m.permute(1, 0, 2) for m in new_memory] + + if cache_len == 0: + return output, None + + return output, new_memory diff --git a/texar/modules/encoders/xlnet_encoder_test.py b/texar/modules/encoders/xlnet_encoder_test.py new file mode 100644 index 000000000..5d821a3c4 --- /dev/null +++ b/texar/modules/encoders/xlnet_encoder_test.py @@ -0,0 +1,149 @@ +""" +Unit tests for XLNet encoder. +""" + +import unittest + +import torch + +from texar.modules.encoders.xlnet_encoder import XLNetEncoder + + +class XLNetEncoderTest(unittest.TestCase): + r"""Tests :class:`~texar.modules.XLNetEncoder` class. + """ + + @unittest.skip("Manual test only") + def test_model_loading(self): + r"""Tests model loading functionality.""" + inputs = torch.zeros(32, 16, dtype=torch.int64) + + # case 1 + encoder = XLNetEncoder(pretrained_model_name="xlnet-base-cased") + _, _ = encoder(inputs) + + # case 2 + encoder = XLNetEncoder(pretrained_model_name="xlnet-large-cased") + _, _ = encoder(inputs) + + @unittest.skip("Manual test only") + def test_hparams(self): + r"""Tests the priority of the encoder arch parameter. + """ + inputs = torch.zeros(32, 16, dtype=torch.int64) + + # case 1: set "pretrained_mode_name" by constructor argument + hparams = { + "pretrained_model_name": "xlnet-large-cased", + } + encoder = XLNetEncoder(pretrained_model_name="xlnet-base-cased", + hparams=hparams) + _, _ = encoder(inputs) + self.assertEqual(encoder.hparams.num_layers, 12) + + # case 2: set "pretrained_mode_name" by hparams + hparams = { + "pretrained_model_name": "xlnet-large-cased", + "num_layers": 6 + } + encoder = XLNetEncoder(hparams=hparams) + _, _ = encoder(inputs) + self.assertEqual(encoder.hparams.num_layers, 24) + + # case 3: set to None in both hparams and constructor argument + hparams = { + "pretrained_model_name": None, + "num_layers": 6 + } + encoder = XLNetEncoder(hparams=hparams) + _, _ = encoder(inputs) + self.assertEqual(encoder.hparams.num_layers, 6) + + # case 4: using default hparams + encoder = XLNetEncoder() + _, _ = encoder(inputs) + self.assertEqual(encoder.hparams.num_layers, 12) + + @unittest.skip("Manual test only") + def test_trainable_variables(self): + r"""Tests the functionality of automatically collecting trainable + variables. + """ + inputs = torch.zeros(32, 16, dtype=torch.int64) + + # case 1: xlnet base + encoder = XLNetEncoder() + _, _ = encoder(inputs) + self.assertEqual(len(encoder.trainable_variables), 182) + + # Case 2: xlnet large + hparams = { + "pretrained_model_name": "xlnet-large-cased" + } + encoder = XLNetEncoder(hparams=hparams) + _, _ = encoder(inputs) + self.assertEqual(len(encoder.trainable_variables), 362) + + # case 3: self-designed bert + hparams = { + "num_layers": 6, + "pretrained_model_name": None + } + encoder = XLNetEncoder(hparams=hparams) + _, _ = encoder(inputs) + self.assertEqual(len(encoder.trainable_variables), 92) + + def test_encode(self): + r"""Tests encoding. + """ + # case 1: xlnet base + hparams = { + "pretrained_model_name": None + } + encoder = XLNetEncoder(hparams=hparams) + + max_time = 8 + batch_size = 16 + inputs = torch.randint(32000, (batch_size, max_time), dtype=torch.int64) + outputs, new_memory = encoder(inputs) + + self.assertEqual(outputs.shape, torch.Size([batch_size, + max_time, + encoder.output_size])) + self.assertEqual(new_memory, None) + + # case 2: self-designed xlnet + hparams = { + 'pretrained_model_name': None, + 'untie_r': True, + 'num_layers': 6, + 'mem_len': 0, + 'reuse_len': 0, + 'num_heads': 8, + 'hidden_dim': 32, + 'head_dim': 64, + 'dropout': 0.1, + 'attention_dropout': 0.1, + 'use_segments': True, + 'ffn_inner_dim': 256, + 'activation': 'gelu', + 'vocab_size': 32000, + 'max_seq_length': 128, + 'initializer': None, + 'name': "xlnet_encoder", + } + encoder = XLNetEncoder(hparams=hparams) + + max_time = 8 + batch_size = 16 + inputs = torch.randint(32000, (batch_size, max_time), dtype=torch.int64) + outputs, new_memory = encoder(inputs) + + self.assertEqual(outputs.shape, torch.Size([batch_size, + max_time, + encoder.output_size])) + self.assertEqual(new_memory, None) + + +if __name__ == "__main__": + unittest.main() diff --git a/texar/modules/pretrained/__init__.py b/texar/modules/pretrained/__init__.py index 242871ccf..8f0e4f360 100644 --- a/texar/modules/pretrained/__init__.py +++ b/texar/modules/pretrained/__init__.py @@ -15,7 +15,7 @@ Pre-trained modules of Texar library. """ -from texar.modules.pretrained.bert_base import * from texar.modules.pretrained.bert_utils import * -from texar.modules.pretrained.gpt2_base import * from texar.modules.pretrained.gpt2_utils import * +from texar.modules.pretrained.pretrained_base import * +from texar.modules.pretrained.xlnet_utils import * diff --git a/texar/modules/pretrained/bert_utils.py b/texar/modules/pretrained/bert_utils.py index 6869c5428..a09ad2644 100644 --- a/texar/modules/pretrained/bert_utils.py +++ b/texar/modules/pretrained/bert_utils.py @@ -12,19 +12,19 @@ # See the License for the specific language governing permissions and # limitations under the License. """ -Utils of Bert Modules. +Utils of BERT Modules. """ from typing import Dict, Optional import json import os -import sys import torch import torch.nn as nn from texar.data.data_utils import maybe_download +from texar.modules.pretrained.pretrained_utils import default_download_dir __all__ = [ @@ -169,35 +169,9 @@ def name_to_variable(model: nn.Module, name: str) -> nn.Module: return pointer -def _default_download_dir() -> str: - r"""Return the directory to which packages will be downloaded by default. - """ - package_dir = os.path.dirname(os.path.dirname( - os.path.dirname(os.path.dirname(__file__)))) - if os.access(package_dir, os.W_OK): - texar_download_dir = os.path.join(package_dir, 'texar_download') - else: - # On Windows, use %APPDATA% - if sys.platform == 'win32' and 'APPDATA' in os.environ: - home_dir = os.environ['APPDATA'] - - # Otherwise, install in the user's home directory. - else: - home_dir = os.path.expanduser('~/') - if home_dir == '~/': - raise ValueError("Could not find a default download directory") - - texar_download_dir = os.path.join(home_dir, 'texar_download') - - if not os.path.exists(texar_download_dir): - os.mkdir(texar_download_dir) - - return os.path.join(texar_download_dir, 'bert') - - def load_pretrained_bert(pretrained_model_name: str, cache_dir: Optional[str] = None) -> str: - r"""Return the directory in which the pretrained BERT is cached. + r"""Return the directory in which the pretrained `BERT` is cached. """ if pretrained_model_name in _MODEL2URL: download_path = _MODEL2URL[pretrained_model_name] @@ -206,7 +180,7 @@ def load_pretrained_bert(pretrained_model_name: str, "Pre-trained model not found: {}".format(pretrained_model_name)) if cache_dir is None: - cache_dir = _default_download_dir() + cache_dir = default_download_dir("bert") file_name = download_path.split('/')[-1] diff --git a/texar/modules/pretrained/bert_utils_test.py b/texar/modules/pretrained/bert_utils_test.py index 660417ee2..d3d89a03c 100644 --- a/texar/modules/pretrained/bert_utils_test.py +++ b/texar/modules/pretrained/bert_utils_test.py @@ -1,5 +1,5 @@ """ -Unit tests for bert utils. +Unit tests for BERT utils. """ import os @@ -8,11 +8,11 @@ from texar.modules.pretrained.bert_utils import * -class BertUtilsTest(unittest.TestCase): - r"""Tests bert utils. +class BERTUtilsTest(unittest.TestCase): + r"""Tests BERT utils. """ - def test_load_pretrained_model_AND_transform_bert_to_texar_config(self): + def test_load_pretrained_bert_AND_transform_bert_to_texar_config(self): pretrained_model_dir = load_pretrained_bert( pretrained_model_name="bert-base-uncased") diff --git a/texar/modules/pretrained/gpt2_utils.py b/texar/modules/pretrained/gpt2_utils.py index 9618132d3..a784d0bf1 100644 --- a/texar/modules/pretrained/gpt2_utils.py +++ b/texar/modules/pretrained/gpt2_utils.py @@ -25,6 +25,7 @@ import torch.nn as nn from texar.data.data_utils import maybe_download +from texar.modules.pretrained.pretrained_utils import default_download_dir __all__ = [ @@ -199,35 +200,9 @@ def name_to_variable(model: nn.Module, name: str) -> nn.Module: return pointer -def _default_download_dir() -> str: - r"""Return the directory to which packages will be downloaded by default. - """ - package_dir = os.path.dirname(os.path.dirname( - os.path.dirname(os.path.dirname(__file__)))) - if os.access(package_dir, os.W_OK): - texar_download_dir = os.path.join(package_dir, 'texar_download') - else: - # On Windows, use %APPDATA% - if sys.platform == 'win32' and 'APPDATA' in os.environ: - home_dir = os.environ['APPDATA'] - - # Otherwise, install in the user's home directory. - else: - home_dir = os.path.expanduser('~/') - if home_dir == '~/': - raise ValueError("Could not find a default download directory") - - texar_download_dir = os.path.join(home_dir, 'texar_download') - - if not os.path.exists(texar_download_dir): - os.mkdir(texar_download_dir) - - return os.path.join(texar_download_dir, 'gpt2') - - def load_pretrained_gpt2(pretrained_model_name: str, cache_dir: Optional[str] = None) -> str: - r"""Return the directory in which the pretrained GPT2 is cached. + r"""Return the directory in which the pretrained `GPT2` is cached. """ if pretrained_model_name in _MODEL2URL: download_path = _MODEL2URL[pretrained_model_name] @@ -236,7 +211,7 @@ def load_pretrained_gpt2(pretrained_model_name: str, "Pre-trained model not found: {}".format(pretrained_model_name)) if cache_dir is None: - cache_dir = _default_download_dir() + cache_dir = default_download_dir("gpt2") file_name = download_path.split('/')[-1] diff --git a/texar/modules/pretrained/gpt2_utils_test.py b/texar/modules/pretrained/gpt2_utils_test.py index ec06bf49a..d09d5ba35 100644 --- a/texar/modules/pretrained/gpt2_utils_test.py +++ b/texar/modules/pretrained/gpt2_utils_test.py @@ -8,11 +8,11 @@ from texar.modules.pretrained.gpt2_utils import * -class GPTUtilsTest(unittest.TestCase): +class GPT2UtilsTest(unittest.TestCase): r"""Tests GPT2 utils. """ - def test_load_pretrained_model_AND_transform_gpt2_to_texar_config(self): + def test_load_pretrained_gpt2_AND_transform_gpt2_to_texar_config(self): pretrained_model_dir = load_pretrained_gpt2( pretrained_model_name="117M") diff --git a/texar/modules/pretrained/bert_base.py b/texar/modules/pretrained/pretrained_base.py similarity index 55% rename from texar/modules/pretrained/bert_base.py rename to texar/modules/pretrained/pretrained_base.py index ffd4a3f3e..aa710596c 100644 --- a/texar/modules/pretrained/bert_base.py +++ b/texar/modules/pretrained/pretrained_base.py @@ -12,30 +12,33 @@ # See the License for the specific language governing permissions and # limitations under the License. """ -Base class for Bert Modules. +Base class for Pre-trained Modules. """ from typing import Optional from texar.module_base import ModuleBase -from texar.modules.pretrained import bert_utils +from texar.modules.pretrained.bert_utils import ( + load_pretrained_bert, transform_bert_to_texar_config) +from texar.modules.pretrained.gpt2_utils import ( + load_pretrained_gpt2, transform_gpt2_to_texar_config) +from texar.modules.pretrained.xlnet_utils import ( + load_pretrained_xlnet, transform_xlnet_to_texar_config) + __all__ = [ - "BertBase", + "PretrainedBase", ] -class BertBase(ModuleBase): - r"""Base class for all BERT classes to inherit. +class PretrainedBase(ModuleBase): + r"""Base class for all pre-trained classes to inherit. Args: - pretrained_model_name (optional): a str with the name - of a pre-trained model to load selected in the list of: - `bert-base-uncased`, `bert-large-uncased`, `bert-base-cased`, - `bert-large-cased`, `bert-base-multilingual-uncased`, - `bert-base-multilingual-cased`, `bert-base-chinese`. - If `None`, will use the model name in :attr:`hparams`. - cache_dir (optional): the path to a folder in which the + pretrained_model_name (optional): A str with the name + of a pre-trained model to load. If `None`, will use the model + name in :attr:`hparams`. + cache_dir (optional): The path to a folder in which the pre-trained models will be cached. If `None` (default), a default directory will be used. hparams (dict or HParams, optional): Hyperparameters. Missing @@ -48,20 +51,33 @@ def __init__(self, pretrained_model_name: Optional[str] = None, cache_dir: Optional[str] = None, hparams=None): - super().__init__(hparams) + + super().__init__(hparams=hparams) self.pretrained_model_dir: Optional[str] = None + if self.model_name == "BERT": + load_func = load_pretrained_bert + transform_func = transform_bert_to_texar_config + elif self.model_name == "GPT2": + load_func = load_pretrained_gpt2 + transform_func = transform_gpt2_to_texar_config + elif self.model_name == "XLNet": + load_func = load_pretrained_xlnet + transform_func = transform_xlnet_to_texar_config + else: + raise ValueError("Could not find this pre-trained model.") + if pretrained_model_name: - self.pretrained_model_dir = bert_utils.load_pretrained_bert( + self.pretrained_model_dir = load_func( pretrained_model_name, cache_dir) elif self._hparams.pretrained_model_name is not None: - self.pretrained_model_dir = bert_utils.load_pretrained_bert( + self.pretrained_model_dir = load_func( self._hparams.pretrained_model_name, cache_dir) if self.pretrained_model_dir: - self.pretrained_model_hparams = bert_utils.\ - transform_bert_to_texar_config(self.pretrained_model_dir) + self.pretrained_model_hparams = transform_func( + self.pretrained_model_dir) @staticmethod def default_hparams(): @@ -70,12 +86,13 @@ def default_hparams(): .. code-block:: python { - "name": "bert" + "pretrained_model_name": None, + "name": "pretrained_base" } """ return { - 'pretrained_model_name': 'bert-base-uncased', - 'name': 'bert_base', + 'pretrained_model_name': None, + 'name': "pretrained_base", '@no_typecheck': ['pretrained_model_name'] } @@ -83,7 +100,7 @@ def forward(self, inputs, *args, **kwargs): r"""Encodes the inputs and (optionally) conduct downstream prediction. Args: - inputs: Inputs to the BERT module. + inputs: Inputs to the pre-trained module. *args: Other arguments. **kwargs: Keyword arguments. diff --git a/texar/modules/pretrained/pretrained_utils.py b/texar/modules/pretrained/pretrained_utils.py new file mode 100644 index 000000000..b5415033a --- /dev/null +++ b/texar/modules/pretrained/pretrained_utils.py @@ -0,0 +1,50 @@ +# Copyright 2019 The Texar Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Utils of Pre-trained Modules. +""" + +import os +import sys + + +__all__ = [ + "default_download_dir", +] + + +def default_download_dir(name: str) -> str: + r"""Return the directory to which packages will be downloaded by default. + """ + package_dir = os.path.dirname(os.path.dirname( + os.path.dirname(os.path.dirname(__file__)))) + if os.access(package_dir, os.W_OK): + texar_download_dir = os.path.join(package_dir, 'texar_download') + else: + # On Windows, use %APPDATA% + if sys.platform == 'win32' and 'APPDATA' in os.environ: + home_dir = os.environ['APPDATA'] + + # Otherwise, install in the user's home directory. + else: + home_dir = os.path.expanduser('~/') + if home_dir == '~/': + raise ValueError("Could not find a default download directory") + + texar_download_dir = os.path.join(home_dir, 'texar_download') + + if not os.path.exists(texar_download_dir): + os.mkdir(texar_download_dir) + + return os.path.join(texar_download_dir, name) diff --git a/texar/modules/pretrained/xlnet_model_utils.py b/texar/modules/pretrained/xlnet_model_utils.py new file mode 100644 index 000000000..ec5a9a821 --- /dev/null +++ b/texar/modules/pretrained/xlnet_model_utils.py @@ -0,0 +1,342 @@ +# Copyright 2019 The Texar Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Model Utils of XLNet Modules. + +Adapted from +https://github.com/zihangdai/xlnet/blob/master/modeling.py +""" + +from typing import Any, Dict, Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from texar.core import get_layer +from texar.module_base import ModuleBase + +__all__ = [ + "PositionWiseFF", + "RelativeMultiheadAttention", + "RelativePositionalEncoding", +] + + +class PositionWiseFF(ModuleBase): + + def __init__(self, hparams=None): + super().__init__(hparams) + + hidden_dim = self._hparams.hidden_dim + ffn_inner_dim = self._hparams.ffn_inner_dim + dropout = self._hparams.dropout + activation = self._hparams.activation.capitalize() + if activation == 'Relu': + activation = 'ReLU' + elif activation == 'Gelu': + activation = 'GPTGELU' + + self.linear1 = nn.Linear(hidden_dim, ffn_inner_dim) + self.activation_fn = get_layer({"type": activation}) + self.dropout = nn.Dropout(dropout, inplace=True) + self.linear2 = nn.Linear(ffn_inner_dim, hidden_dim) + self.layer_norm = nn.LayerNorm(hidden_dim, eps=1e-12) + + @staticmethod + def default_hparams() -> Dict[str, Any]: + return { + "hidden_dim": 768, + "ffn_inner_dim": 3072, + "dropout": 0.1, + "activation": 'relu', + } + + def forward(self, # type: ignore + input: torch.Tensor) -> torch.Tensor: + # position-wise feed-forward + output = self.linear1(input) + output = self.activation_fn(output) + output = self.dropout(output) + output = self.linear2(output) + output = self.dropout(output) + # residual + layer norm + output = self.layer_norm(input + output) + return output + + +class PositionalEmbedding(nn.Module): + + def __init__(self, embed_dim: int): + super().__init__() + + freq_seq = torch.arange(0.0, embed_dim, 2.0) + inv_freq = 1 / (10000 ** (freq_seq / embed_dim)) + self.register_buffer('inv_freq', inv_freq) + + def forward(self, # type: ignore + pos_seq: torch.Tensor) -> torch.Tensor: + sinusoid = torch.ger(pos_seq, self.inv_freq) + pos_embed = torch.cat([sinusoid.sin(), sinusoid.cos()], dim=-1) + return pos_embed + + +class RelativePositionalEncoding(ModuleBase): + + def __init__(self, hparams=None): + super().__init__(hparams) + self.sinusoid_embed = PositionalEmbedding(self._hparams.dim) + + @staticmethod + def default_hparams(): + return { + "dim": 768, + "max_seq_len": 512, + } + + def _create_positional_embedding(self, + start: int, + end: int, + step: int, + batch_size: int, + clamp_len: Optional[int] = None) \ + -> torch.Tensor: + embed_buffer = next(self.sinusoid_embed.buffers()) + pos_seq = torch.arange(start, end, step, device=embed_buffer.device, + dtype=embed_buffer.dtype) + + if clamp_len is not None: + pos_seq = torch.clamp(pos_seq, -clamp_len, clamp_len) + + pos_embed = self.sinusoid_embed(pos_seq) + pos_embed = pos_embed.unsqueeze(1).expand(-1, batch_size, -1) + return pos_embed + + def forward(self, # type: ignore + batch_size: int, + seq_len: int, + total_len: int, + clamp_len: Optional[int] = None, + attn_type: str = 'bi', + bi_data: bool = True) -> torch.Tensor: + if attn_type == 'bi': + start, end = total_len, -seq_len + elif attn_type == 'uni': + start, end = total_len, -1 + else: + raise ValueError(f"Unknown `attn_type` {attn_type}") + + if bi_data: + if batch_size % 2 != 0: + raise ValueError("`batch_size` must be an even number") + fwd_pos_embed = self._create_positional_embedding( + start, end, -1, batch_size // 2, clamp_len) + bwd_pos_embed = self._create_positional_embedding( + -start, -end, 1, batch_size // 2, clamp_len) + pos_embed = torch.cat([fwd_pos_embed, bwd_pos_embed], dim=1) + else: + pos_embed = self._create_positional_embedding( + start, end, -1, batch_size, clamp_len) + return pos_embed + + +class RelativeMultiheadAttention(ModuleBase): + def __init__(self, + r_r_bias: Optional[nn.Parameter] = None, + r_w_bias: Optional[nn.Parameter] = None, + r_s_bias: Optional[nn.Parameter] = None, + hparams=None): + super().__init__(hparams) + + self.num_heads = self._hparams.num_heads + self.head_dim = self._hparams.head_dim + hidden_dim = self._hparams.hidden_dim + + self.head_projection = nn.Linear( + hidden_dim, 3 * self.num_heads * self.head_dim, bias=False) + self.pos_projection = nn.Linear( + hidden_dim, self.num_heads * self.head_dim, bias=False) + + self.dropout = nn.Dropout(self._hparams.dropout) + self.dropout_attn = nn.Dropout(self._hparams.attention_dropout) + self.output_projection = nn.Linear( + self.num_heads * self.head_dim, hidden_dim, bias=False) + + bias_shape = (self.num_heads, self.head_dim) + self.untie_r = r_r_bias is None + self.r_r_bias = (r_r_bias if r_r_bias is not None + else nn.Parameter(torch.Tensor(*bias_shape))) + self.r_w_bias = (r_w_bias if r_w_bias is not None + else nn.Parameter(torch.Tensor(*bias_shape))) + + if self._hparams.use_segments: + self.segment_embed = nn.Parameter(torch.Tensor( + 2, self.num_heads, self.head_dim)) + self.r_s_bias = (r_s_bias if r_s_bias is not None + else nn.Parameter(torch.Tensor(*bias_shape))) + + self.layer_norm = nn.LayerNorm(hidden_dim, eps=1e-12) + + self.scale = 1 / (self.head_dim ** 0.5) + self.reset_parameters() + + def reset_parameters(self): + if self.untie_r: + nn.init.normal_(self.r_w_bias, 0.0, 0.02) + nn.init.normal_(self.r_r_bias, 0.0, 0.02) + if self._hparams.use_segments: + nn.init.normal_(self.segment_embed, 0.0, 0.02) + if self.untie_r: + nn.init.normal_(self.r_s_bias, 0.0, 0.02) + + @staticmethod + def default_hparams() -> Dict[str, Any]: + return { + "num_heads": 12, + "hidden_dim": 768, + "head_dim": 64, + "dropout": 0.1, + "attention_dropout": 0.1, + "use_segments": True, + } + + @staticmethod + def _rel_shift(x: torch.Tensor, klen: int) -> torch.Tensor: + shape = x.size() + x = x.view(shape[1], shape[0], *shape[2:])[1:] + x = x.view(shape[0], shape[1] - 1, *shape[2:])[:, :klen] + return x + + def _compute_attention_score(self, + q_head: torch.Tensor, + k_head_h: torch.Tensor, + v_head_h: torch.Tensor, + k_head_r: torch.Tensor, + segment_mat: Optional[torch.Tensor] = None, + attn_mask: Optional[torch.Tensor] = None) \ + -> torch.Tensor: + # Content based attention score. + q_head_rw = q_head + self.r_w_bias + # attn_ac: (seq_len, tot_len, batch_size, n_head) + attn_ac = torch.einsum('ibnd,jbnd->ijbn', [q_head_rw, k_head_h]) + + # Position based attention score. + q_head_rr = q_head + self.r_r_bias + # attn_bd: (seq_len, tot_len, batch_size, n_head) + attn_bd = torch.einsum('ibnd,jbnd->ijbn', [q_head_rr, k_head_r]) + attn_bd = self._rel_shift(attn_bd, klen=attn_ac.size(1)) + + # Segment based attention score. + if segment_mat is None: + attn_ef = 0 + else: + q_head_rs = q_head + self.r_s_bias + attn_ef = torch.einsum( + 'ibnd,snd->ibns', [q_head_rs, self.segment_embed]) + attn_ef = torch.einsum('ijbs,ibns->ijbn', [segment_mat, attn_ef]) + + # Merge attention scores and perform masking. + # attn_score: (seq_len, tot_len, batch_size, n_head) + attn_score = attn_ac + attn_bd + attn_ef + attn_score.mul_(self.scale) + if attn_mask is not None: + if attn_mask.dim() == 2: + attn_mask = attn_mask[None, :, :, None] + elif attn_mask.dim() == 3: + attn_mask = attn_mask[:, :, :, None] + attn_score = attn_score.float().masked_fill( + attn_mask, -1e30).type_as(attn_score) + + # Compute attention probability. + # attn_prob: (seq_len, tot_len, batch_size, n_head) + attn_prob = F.softmax(attn_score, dim=1) + attn_prob = self.dropout_attn(attn_prob) + + # Compute attention vector. + attn_vec = torch.einsum('ijbn,jbnd->ibnd', [attn_prob, v_head_h]) + return attn_vec.contiguous() + + def _post_attention(self, attn_vec: torch.Tensor) -> torch.Tensor: + attn_vec = attn_vec.view(*attn_vec.size()[:2], -1) + attn_out = self.output_projection(attn_vec) + attn_out = self.dropout(attn_out) + return attn_out + + def forward(self, # type: ignore + states_h: torch.Tensor, + pos_embed: torch.Tensor, + states_g: Optional[torch.Tensor] = None, + segment_mat: Optional[torch.Tensor] = None, + attn_mask_h: Optional[torch.Tensor] = None, + attn_mask_g: Optional[torch.Tensor] = None, + target_mapping: Optional[torch.Tensor] = None, + memory: Optional[torch.Tensor] = None) \ + -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + seq_len, batch_size = states_h.size()[:2] + pos_len = pos_embed.size(0) + + if memory is not None and memory.dim() > 1: + concat_input = torch.cat([memory, states_h], dim=0) + else: + concat_input = states_h + + # Content heads. + heads = self.head_projection(concat_input) + q_head_h, k_head_h, v_head_h = torch.chunk(heads, 3, dim=-1) + q_head_h = q_head_h[-seq_len:] + tot_len = k_head_h.size(0) + + q_head_h = q_head_h.view( + seq_len, batch_size, self.num_heads, self.head_dim) + k_head_h = k_head_h.view( + tot_len, batch_size, self.num_heads, self.head_dim) + v_head_h = v_head_h.view( + tot_len, batch_size, self.num_heads, self.head_dim) + + # Positional heads. + k_head_r = self.pos_projection(pos_embed) + k_head_r = k_head_r.view( + pos_len, batch_size, self.num_heads, self.head_dim) + + # Core attention ops. + attn_vec_h = self._compute_attention_score( + q_head_h, k_head_h, v_head_h, k_head_r, + segment_mat, attn_mask_h) + + # Post attention processing. + attn_out_h = self._post_attention(attn_vec_h) + # residual + layer norm + output_h = self.layer_norm(states_h + attn_out_h) + + if states_g is not None: + proj_dim = self.num_heads * self.head_dim + proj_weight = self.head_projection.weight[:proj_dim] + q_head_g = F.linear(states_g, proj_weight) + q_head_g = q_head_g.view( + q_head_g.size(0), batch_size, self.num_heads, self.head_dim) + if target_mapping is not None: + q_head_g = torch.einsum( + 'mbnd,mlb->lbnd', [q_head_g, target_mapping]) + attn_vec_g = self._compute_attention_score( + q_head_g, k_head_h, v_head_h, k_head_r, + segment_mat, attn_mask_g) + if target_mapping is not None: + attn_vec_g = torch.einsum( + 'lbnd,mlb->mbnd', [attn_vec_g, target_mapping]) + attn_out_g = self._post_attention(attn_vec_g) + output_g = self.layer_norm(states_g + attn_out_g) + else: + output_g = None + + return output_h, output_g diff --git a/texar/modules/pretrained/xlnet_model_utils_test.py b/texar/modules/pretrained/xlnet_model_utils_test.py new file mode 100644 index 000000000..54b4db346 --- /dev/null +++ b/texar/modules/pretrained/xlnet_model_utils_test.py @@ -0,0 +1,87 @@ +""" +Unit tests for XLNet model utils. +""" + +import unittest + +import torch + +from texar.modules.pretrained.xlnet_model_utils import * + + +class XLNetModelUtilsTest(unittest.TestCase): + r"""Tests XLNet model utils. + """ + + def test_PositionWiseFF(self): + + # Case 1 + model = PositionWiseFF() + inputs = torch.rand(32, model._hparams.hidden_dim) + outputs = model(inputs) + self.assertEqual(outputs.shape, torch.Size([32, + model._hparams.hidden_dim])) + + # Case 2 + hparams = { + "hidden_dim": 16, + "ffn_inner_dim": 32, + "dropout": 0.1, + "activation": 'relu', + } + model = PositionWiseFF(hparams=hparams) + inputs = torch.rand(32, 16) + outputs = model(inputs) + self.assertEqual(outputs.shape, torch.Size([32, 16])) + + # Case 3 + hparams = { + "hidden_dim": 16, + "ffn_inner_dim": 32, + "dropout": 0.1, + "activation": 'gelu', + } + model = PositionWiseFF(hparams=hparams) + inputs = torch.rand(32, 16) + outputs = model(inputs) + self.assertEqual(outputs.shape, torch.Size([32, 16])) + + def test_RelativeMultiheadAttention(self): + + model = RelativeMultiheadAttention() + + states_h = torch.rand(16, 32, model._hparams.hidden_dim) + pos_embed = torch.rand(24, 32, model._hparams.hidden_dim) + + output_h, output_g = model(states_h=states_h, pos_embed=pos_embed) + + self.assertEqual(output_h.shape, + torch.Size([16, 32, model._hparams.hidden_dim])) + self.assertEqual(output_g, None) + + def test_RelativePositionalEncoding(self): + + batch_size = 16 + seq_len = 8 + total_len = 32 + + # Case 1 + model = RelativePositionalEncoding() + pos_embed = model(batch_size=batch_size, + seq_len=seq_len, + total_len=total_len) + self.assertEqual(pos_embed.shape, + torch.Size([40, 16, model._hparams.dim])) + + # Case 2 + model = RelativePositionalEncoding() + pos_embed = model(batch_size=batch_size, + seq_len=seq_len, + total_len=total_len, + attn_type='uni') + self.assertEqual(pos_embed.shape, + torch.Size([33, 16, model._hparams.dim])) + + +if __name__ == "__main__": + unittest.main() diff --git a/texar/modules/pretrained/xlnet_utils.py b/texar/modules/pretrained/xlnet_utils.py new file mode 100644 index 000000000..98c88ffd7 --- /dev/null +++ b/texar/modules/pretrained/xlnet_utils.py @@ -0,0 +1,219 @@ +# Copyright 2019 The Texar Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Utils of XLNet Modules. +""" + +from typing import Callable, Dict, Iterable, List, Optional, Union + +import itertools + +import json +import os +import numpy as np + +import torch +import torch.nn as nn + +from texar.data.data_utils import maybe_download +from texar.modules.pretrained.xlnet_model_utils import \ + (PositionWiseFF, RelativeMultiheadAttention) +from texar.modules.pretrained.pretrained_utils import default_download_dir + + +__all__ = [ + "init_xlnet_checkpoint", + "load_pretrained_xlnet", + "transform_xlnet_to_texar_config", + "params_except_in", +] + + +_XLNET_PATH = "https://storage.googleapis.com/xlnet/released_models/" +_MODEL2URL = { + 'xlnet-base-cased': + _XLNET_PATH + "cased_L-12_H-768_A-12.zip", + 'xlnet-large-cased': + _XLNET_PATH + "cased_L-24_H-1024_A-16.zip", +} + + +def init_xlnet_checkpoint(model: nn.Module, cache_dir: str): + r"""Initializes XLNet model parameters from a checkpoint provided by Google. + """ + # remember to call .contiguous after trans_fn + import tensorflow as tf + ckpt = tf.train.load_checkpoint(os.path.join(cache_dir, 'xlnet_model.ckpt')) + from_params: Dict[str, np.ndarray] = { + key: ckpt.get_tensor(key) + for key in ckpt.get_variable_to_shape_map().keys()} + del from_params["global_step"] # useless variable + to_params: Dict[str, nn.Parameter] = dict(model.named_parameters()) + + def get_weight(name: str) -> torch.Tensor: + weight = from_params["model/" + name] + del from_params["model/" + name] + return torch.from_numpy(weight) + + TransFn = Callable[[torch.Tensor], torch.Tensor] + + def assign(param: nn.Parameter, weight: Union[str, torch.Tensor], + trans_fn: Optional[TransFn] = None, allow_fail: bool = False): + param_key = next(k for k, v in to_params.items() if v is param) + del to_params[param_key] # delete regardless of whether weight exists + if isinstance(weight, str): + try: + weight = get_weight(weight) + except KeyError: + if allow_fail: + print(f"Weight {weight} not found in checkpoint") + return + else: + raise + if trans_fn is not None: + weight = trans_fn(weight).contiguous() + if param.size() != weight.size(): + raise ValueError(f"Expected size {param.size()}, " + f"actual size {weight.size()}") + param.data = weight + + def assign_linear(linear: nn.Linear, prefix: str): + trans_fn = lambda p: p.view(p.size(0), -1).t() + assign(linear.weight, prefix + "kernel", trans_fn) + if linear.bias is not None: + assign(linear.bias, prefix + "bias") + + def assign_layer_norm(layer_norm: nn.LayerNorm, prefix: str): + assign(layer_norm.weight, prefix + "LayerNorm/gamma") + assign(layer_norm.bias, prefix + "LayerNorm/beta") + + def load_xlnet_model(xlnet): + n_layers = len(xlnet.attn_layers) + for bias_name in ['r_r_bias', 'r_w_bias', 'r_s_bias']: + weight = get_weight("transformer/" + bias_name) + if xlnet.hparams.untie_r: + for idx in range(n_layers): + layer: RelativeMultiheadAttention = xlnet.attn_layers[idx] + assign(getattr(layer, bias_name), weight[idx]) + else: + assign(getattr(xlnet, bias_name), weight) + assign(xlnet.word_embed.weight, + "transformer/word_embedding/lookup_table") + + for idx in range(n_layers): + layer: RelativeMultiheadAttention = xlnet.attn_layers[idx] + prefix = f"transformer/layer_{idx}/rel_attn/" + qkv_weights = [get_weight(prefix + f"{part}/kernel") + for part in "qkv"] + assign(layer.head_projection.weight, + torch.cat([ + p.view(p.size(0), -1) for p in qkv_weights + ], dim=1).t()) + assign_linear(layer.pos_projection, prefix + "r/") + assign(layer.output_projection.weight, # DO NOT TRANSPOSE THIS!!!! + prefix + "o/kernel", lambda p: p.view(p.size(0), -1)) + assign_layer_norm(layer.layer_norm, prefix) + + for idx in range(n_layers): + layer: PositionWiseFF = xlnet.ff_layers[idx] + prefix = f"transformer/layer_{idx}/ff/" + for linear_idx in range(1, 2 + 1): + linear_prefix = f"{prefix}layer_{linear_idx}/" + linear_layer: nn.Linear = getattr(layer, f"linear{linear_idx}") + assign_linear(linear_layer, linear_prefix) + assign_layer_norm(layer.layer_norm, prefix) + + seg_embeds = [ + p.squeeze(0) + for p in torch.chunk( + get_weight("transformer/seg_embed"), n_layers, dim=0)] + for idx in range(n_layers): + assign(xlnet.attn_layers[idx].segment_embed, seg_embeds[idx]) + + if hasattr(xlnet, 'mask_emb') and hasattr(xlnet, 'lm_bias'): + assign(xlnet.mask_emb, "transformer/mask_emb/mask_emb") + assign(xlnet.lm_bias, "lm_loss/bias") + + load_xlnet_model(model) + + if len(from_params) > 0: + print(f"WARNING: Certain weights from checkpoint are not loaded: " + f"{list(from_params.keys())}") + + filtered_to_params = [k for k in to_params if k.startswith("xlnet")] + if len(filtered_to_params) > 0: + print(f"WARNING: Certain parameters are not initialized: " + f"{list(filtered_to_params)}") + + +def load_pretrained_xlnet(pretrained_model_name: str, + cache_dir: Optional[str] = None) -> str: + r"""Return the directory in which the pretrained `XLNet` is cached. + """ + if pretrained_model_name in _MODEL2URL: + download_path = _MODEL2URL[pretrained_model_name] + else: + raise ValueError( + "Pre-trained model not found: {}".format(pretrained_model_name)) + + if cache_dir is None: + cache_dir = default_download_dir("xlnet") + + file_name = download_path.split('/')[-1] + + cache_path = os.path.join(cache_dir, 'xlnet_' + file_name.split('.')[0]) + if not os.path.exists(cache_path): + maybe_download(download_path, cache_dir, extract=True) + else: + print("Using cached pre-trained XLNet model from: %s." % cache_path) + + return cache_path + + +def transform_xlnet_to_texar_config(cache_dir: str) -> Dict: + r"""Load the Json config file and transform it into Texar style + configuration. + """ + info = list(os.walk(cache_dir)) + root, _, files = info[0] + config_path = None + for file in files: + if file.endswith('config.json'): + config_path = os.path.join(root, file) + if config_path is None: + raise ValueError("Cannot find the config file in {}".format(cache_dir)) + + with open(config_path) as f: + config_ckpt = json.loads(f.read()) + + configs = {} + configs["head_dim"] = config_ckpt["d_head"] + configs["ffn_inner_dim"] = config_ckpt["d_inner"] + configs["hidden_dim"] = config_ckpt["d_model"] + configs["activation"] = config_ckpt["ff_activation"] + configs["num_heads"] = config_ckpt["n_head"] + configs["num_layers"] = config_ckpt["n_layer"] + configs["vocab_size"] = config_ckpt["n_token"] + configs["untie_r"] = config_ckpt["untie_r"] + + return configs + + +def params_except_in(module: nn.Module, + except_names: List[str]) \ + -> Iterable[nn.Parameter]: + return itertools.chain.from_iterable( + child.parameters() for name, child in + module.named_children() + if name not in except_names) diff --git a/texar/modules/pretrained/xlnet_utils_test.py b/texar/modules/pretrained/xlnet_utils_test.py new file mode 100644 index 000000000..ed6aeb935 --- /dev/null +++ b/texar/modules/pretrained/xlnet_utils_test.py @@ -0,0 +1,43 @@ +""" +Unit tests for XLNet utils. +""" + +import os +import unittest + +from texar.modules.pretrained.xlnet_utils import * + + +class XLNetUtilsTest(unittest.TestCase): + r"""Tests XLNet utils. + """ + + def test_load_pretrained_xlnet_AND_transform_xlnet_to_texar_config(self): + + pretrained_model_dir = load_pretrained_xlnet( + pretrained_model_name="xlnet-base-cased") + + info = list(os.walk(pretrained_model_dir)) + _, _, files = info[0] + self.assertIn('spiece.model', files) + self.assertIn('xlnet_model.ckpt.meta', files) + self.assertIn('xlnet_model.ckpt.data-00000-of-00001', files) + self.assertIn('xlnet_model.ckpt.index', files) + self.assertIn('xlnet_config.json', files) + + model_config = transform_xlnet_to_texar_config(pretrained_model_dir) + + exp_config = {'head_dim': 64, + 'ffn_inner_dim': 3072, + 'hidden_dim': 768, + 'activation': 'gelu', + 'num_heads': 12, + 'num_layers': 12, + 'vocab_size': 32000, + 'untie_r': True} + + self.assertDictEqual(model_config, exp_config) + + +if __name__ == "__main__": + unittest.main() diff --git a/texar/modules/regressors/__init__.py b/texar/modules/regressors/__init__.py new file mode 100644 index 000000000..b687d161b --- /dev/null +++ b/texar/modules/regressors/__init__.py @@ -0,0 +1,19 @@ +# Copyright 2019 The Texar Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Modules of Texar library regressors. +""" + +from texar.modules.regressors.regressor_base import * +from texar.modules.regressors.xlnet_regressor import * diff --git a/texar/modules/regressors/regressor_base.py b/texar/modules/regressors/regressor_base.py new file mode 100644 index 000000000..cfeb6d49a --- /dev/null +++ b/texar/modules/regressors/regressor_base.py @@ -0,0 +1,37 @@ +# Copyright 2019 The Texar Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Base class for regressors. +""" +from abc import ABC +from typing import Any, Dict + +from texar.module_base import ModuleBase + +__all__ = [ + "RegressorBase", +] + + +class RegressorBase(ModuleBase, ABC): + r"""Base class inherited by all regressor classes. + """ + + @staticmethod + def default_hparams() -> Dict[str, Any]: + r"""Returns a dictionary of hyperparameters with default values. + """ + return { + "name": "regressor" + } diff --git a/texar/modules/regressors/xlnet_regressor.py b/texar/modules/regressors/xlnet_regressor.py new file mode 100644 index 000000000..a43597bd9 --- /dev/null +++ b/texar/modules/regressors/xlnet_regressor.py @@ -0,0 +1,242 @@ +# Copyright 2019 The Texar Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +XLNet Regressors. +""" + +from typing import Any, Dict, Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from texar.hyperparams import HParams +from texar.modules.regressors.regressor_base import RegressorBase +from texar.modules.encoders.xlnet_encoder import XLNetEncoder +from texar.modules.pretrained.xlnet_utils import params_except_in +from texar.utils.utils import dict_fetch + + +__all__ = [ + "XLNetRegressor", +] + + +class XLNetRegressor(RegressorBase): + r"""Regressor based on XLNet modules. + + Arguments are the same as in + :class:`~texar.modules.XLNetEncoder`. + + Args: + pretrained_model_name (optional): a str with the name + of a pre-trained model to load selected in the list of: + `xlnet-base-cased`, `xlnet-large-cased`. + If `None`, will use the model name in :attr:`hparams`. + cache_dir (optional): the path to a folder in which the + pre-trained models will be cached. If `None` (default), + a default directory will be used. + hparams (dict or HParams, optional): Hyperparameters. Missing + hyperparameters will be set to default values. See + :meth:`default_hparams` for the hyperparameter structure + and default values. + """ + + def __init__(self, + pretrained_model_name: Optional[str] = None, + cache_dir: Optional[str] = None, + hparams=None): + + super().__init__(hparams=hparams) + + # Create the underlying encoder + encoder_hparams = dict_fetch(hparams, XLNetEncoder.default_hparams()) + + self._encoder = XLNetEncoder( + pretrained_model_name=pretrained_model_name, + cache_dir=cache_dir, + hparams=encoder_hparams) + + if self._hparams.use_projection: + if self._hparams.regr_strategy == 'all_time': + self.projection = nn.Linear( + self._encoder.output_size * self._hparams.max_seq_length, + self._encoder.output_size * self._hparams.max_seq_length) + else: + self.projection = nn.Linear(self._encoder.output_size, + self._encoder.output_size) + self.dropout = nn.Dropout(self._hparams.dropout) + + logit_kwargs = self._hparams.logit_layer_kwargs + if logit_kwargs is None: + logit_kwargs = {} + elif not isinstance(logit_kwargs, HParams): + raise ValueError("hparams['logit_layer_kwargs'] " + "must be a dict.") + else: + logit_kwargs = logit_kwargs.todict() + + if self._hparams.regr_strategy == 'all_time': + self.hidden_to_logits = nn.Linear( + self._encoder.output_size * self._hparams.max_seq_length, + 1, **logit_kwargs) + else: + self.hidden_to_logits = nn.Linear( + self._encoder.output_size, 1, **logit_kwargs) + + @staticmethod + def default_hparams() -> Dict[str, Any]: + r"""Returns a dictionary of hyperparameters with default values. + + .. code-block:: python + + { + # (1) Same hyperparameters as in XLNetEncoder + ... + # (2) Additional hyperparameters + "regr_strategy": "cls_time", + "use_projection": True, + "logit_layer_kwargs": None, + "name": "xlnet_regressor", + } + + Here: + + 1. Same hyperparameters as in + :class:`~texar.modules.XLNetEncoder`. + See the :meth:`~texar.modules.XLNetEncoder.default_hparams`. + An instance of XLNetEncoder is created for feature extraction. + + 2. Additional hyperparameters: + + `"regr_strategy"`: str + The regression strategy, one of: + + - **cls_time**: Sequence-level regression based on the + output of the first time step (which is the `CLS` token). + Each sequence has a prediction. + - **all_time**: Sequence-level regression based on + the output of all time steps. Each sequence has a prediction. + - **time_wise**: Step-wise regression, i.e., make + regression for each time step based on its output. + + `"logit_layer_kwargs"`: dict + Keyword arguments for the logit :torch_nn:`Linear` layer + constructor. Ignored if no extra logit layer is appended. + + `"use_projection"`: bool + If `True`, an additional :torch_nn:`Linear` layer is added after + the summary step. + + `"name"`: str + Name of the regressor. + """ + + hparams = XLNetEncoder.default_hparams() + hparams.update(({ + "regr_strategy": "cls_time", + "use_projection": True, + "logit_layer_kwargs": None, + "name": "xlnet_regressor", + })) + return hparams + + def param_groups(self, + lr: Optional[float] = None, + lr_layer_scale: float = 1.0, + decay_base_params: bool = False): + r"""Create parameter groups for optimizers. When + :attr:`lr_layer_decay_rate` is not 1.0, parameters from each layer form + separate groups with different base learning rates. + + Args: + lr (float): The learning rate. Can be omitted if + :attr:`lr_layer_decay_rate` is 1.0. + lr_layer_scale (float): Per-layer LR scaling rate. The `i`-th layer + will be scaled by `lr_layer_scale ^ (num_layers - i - 1)`. + decay_base_params (bool): If `True`, treat non-layer parameters + (e.g. embeddings) as if they're in layer 0. If `False`, these + parameters are not scaled. + + Returns: + The parameter groups, used as the first argument for optimizers. + """ + + if lr_layer_scale != 1.0: + if lr is None: + raise ValueError( + "lr must be specified when lr_layer_decay_rate is not 1.0") + + fine_tune_group = { + "params": params_except_in(self, ["_encoder"]), + "lr": lr + } + param_groups = [fine_tune_group] + param_group = self._encoder.param_groups(lr, lr_layer_scale, + decay_base_params) + param_groups.extend(param_group) + else: + param_groups = self.parameters() + return param_groups + + def forward(self, # type: ignore + token_ids: torch.LongTensor, + segment_ids: Optional[torch.LongTensor] = None, + input_mask: Optional[torch.Tensor] = None) -> torch.Tensor: + r"""Feeds the inputs through the network and makes regression. + + Args: + token_ids: Shape `[batch_size, max_time]`. + segment_ids: Shape `[batch_size, max_time]`. + input_mask: Float tensor of shape `[batch_size, max_time]`. Note + that positions with value 1 are masked out. + + Returns: + Regression predictions. + + - If ``regr_strategy`` is ``cls_time`` or ``all_time``, predictions + have shape `[batch_size]`. + + - If ``clas_strategy`` is ``time_wise``, predictions have shape + `[batch_size, max_time]`. + """ + # output: [batch_size, seq_len, hidden_dim] + output, _ = self._encoder(token_ids=token_ids, + segment_ids=segment_ids, + input_mask=input_mask) + + strategy = self._hparams.regr_strategy + if strategy == 'time_wise': + summary = output + elif strategy == 'cls_time': + summary = output[:, -1] + elif strategy == 'all_time': + length_diff = self._hparams.max_seq_length - token_ids.shape[1] + summary_input = F.pad(output, [0, 0, 0, length_diff, 0, 0]) + summary_input_dim = (self._encoder.output_size * + self._hparams.max_seq_length) + + summary = summary_input.contiguous().view(-1, summary_input_dim) + else: + raise ValueError('Unknown regression strategy: {}'.format( + strategy)) + + if self._hparams.use_projection: + summary = torch.tanh(self.projection(summary)) + + summary = self.dropout(summary) + + preds = self.hidden_to_logits(summary).squeeze(-1) + + return preds diff --git a/texar/modules/regressors/xlnet_regressor_test.py b/texar/modules/regressors/xlnet_regressor_test.py new file mode 100644 index 000000000..06b9210be --- /dev/null +++ b/texar/modules/regressors/xlnet_regressor_test.py @@ -0,0 +1,107 @@ +""" +Unit tests for XLNet regressor. +""" + +import unittest + +import torch + +from texar.modules.regressors.xlnet_regressor import * + + +class XLNetRegressorTest(unittest.TestCase): + r"""Tests :class:`~texar.modules.XLNetRegressor` class. + """ + + @unittest.skip("Manual test only") + def test_model_loading(self): + r"""Tests model loading functionality.""" + inputs = torch.zeros(32, 16, dtype=torch.int64) + + # case 1 + regressor = XLNetRegressor(pretrained_model_name="xlnet-base-cased") + _ = regressor(inputs) + + # case 2 + regressor = XLNetRegressor(pretrained_model_name="xlnet-large-cased") + _ = regressor(inputs) + + def test_trainable_variables(self): + r"""Tests the functionality of automatically collecting trainable + variables. + """ + inputs = torch.zeros(32, 16, dtype=torch.int64) + + # case 1 + hparams = { + "pretrained_model_name": None + } + regressor = XLNetRegressor(hparams=hparams) + _ = regressor(inputs) + self.assertEqual(len(regressor.trainable_variables), 182 + 4) + + # case 2 + hparams = { + "pretrained_model_name": None, + "use_projection": False + } + regressor = XLNetRegressor(hparams=hparams) + _ = regressor(inputs) + self.assertEqual(len(regressor.trainable_variables), 182 + 2) + + # case 3 + hparams = { + "pretrained_model_name": None, + "regr_strategy": "all_time", + "max_seq_length": 8 + } + regressor = XLNetRegressor(hparams=hparams) + _ = regressor(inputs) + self.assertEqual(len(regressor.trainable_variables), 182 + 4) + + # case 4 + hparams = { + "pretrained_model_name": None, + "regr_strategy": "time_wise" + } + regressor = XLNetRegressor(hparams=hparams) + _ = regressor(inputs) + self.assertEqual(len(regressor.trainable_variables), 182 + 4) + + def test_regression(self): + r"""Tests regression. + """ + max_time = 8 + batch_size = 16 + inputs = torch.randint(32000, (batch_size, max_time), dtype=torch.int64) + + # case 1 + hparams = { + "pretrained_model_name": None + } + regressor = XLNetRegressor(hparams=hparams) + preds = regressor(inputs) + self.assertEqual(preds.shape, torch.Size([batch_size])) + + # case 2 + hparams = { + "pretrained_model_name": None, + "regr_strategy": "all_time", + "max_seq_length": max_time + } + regressor = XLNetRegressor(hparams=hparams) + preds = regressor(inputs) + self.assertEqual(preds.shape, torch.Size([batch_size])) + + # case 3 + hparams = { + "pretrained_model_name": None, + "regr_strategy": "time_wise" + } + regressor = XLNetRegressor(hparams=hparams) + preds = regressor(inputs) + self.assertEqual(preds.shape, torch.Size([batch_size, max_time])) + + +if __name__ == "__main__": + unittest.main() diff --git a/texar/utils/utils.py b/texar/utils/utils.py index eec17e46b..eff26d4a9 100644 --- a/texar/utils/utils.py +++ b/texar/utils/utils.py @@ -65,6 +65,7 @@ 'default_str', 'uniquify_str', 'ceildiv', + 'sum_tensors', ] T = TypeVar('T') # type argument @@ -1093,3 +1094,22 @@ def ceildiv(a: int, b: int) -> int: int: The quotient, rounded up. """ return -(-a // b) + + +def sum_tensors(xs: List[Optional[torch.Tensor]]) -> Optional[torch.Tensor]: + r"""Sum a list of tensors with possible `None` values. + + Args: + xs: A list of tensors. + + Returns: + The summation of all the elements in the list. + """ + idx = next((idx for idx, tensor in enumerate(xs) if tensor is not None), -1) + if idx == -1: + return None + ret = xs[idx] + for tensor in xs[(idx + 1):]: + if tensor is not None: + ret = ret + tensor + return ret diff --git a/texar/utils/utils_test.py b/texar/utils/utils_test.py index 3179e33a8..9aa142791 100644 --- a/texar/utils/utils_test.py +++ b/texar/utils/utils_test.py @@ -154,6 +154,20 @@ def test_uniquify_str(self): unique_str = utils.uniquify_str('str', str_set) self.assertEqual(unique_str, 'str_3') + def test_sum_tensors(self): + + inputs = [torch.tensor(1), torch.tensor(2)] + self.assertEqual(utils.sum_tensors(inputs), torch.tensor(3)) + + inputs = [torch.tensor(1), None, torch.tensor(2)] + self.assertEqual(utils.sum_tensors(inputs), torch.tensor(3)) + + inputs = [torch.tensor(1), None, None] + self.assertEqual(utils.sum_tensors(inputs), torch.tensor(1)) + + inputs = [None, None, None] + self.assertEqual(utils.sum_tensors(inputs), None) + # def test_map_ids_to_strs(self): # """Tests :func:`texar.utils.map_ids_to_strs`. # """