## Bahdanau 注意力

一个带有Bahdanau注意力的循环神经网络编码器-解码器模型

![seq2seq-attention-details.svg](https://zh-v2.d2l.ai/_images/seq2seq-attention-details.svg)

In [None]:
import sys
sys.path.append('..')
import torch
from torch import nn
import d2l

In [None]:
class AttentionDecoder(d2l.Decoder):
  '''带有注意力机制解码器的基本接口'''
  def __init__(self):
    super().__init__()

  @property
  def attention_weights(self):
    raise NotImplementedError

In [None]:
class Seq2SeqAttentionDecoder(AttentionDecoder):
  def __init__(self, vocab_size, embed_size, num_hiddens, num_layers, dropout=0):
    super().__init__()
    self.attention = d2l.AdditiveAttention(
      num_hiddens, num_hiddens, num_hiddens, dropout
    )
    self.embedding = nn.Embedding(
      vocab_size, embed_size
    )
    self.rnn = nn.GRU(
      embed_size + num_hiddens, num_hiddens, num_layers, dropout=dropout
    )
    self.dense = nn.Linear(
      num_hiddens, vocab_size
    )
  
  def init_state(self, enc_outputs, enc_valid_lens):
    # outputs的形状为(batch_size，num_steps，num_hiddens).
    # hidden_state的形状为(num_layers，batch_size，num_hiddens)
    outputs, hidden_state = enc_outputs
    return (outputs.permute(1, 0, 2), hidden_state, enc_valid_lens)
  
  def forward(self, X, state):
    # enc_outputs的形状为(batch_size,num_steps,num_hiddens).
    # hidden_state的形状为(num_layers,batch_size,
    # num_hiddens)
