In [None]:
import torch
import torch.nn as nn

#### **Simplified Self-Attention**


In [2]:
inputs = torch.tensor(
  [[0.43, 0.15, 0.89], # Your     (x^1)
   [0.55, 0.87, 0.66], # journey  (x^2)
   [0.57, 0.85, 0.64], # starts   (x^3)
   [0.22, 0.58, 0.33], # with     (x^4)
   [0.77, 0.25, 0.10], # one      (x^5)
   [0.05, 0.80, 0.55]] # step     (x^6)
)

**Single Input Attention Calculation**


In [None]:
query = inputs[1] # Getting the second row, or the features of the token "journey"

attn_scores_2 = torch.empty(inputs.shape[0])

# loop through the tokens in the sequence
for i, x_i in enumerate(inputs):
    # take the dot product of each token embedding vector
    attn_scores_2[i] = torch.dot(x_i, query)
    
    # Dot Product example:
    # torch.Tensor([1,2,3,4,5]).dot(torch.Tensor([2,1,2,1,2]))
    # (1*2)+(2*1)+(3*2)+(4*1)+(5*2) = tensor(24.)

attn_scores_2


In [None]:
# Normalize the attention scores
attn_scores_norm = attn_scores_2 / torch.sum(attn_scores_2)
print(f"Unnormailzed: {attn_scores_2}")
print(f"Normailzed: {attn_scores_norm}")

In [None]:
# But softmax is more desireable for normalization
# Do note that there are underflow and overflow issues that come from this softmax implementation
print(attn_scores_2.exp() / attn_scores_2.exp().sum())

# This softmax implementation is preffered.
print(attn_scores_2.softmax(dim=0))

In [None]:
# calculate the updated embeddings
inputs.T @ attn_scores_2.softmax(dim=0)

**Full Attention Calculation**


In [100]:
# Get the attention scores for the query
attn_scores = inputs @ inputs.T

# normalize w/ softmax
attn_weights = attn_scores.softmax(dim=-1)

In [None]:
# Compute vectors
attn_weights @ inputs

#### **Self-Attention**


In [103]:
x_2 = inputs[1] # second input token
d_in = inputs.shape[1] # The input embedding size, d=3
d_out = 2 # the output embedding size

In [133]:
torch.manual_seed(42)
# We will use three matricies to project the embedded tokens, into:
# query vector: What we are "interested in"
W_query = nn.Parameter(torch.randn(d_in, d_out), requires_grad=False)
# key vector: What we have
W_key = nn.Parameter(torch.randn(d_in, d_out), requires_grad=False)
# Value vector: what information to communicate if it is "interesting"
W_value = nn.Parameter(torch.randn(d_in, d_out), requires_grad=False)

In [134]:
query = inputs @ W_query
keys = inputs @ W_key
values = inputs @ W_value

In [135]:
qk_attn = query @ keys.T

In [136]:
# using the key dimension
d_k = keys.shape[1]
# divide by sqrt if embedding dimension for scalling
# And apply softmax
qk_norm = (qk_attn * d_k**-0.5).softmax(dim=-1)

In [None]:
qk_norm

In [None]:
qk_norm @ values

**Putting it all together**


In [150]:
class SelfAttention_v1(nn.Module):
    def __init__(self, d_in, d_out):
        super().__init__()
        self.W_query = nn.Parameter(torch.randn(d_in, d_out))
        self.W_key = nn.Parameter(torch.randn(d_in, d_out))
        self.W_value = nn.Parameter(torch.randn(d_in, d_out))
    
    def forward(self, x:torch.Tensor):
        q = x @ self.W_query
        k = x @ self.W_key
        v = x @ self.W_value

        # multiply query and keys
        attn_scores = q @ k.T

        # Scaled normalization
        d_k = k.shape[1]
        attn_weights = (attn_scores * d_k**-0.5).softmax(dim=-1)

        # qk normalized matmul values to get output
        out = attn_weights @ v
        return out

In [None]:
torch.manual_seed(42)
self_attn = SelfAttention_v1(3, 2)

