In [None]:
!pip install torch



In [None]:
import torch

In [None]:
# Dữ liệu mẫu
test_sentences = "I love ML and DL technology"
tokens = test_sentences.split()
n = len(tokens)
d = 3

# Tạo Q, K, V ngẫu nhiên (thay bằng embedding thực tế trong ứng dụng)
Q = torch.rand(n, d)
K = torch.rand(n, d)
V = torch.rand(n, d)

print("Q:\n", Q)
print("K:\n", K)
print("V:\n", V)

Q:
 tensor([[0.3238, 0.9859, 0.5374],
        [0.0121, 0.1617, 0.0553],
        [0.5467, 0.5247, 0.3065],
        [0.3482, 0.1910, 0.0825],
        [0.6001, 0.9960, 0.6995],
        [0.6826, 0.7340, 0.9003]])
K:
 tensor([[0.2702, 0.9027, 0.5765],
        [0.6443, 0.8890, 0.9124],
        [0.6316, 0.6336, 0.1246],
        [0.3250, 0.4294, 0.2149],
        [0.0456, 0.6882, 0.1571],
        [0.6904, 0.4364, 0.8812]])
V:
 tensor([[0.4470, 0.0603, 0.8035],
        [0.3097, 0.8499, 0.5492],
        [0.3506, 0.7157, 0.2504],
        [0.3466, 0.3883, 0.1323],
        [0.9092, 0.4529, 0.8132],
        [0.9157, 0.8571, 0.5813]])


In [None]:
def full_attention(Q, K, V):
    d_k = Q.size(-1)
    scores = torch.matmul(Q, K.transpose(-2, -1)) / torch.sqrt(torch.tensor(d_k, dtype=torch.float32))
    attn_weights = torch.softmax(scores, dim=-1)
    output = torch.matmul(attn_weights, V)
    return output

result_full = full_attention(Q, K, V)
print("Full Attention Output:\n", result_full)

Full Attention Output:
 tensor([[0.5340, 0.5674, 0.5406],
        [0.5450, 0.5542, 0.5249],
        [0.5374, 0.5730, 0.5273],
        [0.5418, 0.5647, 0.5211],
        [0.5321, 0.5800, 0.5396],
        [0.5360, 0.5901, 0.5380]])


In [None]:
def sliding_window_attention(Q, K, V, w=2):
    n, d_k = Q.size() # => [số token, Q_dim]
    output = torch.zeros_like(Q)
    for i in range(n):
        # Xác định các token xung quanh token i trong vùng cửa sổ (window)
        # kích thước cửa sổ w = 2 (tức +- 1 token)
        start = max(0, i - w//2)
        end = min(n, i + w//2 + 1)
        # Xác định vị trí của Q
        Q_i = Q[i:i+1]
        K_w = K[start:end]
        V_w = V[start:end]

        scores = torch.matmul(Q_i, K_w.transpose(-2, -1)) / torch.sqrt(torch.tensor(d_k, dtype=torch.float32))
        attn_weights = torch.softmax(scores, dim=-1)
        output[i] = torch.matmul(attn_weights, V_w).squeeze(0)
    return output

result_sliding = sliding_window_attention(Q, K, V, w=2)
print("Sliding Window Attention Output:\n", result_sliding)

Sliding Window Attention Output:
 tensor([[0.3727, 0.4879, 0.6658],
        [0.3691, 0.5411, 0.5379],
        [0.3331, 0.6783, 0.3372],
        [0.5268, 0.5250, 0.3905],
        [0.7522, 0.6058, 0.5215],
        [0.9133, 0.7067, 0.6676]])


In [None]:
def dilated_sliding_window_attention(Q, K, V, w=2, d=2):
    n, d_k = Q.size() # => [số token, Q_dim]
    output = torch.zeros_like(Q)
    for i in range(n):
        indices = []
        # Thu thập vị trí của Key mà Query có thể liên hệ đến với bước nhảy d.
        # Ví dụ: w = 2 (+-1) => cho i = 3: j = 3 - 2 = 1 (bên trái) j = 3 + 2 = 5(bên phải)
        for k in range(-w//2, w//2 + 1):
            j = i + k * d
            if 0 <= j < n:
                indices.append(j)
        Q_i = Q[i:i+1]
        K_w = K[indices]
        V_w = V[indices]
        scores = torch.matmul(Q_i, K_w.transpose(-2, -1)) / torch.sqrt(torch.tensor(d_k, dtype=torch.float32))
        attn_weights = torch.softmax(scores, dim=-1)
        output[i] = torch.matmul(attn_weights, V_w).squeeze(0)

    return output


result_dilated = dilated_sliding_window_attention(Q, K, V, w=2, d=2)
print("Dilated Sliding Window Attention Output:\n", result_dilated)

Dilated Sliding Window Attention Output:
 tensor([[0.4042, 0.3511, 0.5580],
        [0.3276, 0.6269, 0.3478],
        [0.5492, 0.4015, 0.6155],
        [0.5270, 0.7113, 0.4327],
        [0.6078, 0.5947, 0.5095],
        [0.7000, 0.6794, 0.4111]])


In [None]:
def global_sliding_attention(Q, K, V, w=2, global_indices=[0]):
    n, d_k = Q.size()
    output = torch.zeros_like(Q)

    for i in range(n):
        # Nếu i là global => tính full attention
        if i in global_indices:
            scores = torch.matmul(Q[i:i+1], K.transpose(-2, -1)) / torch.sqrt(torch.tensor(d_k, dtype=torch.float32))
            attn_weights = torch.softmax(scores, dim=-1)
            output[i] = torch.matmul(attn_weights, V).squeeze(0)
        else:
            start = max(0, i - w//2)
            end = min(n, i + w//2 + 1)
            indices = list(range(start, end)) + global_indices
            indices = sorted(set([j for j in indices if 0 <= j < n]) )
            Q_i = Q[i:i+1]
            K_w = K[indices]
            scores = torch.matmul(Q_i, K.transpose(-2, -1)) / torch.sqrt(torch.tensor(d_k, dtype=torch.float32))
            attn_weights = torch.softmax(scores, dim=-1)
            output[i] = torch.matmul(attn_weights, V).squeeze(0)

    return output


result_global = global_sliding_attention(Q, K, V, w=2, global_indices=[0])
print("Global + Sliding Window Attention Output:\n", result_global)

Global + Sliding Window Attention Output:
 tensor([[0.5340, 0.5674, 0.5406],
        [0.5450, 0.5542, 0.5249],
        [0.5374, 0.5730, 0.5273],
        [0.5418, 0.5647, 0.5211],
        [0.5321, 0.5800, 0.5396],
        [0.5360, 0.5901, 0.5380]])
