 ### Centered Kernel Alignment（CKA）
适合分析两个视角的中间层或最终表示之间的相似性。
公式思想：通过核相似性评估两个特征矩阵之间的对齐程度
值域：0（完全不相似） ~ 1（完全相同）
CKA 值高 → 冗余高；CKA 值低 → 信息互补

In [16]:
import math
import numpy as np


def centering(K):
    n = K.shape[0]
    unit = np.ones([n, n])
    I = np.eye(n)
    H = I - unit / n

    return np.dot(np.dot(H, K), H)  # HKH are the same with KH, KH is the first centering, H(KH) do the second time, results are the sme with one time centering
    # return np.dot(H, K)  # KH


def rbf(X, sigma=None):
    GX = np.dot(X, X.T)
    KX = np.diag(GX) - GX + (np.diag(GX) - GX).T
    if sigma is None:
        mdist = np.median(KX[KX != 0])
        sigma = math.sqrt(mdist)
    KX *= - 0.5 / (sigma * sigma)
    KX = np.exp(KX)
    return KX


def kernel_HSIC(X, Y, sigma):
    return np.sum(centering(rbf(X, sigma)) * centering(rbf(Y, sigma)))


def linear_HSIC(X, Y):
    L_X = np.dot(X, X.T)
    L_Y = np.dot(Y, Y.T)
    return np.sum(centering(L_X) * centering(L_Y))


def linear_CKA(X, Y):
    hsic = linear_HSIC(X, Y)
    var1 = np.sqrt(linear_HSIC(X, X))
    var2 = np.sqrt(linear_HSIC(Y, Y))

    return hsic / (var1 * var2)


def kernel_CKA(X, Y, sigma=None):
    hsic = kernel_HSIC(X, Y, sigma)
    var1 = np.sqrt(kernel_HSIC(X, X, sigma))
    var2 = np.sqrt(kernel_HSIC(Y, Y, sigma))

    return hsic / (var1 * var2)


if __name__=='__main__':
    X = np.random.randn(100, 64)
    Y = np.random.randn(100, 64)

    print('Linear CKA, between X and Y: {}'.format(linear_CKA(X, Y)))
    print('Linear CKA, between X and X: {}'.format(linear_CKA(X, X)))

    print('RBF Kernel CKA, between X and Y: {}'.format(kernel_CKA(X, Y)))
    print('RBF Kernel CKA, between X and X: {}'.format(kernel_CKA(X, X)))

Linear CKA, between X and Y: 0.3697040651290468
Linear CKA, between X and X: 1.0000000000000002
RBF Kernel CKA, between X and Y: 0.5033810276359184
RBF Kernel CKA, between X and X: 1.0000000000000002


### Mutual Information（MI）

In [18]:
from sklearn import metrics
import numpy as np
# X = np.random.randn(128)
# Y = np.random.randn(128)
X=[0,0,1,1]
Y=[1,1,0,0]
print(metrics.mutual_info_score(X,Y))
print(metrics.mutual_info_score(X,X))

0.6931471805599453
0.6931471805599453
