In [1]:
import torch
import torch.nn as nn

#### **Simplified Self-Attention**


In [2]:
inputs = torch.tensor(
  [[0.43, 0.15, 0.89], # Your     (x^1)
   [0.55, 0.87, 0.66], # journey  (x^2)
   [0.57, 0.85, 0.64], # starts   (x^3)
   [0.22, 0.58, 0.33], # with     (x^4)
   [0.77, 0.25, 0.10], # one      (x^5)
   [0.05, 0.80, 0.55]] # step     (x^6)
)

**Single Input Attention Calculation**


In [8]:
query = inputs[1] # Getting the second row, or the features of the token "journey"

attn_scores_2 = torch.empty(inputs.shape[0])

# loop through the tokens in the sequence
for i, x_i in enumerate(inputs):
    # take the dot product of each token embedding vector
    attn_scores_2[i] = torch.dot(x_i, query)
    
    # Dot Product example:
    # torch.Tensor([1,2,3,4,5]).dot(torch.Tensor([2,1,2,1,2]))
    # (1*2)+(2*1)+(3*2)+(4*1)+(5*2) = tensor(24.)

attn_scores_2


tensor([0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865])

In [19]:
# Normalize the attention scores
attn_scores_norm = attn_scores_2 / torch.sum(attn_scores_2)
print(f"Unnormailzed: {attn_scores_2}")
print(f"Normailzed: {attn_scores_norm}")

Unnormailzed: tensor([0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865])
Normailzed: tensor([0.1455, 0.2278, 0.2249, 0.1285, 0.1077, 0.1656])


In [30]:
# But softmax is more desireable for normalization
# Do note that there are underflow and overflow issues that come from this softmax implementation
print(attn_scores_2.exp() / attn_scores_2.exp().sum())

# This softmax implementation is preffered.
print(attn_scores_2.softmax(dim=0))

tensor([0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581])
tensor([0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581])


In [46]:
# calculate the updated embeddings
inputs.T @ attn_scores_2.softmax(dim=0)

tensor([0.4419, 0.6515, 0.5683])

**Full Attention Calculation**


In [100]:
# Get the attention scores for the query
attn_scores = inputs @ inputs.T

# normalize w/ softmax
attn_weights = attn_scores.softmax(dim=-1)

In [101]:
# Compute vectors
attn_weights @ inputs

tensor([[0.4421, 0.5931, 0.5790],
        [0.4419, 0.6515, 0.5683],
        [0.4431, 0.6496, 0.5671],
        [0.4304, 0.6298, 0.5510],
        [0.4671, 0.5910, 0.5266],
        [0.4177, 0.6503, 0.5645]])

#### **Self-Attention**


In [103]:
x_2 = inputs[1] # second input token
d_in = inputs.shape[1] # The input embedding size, d=3
d_out = 2 # the output embedding size

In [133]:
torch.manual_seed(42)
# We will use three matricies to project the embedded tokens, into:
# query vector: What we are "interested in"
W_query = nn.Parameter(torch.randn(d_in, d_out), requires_grad=False)
# key vector: What we have
W_key = nn.Parameter(torch.randn(d_in, d_out), requires_grad=False)
# Value vector: what information to communicate if it is "interesting"
W_value = nn.Parameter(torch.randn(d_in, d_out), requires_grad=False)

In [134]:
query = inputs @ W_query
keys = inputs @ W_key
values = inputs @ W_value

In [135]:
qk_attn = query @ keys.T

In [136]:
# using the key dimension
d_k = keys.shape[1]
# divide by sqrt if embedding dimension for scalling
# And apply softmax
qk_norm = (qk_attn * d_k**-0.5).softmax(dim=-1)

In [138]:
qk_norm

