Permalink
Browse files

Introducing the image captioning module (#390)

Introducing the image captioning module. Type of models supported: ConvNet encoder and Sockeye NMT decoders

Features:
-   Image encoder that extracts features using preetrained nets: `image_captioning.encoder`
-   Feature extraction script to dump features to disk `image_captioning.extract_features`
-   Pass-through embedding, since we do not need it for images
-   Image-text iterator that loads features on the fly during training with the option of loading all to memory: `image_captioning.data_io`
-   Training and inference pipelines for image captioning: `image_captioning.train`, `image_captioning.inference` and `image_captioning.captioner`
-   README with instructions on how to use the image captioning module: `image_captioning/README.md`
-   Visualization script that loads images and captions (prediction+ground truth) and display them: `image_captioning.visualize`
  • Loading branch information...
lorisbaz authored and fhieber committed May 25, 2018
1 parent 8835331 commit 09a90021453e8d8254201d92a830cd8da0c6e2c6
Showing with 3,635 additions and 42 deletions.
  1. +6 −0 CHANGELOG.md
  2. +1 −0 MANIFEST.in
  3. +1 −0 requirements.dev.txt
  4. +1 −1 requirements.txt
  5. +1 −1 setup.py
  6. +1 −1 sockeye/__init__.py
  7. +11 −0 sockeye/arguments.py
  8. +3 −1 sockeye/constants.py
  9. +94 −10 sockeye/data_io.py
  10. +31 −4 sockeye/decoder.py
  11. +53 −1 sockeye/encoder.py
  12. +12 −0 sockeye/image_captioning/__init__.py
  13. +153 −0 sockeye/image_captioning/arguments.py
  14. +153 −0 sockeye/image_captioning/captioner.py
  15. +123 −0 sockeye/image_captioning/checkpoint_decoder.py
  16. +423 −0 sockeye/image_captioning/data_io.py
  17. +229 −0 sockeye/image_captioning/encoder.py
  18. +166 −0 sockeye/image_captioning/extract_features.py
  19. +232 −0 sockeye/image_captioning/inference.py
  20. +387 −0 sockeye/image_captioning/train.py
  21. +164 −0 sockeye/image_captioning/utils.py
  22. +178 −0 sockeye/image_captioning/visualize.py
  23. +24 −11 sockeye/inference.py
  24. +5 −1 sockeye/initializer.py
  25. +11 −5 sockeye/model.py
  26. +7 −4 sockeye/train.py
  27. +306 −0 test/common_image_captioning.py
  28. +12 −0 test/integration/image_captioning/__init__.py
  29. +59 −0 test/integration/image_captioning/test_extract_features.py
  30. +89 −0 test/integration/image_captioning/test_image_captioning.py
  31. +89 −0 test/unit/image_captioning/test_arguments.py
  32. +276 −0 test/unit/image_captioning/test_data_io.py
  33. +76 −0 test/unit/image_captioning/test_encoder.py
  34. +111 −0 test/unit/image_captioning/test_utils.py
  35. +5 −2 test/unit/test_arguments.py
  36. +4 −0 tutorials/README.md
  37. +138 −0 tutorials/image_captioning/README.md
View
@@ -10,6 +10,12 @@ Note that Sockeye has checks in place to not translate with an old model that wa
Each version section may have have subsections for: _Added_, _Changed_, _Removed_, _Deprecated_, and _Fixed_.
## [1.18.18]
### Added
- \[Experimental\] Introducing the image captioning module. Type of models supported: ConvNet encoder - Sockeye NMT decoders. This includes also a feature extraction script,
an image-text iterator that loads features, training and inference pipelines and a visualization script that loads images and captions.
See [this tutorial](tutorials/image_captioning) for its usage. This module is experimental therefore its maintenance is not fully guaranteed.
## [1.18.17]
### Changed
- Updated to MXNet 1.2
View
@@ -23,3 +23,4 @@ recursive-include docs Makefile
recursive-include tutorials *.md
recursive-include tutorials *.png
recursive-include tutorials *.py
recursive-include test *.txt
View
@@ -1,5 +1,6 @@
pytest
pytest-cov
pillow
check-manifest
matplotlib
mypy>=0.6
View
@@ -1,4 +1,4 @@
pyyaml
mxnet==1.2.0
numpy>=1.12
typing
typing
View
@@ -109,7 +109,7 @@ def get_requirements(filename):
packages=find_packages(exclude=("test",)),
setup_requires=['pytest-runner'],
tests_require=['pytest', 'pytest-cov'],
tests_require=['pytest', 'pytest-cov', 'pillow'],
extras_require={
'optional': ['mxboard', 'matplotlib'],
View
@@ -11,4 +11,4 @@
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.
__version__ = '1.18.17'
__version__ = '1.18.18'
View
@@ -827,6 +827,10 @@ def add_training_args(params):
default=(.0, .0),
help='Recurrent dropout without memory loss (Semeniuta, 2016) for encoder & decoder '
'LSTMs. Use "x:x" to specify separate values. Default: %(default)s.')
train_params.add_argument('--rnn-enc-last-hidden-concat-to-embedding',
action="store_true",
help='Concatenate the last hidden layer of the encoder to the input of the decoder, '
'instead of the previous state of the decoder. Default: %(default)s.')
train_params.add_argument('--rnn-decoder-hidden-dropout',
type=float,
@@ -1029,6 +1033,13 @@ def add_translate_cli_args(params):
add_logging_args(params)
def add_max_output_cli_args(params):
params.add_argument('--max-output-length',
type=int,
default=None,
help='Maximum number of words to generate during translation. If None, it will be computed automatically. Default: %(default)s.')
def add_inference_args(params):
decode_params = params.add_argument_group("Inference parameters")
View
@@ -56,9 +56,10 @@
TRANSFORMER_TYPE = "transformer"
CONVOLUTION_TYPE = "cnn"
TRANSFORMER_WITH_CONV_EMBED_TYPE = "transformer-with-conv-embed"
IMAGE_PRETRAIN_TYPE = "image-pretrain-cnn"
# available encoders
ENCODERS = [RNN_NAME, RNN_WITH_CONV_EMBED_NAME, TRANSFORMER_TYPE, TRANSFORMER_WITH_CONV_EMBED_TYPE, CONVOLUTION_TYPE]
ENCODERS = [RNN_NAME, RNN_WITH_CONV_EMBED_NAME, TRANSFORMER_TYPE, TRANSFORMER_WITH_CONV_EMBED_TYPE, CONVOLUTION_TYPE, IMAGE_PRETRAIN_TYPE]
# available decoder
DECODERS = [RNN_NAME, TRANSFORMER_TYPE, CONVOLUTION_TYPE]
@@ -256,6 +257,7 @@
DEFAULT_FACTOR_DELIMITER = '|'
# data layout strings
BATCH_MAJOR_IMAGE = "NCHW"
BATCH_MAJOR = "NTC"
TIME_MAJOR = "TNC"
View
@@ -91,6 +91,28 @@ def define_parallel_buckets(max_seq_len_source: int,
return buckets
def define_empty_source_parallel_buckets(max_seq_len_target: int,
bucket_width: int = 10) -> List[Tuple[int, int]]:
"""
Returns (source, target) buckets up to (None, max_seq_len_target). The source
is empty since it is supposed to not contain data that can be bucketized.
The target is used as reference to create the buckets.
:param max_seq_len_target: Maximum target bucket size.
:param bucket_width: Width of buckets on longer side.
"""
target_step_size = max(1, bucket_width)
target_buckets = define_buckets(max_seq_len_target, step=target_step_size)
# source buckets are always 0 since there is no text
source_buckets = [0 for b in target_buckets]
target_buckets = [max(2, b) for b in target_buckets]
parallel_buckets = list(zip(source_buckets, target_buckets))
# deduplicate for return
buckets = list(OrderedDict.fromkeys(parallel_buckets))
buckets.sort()
return buckets
def get_bucket(seq_len: int, buckets: List[int]) -> Optional[int]:
"""
Given sequence length and a list of buckets, return corresponding bucket.
@@ -245,17 +267,21 @@ class DataStatisticsAccumulator:
def __init__(self,
buckets: List[Tuple[int, int]],
vocab_source: Dict[str, int],
vocab_source: Optional[Dict[str, int]],
vocab_target: Dict[str, int],
length_ratio_mean: float,
length_ratio_std: float) -> None:
self.buckets = buckets
num_buckets = len(buckets)
self.length_ratio_mean = length_ratio_mean
self.length_ratio_std = length_ratio_std
self.unk_id_source = vocab_source[C.UNK_SYMBOL]
if vocab_source is not None:
self.unk_id_source = vocab_source[C.UNK_SYMBOL]
self.size_vocab_source = len(vocab_source)
else:
self.unk_id_source = None
self.size_vocab_source = 0
self.unk_id_target = vocab_target[C.UNK_SYMBOL]
self.size_vocab_source = len(vocab_source)
self.size_vocab_target = len(vocab_target)
self.num_sents = 0
self.num_discarded = 0
@@ -286,7 +312,8 @@ def sequence_pair(self,
self.max_observed_len_source = max(source_len, self.max_observed_len_source)
self.max_observed_len_target = max(target_len, self.max_observed_len_target)
self.num_unks_source += source.count(self.unk_id_source)
if self.unk_id_source is not None:
self.num_unks_source += source.count(self.unk_id_source)
self.num_unks_target += target.count(self.unk_id_target)
@property
@@ -578,9 +605,14 @@ def get_data_statistics(source_readers: Sequence[Iterable],
data_stats_accumulator = DataStatisticsAccumulator(buckets, source_vocabs[0], target_vocab,
length_ratio_mean, length_ratio_std)
for sources, target in parallel_iter(source_readers, target_reader):
buck_idx, buck = get_parallel_bucket(buckets, len(sources[0]), len(target))
data_stats_accumulator.sequence_pair(sources[0], target, buck_idx)
if source_readers is not None and target_reader is not None:
for sources, target in parallel_iter(source_readers, target_reader):
buck_idx, buck = get_parallel_bucket(buckets, len(sources[0]), len(target))
data_stats_accumulator.sequence_pair(sources[0], target, buck_idx)
else: # Allow stats for target only data
for target in target_reader:
buck_idx, buck = get_target_bucket(buckets, len(target))
data_stats_accumulator.sequence_pair([], target, buck_idx)
return data_stats_accumulator.statistics
@@ -1080,6 +1112,33 @@ def parallel_iter(source_iters: Sequence[Iterable[Optional[Any]]], target_iterab
"Different number of lines in source(s) and target iterables.")
class FileListReader(Iterator):
"""
Reads sequence samples from path provided in a file.
:param fname: File name containing a list of relative paths.
:param path: Path to read data from, which is prefixed to the relative paths of fname.
"""
def __init__(self,
fname: str,
path: str) -> None:
self.fname = fname
self.path = path
self.fd = smart_open(fname)
self.count = 0
def __next__(self):
fname = self.fd.readline().strip("\n")
if fname is None:
self.fd.close()
raise StopIteration
self.count += 1
return os.path.join(self.path, fname)
def get_default_bucket_key(buckets: List[Tuple[int, int]]) -> Tuple[int, int]:
"""
Returns the default bucket from a list of buckets, i.e. the largest bucket.
@@ -1108,6 +1167,24 @@ def get_parallel_bucket(buckets: List[Tuple[int, int]],
return None, None
def get_target_bucket(buckets: List[Tuple[int, int]],
length_target: int) -> Optional[Tuple[int, Tuple[int, int]]]:
"""
Returns bucket index and bucket from a list of buckets, given source and target length.
Returns (None, None) if no bucket fits.
:param buckets: List of buckets.
:param length_target: Length of target sequence.
:return: Tuple of (bucket index, bucket), or (None, None) if not fitting.
"""
bucket = None, None # type: Tuple[int, Tuple[int, int]]
for j, (source_bkt, target_bkt) in enumerate(buckets):
if target_bkt >= length_target:
bucket = j, (source_bkt, target_bkt)
break
return bucket
class ParallelDataSet(Sized):
"""
Bucketed parallel data set with labels
@@ -1182,8 +1259,12 @@ def fill_up(self,
logger.info("Replicating %d random samples from %d samples in bucket %s "
"to size it to multiple of %d",
rest, num_samples, bucket, bucket_batch_size)
random_indices = mx.nd.array(rs.randint(num_samples, size=rest))
source[bucket_idx] = mx.nd.concat(bucket_source, bucket_source.take(random_indices), dim=0)
random_indices_np = rs.randint(num_samples, size=rest)
random_indices = mx.nd.array(random_indices_np)
if isinstance(source[bucket_idx], np.ndarray):
source[bucket_idx] = np.concatenate((bucket_source, bucket_source.take(random_indices_np)), axis=0)
else:
source[bucket_idx] = mx.nd.concat(bucket_source, bucket_source.take(random_indices), dim=0)
target[bucket_idx] = mx.nd.concat(bucket_target, bucket_target.take(random_indices), dim=0)
label[bucket_idx] = mx.nd.concat(bucket_label, bucket_label.take(random_indices), dim=0)
else:
@@ -1200,7 +1281,10 @@ def permute(self, permutations: List[mx.nd.NDArray]) -> 'ParallelDataSet':
num_samples = self.source[buck_idx].shape[0]
if num_samples: # not empty bucket
permutation = permutations[buck_idx]
source.append(self.source[buck_idx].take(permutation))
if isinstance(self.source[buck_idx], np.ndarray):
source.append(self.source[buck_idx].take(np.int64(permutation.asnumpy())))
else:
source.append(self.source[buck_idx].take(permutation))
target.append(self.target[buck_idx].take(permutation))
label.append(self.label[buck_idx].take(permutation))
else:
View
@@ -454,6 +454,8 @@ class RecurrentDecoderConfig(Config):
:param context_gating: Whether to use context gating.
:param layer_normalization: Apply layer normalization.
:param attention_in_upper_layers: Pass the attention value to all layers in the decoder.
:param enc_last_hidden_concat_to_embedding: Concatenate the last hidden representation of the encoder to the
input of the decoder (e.g., context + current embedding).
:param dtype: Data type.
"""
@@ -467,7 +469,9 @@ def __init__(self,
context_gating: bool = False,
layer_normalization: bool = False,
attention_in_upper_layers: bool = False,
dtype: str = C.DTYPE_FP32) -> None:
dtype: str = C.DTYPE_FP32,
enc_last_hidden_concat_to_embedding: bool = False) -> None:
super().__init__()
self.max_seq_len_source = max_seq_len_source
self.rnn_config = rnn_config
@@ -478,6 +482,7 @@ def __init__(self,
self.context_gating = context_gating
self.layer_normalization = layer_normalization
self.attention_in_upper_layers = attention_in_upper_layers
self.enc_last_hidden_concat_to_embedding = enc_last_hidden_concat_to_embedding
self.dtype = dtype
@@ -579,6 +584,14 @@ def decode_sequence(self,
# target_embed: target_seq_len * (batch_size, num_target_embed)
target_embed = mx.sym.split(data=target_embed, num_outputs=target_embed_max_length, axis=1, squeeze_axis=True)
# Get last state from source (batch_size, num_target_embed)
enc_last_hidden = None
if self.config.enc_last_hidden_concat_to_embedding:
enc_last_hidden = mx.sym.SequenceLast(data=source_encoded,
sequence_length=source_encoded_lengths,
axis=1,
use_sequence_length=True)
# get recurrent attention function conditioned on source
attention_func = self.attention.on(source_encoded, source_encoded_lengths,
source_encoded_max_length)
@@ -599,7 +612,8 @@ def decode_sequence(self,
state,
attention_func,
attention_state,
seq_idx)
seq_idx,
enc_last_hidden=enc_last_hidden)
hidden_states.append(state.hidden)
# concatenate along time axis: (batch_size, target_embed_max_length, rnn_num_hidden)
@@ -624,6 +638,14 @@ def decode_step(self,
"""
source_encoded, prev_dynamic_source, source_encoded_length, prev_hidden, *layer_states = states
# Get last state from source (batch_size, num_target_embed)
enc_last_hidden = None
if self.config.enc_last_hidden_concat_to_embedding:
enc_last_hidden = mx.sym.SequenceLast(data=source_encoded,
sequence_length=source_encoded_length,
axis=1,
use_sequence_length=True)
attention_func = self.attention.on(source_encoded, source_encoded_length, source_encoded_max_length)
prev_state = RecurrentDecoderState(prev_hidden, list(layer_states))
@@ -636,7 +658,8 @@ def decode_step(self,
state, attention_state = self._step(target_embed_prev,
prev_state,
attention_func,
prev_attention_state)
prev_attention_state,
enc_last_hidden=enc_last_hidden)
new_states = [source_encoded,
attention_state.dynamic_source,
@@ -809,7 +832,8 @@ def _step(self, word_vec_prev: mx.sym.Symbol,
state: RecurrentDecoderState,
attention_func: Callable,
attention_state: rnn_attention.AttentionState,
seq_idx: int = 0) -> Tuple[RecurrentDecoderState, rnn_attention.AttentionState]:
seq_idx: int = 0,
enc_last_hidden: Optional[mx.sym.Symbol] = None) -> Tuple[RecurrentDecoderState, rnn_attention.AttentionState]:
"""
Performs single-time step in the RNN, given previous word vector, previous hidden state, attention function,
@@ -824,6 +848,9 @@ def _step(self, word_vec_prev: mx.sym.Symbol,
"""
# (1) RNN step
# concat previous word embedding and previous hidden state
if enc_last_hidden is not None:
word_vec_prev = mx.sym.concat(word_vec_prev, enc_last_hidden, dim=1,
name="%sconcat_target_encoder_t%d" % (self.prefix, seq_idx))
rnn_input = mx.sym.concat(word_vec_prev, state.hidden, dim=1,
name="%sconcat_target_context_t%d" % (self.prefix, seq_idx))
# rnn_pre_attention_output: (batch_size, rnn_num_hidden)
Oops, something went wrong.

0 comments on commit 09a9002

Please sign in to comment.