In [1]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F

In [None]:
class SelfAttention(nn.Module):
    def __init__(self, in_dim, out_dim, hidden_dim, dropout=0.1):
        super().__init__()
        self.in_dim = in_dim
        self.out_dim = out_dim  
        self.hidden_dim = hidden_dim
        
        self.W_q = nn.Linear(in_dim, hidden_dim) 
        self.W_k = nn.Linear(in_dim, hidden_dim) 
        self.W_v = nn.Linear(in_dim, hidden_dim)
        self.W_o = nn.Linear(hidden_dim, out_dim)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x, mask=None):
        q = self.W_q(x) 
        # (batch_size, seq_len, in_dim) * (in_dim, hidden_dim)  = (batch_size, seq_len, hidden_dim)
        k = self.W_k(x) 
        v = self.W_v(x) 
        # the shape of q, k, v is (batch_size, seq_len, hidden_dim)
        attn = torch.matmul(q, k.transpose(1, 2)) / math.sqrt(self.hidden_dim)
        # attn = q @ k.transpose(1, 2) / math.sqrt(self.hidden_dim)
        # the shape of attn is (batch_size, seq_len, seq_len)
        if mask is not None:
            attn = attn.masked_fill(mask == 0, -1e9)
        attn = F.softmax(attn / math.sqrt(self.hidden_dim), dim=-1)
        attn = self.dropout(attn)
        attn = torch.matmul(attn, v)
        # attn = attn @ v 
        # the shape of attn is (batch_size, seq_len, hidden
        return self.W_o(attn)

In [6]:
class MultiHeadAttention(nn.Module):
    def __init__(self, in_dim, out_dim, hidden_dim, num_heads, dropout=0.1):
        super().__init__()
        self.in_dim = in_dim
        self.out_dim = out_dim
        self.hidden_dim = hidden_dim
        self.num_heads = num_heads
        self.head_dim = hidden_dim // num_heads 
        self.W_q = nn.Linear(in_dim, hidden_dim)
        self.W_k = nn.Linear(in_dim, hidden_dim)
        self.W_v = nn.Linear(in_dim, hidden_dim)
        self.W_o = nn.Linear(hidden_dim, out_dim)
        # self.attn = SelfAttention(hidden_dim, hidden_dim, hidden_dim, dropout)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x, mask=None):
        batch_size = x.size(0)
        q = self.W_q(x) 
        k = self.W_k(x) 
        v = self.W_v(x) 
        # the shape of q, k, v is (batch_size, seq_len, hidden_dim)
        q = q.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
        k = k.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
        v = v.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
        # the shape of q, k, v is (batch_size, num_heads, seq_len, head_dim)
        # 在缩放点积注意力计算中，我们需要计算每个查询 (Q) 与所有键 (K) 之间的相似度。
        # 为了并行计算多个头的注意力，我们需要将 num_heads 维度放在前面，这样就可以将每个头的数据视为一个独立的批次进行处理。
        
        attn = q @ k.transpose(-2, -1) / math.sqrt(self.head_dim)
        if mask is not None:
            attn = attn.masked_fill(mask == 0, -1e9)
        attn = F.softmax(attn, dim=-1)
        attn = self.dropout(attn)
        # the shape of attn is (batch_size, num_heads, seq_len, seq_len)
        attn = attn @ v
        # the shape of attn is (batch_size, num_heads, seq_len, head_dim)
        attn = attn.transpose(1, 2).contiguous().view(batch_size, -1, self.hidden_dim)
        # the shape of attn is (batch_size, seq_len, hidden_dim)
        return self.W_o(attn), attn

In [7]:
if __name__ == '__main__':
    # Example usage
    batch_size = 32
    seq_len = 10
    hidden_num = 512
    num_heads = 8

    # Create random input tensors
    x = torch.randn(batch_size, seq_len, hidden_num)

    # Create a mask (optional)
    mask = torch.ones(batch_size, 1, seq_len, seq_len)
    mask[:, :, :, :5] = 0  # Mask the first 5 positions
    # Instantiate the SelfAttention module
    self_attn = SelfAttention(hidden_num, hidden_num, hidden_num)

    # Instantiate the MultiHeadAttention module
    multihead_attn = MultiHeadAttention(
        hidden_num, hidden_num, hidden_num, num_heads)

    # Perform self-attention
    output_self_attention = self_attn(x, mask)

    # Perform multi-head attention
    output_mutlihead_attention,_ = multihead_attn(x, mask)

    # Print the output shape
    print(output_self_attention.size())
    print(output_mutlihead_attention.size())


torch.Size([32, 32, 10, 512])
torch.Size([32, 10, 512])