self_attn(inputs)

In [154]:
class SelfAttention_v2(nn.Module):
    """Implementation using nn.Linear instead of matrix multiplication"""
    def __init__(self, d_in, d_out, bias=False):
        super().__init__()
        self.W_query = nn.Linear(d_in, d_out, bias=bias)
        self.W_keys = nn.Linear(d_in, d_out, bias=bias)
        self.W_value = nn.Linear(d_in, d_out, bias=bias)
    
    def forward(self, x:torch.Tensor):
        q = self.W_query(x)
        k = self.W_keys(x)
        v = self.W_value(x)

        # mul q & k
        attn_scores = q @ k.T
        
        # normalize
        attn_weights = (attn_scores * k.shape[1]**-0.5).softmax(dim=-1)

        out = attn_weights @ v
        return out

In [None]:
torch.manual_seed(42)
attn_v2 = SelfAttention_v2(3, 2)
attn_v2(inputs)

#### **Causal Attention**


In [186]:
class CausalSelfAttention(nn.Module):
    """Decoder Only Attention"""
    def __init__(self, d_in:int, d_out:int, bias=False):
        super().__init__()
        self.W_query = nn.Linear(d_in, d_out, bias=bias)
        self.W_key = nn.Linear(d_in, d_out, bias=bias)
        self.W_value = nn.Linear(d_in, d_out, bias=bias)
    
    def forward(self, x:torch.Tensor):
        q: torch.Tensor = self.W_query(x)
        k: torch.Tensor = self.W_key(x)
        v: torch.Tensor = self.W_value(x)

        # compute attn scores
        attn_scores = q @ k.T

        # Since we can only attend to what is previously shown
        attn_mask = ~torch.ones(attn_scores.shape).tril().bool()
        # We need to now mask with (-inf) so softmax only deals with what we have
        masked_attn = attn_scores.masked_fill(attn_mask, -torch.inf)
        # Now scaled norm
        attn_weights = (masked_attn * k.shape[1]**-0.5).softmax(dim=-1)

        out = attn_weights @ v
        return out




In [None]:
torch.manual_seed(42)
attn_v3 = CausalSelfAttention(3,2)
attn_v3(inputs)


**Add in Dropout & Make it more compact**


In [255]:
class CausalAttention(nn.Module):
    """Decoder Only Attention"""
    def __init__(self, d_in:int, d_out:int, context_length:int, dropout:float, bias=False):
        super().__init__()
        self.d_out = d_out
        self.W_query = nn.Linear(d_in, d_out, bias=bias)
        self.W_key = nn.Linear(d_in, d_out, bias=bias)
        self.W_value = nn.Linear(d_in, d_out, bias=bias)
        self.dropout = nn.Dropout(dropout) 
        # this is really a mask but GPT-2 and hf have buffer
        self.register_buffer('mask', torch.triu(torch.ones(context_length, context_length), diagonal=1))

    def forward(self, x:torch.Tensor):
        # create dimensions
        _, num_tokens, _ = x.shape # (B, tokens, )
        # print(num_tokens)
        q: torch.Tensor = self.W_query(x)
        k: torch.Tensor = self.W_key(x)
        v: torch.Tensor = self.W_value(x)

        # compute attn scores
        attn_scores = q @ k.transpose(1,2)

        # New, _ ops are in-place
        # `:num_tokens` to account for cases where the number of tokens in the batch is smaller than the supported context_size
        attn_scores.masked_fill_(self.mask.bool()[:num_tokens, :num_tokens], -torch.inf)  

        attn_weights = torch.softmax(attn_scores * keys.size(-1)**-0.5, dim=-1)

        attn_weights = self.dropout(attn_weights)

        out = attn_weights @ v
        return out


In [None]:
# set an example number of tokens
token_count = 6 
inputs = torch.randn((token_count, 3))
print(inputs)

batch = torch.stack((inputs, inputs), dim=0)
batch.shape

In [None]:
torch.manual_seed(123)

