In [2]:
import torch
import torch.nn as nn

In [3]:
class SelfAttention(nn.Module):
    def __init__(self, embed_size, heads):
        super(SelfAttention, self).__init__()
        self.embed_size = embed_size
        self.heads = heads
        self.head_dim = embed_size // heads

        assert (self.head_dim * heads == embed_size), "Embedding size needs to be divisible by heads"

        self.values = nn.Linear(self.embed_size, self.embed_size, bias=False) # input: x, Linear() : Wx where W is trainable
        self.keys = nn.Linear(self.embed_size, self.embed_size, bias=False)
        self.queries = nn.Linear(self.embed_size, self.embed_size, bias=False)
        self.fc_out = nn.Linear(self.embed_size, self.embed_size, bias=False) # after concating the head outputs we apply W on top of it

    
    def forward(self, values, keys, query, mask=None):
        # query = (N, query_len(number of tokens in the query), embed_size)
        N = query.shape[0] # number of training examples in the batch

        # number of tokens in the query, key, value matrices
        # in case of encoder - these are same
        # in case of decoder - query_len != key_len == value_len as the input comes from encoder output
        value_len, key_len, query_len = values.shape[1], keys.shape[1], query.shape[1]

        values = self.values(values) # (N, value_len, embed_size)
        keys = self.keys(keys)       # (N, key_len, embed_size)
        queries = self.queries(query) # (N, query_len, embed_size)

        # split the embedding into self.heads different pieces
        # keep the first two dimensions same and split the last dimension
        # self.head_dim * self.heads = embed_size
        values = values.reshape(N, value_len, self.heads, self.head_dim) # (N, value_len, heads, head_dim) 
        keys = keys.reshape(N, key_len, self.heads, self.head_dim)       # (N, key_len, heads, head_dim) -> nkhd
        queries = queries.reshape(N, query_len, self.heads, self.head_dim) # (N, query_len, heads, head_dim) -> nqhd

        # The Q, K, V matrices have a dimension of (n_tokens, d_k)
        # We want to perform the dot product attention for each head
        # In this format computing the QK^T is easy and it will return a matrix of dimension (n_tokens, n_tokens)
        # However when we introduce batching the computation of QK^T becomes difficult
        # We use einsum to perform the operation
        # explicitly tell what are the dimensions of the input and to what dimensions we want to map the output
       
        attn_scores = torch.einsum("nqhd,nkhd->nhqk", [queries, keys]) # (N, heads, query_len, key_len)
         # attention scores should give -> for every sentence in the batch, for every head, for every query token, the scores for all key tokens

        # suppose the length of sentences are different and we have to pad them
        if mask is not None:
            attn_scores = attn_scores.masked_fill(mask==0, float("-1e20")) # very large negative value so that after softmax it becomes zero


        # Scale and Normalize
        attention = torch.softmax(attn_scores / (self.head_dim ** (1/2)), dim=3) # dim=3 -> along key_len dimension
        # along dim 3 means for each query token we get a distribution over all key tokens

        # attention shape -> (N, heads, query_len, key_len)
        # values shape -> (N, value_len, heads, head_dim)
        # out after matrix multiply: (N, query_len, heads, head_dim), then
        # we reshape and concatenate the last two dimensions
        # key len == value len
        # key space and value space are same(same set of tokens)
        out = torch.einsum("nhql,nlhd->nqhd", [attention, values]).reshape(N, query_len, self.heads * self.head_dim)

        # dimension of the final output that we get is same as the input we fed in (nqhd)

        out = self.fc_out(out) # (N, query_len, embed_size)

        return out

In [4]:
class TransformerBlock(nn.Module):
    def __init__(self, embed_size, heads, dropout, forward_expansion):
        super(TransformerBlock, self).__init__()
        self.attention = SelfAttention(embed_size, heads)
        self.norm1 = nn.LayerNorm(embed_size)
        self.norm2 = nn.LayerNorm(embed_size)

        # forward expansion, we increase the dimension of the feed forward network
        # 4 times the embedding size as per the original paper
        self.feed_forward = nn.Sequential(
            nn.Linear(embed_size, forward_expansion*embed_size),
            nn.ReLU(),
            nn.Linear(forward_expansion*embed_size, embed_size)
        )

        self.dropout = nn.Dropout(dropout)

    def forward(self, value, key, query, mask):
        # mask is used to mask out the padded tokens
        attention = self.attention(value, key, query, mask)
        
        x = self.dropout(self.norm1(attention + query))
        forward = self.feed_forward(x)
        out = self.dropout(self.norm2(forward + x))
        return out



In [5]:
class Encoder(nn.Module):
    def __init__(self, src_vocab_size, embed_size, num_layers, heads, device, forward_expansion, dropout, max_length):
        super(Encoder, self).__init__()
        self.embed_size = embed_size
        self.device = device
        self.word_embedding = nn.Embedding(src_vocab_size, embed_size)
        self.position_embedding = nn.Embedding(max_length, embed_size)

        self.layers = nn.ModuleList(
            [
                TransformerBlock(
                    embed_size, 
                    heads, 
                    dropout=dropout, 
                    forward_expansion=forward_expansion
                )
                for _ in range(num_layers)
            ]
        )

        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask):
        # N - number of training examples in the batch
        # seq_length - number of tokens in the input sentence
        N, seq_length = x.shape
        positions = torch.arange(0, seq_length).expand(N, seq_length).to(self.device) # (N, seq_length)
        out = self.dropout(
            (self.word_embedding(x) + self.position_embedding(positions))
        )

        # in the encoder the query, key, value are all the same, it's in the decoder that this will change
        for layer in self.layers:
            out = layer(out, out, out, mask)
        
        return out

