In [1]:
import torch
from torch import nn
from einops import rearrange, reduce, repeat


In [2]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [3]:
N = 2
HIDDEN_DIM = 256
NUM_HEAD = 8 
INNER_DIM = 512

PAD_IDX = 0
EOS_IDX = 3

BATCH_SIZE = 64

VOCAB_SIZE = 10000
SEQ_LEN = 60

In [4]:
class Multiheadattention(nn.Module):
    def __init__(self, hidden_dim: int, num_head: int):
        super().__init__()
 
        # embedding_dim, d_model, 512 in paper
        self.hidden_dim = hidden_dim
        # 8 in paper
        self.num_head = num_head
        # head_dim, d_key, d_query, d_value, 64 in paper (= 512 / 8)
        self.head_dim = hidden_dim // num_head
        self.scale = torch.sqrt(torch.FloatTensor([hidden_dim])).to(device)

        self.fcQ = nn.Linear(hidden_dim, hidden_dim)
        self.fcK = nn.Linear(hidden_dim, hidden_dim)
        self.fcV = nn.Linear(hidden_dim, hidden_dim)
        self.fcOut = nn.Linear(hidden_dim, hidden_dim)


    def forward(self, srcQ, srcK, srcV, mask=None):

        ##### SCALED DOT PRODUCT ATTENTION ######
        Q = self.fcQ(srcQ)
        K = self.fcK(srcK)
        V = self.fcV(srcV)
        
        Q = rearrange(
            Q, 'bs seq_len (num_head head_dim) -> bs num_head seq_len head_dim', num_head=self.num_head)
        K_T = rearrange(
            K, 'bs seq_len (num_head head_dim) -> bs num_head head_dim seq_len', num_head=self.num_head)
        V = rearrange(
            V, 'bs seq_len (num_head head_dim) -> bs num_head seq_len head_dim', num_head=self.num_head)

        attention_energy = torch.matmul(Q, K_T)/self.scale

        if mask is not None :
 
            attention_energy = torch.masked_fill(attention_energy, (mask == 0), -1e+4)
            
        attention_energy = torch.softmax(attention_energy, dim = -1)

        result = torch.matmul(attention_energy,V)

        result = rearrange(result, 'bs num_head seq_len head_dim -> bs seq_len (num_head head_dim)')

        result = self.fcOut(result)

        return result






In [5]:
class FFN(nn.Module):
    def __init__ (self, hidden_dim, inner_dim):
        super().__init__()
 
        self.hidden_dim = hidden_dim
        self.inner_dim = inner_dim 

        self.fc1 = nn.Linear(hidden_dim, inner_dim)
        self.fc2 = nn.Linear(inner_dim, hidden_dim)
        self.relu = nn.ReLU(inplace=False)
        self.dropout = nn.Dropout(0.1)

    def forward(self, input):
        output = input
        output = self.fc1(output)
        output = self.relu(output)
        output = self.dropout(output)
        output = self.fc2(output)

        return output



In [6]:
class EncoderLayer(nn.Module):
    def __init__(self, hidden_dim, num_head, inner_dim):
        super().__init__()

        self.hidden_dim = hidden_dim
        self.num_head = num_head
        self.inner_dim = inner_dim
        
        self.multiheadattention = Multiheadattention(hidden_dim, num_head)
        self.dropout1 = nn.Dropout(p=0.1)
        self.layerNorm1 = nn.LayerNorm(hidden_dim)
        self.ffn = FFN(hidden_dim, inner_dim)
        self.dropout2 = nn.Dropout(p=0.1)
        self.layerNorm2 = nn.LayerNorm(hidden_dim)
   

    def forward(self, input, mask = None):

        output = self.multiheadattention(srcQ= input, srcK = input, srcV = input, mask = mask)
        output = self.dropout1(output)
        output = input + output
        output = self.layerNorm1(output)

        output_ = self.ffn(output)
        output_ = self.dropout2(output_)
        output = output + output_
        output = self.layerNorm2(output)

        return output

In [7]:
def makeMask(tensor, option: str):
  
    if option == 'padding':
        tmp = torch.full_like(tensor, fill_value=PAD_IDX).to(device)
        mask = (tensor != tmp).float()
        mask = rearrange(mask, 'bs seq_len -> bs 1 1 seq_len ')


    elif option == 'lookahead':
        padding_mask = makeMask(tensor, 'padding')
        padding_mask = repeat(
            padding_mask, 'bs 1 1 k_len -> bs 1 new k_len', new=padding_mask.shape[3])
        
        mask = torch.ones_like(padding_mask)
        mask = torch.tril(mask)

        mask = mask * padding_mask
     



    return mask            