context_length = batch.shape[1]
ca = CausalAttention(d_in, d_out, context_length, 0.0)

context_vecs = ca(batch)

print(context_vecs)
print("context_vecs.shape:", context_vecs.shape)

#### **Multi-head Attention**

Adding multiple heads is essentially stacking multiple layers of single attention heads.
This creates multiple latent spaces, so the model can attend to different to different information.


In [256]:
class MultiHeadAttentionWrapper(nn.Module):

    def __init__(self, d_in:int, d_out:int, context_length:int, dropout:float, num_heads=1, bias=False):
        super().__init__()

        self.heads = nn.ModuleList(
            [CausalAttention(d_in, d_out, context_length, dropout, bias) for _ in range(num_heads)]
        )
    
    def forward(self, x:torch.Tensor):
        return torch.cat([head(x) for head in self.heads], dim=-1)

In [None]:
# set an example number of tokens
token_count = 6
inputs = torch.randn((token_count, 3))
print(inputs)

batch = torch.stack((inputs, inputs), dim=0)
batch.shape

In [None]:
torch.manual_seed(123)

context_length = batch.shape[1] # This is the number of tokens
d_in, d_out = inputs.size(-1), 4
mha = MultiHeadAttentionWrapper(
    d_in, d_out, context_length, 0.0, num_heads=4
)

context_vecs = mha(batch)

print(context_vecs[0,:,:])
print("context_vecs.shape:", context_vecs.shape)

**Adding Weight Splits**

Instead of creating a module list of attention heads, we instead created one MultiHeaded Attention implementation,
where we define one q,k & v weight matrix, and split then to obtain the separate heads.


##### **To better understand this, lets do two things**

1. Initialize all of the layers, dimensions, etc in a class.
2. Step through the forward pass peice by peice.


In [314]:
class MultiHeadAttentionEX(nn.Module):

    def __init__(self, d_in:int, d_out:int, context_length:int, dropout:float, num_heads=1, bias=False):
        super().__init__()
        assert(d_out % num_heads == 0), "The output dimension must be divisible by the number of heads"

        self.d_out = d_out
        self.num_heads = num_heads
        self.head_dim = d_out // num_heads # This reduces the projection dimention to match the desired output.

        # Define q, k, v. Remember these...
        self.W_query = nn.Linear(d_in, d_out, bias=bias)
        self.W_key = nn.Linear(d_in, d_out, bias=bias)
        self.W_value = nn.Linear(d_in, d_out, bias=bias)

        self.o_proj = nn.Linear(d_out, d_out) # Linear layer to combine head outputs
        self.dropout = nn.Dropout(dropout)

        # Create mask for decoder block
        self.register_buffer("mask", torch.triu(torch.ones(context_length, context_length), diagonal=1))

In [352]:
inputs = torch.randn((6,5))

In [None]:
# Create the batch inputs 
x_batch = torch.stack([inputs, inputs], dim=0)
print(f"Batch Shape: {x_batch.shape}")
print(f"There are {x_batch.shape[0]} exmaples of size {x_batch.shape[1]}x{x_batch.shape[2]} in the batch")

In [354]:
batch_size, context_length, d_in = x_batch.shape
d_out = 2

mha = MultiHeadAttentionEX(d_in, d_out, context_length, 0.0, num_heads=2)

_start of the forward pass_


In [355]:
# forward(x): Think of the forward pass starting here.
x = x_batch

b, num_tokens, d_in = x.shape

# create the full k, q, v matricies
keys = mha.W_key(x) # Shape: (b, num_tokens, d_out)
queries = mha.W_query(x)
values = mha.W_value(x)

_split the matricies using `.view()`_


In [None]:
print(f"Previous key dimension: {keys.shape}")
keys = keys.view(b, num_tokens, mha.num_heads, mha.head_dim)
print(f"Post split key dimension: {keys.shape}")

**So what did we do here?**

Using .view() a way to take an exisitng tensor and temporarily hold it in newly defined dimensions, we split the output dimension by the number of heads we have.

_Example:_

