In [1]:
import torch
import torch.nn as nn

In [53]:
class SelfAttention(nn.Module) : 
    def __init__(self, embed_size, heads): 
        super().__init__()
        self.embed_size = embed_size
        self.heads = heads
        self.head_dim = embed_size // heads

        assert (self.head_dim * heads == embed_size) , "embed_size must be divisible by heads"

        self.queries = nn.Linear(self.head_dim, self.head_dim, bias = False)
        self.keys = nn.Linear(self.head_dim, self.head_dim, bias = False)
        self.values = nn.Linear(self.head_dim, self.head_dim, bias = False)

        self.fc_out = nn.Linear(heads * self.head_dim, embed_size)


    def forward(self, values, keys, query, mask) : 
        #Query : What i'm looking for ?
        #Key : What i can offer ?
        #Value : What i actually offer ?
        
        N = query.shape[0]
        value_len, key_len, query_len = values.shape[1], keys.shape[1], query.shape[1]

        #Split embeddings into self.heads
        values = values.reshape(N, value_len, self.heads, self.head_dim)
        keys = keys.reshape(N, key_len, self.heads, self.head_dim)
        queries = query.reshape(N, query_len, self.heads, self.head_dim)

        values = self.values(values)
        keys = self.keys(keys)
        queries = self.queries(queries)

        energy = torch.einsum("nqhd, nkhd -> nhqk", [queries, keys])
        #energy shape : (N, heads, query_len, key_len)

        if mask is not None : 
            energy = energy.masked_fill(mask == 0, float("-1e28"))

        attention = torch.softmax(energy / ((self.embed_size)**(1/2)), dim = 3)
        #attention shape : (N, heads, query_len, key_len)
        #values shape : (N, value_len, heads, head_size)

        out = torch.einsum("nhql,nlhd -> nqhd", [attention, values]).reshape(N, query_len, self.heads*self.head_dim)
        #out shape : (N, query_len, heads, head_dim)

        out = self.fc_out(out)

        return out        

In [54]:
class TransformerBlock(nn.Module) : 
    def __init__(self, embed_size, heads, dropout, forward_expansion) : 
        super().__init__()
        self.attention = SelfAttention(embed_size, heads)
        self.norm1 = nn.LayerNorm(embed_size)
        self.norm2 = nn.LayerNorm(embed_size)
        self.feed_forward = nn.Sequential(
            nn.Linear(embed_size, forward_expansion*embed_size),
            nn.ReLU(),
            nn.Linear(forward_expansion*embed_size, embed_size)
        )
        self.dropout = nn.Dropout(dropout)


    def forward(self, value, key, query, mask) :
        attention = self.attention(value, key, query, mask)
        x = self.norm1(attention + query)
        x = self.dropout(x)
        forward = self.feed_forward(x)
        out = self.norm2(forward + x)
        out = self.dropout(out)

        return out

In [55]:
class Encoder(nn.Module) : 
    def __init__(self, 
                src_vocab_size,
                embed_size,
                num_layers,
                heads,
                device,
                forward_expansion,
                dropout,
                max_length) : 
        super().__init__()
        
        self.embed_size = embed_size
        self.device = device
        self.dropout = nn.Dropout(dropout)
        self.word_embedding = nn.Embedding(src_vocab_size, embed_size)
        self.position_embedding = nn.Embedding(max_length, embed_size)
        self.layers = nn.ModuleList(
            [
                TransformerBlock(embed_size, heads, dropout, forward_expansion) for _ in range(num_layers)
                
            ]
        )

    def forward(self, x, mask) : 
        N, seq_length = x.shape
        positions = torch.arange(0, seq_length).expand(N, seq_length).to(self.device)
        out = self.word_embedding(x) + self.position_embedding(positions)
        out = self.dropout(out)

        for layer in self.layers : 
            out = layer(out, out,out, mask)

        return out

In [56]:
class DecoderBlock(nn.Module) : 
    def __init__(self, embed_size, heads, forward_expansion, dropout, device) : 
        super().__init__()

        self.attention = SelfAttention(embed_size, heads)
        self.norm = nn.LayerNorm(embed_size)
        self.transformer_block = TransformerBlock(embed_size, heads, dropout, forward_expansion)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, value, key, src_mask, trg_mask) : 
        attention = self.attention(x,x,x,trg_mask)
        query = self.dropout(self.norm(attention + x))
        out = self.transformer_block(value, key, query, src_mask)

        return out

