In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

<figure>
  <figcaption>Transformer Tokens Attention</figcaption>
  <img src="https://drive.google.com/uc?export=view&id=1zg2ZJ8Vnuxth41gGeGg41NydFyDr8TSq" alt="Sample Image" width="500">
</figure>

<figure>
  <figcaption>Attention Maths</figcaption>
  <img src="https://drive.google.com/uc?export=view&id=1TVgwT5BTpI_5hZBdm0RxaaLz_Qwgdlb6" alt="Sample Image" width="250">
</figure>

In [None]:
class SimpleAttention(nn.Module):
    def __init__(self, embed_size):
        super(SimpleAttention, self).__init__()
        self.embed_size = embed_size
        self.query_linear = nn.Linear(embed_size, embed_size)
        self.key_linear = nn.Linear(embed_size, embed_size)
        self.value_linear = nn.Linear(embed_size, embed_size)

    def forward(self, x):
        # x (batch_size, seq_length, embed_size)
        Q = self.query_linear(x)  # (bsz, seq, emb)
        K = self.key_linear(x)    # (bsz, seq, emb)
        V = self.value_linear(x)  # (bsz, seq, emb)

        # Q.K / root(d)
        attention_scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.embed_size ** 0.5)  # (bsz, seq, seq)
        # softmax ( Q.K/root(d) )
        attention_weights = F.softmax(attention_scores, dim=-1)  # (bsz, seq, seq)
        # softmax ( Q.K/root(d) ) . V
        out = torch.matmul(attention_weights, V)  # (bsz, seq, emb)

        return out, attention_weights

# usage
batch_size, seq_length, embed_size = 2, 3, 4
x = torch.randn(batch_size, seq_length, embed_size)
attention_layer = SimpleAttention(embed_size)
out, attention_weights = attention_layer(x)

print("Input:", x)
print("Output:", out)
print("Attention Weights:", attention_weights)

Input:
 tensor([[[ 0.9393,  1.1324,  0.1013,  1.4069],
         [-0.7529,  1.5224, -1.0648, -0.5474],
         [ 0.1634, -0.8008, -0.3369, -0.4246]],

        [[-0.5970, -1.1192, -1.0540,  0.4191],
         [ 1.8294, -0.3668, -0.3958,  1.0895],
         [-0.2905, -0.2080,  1.6845, -0.9910]]])
Output:
 tensor([[[ 0.4623,  0.4141,  0.2415,  0.2815],
         [ 0.4483,  0.3984,  0.2311,  0.2502],
         [ 0.5254,  0.3633,  0.2413,  0.3072]],

        [[ 0.3711,  0.0515,  0.1035, -0.1709],
         [ 0.4296, -0.0672,  0.0026, -0.1290],
         [ 0.3522, -0.1485, -0.2304, -0.1724]]], grad_fn=<UnsafeViewBackward0>)
Attention Weights:
 tensor([[[0.3940, 0.3181, 0.2878],
         [0.4007, 0.2846, 0.3147],
         [0.3259, 0.3377, 0.3364]],

        [[0.3175, 0.4376, 0.2449],
         [0.3682, 0.3157, 0.3161],
         [0.2334, 0.2921, 0.4746]]], grad_fn=<SoftmaxBackward0>)


<figure>
  <figcaption>Multi Head Attention</figcaption>
  <img src="https://drive.google.com/uc?export=view&id=1rs5-ptxkx2ycBfll7eT3-eKclpb1MkH8" alt="Sample Image" width="500">
</figure>

In [None]:
class MultiHeadAttention(nn.Module):
    def __init__(self, embed_size, num_heads):
        super(MultiHeadAttention, self).__init__()
        assert embed_size % num_heads == 0, "Embedding size must be divisible by number of heads"

        self.embed_size = embed_size
        self.num_heads = num_heads
        self.head_dim = embed_size // num_heads

        self.query_linear = nn.Linear(embed_size, embed_size)
        self.key_linear = nn.Linear(embed_size, embed_size)
        self.value_linear = nn.Linear(embed_size, embed_size)
        self.fc_out = nn.Linear(embed_size, embed_size)

    def forward(self, x):
        batch_size, seq_length, embed_size = x.size()

        # Linear projections
        Q = self.query_linear(x)  # (bsz, seq, emb)
        K = self.key_linear(x)    # (bsz, seq, emb)
        V = self.value_linear(x)  # (bsz, seq, emb)

        # reshape - # (bsz, seq, num_heads, head_dim)
        Q = Q.view(batch_size, seq_length, self.num_heads, self.head_dim)
        K = K.view(batch_size, seq_length, self.num_heads, self.head_dim)
        V = V.view(batch_size, seq_length, self.num_heads, self.head_dim)

        # transpose - # (bsz, num_heads, seq, head_dim)
        Q = Q.transpose(1, 2)
        K = K.transpose(1, 2)
        V = V.transpose(1, 2)

        # Scaled dot-product attention
        attention_scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.head_dim ** 0.5)  # (bsz, num_heads, seq, seq)
        attention_weights = F.softmax(attention_scores, dim=-1)  # (bsz, num_heads, seq, seq)
        out = torch.matmul(attention_weights, V)  # (bsz, num_heads, seq, head_dim)

        # concat heads
        out = out.transpose(1, 2).contiguous().view(batch_size, seq_length, embed_size) # (bsz, seq, emb)
        # final
        out = self.fc_out(out)
        return out, attention_weights

# usage
batch_size, seq_length, embed_size, num_heads = 2, 3, 4, 2
x = torch.randn(batch_size, seq_length, embed_size)
multi_head_attention_layer = MultiHeadAttention(embed_size, num_heads)
out, attention_weights = multi_head_attention_layer(x)

