### 编码注意力机制（Coding attention mechanisms）

In [57]:
import torch

x = torch.tensor(
  [[0.43, 0.15, 0.89], # Your     (x^1)
   [0.55, 0.87, 0.66], # journey  (x^2)
   [0.57, 0.85, 0.64], # starts   (x^3)
   [0.22, 0.58, 0.33], # with     (x^4)
   [0.77, 0.25, 0.10], # one      (x^5)
   [0.05, 0.80, 0.55]] # step     (x^6)
)

In [58]:
input_query = x[1]

attention_scores = torch.empty(x.shape[0])

for i, ipx in enumerate(x):
    attention_scores[i] = torch.dot(ipx, input_query)
    
attention_weights = torch.softmax(attention_scores, dim=-1)

print(attention_weights)

tensor([0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581])


In [59]:
context_vec = torch.zeros(x.shape[1])
print(context_vec)

for i, w in enumerate(attention_weights):
    context_vec += w * x[i]
    
print(context_vec)

tensor([0., 0., 0.])
tensor([0.4419, 0.6515, 0.5683])


#### 3.3.2 Computing attention weights for all input tokens

In [60]:
attention_scores = x @ x.T
attention_weights = torch.softmax(attention_scores, dim=-1)

print(attention_weights)
print(x)

context_vecs = attention_weights @ x
print(context_vecs)

tensor([[0.2098, 0.2006, 0.1981, 0.1242, 0.1220, 0.1452],
        [0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581],
        [0.1390, 0.2369, 0.2326, 0.1242, 0.1108, 0.1565],
        [0.1435, 0.2074, 0.2046, 0.1462, 0.1263, 0.1720],
        [0.1526, 0.1958, 0.1975, 0.1367, 0.1879, 0.1295],
        [0.1385, 0.2184, 0.2128, 0.1420, 0.0988, 0.1896]])
tensor([[0.4300, 0.1500, 0.8900],
        [0.5500, 0.8700, 0.6600],
        [0.5700, 0.8500, 0.6400],
        [0.2200, 0.5800, 0.3300],
        [0.7700, 0.2500, 0.1000],
        [0.0500, 0.8000, 0.5500]])
tensor([[0.4421, 0.5931, 0.5790],
        [0.4419, 0.6515, 0.5683],
        [0.4431, 0.6496, 0.5671],
        [0.4304, 0.6298, 0.5510],
        [0.4671, 0.5910, 0.5266],
        [0.4177, 0.6503, 0.5645]])


#### 3.4.1 Computing the attention weights step by step

In [61]:
d_in = x.shape[1]
d_out = 2
input_2 = x[1]
torch.manual_seed(123)
Q_weights = torch.nn.Parameter(torch.rand(d_in, d_out))
K_weights = torch.nn.Parameter(torch.rand(d_in, d_out))
V_weights = torch.nn.Parameter(torch.rand(d_in, d_out))

query_2 = input_2 @ Q_weights
keys = x @ K_weights
values = x @ V_weights

d_k = keys.shape[1]

print(query_2.shape)
print(keys.shape)

atten_scores_2 = query_2 @ keys.T

atten_weights_2 = torch.softmax(atten_scores_2 / d_k**0.5, dim=-1)

context_vec_2 = atten_weights_2 @ values

print(context_vec_2)

torch.Size([2])
torch.Size([6, 2])
tensor([0.3061, 0.8210], grad_fn=<SqueezeBackward4>)


#### 3.4.2 Implementing a compact SelfAttention class

In [62]:
import torch.nn as nn

torch.manual_seed(123)
 
class SelfAttention_v1(nn.Module):
    def __init__(self, d_in, d_out):
        super().__init__()
        self.Q_weights = nn.Parameter(torch.rand(d_in, d_out))
        self.K_weights = nn.Parameter(torch.rand(d_in, d_out))
        self.V_weights = nn.Parameter(torch.rand(d_in, d_out))
        
    def forward(self, inputs):
        queries = inputs @ self.Q_weights
        keys = inputs @ self.K_weights
        values = inputs @ self.V_weights
        
        atten_scores = queries @ keys.T
        atten_weights = torch.softmax(atten_scores / d_k**0.5, dim=-1)
        
        context_vec = atten_weights @ values
        
        return context_vec
    
sa = SelfAttention_v1(d_in=d_in, d_out=d_out)
context_vec = sa(x)

print(context_vec)

tensor([[0.2996, 0.8053],
        [0.3061, 0.8210],
        [0.3058, 0.8203],
        [0.2948, 0.7939],
        [0.2927, 0.7891],
        [0.2990, 0.8040]], grad_fn=<MmBackward0>)


In [63]:
torch.manual_seed(123)
 
