In [28]:
import torch
import torch.nn.functional as F

# 1 - Attention
There is 3 main ways to do attention: encoder-decoder attention, causal attention and bi-directional self-attention. I already implemented the first in this [repo](https://github.com/abtraore/nmt-rnn-pytorch-from-scratch). I will then focus on the later 2. In the heart of any attention scheme used in a transformer reside a **dot product attention**, let focus on that. 

### 1.1 - Dot product attention

The dot product attention can be computed using the formula: $$\textrm{softmax} \left(\frac{Q K^T}{\sqrt{d}} + M \right) V$$ Where the (optional, but default) scaling factor $\sqrt{d}$ is the square root of the embedding dimension.

In [43]:
# Dummy q,k,v matrix. They all should have the same shape.
q = torch.tensor([[1,0,0],[0,1,0]],dtype=torch.float)
k = torch.tensor([[1,2,3],[4,5,6]],dtype=torch.float)
v = torch.tensor([[0,1,0],[1,0,1]],dtype=torch.float)

# The mask used if we want to compute the causal sot product attention.
m = torch.tensor([[1,0],[1,0]],dtype=torch.float)

In [49]:
def dot_product_attention(q,k,v,mask=None,scale=True):

    matmul_qk = q @ k.T

    if scale:
        dk = torch.tensor(k.shape[-1],dtype=torch.float)
        matmul_qk /= torch.sqrt(dk)

    if mask is not None:
        matmul_qk = matmul_qk + (1-mask) * -1e9

    attention_weights = F.softmax(matmul_qk,dim=-1)

    print(f"Attention weights:\n{attention_weights}")

    attention_output = attention_weights @ v

    return attention_output

dot_product_attention(q,k,v,mask=m)

Attention weights:
tensor([[1., 0.],
        [1., 0.]])


tensor([[0., 1., 0.],
        [0., 1., 0.]])

### 1.2 - Causal dot product attention



In [58]:
def causal_dot_product_attention(q,k,v,scale=True):
    
    # The size of the mask equal the penultimate dimention of the query. (seq length)
    mask_size = q.shape[-2]

    # Creates a matrix with ones below the diagonal and 0s above.
    mask = torch.tril(torch.ones(mask_size,mask_size))
    
    print(f"Causal mask:\n{mask}\n")

    return dot_product_attention(q,k,v,mask=mask,scale=scale)

causal_dot_product_attention(q,k,v)

Causal mask:
tensor([[1., 0.],
        [1., 1.]])

Attention weights:
tensor([[1.0000, 0.0000],
        [0.1503, 0.8497]])


tensor([[0.0000, 1.0000, 0.0000],
        [0.8497, 0.1503, 0.8497]])

# 2 - Masking
Masking is a key component in building transformers. There is two type of mask : the ***padding mask*** and the ***look-ahead*** mask. Both contribute to appropriatly compute the softmax by appliying the proper weights to words.

# 2.1 - Padding Mask
When fedding an input batch to a model, sentences can have different length, so it is important to pad them using 0s. Longer sentences that are supperior to the maximum length will be truncaded. Note that to compute the attention we will proceded as we did in 1.1. The 0s will be put to -inf(a small negative number) so that they won't affect the softmax.

In [71]:
def create_padding_mask(token_ids):
    # All the padding will have 0 as value
    mask = 1 - (token_ids == 0).float()
    #Add an extra dimension to allow broadcasting.
    mask = mask.unsqueeze(1)
    return mask

tensor([[[1., 1., 0., 0., 0.]],

        [[1., 1., 1., 0., 0.]],

        [[1., 0., 0., 0., 0.]]])

In [82]:
x = torch.tensor([[7., 6., 0., 0., 0.], [1., 2., 3., 0., 0.], [3., 0., 0., 0., 0.]])
mask = create_padding_mask(x)
x_extented = x.unsqueeze(1)
print(f"No padding mask softmax: \n{F.softmax(x_extented,dim=-1)}")

print(f"\nPadding mask softmax: \n{F.softmax(x_extented + (1 - mask) * -1e9,dim=-1)}")

No padding mask softmax: 
tensor([[[7.2960e-01, 2.6840e-01, 6.6531e-04, 6.6531e-04, 6.6531e-04]],

        [[8.4437e-02, 2.2952e-01, 6.2391e-01, 3.1063e-02, 3.1063e-02]],

        [[8.3393e-01, 4.1519e-02, 4.1519e-02, 4.1519e-02, 4.1519e-02]]])

Padding mask softmax: 
tensor([[[0.7311, 0.2689, 0.0000, 0.0000, 0.0000]],

        [[0.0900, 0.2447, 0.6652, 0.0000, 0.0000]],

        [[1.0000, 0.0000, 0.0000, 0.0000, 0.0000]]])


# 2.2 - Look-ahead Mask

In [90]:
def create_look_ahead_mask(seq_length):
    mask = torch.tril(torch.ones((1,seq_length,seq_length)))
    return mask

x = torch.randn(1,3)
x = create_look_ahead_mask(x.shape[1])
x

tensor([[[1., 0., 0.],
         [1., 1., 0.],
         [1., 1., 1.]]])