Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Script for model converter #524

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
5 changes: 5 additions & 0 deletions flashlight/app/asr/tools/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,8 @@ build_tool(
${CMAKE_CURRENT_LIST_DIR}/benchmark/ArchBenchmark.cpp
fl_asr_arch_benchmark
)
build_tool(
${CMAKE_CURRENT_LIST_DIR}/serialization/ModelConverter.cpp
${CMAKE_CURRENT_LIST_DIR}/serialization/Compat.cpp
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

build_tool function accepts 2 arguments which lead to a build error.

fl_asr_model_converter
)
17 changes: 17 additions & 0 deletions flashlight/app/asr/tools/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -62,3 +62,20 @@ Below are baseline models usable with the tool, although any model/lexicon/token
| [baseline_dev-other](https://dl.fbaipublicfiles.com/wav2letter/audio_analysis/tds_ctc/model.bin) | LibriSpeech | dev-other | CTC | [Archfile](https://dl.fbaipublicfiles.com/wav2letter/audio_analysis/tds_ctc/arch.txt) | [Lexicon](https://dl.fbaipublicfiles.com/wav2letter/audio_analysis/tds_ctc/dict.lst) | [Tokens](https://dl.fbaipublicfiles.com/wav2letter/audio_analysis/tds_ctc/tokens.lst) |

</details>
<summary>Model Conversion</summary>
Sometimes, an ASR model trained from an `old` commit of `flashlight` can fail on `master` if the serialization semantics has changed. It'll be hard to reload the model and run the job in 'continue' mode or use it in decoder. `fl_asr_model_converter` can be used to get around this problem by converting the old model to new serialization format. This involves two steps:

Step 1:
From the `old` commit of `flashlight`, build `fl_asr_model_converter` and run
```
fl_asr_model_converter old <OLD_MODEL_PATH>;
```

Step 2:
From the `new` commit of `flashlight`, build `fl_asr_model_converter` and run
```
fl_asr_model_converter new <OLD_MODEL_PATH>;
```

The converted model can be found at <OLD_MODEL_PATH>.new
<details>
27 changes: 27 additions & 0 deletions flashlight/app/asr/tools/serialization/Compat.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include "flashlight/app/asr/tools/serialization/Compat.h"

#include "flashlight/app/asr/common/Defines.h"
#include "flashlight/app/asr/common/Flags.h"

namespace fl {
namespace app {
namespace asr {
DEFINE_int64(decoder_layers, 1, "s2s transformer decoder: number of layers");
DEFINE_double(decoder_dropout, 0.0, "s2s transformer decoder: dropout");
DEFINE_double(decoder_layerdrop, 0.0, "s2s transformer decoder: layerdrop");

void initCompat() {
DEPRECATE_FLAGS(decoder_layers, am_decoder_tr_layers);
DEPRECATE_FLAGS(decoder_dropout, am_decoder_tr_dropout);
DEPRECATE_FLAGS(decoder_layerdrop, am_decoder_tr_layerdrop);
}
} // namespace asr
} // namespace app
} // namespace fl
23 changes: 23 additions & 0 deletions flashlight/app/asr/tools/serialization/Compat.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#pragma once

#include <gflags/gflags.h>

namespace fl {
namespace app {
namespace asr {

DECLARE_int64(decoder_layers);
DECLARE_double(decoder_dropout);
DECLARE_double(decoder_layerdrop);

void initCompat();
} // namespace asr
} // namespace app
} // namespace fl
204 changes: 204 additions & 0 deletions flashlight/app/asr/tools/serialization/ModelConverter.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,204 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include <glog/logging.h>
#include <string.h>
#include <memory>
#include <unordered_map>

#include <flashlight/fl/flashlight.h>

#include "flashlight/app/asr/common/Defines.h"
#include "flashlight/app/asr/common/Flags.h"
#include "flashlight/app/asr/criterion/criterion.h"
#include "flashlight/app/asr/data/FeatureTransforms.h"
#include "flashlight/app/asr/data/Utils.h"
#include "flashlight/app/asr/runtime/runtime.h"
#include "flashlight/app/asr/tools/serialization/Compat.h"
#include "flashlight/ext/common/SequentialBuilder.h"
#include "flashlight/ext/common/Serializer.h"
#include "flashlight/ext/plugin/ModulePlugin.h"

using namespace fl::app::asr;

namespace {

auto newModelPath = [](const std::string& path) { return path + ".new"; };

auto tempModelPath = [](const std::string& path) { return path + ".tmp"; };

void loadFromBinaryDump(
const char* fname,
std::shared_ptr<fl::Module> ntwrk,
std::shared_ptr<fl::Module> crit) {
ntwrk->eval();
crit->eval();
for (int i = 0; i < ntwrk->params().size(); ++i) {
std::string key = "net-" + std::to_string(i);
ntwrk->setParams(fl::Variable(af::readArray(fname, key.c_str()), false), i);
}
for (int i = 0; i < crit->params().size(); ++i) {
std::string key = "crt-" + std::to_string(i);
crit->setParams(fl::Variable(af::readArray(fname, key.c_str()), false), i);
}
}

void saveToBinaryDump(
const char* fname,
std::shared_ptr<fl::Module> ntwrk,
std::shared_ptr<fl::Module> crit) {
ntwrk->eval();
crit->eval();
for (int i = 0; i < ntwrk->params().size(); ++i) {
std::string key = "net-" + std::to_string(i);
af::saveArray(key.c_str(), ntwrk->param(i).array(), fname, (i != 0));
}

for (int i = 0; i < crit->params().size(); ++i) {
std::string key = "crt-" + std::to_string(i);
af::saveArray(key.c_str(), crit->param(i).array(), fname, true);
}
}

int getSpeechFeatureSize() {
fl::lib::audio::FeatureParams featParams(
FLAGS_samplerate,
FLAGS_framesizems,
FLAGS_framestridems,
FLAGS_filterbanks,
FLAGS_lowfreqfilterbank,
FLAGS_highfreqfilterbank,
FLAGS_mfcccoeffs,
kLifterParam /* lifterparam */,
FLAGS_devwin /* delta window */,
FLAGS_devwin /* delta-delta window */);
featParams.useEnergy = false;
featParams.usePower = false;
featParams.zeroMeanFrame = false;
auto featureRes =
getFeatureType(FLAGS_features_type, FLAGS_channels, featParams);
return featureRes.first;
}

} // namespace

int main(int argc, char** argv) {
std::shared_ptr<fl::Module> network;
std::shared_ptr<SequenceCriterion> criterion;
std::unordered_map<std::string, std::string> cfg;

initCompat();

if (argc < 3) {
LOG(FATAL)
<< "Incorrect usage. 'fl_asr_model_converter [model_path] [old/new]'";
}

std::string binaryType = argv[1];
std::string modelPath = argv[2];
std::string version;
if (binaryType == "old") {
LOG(INFO) << "Saving params from `old binary` model to a binary dump";
fl::ext::Serializer::load(modelPath, version, cfg, network, criterion);
saveToBinaryDump(tempModelPath(modelPath).c_str(), network, criterion);
} else if (binaryType == "new") {
LOG(INFO) << "Loading model params from binary dump to `new binary` model";

// Read gflags from old model
fl::ext::Serializer::load(modelPath, version, cfg);
auto flags = cfg.find(kGflags);
LOG_IF(FATAL, flags == cfg.end()) << "Invalid config loaded";
gflags::ReadFlagsFromString(flags->second, gflags::GetArgv0(), true);
gflags::ParseCommandLineFlags(&argc, &argv, false);
handleDeprecatedFlags();

auto numFeatures = getSpeechFeatureSize();

fl::lib::text::Dictionary tokenDict(FLAGS_tokens);
auto scalemode = getCriterionScaleMode(FLAGS_onorm, FLAGS_sqnorm);
// Setup-specific modifications
for (int64_t r = 1; r <= FLAGS_replabel; ++r) {
tokenDict.addEntry(std::to_string(r));
}
// ctc expects the blank label last
if (FLAGS_criterion == kCtcCriterion) {
tokenDict.addEntry(kBlankToken);
}
bool isSeq2seqCrit = FLAGS_criterion == kSeq2SeqTransformerCriterion ||
FLAGS_criterion == kSeq2SeqRNNCriterion;
if (isSeq2seqCrit) {
tokenDict.addEntry(fl::app::asr::kEosToken);
tokenDict.addEntry(fl::lib::text::kPadToken);
}

int numClasses = tokenDict.indexSize();

// Intialize Network and Criterion
if (fl::lib::endsWith(FLAGS_arch, ".so")) {
network = fl::ext::ModulePlugin(FLAGS_arch).arch(numFeatures, numClasses);
} else {
network =
fl::ext::buildSequentialModule(FLAGS_arch, numFeatures, numClasses);
}
if (FLAGS_criterion == kCtcCriterion) {
criterion = std::make_shared<CTCLoss>(scalemode);
} else if (FLAGS_criterion == kAsgCriterion) {
criterion =
std::make_shared<ASGLoss>(numClasses, scalemode, FLAGS_transdiag);
} else if (FLAGS_criterion == kSeq2SeqRNNCriterion) {
std::vector<std::shared_ptr<AttentionBase>> attentions;
for (int i = 0; i < FLAGS_decoderattnround; i++) {
attentions.push_back(createAttention());
}
criterion = std::make_shared<Seq2SeqCriterion>(
numClasses,
FLAGS_encoderdim,
tokenDict.getIndex(fl::app::asr::kEosToken),
tokenDict.getIndex(fl::lib::text::kPadToken),
FLAGS_maxdecoderoutputlen,
attentions,
createAttentionWindow(),
FLAGS_trainWithWindow,
FLAGS_pctteacherforcing,
FLAGS_labelsmooth,
FLAGS_inputfeeding,
FLAGS_samplingstrategy,
FLAGS_gumbeltemperature,
FLAGS_decoderrnnlayer,
FLAGS_decoderattnround,
FLAGS_decoderdropout);
} else if (FLAGS_criterion == kSeq2SeqTransformerCriterion) {
criterion = std::make_shared<TransformerCriterion>(
numClasses,
FLAGS_encoderdim,
tokenDict.getIndex(fl::app::asr::kEosToken),
tokenDict.getIndex(fl::lib::text::kPadToken),
FLAGS_maxdecoderoutputlen,
FLAGS_am_decoder_tr_layers,
createAttention(),
createAttentionWindow(),
FLAGS_trainWithWindow,
FLAGS_labelsmooth,
FLAGS_pctteacherforcing,
FLAGS_am_decoder_tr_dropout,
FLAGS_am_decoder_tr_layerdrop);
} else {
LOG(FATAL) << "unimplemented criterion";
}

loadFromBinaryDump(tempModelPath(modelPath).c_str(), network, criterion);
fl::ext::Serializer::save(
newModelPath(modelPath), FL_APP_ASR_VERSION, cfg, network, criterion);

} else {
LOG(FATAL) << "Incorrect binary type specified.";
}

LOG(INFO) << "Done !";

return 0;
}