In [3]:
import math
import torch
from torch import nn
from d2l import torch as d2l
def transpose_qkv(X,num_heads):
    # batch_size,tokens,embed_nmum
    X = X.reshape(X.shape[0],X.shape[1],num_heads,-1)
    # batch_size,tokens,num_heads,embed_num/num_heads
    X = X.permute(0,2,1,3)
    # batch_size,num_heads,num_tokens,embed_num/num_heads
    return X.reshape(-1,X.shape[2],X.shape[3])
    # batch_size*num_heads ,num_tokens,embed_num/num_heads
def transpose_output(X,num_heads):
    # batch_size*num_heads ,num_tokens,embed_num/num_heads
    X = X.reshape(-1,num_heads,X.shape[1],X.shape[2])
    # batch_size,num_heads,num_tokens,embed/num_heads
    X.permute(0,2,1,3)
    # batch_size,num_tokens,num_head,embed/num_heads
    return X.reshape(X.shape[0],X.shape[1],-1)
    # batch_size,num_tokens,embed_num

In [4]:
class MultiHeadAttention(nn.Module):
    def __init__(self, key_size,
                 query_size,
                 value_size,
                 num_hiddens,
                 num_head,
                 dropout,
                 bias=False,
                 *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.num_head = num_head
        self.attention = d2l.DotProductAttention(dropout)
        self.W_q = nn.Linear(query_size,num_hiddens,bias=bias)
        self.W_k = nn.Linear(key_size,num_hiddens,bias=bias)
        self.W_v = nn.Linear(value_size,num_hiddens,bias=bias)
        self.W_o = nn.Linear(num_hiddens,num_hiddens,bias=bias)
    
    def forward(self,queries,keys,values,valid_lens):
        queries = transpose_qkv(self.W_k(queries),self.num_head)
        keys = transpose_qkv(self.W_k(keys),self.num_head)
        values = transpose_qkv(self.W_q(values),self.num_head)

        if valid_lens is not None:
            valid_lens = torch.repeat_interleave(valid_lens,
                                                 repeats=self.num_head,
                                                 dim=0)
        output = self.attention(queries,keys,values,valid_lens)

        output_concat = transpose_output(output,self.num_head)

        return self.W_o(output_concat)


In [5]:
import pandas as pd
class PositionWiseFFN(nn.Module):
    def __init__(self,ffn_num_input,ffn_num_hiddens,ffn_num_output,*args, **kwargs):
        super().__init__(*args, **kwargs)
        self.dense1 = nn.Linear(ffn_num_input,ffn_num_hiddens)
        self.relu = nn.ReLU()
        self.dense2 = nn.Linear(ffn_num_hiddens,ffn_num_output)

    def forward(self,X):
        return self.dense2(self.relu(self.dense1(X)))

In [6]:
class AddNorm(nn.Module):
    def __init__(self,normalized_shape,droupt, *args, **kwargs):
        super().__init__(*args, **kwargs)

        self.dropout = nn.Dropout(droupt)
        self.ln = nn.LayerNorm(normalized_shape)

    def forward(self,X,Y):
        return self.ln(self.dropout(Y),X)


In [7]:
class EncoderBlock(nn.Module):
    def __init__(self,key_size,
                 query_size,
                 values_size,
                 num_hidden,
                 norm_shape,
                 ffn_num_input,
                 ffn_num_hidden,
                 num_head,
                dropout,
                use_bias=False,
                 *args, **kwargs):
        super().__init__(*args, **kwargs)

        self.attention = MultiHeadAttention(key_size,query_size,values_size,num_head,num_head,dropout,use_bias)

        self.addnorml1 = AddNorm(normalized_shape=norm_shape,droupt=dropout)
        self.ffn = PositionWiseFFN(ffn_num_input=ffn_num_input,
                                   ffn_num_hiddens=ffn_num_hidden,
                                   ffn_num_output=num_hidden)
        self.addnorml2 = AddNorm(norm_shape,dropout)
    def forward(self,X,valid_len):
        Y = self.addnorml1(X,self.attention(X,X,X,valid_len))
        return self.addnorml2(Y,self.ffn(Y))
    



In [8]:
class TransformerEncoder(d2l.Encoder):
    def __init__(self,vocab_size,
                 ket_size,
                 query_size,
                 value_size,
                 num_hiddens,
                 norm_shape,
                 ffn_num_input,
                 ffn_num_hidden,
                 num_head,num_layer,droupt,use_bias=False):
        super().__init__()

        self.num_hiddens = num_hiddens
        self.embedding = nn.Embedding(vocab_size,num_hiddens)
        self.pos_encoding = d2l.PositionalEncoding(num_hiddens=num_hiddens,dropout=droupt)
        self.blks = nn.Sequential()
        for i in range(num_layer):
            self.add_module(
                "block"+str(i),
                EncoderBlock(
                    key_size=ket_size,query_size=query_size,
                    values_size=value_size,
                    num_hidden=num_hiddens,
                    norm_shape=norm_shape,
                    ffn_num_hidden=ffn_num_hidden,ffn_num_input=ffn_num_input,
                    num_head=num_head,
                    dropout=droupt,use_bias=use_bias
                )
            )

    def forward(self, X,valid_lens,*agrs):
        X = self.pos_encoding(
            self.embedding(X) * math.sqrt(self.num_hiddens)
        )
        self.attention_weights = [None]*len(self.blks)
        for i,blk in enumerate(self.blks):
            X = blk(X,valid_lens)
            self.attention_weights[i] = blk.attention.attention_weights
        return X

In [9]:
class DecoderBlock(nn.Module):
    def __init__(self, key_size,
                 query_size,
                 value_size,
                 num_hidden,
                 norm_shape,
                 ffn_num_input,
                 ffn_num_hidden,
                 num_head,
                 dropout,
                 i,
                 *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.i = i
        self.attention1 = MultiHeadAttention(key_size,query_size,value_size,num_head,num_head,dropout)

        self.addnorm1 = d2l.AddNorm(norm_shape,dropout)

        self.attention2 = MultiHeadAttention(key_size,query_size,value_size,num_head,num_head,dropout)
        self.addnorm2 = d2l.AddNorm(norm_shape,dropout)
        self.ffn = PositionWiseFFN(ffn_num_input,ffn_num_hidden,num_hidden)

        self.addnorm3 = d2l.AddNorm(norm_shape,dropout)
    
    def forward(self,X,state):
        enc_outputs ,enc_valid_lens = state[0],state[1]
        
        if state[2][self.i] is None:
            key_values = X
        else:
            key_values = torch.cat((state[2][self.i],X),dim=1)
        state[2][self.i] = key_values
        if self.training:
           batch_size,num_step,_ = X.shape
           dec_valid_lens = torch.arange(1,num_step+1,device=X.device).repeat(
               batch_size,1
           ) 
        else:
            dec_valid_lens = None
        
        X2 = self.attention1(X,key_values,key_values,dec_valid_lens)
        Y = self.addnorm1(X,X2)
        Y2 = self.attention2(Y,enc_outputs,enc_outputs,enc_valid_lens)
        z = self.addnorm2(Y,Y2)
        return self.addnorm3(z,self.ffn(z)),state

In [10]:
class TransformerDecoder(d2l.AttentionDecoder):
    def __init__(self,
                 vocab_size,
                 key_size,
                 query_size,
                 value_size,
                 num_hiddens,
                 norm_shape,
                 ffn_num_input,
                 ffn_num_hidden,
                 num_head,
                 num_layer,
                 dropout):
        super().__init__()

        self.num_hiddens = num_hiddens
        self.num_layer = num_layer
        self.embedding = nn.Embedding(vocab_size,num_hiddens)
        self.pos_encoding = d2l.PositionalEncoding(num_hiddens,dropout)
        self.blks = nn.Sequential()
        for i in range(num_layer):
            self.blks.add_module(
                "block"+str(i),
                DecoderBlock(
                    key_size=key_size,
                    query_size=query_size,
                    value_size=value_size,
                    num_hidden=num_hiddens,
                    norm_shape=norm_shape,
                    ffn_num_input=ffn_num_input,
                    ffn_num_hidden=ffn_num_hidden,
                    num_head=num_head,
                    dropout=dropout,
                    i=i
                )
            )
        self.dense = nn.Linear(num_hiddens,vocab_size)
    def init_state(self,enc_outputs,enc_valid_lens):
        return [enc_outputs,enc_valid_lens,None]
    def forward(self,X,state):
        X = self.pos_encoding(X)
        self._attention_weights = [[None]*len(self.blks) for _ in range(2)]

        for i , blk in enumerate(self.blks):
            X, state = blk(X,state)
            self._attention_weights[0][i] = blk.attention1.attention.attention_weights
            self._attention_weights[1][i] = blk.attention2.attention.attention_weights
            
        return self.dense(X),state
    @property
    def attention_weights(self):
        return self._attention_weights
    