In [1]:
import torch
import dltools
from torch import nn
import math

In [2]:
class PositionWsieFNN(nn.Module):
    def __init__(self, fnn_num_input,fnn_num_hidden,fnn_num_output, **kwargs):
        super().__init__(**kwargs)
        self.dense1 = nn.Linear(fnn_num_input,fnn_num_hidden)
        self.relu = nn.ReLU()
        self.dense2 = nn.Linear(fnn_num_hidden,fnn_num_output)

    def forward(self,X):
        return self.dense2(self.relu(self.dense1(X)))

In [3]:
fnn = PositionWsieFNN(2,4,8)
fnn.eval()
X = torch.ones((2,3,2))
fnn(X).shape

torch.Size([2, 3, 8])

In [4]:
class AddNorm(nn.Module):
    def __init__(self, normalized_shape,dropout,**kwargs):
        super().__init__(**kwargs)
        self.dropout = nn.Dropout(dropout)
        self.ln = nn.LayerNorm(normalized_shape)

    def forward(self,X,Y):
        return self.ln(self.dropout(Y)+X)


In [5]:
addnorm = AddNorm([8,4],0.2)
addnorm.eval()
X= torch.ones((2,8,4))
Y= torch.ones((2,8,4))
addnorm(X,Y).shape

torch.Size([2, 8, 4])

In [7]:
class EncoderBlock(nn.Module):
    def __init__(self, query_num,key_num,value_num,num_hiddens,fnn_num_input,fnn_num_hiddens,fnn_num_output,norm_shape,num_heads,dropout,use_bias = False, **kwargs):
        super().__init__(**kwargs)
        self.attention = dltools.MultiHeadAttention(key_num,query_num,value_num,num_hiddens,num_heads,dropout,use_bias)
        self.fnn = PositionWsieFNN(fnn_num_input,fnn_num_hiddens,fnn_num_output)
        self.addNorm1 = AddNorm(norm_shape,dropout)
        self.addNorm2 = AddNorm(norm_shape,dropout)
    def forward(self,X,valid_len):
        Y = self.addNorm1(X,self.attention(X,X,X,valid_len))
        return self.addNorm2(Y,self.fnn(Y))

In [33]:
class TransformerEncoder(dltools.Encoder):
    def __init__(self,vocab_size,num_layers,query_size,key_size,value_size,num_hiddens,fnn_num_input,fnn_num_hiddens,fnn_num_output,norm_shape,num_heads,dropout,use_bias = False,**kwargs):
        super().__init__(**kwargs)
        self.numhiddens = num_hiddens
        self.embed = nn.Embedding(vocab_size,num_hiddens)
        self.posit_encode = dltools.PositionalEncoding(num_hiddens,dropout)
        self.blks = nn.Sequential()
        for i in range(num_layers):
            self.blks.add_module(
        "block" + str(i),
        EncoderBlock(key_size, query_size, value_size,
                    num_hiddens,
                    fnn_num_input, fnn_num_hiddens, fnn_num_output,
                    norm_shape, num_heads, dropout)
    )


        self.attention = [None] * len(self.blks)

    def forward(self, X,valid_len, *args):
        X = self.posit_encode(self.embed(X) * math.sqrt(self.numhiddens))
        self.attention = [None] * len(self.blks)
        for i,blk in enumerate(self.blks):
            X = blk(X,valid_len)
            self.attention[i] = blk.attention.attention.attention_weights
        return X

In [34]:
encoder = TransformerEncoder(200,2,24,24,24,24,24,48,24,[100,24],8,0.4)
encoder.eval()

valid_len = torch.tensor([3,2])
X = torch.ones((2,100), dtype=torch.long)
out = encoder(X, valid_len)

print(out.shape)


torch.Size([2, 100, 24])