In [8]:
class MultiHeadAttention(nn.Module):
    def __init__(self, hidden) -> None:
        super().__init__()
        self.linear_q = nn.Linear()
    
    def forward(self, x, mask):
        
        pass

In [9]:
class PositionalEmbedding(nn.Module):
    def __init__(self, max_seq_len, d_model) -> None:
        super().__init__()
        pos = torch.arange(
            0, max_seq_len, dtype=torch.float
        ).unsqueeze(1)
        indices = torch.arange(d_model, dtype=torch.float) // 2
        self.positional_encoding = torch.where(
            torch.arange(0, d_model) % 2 == 0,
            torch.sin(pos / (10000 ** (indices / d_model))),
            torch.cos(pos / (10000 ** (indices / d_model)))
        )
        self.positional_encoding.requires_grad = False

    def forward(self, x):
        seq_len = x.shape[1]
        return self.positional_encoding[:seq_len, :]

PositionalEmbedding(100, 7)(torch.arange(700).reshape((100, 7)))

tensor([[ 0.0000,  1.0000,  0.0000,  1.0000,  0.0000,  1.0000,  0.0000],
        [ 0.8415,  0.5403,  0.2651,  0.9642,  0.0719,  0.9974,  0.0193],
        [ 0.9093, -0.4161,  0.5112,  0.8595,  0.1434,  0.9897,  0.0386],
        [ 0.1411, -0.9900,  0.7207,  0.6932,  0.2142,  0.9768,  0.0579],
        [-0.7568, -0.6536,  0.8787,  0.4774,  0.2839,  0.9588,  0.0772],
        [-0.9589,  0.2837,  0.9738,  0.2274,  0.3521,  0.9360,  0.0964],
        [-0.2794,  0.9602,  0.9992, -0.0388,  0.4185,  0.9082,  0.1156]])

In [10]:
class Encoder(nn.Module):
    def __init__(self, num_embeddings, embedding_dim, max_seq_len, num_head, inner_dim, num_enc_layers ,dropout_ratio=0.1) -> None:
        super().__init__()
        self.embedding = nn.Embedding(num_embeddings=num_embeddings, embedding_dim=embedding_dim)
        self.pos_embedding = PositionalEmbedding(max_seq_len, embedding_dim)
        self.dropout = nn.Dropout(dropout_ratio)

        self.enc_layers = nn.ModuleList([EncoderLayer(embedding_dim, num_head, inner_dim) for _ in range(num_enc_layers)])
        
    
    def forward(self, x):
        mask = makeMask(x, option='padding')
        x1 = self.embedding(x)
        x2 = self.pos_embedding(x)
        x = self.dropout(x1 + x2)

        for layer in self.enc_layers:
            x = layer(x, mask)
        return x

encoder = Encoder(15, 256, 100, 8, 512, 3)

encoder(torch.arange(6, dtype=torch.int).reshape((2, 3)))

tensor([[[-2.0795,  1.3508,  0.5941,  ..., -0.6936,  0.2948, -0.3487],
         [ 0.0893, -1.0699,  0.1096,  ...,  0.1377, -1.3342,  0.9301],
         [ 1.0097, -1.5905,  1.3774,  ...,  1.0004, -0.5331, -0.2621]],

        [[-2.1686, -1.3438, -1.2352,  ..., -0.6638, -0.3742,  1.0105],
         [-0.6628,  1.2588, -0.2769,  ...,  1.3390, -1.0327,  0.3651],
         [-0.7371,  0.0410, -1.3979,  ..., -0.8101, -1.3596,  1.7677]]],
       grad_fn=<NativeLayerNormBackward0>)

In [11]:
class DecoderLayer(nn.Module):
    def __init__(self, hidden_dim, num_head, inner_dim):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.inner_dim = inner_dim

        self.multiheadattention1 = Multiheadattention(hidden_dim, num_head)
        self.dropout1 = nn.Dropout(p=0.1)
        self.layerNorm1 = nn.LayerNorm(hidden_dim)

        self.multiheadattention2 = Multiheadattention(hidden_dim, num_head)
        self.dropout2 = nn.Dropout(p=0.1)
        self.layerNorm2 = nn.LayerNorm(hidden_dim)

        self.ffn = FFN(hidden_dim, inner_dim)
        self.dropout3 = nn.Dropout(p=0.1)
        self.layerNorm3 = nn.LayerNorm(hidden_dim)

    def forward(self, inputs, enc_output, padding_mask, look_ahead_mask):
        inputs1 = inputs
        inputs2 = self.multiheadattention1(inputs, inputs, inputs, look_ahead_mask)
        inputs2 = self.dropout1(inputs2)
        inputs = inputs1 + inputs2
        inputs = self.layerNorm1(inputs)

        inputs1 = inputs
        inputs2 = self.multiheadattention2(inputs, enc_output, enc_output, padding_mask)
        inputs2 = self.dropout2(inputs2)
        inputs = inputs1 + inputs2

        inputs1 = inputs
        inputs2 = self.ffn(inputs)
        inputs2 = self.dropout3(inputs)
        inputs = inputs1 + inputs2
        return inputs

