In [1]:
import torch
import torch.nn as nn

#### **Simplified Self-Attention**


In [2]:
inputs = torch.tensor(
  [[0.43, 0.15, 0.89], # Your     (x^1)
   [0.55, 0.87, 0.66], # journey  (x^2)
   [0.57, 0.85, 0.64], # starts   (x^3)
   [0.22, 0.58, 0.33], # with     (x^4)
   [0.77, 0.25, 0.10], # one      (x^5)
   [0.05, 0.80, 0.55]] # step     (x^6)
)

**Single Input Attention Calculation**


In [8]:
query = inputs[1] # Getting the second row, or the features of the token "journey"

attn_scores_2 = torch.empty(inputs.shape[0])

# loop through the tokens in the sequence
for i, x_i in enumerate(inputs):
    # take the dot product of each token embedding vector
    attn_scores_2[i] = torch.dot(x_i, query)
    
    # Dot Product example:
    # torch.Tensor([1,2,3,4,5]).dot(torch.Tensor([2,1,2,1,2]))
    # (1*2)+(2*1)+(3*2)+(4*1)+(5*2) = tensor(24.)

attn_scores_2


tensor([0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865])

In [19]:
# Normalize the attention scores
attn_scores_norm = attn_scores_2 / torch.sum(attn_scores_2)
print(f"Unnormailzed: {attn_scores_2}")
print(f"Normailzed: {attn_scores_norm}")

Unnormailzed: tensor([0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865])
Normailzed: tensor([0.1455, 0.2278, 0.2249, 0.1285, 0.1077, 0.1656])


In [30]:
# But softmax is more desireable for normalization
# Do note that there are underflow and overflow issues that come from this softmax implementation
print(attn_scores_2.exp() / attn_scores_2.exp().sum())

# This softmax implementation is preffered.
print(attn_scores_2.softmax(dim=0))

tensor([0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581])
tensor([0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581])


In [46]:
# calculate the updated embeddings
inputs.T @ attn_scores_2.softmax(dim=0)

tensor([0.4419, 0.6515, 0.5683])

**Full Attention Calculation**


In [100]:
# Get the attention scores for the query
attn_scores = inputs @ inputs.T

# normalize w/ softmax
attn_weights = attn_scores.softmax(dim=-1)

In [101]:
# Compute vectors
attn_weights @ inputs

tensor([[0.4421, 0.5931, 0.5790],
        [0.4419, 0.6515, 0.5683],
        [0.4431, 0.6496, 0.5671],
        [0.4304, 0.6298, 0.5510],
        [0.4671, 0.5910, 0.5266],
        [0.4177, 0.6503, 0.5645]])

#### **Self-Attention**


In [103]:
x_2 = inputs[1] # second input token
d_in = inputs.shape[1] # The input embedding size, d=3
d_out = 2 # the output embedding size

In [133]:
torch.manual_seed(42)
# We will use three matricies to project the embedded tokens, into:
# query vector: What we are "interested in"
W_query = nn.Parameter(torch.randn(d_in, d_out), requires_grad=False)
# key vector: What we have
W_key = nn.Parameter(torch.randn(d_in, d_out), requires_grad=False)
# Value vector: what information to communicate if it is "interesting"
W_value = nn.Parameter(torch.randn(d_in, d_out), requires_grad=False)

In [134]:
query = inputs @ W_query
keys = inputs @ W_key
values = inputs @ W_value

In [135]:
qk_attn = query @ keys.T

In [136]:
# using the key dimension
d_k = keys.shape[1]
# divide by sqrt if embedding dimension for scalling
# And apply softmax
qk_norm = (qk_attn * d_k**-0.5).softmax(dim=-1)

In [138]:
qk_norm

