#### Multi-head self attention

[code source](https://peterbloem.nl/blog/transformers)

* each attention head (k) is a chunk of input dimension (k). Multihead attention outputs (k//h) are aggregated into the original shape of input - this is just a more efficient way of calculating multi-head attention
* x' = attention (x), where x and x' share the same dimensionality

In [1]:
import torch
from torch import nn
import torch.nn.functional as F

class SelfAttention(nn.Module):
    def __init__(self, k, heads=4, mask=False):
        super().__init__()
        assert k % heads == 0
        self.k, self.heads = k, heads
        # since key, query and value similarity are all same and characterized using product, 
        # thus they share the same dimension
        self.tokeys = nn.Linear(k, k, bias=False)
        self.toqueries = nn.Linear(k, k, bias=False)
        self.tovalues = nn.Linear(k, k, bias=False)
        self.unifyheads = nn.Linear(k, k)

    def forward(self, x):
        # Example input shape: (batch_size=2, sequence_length=10, k=6)
        print("Input shape:", x.shape)

        b, t, k = x.size() # t represents sequence length, k represents sequence dimension
        h = self.heads

        queries = self.toqueries(x)
        keys = self.tokeys(x)
        print(f'keys vector after linear transformation: {keys.size()}') # (4,10,3)
        print("Weight shape:", self.toqueries.weight.shape)


        values = self.tovalues(x)

        s = k // h

        keys = keys.view(b, t, h, s).transpose(1, 2).contiguous().view(b * h, t, s)
        queries = queries.view(b, t, h, s).transpose(1, 2).contiguous().view(b * h, t, s)
        values = values.view(b, t, h, s).transpose(1, 2).contiguous().view(b * h, t, s)

        # Example weight shape: (k=6, k=6)
        print(f'keys vector shape after adding multiple attension heads is: {keys.size()}') # (4,10,3)

        queries = queries
        keys = keys

        dot = torch.bmm(queries, keys.transpose(1, 2))
        # Example dot shape: (batch_size=2, heads=4, sequence_length=10, sequence_length=10)
        print("Dot product shape:", dot.shape)

        dot = dot / (k ** (1/2))

        dot = F.softmax(dot, dim=2)
        print(f'Softmax dot product shape is {dot.shape}')

        out = torch.bmm(dot, values).view(b, h, t, s)
        out = out.transpose(1, 2).contiguous().view(b, t, s * h)

        # Example output shape: (batch_size=2, sequence_length=10, k=6)
        print("Output shape:", self.unifyheads(out).shape)

        return self.unifyheads(out)

In [16]:
# Example input tensor
# sequence dimension=6. sequence length = 10
x = torch.randn(2, 10, 6)

# Create an instance of the SelfAttention class with k=6 and heads=2
self_attention = SelfAttention(6, 2)

# Pass the input tensor to the forward method
output = self_attention(x)

Input shape: torch.Size([2, 10, 6])
keys vector after linear transformation: torch.Size([2, 10, 6])
Weight shape: torch.Size([6, 6])
keys vector shape after adding multiple attension heads is: torch.Size([4, 10, 3])
Dot product shape: torch.Size([4, 10, 10])
Softmax dot product shape is torch.Size([4, 10, 10])
Output shape: torch.Size([2, 10, 6])
torch.Size([2, 10, 6])


#### Explanation

* input x with size (batch,seq_len,seq_dim)
* linear transformation matrix (query,key,values) with size (seq_dim,seq_dim). Transform input sequence to the matrix with the same shape
* Adding multiple heads: reshape the matrix into different components. Each component represents one head
* Query * keys = (batch,seq_len,seq_len): a matrix contains all similarity beteen each sequence. attention matrix
* attention matrix mulitiplied by values: get the final attension score. The final attention score is the same shape as input. But now it's a overall vector combines all importance information in different sequence parts.