In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from dplearning_second_part.limu_dplearning.utils.useful_func import masked_softmax

In [60]:
logit=torch.tensor([0.3,0.4])
target=torch.tensor([1,1],dtype=torch.float)
F.binary_cross_entropy_with_logits(logit,target)

In [61]:
F.binary_cross_entropy_with_logits(logit,target)

tensor(0.5337)

In [75]:
F.sigmoid(torch.tensor([[0.3,0.4],[1,2]]))

tensor([[0.5744, 0.5987],
        [0.7311, 0.8808]])

In [3]:

# 位置编码
class PositionalEncoding(nn.Module):
    """位置编码"""
    def __init__(self, num_hiddens, dropout=0.1, max_len=1000):
        super(PositionalEncoding, self).__init__()
        # 创建一个足够长的P 长是指第2个维度 第一个维度应该是batch_size
        # 每一个batch用的是一套位置编码 如果不用一套的话就无法学习到位置信息泛化能力极差
        self.P=torch.zeros(1,max_len,num_hiddens)
        self.dropout = nn.Dropout(dropout)
        X=torch.arange(max_len,dtype=torch.float32).reshape(-1,1)/\
            torch.pow(10000,torch.arange(0,num_hiddens,2)/num_hiddens)
        self.P[:,:,0::2]=torch.sin(X)
        self.P[:,:,1::2]=torch.cos(X)
    def forward(self, X):
        X = X+self.P[:,:X.shape[1],:].to(X.device)
        return self.dropout(X)


# 先实现一个点积注意力
class DotProductAttention(nn.Module):
    def __init__(self, dropout=0.2):
        super(DotProductAttention, self).__init__()
        self.dropout = nn.Dropout(dropout)
        self.attention_weights = None
    # q(b,step,embed_size)
    # k(b,键值对个数,embed_size)
    # v(b,键值对个数,embed_size)
    def forward(self, q, k, v,valid_lens):
        attn_weights = torch.bmm(q, k.transpose(1, 2))/torch.sqrt(torch.tensor(q.shape[-1]))
        self.attention_weights=masked_softmax(attn_weights,valid_lens)
        return torch.bmm(self.dropout(self.attention_weights), v)

#@save
class MultiHeadAttention(nn.Module):
    """多头注意力"""
    # qkv各自的embed_size, 隐藏层大小 头数量
    # 需要并行运算多个头 因此num_hiddens 必须能够整除以num_heads
    def __init__(self,key_size,query_size,value_size,num_hiddens,num_heads,dropout,bias=False, **kwargs):
        super(MultiHeadAttention, self).__init__(**kwargs)
        self.num_heads = num_heads
        self.attention = DotProductAttention(dropout)
        self.W_q=nn.Linear(query_size,num_hiddens,bias=bias)
        self.W_k=nn.Linear(key_size,num_hiddens,bias=bias)
        self.W_v=nn.Linear(value_size,num_hiddens,bias=bias)
        self.W_o=nn.Linear(num_hiddens,num_hiddens,bias=bias)

    def forward(self, q, k, v,valid_lens=None):

        queries=self.W_q(q)
        keys=self.W_k(k)
        values=self.W_v(v)

        # 在这一步需要对qkv拆分为多头 并行计算attention
        queries=queries.reshape(queries.shape[0],queries.shape[1],self.num_heads,-1).permute(0,2,1,3)
        keys=keys.reshape(keys.shape[0],keys.shape[1],self.num_heads,-1).permute(0,2,1,3)
        values=values.reshape(values.shape[0],values.shape[1],self.num_heads,-1).permute(0,2,1,3)

        queries=queries.reshape(-1,queries.shape[2],queries.shape[3])
        keys=keys.reshape(-1,keys.shape[2],keys.shape[3])
        values=values.reshape(-1,values.shape[2],values.shape[3])

        if valid_lens is not None:
            # 在轴0，将第一项（标量或者矢量）复制num_heads次，
            # 然后如此复制第二项，然后诸如此类。
            valid_lens = torch.repeat_interleave(
                valid_lens, repeats=self.num_heads, dim=0)

        attn_weights=self.attention(queries,keys,values,valid_lens)
        attn_weights=attn_weights.reshape(-1,self.num_heads,attn_weights.shape[1],attn_weights.shape[2])
        attn_weights=attn_weights.permute(0,2,1,3)
        attn_weights=attn_weights.reshape(attn_weights.shape[0],attn_weights.shape[1],-1)

        return self.W_o(attn_weights)


