Skip to content

Commit

Permalink
Add support for zero posEmb (#3)
Browse files Browse the repository at this point in the history
Summary:
This commit allows disabling relative positional embeddings by passing bptt = 0 to the Transformer constructor.

Pull Request resolved: fairinternal/flashlight#3

Reviewed By: syhw, chaitan3

Differential Revision: D21958749

Pulled By: jacobkahn

fbshipit-source-id: 777e5d5625458858393616e425d85f91dbfe83da
  • Loading branch information
Chaitanya Talnikar authored and facebook-github-bot committed Jul 25, 2020
1 parent 831d7c3 commit 4f159fa
Showing 1 changed file with 7 additions and 3 deletions.
10 changes: 7 additions & 3 deletions flashlight/contrib/modules/Transformer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,10 @@ Transformer::Transformer(
transformerInitLinear(headDim * nHeads, modelDim))),
norm1_(std::make_shared<LayerNorm>(std::vector<int>({0, 3}))),
norm2_(std::make_shared<LayerNorm>(std::vector<int>({0, 3}))) {
params_.push_back(
uniform(2 * bptt - 1, headDim, -0.1, 0.1, af::dtype::f32, true));
if (bptt > 0) {
params_.push_back(
uniform(2 * bptt - 1, headDim, -0.1, 0.1, af::dtype::f32, true));
}

add(w1_);
add(w2_);
Expand Down Expand Up @@ -141,7 +143,9 @@ Variable Transformer::selfAttention(const std::vector<Variable>& input) {
auto v = transpose((*wv_)(concatenate(input, 1)));

Variable mask, posEmb;
posEmb = tile(params_[0], af::dim4(1, 1, nHeads_ * bsz));
if (bptt_ > 0) {
posEmb = tile(params_[0], af::dim4(1, 1, nHeads_ * bsz));
}
if (useMask_ && input.back().dims(1) > 1) {
mask = getMask(n, input.size() == 2);
}
Expand Down

0 comments on commit 4f159fa

Please sign in to comment.