In [None]:
# Before starting "Casual Self Attention"
import torch
from torch import tensor, nn

# input-embedding:
inputs = tensor(
  [[
      [0.43, 0.15, 0.89],  # Your     (x^1)
      [0.55, 0.87, 0.66]], # journey  (x^2)

    [
      [0.57, 0.85, 0.64],  # starts   (x^3)
      [0.22, 0.58, 0.33]], # with     (x^4)

    [
      [0.77, 0.25, 0.10],  # one      (x^5)
      [0.05, 0.80, 0.55]]] # step     (x^6)
)

inputs.shape # (3 batches with 2 token-embedding each & token_dim = 3)

torch.Size([3, 2, 3])

In [3]:
torch.manual_seed(42)

d_in = inputs.shape[-1]
d_out = 4

W_q = nn.Linear(in_features=d_in, out_features=d_out)
W_k = nn.Linear(in_features=d_in, out_features=d_out)
W_v = nn.Linear(in_features=d_in, out_features=d_out)

queries = W_q(inputs) # (3, 2, 4)
keys = W_k(inputs)    # (3, 2, 4)
values = W_v(inputs)  # (3, 2, 4)

attention_scores = queries @ keys.transpose(-1, -2)
attention_weights = torch.softmax(attention_scores / 4**0.5, dim=-1)
context_vectors = attention_weights @ values # (3, 2, 2) @ (3, 2, 3)
context_vectors.shape

torch.Size([3, 2, 4])

In [4]:
context_vectors 

tensor([[[0.4616, 0.3298, 0.3860, 0.5513],
         [0.4636, 0.3298, 0.3841, 0.5518]],

        [[0.4868, 0.1773, 0.3031, 0.4641],
         [0.4869, 0.1776, 0.3032, 0.4643]],

        [[0.5248, 0.0926, 0.3180, 0.3088],
         [0.5146, 0.0927, 0.3158, 0.3225]]], grad_fn=<UnsafeViewBackward0>)

In [5]:
class SelfAttention(nn.Module):
    def __init__(self, d_in, d_out, qkv_bias = False):
        '''Let's Consider Batched inputs'''
        super().__init__()
        self.W_q = nn.Linear(in_features=d_in, out_features=d_out, bias=qkv_bias)
        self.W_k = nn.Linear(in_features=d_in, out_features=d_out, bias=qkv_bias)
        self.W_v = nn.Linear(in_features=d_in, out_features=d_out, bias=qkv_bias)
    
    def forward(self, x):
        '''x: 3D matrix, with (batch_size, n_tokens, d_in)'''
        Q = self.W_q(x) # (8, 6, 2)
        K = self.W_k(x)
        V = self.W_v(x)
        attention_score = torch.matmul(Q, K.transpose(-1, -2))  
        attention_weights = torch.softmax(attention_score/ K.shape[-1]**0.5 , dim=-1)
        context_vectors = attention_weights @ V
        return context_vectors

In [6]:
inputs = torch.tensor(
  [[0.43, 0.15, 0.89], # Your     (x^1)
   [0.55, 0.87, 0.66], # journey  (x^2)
   [0.57, 0.85, 0.64], # starts   (x^3)
   [0.22, 0.58, 0.33], # with     (x^4)
   [0.77, 0.25, 0.10], # one      (x^5)
   [0.05, 0.80, 0.55]] # step     (x^6)
)

In [7]:
# inputs.shape
# inputs = torch.rand(4, 6, 3)

In [8]:
sa_v2 = SelfAttention(d_in=inputs.shape[-1], d_out=2)
queries = sa_v2.W_q(inputs)
keys = sa_v2.W_k(inputs)
values = sa_v2.W_v(inputs)
attn_scores = queries @ keys.transpose(-1, -2)
attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
print(attn_weights)

tensor([[0.1582, 0.1631, 0.1634, 0.1729, 0.1726, 0.1698],
        [0.1771, 0.1581, 0.1585, 0.1691, 0.1724, 0.1649],
        [0.1772, 0.1584, 0.1587, 0.1688, 0.1721, 0.1648],
        [0.1743, 0.1610, 0.1613, 0.1680, 0.1702, 0.1653],
        [0.1760, 0.1656, 0.1656, 0.1637, 0.1651, 0.1640],
        [0.1735, 0.1582, 0.1585, 0.1705, 0.1734, 0.1659]],
       grad_fn=<SoftmaxBackward0>)


In [69]:
mask = attn_scores.triu(diagonal=1).bool()
# mask
modified_attn_weights = attn_scores.tril().masked_fill(mask, value=-torch.inf).softmax(dim=-1)
modified_attn_weights

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5399, 0.4601, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3691, 0.3150, 0.3159, 0.0000, 0.0000, 0.0000],
        [0.2675, 0.2392, 0.2396, 0.2538, 0.0000, 0.0000],
        [0.2151, 0.1973, 0.1973, 0.1940, 0.1963, 0.0000],
        [0.1764, 0.1547, 0.1552, 0.1721, 0.1761, 0.1655]],
       grad_fn=<SoftmaxBackward0>)

In [73]:
modified_attn_weights @ values

tensor([[ 0.1150, -0.1660],
        [-0.0324, -0.1057],
        [-0.0855, -0.0849],
        [-0.1074, -0.0574],
        [-0.0970, -0.0705],
        [-0.1154, -0.0480]], grad_fn=<MmBackward0>)