In [4]:
class PositionWiseFFN(nn.Module):
    """基于位置的前馈网络"""
    def __init__(self,ffn_num_input,ffn_num_hiddens,ffn_num_outputs,**kwargs):
        super(PositionWiseFFN,self).__init__(**kwargs)
        self.dense1=nn.Linear(ffn_num_input,ffn_num_hiddens)
        self.relu=nn.ReLU()
        self.dense2=nn.Linear(ffn_num_hiddens,ffn_num_outputs)
    def forward(self,x):
        return self.dense2(self.relu(self.dense1(x)))

In [5]:
ffn=PositionWiseFFN(ffn_num_input=4,ffn_num_hiddens=4,ffn_num_outputs=8)
ffn.eval()
ffn(torch.ones((2,3,4))).shape

torch.Size([2, 3, 8])

In [6]:
# 残差连接和层规范化

In [7]:
ln = nn.LayerNorm(2)
bn = nn.BatchNorm1d(2)

In [8]:
x=torch.tensor([[1,2],[2,3]],dtype=torch.float32)
x,x.shape

(tensor([[1., 2.],
         [2., 3.]]),
 torch.Size([2, 2]))

In [34]:
# 残差连接和层规范化
class AddNorm(nn.Module):
    def __init__(self,normalized_shape,dropout,**kwargs):
        super(AddNorm,self).__init__()
        self.dropout=nn.Dropout(dropout)
        self.ln=nn.LayerNorm(normalized_shape)
        
    def forward(self,X,Y):
        return self.ln(self.dropout(Y) + X)

In [35]:
add_norm = AddNorm([3, 4], 0.5)
add_norm.eval()

AddNorm(
  (dropout): Dropout(p=0.5, inplace=False)
  (ln): LayerNorm((3, 4), eps=1e-05, elementwise_affine=True)
)

In [36]:
add_norm(torch.ones((2, 3, 4)), torch.ones((2, 3, 4))).shape

torch.Size([2, 3, 4])

In [37]:
class EncoderBlock(nn.Module):
    def __init__(self,key_size,query_size,value_size,num_hiddens,norm_shape
                 ,ffn_num_input,ffn_num_hiddens,num_heads,dropout,use_bias=False,**kwargs):
        super(EncoderBlock,self).__init__()
        self.attention = MultiHeadAttention(key_size,query_size,value_size,num_hiddens,num_heads,dropout,bias=use_bias)
        self.addnorm1=AddNorm(norm_shape, dropout)
        self.ffn = PositionWiseFFN(ffn_num_input,ffn_num_hiddens,num_hiddens,**kwargs)
        self.addnorm2=AddNorm(norm_shape, dropout)
        
    def forward(self,X,valid_lens):
        Y=self.addnorm1(X,self.attention(X,X,X,valid_lens))
        return self.addnorm2(Y,self.ffn(Y))

In [38]:
x=torch.ones((2,100,24))
valid_lens=torch.tensor([3,2])
encoder_blk=EncoderBlock(key_size=24,query_size=24,value_size=24,num_hiddens=24,norm_shape=[100,24],ffn_num_input=24,ffn_num_hiddens=48,num_heads=8,dropout=0.5)
encoder_blk.eval()

EncoderBlock(
  (attention): MultiHeadAttention(
    (attention): DotProductAttention(
      (dropout): Dropout(p=0.5, inplace=False)
    )
    (W_q): Linear(in_features=24, out_features=24, bias=False)
    (W_k): Linear(in_features=24, out_features=24, bias=False)
    (W_v): Linear(in_features=24, out_features=24, bias=False)
    (W_o): Linear(in_features=24, out_features=24, bias=False)
  )
  (addnorm1): AddNorm(
    (dropout): Dropout(p=0.5, inplace=False)
    (ln): LayerNorm((100, 24), eps=1e-05, elementwise_affine=True)
  )
  (ffn): PositionWiseFFN(
    (dense1): Linear(in_features=24, out_features=48, bias=True)
    (relu): ReLU()
    (dense2): Linear(in_features=48, out_features=24, bias=True)
  )
  (addnorm2): AddNorm(
    (dropout): Dropout(p=0.5, inplace=False)
    (ln): LayerNorm((100, 24), eps=1e-05, elementwise_affine=True)
  )
)

In [39]:
encoder_blk(x,valid_lens).shape

torch.Size([2, 100, 24])

In [54]:
encoder_blk.attention.attention.attention_weights

