In [None]:
import math
import torch
import torch.nn as nn

class SelfAttentionV1(nn.Module):
    def __init__(self, hidden_dim: int = 728) -> None:
        super().__init__()

        self.hidden_dim = hidden_dim
        self.q_proj = nn.Linear(hidden_dim, hidden_dim) # (in_feature, out_feature)
        self.k_proj = nn.Linear(hidden_dim, hidden_dim)
        self.v_proj = nn.Linear(hidden_dim, hidden_dim)
     # the Shape of X is (batch_size, seq_length, hidden_dim)
    def forward(self, X):
        Q = self.q_proj(X)
        K = self.k_proj(X)
        V = self.v_proj(X)

        attention_value = torch.matmul(Q, K.transpose(-1, -2))

        attention_weight = torch.softmax(
            attention_value / math.sqrt(self.hidden_dim), 
            dim=-1
        )
        print(attention_weight) # the Shape is (batch_size, seq_length, seq_length)

        output = torch.matmul(attention_weight, V)
        return output

In [None]:
X = torch.rand(3, 2, 4)
net = SelfAttentionV1(4)
print(X)
net(X)

In [None]:
import math
import torch
import torch.nn as nn

class SelfAttentionV2(nn.Module):
    def __init__(self, dim, dropout_rate = 0.1) -> None:
        super().__init__()

        self.dim = dim
        self.proj = nn.Linear(dim, dim * 3)
        self.attention_dropout = nn.Dropout(dropout_rate)
        self.output_proj = nn.Linear(dim, dim) # 可选

    def forward(self, X, attention_mask = None):
        QKV = self.proj(X)
        Q, K, V = torch.split(QKV, self.dim, dim=-1)
        attention_score = Q @ K.transpose(-1, -2) / math.sqrt(self.dim)
        if attention_mask is not None: 
            # 如果有attention_mask，就对无效位置（即attention_mask=0的位置）填充一个极小值
            # 这样softmax时就会让这个无效位置权重接近0，即“忽略”
            attention_score = attention_score.masked_fill(
                attention_mask == 0,
                float("-1e20")
            )
        print(f"score before softmax:\n {attention_score}")
        attention_weight = torch.softmax(
            attention_score,
            dim=-1
        )
        print(f"weight after softmax:\n {attention_weight}")
        attention_weight = self.attention_dropout(attention_weight)
        print(f"weight after softmax & dropout:\n {attention_weight}")
        attention_result = attention_score @ V
        print(f"attention result(before linear layer):\n {attention_result}")
        output = self.output_proj(attention_result)

        return output # 可以看到每个token的向量表示，形状是（batch_size, seq_length, dim)
    
X = torch.rand(3, 4, 2)
mask = torch.tensor([
    [1, 1, 1, 0],
    [1, 1, 0, 0],
    [1, 0, 0, 0]
])
print(f"original mask shape:{mask.shape}")
mask = mask.unsqueeze(dim=1).repeat(1, 4, 1) 
# unsqueeze(dim=1)是在第1维（即第二个维度）上增加一个维度，repeat(1, 4, 1)重复4次，4代表seq_length
# 让mask变成和attention_weight一样的形状，即（batch_size, seq_length, seq_length）
print(f"repeat mask shape:{mask.shape}")
net = SelfAttentionV2(2)
net(X, mask) # 经过一个transformer block之后，shape不变

In [8]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F

class SelfAttentionV3(nn.Module):
    def __init__(self, dim: int, dropout_rate: float = 0.1) -> None:
        super().__init__()
        self.dim = dim
        self.q = nn.Linear(dim, dim)
        self.k = nn.Linear(dim, dim)
        self.v = nn.Linear(dim, dim)
        self.dropout = nn.Dropout(dropout_rate)

    def forward(self, X, attention_mask=None):
        Q = self.q(X)
        K = self.k(X)
        V = self.v(X)

        scores = Q @ K.transpose(-1, -2) / math.sqrt(self.dim)
        if attention_mask is not None:
            scores = scores.masked_fill(attention_mask == 0, float('-inf'))
        attention_weights = torch.softmax(scores, dim=-1)
        attention_weights = self.dropout(attention_weights)
        output = attention_weights @ V
        return output

class SimpleTransformerBlock(nn.Module):
    def __init__(self, dim, vocab_size, dropout_rate=0.1):
        super().__init__()
        self.attention = SelfAttentionV3(dim, dropout_rate)
        self.ln1 = nn.LayerNorm(dim)
        self.ffn = nn.Sequential(
            nn.Linear(dim, dim * 4),
            nn.ReLU(),
            nn.Linear(dim * 4, dim),
            nn.Dropout(dropout_rate)
        )
        self.ln2 = nn.LayerNorm(dim)
        self.output_layer = nn.Linear(dim, vocab_size)

    def forward(self, X, attention_mask=None):
        # Self-Attention + Residual + LayerNorm
        attn_out = self.attention(X, attention_mask)
        x = self.ln1(X + attn_out)
        # Feed Forward + Residual + LayerNorm
        ffn_out = self.ffn(x)
        x = self.ln2(x + ffn_out)
        # 输出映射到词表
        logits = self.output_layer(x)  # shape: [batch, seq_len, vocab_size]
        return logits

# 测试代码
batch_size = 1
seq_len = 4
hidden_dim = 8
vocab_size = 1000

# 假输入（随机初始化）
X = torch.randn(batch_size, seq_len, hidden_dim)

# 全部token可见mask
attention_mask = torch.ones(batch_size, seq_len, seq_len)

model = SimpleTransformerBlock(hidden_dim, vocab_size)

logits = model(X, attention_mask)

# 取最后一个token的logits预测下一个词
last_token_logits = logits[:, -1, :]
probs = torch.softmax(last_token_logits, dim=-1)
predicted_token_id = torch.argmax(probs, dim=-1)

print("预测的下一个 token id:", predicted_token_id.item())


预测的下一个 token id: 872
