In [12]:
import torch

# 设置样本数和维度
n, d = 30, 16
torch.manual_seed(42)

# 构造两个相同结构的表示矩阵
X = torch.randn(n, d)
Y = X.clone()


In [9]:
print(X)
print(Y)

tensor([[ 1.9269e+00,  1.4873e+00,  9.0072e-01, -2.1055e+00,  6.7842e-01,
         -1.2345e+00, -4.3067e-02, -1.6047e+00, -7.5214e-01,  1.6487e+00,
         -3.9248e-01, -1.4036e+00, -7.2788e-01, -5.5943e-01, -7.6884e-01,
          7.6245e-01],
        [ 1.6423e+00, -1.5960e-01, -4.9740e-01,  4.3959e-01, -7.5813e-01,
          1.0783e+00,  8.0080e-01,  1.6806e+00,  1.2791e+00,  1.2964e+00,
          6.1047e-01,  1.3347e+00, -2.3162e-01,  4.1759e-02, -2.5158e-01,
          8.5986e-01],
        [-1.3847e+00, -8.7124e-01, -2.2337e-01,  1.7174e+00,  3.1888e-01,
         -4.2452e-01,  3.0572e-01, -7.7459e-01, -1.5576e+00,  9.9564e-01,
         -8.7979e-01, -6.0114e-01, -1.2742e+00,  2.1228e+00, -1.2347e+00,
         -4.8791e-01],
        [-9.1382e-01, -6.5814e-01,  7.8024e-02,  5.2581e-01, -4.8799e-01,
          1.1914e+00, -8.1401e-01, -7.3599e-01, -1.4032e+00,  3.6004e-02,
         -6.3477e-02,  6.7561e-01, -9.7807e-02,  1.8446e+00, -1.1845e+00,
          1.3835e+00],
        [ 1.4451e+00

In [22]:
# 模拟偏移（所有 token 向量都加一个偏移）
offset1 = torch.ones(1, d) * 3.0
offset2 = torch.ones(1, d) * 100000.0
print(offset1)
print(offset2)
X_shifted = X + offset1
Y_shifted = Y + offset2

print(X_shifted)
print(Y_shifted)


tensor([[3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3.]])
tensor([[100000., 100000., 100000., 100000., 100000., 100000., 100000., 100000.,
         100000., 100000., 100000., 100000., 100000., 100000., 100000., 100000.]])
tensor([[4.9269, 4.4873, 3.9007, 0.8945, 3.6784, 1.7655, 2.9569, 1.3953, 2.2479,
         4.6487, 2.6075, 1.5964, 2.2721, 2.4406, 2.2312, 3.7624],
        [4.6423, 2.8404, 2.5026, 3.4396, 2.2419, 4.0783, 3.8008, 4.6806, 4.2791,
         4.2964, 3.6105, 4.3347, 2.7684, 3.0418, 2.7484, 3.8599],
        [1.6153, 2.1288, 2.7766, 4.7174, 3.3189, 2.5755, 3.3057, 2.2254, 1.4424,
         3.9956, 2.1202, 2.3989, 1.7258, 5.1228, 1.7653, 2.5121],
        [2.0862, 2.3419, 3.0780, 3.5258, 2.5120, 4.1914, 2.1860, 2.2640, 1.5968,
         3.0360, 2.9365, 3.6756, 2.9022, 4.8446, 1.8155, 4.3835],
        [4.4451, 3.8564, 5.2181, 3.5232, 3.3466, 2.8027, 1.9454, 4.2780, 2.8278,
         3.5238, 3.0566, 3.4263, 3.5750, 2.3583, 0.7936, 2.2492],
        [3.0109, 2.6613, 1

In [23]:
# 去中心化矩阵 H
H = torch.eye(n) - torch.ones(n, n) / n

# CKA 相似度函数（Linear CKA）
def linear_CKA(X1, X2):
    K = X1 @ X1.T
    L = X2 @ X2.T
    hsic = (K * L).sum()
    norm_K = (K * K).sum().sqrt()
    norm_L = (L * L).sum().sqrt()
    return hsic / (norm_K * norm_L)

# 不去中心化的结果
cka_raw = linear_CKA(X_shifted, Y_shifted).item()

# 去中心化后的结果
X_centered = H @ X_shifted
Y_centered = H @ Y_shifted
cka_centered = linear_CKA(X_centered, Y_centered).item()

print(f"未去中心化 CKA: {cka_raw:.4f}")
print(f"去中心化后 CKA: {cka_centered:.4f}")

未去中心化 CKA: 0.9937
去中心化后 CKA: 0.9999
