Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimizing the Performance of Decoder fl_asr_decode #709

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
160 changes: 117 additions & 43 deletions flashlight/app/asr/Decode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,12 @@ int main(int argc, char** argv) {
}
fl::VerboseLogging::setMaxLoggingLevel(FLAGS_fl_vlog_level);

// flashlight optim mode
auto flOptimLevel = FLAGS_fl_optim_mode.empty()
? fl::OptimLevel::DEFAULT
: fl::OptimMode::toOptimLevel(FLAGS_fl_optim_mode);
fl::OptimMode::get().setOptimLevel(flOptimLevel);

/* ===================== Create Network ===================== */
if (FLAGS_emission_dir.empty() && FLAGS_am.empty()) {
LOG(FATAL) << "Both flags are empty: `-emission_dir` and `-am`";
Expand Down Expand Up @@ -330,15 +336,16 @@ int main(int argc, char** argv) {
/*sfxConf=*/{});
auto targetTransform = targetFeatures(tokenDict, lexicon, targetGenConfig);
auto wordTransform = wordFeatures(wordDict);
int targetpadVal =
isSeq2seqCrit ? tokenDict.getIndex(fl::lib::text::kPadToken) : kTargetPadValue;
int wordpadVal = wordDict.getIndex(kUnkToken);
int targetpadVal = isSeq2seqCrit
? tokenDict.getIndex(fl::lib::text::kPadToken)
: kTargetPadValue;
int wordpadVal = -1;

std::vector<std::string> testSplits = fl::lib::split(",", FLAGS_test, true);
auto ds = createDataset(
testSplits,
FLAGS_datadir,
1 /* batchsize */,
FLAGS_batchsize,
inputTransform,
targetTransform,
wordTransform,
Expand All @@ -364,7 +371,9 @@ int main(int argc, char** argv) {
&tokenDict,
&wordDict,
&emissionQueue,
&isSeq2seqCrit](int tid) {
&isSeq2seqCrit,
targetpadVal,
wordpadVal](int tid) {
// Initialize AM
fl::setDevice(tid);
std::shared_ptr<fl::Module> localNetwork = network;
Expand All @@ -387,62 +396,127 @@ int main(int argc, char** argv) {
localDs = std::make_shared<fl::PrefetchDataset>(
localDs, FLAGS_nthread, FLAGS_nthread);

auto trimPad = [](std::vector<int>& vec, const int padToken) {
if (vec.back() != padToken) {
return;
}
int left = 0;
int right = vec.size() - 1;
while (left < right) {
int mid = (left + right) / 2;
if (vec[mid] != padToken && vec[mid + 1] == padToken) {
vec.resize(mid + 1);
return;
} else if (vec[mid] == padToken) {
right = mid;
} else {
left = mid;
}
}
throw std::runtime_error("[Decoder] Could not identify a non-pad token.");
};

for (auto& sample : *localDs) {
auto sampleId = readSampleIds(sample[kSampleIdx]).front();
auto sampleIds = readSampleIds(sample[kSampleIdx]);
auto tokenTargetArray = sample[kTargetIdx];
auto wordTargetArray = sample[kWordIdx];
int effectiveBatchSize = tokenTargetArray.dims(1);

/* 2. Load Targets */
TargetUnit targetUnit;
auto tokenTarget = afToVector<int>(sample[kTargetIdx]);
auto wordTarget = afToVector<int>(sample[kWordIdx]);
// TODO: we will reform the dataset so that the loaded word
// targets are strings already
std::vector<std::string> wordTargetStr;
if (FLAGS_uselexicon) {
wordTargetStr = wrdIdx2Wrd(wordTarget, wordDict);
} else {
auto letterTarget = tknTarget2Ltr(
tokenTarget,
tokenDict,
FLAGS_criterion,
FLAGS_surround,
isSeq2seqCrit,
FLAGS_replabel,
FLAGS_usewordpiece,
FLAGS_wordseparator);
wordTargetStr = tkn2Wrd(letterTarget, FLAGS_wordseparator);
std::vector<TargetUnit> targetUnits(effectiveBatchSize);
auto tokenTargetLength = tokenTargetArray.dims(0);
auto wordTargetLength = wordTargetArray.dims(0);

// MTMD: Shouldn't we skip placing these two (i.e., tokenTargetArray and
// wordTargetArray) in an array in the first place?
auto tokenTargets = afToVector<int>(tokenTargetArray);
auto wordTargets = afToVector<int>(wordTargetArray);

auto tokenTarget = std::vector<int>(tokenTargetLength);
auto wordTarget = std::vector<int>(wordTargetLength);
for (int i = 0; i < effectiveBatchSize; i++) {
tokenTarget.resize(tokenTargetLength);
wordTarget.resize(wordTargetLength);

copy(
tokenTargets.begin() + i * tokenTargetLength,
tokenTargets.begin() + (i + 1) * tokenTargetLength,
tokenTarget.begin());

copy(
wordTargets.begin() + i * wordTargetLength,
wordTargets.begin() + (i + 1) * wordTargetLength,
wordTarget.begin());

trimPad(tokenTarget, targetpadVal);
trimPad(wordTarget, wordpadVal);

// TODO: we will reform the dataset so that the loaded word
// targets are strings already
std::vector<std::string> wordTargetStr;
if (FLAGS_uselexicon) {
wordTargetStr = wrdIdx2Wrd(wordTarget, wordDict);
} else {
auto letterTarget = tknTarget2Ltr(
tokenTarget,
tokenDict,
FLAGS_criterion,
FLAGS_surround,
isSeq2seqCrit,
FLAGS_replabel,
FLAGS_usewordpiece,
FLAGS_wordseparator);
wordTargetStr = tkn2Wrd(letterTarget, FLAGS_wordseparator);
}
targetUnits[i].wordTargetStr = wordTargetStr;
targetUnits[i].tokenTarget = tokenTarget;
}

targetUnit.wordTargetStr = wordTargetStr;
targetUnit.tokenTarget = tokenTarget;

/* 3. Load Emissions */
EmissionUnit emissionUnit;
std::vector<EmissionUnit> emissionUnit;
auto mathType = FLAGS_fl_amp_use_mixed_precision ? f16 : f32;
if (FLAGS_emission_dir.empty()) {
fl::Variable rawEmission;
if (usePlugin) {
rawEmission = localNetwork
->forward({fl::input(sample[kInputIdx]),
fl::noGrad(sample[kDurationIdx])})
.front();
rawEmission =
localNetwork
->forward({fl::input(sample[kInputIdx].as(mathType)),
fl::noGrad(sample[kDurationIdx])})
.front();

} else {
rawEmission = fl::ext::forwardSequentialModuleWithPadMask(
fl::input(sample[kInputIdx]), localNetwork, sample[kDurationIdx]);
}
emissionUnit = EmissionUnit(
afToVector<float>(rawEmission),
sampleId,
rawEmission.dims(1),
rawEmission.dims(0));
auto emissionSize = rawEmission.dims(0) * rawEmission.dims(1);
auto rawEmissions = afToVector<float>(rawEmission);
std::vector<float> rawEmissionHost(emissionSize);
for (int i = 0; i < effectiveBatchSize; i++) {
std::copy(
rawEmissions.begin() + i * emissionSize,
rawEmissions.begin() + (i + 1) * emissionSize,
rawEmissionHost.begin());
emissionUnit.emplace_back(
rawEmissionHost,
sampleIds[i],
rawEmission.dims(1),
rawEmission.dims(0));
}
} else {
auto cleanTestPath = cleanFilepath(FLAGS_test);
std::string emissionDir =
pathsConcat(FLAGS_emission_dir, cleanTestPath);
std::string savePath = pathsConcat(emissionDir, sampleId + ".bin");
std::string eVersion;
Serializer::load(savePath, eVersion, emissionUnit);
for (auto sampleId : sampleIds) {
std::string savePath = pathsConcat(emissionDir, sampleId + ".bin");
std::string eVersion;
EmissionUnit tmpEmissionUnit;
Serializer::load(savePath, eVersion, tmpEmissionUnit);
emissionUnit.push_back(tmpEmissionUnit);
}
}
for (int i = 0; i < effectiveBatchSize; i++) {
emissionQueue.add({emissionUnit[i], targetUnits[i]});
}

emissionQueue.add({emissionUnit, targetUnit});
}

localNetwork.reset(); // AM is only used in running forward pass. So we will
Expand Down
4 changes: 2 additions & 2 deletions flashlight/app/asr/common/Flags.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,7 @@ DEFINE_bool(
DEFINE_string(
fl_optim_mode,
"",
"[train] Sets the flashlight optimization mode. "
"Sets the flashlight optimization mode. "
"Optim modes can be O1, O2, or O3.");
DEFINE_string(
fl_log_level,
Expand All @@ -332,7 +332,7 @@ DEFINE_int64(
DEFINE_bool(
fl_amp_use_mixed_precision,
false,
"[train] Use mixed precision for training - scale loss and gradients up and down "
"Use mixed precision for training - scale loss and gradients up and down "
"by a scale factor that changes over time. If no fl optim mode is "
"specified with --fl_optim_mode when passing this flag, automatically "
"sets the optim mode to O1.");
Expand Down
2 changes: 1 addition & 1 deletion flashlight/app/asr/decoder/TranscriptionUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ std::vector<std::string> tknIdx2Ltr(
if (result.front() == wordSep) {
result.erase(result.begin());
}
if (!result.empty() && result.back() == wordSep) {
while (!result.empty() && result.back() == wordSep) {
result.pop_back();
}
}
Expand Down
2 changes: 2 additions & 0 deletions flashlight/fl/autograd/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ if (FL_USE_CPU)
${CMAKE_CURRENT_LIST_DIR}/backend/cpu/operators/AdvancedIndex.cpp
${CMAKE_CURRENT_LIST_DIR}/backend/cpu/Conv2D.cpp
${CMAKE_CURRENT_LIST_DIR}/backend/cpu/Pool2D.cpp
${CMAKE_CURRENT_LIST_DIR}/backend/cpu/PositionalEmbedding.cpp
${CMAKE_CURRENT_LIST_DIR}/backend/cpu/RNN.cpp
${CMAKE_CURRENT_LIST_DIR}/backend/cpu/BatchNorm.cpp # generic
${CMAKE_CURRENT_LIST_DIR}/backend/cpu/DnnlUtils.cpp # generic
Expand Down Expand Up @@ -62,6 +63,7 @@ if (FL_USE_CUDA)
${CMAKE_CURRENT_LIST_DIR}/backend/cuda/CudnnUtils.h
${CMAKE_CURRENT_LIST_DIR}/backend/cuda/CudnnUtils.cpp
${CMAKE_CURRENT_LIST_DIR}/backend/cuda/Pool2D.cpp
${CMAKE_CURRENT_LIST_DIR}/backend/cuda/PositionalEmbedding.cu
${CMAKE_CURRENT_LIST_DIR}/backend/cuda/RNN.cpp
)

Expand Down
9 changes: 3 additions & 6 deletions flashlight/fl/autograd/Functions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1236,16 +1236,13 @@ Variable gelu(const Variable& in) {
fl::tanh(0.7978845608 * (input + 0.044715 * input * input * input)));
}

fl::Variable relativePositionEmbeddingRotate(const fl::Variable& input) {
fl::Variable relativePositionalEmbeddingRotate(const fl::Variable& input) {
auto data = input.array();
int d0 = data.dims(0);
int d1 = data.dims(1);
int d2 = data.dims(2);
int d3 = data.dims(3);
data = af::join(0, data, af::constant(0.0, d1, d1, d2, d3, data.type()));
data = af::moddims(data, af::dim4((d0 + d1) * d1, 1, d2, d3));
data = data.rows(0, (d1 + d0 - 1) * d1 - 1);
data = af::moddims(data, af::dim4(d0 + d1 - 1, d1, d2, d3));
data = relativePositionalEmbeddingRotate(data);
auto gradFunc = [d0, d1, d2, d3](
std::vector<fl::Variable>& inputs,
const fl::Variable& gradOutput) {
Expand Down Expand Up @@ -1283,7 +1280,7 @@ fl::Variable multiheadAttention(
if (!posEmb.isempty()) {
int n = posEmb.dims(0) / 2 - offset;
auto pscores =
relativePositionEmbeddingRotate(matmulNT(posEmb.as(q.type()), q));
relativePositionalEmbeddingRotate(matmulNT(posEmb.as(q.type()), q));
scores = scores + transpose(pscores.rows(n, n + k.dims(0) - 1));
}
if (!mask.isempty()) {
Expand Down
7 changes: 7 additions & 0 deletions flashlight/fl/autograd/Functions.h
Original file line number Diff line number Diff line change
Expand Up @@ -961,6 +961,13 @@ Variable relu(const Variable& input);
*/
Variable gelu(const Variable& input);

/**
* Relative positional embedding for the multihead attention
* Implementation partially follows https://arxiv.org/pdf/1803.02155.pdf
* The overloaded version of this function will be called for `Variables`.
*/
af::array relativePositionalEmbeddingRotate(const af::array& input);

/**
* Relative positional embedding for the multihead attention
* Implementation partially follows https://arxiv.org/pdf/1803.02155.pdf
Expand Down
16 changes: 16 additions & 0 deletions flashlight/fl/autograd/backend/cpu/PositionalEmbedding.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
#include <arrayfire.h>

namespace fl {
af::array relativePositionalEmbeddingRotate(const af::array& input) {
int d0 = input.dims(0);
int d1 = input.dims(1);
int d2 = input.dims(2);
int d3 = input.dims(3);
auto data =
af::join(0, input, af::constant(0.0, d1, d1, d2, d3, input.type()));
data = af::moddims(data, af::dim4((d0 + d1) * d1, 1, d2, d3));
data = data.rows(0, (d1 + d0 - 1) * d1 - 1);
data = af::moddims(data, af::dim4(d0 + d1 - 1, d1, d2, d3));
return data;
}
} // namespace fl
76 changes: 76 additions & 0 deletions flashlight/fl/autograd/backend/cuda/PositionalEmbedding.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
#include <cuda.h>
#include <cuda_fp16.h>
#include <cuda_runtime_api.h>
#include <device_launch_parameters.h>

#include <arrayfire.h>

#include <flashlight/fl/common/backend/cuda/CudaUtils.h>

namespace fl {

namespace {
#define NUM_THREADS 1024
} // namespace

template <typename T>
__global__ void relativePositionalEmbeddingRotateKernel(
const T* __restrict in,
T* __restrict out,
const int input_dim_0,
const int output_dim_0,
const int dim_1,
const int dim_2) {
const int featureMapId = blockIdx.x / dim_1;
const int columnId = blockIdx.x % dim_1;

const int inputFeatureMapSize = input_dim_0 * dim_1;
const int outputFeatureMapSize = output_dim_0 * dim_1;
const T* input = &in[featureMapId * inputFeatureMapSize];
T* output = &out[featureMapId * outputFeatureMapSize];
const int inputOffset = columnId * input_dim_0 - columnId;
const int outputOffset = columnId * output_dim_0;
for (int i = threadIdx.x; i < output_dim_0; i += blockDim.x) {
T tmp = 0;
if (i >= columnId && i < input_dim_0 + columnId) {
tmp = input[inputOffset + i];
}
output[outputOffset + i] = tmp;
}
}

af::array relativePositionalEmbeddingRotate(const af::array& input) {
const int input_dim_0 = input.dims(0);
const int dim_1 = input.dims(1);
const int dim_2 = input.dims(2);
const int output_dim_0 = input_dim_0 + dim_1 - 1;
auto output = af::array(af::dim4(output_dim_0, dim_1, dim_2), input.type());
input.eval();
output.eval();

if (input.type() == f16) {
relativePositionalEmbeddingRotateKernel<half>
<<<dim_1 * dim_2, NUM_THREADS, 0, fl::cuda::getActiveStream()>>>(
input.device<half>(),
output.device<half>(),
input_dim_0,
output_dim_0,
dim_1,
dim_2);
} else if (input.type() == f32) {
relativePositionalEmbeddingRotateKernel<float>
<<<dim_1 * dim_2, NUM_THREADS, 0, fl::cuda::getActiveStream()>>>(
input.device<float>(),
output.device<float>(),
input_dim_0,
output_dim_0,
dim_1,
dim_2);
} else {
throw std::runtime_error("Unsupported Type in Position Embedding.");
}
input.unlock();
output.unlock();
return output;
}
} // namespace fl