- Output dimension = 16 | Attention heads 2 | Input dimensions are 2x6x16 (batch, context_length, d_out)
- New head dimension is 16 // 2 => 8 (Will explain // shortly)
- So we now want the key matrix to look like this 2x6x(attn_heads)x(new_dim) => 2x6x2x8

- So you can think of treating the last two dimesnions as "one".

- What happens if we increase the number of attention heads to 4?
- New head dimension is 16 // 4 => 4 (will explain // shortly)
- So we have a key matric like 2x6x4x4 (Note how the last two dimensions == Output Dimension)


**Let's rework the example using the example numbers**


Let's define our example inputs to have 8 tokens (context_length of 8) and have an input dimension of 6


In [None]:
tokens, in_dim = 8, 6
inputs_v2 = torch.randn((tokens, in_dim))
# Create the batch inputs 
x_batch = torch.stack([inputs_v2, inputs_v2], dim=0)
print(f"Batch Shape: {x_batch.shape}")
print(f"There are {x_batch.shape[0]} exmaples of size {x_batch.shape[1]}x{x_batch.shape[2]} in the batch")
batch_size, context_length, d_in = x_batch.shape
d_out, num_heads = 16, 4

In [392]:
batch_size, context_length, d_in = x_batch.shape
d_out, num_heads = 16, 4

mha = MultiHeadAttentionEX(d_in, d_out, context_length, 0.05, num_heads=4)

In [393]:
# forward(x): Think of the forward pass starting here.
x = x_batch

b, num_tokens, d_in = x.shape

# create the full k, q, v matricies
keys:torch.Tensor = mha.W_key(x) # Shape: (b, num_tokens, d_out)
queries:torch.Tensor = mha.W_query(x)
values:torch.Tensor = mha.W_value(x)

_split the matricies_


In [None]:
print(f"Previous key dimension: {keys.shape}")
keys = keys.view(b, num_tokens, mha.num_heads, mha.head_dim)
print(f"Post split key dimension: {keys.shape}")

_do the same for the query and the values_


In [395]:
queries = queries.view(b, num_tokens, mha.num_heads, mha.head_dim)
values = values.view(b, num_tokens, mha.num_heads, mha.head_dim)

_Now we need to Transpose the key matrix_

Wait why are we doing this now?

- Our initial dimensions were (batch, context_length, number of attn heads, dim of the heads)
- We want to swap the context length with the number of attn, heads so:
  - (batch, number of attn heads, context_length, dim of the heads)
- If you cannot see why this is happening now, we will cover it in a second.


In [396]:
keys = keys.transpose(1, 2)
queries = queries.transpose(1, 2)
values = values.transpose(1, 2)

So the reason we need to take the transpose is so we compute the attn scores, and allow us to take advantage of PyTorch's broadcasting functionality. [What is broadcasting?](https://pytorch.org/docs/stable/notes/broadcasting.html)


In [None]:
queries.shape, keys.transpose(2,3).shape

now we are going to compute the attention scores like we did previosly with matrix mulitplication. Note that we are broadcasting the multiplication across the first two dimensions, so realy we are doing 8x4 (queries) and 4x8 (keys) along the two dimensions.


In [398]:
attn_scores = queries @ keys.transpose(2,3)

In [None]:
attn_scores.shape

Just like before, we need to normalize the attn_scores into attention weights. As you can see the matricies are not normalized


In [None]:
attn_scores[0,0,:,:]

But first we need to apply the attention masking


In [None]:
mask_bool = mha.mask.bool()[:num_tokens, :num_tokens]

# apply the attention masking in place
attn_scores.masked_fill_(mask_bool, -torch.inf)
print("mask applied")

In [None]:
attn_scores[0,0,:,:]

Now lets apply the softmax


In [404]:
attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)

In [None]:
attn_weights[0,0,:,:]

Now dropout!


In [406]:
attn_weights = mha.dropout(attn_weights)

In [None]:
attn_weights[0,0,:,:]