In [12]:
class Decoder(nn.Module):
    def __init__(self, num_embeddings, embedding_dim, max_seq_len, num_head, inner_dim, num_dec_layers, dropout_ratio=0.1) -> None:
        super().__init__()
        self.embedding = nn.Embedding(num_embeddings=num_embeddings, embedding_dim=embedding_dim)
        self.pos_embedding = PositionalEmbedding(max_seq_len, embedding_dim)
        self.dropout = nn.Dropout(dropout_ratio)
        self.dec_layers = nn.ModuleList([DecoderLayer(embedding_dim, num_head, inner_dim) for _ in range(num_dec_layers)])
        
    def forward(self, inputs, enc_output):
        padding_mask = makeMask(inputs, option='padding')
        look_ahead_mask = makeMask(inputs, option='lookahead')
        inputs1 = self.embedding(inputs)
        inputs2 = self.pos_embedding(inputs)
        inputs = self.dropout(inputs1 + inputs2)

        for layer in self.dec_layers:
            inputs = layer(inputs, enc_output, padding_mask, look_ahead_mask)
        return inputs

In [13]:
class Transformer(nn.Module):
    def __init__(self, seq_len, num_classes, N = 2, hidden_dim = 256, num_head = 8, inner_dim = 512, vocab_size=1000):
        super().__init__()
        self.encoder = Encoder(vocab_size, hidden_dim, seq_len, num_head, inner_dim, N)
        self.decoder = Decoder(vocab_size, hidden_dim, seq_len, num_head, inner_dim, N)
        self.flatten = nn.Flatten()
        self.linear = nn.Linear(seq_len * hidden_dim, num_classes)
        self.softmax = nn.Softmax(1)
    
    def forward(self, enc_src, dec_src):
        enc_out = self.encoder(enc_src)
        dec_out = self.decoder(dec_src, enc_out)
        print(dec_out.shape)
        out = self.flatten(dec_out)
        out = self.linear(out)
        out = self.softmax(out)
        return out

In [14]:
model = Transformer(3, 10).to(device)
model

Transformer(
  (encoder): Encoder(
    (embedding): Embedding(1000, 256)
    (pos_embedding): PositionalEmbedding()
    (dropout): Dropout(p=0.1, inplace=False)
    (enc_layers): ModuleList(
      (0-1): 2 x EncoderLayer(
        (multiheadattention): Multiheadattention(
          (fcQ): Linear(in_features=256, out_features=256, bias=True)
          (fcK): Linear(in_features=256, out_features=256, bias=True)
          (fcV): Linear(in_features=256, out_features=256, bias=True)
          (fcOut): Linear(in_features=256, out_features=256, bias=True)
        )
        (dropout1): Dropout(p=0.1, inplace=False)
        (layerNorm1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
        (ffn): FFN(
          (fc1): Linear(in_features=256, out_features=512, bias=True)
          (fc2): Linear(in_features=512, out_features=256, bias=True)
          (relu): ReLU()
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (dropout2): Dropout(p=0.1, inplace=False)
        (layer

In [15]:
model(torch.arange(12, dtype=torch.int).reshape((4, 3)), torch.arange(12, dtype=torch.int).reshape((4, 3)))

torch.Size([4, 3, 256])


tensor([[0.0191, 0.4483, 0.0056, 0.1059, 0.0913, 0.0453, 0.0905, 0.0868, 0.0370,
         0.0702],
        [0.2952, 0.1161, 0.0075, 0.0164, 0.0190, 0.0440, 0.2896, 0.0178, 0.0684,
         0.1259],
        [0.1416, 0.0931, 0.0167, 0.0069, 0.1751, 0.2000, 0.0620, 0.0477, 0.1580,
         0.0989],
        [0.1691, 0.1336, 0.0591, 0.0090, 0.1170, 0.0117, 0.1804, 0.0595, 0.0523,
         0.2084]], grad_fn=<SoftmaxBackward0>)