In [1]:
import torch
from torch import nn
from torch.functional import F

# Softmax, SafeSoftmax, OnlineSoftmax

In [52]:
x = torch.tensor(range(-3, 3), dtype=torch.float16)
xlimt = torch.tensor(12, dtype=torch.float16)
xlimt.exp()


tensor(inf, dtype=torch.float16)

In [53]:
x / x.exp().sum()

tensor([-0.2571, -0.1715, -0.0858,  0.0000,  0.0858,  0.1715],
       dtype=torch.float16)

In [54]:
x_calibrated = x - x.max()
x_calibrated.exp() / x_calibrated.exp().sum()


tensor([0.0043, 0.0116, 0.0316, 0.0858, 0.2332, 0.6338], dtype=torch.float16)

In [55]:
torch.softmax(x, dim=-1)

tensor([0.0043, 0.0116, 0.0316, 0.0858, 0.2332, 0.6338], dtype=torch.float16)

# Self-attention

In [11]:
class NaiveSelfAttention(nn.Module):
    def __init__(self, embed_size):
        super().__init__()
        self.embed_size = embed_size
        self.scale = self.embed_size ** 0.5
        self.q_proj = nn.Linear(embed_size, embed_size)
        self.k_proj = nn.Linear(embed_size, embed_size)
        self.v_proj = nn.Linear(embed_size, embed_size)
        self.out_proj = nn.Linear(embed_size, embed_size)

    def forward(self, x):
        B, N, _ = x.shape
        q = self.q_proj(x)
        k = self.k_proj(x)
        v = self.v_proj(x)
        q = q.view(B, N, self.embed_size)
        k = k.view(B, N, self.embed_size)
        v = v.view(B, N, self.embed_size)
        
        attn = q @ k.transpose(-2, -1) / self.scale
        # each row, the sum of the softmaxed values is 1; for each query, it's dot product with all keys should be sum of 1
        attn = attn.softmax(dim=-1)
        out = attn @ v  
        return out
        

In [12]:
x = torch.randn(1, 3, 4)
naive_self_attention = NaiveSelfAttention(4)
naive_self_attention(x)

tensor([[[-0.4288,  0.4835,  0.8056,  0.7329],
         [-0.3759,  0.4754,  0.8185,  0.6988],
         [-0.3261,  0.4747,  0.8004,  0.6714]]], grad_fn=<UnsafeViewBackward0>)

In [None]:
class CausalAttention(nn.Module):
    def __init__(self, embed_size, context_length):
        super().__init__()
        self.embed_size = embed_size
        self.scale = self.embed_size ** 0.5
        self.q_proj = nn.Linear(embed_size, embed_size)
        self.k_proj = nn.Linear(embed_size, embed_size)
        self.v_proj = nn.Linear(embed_size, embed_size)
        self.out_proj = nn.Linear(embed_size, embed_size)
        self.register_buffer('mask', torch.triu(torch.ones(context_length, context_length), diagonal=1))

    def forward(self, x):
        B, N, _ = x.shape
        q = self.q_proj(x)
        k = self.k_proj(x)
        v = self.v_proj(x)
        q = q.view(B, N, self.embed_size)
        k = k.view(B, N, self.embed_size)
        v = v.view(B, N, self.embed_size)
        
        attn = q @ k.transpose(-2, -1) / self.scale
        # self.mask.bool()[:N, :N] is a boolean tensor of shape (N, N); only the upper triangular part is True (does not include the diagonal), hence the -inf will the upper triangular part; we kept the values of the lower triangular part, including the diagonal as desired
        attn = attn.masked_fill_(self.mask.bool()[:N, :N], float('-inf'))
        # each row, the sum of the softmaxed values is 1; for each query, it's dot product with all keys should be sum of 1
        attn = attn.softmax(dim=-1)
        out = attn @ v  
        y = self.out_proj(out)
        return y
        

In [None]:
class MultiHeadAttention(nn.Module):
    def __init__(self, embed_size, num_heads, context_length):
        super().__init__()
        self.embed_size = embed_size
        self.num_heads = num_heads
        assert embed_size % num_heads == 0, "embed_size must be divisible by num_heads"
        self.head_dim = embed_size // num_heads
        self.q_proj = nn.Linear(embed_size, embed_size)
        self.k_proj = nn.Linear(embed_size, embed_size)
        self.v_proj = nn.Linear(embed_size, embed_size)
        self.out_proj = nn.Linear(embed_size, embed_size)
        self.register_buffer('mask', torch.triu(torch.ones(context_length, context_length), diagonal=1))
        
    def forward(self, x):
        B, N, _ = x.shape
        q = self.q_proj(x)
        k = self.k_proj(x)
        v = self.v_proj(x)
        
        # split the query, key, and value into multiple heads
        q = q.view(B, N, self.num_heads, self.head_dim)        
        k = k.view(B, N, self.num_heads, self.head_dim)
        v = v.view(B, N, self.num_heads, self.head_dim)
        
        
        # change the shape to (B, num_heads, N, head_dim)
        q = q.transpose(1, 2)
        k = k.transpose(1, 2)
        v = v.transpose(1, 2)
        
        # compute attention
        attn = q @ k.transpose(-2, -1) / (self.head_dim ** 0.5)
        attn = attn.masked_fill_(self.mask.bool()[:N, :N], float('-inf'))
        attn = attn.softmax(dim=-1)
        # (B, num_heads, N, N) @ (B, num_heads, N, head_dim) -> (B, num_heads, N, head_dim)
        out = attn @ v
        
        # change the shape to (B, N, embed_size)
        out = out.transpose(1, 2).contiguous().view(B, N, self.embed_size)
        y = self.out_proj(out)
        return y
        
        
        

