Skip to content
This repository has been archived by the owner on Aug 3, 2021. It is now read-only.

Args options #520

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 2 additions & 29 deletions ctc_decoder_with_lm/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,34 +21,7 @@ You'll need the following pre-requisites downloaded/installed:
* [TensorFlow source and requirements](https://www.tensorflow.org/install/install_sources)
* [libsox](https://sourceforge.net/projects/sox/)


## Preparation

Create a symbolic link in your TensorFlow checkout to `ctc_decoder_with_lm` directory. If your DeepSpeech and TensorFlow checkouts are side by side in the same directory, do:

```
cd tensorflow
ln -s ../OpenSeq2Seq/ctc_decoder_with_lm ./
```

## Building

## Step 1 : Build Tensorflow
You need to re-build TensorFlow.
Follow the [instructions](https://www.tensorflow.org/install/install_sources) on the TensorFlow site for your platform, up to the end of 'Configure the installation':

```
./configure
bazel build --config=opt --config=cuda //tensorflow/tools/pip_package:build_pip_package
bazel-bin/tensorflow/tools/pip_package/build_pip_package /tmp/tensorflow_pkg
sudo pip install /tmp/tensorflow_pkg/tensorflow*.whl
sudo pip install --upgrade numpy
```

## Step 2: Build CTC beam search decoder:

```
bazel build -c opt --copt=-O3 --config=cuda //tensorflow:libtensorflow_cc.so //tensorflow:libtensorflow_framework.so //ctc_decoder_with_lm:libctc_decoder_with_kenlm.so //ctc_decoder_with_lm:generate_trie
cp bazel-bin/ctc_decoder_with_lm/*.so OpenSeq2Seq/ctc_decoder_with_lm/
cp bazel-bin/ctc_decoder_with_lm/generate_trie OpenSeq2Seq/ctc_decoder_with_lm/
```
Please see the detailed instructions in [OpenSeq2Seq documentation](https://nvidia.github.io/OpenSeq2Seq/html/installation.html#how-to-build-a-custom-native-tf-op-for-ctc-decoder-with-language-model-optional).

42 changes: 21 additions & 21 deletions ctc_decoder_with_lm/beam_search.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@ namespace ctc {

template <typename CTCBeamState = ctc_beam_search::EmptyBeamState,
typename CTCBeamComparer =
ctc_beam_search::BeamComparer<CTCBeamState>>
class CTCBeamSearchNormLogDecoder : public CTCDecoder {
ctc_beam_search::BeamComparer<float, CTCBeamState>>
class CTCBeamSearchNormLogDecoder : public CTCDecoder<float> {
// Beam Search
//
// Example (GravesTh Fig. 7.5):
Expand Down Expand Up @@ -70,12 +70,12 @@ class CTCBeamSearchNormLogDecoder : public CTCDecoder {
// starts at 0). This special case can be calculated as:
// P(l=abc? @ t=3) = P(a @ 0)*P(b @ 1)*P(c @ 2)*P(? @ 3)
// but we calculate it recursively for speed purposes.
typedef ctc_beam_search::BeamEntry<CTCBeamState> BeamEntry;
typedef ctc_beam_search::BeamRoot<CTCBeamState> BeamRoot;
typedef ctc_beam_search::BeamProbability BeamProbability;
typedef ctc_beam_search::BeamEntry<float, CTCBeamState> BeamEntry;
typedef ctc_beam_search::BeamRoot<float, CTCBeamState> BeamRoot;
typedef ctc_beam_search::BeamProbability<float> BeamProbability;

public:
typedef BaseBeamScorer<CTCBeamState> DefaultBeamScorer;
typedef BaseBeamScorer<float, CTCBeamState> DefaultBeamScorer;

// The beam search decoder is constructed specifying the beam_width (number of
// candidates to keep at each decoding timestep) and a beam scorer (used for
Expand All @@ -84,9 +84,9 @@ class CTCBeamSearchNormLogDecoder : public CTCDecoder {
// implementation, CTCBeamSearchDecoder<>::DefaultBeamScorer, generates the
// standard beam search.
CTCBeamSearchNormLogDecoder(int num_classes, int beam_width,
BaseBeamScorer<CTCBeamState>* scorer, int batch_size = 1,
BaseBeamScorer<float, CTCBeamState>* scorer, int batch_size = 1,
bool merge_repeated = false)
: CTCDecoder(num_classes, batch_size, merge_repeated),
: CTCDecoder<float>(num_classes, batch_size, merge_repeated),
beam_width_(beam_width),
leaves_(beam_width),
beam_scorer_(CHECK_NOTNULL(scorer)) {
Expand All @@ -96,10 +96,10 @@ class CTCBeamSearchNormLogDecoder : public CTCDecoder {
~CTCBeamSearchNormLogDecoder() override {}

// Run the hibernating beam search algorithm on the given input.
Status Decode(const CTCDecoder::SequenceLength& seq_len,
const std::vector<CTCDecoder::Input>& input,
std::vector<CTCDecoder::Output>* output,
CTCDecoder::ScoreOutput* scores) override;
Status Decode(const CTCDecoder<float>::SequenceLength& seq_len,
const std::vector<CTCDecoder<float>::Input>& input,
std::vector<CTCDecoder<float>::Output>* output,
CTCDecoder<float>::ScoreOutput* scores) override;

// Calculate the next step of the beam search and update the internal state.
template <typename Vector>
Expand All @@ -111,7 +111,7 @@ class CTCBeamSearchNormLogDecoder : public CTCDecoder {
std::vector<int>* top_k_indices);

// Retrieve the beam scorer instance used during decoding.
BaseBeamScorer<CTCBeamState>* GetBeamScorer() const { return beam_scorer_; }
BaseBeamScorer<float, CTCBeamState>* GetBeamScorer() const { return beam_scorer_; }

// Set label selection parameters for faster decoding.
// See comments for label_selection_size_ and label_selection_margin_.
Expand All @@ -129,7 +129,7 @@ class CTCBeamSearchNormLogDecoder : public CTCDecoder {
std::vector<float>* log_probs, bool merge_repeated) const;

gtl::TopN<BeamEntry*, CTCBeamComparer> leaves_;
BaseBeamScorer<CTCBeamState>* beam_scorer_;
BaseBeamScorer<float, CTCBeamState>* beam_scorer_;

private:
int beam_width_;
Expand All @@ -156,15 +156,15 @@ class CTCBeamSearchNormLogDecoder : public CTCDecoder {

template <typename CTCBeamState, typename CTCBeamComparer>
Status CTCBeamSearchNormLogDecoder<CTCBeamState, CTCBeamComparer>::Decode(
const CTCDecoder::SequenceLength& seq_len,
const std::vector<CTCDecoder::Input>& input,
std::vector<CTCDecoder::Output>* output, ScoreOutput* scores) {
const CTCDecoder<float>::SequenceLength& seq_len,
const std::vector<CTCDecoder<float>::Input>& input,
std::vector<CTCDecoder<float>::Output>* output, ScoreOutput* scores) {
// Storage for top paths.
std::vector<std::vector<int>> beams;
std::vector<float> beam_log_probabilities;
int top_n = output->size();
if (std::any_of(output->begin(), output->end(),
[this](const CTCDecoder::Output& output) -> bool {
[this](const CTCDecoder<float>::Output& output) -> bool {
return output.size() < this->batch_size_;
})) {
return errors::InvalidArgument(
Expand Down Expand Up @@ -325,7 +325,7 @@ void CTCBeamSearchNormLogDecoder<CTCBeamState, CTCBeamComparer>::Step(
// isn't full, or the lowest probability entry in the beam has a
// lower probability than the leaf.
auto is_candidate = [this](const BeamProbability& prob) {
return (prob.total > kLogZero &&
return (prob.total > kLogZero<float>() &&
(leaves_.size() < beam_width_ ||
prob.total > leaves_.peek_bottom()->newp.total));
};
Expand All @@ -345,7 +345,7 @@ void CTCBeamSearchNormLogDecoder<CTCBeamState, CTCBeamComparer>::Step(
BeamEntry& c = b->GetChild(label);
if (!c.Active()) {
// Pblank(l=abcd @ t=6) = 0
c.newp.blank = kLogZero;
c.newp.blank = kLogZero<float>();
// If new child label is identical to beam label:
// Plabel(l=abcc @ t=6) = Pblank(l=abc @ t=5) * P(c @ 6)
// Otherwise:
Expand Down Expand Up @@ -727,7 +727,7 @@ class CTCBeamSearchDecoderWithLMOp : public tf::OpKernel {
beam_search.Step(input_bi);
}

typedef tf::ctc::ctc_beam_search::BeamEntry<WordLMBeamState> BeamEntry;
typedef tf::ctc::ctc_beam_search::BeamEntry<float, WordLMBeamState> BeamEntry;
std::unique_ptr<std::vector<BeamEntry*>> branches(beam_search.leaves_.Extract());
beam_search.leaves_.Reset();
for (int i = 0; i < branches->size(); ++i) {
Expand Down
2 changes: 1 addition & 1 deletion ctc_decoder_with_lm/beam_search.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ struct WordLMBeamState {
bool new_word;
};

class WordLMBeamScorer : public tensorflow::ctc::BaseBeamScorer<WordLMBeamState> {
class WordLMBeamScorer : public tensorflow::ctc::BaseBeamScorer<float, WordLMBeamState> {
public:
WordLMBeamScorer(const std::string &kenlm_path, const std::string &trie_path,
const std::string &alphabet_path,
Expand Down
1 change: 1 addition & 0 deletions docs/sources/source/getting-started/asr.rst
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ dataset size will be around 224GB (including archives and original compressed au
Now, everything should be setup to train the model::

python run.py --config_file=example_configs/speech2text/ds2_librispeech_larc_config.py --mode=train_eval
python run.py --config_file=example_configs/speech2text/ds2_librispeech_larc_config.py --mode=train_eval --infer_dataset=example_configs/datasets/infer.csv

If you want to run evaluation/inference with the trained model, replace
``--mode=train_eval`` with ``--mode=eval`` or ``--mode=infer``.
Expand Down
35 changes: 35 additions & 0 deletions open_seq2seq/models/speech2text_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,41 @@ def convergence_test(self, train_loss_threshold,
self.assertLess(eval_loss, eval_loss_threshold)
self.assertLess(eval_dict['Eval WER'], eval_wer_threshold)

def finetuning_test(self, train_loss_threshold,
eval_loss_threshold, eval_wer_threshold):
for dtype in [tf.float32, "mixed"]:

# pre-training
train_config, eval_config = self.prepare_config()
train_config.update({
"dtype": dtype,
})
eval_config.update({
"dtype": dtype,
})
loss, eval_loss, eval_dict = self.run_model(train_config, eval_config)

self.assertLess(loss, train_loss_threshold)
self.assertLess(eval_loss, eval_loss_threshold)
self.assertLess(eval_dict['Eval WER'], eval_wer_threshold)

# finetuning
restore_dir = train_config['logdir']
train_config['logdir'] = tempfile.mktemp()
eval_config['logdir'] = train_config['logdir']
train_config.update({
"load_model": restore_dir,
"lr_policy_params": {
"learning_rate": 0.0001,
"power": 2,
}
})
loss_ft, eval_loss_ft, eval_dict_ft = self.run_model(train_config, eval_config)

self.assertLess(loss_ft, train_loss_threshold)
self.assertLess(eval_loss_ft, eval_loss_threshold)
self.assertLess(eval_dict_ft['Eval WER'], eval_wer_threshold)

def convergence_with_iter_size_test(self):
try:
import horovod.tensorflow as hvd
Expand Down
59 changes: 35 additions & 24 deletions open_seq2seq/utils/funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,9 +116,33 @@ def train(train_model, eval_model=None, debug_port=None, custom_hooks=None):
# checkpoint.
restoring = load_model_dir and not tf.train.latest_checkpoint(checkpoint_dir)
if restoring:
scaffold = TransferScaffold(
local_init_op=tf.group(tf.local_variables_initializer(), init_data_layer)
vars_in_checkpoint = {}
for var_name, var_shape in tf.train.list_variables(load_model_dir):
vars_in_checkpoint[var_name] = var_shape

print('VARS_IN_CHECKPOINT:')
print(vars_in_checkpoint)

vars_to_load = []
for var in tf.global_variables():
var_name = var.name.split(':')[0]
if var_name in vars_in_checkpoint:
if var.shape == vars_in_checkpoint[var_name] and \
'global_step' not in var_name:
vars_to_load.append(var)

print('VARS_TO_LOAD:')
for var in vars_to_load:
print(var)

load_model_fn = tf.contrib.framework.assign_from_checkpoint_fn(
tf.train.latest_checkpoint(load_model_dir), vars_to_load
)
scaffold = tf.train.Scaffold(
local_init_op=tf.group(tf.local_variables_initializer(), init_data_layer),
init_fn = lambda scaffold_self, sess: load_model_fn(sess)
)

else:
scaffold = tf.train.Scaffold(
local_init_op=tf.group(tf.local_variables_initializer(), init_data_layer)
Expand All @@ -134,28 +158,15 @@ def train(train_model, eval_model=None, debug_port=None, custom_hooks=None):
"train model does not define get_num_objects_per_step method.")

# starting training
if restoring:
sess = TransferMonitoredTrainingSession(
scaffold=scaffold,
checkpoint_dir=checkpoint_dir,
save_summaries_steps=train_model.params['save_summaries_steps'],
config=sess_config,
save_checkpoint_secs=None,
log_step_count_steps=train_model.params['save_summaries_steps'],
stop_grace_period_secs=300,
hooks=hooks,
load_model_dir=load_model_dir,
load_fc=train_model.params['load_fc'])
else:
sess = tf.train.MonitoredTrainingSession(
scaffold=scaffold,
checkpoint_dir=checkpoint_dir,
save_summaries_steps=train_model.params['save_summaries_steps'],
config=sess_config,
save_checkpoint_secs=None,
log_step_count_steps=train_model.params['save_summaries_steps'],
stop_grace_period_secs=300,
hooks=hooks)
sess = tf.train.MonitoredTrainingSession(
scaffold=scaffold,
checkpoint_dir=checkpoint_dir,
save_summaries_steps=train_model.params['save_summaries_steps'],
config=sess_config,
save_checkpoint_secs=None,
log_step_count_steps=train_model.params['save_summaries_steps'],
stop_grace_period_secs=300,
hooks=hooks)
step = 0
num_bench_updates = 0
while True:
Expand Down
12 changes: 10 additions & 2 deletions open_seq2seq/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -506,7 +506,13 @@ def get_base_config(args):
help='whether to log output, git info, cmd args, etc.')
parser.add_argument('--use_xla_jit', dest='use_xla_jit', action='store_true',
help='whether to use XLA_JIT to compile and run the model.')
parser.add_argument('--infer_dataset', dest='infer_dataset',
help='infer_dataset csv file.')
parser.add_argument('--train_dataset', dest='train_dataset',
help='train_dataset csv file.')
args, unknown = parser.parse_known_args(args)
infer_params = args.infer_dataset
train_params = args.train_dataset

if args.mode not in [
'train',
Expand All @@ -519,7 +525,10 @@ def get_base_config(args):
"['train', 'eval', 'train_eval', 'infer', "
"'interactive_infer']")
config_module = runpy.run_path(args.config_file, init_globals={'tf': tf})

if infer_params:
config_module['infer_params']['data_layer_params']['dataset_files'] = infer_params.split(',')
if train_params:
config_module['train_params']['data_layer_params']['dataset_files'] = train_params.split(',')
base_config = config_module.get('base_params', None)
if base_config is None:
raise ValueError('base_config dictionary has to be '
Expand All @@ -541,7 +550,6 @@ def get_base_config(args):
parser_unk.add_argument('--' + pm, default=value, type=ast.literal_eval)
config_update = parser_unk.parse_args(unknown)
nested_update(base_config, nest_dict(vars(config_update)))

return args, base_config, base_model, config_module

def get_calibration_config(arguments):
Expand Down
Loading