In [2]:
import torch
import torch.nn as nn

In [3]:
class SelfAttention(nn.Module):
    def __init__(self, embed_size, heads):
        super(SelfAttention, self).__init__()
        self.embed_size = embed_size
        self.heads = heads
        self.head_dim = embed_size // heads

        assert (self.head_dim * heads == embed_size), "Embedding size needs to be divisible by heads"

        self.values = nn.Linear(self.embed_size, self.embed_size, bias=False) # input: x, Linear() : Wx where W is trainable
        self.keys = nn.Linear(self.embed_size, self.embed_size, bias=False)
        self.queries = nn.Linear(self.embed_size, self.embed_size, bias=False)
        self.fc_out = nn.Linear(self.embed_size, self.embed_size, bias=False) # after concating the head outputs we apply W on top of it

    
    def forward(self, values, keys, query, mask=None):
        # query = (N, query_len(number of tokens in the query), embed_size)
        N = query.shape[0] # number of training examples in the batch

        # number of tokens in the query, key, value matrices
        # in case of encoder - these are same
        # in case of decoder - query_len != key_len == value_len as the input comes from encoder output
        value_len, key_len, query_len = values.shape[1], keys.shape[1], query.shape[1]

        values = self.values(values) # (N, value_len, embed_size)
        keys = self.keys(keys)       # (N, key_len, embed_size)
        queries = self.queries(query) # (N, query_len, embed_size)

        # split the embedding into self.heads different pieces
        # keep the first two dimensions same and split the last dimension
        # self.head_dim * self.heads = embed_size
        values = values.reshape(N, value_len, self.heads, self.head_dim) # (N, value_len, heads, head_dim) 
        keys = keys.reshape(N, key_len, self.heads, self.head_dim)       # (N, key_len, heads, head_dim) -> nkhd
        queries = queries.reshape(N, query_len, self.heads, self.head_dim) # (N, query_len, heads, head_dim) -> nqhd

        # The Q, K, V matrices have a dimension of (n_tokens, d_k)
        # We want to perform the dot product attention for each head
        # In this format computing the QK^T is easy and it will return a matrix of dimension (n_tokens, n_tokens)
        # However when we introduce batching the computation of QK^T becomes difficult
        # We use einsum to perform the operation
        # explicitly tell what are the dimensions of the input and to what dimensions we want to map the output
       
        attn_scores = torch.einsum("nqhd,nkhd->nhqk", [queries, keys]) # (N, heads, query_len, key_len)
         # attention scores should give -> for every sentence in the batch, for every head, for every query token, the scores for all key tokens

        # suppose the length of sentences are different and we have to pad them
        if mask is not None:
            attn_scores = attn_scores.masked_fill(mask==0, float("-1e20")) # very large negative value so that after softmax it becomes zero


        # Scale and Normalize
        attention = torch.softmax(attn_scores / (self.head_dim ** (1/2)), dim=3) # dim=3 -> along key_len dimension
        # along dim 3 means for each query token we get a distribution over all key tokens

        # attention shape -> (N, heads, query_len, key_len)
        # values shape -> (N, value_len, heads, head_dim)
        # out after matrix multiply: (N, query_len, heads, head_dim), then
        # we reshape and concatenate the last two dimensions
        # key len == value len
        # key space and value space are same(same set of tokens)
        out = torch.einsum("nhql,nlhd->nqhd", [attention, values]).reshape(N, query_len, self.heads * self.head_dim)

        # dimension of the final output that we get is same as the input we fed in (nqhd)

        out = self.fc_out(out) # (N, query_len, embed_size)

        return out

In [4]:
class TransformerBlock(nn.Module):
    def __init__(self, embed_size, heads, dropout, forward_expansion):
        super(TransformerBlock, self).__init__()
        self.attention = SelfAttention(embed_size, heads)
        self.norm1 = nn.LayerNorm(embed_size)
        self.norm2 = nn.LayerNorm(embed_size)

        # forward expansion, we increase the dimension of the feed forward network
        # 4 times the embedding size as per the original paper
        self.feed_forward = nn.Sequential(
            nn.Linear(embed_size, forward_expansion*embed_size),
            nn.ReLU(),
            nn.Linear(forward_expansion*embed_size, embed_size)
        )

        self.dropout = nn.Dropout(dropout)

    def forward(self, value, key, query, mask):
        # mask is used to mask out the padded tokens
        attention = self.attention(value, key, query, mask)
        
        x = self.dropout(self.norm1(attention + query))
        forward = self.feed_forward(x)
        out = self.dropout(self.norm2(forward + x))
        return out



In [None]:
class Encoder(nn.Module):
    def __init__(self, src_vocab_size, embed_size, num_layers, heads, device, forward_expansion, dropout, max_length):
        