In [1]:
import torch
from torch import nn

In [2]:
import math
class DotProductAttention(nn.Module):
    def __init__(self) -> None:
        super(DotProductAttention,self).__init__()

    def forward(self,q,k,v):
        return torch.bmm(torch.bmm(q,k.transpose(1,2))/math.sqrt(q.shape[-1]),v)

In [3]:
def makeshapefine(X,num_head):
    X=X.reshape(X.shape[0], X.shape[1], num_head, -1)
    X=X.permute(0, 2, 1, 3)
    return X.reshape(-1, X.shape[2], X.shape[3])

In [4]:
def reversefineshape(X,num_head):
    X=X.reshape(-1,num_head,X.shape[1],X.shape[2])
    X=X.permute(0,2,1,3)
    return X.reshape(X.shape[0],X.shape[1],-1)

In [5]:
class MultiHeadAttention(nn.Module):
    def __init__(self,q_len,k_len,v_len,num_hidden,num_head) -> None:
        super(MultiHeadAttention,self).__init__()
        self.num_head=num_head
        self.q=nn.Linear(q_len,num_hidden)
        self.k=nn.Linear(k_len,num_hidden)
        self.v=nn.Linear(v_len,num_hidden)
        self.attention=DotProductAttention()

    def forward(self,q,k,v):
        return reversefineshape(self.attention(makeshapefine(self.q(q),self.num_head),makeshapefine(self.k(k),self.num_head),makeshapefine(self.v(v),self.num_head)),self.num_head)

In [6]:
class AddNorm(nn.Module):
    def __init__(self,norm_shape) -> None:
        super(AddNorm,self).__init__()
        self.norm=nn.LayerNorm(norm_shape)
    
    def forward(self,x,y):
        return self.norm(x+y)

In [7]:
class PositionWiseFFN(nn.Module):
    def __init__(self,num_ffn_input,num_ffn_hidden,num_ffn_output) -> None:
        super(PositionWiseFFN,self).__init__()
        self.dense1=nn.Linear(num_ffn_input,num_ffn_hidden)
        self.relu=nn.ReLU()
        self.dense2=nn.Linear(num_ffn_hidden,num_ffn_output)
    
    def forward(self,x):
        return self.dense2(self.relu(self.dense1(x)))

In [8]:
class EncodeBlk(nn.Module):
    def __init__(self,q_len,k_len,v_len,num_hidden,num_head,norm_shape,num_ffn_input,num_ffn_hidden) -> None:
        super(EncodeBlk,self).__init__()
        self.attention=MultiHeadAttention(q_len,k_len,v_len,num_hidden,num_head)
        self.addnorm1=AddNorm(norm_shape)
        self.PWiseFFn=PositionWiseFFN(num_ffn_input,num_ffn_hidden,num_hidden)
        self.addnorm2=AddNorm(norm_shape)

    def forward(self,x):
        y=self.addnorm1(x,self.attention(x,x,x))
        return self.addnorm2(y,self.PWiseFFn(y))

In [9]:
class PositionalEncoding(nn.Module):
    def __init__(self,num_hidden,max_len=1000) -> None:
        super(PositionalEncoding,self).__init__()
        self.P=torch.zeros((1,max_len,num_hidden))
        X=torch.arange(max_len,dtype=torch.float32).reshape(-1,1)/torch.pow(10000,torch.arange(0,num_hidden,2,dtype=torch.float32)/num_hidden)
        self.P[:,:,0::2]=torch.sin(X)
        self.P[:,:,1::2]=torch.cos(X)
    
    def forward(self,X):
        return X+self.P[:,:X.shape[1],:]

In [16]:
class TransformerEncoder(nn.Module):
    def __init__(self,num_layer,vocab_size,q_len,k_len,v_len,num_hidden,num_head,norm_shape,num_ffn_input,num_ffn_hidden) -> None:
        super(TransformerEncoder,self).__init__()
        self.num_hidden=num_hidden
        self.embedding=nn.Embedding(vocab_size,num_hidden)
        self.posembedding=PositionalEncoding(num_hidden)
        self.blocks=nn.Sequential()
        for i in range(num_layer):
            self.blocks.add_module("block"+str(i),EncodeBlk(q_len,k_len,v_len,num_hidden,num_head,norm_shape,num_ffn_input,num_ffn_hidden))
        
    def forward(self,x):
        x=self.posembedding(self.embedding(x)*math.sqrt(self.num_hidden))
        self.attention_weights=[None]*len(self.blocks)
        for i, block in enumerate(self.blocks):
            x=block(x)
            self.attention_weights[i]=block.attention.attention.attention_weights
        return x

