In [28]:
import torch
import torch.nn.functional as F

# 1 Attention
There is 3 main ways to do attention: encoder-decoder attention, causal attention and bi-directional self-attention. I already implemented the first in this [repo](https://github.com/abtraore/nmt-rnn-pytorch-from-scratch). I will then focus on the later 2. In the heart of any attention scheme used in a transformer reside a **dot product attention**, let focus on that. 

### 1.1 Dot product attention

The dot product attention can be computed using the formula: $$\textrm{softmax} \left(\frac{Q K^T}{\sqrt{d}} + M \right) V$$ Where the (optional, but default) scaling factor $\sqrt{d}$ is the square root of the embedding dimension.

In [43]:
# Dummy q,k,v matrix. They all should have the same shape.
q = torch.tensor([[1,0,0],[0,1,0]],dtype=torch.float)
k = torch.tensor([[1,2,3],[4,5,6]],dtype=torch.float)
v = torch.tensor([[0,1,0],[1,0,1]],dtype=torch.float)

# The mask used if we want to compute the causal sot product attention.
m = torch.tensor([[1,0],[1,0]],dtype=torch.float)

In [49]:
def dot_product_attention(q,k,v,mask=None,scale=True):

    matmul_qk = q @ k.T

    if scale:
        dk = torch.tensor(k.shape[-1],dtype=torch.float)
        matmul_qk /= torch.sqrt(dk)

    if mask is not None:
        matmul_qk = matmul_qk + (1-mask) * -1e9

    attention_weights = F.softmax(matmul_qk,dim=-1)

    print(f"Attention weights:\n{attention_weights}")

    attention_output = attention_weights @ v

    return attention_output

dot_product_attention(q,k,v,mask=m)

Attention weights:
tensor([[1., 0.],
        [1., 0.]])


tensor([[0., 1., 0.],
        [0., 1., 0.]])

### 1.2 Causal dot product attention



In [55]:
def causal_dot_product_attention(q,k,v,scale=True):
    
    # The size of the mask equal the penultimate dimention of the query. (seq length)
    mask_size = q.shape[-2]

    mask = torch.tril(torch.ones(mask_size,mask_size))
    print(f"Causal mask:\n{mask}\n")

    return dot_product_attention(q,k,v,mask=mask,scale=scale)

causal_dot_product_attention(q,k,v)

Causal mask:
tensor([[1., 0.],
        [1., 1.]])

Attention weights:
tensor([[1.0000, 0.0000],
        [0.1503, 0.8497]])


tensor([[0.0000, 1.0000, 0.0000],
        [0.8497, 0.1503, 0.8497]])