tensor([[0.1596, 0.1217, 0.1202, 0.2238, 0.1343, 0.2405],
        [0.1686, 0.1487, 0.1473, 0.1899, 0.1408, 0.2047],
        [0.1688, 0.1501, 0.1487, 0.1882, 0.1417, 0.2024],
        [0.1686, 0.1590, 0.1581, 0.1772, 0.1523, 0.1848],
        [0.1690, 0.1798, 0.1798, 0.1526, 0.1645, 0.1543],
        [0.1670, 0.1448, 0.1435, 0.1945, 0.1422, 0.2080]])

In [137]:
qk_norm @ values

tensor([[0.5141, 0.3639],
        [0.5633, 0.3251],
        [0.5659, 0.3221],
        [0.5839, 0.2941],
        [0.6180, 0.2539],
        [0.5575, 0.3262]])

**Putting it all together**


In [150]:
class SelfAttention_v1(nn.Module):
    def __init__(self, d_in, d_out):
        super().__init__()
        self.W_query = nn.Parameter(torch.randn(d_in, d_out))
        self.W_key = nn.Parameter(torch.randn(d_in, d_out))
        self.W_value = nn.Parameter(torch.randn(d_in, d_out))
    
    def forward(self, x:torch.Tensor):
        q = x @ self.W_query
        k = x @ self.W_key
        v = x @ self.W_value

        # multiply query and keys
        attn_scores = q @ k.T

        # Scaled normalization
        d_k = k.shape[1]
        attn_weights = (attn_scores * d_k**-0.5).softmax(dim=-1)

        # qk normalized matmul values to get output
        out = attn_weights @ v
        return out

In [151]:
torch.manual_seed(42)
self_attn = SelfAttention_v1(3, 2)

self_attn(inputs)

tensor([[0.5141, 0.3639],
        [0.5633, 0.3251],
        [0.5659, 0.3221],
        [0.5839, 0.2941],
        [0.6180, 0.2539],
        [0.5575, 0.3262]], grad_fn=<MmBackward0>)

In [154]:
class SelfAttention_v2(nn.Module):
    """Implementation using nn.Linear instead of matrix multiplication"""
    def __init__(self, d_in, d_out, bias=False):
        super().__init__()
        self.W_query = nn.Linear(d_in, d_out, bias=bias)
        self.W_keys = nn.Linear(d_in, d_out, bias=bias)
        self.W_value = nn.Linear(d_in, d_out, bias=bias)
    
    def forward(self, x:torch.Tensor):
        q = self.W_query(x)
        k = self.W_keys(x)
        v = self.W_value(x)

        # mul q & k
        attn_scores = q @ k.T
        
        # normalize
        attn_weights = (attn_scores * k.shape[1]**-0.5).softmax(dim=-1)

        out = attn_weights @ v
        return out

In [155]:
torch.manual_seed(42)
attn_v2 = SelfAttention_v2(3, 2)
attn_v2(inputs)

tensor([[0.3755, 0.2777],
        [0.3761, 0.2831],
        [0.3761, 0.2833],
        [0.3768, 0.2763],
        [0.3754, 0.2836],
        [0.3772, 0.2746]], grad_fn=<MmBackward0>)

#### **Causal Attention**


In [186]:
class CausalSelfAttention(nn.Module):
    """Decoder Only Attention"""
    def __init__(self, d_in:int, d_out:int, bias=False):
        super().__init__()
        self.W_query = nn.Linear(d_in, d_out, bias=bias)
        self.W_key = nn.Linear(d_in, d_out, bias=bias)
        self.W_value = nn.Linear(d_in, d_out, bias=bias)
    
    def forward(self, x:torch.Tensor):
        q: torch.Tensor = self.W_query(x)
        k: torch.Tensor = self.W_key(x)
        v: torch.Tensor = self.W_value(x)

        # compute attn scores
        attn_scores = q @ k.T

        # Since we can only attend to what is previously shown
        attn_mask = ~torch.ones(attn_scores.shape).tril().bool()
        # We need to now mask with (-inf) so softmax only deals with what we have
        masked_attn = attn_scores.masked_fill(attn_mask, -torch.inf)
        # Now scaled norm
        attn_weights = (masked_attn * k.shape[1]**-0.5).softmax(dim=-1)

        out = attn_weights @ v
        return out




