Skip to content

Commit

Permalink
Resolve #63 seq2seq_attn example polish & Resolve #55 AttentionRNNDec…
Browse files Browse the repository at this point in the history
…oder interface refactor (#85)

* Polish seq2seq_attn & Refactor AttentionRNNDecoder interface
  • Loading branch information
gpengzhi committed Jul 8, 2019
1 parent fc99e0d commit 6552f2f
Show file tree
Hide file tree
Showing 11 changed files with 269 additions and 91 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,7 @@ docs/_build

### Project ###
/data/
/texar_download/
checkpoints/
/language_models/
/examples/language_model_ptb/simple-examples/
Expand Down
2 changes: 1 addition & 1 deletion examples/seq2seq_attn/config_iwslt14.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
num_epochs = 15
display = 1
display = 50

source_vocab_file = './data/iwslt14/vocab.de'
target_vocab_file = './data/iwslt14/vocab.en'
Expand Down
5 changes: 3 additions & 2 deletions examples/seq2seq_attn/config_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# Hyperparameters not specified here will take the default values.

num_units = 256
# beam_width = 10 # TODO: implement beam search in the decoding
beam_width = 10

embedder = {
'dim': num_units
Expand All @@ -27,7 +27,8 @@
'num_units': num_units,
},
'attention_layer_size': num_units
}
},
'max_decoding_length_infer': 60,
}

opt = {
Expand Down
6 changes: 3 additions & 3 deletions examples/seq2seq_attn/config_model_full.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# `config_model.py`.

num_units = 256
# beam_width = 10 # TODO: implement beam search in the decoding
beam_width = 10

# --------------------- Embedder --------------------- #
embedder = {
Expand Down Expand Up @@ -94,7 +94,7 @@
"kwargs": {
"num_units": 256,
},
"attention_layer_size": None,
"attention_layer_size": 256,
"alignment_history": False,
"output_attention": True,
},
Expand All @@ -107,7 +107,7 @@
'kwargs': {}
},
'max_decoding_length_train': None,
'max_decoding_length_infer': None,
'max_decoding_length_infer': 60,
'output_layer_bias': True,
'name': 'attention_rnn_decoder'
}
Expand Down
2 changes: 1 addition & 1 deletion examples/seq2seq_attn/config_toy_copy.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
num_epochs = 4
display = 1
display = 50

source_vocab_file = './data/toy_copy/train/vocab.sources.txt'
target_vocab_file = './data/toy_copy/train/vocab.targets.txt'
Expand Down
68 changes: 33 additions & 35 deletions examples/seq2seq_attn/seq2seq_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,9 @@
import importlib

import torch
import torch.nn as nn

import texar as tx
from texar.core.optimization import get_optimizer, get_train_op
from texar.module_base import ModuleBase

