In [1]:
import torch 
import torch.nn as nn
import numpy as np

In [2]:
class SelfAttention(nn.Module):
    def __init__(self,embed_size,heads):
        super().__init__()
        self.embed_size = embed_size
        self.heads = heads
        self.head_dim = embed_size // heads
#         print(self.embed_size)
#         print(self.heads)
#         print(self.head_dim)
        
        assert (self.head_dim * heads == embed_size),"Embed size must be divisible by heads"
        
        self.values = nn.Linear(self.head_dim,self.head_dim,bias = False)
        self.keys = nn.Linear(self.head_dim,self.head_dim,bias = False)
        self.queries = nn.Linear(self.head_dim,self.head_dim,bias = False)
        self.fc_out = nn.Linear(heads * self.head_dim,embed_size)
        
    def forward(self,values,keys,queries,mask):
        N = queries.shape[0]
        
        value_len,key_len,query_len = values.shape[1],keys.shape[1],queries.shape[1]
        
        #splitting embeddings into self.head pieces
        values = values.reshape(N , value_len , self.heads , self.head_dim)
        keys = keys.reshape(N , key_len , self.heads, self.head_dim)
        queries = queries.reshape(N , query_len, self.heads, self.head_dim)
        
        values = self.values(values)
        keys = self.keys(keys)
        queries = self.queries(queries)
        
        energy = torch.einsum("nqhd,nkhd->nhqk",[queries,keys])
        
        # Making -inf to the least valued words in the sentence
        if mask is not None:
            energy = energy.masked_fill(mask == 0 , float("-1e20"))
        
        attention = torch.softmax(energy / (self.embed_size ** (1/2)),dim=3)
        
        out = torch.einsum("nhql,nlhd->nqhd",[attention,values]).reshape(N , query_len , self.heads * self.head_dim)
        
        out = self.fc_out(out)
        return out

In [3]:
class TransformerBlock(nn.Module):
    def __init__(self,embed_size,heads,dropout,forward_expansion):
        super(TransformerBlock,self).__init__()
        self.attention = SelfAttention(embed_size,heads)
        
        self.norm1 = nn.LayerNorm(embed_size)
        self.norm2 = nn.LayerNorm(embed_size)
        
        self.feed_forward = nn.Sequential(
            nn.Linear(embed_size,forward_expansion*embed_size),
            nn.ReLU(),
            nn.Linear(forward_expansion*embed_size,embed_size)
        )
        
        self.dropout = nn.Dropout(dropout)
        
    def forward(self,values,keys,queries,mask):
        attention = self.attention(values,keys,queries,mask)
        
        #skip connection
        x = self.dropout(self.norm1(attention + queries))
        forward = self.feed_forward(x)
        out = self.dropout(self.norm2(forward+x))
        
        return out

In [4]:
class Encoder(nn.Module):
    def __init__(self,embed_size,
                 heads,src_vocab_size,num_layers,
                device,forward_expansion,dropout,
                max_length):
        super().__init__()
        self.embed_size = embed_size
        self.device = device
        self.word_embedding = nn.Embedding(src_vocab_size,embed_size)
        self.position_embedding = nn.Embedding(max_length,embed_size)
        
        self.layers = nn.ModuleList(
        [
          TransformerBlock(embed_size,
                           heads,
                           dropout=dropout,
                           forward_expansion=forward_expansion)
            for _ in range(num_layers)
        ])
        
        self.dropout = nn.Dropout(dropout)
        
    def forward(self,x,mask):
        N,seq_length = x.shape
        
        positions = torch.arange(0,seq_length).expand(N,seq_length).to(self.device)
        
        # It learns how words are structured
        out = self.dropout(self.word_embedding(x) + self.position_embedding(positions))
        
        for layer in self.layers:
            out = layer(out,out,out,mask)
            
        return out
        