Like before we now need to take the attention weights and combine them with our value vectors.\
\
As we can see, we are still broadcasting across the first two dimensions, so it is like we are performing the following:\
8x8 (attn_weights) @ 8x4 (values) -> 8x4 -> with batch dimension for all heads we have 2x4x8x4


In [None]:
attn_weights.shape, values.shape

but we want to use transpose to switch the context_length (in this case 8) and the dimension heads so we can join them together in a later step


In [410]:
# Shape: (b, num_tokens, num_heads, head_dim)
context_vec = (attn_weights @ values).transpose(1, 2) 

In [None]:
context_vec.shape

Why do we need `.continguous()` here?

- checkout this [post](https://stackoverflow.com/questions/48915810/what-does-contiguous-do-in-pytorch)


In [413]:
# Combine heads, where self.d_out = self.num_heads * self.head_dim
context_vec = context_vec.contiguous().view(b, num_tokens, mha.d_out)
context_vec = mha.o_proj(context_vec) # optional projection

In [None]:
context_vec.shape

Finally, we reached the our output for the Multi-headed attention implementation!


#### **Lets Put it All together!**


In [438]:
class MultiHeadAttention(nn.Module):

    def __init__(self, d_in:int, d_out:int, context_length:int, dropout:float, num_heads=1, bias=False):
        super().__init__()
        assert(d_out % num_heads == 0), "The output dimension must be divisible by the number of heads"

        self.d_out = d_out
        self.num_heads = num_heads
        self.head_dim = d_out // num_heads # This reduces the projection dimention to match the desired output.

        # Define q, k, v. Remember these...
        self.W_query = nn.Linear(d_in, d_out, bias=bias)
        self.W_key = nn.Linear(d_in, d_out, bias=bias)
        self.W_value = nn.Linear(d_in, d_out, bias=bias)

        self.o_proj = nn.Linear(d_out, d_out) # Linear layer to combine head outputs
        self.dropout = nn.Dropout(dropout)

        # Create mask for decoder block
        self.register_buffer("mask", torch.triu(torch.ones(context_length, context_length), diagonal=1))
    

    def forward(self, x:torch.Tensor):
        b, num_tokens, d_in = x.shape

        keys:torch.Tensor = self.W_key(x) # Shape: (b, num_tokens, d_out)
        queries:torch.Tensor = self.W_query(x)
        values: torch.Tensor = self.W_value(x)

        # We implicitly split the matrix by adding a `num_heads` dimension
        # Unroll last dim: (b, num_tokens, d_out) -> (b, num_tokens, num_heads, head_dim)
        keys = keys.view(b, num_tokens, self.num_heads, self.head_dim) 
        values = values.view(b, num_tokens, self.num_heads, self.head_dim)
        queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)

        # Transpose: (b, num_tokens, num_heads, head_dim) -> (b, num_heads, num_tokens, head_dim)
        keys = keys.transpose(1, 2)
        queries = queries.transpose(1, 2)
        values = values.transpose(1, 2)

       # Compute scaled dot-product attention (aka self-attention) with a causal mask
        attn_scores = queries @ keys.transpose(2, 3)  # Dot product for each head

        # Original mask truncated to the number of tokens and converted to boolean
        mask_bool = self.mask.bool()[:num_tokens, :num_tokens]

        # Use the mask to fill attention scores
        attn_scores.masked_fill_(mask_bool, -torch.inf)
        
        attn_weights = torch.softmax(attn_scores / keys.size(-1)**0.5, dim=-1)
        attn_weights = self.dropout(attn_weights)

        # Shape: (b, num_tokens, num_heads, head_dim)
        context_vec = (attn_weights @ values).transpose(1, 2) 
        
        # Combine heads, where self.d_out = self.num_heads * self.head_dim
        context_vec = context_vec.contiguous().view(b, num_tokens, self.d_out)
        context_vec = self.o_proj(context_vec) # optional projection

        return context_vec