parser = argparse.ArgumentParser()
parser.add_argument('--config_model',
Expand All @@ -37,8 +36,10 @@
config_model = importlib.import_module(args.config_model)
config_data = importlib.import_module(args.config_data)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class Seq2SeqAttn(ModuleBase):

class Seq2SeqAttn(nn.Module):

def __init__(self, train_data):

Expand All @@ -65,27 +66,24 @@ def __init__(self, train_data):
self.decoder = tx.modules.AttentionRNNDecoder(
encoder_output_size=(self.encoder.cell_fw.hidden_size +
self.encoder.cell_bw.hidden_size),
input_size=self.target_embedder.dim + config_model.decoder
['attention']['attention_layer_size'],
input_size=self.target_embedder.dim,
vocab_size=self.target_vocab_size,
hparams=config_model.decoder)

def forward(self, batch, mode):

if mode == "train":
self.train()
else:
self.eval()

enc_outputs, _ = self.encoder(
inputs=self.source_embedder(batch['source_text_ids']),
sequence_length=batch['source_length'])

memory = torch.cat(enc_outputs, dim=2)

if mode == "train":
helper_train = self.decoder.create_helper()
helper_train = self.decoder.create_helper(
decoding_strategy="train_greedy")

training_outputs, _, _ = self.decoder(
memory=torch.cat(enc_outputs, dim=2),
memory=memory,
memory_sequence_length=batch['source_length'],
helper=helper_train,
inputs=self.target_embedder(batch['target_text_ids'][:, :-1]),
Expand All @@ -98,43 +96,44 @@ def forward(self, batch, mode):

return mle_loss
else:
start_tokens = torch.ones_like(batch['target_length']) * \
self.bos_token_id
start_tokens = memory.new_full(batch['target_length'].size(),
self.bos_token_id,
dtype=torch.int64)

helper_infer = self.decoder.create_helper(
decoding_strategy="infer_greedy",
embedding=self.target_embedder,
infer_outputs = self.decoder(
start_tokens=start_tokens,
end_token=self.eos_token_id.item())

infer_outputs, _, _ = self.decoder(
helper=helper_infer,
memory=torch.cat(enc_outputs, dim=2),
end_token=self.eos_token_id.item(),
embedding=self.target_embedder,
memory=memory,
memory_sequence_length=batch['source_length'],
max_decoding_length=60)
beam_width=config_model.beam_width)

return infer_outputs


def main():
"""Entrypoint.
"""
train_data = tx.data.PairedTextData(hparams=config_data.train)
val_data = tx.data.PairedTextData(hparams=config_data.val)
test_data = tx.data.PairedTextData(hparams=config_data.test)
train_data = tx.data.PairedTextData(hparams=config_data.train,
device=device)
val_data = tx.data.PairedTextData(hparams=config_data.val,
device=device)
test_data = tx.data.PairedTextData(hparams=config_data.test,
device=device)
data_iterator = tx.data.TrainTestDataIterator(
train=train_data, val=val_data, test=test_data)

model = Seq2SeqAttn(train_data)
optimizer = get_optimizer(model.parameters(), config_model.opt)
train_op = get_train_op(optimizer, config_model.opt)
model.to(device)
train_op = tx.core.get_train_op(params=model.parameters(),
hparams=config_model.opt)

def _train_epoch():
data_iterator.switch_to_train_data()
iterator = data_iterator.get_iterator()
model.train()

step = 0
for batch in iterator:
for batch in data_iterator:
loss = model(batch, mode="train")
loss.backward()
train_op()
Expand All @@ -145,15 +144,14 @@ def _train_epoch():
def _eval_epoch(mode):
if mode == 'val':
data_iterator.switch_to_val_data()
iterator = data_iterator.get_iterator()
else:
data_iterator.switch_to_test_data()
iterator = data_iterator.get_iterator()
model.eval()

refs, hypos = [], []
for batch in iterator:
infer_outputs = model(batch, mode="infer")
output_ids = infer_outputs.sample_id.cpu()
for batch in data_iterator:
infer_outputs = model(batch, mode="val")
output_ids = infer_outputs["sample_id"][:, :, 0].cpu()
target_texts_ori = [text[1:] for text in batch['target_text']]
target_texts = tx.utils.strip_special_tokens(
target_texts_ori, is_token_list=True)
Expand Down
32 changes: 26 additions & 6 deletions texar/core/optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,10 @@
"""

import functools
from typing import Any, Callable, Dict, List, Optional, Union
from typing import Any, Callable, Dict, List, Optional, Union, Iterable

import torch
import torch.nn as nn
from torch.optim.lr_scheduler import _LRScheduler
from torch.optim.optimizer import Optimizer # pylint: disable=no-name-in-module

Expand Down Expand Up @@ -146,7 +147,7 @@ def default_optimization_hparams() -> Dict[str, Any]:


def get_optimizer(
params: Union[List[torch.Tensor], List[Dict[str, List[torch.Tensor]]]],
params: Iterable[Union[torch.Tensor, Dict[str, Any]]],
hparams: Optional[Union[HParams, Dict[str, Any]]] = None) -> \
Optimizer:
r"""Creates a optimizer instance.
Expand Down Expand Up @@ -271,14 +272,21 @@ def get_grad_clip_fn(hparams: Optional[Union[HParams,
return grad_clip_fn


def get_train_op(optimizer: Optimizer,
def get_train_op(params: Optional[Iterable[Union[torch.Tensor,
Dict[str, Any]]]] = None,
optimizer: Optional[Optimizer] = None,
scheduler: Optional[_LRScheduler] = None,
hparams: Optional[Union[HParams, Dict[str, Any]]] = None) -> \
Callable[[], None]:
r"""Creates a training op.
Args:
params: an iterable of :class:`torch.Tensor` or
:class:`dict`. Specifies what Tensors should be optimized.
optimizer: A :torch_docs:`torch.optim.Optimizer
<optim.html#torch.optim.Optimizer>` instance.
scheduler: A :torch_docs:`torch.optim.lr_scheduler._LRScheduler
<optim.html#how-to-adjust-learning-rate>` instance.
hparams (dict or HParams, optional): hyperparameters. Missing
hyperparameters are set to default values automatically. See
:func:`~texar.core.default_optimization_hparams` for
Expand All @@ -289,15 +297,27 @@ def get_train_op(optimizer: Optimizer,
"""
hparams = HParams(hparams, default_optimization_hparams())

scheduler = get_scheduler(optimizer, hparams)
if params is None and optimizer is None and scheduler is None:
raise ValueError("'params', 'optimizer' and 'scheduler' must not be "
"None simultaneously.")

if scheduler is None:
if optimizer is None and params is not None:
optimizer = get_optimizer(params, hparams)
if optimizer is not None:
scheduler = get_scheduler(optimizer, hparams)
else:
optimizer = scheduler.optimizer # type: ignore

grad_clip_fn = get_grad_clip_fn(hparams)

params_list = []
# TODO: Support per-parameter options in the future.
params_list: List[nn.Parameter] = []
for param_group in optimizer.param_groups: # type: ignore
params = param_group["params"]
if isinstance(params, torch.Tensor):
params_list.append(params)
else:
elif isinstance(params, list):
params_list += params

def _train_op():
Expand Down
24 changes: 23 additions & 1 deletion texar/core/optimization_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,8 +222,30 @@ def test_get_train_op(self):
"name": None
}

# Case 1
optimizer = get_optimizer(self.model.parameters(), hparams)
train_op = get_train_op(optimizer, hparams)
train_op = get_train_op(optimizer=optimizer, hparams=hparams)

for t in range(50):
y_pred = self.model(self.x)
loss = self.loss_fn(y_pred, self.y)
loss.backward()
train_op()

# Case 2
train_op = get_train_op(params=self.model.parameters(), hparams=hparams)

for t in range(50):
y_pred = self.model(self.x)
loss = self.loss_fn(y_pred, self.y)
loss.backward()
train_op()

# Case 3
optimizer = get_optimizer(self.model.parameters(), hparams)
scheduler = get_scheduler(optimizer=optimizer,
hparams=hparams)
train_op = get_train_op(scheduler=scheduler, hparams=hparams)

for t in range(50):
y_pred = self.model(self.x)
Expand Down

0 comments on commit 6552f2f

Please sign in to comment.