In [3]:
import torch.nn as nn
import torch

In [4]:
inputs = torch.tensor([
    [0.43, 0.15, 0.89], # your
    [0.55, 0.87, 0.66], # journey
    [0.57, 0.85, 0.64], # starts
    [0.22, 0.58, 0.33], # with
    [0.77, 0.25, 0.10], # one 
    [0.05, 0.80, 0.55], # step
])

In [5]:
import torch.nn as nn

class SelfAtention_v2(nn.Module):
    def __init__(self, d_in, d_out, qkv_bias = False):
        super().__init__()
        self.d_out = d_out
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        
    def forward(self, x):
        keys  = self.W_key(x)
        queries = self.W_query(x)
        values = self.W_value(x)
        
        attn_scores = queries @keys.T        
        attn_weights = torch.softmax(attn_scores / (self.d_out ** 0.5), dim=-1)
        print(f"attn_weights: \n{attn_weights}")
            
        context_vec  = attn_weights @ values

        return context_vec
        

In [6]:
torch.manual_seed(789)
sa_v2 = SelfAtention_v2(d_in=3, d_out=2)
print(sa_v2(inputs))

attn_weights: 
tensor([[0.1921, 0.1646, 0.1652, 0.1550, 0.1721, 0.1510],
        [0.2041, 0.1659, 0.1662, 0.1496, 0.1665, 0.1477],
        [0.2036, 0.1659, 0.1662, 0.1498, 0.1664, 0.1480],
        [0.1869, 0.1667, 0.1668, 0.1571, 0.1661, 0.1564],
        [0.1830, 0.1669, 0.1670, 0.1588, 0.1658, 0.1585],
        [0.1935, 0.1663, 0.1666, 0.1542, 0.1666, 0.1529]],
       grad_fn=<SoftmaxBackward0>)
tensor([[-0.0739,  0.0713],
        [-0.0748,  0.0703],
        [-0.0749,  0.0702],
        [-0.0760,  0.0685],
        [-0.0763,  0.0679],
        [-0.0754,  0.0693]], grad_fn=<MmBackward0>)


In [7]:
import torch.nn as nn

class SelfAtention_v3(nn.Module):
    def __init__(self, d_in, d_out, qkv_bias = False):
        super().__init__()
        self.d_out = d_out
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        
    def forward(self, x):
        keys  = self.W_key(x)
        queries = self.W_query(x)
        values = self.W_value(x)
        
        attn_scores = queries @keys.T
        print(f"attn_score: \n{attn_scores}")
        
        context_length = attn_scores.shape[-1]
        # torch.triu 获得上三角矩阵的mask， diagnoal=1表示向右移动一个元素取上三角
        mask = torch.triu(torch.ones(context_length, context_length), diagonal=1)
        print(f"mask: \n{mask}")
        
        # masked_fill 将 bool 为 true 的位置填充为 -inf 
        masked_scores = attn_scores.masked_fill(mask.bool(), -torch.inf)
        print(f"masked: \n{masked_scores}")
        
        # 计算 masked_scores 的 softmax
        attn_weights = torch.softmax(masked_scores / (self.d_out ** 0.5), dim=-1)
        print(f"attn_weights: \n{attn_weights}")
        
        # dropout 用于丢弃某些参数，防止模型过于依赖某些部分的参数
        # 但是目前的 llm 一般不使用 dropout
        torch.manual_seed(123)
        # 由于 dropout 50% 的参数，会对剩余的元素进行补偿 1 / (1 - drop_ratio) = 2
        dropout = nn.Dropout(p=0.5)
        attn_weights = dropout(attn_weights)
        print(f"attn_weights after dropout: \n{attn_weights}")
            
        context_vec  = attn_weights @ values

        return context_vec
        

In [8]:
torch.manual_seed(789)
sa_v3 = SelfAtention_v3(d_in=3, d_out=2)
print(sa_v3(inputs))

attn_score: 
tensor([[ 0.2899,  0.0716,  0.0760, -0.0138,  0.1344, -0.0511],
        [ 0.4656,  0.1723,  0.1751,  0.0259,  0.1771,  0.0085],
        [ 0.4594,  0.1703,  0.1731,  0.0259,  0.1745,  0.0090],
        [ 0.2642,  0.1024,  0.1036,  0.0186,  0.0973,  0.0122],
        [ 0.2183,  0.0874,  0.0882,  0.0177,  0.0786,  0.0144],
        [ 0.3408,  0.1270,  0.1290,  0.0198,  0.1290,  0.0078]],
       grad_fn=<MmBackward0>)
