In [1]:
import torch
import torch.nn as nn

Lets compute a simple version of the attention for the token "journey"

In [2]:
inputs = torch.tensor(
    [[0.43, 0.15, 0.89], # Your
    [0.55, 0.87, 0.66], # journey
    [0.57, 0.85, 0.64], # starts
    [0.22, 0.58, 0.33], # with
    [0.77, 0.25, 0.10], # one
    [0.05, 0.80, 0.55]] # step
)

In [3]:
# computing the attention weights for journey
query = inputs[1]
attention_weight_for_journey = torch.empty(inputs.shape[0])
for i,token_embedding in enumerate(inputs):
    attention_weight_for_journey[i] = query @ token_embedding
attention_weight_for_journey = attention_weight_for_journey.softmax(dim=0)
print(attention_weight_for_journey)

tensor([0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581])


In [4]:
# computing the context vector for journey which a weighted sum of the inputs embedding

journey_context_vec = attention_weight_for_journey @ inputs
print(journey_context_vec)

tensor([0.4419, 0.6515, 0.5683])


In [5]:
# context vectors for all inputs token
attention_weights = (inputs @ inputs.T).softmax(dim=-1)
print("attention_weights =",attention_weights)
context_vectors = attention_weights @ inputs
print("context_vectors =",context_vectors)

attention_weights = tensor([[0.2098, 0.2006, 0.1981, 0.1242, 0.1220, 0.1452],
        [0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581],
        [0.1390, 0.2369, 0.2326, 0.1242, 0.1108, 0.1565],
        [0.1435, 0.2074, 0.2046, 0.1462, 0.1263, 0.1720],
        [0.1526, 0.1958, 0.1975, 0.1367, 0.1879, 0.1295],
        [0.1385, 0.2184, 0.2128, 0.1420, 0.0988, 0.1896]])
context_vectors = tensor([[0.4421, 0.5931, 0.5790],
        [0.4419, 0.6515, 0.5683],
        [0.4431, 0.6496, 0.5671],
        [0.4304, 0.6298, 0.5510],
        [0.4671, 0.5910, 0.5266],
        [0.4177, 0.6503, 0.5645]])


In [6]:
# creating a learnable attention
torch.manual_seed(123)
emb_dim = inputs.size(1)
W_q = nn.Parameter(torch.rand((emb_dim,emb_dim)),requires_grad=False)
W_k = nn.Parameter(torch.rand((emb_dim,emb_dim)),requires_grad=False)
W_v = nn.Parameter(torch.rand((emb_dim,emb_dim)),requires_grad=False)

In [7]:
query_2 = inputs[1] @ W_q
keys = inputs @ W_k
values = inputs @ W_v

attention_scores_2 = query_2 @ keys.T
attention_weights = (attention_scores_2 * torch.sqrt(torch.tensor(emb_dim))).softmax(dim=-1)

context_vector_2 = attention_weights @ values
print(context_vector_2)

tensor([0.7446, 1.1523, 1.2343])


In [8]:
from model.attention import SelfAttention_V2,MultiHeadAttention

In [9]:
self_attention = SelfAttention_V2(emb_dim,2)
print(self_attention(inputs))

tensor([[-0.2079,  0.2694],
        [-0.2078,  0.2744],
        [-0.2079,  0.2745],
        [-0.2089,  0.2738],
        [-0.2088,  0.2757],
        [-0.2087,  0.2729]], grad_fn=<MmBackward0>)


Causal Attention

In [10]:
queries = inputs @ W_q 
keys = inputs @ W_k
attention_scores = queries @ keys.T

In [11]:
# creating a mask for future predictions
mask = torch.triu(torch.ones_like(attention_scores),diagonal=1)
attention_scores = attention_scores.masked_fill(mask.bool(),-torch.inf)
attention_scores

tensor([[0.7616,   -inf,   -inf,   -inf,   -inf,   -inf],
        [1.7872, 2.0141,   -inf,   -inf,   -inf,   -inf],
        [1.7646, 1.9901, 1.9852,   -inf,   -inf,   -inf],
        [1.0664, 1.1947, 1.1916, 0.5897,   -inf,   -inf],
        [0.8601, 0.9968, 0.9950, 0.4947, 0.6817,   -inf],
        [1.3458, 1.4957, 1.4915, 0.7374, 0.9968, 0.8366]])

In [12]:
attention_weights = attention_scores.softmax(dim=-1)
attention_weights

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.4435, 0.5565, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2857, 0.3580, 0.3563, 0.0000, 0.0000, 0.0000],
        [0.2570, 0.2922, 0.2913, 0.1596, 0.0000, 0.0000],
        [0.2074, 0.2378, 0.2374, 0.1439, 0.1735, 0.0000],
        [0.1935, 0.2247, 0.2238, 0.1053, 0.1365, 0.1163]])

Multihead Atttention

In [13]:
batch = torch.cat([inputs[None],inputs[None]],dim=0)
batch.shape

torch.Size([2, 6, 3])

In [14]:
torch.manual_seed(123)
batch_size, context_length, d_in = batch.shape
d_out = 2
mha = MultiHeadAttention(d_in, d_out, context_length, 0.0, num_heads=2)
context_vecs = mha(batch)
print(context_vecs)
print("context_vecs.shape:", context_vecs.shape)

tensor([[[0.3190, 0.4858],
         [0.2940, 0.3909],
         [0.2853, 0.3604],
         [0.2692, 0.3882],
         [0.2634, 0.3938],
         [0.2574, 0.4036]],

        [[0.3190, 0.4858],
         [0.2940, 0.3909],
         [0.2853, 0.3604],
         [0.2692, 0.3882],
         [0.2634, 0.3938],
         [0.2574, 0.4036]]], grad_fn=<ViewBackward0>)
context_vecs.shape: torch.Size([2, 6, 2])