In [5]:
class DecoderBlock(nn.Module):
    def __init__(self,embed_size,heads,forward_expansion,dropout,device):
        super().__init__()
        
        self.attention = SelfAttention(embed_size,heads)
        self.norm = nn.LayerNorm(embed_size)
        self.transformerblock = TransformerBlock(embed_size,heads,dropout,forward_expansion)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self,x,value,key,src_mask,target_mask):
        attention = self.attention(x,x,x,target_mask)
        query = self.dropout(self.norm(attention + x))
        out = self.transformerblock(value,key,query,src_mask)
        return out

In [6]:
class Decoder(nn.Module):
    def __init__(self,target_vocab_size,
                embed_size,num_layers,
                heads,forward_expansion,
                dropout,device,max_length):
        
        super().__init__()
        
        self.device = device
        self.word_embedding = nn.Embedding(target_vocab_size,embed_size)
        self.position_embedding = nn.Embedding(max_length,embed_size)
        
        self.layers = nn.ModuleList([
            DecoderBlock(embed_size,heads,forward_expansion,dropout,device)
            for _ in range(num_layers)
        ])
    
        self.fc_out = nn.Linear(embed_size,target_vocab_size)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self,x,enc_out,src_mask,target_mask):
        N, seq_length = x.shape
        positions = torch.arange(0,seq_length).expand(N,seq_length).to(self.device)
        
        x = self.dropout((self.word_embedding(x) + self.position_embedding(positions)))
        
        for layer in self.layers:
            x = layer(x,enc_out,enc_out,src_mask,target_mask)
            
        out = self.fc_out(x)# fully connected out
        return out

In [7]:
class TransFormer(nn.Module):
    def __init__(self,src_vocab_size,
                target_vocab_size,src_pad_idx,
                target_pad_idx,embed_size=512,
                num_layers=6,forward_expansion=4,
                heads=8,dropout=0,device="cpu",max_length=100):
        super().__init__()
        
        self.encoder = Encoder(src_vocab_size=src_vocab_size,
                               embed_size=embed_size,
                               num_layers=num_layers,heads=heads,
                               device=device,
                               forward_expansion=forward_expansion,
                              dropout=dropout,
                               max_length=max_length)
        
        self.decoder = Decoder(target_vocab_size=target_vocab_size,
                              embed_size=embed_size,
                              num_layers=num_layers,
                              heads=heads,
                              forward_expansion=forward_expansion,
                              dropout=dropout,
                              device=device,
                              max_length=max_length)
        
        self.src_pad_idx = src_pad_idx
        self.target_pad_idx = target_pad_idx
        self.device = device
        
    def make_src_mask(self,src):
        src_mask = (src != self.src_pad_idx).unsqueeze(1).unsqueeze(2)
        # shape -> (N,1,1,src_length)
        return src_mask.to(self.device)
    
    def make_target_mask(self,target):
        
        N,target_len = target.shape
        
        target_mask = torch.tril(torch.ones((target_len,target_len))).expand(N,1,target_len,target_len)
        
        return target_mask.to(self.device)
    
    def forward(self,src,target):
        
        src_mask = self.make_src_mask(src)
        target_mask = self.make_target_mask(target)
        
        enc_src = self.encoder(src,src_mask)
        out = self.decoder(target,enc_src,src_mask,target_mask)
        
        return out

In [8]:
if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(device)

    x = torch.tensor([[1, 5, 6, 4, 3, 9, 5, 2, 0], [1, 8, 7, 3, 4, 5, 6, 7, 2]]).to(device)
    
    trg = torch.tensor([[1, 7, 4, 3, 5, 9, 2, 0], [1, 5, 6, 2, 4, 7, 6, 2]]).to(device)

    src_pad_idx = 0
    trg_pad_idx = 0
    src_vocab_size = 10
    trg_vocab_size = 10
    model = TransFormer(src_vocab_size, trg_vocab_size, src_pad_idx, trg_pad_idx, device=device).to(device)
    out = model(x, trg[:, :-1])
    print(out.shape)

cpu
torch.Size([2, 7, 10])
