Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Pseudo self attention #2060

Draft
wants to merge 5 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions .github/workflows/push.yml
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,18 @@ jobs:
-word_vec_size 5 -report_every 5 \
-coverage_attn true -lambda_coverage 0.1 \
-rnn_size 10 -train_steps 10
- name: Test Transformer training with pseudo self attention
run : |
python train.py \
-config data/align_data.yaml \
-src_vocab /tmp/onmt.vocab.src \
-tgt_vocab /tmp/onmt.vocab.tgt \
-src_vocab_size 1000 \
-tgt_vocab_size 1000 \
-max_generator_batches 0 \
-encoder_type transformer -decoder_type transformer_lm_psa \
-layers 4 -word_vec_size 16 -rnn_size 16 -heads 2 -transformer_ff 64 \
-report_every 5 -train_steps 10
- name: Test Transformer training with align
run: |
python train.py \
Expand Down
37 changes: 28 additions & 9 deletions onmt/decoders/__init__.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,32 @@
"""Module defining decoders."""
from onmt.decoders.decoder import DecoderBase, InputFeedRNNDecoder, \
StdRNNDecoder
from onmt.decoders.transformer import TransformerDecoder, TransformerLMDecoder
from onmt.decoders.cnn_decoder import CNNDecoder
from onmt.decoders.decoder import (
DecoderBase,
InputFeedRNNDecoder,
StdRNNDecoder,
)
from onmt.decoders.transformer import (
TransformerDecoder,
TransformerLMDecoder,
TransformerLMPseudoSelfAttentionDecoder,
)

str2dec = {
"rnn": StdRNNDecoder,
"ifrnn": InputFeedRNNDecoder,
"cnn": CNNDecoder,
"transformer": TransformerDecoder,
"transformer_lm": TransformerLMDecoder,
"transformer_lm_psa": TransformerLMPseudoSelfAttentionDecoder,
}

str2dec = {"rnn": StdRNNDecoder, "ifrnn": InputFeedRNNDecoder,
"cnn": CNNDecoder, "transformer": TransformerDecoder,
"transformer_lm": TransformerLMDecoder}