class SelfAttention_v2(nn.Module):
    def __init__(self, d_in, d_out, qkv_bias = False):
        super().__init__()
        self.Q_weights = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.K_weights = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.V_weights = nn.Linear(d_in, d_out, bias=qkv_bias)
        
    def forward(self, inputs):
        queries = self.Q_weights(inputs)
        keys = self.K_weights(inputs)
        values = self.V_weights(inputs)
        
        atten_scores = queries @ keys.T
        atten_weights = torch.softmax(atten_scores / d_k**0.5, dim=-1)
        
        context_vec = atten_weights @ values
        
        return context_vec
    
sa = SelfAttention_v2(d_in=d_in, d_out=d_out)
context_vec = sa(x)

print(context_vec)

tensor([[-0.5337, -0.1051],
        [-0.5323, -0.1080],
        [-0.5323, -0.1079],
        [-0.5297, -0.1076],
        [-0.5311, -0.1066],
        [-0.5299, -0.1081]], grad_fn=<MmBackward0>)


#### 3.5.1 Applying a causal attention mask

In [64]:
queries = sa.Q_weights(x)
keys = sa.K_weights(x)
values = sa.V_weights(x)
        
atten_scores = queries @ keys.T

context_len = x.shape[0]
mask = torch.triu(torch.ones(context_len, context_len), diagonal=1)
masked_simple = atten_scores.masked_fill(mask.bool(), -torch.inf)
print(masked_simple)

atten_weights = torch.softmax(masked_simple / d_k**0.5, dim=-1)
print(atten_weights)

tensor([[0.3111,   -inf,   -inf,   -inf,   -inf,   -inf],
        [0.1655, 0.2602,   -inf,   -inf,   -inf,   -inf],
        [0.1667, 0.2602, 0.2577,   -inf,   -inf,   -inf],
        [0.0510, 0.1080, 0.1064, 0.0643,   -inf,   -inf],
        [0.1415, 0.1875, 0.1863, 0.0987, 0.1121,   -inf],
        [0.0476, 0.1192, 0.1171, 0.0731, 0.0477, 0.0966]],
       grad_fn=<MaskedFillBackward0>)
tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.4833, 0.5167, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3190, 0.3408, 0.3402, 0.0000, 0.0000, 0.0000],
        [0.2445, 0.2545, 0.2542, 0.2468, 0.0000, 0.0000],
        [0.1994, 0.2060, 0.2058, 0.1935, 0.1953, 0.0000],
        [0.1624, 0.1709, 0.1706, 0.1654, 0.1625, 0.1682]],
       grad_fn=<SoftmaxBackward0>)


#### 3.5.2 Masking additional attention weights with dropout

In [65]:
torch.manual_seed(123)

dropout = torch.nn.Dropout(0.5)

simple = torch.ones(6, 6)

dropout(simple)

tensor([[2., 2., 0., 2., 2., 0.],
        [0., 0., 0., 2., 0., 2.],
        [2., 2., 2., 2., 0., 2.],
        [0., 2., 2., 0., 0., 2.],
        [0., 2., 0., 2., 0., 2.],
        [0., 2., 2., 2., 2., 0.]])

#### 3.5.3 Implementing a compact causal self-attention class

In [66]:
batch = torch.stack((x, x), dim=0)
batch

tensor([[[0.4300, 0.1500, 0.8900],
         [0.5500, 0.8700, 0.6600],
         [0.5700, 0.8500, 0.6400],
         [0.2200, 0.5800, 0.3300],
         [0.7700, 0.2500, 0.1000],
         [0.0500, 0.8000, 0.5500]],

        [[0.4300, 0.1500, 0.8900],
         [0.5500, 0.8700, 0.6600],
         [0.5700, 0.8500, 0.6400],
         [0.2200, 0.5800, 0.3300],
         [0.7700, 0.2500, 0.1000],
         [0.0500, 0.8000, 0.5500]]])

In [67]:
class CausalAttention(nn.Module):
    def __init__(self, d_in, d_out, context_len, dropout, qkv_bias=False):
        super().__init__()
        self.d_out = d_out
        self.Q_weights = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.K_weights = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.V_weights = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.dropout = nn.Dropout(dropout)
        self.register_buffer('mask', torch.triu(torch.ones(context_len, context_len), diagonal=1))
        
    def forward(self, x):
        b, num_tokens, d_in = x.shape
        queries = self.Q_weights(x)
        keys = self.K_weights(x)
        values = self.V_weights(x)
        
        atten_scores = queries @ keys.transpose(1, 2)
        atten_scores.masked_fill_(
            self.mask.bool()[:num_tokens, :num_tokens], -torch.inf
        )
        
        atten_weights = torch.softmax(atten_scores / keys.shape[-1]**0.5, dim=-1)
        
        atten_weights = self.dropout(atten_weights)
        
        context_vec = atten_weights @ values
        
        return context_vec

torch.manual_seed(123) 
context_len = batch.shape[1]
sa = CausalAttention(d_in=d_in, d_out=d_out, context_len=context_len, dropout=0)
context_vec = sa(batch)

print(context_vec)