tensor([[[0.3333, 0.3333, 0.3333,  ..., 0.0000, 0.0000, 0.0000],
         [0.3333, 0.3333, 0.3333,  ..., 0.0000, 0.0000, 0.0000],
         [0.3333, 0.3333, 0.3333,  ..., 0.0000, 0.0000, 0.0000],
         ...,
         [0.3333, 0.3333, 0.3333,  ..., 0.0000, 0.0000, 0.0000],
         [0.3333, 0.3333, 0.3333,  ..., 0.0000, 0.0000, 0.0000],
         [0.3333, 0.3333, 0.3333,  ..., 0.0000, 0.0000, 0.0000]],

        [[0.3333, 0.3333, 0.3333,  ..., 0.0000, 0.0000, 0.0000],
         [0.3333, 0.3333, 0.3333,  ..., 0.0000, 0.0000, 0.0000],
         [0.3333, 0.3333, 0.3333,  ..., 0.0000, 0.0000, 0.0000],
         ...,
         [0.3333, 0.3333, 0.3333,  ..., 0.0000, 0.0000, 0.0000],
         [0.3333, 0.3333, 0.3333,  ..., 0.0000, 0.0000, 0.0000],
         [0.3333, 0.3333, 0.3333,  ..., 0.0000, 0.0000, 0.0000]],

        [[0.3333, 0.3333, 0.3333,  ..., 0.0000, 0.0000, 0.0000],
         [0.3333, 0.3333, 0.3333,  ..., 0.0000, 0.0000, 0.0000],
         [0.3333, 0.3333, 0.3333,  ..., 0.0000, 0.0000, 0.

In [56]:
class TransformerEncoder(nn.Module):
    def __init__(self, vocab_size, key_size, query_size, value_size,
                 num_hiddens, norm_shape, ffn_num_input, ffn_num_hiddens,
                 num_heads, num_layers, dropout, use_bias=False, **kwargs):
        super(TransformerEncoder,self).__init__()
        self.num_hiddens = num_hiddens
        self.embedding = nn.Embedding(vocab_size,num_hiddens)
        self.pos_encoding=PositionalEncoding(num_hiddens,dropout)
        self.blks=nn.Sequential()
        for i in range(num_layers):
            self.blks.add_module('block'+str(i)
                                 ,EncoderBlock(key_size,query_size,value_size,num_hiddens
                                               ,norm_shape,ffn_num_input,ffn_num_hiddens
                                               ,num_heads,dropout,use_bias))
    def forward(self,X,valid_lens,*args):
        # 因为位置编码值在-1和1之间，
        # 因此嵌入值乘以嵌入维度的平方根进行缩放， 因为嵌入层会把整个层的元素都压缩在均值为0方差为1的分布中 
        # 因此当num_hiddens越大单个值会越小所以这么乘 保证每个元素也是在-1 1之间
        # 然后再与位置编码相加。
        X=self.pos_encoding(self.embedding(X)*torch.sqrt(torch.tensor(self.num_hiddens)))
        self.attention_weights = [None] * len(self.blks)
        for i, blk in enumerate(self.blks):
            X = blk(X, valid_lens)
            self.attention_weights[
                i] = blk.attention.attention.attention_weights
        return X

In [57]:
encoder = TransformerEncoder(
    200, 24, 24, 24, 24, [100, 24], 24, 48, 8, 2, 0.5)
encoder.eval()
encoder(torch.ones((2, 100), dtype=torch.long), valid_lens).shape

torch.Size([2, 100, 24])

In [None]:
class DecoderBlock(nn.Module):
    """解码器中的第i个块"""
    def __init__(self,key_size,query_size,value_size,num_hiddens,norm_shape
                 ,ffn_num_input,ffn_num_hiddens,num_heads,dropout,i,**kwargs):
        super(DecoderBlock,self).__init__()
        self.i=i
        self.attention1=MultiHeadAttention(key_size,query_size,value_size,num_hiddens,num_heads,dropout)
        self.addnorm1=AddNorm(norm_shape, dropout)
        self.attention2=MultiHeadAttention(key_size,query_size,value_size,num_hiddens,num_heads,dropout)
        self.addnorm2=AddNorm(norm_shape,dropout)
        self.ffn=PositionWiseFFN(ffn_num_input,ffn_num_hiddens,num_hiddens,**kwargs)
        self.addnorm3=AddNorm(norm_shape, dropout)
    def forward(self,X,state):
        