In [23]:
import torch
from torch import nn

from utils.useful_func import *

# Bahdanau 注意力 只需要修改decoder部分

In [4]:
class AttentionDecoder(Decoder):
    """带有注意力机制解码器的基本接口"""
    def __init__(self,**kwargs):
        super(AttentionDecoder, self).__init__(**kwargs)
    @property
    def attention_weight(self):
        raise NotImplementedError

In [5]:
class Seq2SeqAttentionDecoder(AttentionDecoder):
    """通过encoder传来的原始序列的编码信息 进行解码翻译"""
    def __init__(self,vocab_size,embed_size,num_hiddens,num_layers,dropout=0,**kwargs):
        super(Seq2SeqAttentionDecoder, self).__init__(**kwargs)
        self.attention=AdditiveAttention(num_hiddens,num_hiddens,num_hiddens,dropout=dropout)
        self.embedding=nn.Embedding(vocab_size,embed_size)
        self.rnn=nn.GRU(embed_size+num_hiddens,num_hiddens,num_layers,dropout=dropout)
        self.dense=nn.Linear(num_hiddens,vocab_size)

    def init_state(self,encoder_out,enc_valid_lens,*args):

        outputs,hidden_state=encoder_out
        return outputs.permute(1,0,2),hidden_state,enc_valid_lens

    def forward(self,X,state):
        #ouputs batch_size,num_steps,num_hiddens
        outputs, hidden_state, enc_valid_lens = state
        X=self.embedding(X).permute(1,0,2)
        dec_outputs, self._attention_weights = [], []
        for x in X:
            #，unsqueeze(1) 这一步的目的就是为了给 Decoder 的顶层隐藏状态显式地添加一个维度，用来表示 Query 的数量 (在这个时间步是 1)，
            # 从而使其形状符合 Attention 模块期望的 (batch_size, num_queries, feature_size) 输入格式，使得 Attention 模块可以正确地进行批处理和内部计算。
            # query的形状为(batch_size,1,num_hiddens)
            query=hidden_state[-1].unsqueeze(1)
            # qkv 和q的有效长度
            # query batch_size,1,num_hiddens
            # outputs batch_size,num_steps,num_hiddens
            context=self.attention(query,outputs,outputs,enc_valid_lens)
            # x为batch_size,1,embed+hidden_size
            x=torch.cat((context,x.unsqueeze(1)),dim=-1)
            out,hidden_state=self.rnn(x.permute(1,0,2),hidden_state)
            dec_outputs.append(out)
            self._attention_weights.append(self.attention.attention_weights)
            dec_outputs = self.dense(torch.cat(dec_outputs, dim=0))
            return outputs.permute(1, 0, 2), [outputs, hidden_state,
                                          enc_valid_lens]
    @property
    def attention_weights(self):
        return self._attention_weights

In [73]:
torch.bmm(torch.randn(1,2,3),torch.randn(1,3,2))

tensor([[[-2.5753, -1.5549],
         [-2.1453, -0.8174]]])

In [72]:
torch.randn(1,2,3).shape

torch.Size([1, 2, 3])

In [61]:
gru=nn.GRU(5,10)

In [62]:
x=torch.randn(20,10,5)

In [68]:
out,hidden=gru(x)
out.shape


torch.Size([20, 10, 10])