tensor([[[-0.4519,  0.2216],
         [-0.5874,  0.0058],
         [-0.6300, -0.0632],
         [-0.5675, -0.0843],
         [-0.5526, -0.0981],
         [-0.5299, -0.1081]],

        [[-0.4519,  0.2216],
         [-0.5874,  0.0058],
         [-0.6300, -0.0632],
         [-0.5675, -0.0843],
         [-0.5526, -0.0981],
         [-0.5299, -0.1081]]], grad_fn=<UnsafeViewBackward0>)


#### 3.6.1 Stacking multiple single-head attention layers

In [68]:
class MultiHeadAttentionWrapper(nn.Module):
    def __init__(self, d_in, d_out, context_len, dropout, heads, qkv_bias=False):
        super().__init__()
        self.headers = nn.ModuleList([
            CausalAttention(d_in, d_out, context_len, dropout, qkv_bias) for _ in range(heads)
        ])
        
    def forward(self, x):
        return torch.cat([ca(x) for ca in self.headers], dim=-1)
    
torch.manual_seed(123)
context_length = batch.shape[1] # 这是词元的数量
d_in, d_out = 3, 2
mha = MultiHeadAttentionWrapper(
    d_in, d_out, context_length, 0.0, heads=2
)
context_vecs = mha(batch)

print(context_vecs)
print("context_vecs.shape:", context_vecs.shape)

tensor([[[-0.4519,  0.2216,  0.4772,  0.1063],
         [-0.5874,  0.0058,  0.5891,  0.3257],
         [-0.6300, -0.0632,  0.6202,  0.3860],
         [-0.5675, -0.0843,  0.5478,  0.3589],
         [-0.5526, -0.0981,  0.5321,  0.3428],
         [-0.5299, -0.1081,  0.5077,  0.3493]],

        [[-0.4519,  0.2216,  0.4772,  0.1063],
         [-0.5874,  0.0058,  0.5891,  0.3257],
         [-0.6300, -0.0632,  0.6202,  0.3860],
         [-0.5675, -0.0843,  0.5478,  0.3589],
         [-0.5526, -0.0981,  0.5321,  0.3428],
         [-0.5299, -0.1081,  0.5077,  0.3493]]], grad_fn=<CatBackward0>)
context_vecs.shape: torch.Size([2, 6, 4])


#### 3.6.2 通过权重划分实现多头注意力

In [None]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_in, d_out, context_len, dropout, num_heads, qkv_bias=False):
        super().__init__()
        
        assert(d_out % num_heads == 0), "d_out must be divisible by num_heads"
        
        self.d_out = d_out
        self.num_heads = num_heads
        self.head_dim = d_out // num_heads
        
        self.Q_weights = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.K_weights = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.V_weights = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.out_proj = nn.Linear(d_out, d_out)
        self.dropout = nn.Dropout(dropout)
        self.register_buffer('mask', torch.triu(torch.ones(context_len, context_len), diagonal=1))
        
    def forward(self, x):
        b, num_tokens, d_in = x.shape
        queries = self.Q_weights(x)
        keys = self.K_weights(x)
        values = self.V_weights(x)
        
        keys = keys.view(b, num_tokens, self.num_heads, self.head_dim)
        values = values.view(b, num_tokens, self.num_heads, self.head_dim)
        queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)
        
        keys = keys.transpose(1, 2)
        queries = queries.transpose(1, 2)
        values = values.transpose(1, 2)
        
        atten_scores = queries @ keys.transpose(2, 3)
        atten_scores.masked_fill_(
            self.mask.bool()[:num_tokens, :num_tokens], -torch.inf
        )
        
        atten_weights = torch.softmax(atten_scores / keys.shape[-1]**0.5, dim=-1)
        
        atten_weights = self.dropout(atten_weights)
        
        context_vec = (atten_weights @ values).transpose(1, 2)
        
        context_vec = context_vec.contiguous().view(b, num_tokens, self.d_out)
        
        context_vec = self.out_proj(context_vec)
        
        return context_vec

torch.manual_seed(123) 
context_len = batch.shape[1]

d_out = 4

sa = CausalAttention(d_in=d_in, d_out=d_out, context_len=context_len, dropout=0)
context_vec = sa(batch)

print(context_vec)

tensor([[[-0.3132, -0.2272,  0.4772,  0.1063],
         [-0.2320,  0.0293,  0.5789,  0.3056],
         [-0.2068,  0.1162,  0.6118,  0.3695],
         [-0.1635,  0.1328,  0.5457,  0.3531],
         [-0.1687,  0.1813,  0.5315,  0.3400],
         [-0.1411,  0.1727,  0.5063,  0.3432]],

        [[-0.3132, -0.2272,  0.4772,  0.1063],
         [-0.2320,  0.0293,  0.5789,  0.3056],
         [-0.2068,  0.1162,  0.6118,  0.3695],
         [-0.1635,  0.1328,  0.5457,  0.3531],
         [-0.1687,  0.1813,  0.5315,  0.3400],
         [-0.1411,  0.1727,  0.5063,  0.3432]]], grad_fn=<UnsafeViewBackward0>)