print("Input:", x)
print("Output:", out)
print("Attention Weights:", attention_weights)

<figure>
  <figcaption>Cross Attention</figcaption>
  <img src="https://drive.google.com/uc?export=view&id=1lhGJrpdrNF2lu8URGcUdUomngPichK5i" alt="Sample Image" width="500">
</figure>


In [None]:
class CrossAttention(nn.Module):
    def __init__(self, embed_size):
        super(CrossAttention, self).__init__()
        self.embed_size = embed_size
        self.query_linear = nn.Linear(embed_size, embed_size)
        self.key_linear = nn.Linear(embed_size, embed_size)
        self.value_linear = nn.Linear(embed_size, embed_size)

    def forward(self, x, context):
        Q = self.query_linear(x)         # (bsz, seq, emb)
        K = self.key_linear(context)     # (bsz, context_seq, emb)
        V = self.value_linear(context)   # (bsz, context_seq, emb)

        # Q.K / root(d)
        attention_scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.embed_size ** 0.5)  # (bsz, seq, context_seq)
        # softmax ( Q.K/root(d) )
        attention_weights = F.softmax(attention_scores, dim=-1)  # (bsz, seq, context_seq)
        # softmax ( Q.K/root(d) ) . V
        out = torch.matmul(attention_weights, V)  # (bsz, seq, emb)
        #
        return out, attention_weights

# usage
batch_size, seq_length, context_length, embed_size = 2, 3, 5, 4
x = torch.randn(batch_size, seq_length, embed_size)
context = torch.randn(batch_size, context_length, embed_size)
attention_layer = CrossAttention(embed_size)
out, attention_weights = attention_layer(x, context)

print("Input:", x)
print("Context:", context)
print("Output:", out)
print("Attention Weights:", attention_weights)

<figure>
  <figcaption>Cross Attention</figcaption>
  <img src="https://drive.google.com/uc?export=view&id=1g9l2FFRxJHAXwnpMSE8n05w3fqQNkySn" alt="Grouped Query Attention Image" width="500">
</figure>


In [None]:
class GroupedQueryAttention(nn.Module):
    def __init__(self, embed_size, num_heads, num_query_groups):
        super(GroupedQueryAttention, self).__init__()
        assert embed_size % num_heads == 0, "Embedding size must be divisible by number of heads"
        assert num_heads % num_query_groups == 0, "Number of heads must be divisible by number of query groups"

        self.embed_size = embed_size
        self.num_heads = num_heads
        self.num_query_groups = num_query_groups
        self.head_dim = embed_size // num_heads
        self.group_dim = embed_size // num_query_groups

        self.query_linear = nn.Linear(embed_size, embed_size)
        self.key_linear = nn.Linear(embed_size, embed_size)
        self.value_linear = nn.Linear(embed_size, embed_size)
        self.fc_out = nn.Linear(embed_size, embed_size)

    def forward(self, x):
        batch_size, seq_length, embed_size = x.size()
        #
        Q = self.query_linear(x)  # (bsz, seq, emb)
        K = self.key_linear(x)    # (bsz, seq, emb)
        V = self.value_linear(x)  # (bsz, seq, emb)
        #
        Q = Q.view(batch_size, seq_length, self.num_query_groups, self.group_dim)
        K = K.view(batch_size, seq_length, self.num_heads, self.head_dim)
        V = V.view(batch_size, seq_length, self.num_heads, self.head_dim)
        #
        Q = Q.transpose(1, 2)  # (bsz, num_query_groups, seq, group_dim)
        K = K.transpose(1, 2)  # (bsz, num_heads, seq, head_dim)
        V = V.transpose(1, 2)  # (bsz, num_heads, seq, head_dim)

        # G Q attn scores
        group_size = self.num_heads // self.num_query_groups
        attention_scores = []
        for i in range(self.num_query_groups):
            start_head = i * group_size
            end_head = start_head + group_size
            q_group = Q[:, i:i+1, :, :]  # (bsz, 1, seq, group_dim)
            k_group = K[:, start_head:end_head, :, :]  # (bsz, group_size, seq, head_dim)
            scores = torch.matmul(q_group, k_group.transpose(-2, -1)) / (self.group_dim ** 0.5)  # (bsz, 1, seq, seq)
            attention_scores.append(scores)

        attention_scores = torch.cat(attention_scores, dim=1)  # (bsz, num_heads, seq, seq)
        attention_weights = F.softmax(attention_scores, dim=-1)  # (bsz, num_heads, seq, seq)
        #
        out = torch.matmul(attention_weights, V)  # (bsz, num_heads, seq, head_dim)
        # concat
        out = out.transpose(1, 2).contiguous().view(batch_size, seq_length, embed_size)  # (bsz, seq, emb)
        #
        out = self.fc_out(out)
        return out, attention_weights

# Usage
batch_size, seq_length, embed_size, num_heads, num_query_groups = 2, 3, 4, 2, 1
x = torch.randn(batch_size, seq_length, embed_size)
grouped_query_attention_layer = GroupedQueryAttention(embed_size, num_heads, num_query_groups)
out, attention_weights = grouped_query_attention_layer(x)

print("Input:", x)
print("Output:", out)
print("Attention Weights:", attention_weights)

Ghost Attention will interact with the self-attention mechanism used for Transformer models, Ghost Attention is not itself a replacement for self-attention, rather a way to give the self-attention mechanism better data so it will remember instructions given early on over longer contexts.

<figure>
  <figcaption>Cross Attention</figcaption>
  <img src="https://drive.google.com/uc?export=view&id=1hK5i" alt="GAtt Image" width="500">
</figure>