Библиотеки

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math


Механизм внимания

In [2]:
class ScaledDotProductAttention(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, Q, K, V):
        # Q (Query) - что ищу
        # K (Key) - что я предлагаю
        # V (Value) - какую информацию отдаю
        d_k = Q.size(-1)
        
        # насколько токен i связан с токеном j
        scores = torch.matmul(Q, K.transpose(-2, -1))  # (batch, seq, seq)

        # масштабирование для избежания насыщения градиента
        scores = scores / math.sqrt(d_k)

        attn_weights = F.softmax(scores, dim=-1)

        # взвешенная сумма V
        output = torch.matmul(attn_weights, V)

        return output, attn_weights


In [15]:
torch.manual_seed(42)

# пример токенизированного текста
batch_size = 1
seq_len = 3
embed_dim = 4

x = torch.randn(batch_size, seq_len, embed_dim)

print("Input x:")
print(x)
print("Shape:", x.shape)

Input x:
tensor([[[ 0.3367,  0.1288,  0.2345,  0.2303],
         [-1.1229, -0.1863,  2.2082, -0.6380],
         [ 0.4617,  0.2674,  0.5349,  0.8094]]])
Shape: torch.Size([1, 3, 4])


In [16]:
# Получение Q, K, V из токенизированного текста
W_q = nn.Linear(embed_dim, embed_dim, bias=False)
W_k = nn.Linear(embed_dim, embed_dim, bias=False)
W_v = nn.Linear(embed_dim, embed_dim, bias=False)

Q = W_q(x)
K = W_k(x)
V = W_v(x)

print("\nQ:")
print(Q)
print("\nK:")
print(K)
print("\nV:")
print(V)


Q:
tensor([[[-0.2649, -0.0397,  0.1738,  0.0548],
         [ 0.3662,  1.3071, -1.0046,  0.0585],
         [-0.5627, -0.2125,  0.3637,  0.0456]]], grad_fn=<UnsafeViewBackward0>)

K:
tensor([[[ 0.1799,  0.1574, -0.1143,  0.0052],
         [ 0.0543,  0.2965, -1.1980,  0.5401],
         [ 0.4996,  0.3018, -0.3854,  0.1773]]], grad_fn=<UnsafeViewBackward0>)

V:
tensor([[[ 0.0620, -0.0369,  0.0671,  0.0050],
         [ 0.9148,  0.4785,  0.9522,  1.2772],
         [-0.0075, -0.1336,  0.0248,  0.1580]]], grad_fn=<UnsafeViewBackward0>)


In [20]:
# Пример применения реализованного механизма внимания
attntn = ScaledDotProductAttention()
output, attn_weights = attntn.forward(Q=Q, K=K, V=V)
print(attn_weights)

tensor([[[0.3479, 0.3258, 0.3263],
         [0.2372, 0.4445, 0.3183],
         [0.3692, 0.3133, 0.3176]]], grad_fn=<SoftmaxBackward0>)


Multi-head attention

