- input embedding + encoder + decoder
    - position encoding is in input embedding
- rnn 天然编码了位置信息
    - $h_{t}=f(h_{t-1},x_t)$
    - $f$ 是非线性激活函数，$h_{t-1}$ 是前一时间步的隐藏状态，$x_t$ 是当前时间步的输入。由于 $h_t$ 依赖于 $h_{t-1}$，而 $h_{t-1}$ 又依赖于 $h_{t-2}$，以此类推，隐藏状态包含了从初始时间步到当前时间步的所有历史信息。这种递归结构使得位置信息被隐式地编码在隐藏状态中。
    - RNN 通过其递归结构隐式地编码位置信息，而 Transformer 需要通过**显式添加位置编码**来获取位置信息。
- 如果在 Transformer Encoder 中没有使用位置编码，那么模型将无法区分输入序列中各个词的顺序，这实际上等同于一个词袋（Bag of Words）模型。原因是 Transformer 的**自注意力机制本质上是对输入的加权求和**，而没有位置编码的情况下，模型无法获取任何位置信息。
- Permutation Equivariance（排列等变）
    - **Permutation Equivariance（排列等变）**：如果对输入序列进行某种排列，模型的输出将以相同的方式被排列。
    - Permutation Invariance（排列不变）：对输入序列的排列不会影响模型的输出，即输出与输入的排列无关。
    - 没有位置编码的 Transformer Encoder 并不是排列不变的，而是排列等变的。这意味着如果我们改变输入序列中词的顺序，输出序列中的元素也会按照相同的方式重新排列，但输出本身的数值不会保持不变。

In [6]:
import torch
import torch.nn as nn

In [7]:
torch.manual_seed(42)

<torch._C.Generator at 0x77b7bc12a070>

In [8]:
# 定义模型参数
vocab_size = 10000  # 词汇表大小
d_model = 512       # 嵌入维度
nhead = 8           # 注意力头数
num_layers = 1      # Transformer Encoder 层数

### RNN

In [9]:
# 定义输入序列
sequence_length = 5  # 序列长度
embedding_dim = 8    # 词嵌入维度
batch_size = 1       # 批大小

In [10]:
original_sequence = torch.randn(batch_size, sequence_length, embedding_dim)
original_sequence.shape

torch.Size([1, 5, 8])

In [11]:
original_sequence

tensor([[[ 1.9269,  1.4873,  0.9007, -2.1055,  0.6784, -1.2345, -0.0431,
          -1.6047],
         [-0.7521,  1.6487, -0.3925, -1.4036, -0.7279, -0.5594, -0.7688,
           0.7624],
         [ 1.6423, -0.1596, -0.4974,  0.4396, -0.7581,  1.0783,  0.8008,
           1.6806],
         [ 0.0349,  0.3211,  1.5736, -0.8455,  1.3123,  0.6872, -1.0892,
          -0.3553],
         [-1.4181,  0.8963,  0.0499,  2.2667,  1.1790, -0.4345, -1.3864,
          -1.2862]]])

In [17]:
#创建一个打乱顺序的输入序列
permuted_sequence = original_sequence.clone()
permutation = torch.randperm(sequence_length)
permuted_sequence = permuted_sequence[:, permutation, :]

In [18]:
permuted_sequence.shape

torch.Size([1, 5, 8])

In [19]:
permuted_sequence

tensor([[[-1.4181,  0.8963,  0.0499,  2.2667,  1.1790, -0.4345, -1.3864,
          -1.2862],
         [ 1.9269,  1.4873,  0.9007, -2.1055,  0.6784, -1.2345, -0.0431,
          -1.6047],
         [ 0.0349,  0.3211,  1.5736, -0.8455,  1.3123,  0.6872, -1.0892,
          -0.3553],
         [ 1.6423, -0.1596, -0.4974,  0.4396, -0.7581,  1.0783,  0.8008,
           1.6806],
         [-0.7521,  1.6487, -0.3925, -1.4036, -0.7279, -0.5594, -0.7688,
           0.7624]]])

In [33]:
import torch
from torchsummary import summary

In [37]:
hidden_dim = 16
# 超参数
rnn = nn.RNN(input_size=embedding_dim, hidden_size=hidden_dim, batch_first=True)


In [38]:
ori_output, _ = rnn(original_sequence)

In [39]:
perm_output, _ = rnn(permuted_sequence)

In [40]:
ori_output.shape, perm_output.shape

(torch.Size([1, 5, 16]), torch.Size([1, 5, 16]))

In [43]:
# global mean pooling
ori_output.squeeze(0).mean(dim=0)

tensor([-0.1647,  0.1817,  0.1086,  0.3989, -0.5853, -0.1526,  0.2821,  0.1751,
         0.0053, -0.1486,  0.4298, -0.3109,  0.2444, -0.2696,  0.1174, -0.1964],
       grad_fn=<MeanBackward1>)

In [44]:
perm_output.squeeze(0).mean(dim=0)

tensor([-0.1541,  0.2012,  0.1875,  0.4956, -0.6121, -0.1516,  0.3347,  0.2138,
         0.1248, -0.1832,  0.3795, -0.3336,  0.3319, -0.1917,  0.2129, -0.3236],
       grad_fn=<MeanBackward1>)

**结论**
- RNN不具备等变性

### w/o pe

```
self_attention(perm(x)) = perm(self_attention(x)).
```
- x: input sequence
- perm：permutation，置换

In [50]:
# 定义嵌入层和 Transformer Encoder
embedding = nn.Embedding(vocab_size, d_model)
# dropout == 0.
encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, dropout=0.0)
transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)

In [53]:
# 生成随机输入序列
seq_len = 10  # 序列长度
# 生成[0,vocab_size]之间的随机10个数字
input_ids = torch.randint(0, vocab_size, (seq_len,))
input_ids.shape

