In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

In [None]:
def scaled_dot_product(q, k, v, mask=None):
    """
    Here we compute the similarity between the query and the key vectors
    by calculating the multiplication of Q with K (i.e. Q. K ^T).
    We divide this by square root of the dimension of k to reduce the
    variance. We pass the resultant through a softmax to get a matrix of
    probabilities which are the "Attention Scores" denoting (numerically)
    how much each word/token in a sentence is related to the rest of the
    words.

    Parameters
    ----------
    q : tensor
        This is a tensor of dimension batch size x number of attention heads
        x sequence length x length of the query vector of each head.
    k : tensor
        This is a tensor of dimension batch size x number of attention heads
        x sequence length x length of the key vector of each head.
    v : tensor
        This is a tensor of dimension batch size x number of attention heads
        x sequence length x length of the value vector of each head.
    mask : matrix, optional
        In the Encoder, we do not require masking, in Decoder we do
        require masking as we do not want to know the relavence of
        the next words. We do not want the behaviour to be
        bi-directional for language,
        by default None

    Returns
    -------
    _type_
        _description_
    """

    d_k = q.size()[-1]
    print("q.size() = ", q.size())
    scaled = torch.matmul(q, k.transpose(-1, -2)) / math.sqrt(d_k)
    if mask is not None:
        scaled += mask
    attention = F.softmax(scaled, dim=-1)
    values = torch.matmul(attention, v)
    return values, attention


In [None]:
class MultiheadAttention(nn.Module):

    """

    This is the Multihead Attention class defined by 3 arguments:

    input_dim: This represents the vector dimension of every word
    that goes into the attention unit.

    d_model: output of the attention unit for every single word
    (i.e. after coming out as a value vector)

    num_heads: number of attention heads

    """
    def __init__(self, input_dim, d_model, num_heads):

        """
        This is the __init__ method the constructor of the class.
        It carries out the follwoing actions:
        a) Calls the superclass nn.Module constructor with
        super().__init__().
        b) Sets the input arguments as the attributes of the class for
        later use (self.input_dim, self.d_model, self.num_heads).
        c) Calculates the dimension of the attention head by dividing
        d_model by number of heads

        Further we have 2 linear layers:
        a) self.qkv_layer: This represents the qkv_later which
        takes the input vector and maps it to the concatenated q, k, v
        vectors respectively.

        b)self.linear_layer: This linear layer is used to process
        the concatenated outputs of all attention heads.
        It takes the concatenated results and maps them
        back to the original d_model dimension.


        Parameters
        ----------
        input_dim : integer
            As defined above
        d_model : integer
            As defined above
        num_heads : integer
           As defined above
        """

        super().__init__()
        self.input_dim = input_dim
        self.d_model = d_model
        self.num_heads = num_heads
        self.head_dim = d_model // num_heads
        self.qkv_layer = nn.Linear(input_dim , 3 * d_model)
        self.linear_layer = nn.Linear(d_model, d_model)

    def forward(self, x, mask=None):
        """
        Here we carry out the multihead attention mechanism.
        The multi head attention mechanism involves the follwoing steps:
        1) We pass the vector from positional encoding through the q, k, v layer
        2) The q, k, v layer is a linear layer which transforms vector from
        positional encoding into a concatenated q, k, v vectors expressing
        query, key, value as discussed in theory.
        3) The tensor of batch size x sequence length x concatenated q, k, v
        length (e.g. 512 x 3) is reshaped into batch size x sequence length
        x no. of heads x head dimension.
        4) Permute - we switch around the 2nd and 3rd dimension
        5) Chunk - i.e. we obtain the query, key and value vector individually by breaking down
        the entire tensor by its last dimension/
        6) We get the value vector and the matrix of attention score through
        the scalar_dot_product function.
        7) We pass the value vector through another linear layer in order to
        exchange the information through various heads.


        Parameters
        ----------
        x : tensor
            tensor from positional encoding of size batch_size, sequence_length,
            input_dim
        mask : matrix, optional
        In the Encoder, we do not require masking, in Decoder we do
        require masking as we do not want to know the relavence of
        the next words. We do not want the behaviour to be
        bi-directional for language,
        by default None


        Returns
        -------
        tensor
            final concatenated value tensor
        """

        batch_size, sequence_length, input_dim = x.size()
        print(f"x.size(): {x.size()}")
        qkv = self.qkv_layer(x)
        print(f"qkv.size(): {qkv.size()}")
        qkv = qkv.reshape(batch_size, sequence_length, self.num_heads, 3 * self.head_dim)
        print(f"qkv.size(): {qkv.size()}")
        qkv = qkv.permute(0, 2, 1, 3)
        print(f"qkv.size(): {qkv.size()}")
        q, k, v = qkv.chunk(3, dim=-1)
        print(f"q size: {q.size()}, k size: {k.size()}, v size: {v.size()}, ")
        values, attention = scaled_dot_product(q, k, v, mask)
        print(f"values.size(): {values.size()}, attention.size:{ attention.size()} ")
        values = values.reshape(batch_size, sequence_length, self.num_heads * self.head_dim)
        print(f"values.size(): {values.size()}")
        out = self.linear_layer(values)
        print(f"out.size(): {out.size()}")
        return out


In [None]:
input_dim = 512
d_model = 512
num_heads = 8

batch_size = 1
sequence_length = 4
x = torch.randn( (batch_size, sequence_length, input_dim) )

model = MultiheadAttention(input_dim, d_model, num_heads)
out = model.forward(x)

x.size(): torch.Size([1, 4, 512])
qkv.size(): torch.Size([1, 4, 1536])
qkv.size(): torch.Size([1, 4, 8, 192])
qkv.size(): torch.Size([1, 8, 4, 192])
q size: torch.Size([1, 8, 4, 64]), k size: torch.Size([1, 8, 4, 64]), v size: torch.Size([1, 8, 4, 64]), 
q.size() =  torch.Size([1, 8, 4, 64])
values.size(): torch.Size([1, 8, 4, 64]), attention.size:torch.Size([1, 8, 4, 4]) 
values.size(): torch.Size([1, 4, 512])
out.size(): torch.Size([1, 4, 512])