In [187]:
torch.manual_seed(42)
attn_v3 = CausalSelfAttention(3,2)
attn_v3(inputs)


tensor([[0.4429, 0.1077],
        [0.4656, 0.2597],
        [0.4732, 0.3030],
        [0.4135, 0.2921],
        [0.4078, 0.2567],
        [0.3772, 0.2746]], grad_fn=<MmBackward0>)

**Add in Dropout & Make it more compact**


In [255]:
class CausalAttention(nn.Module):
    """Decoder Only Attention"""
    def __init__(self, d_in:int, d_out:int, context_length:int, dropout:float, bias=False):
        super().__init__()
        self.d_out = d_out
        self.W_query = nn.Linear(d_in, d_out, bias=bias)
        self.W_key = nn.Linear(d_in, d_out, bias=bias)
        self.W_value = nn.Linear(d_in, d_out, bias=bias)
        self.dropout = nn.Dropout(dropout) 
        # this is really a mask but GPT-2 and hf have buffer
        self.register_buffer('mask', torch.triu(torch.ones(context_length, context_length), diagonal=1))

    def forward(self, x:torch.Tensor):
        # create dimensions
        _, num_tokens, _ = x.shape # (B, tokens, )
        # print(num_tokens)
        q: torch.Tensor = self.W_query(x)
        k: torch.Tensor = self.W_key(x)
        v: torch.Tensor = self.W_value(x)

        # compute attn scores
        attn_scores = q @ k.transpose(1,2)

        # New, _ ops are in-place
        # `:num_tokens` to account for cases where the number of tokens in the batch is smaller than the supported context_size
        attn_scores.masked_fill_(self.mask.bool()[:num_tokens, :num_tokens], -torch.inf)  

        attn_weights = torch.softmax(attn_scores * keys.size(-1)**-0.5, dim=-1)

        attn_weights = self.dropout(attn_weights)

        out = attn_weights @ v
        return out


In [251]:
# set an example number of tokens
token_count = 6 
inputs = torch.randn((token_count, 3))
print(inputs)

batch = torch.stack((inputs, inputs), dim=0)
batch.shape

tensor([[-0.3885, -0.9343,  1.0533],
        [ 0.1388, -0.2044, -2.2685],
        [-0.9133, -0.4204,  1.3111],
        [-0.2199,  0.1838,  0.2293],
        [ 0.6177, -0.2876,  0.8218],
        [ 0.1512,  0.1036, -2.1996]])


torch.Size([2, 6, 3])

In [252]:
torch.manual_seed(123)

context_length = batch.shape[1]
ca = CausalAttention(d_in, d_out, context_length, 0.0)

context_vecs = ca(batch)

print(context_vecs)
print("context_vecs.shape:", context_vecs.shape)

6
tensor([[[ 0.2944,  0.8538],
         [ 0.4221, -0.2458],
         [ 0.3523,  0.3993],
         [ 0.2727,  0.2175],
         [ 0.0894,  0.3412],
         [ 0.2619, -0.1882]],

        [[ 0.2944,  0.8538],
         [ 0.4221, -0.2458],
         [ 0.3523,  0.3993],
         [ 0.2727,  0.2175],
         [ 0.0894,  0.3412],
         [ 0.2619, -0.1882]]], grad_fn=<UnsafeViewBackward0>)
context_vecs.shape: torch.Size([2, 6, 2])


#### **Multi-head Attention**

Adding multiple heads is essentially stacking multiple layers of single attention heads.
This creates multiple latent spaces, so the model can attend to different to different information.


In [256]:
class MultiHeadAttentionWrapper(nn.Module):

    def __init__(self, d_in:int, d_out:int, context_length:int, dropout:float, num_heads=1, bias=False):
        super().__init__()

        self.heads = nn.ModuleList(
            [CausalAttention(d_in, d_out, context_length, dropout, bias) for _ in range(num_heads)]
        )
    
    def forward(self, x:torch.Tensor):
        return torch.cat([head(x) for head in self.heads], dim=-1)