In [11]:
class DecodeBlk(nn.Module):
    def __init__(self,i,num_ffn_input,num_ffn_hidden,norm_shape,q_len,k_len,v_len,num_hidden,num_head) -> None:
        super(DecodeBlk,self).__init__()
        self.i=i
        self.attention1=MultiHeadAttention(q_len,k_len,v_len,num_hidden,num_head)
        self.addnorm1=AddNorm(norm_shape)
        self.attention2=MultiHeadAttention(q_len,k_len,v_len,num_hidden,num_head)
        self.addnorm2=AddNorm(norm_shape)
        self.ffn=PositionWiseFFN(num_ffn_input,num_ffn_hidden,num_hidden)
        self.addnorm3=AddNorm(norm_shape)

    def forward(self,x,state):
        encode_outputs=state[0]
        if state[1][self.i] is None:
            key_values=x
        else:
            key_values=torch.cat((state[1][self.i],x),axis=1)
        state[1][self.i]=key_values
        x2=self.attention1(x,key_values,key_values)
        y=self.addnorm1(x,x2)
        y2=self.attention2(y,encode_outputs,encode_outputs)
        z=self.addnorm2(y,y2)
        return self.addnorm3(z,self.ffn(z)),state


In [20]:
class TransformerDecoder(nn.Module):
    def __init__(self,num_hidden,num_layer,vocab_size,num_ffn_input,num_ffn_hidden,norm_shape,q_len,k_len,v_len,num_head) -> None:
        super(TransformerDecoder,self).__init__()
        self.num_hidden=num_hidden
        self.num_layer=num_layer
        self.embedding=nn.Embedding(vocab_size,num_hidden)
        self.posembedding=PositionalEncoding(num_hidden)
        self.blocks=nn.Sequential()
        for i in range(num_layer):
            self.blocks.add_module("block"+str(i),DecodeBlk(i,num_ffn_input,num_ffn_hidden,norm_shape,q_len,k_len,v_len,num_hidden,num_head))
        self.dense=nn.Linear(num_hidden,vocab_size)
    
    def init_state(self,encode_outputs):
        return [encode_outputs,[None]*self.num_layer]
    
    def forward(self,x,state):
        x=self.posembedding(self.embedding(x)*math.sqrt(self.num_hidden))
        for i,block in enumerate(self.blocks):
            x,state=block(x,state)
        return self.dense(x),state

In [21]:
class Transformer(nn.Module):
    def __init__(self,encoder,decoder) -> None:
        super(Transformer,self).__init__()
        self.encoder=encoder
        self.decoder=decoder
    
    def forward(self,encoder_x,decoder_x):
        encoder_output=self.encoder(encoder_x)
        state=self.decoder.init_state(encoder_output)
        decoder_output=self.decoder(decoder_output,state)
        return decoder_output

In [None]:
num_hiddens, num_layers, dropout, batch_size, num_steps = 32, 2, 0.1, 64, 10
lr, num_epochs, device = 0.005, 200, torch.device('cuda')
num_ffn_input, num_ffn_hiddens, num_heads = 32, 64, 4
key_size, query_size, value_size = 32, 32, 32
norm_shape = [32]

In [14]:
from d2l import torch as d2l
train_iter, src_vocab, tgt_vocab = d2l.load_data_nmt(batch_size, num_steps)

In [22]:
encoder = TransformerEncoder(num_layers,len(tgt_vocab),query_size,key_size,value_size,num_hiddens,num_heads,norm_shape,num_ffn_input,num_ffn_hiddens)
decoder = TransformerDecoder(num_hiddens,num_layers,len(tgt_vocab),num_ffn_input,num_ffn_hiddens,norm_shape,query_size,key_size,value_size,num_heads)
net = Transformer(encoder, decoder)