In [57]:
class Decoder(nn.Module) : 
    def __init__(self,
                trg_vocab_size,
                embed_size,
                num_layers,
                heads,
                forward_expansion,
                dropout,
                device,
                max_length) : 
        super().__init__()
        self.device = device
        self.word_embedding = nn.Embedding(trg_vocab_size, embed_size)
        self.position_embedding = nn.Embedding(max_length, embed_size)
        
        self.layers = nn.ModuleList([DecoderBlock(embed_size, heads, forward_expansion, dropout, device) 
                                    for _ in range(num_layers)])

        self.fc_out = nn.Linear(embed_size, trg_vocab_size)
        self.dropout = nn.Dropout(dropout)


    def forward(self, x, enc_out, src_mask, trg_mask) : 
        N, seq_length = x.shape
        positions = torch.arange(0, seq_length).expand(N, seq_length).to(self.device)
        x = self.dropout((self.word_embedding(x) + self.position_embedding(positions)))

        for layer in self.layers : 
            x = layer(x, enc_out, enc_out, src_mask, trg_mask)

        out = self.fc_out(x)

        return out

In [58]:
class Transformer(nn.Module) : 
    def __init__(self, 
                src_vocab_size,
                trg_vocab_size,
                src_pad_idx,
                trg_pad_idx,
                embed_size = 256,
                num_layers = 6,
                forward_expansion = 4,
                heads = 8,
                dropout = 0,
                device = "cpu",
                max_length = 188) : 
        super().__init__()
        
        self.encoder = Encoder(src_vocab_size,
                embed_size,
                num_layers,
                heads,
                device,
                forward_expansion,
                dropout,
                max_length)

        self.decoder = Decoder(trg_vocab_size,
                embed_size,
                num_layers,
                heads,
                forward_expansion,
                dropout,
                device,
                max_length)

        self.src_pad_idx = src_pad_idx
        self.trg_pad_idx = trg_pad_idx
        self.device = device


    def make_src_mask(self, src) :
        src_mask = (src != self.src_pad_idx).unsqueeze(1).unsqueeze(2)
        #(N,1,1,src_len)
        return src_mask.to(self.device)

    def make_trg_mask(self, trg) : 
        N, trg_len = trg.shape
        trg_mask = torch.tril(torch.ones((trg_len, trg_len))).expand(N,1,trg_len,trg_len)
        return trg_mask.to(self.device)

    def forward(self, src, trg) :
        src_mask = self.make_src_mask(src)
        trg_mask = self.make_trg_mask(trg)
        enc_src = self.encoder(src, src_mask)
        out = self.decoder(trg, enc_src, src_mask, trg_mask)

        return out

# Example

In [59]:
device = "cpu"
x = torch.tensor([[2,5,8,6,7,4,5,2,0], [0,2,3,5,8,9,6,5,7]])
trg = torch.tensor([[5,8,7,9,6,5,4,2], [4,5,8,9,2,3,5,4]])

src_pad_idx = 0
trg_pad_idx = 0
src_vocab_size = 12
trg_vocab_size = 12

model = Transformer(src_vocab_size, trg_vocab_size, src_pad_idx, trg_pad_idx)
out = model(x, trg[: , :-1])

In [60]:
out

tensor([[[-5.3821e-01,  4.5850e-01,  4.4112e-01, -7.9740e-01, -5.0590e-01,
          -2.0741e-01,  1.5470e+00,  4.7678e-01, -1.2351e-01, -4.3349e-01,
           3.6817e-01,  4.8197e-01],
         [-2.9126e-01, -1.1054e-01,  1.6197e-01,  5.8384e-02, -6.9372e-02,
           2.4640e-01,  1.3628e+00,  9.5989e-01, -2.9523e-01, -5.9288e-01,
          -4.2636e-01,  5.4987e-05],
         [-3.6405e-01,  3.3587e-02, -5.0334e-01, -4.6275e-01, -2.2776e-01,
           4.1053e-01,  2.6704e-01,  6.0039e-01, -1.5268e-01, -8.0262e-01,
          -1.0979e-01,  7.1486e-01],
         [-2.5616e-01,  2.1733e-01, -5.8598e-01, -1.0408e+00, -3.9658e-01,
          -3.3248e-01,  7.1846e-01,  6.8377e-01,  2.0936e-01, -1.7869e-01,
          -3.5113e-01, -4.0672e-01],
         [ 1.1847e-01,  7.3039e-02, -6.4857e-03, -1.7637e-01, -4.7816e-01,
          -7.6010e-02,  1.0083e+00,  4.8404e-01, -4.7018e-01, -2.4168e-01,
          -8.6390e-01,  6.2029e-02],
         [ 7.1927e-01,  3.2494e-01,  9.8658e-02,  8.4394e-01, -2.

In [61]:
out.shape

torch.Size([2, 7, 12])