__all__ = ["DecoderBase", "TransformerDecoder", "StdRNNDecoder", "CNNDecoder",
"InputFeedRNNDecoder", "str2dec", "TransformerLMDecoder"]
__all__ = [
"DecoderBase",
"TransformerDecoder",
"StdRNNDecoder",
"CNNDecoder",
"InputFeedRNNDecoder",
"str2dec",
"TransformerLMDecoder",
"TransformerLMPseudoSelfAttentionDecoder",
]
241 changes: 234 additions & 7 deletions onmt/decoders/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from onmt.decoders.decoder import DecoderBase
from onmt.modules import MultiHeadedAttention, AverageAttention
from onmt.modules import MultiHeadedPseudoSelfAttention
from onmt.modules.position_ffn import PositionwiseFeedForward
from onmt.modules.position_ffn import ActivationFunction
from onmt.utils.misc import sequence_mask
Expand Down Expand Up @@ -68,10 +69,16 @@ def __init__(
self.self_attn = AverageAttention(
d_model, dropout=attention_dropout, aan_useffn=aan_useffn
)

self.feed_forward = PositionwiseFeedForward(d_model, d_ff, dropout,
pos_ffn_activation_fn
)
elif self_attn_type == "pseudo-self":
self.self_attn = MultiHeadedPseudoSelfAttention(
heads,
d_model,
dropout=attention_dropout,
max_relative_positions=max_relative_positions,
)
self.feed_forward = PositionwiseFeedForward(
d_model, d_ff, dropout, pos_ffn_activation_fn
)
self.layer_norm_1 = nn.LayerNorm(d_model, eps=1e-6)
self.drop = nn.Dropout(dropout)
self.full_context_alignment = full_context_alignment
Expand Down Expand Up @@ -120,7 +127,8 @@ def update_dropout(self, dropout, attention_dropout):
def _forward(self, *args, **kwargs):
raise NotImplementedError

def _compute_dec_mask(self, tgt_pad_mask, future):
@staticmethod
def _compute_dec_mask(tgt_pad_mask, future):
tgt_len = tgt_pad_mask.size(-1)
if not future: # apply future_mask, result mask in (B, T, T)
future_mask = torch.ones(
Expand Down Expand Up @@ -253,7 +261,7 @@ def _forward(
"""
dec_mask = None

if inputs.size(1) > 1:
if step is None:
# masking is necessary when sequence length is greater than one
dec_mask = self._compute_dec_mask(tgt_pad_mask, future)

Expand Down Expand Up @@ -548,7 +556,7 @@ def _forward(
"""
dec_mask = None

if inputs.size(1) > 1:
if step is None or inputs.size(1) > 1:
# masking is necessary when sequence length is greater than one
dec_mask = self._compute_dec_mask(tgt_pad_mask, future)

Expand Down Expand Up @@ -693,3 +701,222 @@ def _init_cache(self, memory_bank=None):
if isinstance(layer.self_attn, AverageAttention):
raise NotImplementedError
self.state["cache"]["layer_{}".format(i)] = layer_cache


class TransformerLMPseudoSelfAttentionDecoderLayer(
TransformerDecoderLayerBase
):
"""Transformer Decoder only layer block in GPT style.

.. mermaid::

graph LR
%% "*SubLayer" can be self-attn, src-attn or feed forward block
A(input) --> B[Norm]
B --> C["*SubLayer"]
C --> D[Drop]
D --> E((+))
A --> E
E --> F(out)


Args:
See TransformerDecoderLayerBase
"""

def _forward(
self,
inputs,
src_memory_bank,
src_pad_mask,
tgt_pad_mask,
layer_cache=None,
step=None,
future=False,
):
"""A naive forward pass for transformer decoder.

# T: could be 1 in the case of stepwise decoding or tgt_len

Args:
inputs (FloatTensor): ``(batch_size, T, model_dim)``
tgt_pad_mask (bool): ``(batch_size, 1, T)``
layer_cache (dict or None): cached layer info when stepwise decode
step (int or None): stepwise decoding counter
future (bool): If set True, do not apply future_mask.

Returns:
(FloatTensor, FloatTensor):

* output ``(batch_size, T, model_dim)``
* attns ``(batch_size, head, T, T)``

"""
dec_mask = None
pseudo_mask = None
if step is None:
# masking is necessary when sequence length is greater than one
dec_mask = self._compute_dec_mask(tgt_pad_mask, future)
pseudo_mask = torch.cat(
[src_pad_mask.repeat(1, inputs.size(1), 1), dec_mask], axis=-1
)
else:
pseudo_mask = torch.cat(
(
src_pad_mask.repeat(1, inputs.size(1), 1),
torch.zeros(
(inputs.size(0), inputs.size(1), step + 1),
dtype=torch.bool,
device=src_pad_mask.device,
),
),
axis=-1,
)
inputs_norm = self.layer_norm_1(inputs)

query, attns = self.self_attn(
src_memory_bank.transpose(0, 1),
inputs_norm,
mask=pseudo_mask,
layer_cache=layer_cache,
attn_type="self",
)

output = self.drop(query) + inputs

output_feedforward = self.feed_forward(output)

return output_feedforward, attns


class TransformerLMPseudoSelfAttentionDecoder(TransformerDecoderBase):
"""The Transformer decoder from GPT-2 with pseudo self attention

.. mermaid::

graph BT
A[input]
B[multi-head self-attn]
C[feed forward]
O[output]
A --> B
B --> C
C --> O


Args:
num_layers (int): number of decoder layers.
d_model (int): size of the model
heads (int): number of heads
d_ff (int): size of the inner FF layer
copy_attn (bool): if using a separate copy attention
self_attn_type (str): type of self-attention scaled-dot, average
dropout (float): dropout in residual, self-attn(dot) and feed-forward
attention_dropout (float): dropout in context_attn (and self-attn(avg))
embeddings (onmt.modules.Embeddings):
embeddings to use, should have positional encodings
max_relative_positions (int):
Max distance between inputs in relative positions representations
aan_useffn (bool): Turn on the FFN layer in the AAN decoder
"""

def __init__(
self,
num_layers,
d_model,
heads,
d_ff,
copy_attn,
self_attn_type,
dropout,
attention_dropout,
embeddings,
max_relative_positions,
aan_useffn,
full_context_alignment=None,
alignment_layer=None,
alignment_heads=None,
pos_ffn_activation_fn=ActivationFunction.relu,
):
super(TransformerLMPseudoSelfAttentionDecoder, self).__init__(
d_model, copy_attn, embeddings, None
)
self.transformer_layers = nn.ModuleList(
[
TransformerLMPseudoSelfAttentionDecoderLayer(
d_model,
heads,
d_ff,
dropout,
attention_dropout,
self_attn_type="pseudo-self",
max_relative_positions=max_relative_positions,
aan_useffn=aan_useffn,
full_context_alignment=None,
alignment_heads=None,
pos_ffn_activation_fn=pos_ffn_activation_fn,
)
for i in range(num_layers)
]
)

def detach_state(self):
pass

def forward(self, tgt, memory_bank=None, step=None, **kwargs):
"""Decode, possibly stepwise."""
if step == 0:
self._init_cache()

tgt_words = tgt[:, :, 0].transpose(0, 1)

emb = self.embeddings(tgt, step=step)
assert emb.dim() == 3 # len x batch x embedding_dim

output = emb.transpose(0, 1).contiguous()

pad_idx = self.embeddings.word_padding_idx
src_lens = kwargs["memory_lengths"]
src_max_len = self.state["src"].shape[0]
src_pad_mask = ~sequence_mask(src_lens, src_max_len).unsqueeze(1)
tgt_pad_mask = tgt_words.data.eq(pad_idx).unsqueeze(1) # [B, 1, T_tgt]

with_align = kwargs.pop("with_align", False)
assert not with_align, "TransformerLMDecoder does not support align"

for i, layer in enumerate(self.transformer_layers):
layer_cache = (
self.state["cache"]["layer_{}".format(i)]
if step is not None
else None
)
output, attn, _ = layer(
output,
memory_bank,
src_pad_mask,
tgt_pad_mask,
layer_cache=layer_cache,
step=step,
with_align=with_align,
)

output = self.layer_norm(output)
dec_outs = output.transpose(0, 1).contiguous()
attn = attn.transpose(0, 1).contiguous()

attns = {"std": attn}
if self._copy:
attns["copy"] = attn

# TODO change the way attns is returned dict => list or tuple (onnx)
return dec_outs, attns

def _init_cache(self, memory_bank=None):
self.state["cache"] = {}

for i, layer in enumerate(self.transformer_layers):
layer_cache = {"self_keys": None, "self_values": None,
"src_keys": None, "src_values": None}
if isinstance(layer.self_attn, AverageAttention):
raise NotImplementedError
self.state["cache"]["layer_{}".format(i)] = layer_cache
36 changes: 27 additions & 9 deletions onmt/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,34 @@
from onmt.modules.gate import context_gate_factory, ContextGate
from onmt.modules.global_attention import GlobalAttention
from onmt.modules.conv_multi_step_attention import ConvMultiStepAttention
from onmt.modules.copy_generator import CopyGenerator, CopyGeneratorLoss, \
CopyGeneratorLossCompute, CopyGeneratorLMLossCompute
from onmt.modules.multi_headed_attn import MultiHeadedAttention
from onmt.modules.copy_generator import (
CopyGenerator,
CopyGeneratorLoss,
CopyGeneratorLossCompute,
CopyGeneratorLMLossCompute,
)
from onmt.modules.multi_headed_attn import (
MultiHeadedAttention,
MultiHeadedPseudoSelfAttention,
)
from onmt.modules.embeddings import Embeddings, PositionalEncoding
from onmt.modules.weight_norm import WeightNormConv2d
from onmt.modules.average_attn import AverageAttention

__all__ = ["Elementwise", "context_gate_factory", "ContextGate",
"GlobalAttention", "ConvMultiStepAttention", "CopyGenerator",
"CopyGeneratorLoss", "CopyGeneratorLossCompute",
"MultiHeadedAttention", "Embeddings", "PositionalEncoding",
"WeightNormConv2d", "AverageAttention",
"CopyGeneratorLMLossCompute"]
__all__ = [
"Elementwise",
"context_gate_factory",
"ContextGate",
"GlobalAttention",
"ConvMultiStepAttention",
"CopyGenerator",
"CopyGeneratorLoss",
"CopyGeneratorLossCompute",
"MultiHeadedAttention",
"Embeddings",
"PositionalEncoding",
"WeightNormConv2d",
"AverageAttention",
"CopyGeneratorLMLossCompute",
"MultiHeadedPseudoSelfAttention",
]
Loading