mask: 
tensor([[0., 1., 1., 1., 1., 1.],
        [0., 0., 1., 1., 1., 1.],
        [0., 0., 0., 1., 1., 1.],
        [0., 0., 0., 0., 1., 1.],
        [0., 0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 0., 0.]])
masked: 
tensor([[0.2899,   -inf,   -inf,   -inf,   -inf,   -inf],
        [0.4656, 0.1723,   -inf,   -inf,   -inf,   -inf],
        [0.4594, 0.1703, 0.1731,   -inf,   -inf,   -inf],
        [0.2642, 0.1024, 0.1036, 0.0186,   -inf,   -inf],
        [0.2183, 0.0874, 0.0882, 0.0177, 0.0786,   -inf],
        [0.3408, 0.1270, 0.1290, 0.0198, 0.1290, 0.0078]],
  

In [5]:
class CausalAttention(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, qkv_bias = False):
        super().__init__()
        self.d_out = d_out
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        
        self.dropout = nn.Dropout(p=dropout)
        # register_buffer 可以在使用 gpu 时，将变量自动转移到 gpu 上
        self.register_buffer(
            'mask',
            torch.triu(torch.ones(context_length, context_length), diagonal=1)
        )
        
    
    def forward(self, x):
        b, num_tokens, d_in = x.shape
        keys = self.W_key(x)
        queries = self.W_query(x)
        values = self.W_value(x)
        
        # 这里由于存在 batch, keys and queries 的形状是 (b, num_tokens, d_out)
        # 所以 keys 的转置为 (num_tokens, d_out) 即 1 - 2 维
        attn_scores = queries @ keys.transpose(1, 2)
        attn_scores.masked_fill_(self.mask.bool()[:num_tokens, :num_tokens], -torch.inf)
        attn_weights = torch.softmax(attn_scores / (self.d_out ** 0.5), dim=-1)
        # print(f"attn_weights: \n{attn_weights}")
        
        # 这里在 attn_weight 应用 dropout
        # 这里为什么不在 attn_scores 上应用 dropout 呢？
        attn_weights = self.dropout(attn_weights)
        # print(f"attn_weights after dropout: \n{attn_weights}")
        
        context_vec = attn_weights @ values
        
        return context_vec
        

In [None]:
batch = torch.stack((inputs, inputs), dim = 0)
print(f"batch shape: {batch.shape}")

torch.manual_seed(123)
context_length = batch.shape[1]
ca = CausalAttention(d_in = 3, d_out=2, context_length = context_length, dropout=0.0, qkv_bias=False)
context_vecs = ca(batch)
print(f"context_vecs shape: {context_vecs.shape}")
printf(f"context_vecs: \n{context_vecs}")

batch shape: torch.Size([2, 6, 3])
attn_weights: 
tensor([[[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.4833, 0.5167, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.3190, 0.3408, 0.3402, 0.0000, 0.0000, 0.0000],
         [0.2445, 0.2545, 0.2542, 0.2468, 0.0000, 0.0000],
         [0.1994, 0.2060, 0.2058, 0.1935, 0.1953, 0.0000],
         [0.1624, 0.1709, 0.1706, 0.1654, 0.1625, 0.1682]],

        [[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.4833, 0.5167, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.3190, 0.3408, 0.3402, 0.0000, 0.0000, 0.0000],
         [0.2445, 0.2545, 0.2542, 0.2468, 0.0000, 0.0000],
         [0.1994, 0.2060, 0.2058, 0.1935, 0.1953, 0.0000],
         [0.1624, 0.1709, 0.1706, 0.1654, 0.1625, 0.1682]]],
       grad_fn=<SoftmaxBackward0>)
attn_weights after dropout: 
tensor([[[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.4833, 0.5167, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.3190, 0.3408, 0.3402, 0.0000, 0.0000, 0.0000

In [6]:
class MultiHeadAttentionWrapper(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias = False):
        super().__init__()
        self.heads = nn.ModuleList(
            [CausalAttention(d_in, d_out, context_length, dropout, qkv_bias) for _ in range(num_heads)]
        )
        
    def forward(self, x):
        return torch.cat([head(x) for head in self.heads], dim=-1)

In [7]:
torch.manual_seed(123)
context_length = batch.shape[1] # 一个 batch 中的 token 数量
d_in, d_out = 3, 1 # 输入和输出的维度
hma = MultiHeadAttentionWrapper(
    d_in = d_in,
    d_out = d_out,
    context_length = context_length,
    dropout = 0.0,
    num_heads = 2
)
context_vecs = hma(batch)
print(f"context_vecs shape: {context_vecs.shape}")
print(f"context_vecs: {context_vecs}")


context_vecs shape: torch.Size([2, 6, 2])
context_vecs: tensor([[[-0.5740,  0.2216],
         [-0.7320,  0.0155],
         [-0.7774, -0.0546],
         [-0.6979, -0.0817],
         [-0.6538, -0.0957],
         [-0.6424, -0.1065]],

        [[-0.5740,  0.2216],
         [-0.7320,  0.0155],
         [-0.7774, -0.0546],
         [-0.6979, -0.0817],
         [-0.6538, -0.0957],
         [-0.6424, -0.1065]]], grad_fn=<CatBackward0>)


In [2]:
import torch.nn as nn
import torch

class MultiHeadAttention(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
        super().__init__()
        assert d_out % num_heads == 0, "d_out must be divisible by num_heads"
        
        self.d_out = d_out
        self.num_heads = num_heads
        # 每个头的维度
        self.head_dim = d_out // num_heads
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        
        # 使用一个 projection 层将所有头的输出合并
        self.out_proj = nn.Linear(d_out, d_out)
        self.dropout = nn.Dropout(p=dropout)
        self.register_buffer(
            'mask',
            torch.triu(torch.ones(context_length, context_length), diagonal=1)
        )
    
    def forward(self, x):
        b, num_tokens, d_in = x.shape
        keys = self.W_key(x)
        queries = self.W_query(x)
        values = self.W_value(x)
        
        # 将 d_out 展开为 (num_heads, head_dim)
        # 最终得到维度 (b, num_tokens, num_heads, head_dim)
        keys = keys.view(b, num_tokens, self.num_heads, self.head_dim)
        values = values.view(b, num_tokens, self.num_heads, self.head_dim)
        queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)
        
        # 转置使得维度变为 (b, num_heads, num_tokens, head_dim)
        # (num_heads, head_dim) 可以按照原来方法计算
        keys = keys.transpose(1, 2)
        queries = queries.transpose(1, 2)
        values = values.transpose(1, 2)
        
        # 计算每个头的注意力分数
        attn_scores = queries @ keys.transpose(2, 3)
        mask_pool = self.mask.bool()[:num_tokens, :num_tokens]
        attn_scores.masked_fill_(mask_pool, -torch.inf)
        
        attn_weights = torch.softmax(attn_scores / (self.head_dim ** 0.5), dim=-1)
        attn_weights = self.dropout(attn_weights)
        
        # attn_weight @ values 的结果维度为 (b, num_heads, num_tokens, head_dim)
        # 通过转置，得到 (b, num_tokens, num_heads, head_dim)
        context_vecs = (attn_weights @ values).transpose(1, 2)
        # 将 context_vecs 的维度从 (b, num_heads, num_tokens, head_dim) 转换为 (b, num_tokens, d_out)
        context_vecs = context_vecs.contiguous().view(b, num_tokens, self.d_out)
        
        # 添加一层 linear projection 变换，但保证输出维度仍然是 (b, num_tokens, d_out)
        context_vecs = self.out_proj(context_vecs)
        
        return context_vecs

In [9]:
inputs = torch.tensor([
    [0.43, 0.15, 0.89], # your
    [0.55, 0.87, 0.66], # journey
    [0.57, 0.85, 0.64], # starts
    [0.22, 0.58, 0.33], # with
    [0.77, 0.25, 0.10], # one 
    [0.05, 0.80, 0.55], # step
])


batch = torch.stack((inputs, inputs), dim = 0)
print(f"batch shape: {batch.shape}")

torch.manual_seed(123)
context_length = batch.shape[1]
mha = MultiHeadAttention(d_in = 3, d_out=2, context_length = context_length, dropout=0.0, num_heads=2, qkv_bias=False)
context_vecs = mha(batch)
print(f"context_vecs shape: {context_vecs.shape}")

batch shape: torch.Size([2, 6, 3])
context_vecs shape: torch.Size([2, 6, 2])