In [6]:
class DecoderBlock(nn.Module):
    def __init__(self, embed_size, heads, forward_expansion, dropout, device):
        super(DecoderBlock, self).__init__()
        self.attention = SelfAttention(embed_size, heads)
        self.norm = nn.LayerNorm(embed_size)
        self.transformer_block = TransformerBlock(
            embed_size, heads, dropout, forward_expansion
        )
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, value, key, src_mask, trg_mask):
        attention = self.attention(x, x, x, trg_mask)
        query = self.dropout(self.norm(attention + x))
        # query comes from the input but key and value come from the encoder output
        # this is because we want to find the relevance of the input tokens to the output tokens
        out = self.transformer_block(value, key, query, src_mask)
        return out

In [8]:
class Decoder(nn.Module):
    def __init__(self, trg_vocab_size, embed_size, num_layers, heads, forward_expansion, dropout, device, max_length):
        super(Decoder, self).__init__()
        self.device = device
        self.word_embedding = nn.Embedding(trg_vocab_size, embed_size)
        self.position_embedding = nn.Embedding(max_length, embed_size)

        self.layers = nn.ModuleList(
            [
                DecoderBlock(embed_size, heads, forward_expansion, dropout, device)
                for _ in range(num_layers)
            ]
        )
        self.fc_out = nn.Linear(embed_size, trg_vocab_size) # projecting from embed_dim to vocab_size
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, enc_out, src_mask, trg_mask):
        N, seq_length = x.shape
        positions = torch.arange(0, seq_length).expand(N, seq_length).to(self.device)
        x = self.dropout(self.word_embedding(x) + self.position_embedding(positions))

        for layer in self.layers:
            # x current input to the decoder
            # enc_out is the output from the encoder
            # trg_mask is the look ahead mask for masking the future tokens in the input to the decoder
            x = layer(x, enc_out, enc_out, src_mask, trg_mask)

        out = self.fc_out(x)

        return out

In [10]:
class Transformer(nn.Module):
    def __init__(self, 
                 src_vocab_size, 
                 trg_vocab_size, 
                 src_pad_idx, 
                 trg_pad_idx, 
                 embed_size=256, 
                 num_layers=6, 
                 forward_expansion=4, 
                 heads=8, 
                 dropout=0, 
                 device="cuda", 
                 max_length=100):
        super(Transformer, self).__init__()

        self.encoder = Encoder(
            src_vocab_size, 
            embed_size, 
            num_layers, 
            heads, 
            device, 
            forward_expansion, 
            dropout, 
            max_length
        )

        self.decoder = Decoder(
            trg_vocab_size, 
            embed_size, 
            num_layers, 
            heads, 
            forward_expansion, 
            dropout, 
            device, 
            max_length
        )
        # takes care of thte padded tokens in the input
        self.src_pad_idx = src_pad_idx
        # takes care of the future tokens in the target sequence
        self.trg_pad_idx = trg_pad_idx
        self.device = device

    def make_src_mask(self, src):
        # src shape: (N, src_len)
        src_mask = (src != self.src_pad_idx).unsqueeze(1).unsqueeze(2)
        # (N, 1, 1, src_len) -> for each sentence in the batch, for each token in the sentence we have a 1 if it's not a pad token else 0
        return src_mask.to(self.device)
    
    def make_trg_mask(self, trg):
        N, trg_len = trg.shape
        trg_mask = (trg != self.trg_pad_idx).unsqueeze(1).unsqueeze(2)
        # (N, 1, 1, trg_len)

        # subsequent mask to ensure that when predicting the i-th token, the model only attends to the first i tokens
        # this is done to prevent the model from peeking into the future tokens
        subsequent_mask = torch.tril(torch.ones((trg_len, trg_len))).bool().to(self.device)
        # (trg_len, trg_len)

        trg_mask = trg_mask & subsequent_mask
        # (N, 1, trg_len, trg_len)

        return trg_mask

    def forward(self, src, trg):
        src_mask = self.make_src_mask(src)
        trg_mask = self.make_trg_mask(trg)
        enc_src = self.encoder(src, src_mask)
        out = self.decoder(trg, enc_src, src_mask, trg_mask)
        return out


### Inference Example

In [11]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

x = torch.tensor([[1,5,6,4,3,9,5,2,0],
                  [1,8,7,3,4,5,6,7,2]],).to(device)
trg = torch.tensor([[1,7,4,3,5,9,2,0],
                    [1,5,6,2,4,7,6,2]]).to(device)

src_pad_idx = 0
trg_pad_idx = 0
src_vocab_size = 10
trg_vocab_size = 10
model = Transformer(src_vocab_size, trg_vocab_size, src_pad_idx, trg_pad_idx, device=device).to(device)
out = model(x, trg[:, :-1])
print(out.shape) # (N, trg_len - 1, trg_vocab_size)

cuda
torch.Size([2, 7, 10])
