Skip to content

Commit

Permalink
Update CI config & type annotations (#103)
Browse files Browse the repository at this point in the history
- Update Travis CI config to specify version of code tools
- Update code to adapt to changes introduced in mypy 0.720
- Update CI type checking config to also check examples
- Update type annotations for XLNet example
- Other minor tweaks that does not affect usage
  • Loading branch information
huzecong committed Jul 15, 2019
1 parent 0ebb169 commit b23b5f3
Show file tree
Hide file tree
Showing 11 changed files with 48 additions and 44 deletions.
1 change: 1 addition & 0 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ script:
- flake8 texar/ examples/
# type-checking
- mypy .
- for dir in `echo examples/**/`; do mypy $dir; done
# unit tests
- pytest

Expand Down
7 changes: 5 additions & 2 deletions examples/transformer/transformer_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,15 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Transformer model.
"""

import argparse
import functools
import importlib
import os
from typing import Any

import torch
import tqdm
Expand Down Expand Up @@ -49,8 +52,8 @@

args = parser.parse_args()

config_model = importlib.import_module(args.config_model)
config_data = importlib.import_module(args.config_data)
config_model: Any = importlib.import_module(args.config_model)
config_data: Any = importlib.import_module(args.config_data)

utils.set_random_seed(config_model.random_seed)

Expand Down
43 changes: 22 additions & 21 deletions examples/xlnet/xlnet/model/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class XLNetDecoderOutput(NamedTuple):
State = List[Tensor]


class XLNetDecoder(XLNet, tx.modules.DecoderBase[State, Output]):
class XLNetDecoder(XLNet, tx.modules.DecoderBase[Optional[State], Output]):
def __init__(self, hparams=None):
super().__init__(hparams)

Expand Down Expand Up @@ -66,10 +66,12 @@ def _create_input(inputs: List[Tensor], initial: bool = False) \
seq_len += 1
segment_ids = word_embed.new_zeros(
seq_len, batch_size, dtype=torch.long)
if initial:
target_mapping = None
permute_mask = None
else:
return_dict = {
"word_embed": word_embed,
"segment_ids": segment_ids,
}

if not initial:
# Only the dummy token is considered target.
target_mapping = torch.cat([
torch.zeros(1, seq_len - 1, batch_size),
Expand All @@ -80,13 +82,12 @@ def _create_input(inputs: List[Tensor], initial: bool = False) \
torch.zeros(seq_len, seq_len - 1, batch_size),
torch.ones(seq_len, 1, batch_size),
], dim=1).to(device=word_embed.device)
return_dict.update({
"target_mapping": target_mapping,
"permute_mask": permute_mask,
})

return {
"word_embed": word_embed,
"segment_ids": segment_ids,
"target_mapping": target_mapping,
"permute_mask": permute_mask,
}
return return_dict

def initialize(self, # pylint: disable=no-self-use
helper: Helper, inputs: Optional[Tensor],
Expand All @@ -101,18 +102,19 @@ def step(self, helper: Helper, time: int, inputs: Tensor,
state: Optional[State]) \
-> Tuple[Output, Optional[State], Tensor, torch.ByteTensor]:
self._state_previous_inputs.append(inputs)
if not self._state_recompute_memory:
if self._state_recompute_memory:
net_output, memory = self._forward(
two_stream=True,
**self._create_input(
self._state_previous_inputs[-self._state_cache_len:]))
else:
assert state is not None
net_output, memory = self._forward(
memory=state, cache_len=self._state_cache_len, two_stream=True,
**self._create_input(self._state_previous_inputs[-1:]))
assert memory is not None
# Omit memory for the dummy token.
memory = [mem[:-1] for mem in memory]
else:
net_output, memory = self._forward(
two_stream=True,
**self._create_input(
self._state_previous_inputs[-self._state_cache_len:]))
logits = F.linear(net_output, self.word_embed.weight, self.lm_bias)
logits = logits[-1]
sample_ids = helper.sample(time=time, outputs=logits)
Expand All @@ -123,15 +125,14 @@ def step(self, helper: Helper, time: int, inputs: Tensor,
outputs = XLNetDecoderOutput(logits=logits, sample_id=sample_ids)
return outputs, memory, next_inputs, finished

def finalize(self, outputs: Output, final_state: Optional[State],
sequence_lengths: torch.LongTensor) \
-> Tuple[Output, Optional[State]]:
def finalize(self, outputs, final_state, sequence_lengths):
del self._state_cache_len
del self._state_recompute_memory
del self._state_previous_inputs
return super().finalize(outputs, final_state, sequence_lengths)

def forward(self, start_tokens: torch.LongTensor,
def forward(self, # type: ignore
start_tokens: torch.LongTensor,
memory: Optional[State] = None,
cache_len: int = 512,
max_decoding_length: Optional[int] = 500,
Expand Down
2 changes: 1 addition & 1 deletion examples/xlnet/xlnet/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ def forward(self, # type: ignore
"""
return self._forward(self.word_embed(token_ids), *args, **kwargs)

def _forward(self, # type: ignore
def _forward(self,
word_embed: Tensor, segment_ids: Optional[LongTensor],
input_mask: Optional[Tensor] = None,
memory: Optional[List[Tensor]] = None,
Expand Down
1 change: 1 addition & 0 deletions mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
[mypy]
warn_unused_ignores = True
warn_unused_configs = True
warn_redundant_casts = True
no_implicit_optional = True
follow_imports = silent
ignore_missing_imports = True
Expand Down
20 changes: 11 additions & 9 deletions stubs/torch/__init__.pyi
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import builtins
from typing import Callable, ContextManager, Iterator, Optional, Sequence, Tuple, List, Type, TypeVar, Union, overload
import pickle
from typing import Any, Callable, Dict, IO, Iterator, List, Optional, Sequence, Tuple, Type, TypeVar, Union, overload

import numpy as np
import torch.autograd
Expand All @@ -19,25 +20,26 @@ import torch.testing
import torch.utils.backcompat
from torch._tensor_str import set_printoptions
from torch.random import get_rng_state, initial_seed, manual_seed, set_rng_state
from torch.serialization import load, save
from torch.storage import _StorageBase
from torch.tensor import Tensor as TensorBase
from torch.utils.hooks import RemovableHandle


def no_grad() -> ContextManager[None]:
...
def load(f: Union[str, IO],
map_location: Optional[Union[Dict[str, str], str, torch.device, Callable[[str, str], str]]] = None,
pickle_module=pickle, **pickle_load_args): ...


def enable_grad() -> ContextManager[None]:
...
def save(obj: Any, f: Union[str, IO], pickle_module=pickle, pickle_protocol=pickle.DEFAULT_PROTOCOL): ...


def set_grad_enabled(mode: bool) -> ContextManager[None]:
...
no_grad: Any = ...
enable_grad: Any = ...
set_grad_enabled: Any = ...


class device: ...
class device:
def __init__(self, device: Union[builtins.int, builtins.str]): ...


class finfo:
Expand Down
5 changes: 3 additions & 2 deletions texar/data/data/data_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -608,13 +608,14 @@ def default_hparams():
"parallelize_processing": True,
}

def to(self, device: torch.device):
def to(self, device: Optional[torch.device]):
r"""Move the dataset to the specific device. Note that we don't actually
move data or do anything here --- we rely on correct implementations of
:meth:`_process` and :meth:`_collate` to move data to appropriate
devices.
"""
self.device = device
if device is not None:
self.device = device
return self

def _prefetch_processed(self, index: int):
Expand Down
2 changes: 1 addition & 1 deletion texar/data/data/multi_aligned_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,7 +379,7 @@ def default_hparams():
hparams["datasets"] = []
return hparams

def to(self, device: torch.device):
def to(self, device: Optional[torch.device]):
for dataset in self._databases:
dataset.to(device)
return super().to(device)
Expand Down
4 changes: 2 additions & 2 deletions texar/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@
Modules of Texar library module.
"""

from texar.modules.pretrained import *
from texar.modules.classifiers import *
from texar.modules.connectors import *
from texar.modules.decoders import *
from texar.modules.embedders import *
from texar.modules.encoders import *
from texar.modules.networks import *
from texar.modules.connectors import *
from texar.modules.pretrained import *
5 changes: 0 additions & 5 deletions texar/modules/classifiers/classifier_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@
from abc import ABC
from typing import Any, Dict

import torch

from texar.module_base import ModuleBase

__all__ = [
Expand All @@ -37,6 +35,3 @@ def default_hparams() -> Dict[str, Any]:
return {
"name": "classifier"
}

def forward(self, *input: torch.Tensor): # pylint: disable=redefined-builtin
raise NotImplementedError
2 changes: 1 addition & 1 deletion texar/modules/classifiers/conv_classifiers.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ class Conv1DClassifier(ClassifierBase):

def __init__(self, in_channels: int, in_features: Optional[int] = None,
hparams: Optional[Union[HParams, Dict[str, Any]]] = None):
ClassifierBase.__init__(self, hparams)
super().__init__(hparams)

encoder_hparams = utils.dict_fetch(hparams,
Conv1DEncoder.default_hparams())
Expand Down

0 comments on commit b23b5f3

Please sign in to comment.