In [257]:
# set an example number of tokens
token_count = 6
inputs = torch.randn((token_count, 3))
print(inputs)

batch = torch.stack((inputs, inputs), dim=0)
batch.shape

tensor([[-0.1690,  0.9178, -0.3885],
        [-0.9343, -0.4991, -1.0867],
        [ 0.9624,  0.2492, -0.4845],
        [-2.0929,  0.0983, -0.0935],
        [ 0.2662, -0.5850, -0.3430],
        [-0.6821, -0.9887, -1.7018]])


torch.Size([2, 6, 3])

In [269]:
torch.manual_seed(123)

context_length = batch.shape[1] # This is the number of tokens
d_in, d_out = inputs.size(-1), 4
mha = MultiHeadAttentionWrapper(
    d_in, d_out, context_length, 0.0, num_heads=4
)

context_vecs = mha(batch)

print(context_vecs[0,:,:])
print("context_vecs.shape:", context_vecs.shape)

tensor([[ 0.3116,  0.5603,  0.0720,  0.4299,  0.6344, -0.1544,  0.1832,  0.2823,
          0.0813,  0.5999, -0.2545, -0.1395, -0.3446,  0.2793, -0.1714, -0.0352],
        [ 0.3746,  0.2453, -0.4253,  0.0037,  0.5332,  0.0141,  0.3457, -0.0426,
         -0.5745,  0.4919,  0.0050, -0.2978,  0.0199,  0.2265, -0.0673, -0.2861],
        [ 0.2100,  0.4458, -0.1023,  0.1582,  0.4404,  0.0066,  0.4027, -0.2082,
         -0.0324,  0.3846, -0.0236, -0.3122, -0.0088,  0.0382, -0.0439, -0.3194],
        [ 0.3731, -0.0134, -0.5013, -0.0939,  0.4094, -0.0367,  0.2778, -0.0520,
         -0.6470,  0.3848, -0.0049, -0.1563, -0.0135,  0.2293, -0.0666, -0.1699],
        [ 0.2431,  0.0763, -0.3272, -0.0530,  0.2710,  0.0839,  0.2938, -0.1943,
         -0.3923,  0.2629,  0.0495, -0.2164,  0.0562,  0.1424, -0.0290, -0.1906],
        [ 0.3435, -0.0360, -0.5659, -0.1846,  0.3624, -0.0688,  0.2670, -0.0778,
         -0.4533,  0.2617,  0.1422, -0.3615,  0.1821,  0.1153, -0.0007, -0.3660]],
       grad_fn=<Slice

**Adding Weight Splits**

Instead of creating a module list of attention heads, we instead created one MultiHeaded Attention implementation,
where we define one q,k & v weight matrix, and split then to obtain the separate heads.


In [270]:
class MultiHeadAttention(nn.Module):

    def __init__(self, d_in:int, d_out:int, context_length:int, dropout:float, num_heads=1, bias=False):
        super().__init__()
        assert(d_out % num_heads == 0), "The output dimension must be divisible by the number of heads"

        self.d_out = d_out
        self.num_heads = num_heads
        self.head_dim = d_out // num_heads # This reduces the projection dimention to match the desired output.

        # Define q, k, v. Remember these...
        self.W_query = nn.Linear(d_in, d_out, bias=bias)
        self.W_key = nn.Linear(d_in, d_out, bias=bias)
        self.W_value = nn.Linear(d_in, d_out, bias=bias)

        self.o_proj = nn.Linear(d_out, d_out) # Linear layer to combine head outputs
        self.dropout = nn.Dropout(dropout)

        # Create mask for decoder block
        self.register_buffer("mask", torch.triu(torch.ones(context_length, context_length), diagonal=1))
    

    def forward(self, x:torch.Tensor):
        b, num_tokens, d_in = x.shape

        keys = self.W_key(x) # Shape: (b, num_tokens, d_out)
        queries = self.W_query(x)
        values = self.W_value(x)

        # We implicitly split the matrix by adding a `num_heads` dimension
        # Unroll last dim: (b, num_tokens, d_out) -> (b, num_tokens, num_heads, head_dim)
        keys = keys.view(b, num_tokens, self.num_heads, self.head_dim) 
        values = values.view(b, num_tokens, self.num_heads, self.head_dim)
        queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)

        # Compute scaled dot-product attention (aka self-attention) with a causal mask
        attn_scores = queries @ keys.transpose(2, 3)  # Dot product for each head

        # Original mask truncated to the number of tokens and converted to boolean
        mask_bool = self.mask.bool()[:num_tokens, :num_tokens]

        # Use the mask to fill attention scores
        attn_scores.masked_fill_(mask_bool, -torch.inf)
        
        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
        attn_weights = self.dropout(attn_weights)

        # Shape: (b, num_tokens, num_heads, head_dim)
        context_vec = (attn_weights @ values).transpose(1, 2) 
        
        # Combine heads, where self.d_out = self.num_heads * self.head_dim
        context_vec = context_vec.contiguous().view(b, num_tokens, self.d_out)
        context_vec = self.out_proj(context_vec) # optional projection

        return context_vec