tensor([[0.1596, 0.1217, 0.1202, 0.2238, 0.1343, 0.2405],
        [0.1686, 0.1487, 0.1473, 0.1899, 0.1408, 0.2047],
        [0.1688, 0.1501, 0.1487, 0.1882, 0.1417, 0.2024],
        [0.1686, 0.1590, 0.1581, 0.1772, 0.1523, 0.1848],
        [0.1690, 0.1798, 0.1798, 0.1526, 0.1645, 0.1543],
        [0.1670, 0.1448, 0.1435, 0.1945, 0.1422, 0.2080]])

In [137]:
qk_norm @ values

tensor([[0.5141, 0.3639],
        [0.5633, 0.3251],
        [0.5659, 0.3221],
        [0.5839, 0.2941],
        [0.6180, 0.2539],
        [0.5575, 0.3262]])

**Putting it all together**


In [150]:
class SelfAttention_v1(nn.Module):
    def __init__(self, d_in, d_out):
        super().__init__()
        self.W_query = nn.Parameter(torch.randn(d_in, d_out))
        self.W_key = nn.Parameter(torch.randn(d_in, d_out))
        self.W_value = nn.Parameter(torch.randn(d_in, d_out))
    
    def forward(self, x:torch.Tensor):
        q = x @ self.W_query
        k = x @ self.W_key
        v = x @ self.W_value

        # multiply query and keys
        attn_scores = q @ k.T

        # Scaled normalization
        d_k = k.shape[1]
        attn_weights = (attn_scores * d_k**-0.5).softmax(dim=-1)

        # qk normalized matmul values to get output
        out = attn_weights @ v
        return out

In [151]:
torch.manual_seed(42)
self_attn = SelfAttention_v1(3, 2)

self_attn(inputs)

tensor([[0.5141, 0.3639],
        [0.5633, 0.3251],
        [0.5659, 0.3221],
        [0.5839, 0.2941],
        [0.6180, 0.2539],
        [0.5575, 0.3262]], grad_fn=<MmBackward0>)

In [154]:
class SelfAttention_v2(nn.Module):
    """Implementation using nn.Linear instead of matrix multiplication"""
    def __init__(self, d_in, d_out, bias=False):
        super().__init__()
        self.W_query = nn.Linear(d_in, d_out, bias=bias)
        self.W_keys = nn.Linear(d_in, d_out, bias=bias)
        self.W_value = nn.Linear(d_in, d_out, bias=bias)
    
    def forward(self, x:torch.Tensor):
        q = self.W_query(x)
        k = self.W_keys(x)
        v = self.W_value(x)

        # mul q & k
        attn_scores = q @ k.T
        
        # normalize
        attn_weights = (attn_scores * k.shape[1]**-0.5).softmax(dim=-1)

        out = attn_weights @ v
        return out

In [155]:
torch.manual_seed(42)
attn_v2 = SelfAttention_v2(3, 2)
attn_v2(inputs)

tensor([[0.3755, 0.2777],
        [0.3761, 0.2831],
        [0.3761, 0.2833],
        [0.3768, 0.2763],
        [0.3754, 0.2836],
        [0.3772, 0.2746]], grad_fn=<MmBackward0>)

#### **Causal Attention**


In [186]:
class CausalSelfAttention(nn.Module):
    """Decoder Only Attention"""
    def __init__(self, d_in:int, d_out:int, bias=False):
        super().__init__()
        self.W_query = nn.Linear(d_in, d_out, bias=bias)
        self.W_key = nn.Linear(d_in, d_out, bias=bias)
        self.W_value = nn.Linear(d_in, d_out, bias=bias)
    
    def forward(self, x:torch.Tensor):
        q: torch.Tensor = self.W_query(x)
        k: torch.Tensor = self.W_key(x)
        v: torch.Tensor = self.W_value(x)

        # compute attn scores
        attn_scores = q @ k.T

        # Since we can only attend to what is previously shown
        attn_mask = ~torch.ones(attn_scores.shape).tril().bool()
        # We need to now mask with (-inf) so softmax only deals with what we have
        masked_attn = attn_scores.masked_fill(attn_mask, -torch.inf)
        # Now scaled norm
        attn_weights = (masked_attn * k.shape[1]**-0.5).softmax(dim=-1)

        out = attn_weights @ v
        return out




