Skip to content

Commit

Permalink
Merge pull request #43 from haoransh/master
Browse files Browse the repository at this point in the history
Modularize multihead attention
  • Loading branch information
ZhitingHu committed Nov 1, 2018
2 parents 3d684fa + c46d02a commit dd972bc
Show file tree
Hide file tree
Showing 9 changed files with 415 additions and 176 deletions.
49 changes: 49 additions & 0 deletions examples/transformer/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -86,3 +86,52 @@ python bleu_tool.py --reference=data/en_de/test.de --translation=temp/test.outpu
```
Using an Nvidia GTX 1080Ti, the model usually converges within 5 hours (~15 epochs) on IWSLT'15.


## Hands on your customed experiments!

The following is the instructions on how to experiment on you customed dataset with Transformer.

### 1. Prepare your dataset

create the directory and store the original paired dataset in the directory. The dataset directory should
`data/${src_language}_${tgt_language}`. There shoule be at least six files in the corresponding directory,
aka, `train/dev/test.src_language/tgt_language`.

For example, after you run `sh scripts/iwslt15_en_vi.sh`, you can find the directory `data/en_vi/` and
six corpus files in that directory.

### 2. Preprocess the data

Run `preprocess_data.sh ${encoder} ${src_language} ${tgt_language}` to obtain the processed dataset.
The `encoder` parameter can be `bpe`(byte pairwise encoding), `spm` (sentence piece encoding), or
`raw`(no subword encoding).

The above examples are using `bpe` or `spm`. If you choose to use `raw` encoding method, Notice that:

- By default, the word embedding layer is built with the combination of source language vocabulary and target language vocabulary.
- By default, the final output layer of transformer decoder (hidden_state -> logits) shares the parameters with the word embedding layer.

### 3. Define your model configuratoin and data configuration

Create your customed python files to define your transformer configuration and data loading configuration.

You can refer to the provided templated named `config_iwslt15.py` and `config_model.py`.

### 4. Train the model

Train:
```
python transformer_main.py --run_mode=train_and_evaluate --config_model=custom_config_model --config_data=custom_config_data
```

### 5. Test the model
Test:
```
python transformer_main.py --run_mode=test --config_data=custom_config_data --model_dir=./outputs
```

Inferenced test samples are in `outputs/test.output`.
You can want to decode the samples with respective decoder if you choose to use `bpe` or `spm` encoder before.
Finally, you can use
`python bleu_tool.py --reference=you_reference_file --translation=your_decoded_file` to calculate the BLEU score.
11 changes: 10 additions & 1 deletion examples/transformer/preprocess_data.sh
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,9 @@ echo "Output dir = $out"

echo "Step 1a: Preprocess inputs"

echo "Learning Word Piece or Byte Pairwise on source and target combined"
case ${encoder} in
'spm')
echo "Learning Word Piece on source and target combined"
spm_train --input=${train_src},${train_tgt} --vocab_size ${vocab_size} --model_prefix=$out/data/spm-codes.${vocab_size}
spm_encode --model $out/data/spm-codes.${vocab_size}.model --output_format=piece --infile $train_src --outfile $out/data/train.${src_language}.spm
spm_encode --model $out/data/spm-codes.${vocab_size}.model --output_format=piece --infile $valid_src --outfile $out/data/valid.${src_language}.spm
Expand All @@ -67,6 +67,7 @@ case ${encoder} in
spm_encode --model $out/data/spm-codes.${vocab_size}.model --output_format=piece --infile ${test_tgt} --outfile $out/data/test.${tgt_language}.spm
cp ${test_tgt} ${out}/test/test.${tgt_language} ;;
'bpe'):
echo "Learning Byte Pairwise on source and target combined"
cat ${train_src} ${train_tgt} | learn_bpe -s ${vocab_size} > ${out}/data/bpe-codes.${vocab_size}
apply_bpe -c ${out}/data/bpe-codes.${vocab_size} < ${train_src} > $out/data/train.${src_language}.bpe
apply_bpe -c ${out}/data/bpe-codes.${vocab_size} < ${valid_src} > ${out}/data/valid.${src_language}.bpe
Expand All @@ -75,6 +76,14 @@ case ${encoder} in
apply_bpe -c ${out}/data/bpe-codes.${vocab_size} < ${valid_tgt} > ${out}/data/valid.${tgt_language}.bpe
apply_bpe -c ${out}/data/bpe-codes.${vocab_size} < ${test_tgt} > ${out}/data/test.${tgt_language}.bpe
cp ${test_tgt} ${out}/test/test.${tgt_language} ;;
'raw'):
echo "No subword encoding is applied, just copy the corpus files into correct directory"
cp ${train_src} $out/data/train.${src_language}.raw
cp ${valid_src} $out/data/valid.${src_language}.raw
cp ${test_src} $out/data/test.${src_language}.raw
cp ${train_tgt} $out/data/train.${tgt_language}.raw
cp ${valid_tgt} $out/data/valid.${tgt_language}.raw
cp ${test_tgt} $out/data/test.${tgt_language}.raw
esac

python ${TF}/utils/preprocess.py -i ${out}/data \
Expand Down
8 changes: 4 additions & 4 deletions examples/transformer/transformer_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,9 @@
from texar.modules import TransformerEncoder, TransformerDecoder
from texar.utils import transformer_utils

from utils import data_utils
from utils import utils
from utils import data_utils, utils
from bleu_tool import bleu_wrapper

from utils.preprocess import bos_token_id, eos_token_id
# pylint: disable=invalid-name, too-many-locals

flags = tf.flags
Expand Down Expand Up @@ -59,7 +58,6 @@ def main():
with open(config_data.vocab_file, 'rb') as f:
id2w = pickle.load(f)
vocab_size = len(id2w)
bos_token_id, eos_token_id = 1, 2

beam_width = config_model.beam_width

Expand Down Expand Up @@ -258,11 +256,13 @@ def _train_epoch(sess, epoch, step, smry_writer):
smry_writer = tf.summary.FileWriter(FLAGS.model_dir, graph=sess.graph)

if FLAGS.run_mode == 'train_and_evaluate':
logger.info('Begin running with train_and_evaluate mode')
step = 0
for epoch in range(config_data.max_train_epoch):
step = _train_epoch(sess, epoch, step, smry_writer)

elif FLAGS.run_mode == 'test':
logger.info('Begin running with test mode')
saver.restore(sess, tf.train.latest_checkpoint(FLAGS.model_dir))
_eval_epoch(sess, 0, mode='test')

Expand Down
15 changes: 8 additions & 7 deletions examples/transformer/utils/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,10 @@

split_pattern = re.compile(r'([.,!?"\':;)(])')
digit_pattern = re.compile(r'\d')
Special_Seq = collections.namedtuple('Special_Seq', \
['PAD', 'BOS', 'EOS', 'UNK'])
Vocab_Pad = Special_Seq(PAD=0, BOS=1, EOS=2, UNK=3)

# Refer to https://texar.readthedocs.io/en/latest/_modules/texar/data/vocabulary.html#SpecialTokens
# these tokens will by default have token ids 0, 1, 2, 3 respectively
pad_token_id, bos_token_id, eos_token_id, unk_token_id = 0, 1, 2, 3

def split_sentence(s, tok=False):
"""split sentence with some segmentation rules."""
Expand Down Expand Up @@ -72,7 +73,7 @@ def count_words(path, max_vocab_size=40000, tok=False):

def make_array(word_id, words):
"""generate id numpy array from plain text words."""
ids = [word_id.get(word, Vocab_Pad.UNK) for word in words]
ids = [word_id.get(word, unk_token_id) for word in words]
return np.array(ids, 'i')

def make_dataset(path, w2id, tok=False):
Expand All @@ -84,7 +85,7 @@ def make_dataset(path, w2id, tok=False):
npy_dataset.append(array)
dataset.append(words)
token_count += array.size
unknown_count += (array == Vocab_Pad.UNK).sum()
unknown_count += (array == unk_token_id).sum()
print('# of tokens:{}'.format(token_count))
print('# of unknown {} {:.2}'.format(unknown_count,\
100. * unknown_count / token_count))
Expand All @@ -95,9 +96,9 @@ def get_preprocess_args():
class Config(): pass
config = Config()
parser = argparse.ArgumentParser(description='Preprocessing Options')
parser.add_argument('--source-vocab', type=int, default=40000,
parser.add_argument('--source_vocab', type=int, default=40000,
help='Vocabulary size of source language')
parser.add_argument('--target-vocab', type=int, default=40000,
parser.add_argument('--target_vocab', type=int, default=40000,
help='Vocabulary size of target language')
parser.add_argument('--tok', dest='tok', action='store_true',
help='tokenized and lowercased')
Expand Down
2 changes: 1 addition & 1 deletion texar/data/vocabulary.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@

class SpecialTokens(object):
"""Special tokens, including :attr:`PAD`, :attr:`BOS`, :attr:`EOS`,
:attr:`UNK`. These tokens will by default have token ids 0, 1, 3, 4,
:attr:`UNK`. These tokens will by default have token ids 0, 1, 2, 3,
respectively.
"""
PAD = "<PAD>"
Expand Down
107 changes: 79 additions & 28 deletions texar/modules/decoders/transformer_decoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from __future__ import print_function

# pylint: disable=no-name-in-module, too-many-arguments, too-many-locals
# pylint: disable=invalid-name, redefined-variable-type
# pylint: disable=invalid-name

import collections

Expand All @@ -32,7 +32,9 @@
from texar.modules.networks.networks import FeedForwardNetwork
from texar.modules.embedders.position_embedders import SinusoidsPositionEmbedder
from texar.modules.encoders.transformer_encoders import \
default_transformer_poswise_net_hparams
default_transformer_poswise_net_hparams
from texar.modules.encoders.multihead_attention import \
MultiheadAttentionEncoder
from texar.utils import beam_search
from texar.utils.shapes import shape_list, mask_sequences
from texar.utils import transformer_attentions as attn
Expand Down Expand Up @@ -60,8 +62,14 @@ class TransformerDecoderOutput(


class TransformerDecoder(ModuleBase):
"""Transformer decoder that applies multi-head self attention for
"""Transformer decoder that applies multi-head attention for
sequence decoding.
Stacked `~texar.modules.encoders.MultiheadAttentionEncoder` for
encoder-decoder attention and self attention,
`~texar.modules.FeedForwardNetwork` and residual connections.
Use the passed `embedding` variable as the parameters of the
transform layer from output to logits.
Args:
embedding: A Tensor of shape `[vocab_size, dim]` containing the
Expand Down Expand Up @@ -89,8 +97,47 @@ def __init__(self, embedding, hparams=None):
self._embedding = embedding
self._vocab_size = self._embedding.get_shape().as_list()[0]

self.output_layer = \
self._build_output_layer(shape_list(self._embedding)[-1])
self.output_layer = \
self._build_output_layer(shape_list(self._embedding)[-1])
self.multihead_attentions = {
'self_att': [],
'encdec_att': []
}
self.poswise_networks = []
for i in range(self._hparams.num_blocks):
layer_name = 'layer_{}'.format(i)
with tf.variable_scope(layer_name):
with tf.variable_scope("self_attention"):
multihead_attention = MultiheadAttentionEncoder(
self._hparams.multihead_attention)
self.multihead_attentions['self_att'].append(
multihead_attention)
# pylint: disable=protected-access
if self._hparams.dim != \
multihead_attention._hparams.output_dim:
raise ValueError('The output dimenstion of'
'MultiheadEncoder should be equal'
'to the dim of TransformerDecoder')

with tf.variable_scope('encdec_attention'):
multihead_attention = MultiheadAttentionEncoder(
self._hparams.multihead_attention)
self.multihead_attentions['encdec_att'].append(
multihead_attention)
if self._hparams.dim != \
multihead_attention._hparams.output_dim:
raise ValueError('The output dimenstion of'
'MultiheadEncoder should be equal'
'to the dim of TransformerDecoder')

poswise_network = FeedForwardNetwork(
hparams=self._hparams['poswise_feedforward'])
if self._hparams.dim != \
poswise_network._hparams.layers[-1]['kwargs']['units']:
raise ValueError('The output dimenstion of'
'FeedForwardNetwork should be equal'
'to the dim of TransformerDecoder')
self.poswise_networks.append(poswise_network)

@staticmethod
def default_hparams():
Expand All @@ -101,13 +148,15 @@ def default_hparams():
{
# Same as in TransformerEncoder
"num_blocks": 6,
"num_heads": 8,
"dim": 512,
"position_embedder_hparams": None,
"embedding_dropout": 0.1,
"attention_dropout": 0.1,
"residual_dropout": 0.1,
"poswise_feedforward": default_transformer_poswise_net_hparams,
"multihead_attention": {
"num_units": 512,
"num_heads": 8,
},
"initializer": None,
# Additional for TransformerDecoder
"embedding_tie": True,
Expand All @@ -121,9 +170,6 @@ def default_hparams():
"num_blocks" : int
Number of stacked blocks.
"num_heads" : int
Number of heads for attention calculation.
"dim" : int
Hidden dimension of the encoder.
Expand All @@ -137,19 +183,25 @@ def default_hparams():
"embedding_dropout": float
Dropout rate of the input word and position embeddings.
"attention_dropout: : float
Dropout rate in the attention.
"residual_dropout" : float
Dropout rate of the residual connections.
"poswise_feedforward" : dict,
Hyperparameters for a feed-forward network used in residual
connections.
Make sure the dimension of the output tensor is equal to `dim`.
See :func:`~texar.modules.default_transformer_poswise_net_hparams`
for details.
"multihead_attention": dict,
Hyperparameters for the multihead attention strategy.
Make sure the `output_dim` in this module is equal to `dim`.
See :func:
`~texar.modules.encoder.MultiheadAttentionEncoder.
default_harams` for details.
`
"initializer" : dict, optional
Hyperparameters of the default initializer that initializes
variables created in this module.
Expand All @@ -176,17 +228,19 @@ def default_hparams():
Name of the module.
"""
return {
"num_heads": 8,
"num_blocks": 6,
"initializer": None,
"position_embedder_hparams": None,
"embedding_tie": True,
"output_layer_bias": False,
"max_decoding_length": 1e10,
"embedding_dropout": 0.1,
"attention_dropout": 0.1,
"residual_dropout": 0.1,
"poswise_feedforward": default_transformer_poswise_net_hparams(),
'multihead_attention': {
'num_units': 512,
'num_heads': 8,
},
"dim": 512,
"name": "transformer_decoder",
}
Expand Down Expand Up @@ -451,15 +505,14 @@ def _self_attention_stack(self,
layer_cache = cache[layer_name] if cache is not None else None
with tf.variable_scope(layer_name):
with tf.variable_scope("self_attention"):
selfatt_output = attn.multihead_attention(
multihead_attention = \
self.multihead_attentions['self_att'][i]
selfatt_output = multihead_attention(
queries=layers.layer_normalize(x),
memory=None,
memory_attention_bias=decoder_self_attention_bias,
num_units=self._hparams.dim,
num_heads=self._hparams.num_heads,
dropout_rate=self._hparams.attention_dropout,
cache=layer_cache,
scope="multihead_attention",
mode=mode,
)
x = x + tf.layers.dropout(
selfatt_output,
Expand All @@ -468,21 +521,19 @@ def _self_attention_stack(self,
)
if memory is not None:
with tf.variable_scope('encdec_attention'):
encdec_output = attn.multihead_attention(
multihead_attention = \
self.multihead_attentions['encdec_att'][i]
encdec_output = multihead_attention(
queries=layers.layer_normalize(x),
memory=memory,
memory_attention_bias=memory_attention_bias,
num_units=self._hparams.dim,
num_heads=self._hparams.num_heads,
dropout_rate=self._hparams.attention_dropout,
scope="multihead_attention"
mode=mode,
)
x = x + tf.layers.dropout(encdec_output, \
rate=self._hparams.residual_dropout, \
training=is_train_mode(mode))
poswise_network = FeedForwardNetwork( \
hparams=self._hparams['poswise_feedforward'])
with tf.variable_scope(poswise_network.variable_scope):
poswise_network = self.poswise_networks[i]
with tf.variable_scope('past_poswise_ln'):
sub_output = tf.layers.dropout(
poswise_network(layers.layer_normalize(x)),
rate=self._hparams.residual_dropout,
Expand Down

0 comments on commit dd972bc

Please sign in to comment.