Skip to content

Commit

Permalink
Support Transformer models with post normalization (#469)
Browse files Browse the repository at this point in the history
  • Loading branch information
guillaumekln committed May 4, 2021
1 parent ca238cb commit 316ed26
Show file tree
Hide file tree
Showing 7 changed files with 93 additions and 50 deletions.
9 changes: 2 additions & 7 deletions include/ctranslate2/layers/attention.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,19 +10,14 @@ namespace ctranslate2 {
dim_t max_position,
bool with_cache = false);

enum class LayerNormStrategy {
Input,
Output,
};

class MultiHeadAttention : public Layer
{
public:
MultiHeadAttention(const models::Model& model,
const std::string& scope,
dim_t num_heads,
bool self_attention,
LayerNormStrategy layer_norm_strategy = LayerNormStrategy::Input);
bool pre_norm = true);
DataType output_type() const override;
dim_t output_size() const override;
void operator()(const StorageView& queries,
Expand All @@ -37,7 +32,7 @@ namespace ctranslate2 {
const dim_t _num_heads;
const bool _self_attention;
const std::vector<Dense> _linear;
const LayerNormStrategy _layer_norm_strategy;
const bool _pre_norm;
const LayerNorm _layer_norm;
const StorageView* _relative_position_keys;
const StorageView* _relative_position_values;
Expand Down
25 changes: 16 additions & 9 deletions include/ctranslate2/layers/transformer.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@ namespace ctranslate2 {
class FeedForwardNetwork : public Layer
{
public:
FeedForwardNetwork(const models::Model& model, const std::string& scope);
FeedForwardNetwork(const models::Model& model,
const std::string& scope,
const bool pre_norm = true);

void operator()(const StorageView& input, StorageView& output) const;

Expand All @@ -26,6 +28,7 @@ namespace ctranslate2 {

private:
const LayerNorm _layer_norm;
const bool _pre_norm;
const Activation _activation;
const Dense _ff1;
const Dense _ff2;
Expand All @@ -36,7 +39,8 @@ namespace ctranslate2 {
public:
TransformerEncoderLayer(const models::Model& model,
const std::string& scope,
const size_t num_heads);
const size_t num_heads,
const bool pre_norm = true);

void operator()(const StorageView& input,
const StorageView& lengths,
Expand All @@ -62,7 +66,8 @@ namespace ctranslate2 {
TransformerDecoderLayer(const models::Model& model,
const std::string& scope,
const size_t num_heads,
const bool with_encoder_attention = true);
const bool with_encoder_attention = true,
const bool pre_norm = true);

void operator()(const StorageView& input,
const StorageView* memory,
Expand Down Expand Up @@ -95,25 +100,26 @@ namespace ctranslate2 {
TransformerEncoder(const models::Model& model,
const std::string& scope,
const size_t num_heads,
const bool with_position_encoding = true);
const bool with_position_encoding = true,
const bool pre_norm = true);

void operator()(const StorageView& ids,
const StorageView& lengths,
StorageView& output) override;

DataType output_type() const {
return _output_norm.output_type();
return _layers.back()->output_type();
}

dim_t output_size() const {
return _output_norm.output_size();
return _layers.back()->output_size();
}

private:
const Embeddings _embeddings;
const ComputeType _compute_type;
const std::unique_ptr<PositionEncoder> _position_encoder;
const LayerNorm _output_norm;
const std::unique_ptr<LayerNorm> _output_norm;
std::vector<std::unique_ptr<const TransformerEncoderLayer>> _layers;
};

Expand All @@ -124,7 +130,8 @@ namespace ctranslate2 {
const std::string& scope,
const size_t num_heads,
const bool with_position_encoding = true,
const bool with_encoder_attention = true);
const bool with_encoder_attention = true,
const bool pre_norm = true);

void set_vocabulary_mask(const StorageView& ids) override;
void reset_vocabulary_mask() override;
Expand Down Expand Up @@ -152,7 +159,7 @@ namespace ctranslate2 {
const ComputeType _compute_type;
const Embeddings _embeddings;
const std::unique_ptr<PositionEncoder> _position_encoder;
const LayerNorm _output_norm;
const std::unique_ptr<LayerNorm> _output_norm;
std::vector<std::unique_ptr<const TransformerDecoderLayer>> _layers;
Dense _proj;
};
Expand Down
1 change: 1 addition & 0 deletions include/ctranslate2/models/transformer.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ namespace ctranslate2 {
private:
size_t _num_heads;
bool _with_relative_position;
bool _pre_norm;
};

}
Expand Down
25 changes: 18 additions & 7 deletions python/ctranslate2/specs/transformer_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,23 @@ class TransformerSpec(model_spec.SequenceToSequenceModelSpec):
explicitly set the number of layers and attention heads.
"""

def __init__(self, num_layers, num_heads, with_relative_position=False):
def __init__(
self,
num_layers,
num_heads,
with_relative_position=False,
pre_norm=True,
):
super().__init__()
if isinstance(num_layers, (list, tuple)):
num_encoder_layers, num_decoder_layers = num_layers
else:
num_encoder_layers, num_decoder_layers = num_layers, num_layers
self.num_heads = np.dtype("int8").type(num_heads)
self.pre_norm = np.dtype("int8").type(pre_norm)
self.with_relative_position = with_relative_position
self.encoder = TransformerEncoderSpec(num_encoder_layers)
self.decoder = TransformerDecoderSpec(num_decoder_layers)
self.encoder = TransformerEncoderSpec(num_encoder_layers, pre_norm=pre_norm)
self.decoder = TransformerDecoderSpec(num_decoder_layers, pre_norm=pre_norm)

@property
def name(self):
Expand All @@ -42,18 +49,22 @@ def vocabulary_size(self):


class TransformerEncoderSpec(model_spec.LayerSpec):
def __init__(self, num_layers):
def __init__(self, num_layers, pre_norm=True):
self.embeddings = common_spec.EmbeddingsSpec()
self.position_encodings = PositionEncoderSpec()
self.layer_norm = common_spec.LayerNormSpec()
self.layer_norm = (
common_spec.LayerNormSpec() if pre_norm else model_spec.OPTIONAL
)
self.layer = [TransformerEncoderLayerSpec() for _ in range(num_layers)]


class TransformerDecoderSpec(model_spec.LayerSpec):
def __init__(self, num_layers):
def __init__(self, num_layers, pre_norm=True):
self.embeddings = common_spec.EmbeddingsSpec()
self.position_encodings = PositionEncoderSpec()
self.layer_norm = common_spec.LayerNormSpec()
self.layer_norm = (
common_spec.LayerNormSpec() if pre_norm else model_spec.OPTIONAL
)
self.projection = common_spec.LinearSpec()
self.layer = [TransformerDecoderLayerSpec() for _ in range(num_layers)]

Expand Down
15 changes: 8 additions & 7 deletions src/layers/attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -136,11 +136,11 @@ namespace ctranslate2 {
const std::string& scope,
dim_t num_heads,
bool self_attention,
LayerNormStrategy layer_norm_strategy)
bool pre_norm)
: _num_heads(num_heads)
, _self_attention(self_attention)
, _linear(make_linear_layers(model, scope, self_attention))
, _layer_norm_strategy(layer_norm_strategy)
, _pre_norm(pre_norm)
, _layer_norm(model, scope + "/layer_norm")
, _relative_position_keys(model.get_variable_if_exists(scope + "/relative_position_keys"))
, _relative_position_values(model.get_variable_if_exists(scope + "/relative_position_values"))
Expand Down Expand Up @@ -177,13 +177,14 @@ namespace ctranslate2 {
StorageView split_keys(dtype, device);
StorageView split_values(dtype, device);

if (_layer_norm_strategy == LayerNormStrategy::Input) {
const StorageView* q = &queries;
if (_pre_norm) {
_layer_norm(queries, queries_proj);
_linear[0](queries_proj, fused_proj);
} else {
_linear[0](queries, fused_proj);
q = &queries_proj;
}

_linear[0](*q, fused_proj);

if (!_self_attention) {
split_heads(fused_proj, split_queries);
if (cached_keys == nullptr || cached_keys->empty()) {
Expand Down Expand Up @@ -258,7 +259,7 @@ namespace ctranslate2 {

_linear.back()(combined, output);
ops::Add()(queries, output, output);
if (_layer_norm_strategy == LayerNormStrategy::Output) {
if (!_pre_norm) {
_layer_norm(output, output);
}
}
Expand Down
60 changes: 42 additions & 18 deletions src/layers/transformer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4,30 +4,41 @@ namespace ctranslate2 {
namespace layers {

FeedForwardNetwork::FeedForwardNetwork(const models::Model& model,
const std::string& scope)
const std::string& scope,
const bool pre_norm)
: _layer_norm(model, scope + "/layer_norm")
, _pre_norm(pre_norm)
, _activation(ActivationType::ReLU)
, _ff1(model, scope + "/linear_0", &_activation)
, _ff2(model, scope + "/linear_1") {
}

void FeedForwardNetwork::operator()(const StorageView& input, StorageView& output) const {
const StorageView* x = &input;
if (_pre_norm) {
_layer_norm(input, output);
x = &output;
}

StorageView inner(input.dtype(), input.device());
_layer_norm(input, output);
_ff1(output, inner);
_ff1(*x, inner);
_ff2(inner, output);
ops::Add()(input, output, output);
if (!_pre_norm)
_layer_norm(output, output);
}


TransformerEncoderLayer::TransformerEncoderLayer(const models::Model& model,
const std::string& scope,
const size_t num_heads)
const size_t num_heads,
const bool pre_norm)
: _self_attention(model,
scope + "/self_attention",
num_heads,
/*self_attention=*/true)
, _ff(model, scope + "/ffn") {
/*self_attention=*/true,
pre_norm)
, _ff(model, scope + "/ffn", pre_norm) {
}

void TransformerEncoderLayer::operator()(const StorageView& input,
Expand All @@ -44,18 +55,21 @@ namespace ctranslate2 {
TransformerDecoderLayer::TransformerDecoderLayer(const models::Model& model,
const std::string& scope,
const size_t num_heads,
const bool with_encoder_attention)
const bool with_encoder_attention,
const bool pre_norm)
: _self_attention(model,
scope + "/self_attention",
num_heads,
/*self_attention=*/true)
/*self_attention=*/true,
pre_norm)
, _encoder_attention(with_encoder_attention
? std::make_unique<MultiHeadAttention>(model,
scope + "/attention",
num_heads,
/*self_attention=*/false)
/*self_attention=*/false,
pre_norm)
: nullptr)
, _ff(model, scope + "/ffn") {
, _ff(model, scope + "/ffn", pre_norm) {
}

void TransformerDecoderLayer::operator()(const StorageView& input,
Expand Down Expand Up @@ -99,19 +113,23 @@ namespace ctranslate2 {
TransformerEncoder::TransformerEncoder(const models::Model& model,
const std::string& scope,
const size_t num_heads,
const bool with_position_encoding)
const bool with_position_encoding,
const bool pre_norm)
: _embeddings(model, scope + "/embeddings")
, _compute_type(model.effective_compute_type())
, _position_encoder(with_position_encoding
? build_position_encoder(model, scope + "/position_encodings", _embeddings)
: nullptr)
, _output_norm(model, scope + "/layer_norm") {
, _output_norm(pre_norm
? std::make_unique<LayerNorm>(model, scope + "/layer_norm")
: nullptr) {
for (size_t l = 0;; ++l) {
const std::string layer_scope = scope + "/layer_" + std::to_string(l);
try {
auto layer = std::make_unique<TransformerEncoderLayer>(model,
layer_scope,
num_heads);
num_heads,
pre_norm);
_layers.emplace_back(std::move(layer));
} catch (std::exception&) {
if (l == 0)
Expand Down Expand Up @@ -143,7 +161,8 @@ namespace ctranslate2 {
if (l + 1 < _layers.size())
input = std::move(output);
}
_output_norm(output, output);
if (_output_norm)
(*_output_norm)(output, output);
if (padder)
padder->add_padding(output);
}
Expand All @@ -153,23 +172,27 @@ namespace ctranslate2 {
const std::string& scope,
const size_t num_heads,
const bool with_position_encoding,
const bool with_encoder_attention)
const bool with_encoder_attention,
const bool pre_norm)
: Decoder(model.device())
, _with_encoder_attention(with_encoder_attention)
, _compute_type(model.effective_compute_type())
, _embeddings(model, scope + "/embeddings")
, _position_encoder(with_position_encoding
? build_position_encoder(model, scope + "/position_encodings", _embeddings)
: nullptr)
, _output_norm(model, scope + "/layer_norm")
, _output_norm(pre_norm
? std::make_unique<LayerNorm>(model, scope + "/layer_norm")
: nullptr)
, _proj(model, scope + "/projection") {
for (size_t l = 0;; ++l) {
const std::string layer_scope = scope + "/layer_" + std::to_string(l);
try {
auto layer = std::make_unique<TransformerDecoderLayer>(model,
layer_scope,
num_heads,
with_encoder_attention);
with_encoder_attention,
pre_norm);
_layers.emplace_back(std::move(layer));
} catch (std::exception&) {
if (l == 0)
Expand Down Expand Up @@ -256,7 +279,8 @@ namespace ctranslate2 {
}

if (logits) {
_output_norm(layer_in, layer_in);
if (_output_norm)
(*_output_norm)(layer_in, layer_in);
_proj(layer_in, *logits);
}
}
Expand Down

0 comments on commit 316ed26

Please sign in to comment.