torch.Size([10])

In [54]:
# 打乱输入序列
perm = torch.randperm(seq_len)
shuffled_input_ids = input_ids[perm]

In [55]:
perm, torch.argsort(perm)

(tensor([5, 3, 7, 4, 0, 2, 1, 6, 8, 9]),
 tensor([4, 6, 5, 1, 3, 0, 7, 2, 8, 9]))

In [72]:
# 获取嵌入表示
embedded_input = embedding(input_ids)  # [seq_len, d_model]
embedded_shuffled_input = embedding(shuffled_input_ids)
print(embedded_input.shape)
print(embedded_shuffled_input.shape)

torch.Size([10, 512])
torch.Size([10, 512])


In [73]:
# Transformer 期望的输入形状为 [seq_len, batch_size, d_model]，因此需要调整维度。

# 添加 batch 维度
# [seq_len, 1, d_model]
embedded_input = embedded_input.unsqueeze(1)       
print(embedded_input.shape)
embedded_shuffled_input = embedded_shuffled_input.unsqueeze(1)

# 通过 Transformer Encoder
output = transformer_encoder(embedded_input)           # [seq_len, 1, d_model]
output_shuffled = transformer_encoder(embedded_shuffled_input)

torch.Size([10, 1, 512])


In [64]:
output.shape, output_shuffled.shape

(torch.Size([10, 1, 512]), torch.Size([10, 1, 512]))

In [65]:
torch.allclose(output, output_shuffled, atol=1e-6)

False

In [66]:
are_outputs_equal = torch.allclose(output.squeeze(1).mean(dim=0), output_shuffled.squeeze(1).mean(dim=0), atol=1e-6)
are_outputs_equal

True

In [67]:
inverse_perm = torch.argsort(perm)
output_shuffled_reordered = output_shuffled[inverse_perm]

In [68]:
torch.allclose(output, output_shuffled_reordered, atol=1e-6)

True

**结论**
- Transformer具备排列等变性

### with pe

In [74]:
# 定义位置编码
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super().__init__()
        # 创建位置编码矩阵
        position = torch.arange(0, max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-torch.log(torch.tensor(10000.0)) / d_model))
        pe = torch.zeros(max_len, 1, d_model)
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        self.pe = pe

    def forward(self, x):
        # 加性位置编码
        x = x + self.pe[:x.size(0)]
        return x

In [75]:
# 添加位置编码
pos_encoder = PositionalEncoding(d_model)
embedded_input_pos = pos_encoder(embedded_input)
embedded_shuffled_input_pos = pos_encoder(embedded_shuffled_input)

In [76]:
# 通过 Transformer Encoder
output_pe = transformer_encoder(embedded_input_pos)
output_shuffled_pe = transformer_encoder(embedded_shuffled_input_pos)

In [77]:
output_pe.shape, output_shuffled_pe.shape

(torch.Size([10, 1, 512]), torch.Size([10, 1, 512]))

In [78]:
# global mean pooling
torch.allclose(output_pe.squeeze(1).mean(dim=0), output_shuffled_pe.squeeze(1).mean(dim=0), atol=1e-6)

False

**结论**
- 加PE的Transformer不具备排列等变性

### 数学推导：`self_attention(perm(x)) = perm(self_attention(x))`

In [31]:
P = torch.tensor([[1, 0, 0, 0],
                   [0, 0, 1, 0],
                   [0, 1, 0, 0],
                   [0, 0, 0, 1]])
P @ P.T, P.T @ P

(tensor([[1, 0, 0, 0],
         [0, 1, 0, 0],
         [0, 0, 1, 0],
         [0, 0, 0, 1]]),
 tensor([[1, 0, 0, 0],
         [0, 1, 0, 0],
         [0, 0, 1, 0],
         [0, 0, 0, 1]]))

- 设 $P$ 是排列矩阵，排列后的输入为：$X_{\text{perm}}=PX$
- QKV
    - $Q_{\text{perm}}=PQ, K_{\text{perm}}=PK,V_{\text{perm}}=PV$
- attention score matrix
    - $S_{\text{perm}}=\frac{Q_{\text{perm}}K^T_{\text{perm}}}{\sqrt{d_k}}=\frac{PQK^TP^T}{\sqrt{d_k}}=P\left(\frac{QK^T}{\sqrt{d_k}}\right)P^T=PSP^T$
        - $S=\frac{QK^T}{\sqrt{d_k}}$
- softmax
    - $A_{\text{perm}}=\text{softmax}(S_{\text{perm}})=\text{softmax}(PSP^T)=PAP^T$
        - $A=\text{softmax(S)}$
- attention output
    - $Y_{\text{perm}}=A_{\text{perm}}V_{\text{perm}}=PAP^TPV=P(AV)$
        - 对于排列矩阵 $P^TP=I$

In [79]:
import torch.nn.functional as F

P = torch.tensor([[1, 0, 0, 0],
                   [0, 0, 1, 0],
                   [0, 1, 0, 0],
                   [0, 0, 0, 1]], dtype=torch.float32)
S = torch.randn(4, 4)

F.softmax(P @ S @ P.T, dim=1), P @ F.softmax(S, dim=1) @ P.T

(tensor([[0.5043, 0.1345, 0.1801, 0.1810],
         [0.3099, 0.0326, 0.3377, 0.3198],
         [0.2596, 0.2702, 0.1501, 0.3201],
         [0.0472, 0.3261, 0.1903, 0.4363]]),
 tensor([[0.5043, 0.1345, 0.1801, 0.1810],
         [0.3099, 0.0326, 0.3377, 0.3198],
         [0.2596, 0.2702, 0.1501, 0.3201],
         [0.0472, 0.3261, 0.1903, 0.4363]]))