In [187]:
torch.manual_seed(42)
attn_v3 = CausalSelfAttention(3,2)
attn_v3(inputs)


tensor([[0.4429, 0.1077],
        [0.4656, 0.2597],
        [0.4732, 0.3030],
        [0.4135, 0.2921],
        [0.4078, 0.2567],
        [0.3772, 0.2746]], grad_fn=<MmBackward0>)

**Add in Dropout & Make it more compact**


In [235]:
class CausalAttention(nn.Module):
    """Decoder Only Attention"""
    def __init__(self, d_in:int, d_out:int, context_length:int, dropout:float, bias=False):
        super().__init__()
        self.d_out = d_out
        self.W_query = nn.Linear(d_in, d_out, bias=bias)
        self.W_key = nn.Linear(d_in, d_out, bias=bias)
        self.W_value = nn.Linear(d_in, d_out, bias=bias)
        self.dropout = nn.Dropout(dropout) 
        # this is really a mask but GPT-2 and hf have buffer
        self.register_buffer('mask', torch.triu(torch.ones(context_length, context_length), diagonal=1))

    def forward(self, x:torch.Tensor):
        # create dimensions
        _, num_tokens, _ = x.shape # (B, tokens, )
        print(num_tokens)
        q: torch.Tensor = self.W_query(x)
        k: torch.Tensor = self.W_key(x)
        v: torch.Tensor = self.W_value(x)

        # compute attn scores
        attn_scores = q @ k.transpose(1,2)

        # New, _ ops are in-place
        # `:num_tokens` to account for cases where the number of tokens in the batch is smaller than the supported context_size
        attn_scores.masked_fill_(self.mask.bool()[:num_tokens, :num_tokens], -torch.inf)  

        attn_weights = torch.softmax(attn_scores * keys.size(-1)**-0.5, dim=-1)

        attn_weights = self.dropout(attn_weights)

        out = attn_weights @ v
        return out


In [242]:
# set an example number of tokens
token_count = 12
inputs = torch.randn((token_count, 3))
print(inputs)

batch = torch.stack((inputs, inputs), dim=0)
batch.shape

tensor([[-0.1690,  0.9178,  1.5810],
        [ 1.3010,  1.2753, -0.2010],
        [-0.1606, -0.4015, -0.4845],
        [-2.0929, -0.8199, -0.4210],
        [-0.9620,  1.2825,  0.8768],
        [ 1.6221, -0.9887, -1.7018],
        [-0.7498, -1.1285,  0.4135],
        [ 0.2892,  2.2473, -0.8036]])


torch.Size([2, 8, 3])

In [243]:
torch.manual_seed(123)

context_length = batch.shape[1]
ca = CausalAttention(d_in, d_out, context_length, 0.0)

context_vecs = ca(batch)

print(context_vecs)
print("context_vecs.shape:", context_vecs.shape)

8
tensor([[[-0.5738,  0.2126],
         [-0.8129, -0.2973],
         [-0.3854, -0.1730],
         [ 0.4671,  0.1153],
         [ 0.0364, -0.0357],
         [ 0.0048, -0.0990],
         [ 0.0583,  0.0115],
         [ 0.1006, -0.1558]],

        [[-0.5738,  0.2126],
         [-0.8129, -0.2973],
         [-0.3854, -0.1730],
         [ 0.4671,  0.1153],
         [ 0.0364, -0.0357],
         [ 0.0048, -0.0990],
         [ 0.0583,  0.0115],
         [ 0.1006, -0.1558]]], grad_fn=<UnsafeViewBackward0>)
context_vecs.shape: torch.Size([2, 8, 2])


#### **Multi-head Attention**