In [274]:

# set an example number of tokens
token_count = 6
inputs = torch.randn((token_count, 3))
print(inputs)

batch = torch.stack((inputs, inputs), dim=0)
batch.shape

tensor([[-0.1606, -0.4015, -0.4845],
        [-2.0929, -0.8199, -0.4210],
        [-0.9620,  1.2825,  0.8768],
        [ 1.6221, -0.9887, -1.7018],
        [-0.7498, -1.1285,  0.4135],
        [ 0.2892,  2.2473, -0.8036]])


torch.Size([2, 6, 3])

In [275]:
torch.manual_seed(123)

context_length = batch.shape[1] # This is the number of tokens
d_in, d_out = inputs.size(-1), 4
mha = MultiHeadAttentionWrapper(
    d_in, d_out, context_length, 0.0, num_heads=4
)

context_vecs = mha(batch)

print(context_vecs[0,:,:])
print("context_vecs.shape:", context_vecs.shape)

tensor([[ 0.0928,  0.0065, -0.3117, -0.1939,  0.0582,  0.1435,  0.3198, -0.4179,
         -0.4708,  0.0620,  0.1877, -0.2482,  0.2417, -0.0657,  0.0599, -0.3094],
        [ 0.3648, -0.6721, -1.0088, -0.6346, -0.0393,  0.3453,  0.3795, -0.5872,
         -1.2055,  0.0749,  0.2064, -0.0673,  0.3268,  0.2618,  0.0416, -0.1495],
        [ 0.2763, -0.4179, -0.5787, -0.3014, -0.1499,  0.5634,  0.3281, -0.6037,
         -0.5574,  0.1637, -0.1271,  0.2432, -0.0517,  0.4222, -0.0752,  0.2217],
        [ 0.2185, -0.1863, -0.3789, -0.1514,  0.0727,  0.1664,  0.0779, -0.0478,
         -0.2971,  0.1395,  0.3470, -0.6590,  0.1720,  0.0974,  0.0172, -0.1974],
        [ 0.1482, -0.4734, -0.4635, -0.3274, -0.0259,  0.1764,  0.1157, -0.1898,
         -0.8554, -0.0261,  0.2361, -0.1326,  0.2441,  0.0483,  0.0547, -0.1571],
        [ 0.1668,  0.7374,  0.0127,  0.2508, -0.0969,  0.3478,  0.3665, -0.6184,
          0.2565,  0.6251, -0.3679, -0.0152, -0.1581,  0.1714, -0.0993, -0.1328]],
       grad_fn=<Slice