Scaled dot product attention::

- Project each token into query, key, value vectors
- Cal attention scores using dot prod as the similarity func for similarity between query and key, and scale with sqrt(dim)
- Cal attention weights by normalizing the score with softmax
- Update the embedding vector by mul weights with value: x' = SUM_j(W_j * V_j)

In [12]:
from transformers import AutoTokenizer, AutoConfig

from torch import nn, bmm
import numpy as np
from math import sqrt

In [2]:
model_name = 'bert-base-uncased'
tokenizer = AutoTokenizer.from_pretrained(model_name)

In [3]:
text = 'Its better to take action than wonder'
inputs = tokenizer(text, return_tensors='pt', add_special_tokens=False)

Dense embedding vector (without any context info)

In [6]:
config = AutoConfig.from_pretrained(model_name)
token_emb = nn.Embedding(config.vocab_size, config.hidden_size)
token_emb

Embedding(30522, 768)

Generate embeddings for the input - Lookup the embedding vector

In [8]:
input_embeds = token_emb(inputs.input_ids)

In [9]:
input_embeds.size()

torch.Size([1, 7, 768])

Calc query, key and value: Naive calc, since query, key and value are assumed to be equal to embedding

In [20]:
#Keep these equal for now
query = key = value =input_embeds

dim_k = key.size(-1)

#bmm is batch wise matrix-matrix multiplication
#Attention scores are scaled with dim_k to bring the values in a managable range, so we can apply softmax later
scores = bmm(query, key.transpose(1,2))/sqrt(dim_k)

In [16]:
print(scores)
print(scores.size())

tensor([[[28.3634, -2.5400, -1.0776,  0.3073,  0.0299, -0.1742,  1.0891],
         [-2.5400, 29.2486,  1.0106, -0.4001,  1.2901,  0.5422, -0.6729],
         [-1.0776,  1.0106, 29.1372,  0.9001,  0.9464, -1.7248, -1.0170],
         [ 0.3073, -0.4001,  0.9001, 25.7758, -1.0010,  0.7502, -0.1281],
         [ 0.0299,  1.2901,  0.9464, -1.0010, 27.0208,  0.8866,  0.7322],
         [-0.1742,  0.5422, -1.7248,  0.7502,  0.8866, 27.4524,  0.5691],
         [ 1.0891, -0.6729, -1.0170, -0.1281,  0.7322,  0.5691, 27.2870]]],
       grad_fn=<DivBackward0>)
torch.Size([1, 7, 7])


Softmax to normalize the attention weights

In [19]:
weights = nn.functional.softmax(scores, dim=-1)
weights.sum(dim=-1)

tensor([[1., 1., 1., 1., 1., 1., 1.]], grad_fn=<SumBackward1>)

Multiply attention weights by value

In [21]:
weights.size()

torch.Size([1, 7, 7])

In [22]:
attn_outputs = bmm(weights, value)
attn_outputs.shape

torch.Size([1, 7, 768])

Simplified scaled dot product attention Function

In [23]:
def scaled_dot_prod_attn(query, key, value):

    dim_k = query.size(-1)
    scores = bmm(query, key.transpose(1,2))/sqrt(dim_k)
    weights = nn.functional.softmax(scores, dim=-1)
    attn_outputs = bmm(weights, value)
    return attn_outputs