In [32]:
class MultiHeadAttention(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super().__init__()

        assert embed_dim % num_heads == 0

        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads

        # Линейные преобразования для Q, K, V
        self.q_proj = nn.Linear(embed_dim, embed_dim)
        self.k_proj = nn.Linear(embed_dim, embed_dim)
        self.v_proj = nn.Linear(embed_dim, embed_dim)

        self.attention = ScaledDotProductAttention()

        # Финальная проекция
        self.out_proj = nn.Linear(embed_dim, embed_dim)

    def forward(self, x, mask=None, return_attention=False):
        """
        x: (batch_size, seq_len, embed_dim)
        """
        batch_size, seq_len, _ = x.size()

        # Q, K, V
        Q = self.q_proj(x)
        K = self.k_proj(x)
        V = self.v_proj(x)

        # reshape для multi-head
        Q = Q.view(batch_size, seq_len, self.num_heads, self.head_dim)
        K = K.view(batch_size, seq_len, self.num_heads, self.head_dim)
        V = V.view(batch_size, seq_len, self.num_heads, self.head_dim)

        # (batch, heads, seq, head_dim)
        Q = Q.transpose(1, 2)
        K = K.transpose(1, 2)
        V = V.transpose(1, 2)

        # применяем внимание
        attn_output, attn_weights = self.attention(Q, K, V, mask)

        # объединяем головы
        attn_output = attn_output.transpose(1, 2)
        attn_output = attn_output.contiguous().view(
            batch_size, seq_len, self.embed_dim
        )

        output = self.out_proj(attn_output)

        if return_attention:
            return output, attn_weights
    
        return output


Проверка на случайных данных

In [34]:
batch_size = 2
seq_len = 5
embed_dim = 16
num_heads = 4

x = torch.randn(batch_size, seq_len, embed_dim)

mha = MultiHeadAttention(embed_dim, num_heads)
output, attn_weights = mha.forward(x, return_attention=True)


print(x.shape)
print(output.shape)
print(torch.allclose(x, output))
print(attn_weights.shape)
print(attn_weights[0, 0])



torch.Size([2, 5, 16])
torch.Size([2, 5, 16])
False
torch.Size([2, 4, 5, 5])
tensor([[0.2267, 0.2042, 0.2597, 0.1530, 0.1564],
        [0.2125, 0.2126, 0.2224, 0.1817, 0.1708],
        [0.1179, 0.1815, 0.1725, 0.2948, 0.2331],
        [0.1439, 0.1855, 0.2088, 0.2448, 0.2171],
        [0.1913, 0.2026, 0.2247, 0.1939, 0.1874]], grad_fn=<SelectBackward0>)


In [36]:
x2 = x.clone()
x2[0, 0] += 10.0   # меняем один токен для проверки взаимного влияния

out1 = mha.forward(x)
out2 = mha.forward(x2)

print("difference in other positions:",
      torch.norm(out1[0, 1:] - out2[0, 1:]))

difference in other positions: tensor(9.5380, grad_fn=<LinalgVectorNormBackward0>)


Интеграция в простую модель

Embedding — «кто я»
Attention — «кто рядом и как он на меня влияет»
Pooling — «общее впечатление»
Classifier — «решение»

In [5]:
class SimpleAttentionModel(nn.Module):
    def __init__(self, vocab_size, embed_dim, num_heads, num_classes):
        super().__init__()

        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.attention = MultiHeadAttention(embed_dim, num_heads)
        self.pool = nn.AdaptiveAvgPool1d(1)
        self.classifier = nn.Linear(embed_dim, num_classes)

    def forward(self, x):
        """
        x: (batch_size, seq_len)
        """
        x = self.embedding(x)
        x = self.attention(x)

        # pooling
        x = x.transpose(1, 2)
        x = self.pool(x).squeeze(-1)

        return self.classifier(x)


Проверка работы модели

In [13]:
# параметры
vocab_size = 100
embed_dim = 16
num_heads = 4
num_classes = 3
seq_len = 8
batch_size = 10

model = SimpleAttentionModel(
    vocab_size=vocab_size,
    embed_dim=embed_dim,
    num_heads=num_heads,
    num_classes=num_classes
)

# случайный "текст"
x = torch.randint(0, vocab_size, (batch_size, seq_len))

# прямой проход
with torch.no_grad():
    logits = model(x)

preds = torch.argmax(logits, dim=1)

print("Input tokens:\n", x)
print("Logits:\n", logits)
print("Predicted classes:\n", preds)


Input tokens:
 tensor([[ 9, 27, 38, 83, 65, 34, 73, 26],
        [88,  2, 46, 21, 83, 67, 72, 43],
        [47,  2, 71, 52, 19, 21, 95, 93],
        [ 3, 29, 40, 38,  5,  0, 66, 60],
        [51, 81, 81, 21, 79, 94,  7, 76],
        [44, 96, 60, 19, 49, 12, 14, 74],
        [48, 45,  6, 66, 24, 52, 77, 83],
        [71, 29, 20, 54,  1, 62, 83, 40],
        [17, 75, 36, 98, 27, 89, 52, 28],
        [10, 98,  7, 89, 30,  2, 66, 80]])
Logits:
 tensor([[ 0.1607,  0.2840, -0.0180],
        [ 0.1302,  0.3509, -0.0536],
        [ 0.0769,  0.4199,  0.1015],
        [ 0.1840,  0.2547, -0.2494],
        [ 0.0160,  0.4413, -0.0203],
        [ 0.1213,  0.2651, -0.0520],
        [ 0.1867,  0.1728, -0.2080],
        [ 0.2382,  0.2816,  0.1088],
        [ 0.1682,  0.2868,  0.0510],
        [ 0.1238,  0.3252, -0.0445]])
Predicted classes:
 tensor([1, 1, 1, 1, 1, 1, 0, 1, 1, 1])