In [3]:
context_length = 6
mask = torch.triu(torch.ones(context_length, context_length), diagonal=1)
N = 3
attn = torch.randn(N, N)
attn = attn.masked_fill_(mask[:N, :N].bool(), float('-inf'))
attn


tensor([[ 0.4651,    -inf,    -inf],
        [ 1.0145,  0.1376,    -inf],
        [-1.2505,  0.4200,  1.4277]])

In [5]:
mask[:N, :N]

tensor([[0., 1., 1.],
        [0., 0., 1.],
        [0., 0., 0.]])

In [4]:
attn

tensor([[ 0.4651,    -inf,    -inf],
        [ 1.0145,  0.1376,    -inf],
        [-1.2505,  0.4200,  1.4277]])

## GQA

In [None]:
class GroupedQueryAttention(nn.Module):
    def __init__(self, embed_size, num_query_heads, num_kv_heads, context_length):
        """
        Grouped Query Attention (GQA)
        
        Args:
            embed_size (int): Total embedding dimension.
            num_query_heads (int): Number of query heads.
            num_kv_heads (int): Number of key/value heads (usually fewer).
            context_length (int): Maximum sequence length.
        """
        super().__init__()
        assert embed_size % num_query_heads == 0, "embed_size must be divisible by num_query_heads"
        assert num_query_heads % num_kv_heads == 0, "num_query_heads must be divisible by num_kv_heads"

        self.embed_size = embed_size
        self.num_query_heads = num_query_heads
        self.num_kv_heads = num_kv_heads
        self.group_size = num_query_heads // num_kv_heads
        self.head_dim = embed_size // num_query_heads

        # projections
        self.q_proj = nn.Linear(embed_size, embed_size)
        self.k_proj = nn.Linear(embed_size, num_kv_heads * self.head_dim)
        self.v_proj = nn.Linear(embed_size, num_kv_heads * self.head_dim)
        self.out_proj = nn.Linear(embed_size, embed_size)

        # causal mask
        self.register_buffer('mask', torch.triu(torch.ones(context_length, context_length), diagonal=1))

    def forward(self, x):
        B, N, _ = x.shape

        # project
        q = self.q_proj(x)
        k = self.k_proj(x)
        v = self.v_proj(x)

        # reshape to heads
        q = q.view(B, N, self.num_query_heads, self.head_dim).transpose(1, 2)  # (B, num_query_heads, N, head_dim)
        k = k.view(B, N, self.num_kv_heads, self.head_dim).transpose(1, 2)     # (B, num_kv_heads, N, head_dim)
        v = v.view(B, N, self.num_kv_heads, self.head_dim).transpose(1, 2)     # (B, num_kv_heads, N, head_dim)

        # expand k, v to match grouped queries
        if self.group_size > 1:
            k = k.repeat_interleave(self.group_size, dim=1)
            v = v.repeat_interleave(self.group_size, dim=1)

        # scaled dot-product attention
        attn = q @ k.transpose(-2, -1) / (self.head_dim ** 0.5)
        attn = attn.masked_fill_(self.mask.bool()[:N, :N], float('-inf'))
        attn = torch.softmax(attn, dim=-1)

        out = attn @ v  # (B, num_query_heads, N, head_dim)
        out = out.transpose(1, 2).contiguous().view(B, N, self.embed_size)
        y = self.out_proj(out)
        return y

In [11]:
k = torch.randn(1, 2, 3)
print(k.shape)
print(k)
k_repeat =k.repeat_interleave(2, dim=1)
print(k_repeat.shape)
print(k_repeat)

torch.Size([1, 2, 3])
tensor([[[ 1.1628, -0.5913,  0.2090],
         [-0.0834, -1.1340, -0.3211]]])
torch.Size([1, 4, 3])
tensor([[[ 1.1628, -0.5913,  0.2090],
         [ 1.1628, -0.5913,  0.2090],
         [-0.0834, -1.1340, -0.3211],
         [-0.0834, -1.1340, -0.3211]]])


# Tensor Parallelism

$$Y_{m\times 2d} = X_{m\times n} @ W_{n\times 2d}$$

Let tp_size = 2, then 

$$[Y^1_{m\times d}, Y^2_{m\times d}] = [X_{m\times n} @ W^1_{n\times d}, X_{m\times n} @ W^2_{n\times d}]$$







In [None]:
# simplified column parallel linear
torch.manual_seed(0)    
input_dim = 2
output_dim = 4
W = torch.randn(input_dim, output_dim)
X = torch.randn(3, input_dim)
Y = X @ W
print('Y.shape', Y.shape)

tp_size = 2
W_tp = torch.split(W, output_dim // tp_size, dim=1)
Ys = [X @ W_tp[i] for i in range(tp_size)]
Ys_gathered = torch.cat(Ys, dim=1)
print('Ys_gathered.shape', Ys_gathered.shape)

torch.allclose(Y, Ys_gathered)


Y.shape torch.Size([3, 4])
Ys_gathered.shape torch.Size([3, 4])


True