In [None]:
tokens, in_dim = 8, 6
inputs_v2 = torch.randn((tokens, in_dim))
# Create the batch inputs 
x_batch = torch.stack([inputs_v2, inputs_v2], dim=0)
print(f"Batch Shape: {x_batch.shape}")
print(f"There are {x_batch.shape[0]} exmaples of size {x_batch.shape[1]}x{x_batch.shape[2]} in the batch")
batch_size, context_length, d_in = x_batch.shape
d_out, num_heads = 16, 4

In [440]:
mha = MultiHeadAttention(d_in, d_out, context_length, dropout=0.1, num_heads=4, bias=False)

In [None]:
x_batch.shape

In [None]:
mha(x_batch).shape

**Shared Buffers**

- class that is used alongside rope to ruse the attention mask, sin and cos computations each subsequen prediction that improves effeciency


In [3]:
from rope import precompute_rope_params
class SharedBuffers:
    _buffers = {}

    @staticmethod
    def get_buffers(context_length, head_dim, rope_base, freq_config, dtype=torch.float32):
        # if rope config is not none, get the (mask, cos, sin) config values, otherwise pass none
        key = (context_length, head_dim, rope_base, tuple(freq_config.values()) if freq_config else freq_config, dtype)

        if key not in SharedBuffers._buffers:
            # Create or fetch the buffers
            mask = torch.triu(torch.ones(context_length, context_length), diagonal=1)
            cos, sin = precompute_rope_params(head_dim, rope_base, context_length, freq_config)
            if dtype is not None:
                cos = cos.to(dtype)
                sin = sin.to(dtype)
            SharedBuffers._buffers[key] = (mask, cos, sin)

        return SharedBuffers._buffers[key]

In [4]:
class GroupedQueryAttention(nn.Module):
    def __init__(
            self, d_in, d_out, context_length, num_heads,
            num_kv_groups,       # NEW
            rope_base=10_000,    # NEW
            rope_config=None,    # NEW
            dtype=None
        ):
        super().__init__()
        assert d_out % num_heads == 0, "d_out must be divisible by num_heads"
        assert num_heads % num_kv_groups == 0, "num_heads must be divisible by num_kv_groups"

        # Set the dimensions of the q, k, v queries
        self.d_out = d_out
        self.num_heads = num_heads
        self.head_dim = d_out // num_heads

        # Create the k and v weight matricies. 
        # Traditionally, the second dim is d_out. If num_kv_groups=1 we have Multi-Query, if num_kv_groups=num_head we have Multi-head attention
        self.W_key = nn.Linear(d_in, num_kv_groups * self.head_dim, bias=False, dtype=dtype)
        self.W_value = nn.Linear(d_in, num_kv_groups * self.head_dim, bias=False, dtype=dtype)
        self.num_kv_groups = num_kv_groups
        self.group_size = num_heads // num_kv_groups

        # query weights are the same as MHA
        self.W_query = nn.Linear(d_in, d_out, bias=False, dtype=dtype)
        self.out_proj = nn.Linear(d_out, d_out, bias=False, dtye=dtype)
        
        # Fetch buffers using Shared buffers class
        mask, cos, sin = SharedBuffers.get_buffers(context_length, self.head_dim, rope_base, rope_config, dtype)

        self.register_buffer("mask", mask)
        self.register_buffer("cos", cos)
        self.register_buffer("sin", sin)

#### Stepping through the forward pass


create fake input tokens


In [None]:
# Settings
batch_size = 2
context_length = 3000
max_context_len=8192
embed_dim  =4096
num_heads = 36
x_batch = torch.randn((batch_size, context_length, embed_dim))
# Create the batch inputs 
print(f"Batch Shape: {x_batch.shape}")
print(f"There are {x_batch.shape[0]} exmaples of size {x_batch.shape[1]}x{x_batch.shape[2]} in the batch")

Now that we have a sample of input tokens, let's initialize our weight matricies


In [None]:
num_kv_groups = 8 # can be tuned, but make sure it evenly divides the number of heads
gqa = GroupedQueryAttention(d_in, d_out, context_length, num_heads, num_kv_groups, rope_base=500_000, rope_config=None)

In [None]:
x_batch.shape #(Batch, number of tokens, input dimension)