# Centered Kernel Alignment (CKA)

In [33]:
import numpy as np

This notebook is heavily inspired by [Kornblith et al. (2019)](http://proceedings.mlr.press/v97/kornblith19a/kornblith19a.pdf) ([Code](https://cka-similarity.github.io/), [Video](https://www.youtube.com/watch?v=TBjdvjdS2KM)).

The CKA is defined as the normalized Hilbert-Schmidt Independence Criterion (HSIC). Assume $X \in \mathbb{R}^{n \times p_1}$ and $Y \in \mathbb{R}^{n \times p_2}$ with centered columns. Then the HSIC is defined as
$$\mathrm{HSIC}(\mathbf{K}, \mathbf{L}) = \frac{1}{(n - 1)^2} \mathrm{tr}(\mathbf{KHLH}),$$
where $K_{ij} = k(\mathbf{x}_i, \mathbf{x}_j)$, $L_{ij} = l(\mathbf{y}_i, \mathbf{y}_j)$ and $\mathbf{H} = \mathbf{I}_n - \frac{1}{n}\mathbf{1}\mathbf{1}^T$. Thus, the CKA is defined as
$$\mathrm{CKA}(\mathbf{K}, \mathbf{L}) = \frac{\mathrm{HSIC}(\mathbf{K}, \mathbf{L})}{\sqrt{\mathrm{HSIC}(\mathbf{K}, \mathbf{K})\mathrm{HSIC}(\mathbf{L}, \mathbf{L})}},$$
and using the linear kernel $k(\mathbf{x}, \mathbf{y}) = l(\mathbf{x}, \mathbf{y}) = \mathbf{x}^T\mathbf{y}$ is often reasonable according to [Kornblith et al. (2019)](http://proceedings.mlr.press/v97/kornblith19a/kornblith19a.pdf).

In [105]:
def linear_kernel(X, Y):
    return np.matmul(X, Y.T)

def rbf(X, Y, sigma=None):
    """
    Radial-Basis Function kernel for X and Y with bandwith chosen
    from median if not specified.
    """
    GX = np.dot(X, Y.T)
    KX = np.diag(GX) - GX + (np.diag(GX) - GX).T
    if sigma is None:
        mdist = np.median(KX[KX != 0])
        sigma = math.sqrt(mdist)
    KX *= - 0.5 / (sigma * sigma)
    KX = np.exp(KX)
    return KX

def HSIC(K, L):
    """
    Calculate Hilbert-Schmidt Independence Criterion on K and L.
    """
    n = K.shape[0]
    H = np.identity(n) - (1./n) * np.ones((n, n))

    KH = np.matmul(K, H)
    LH = np.matmul(L, H)
    return 1./((n-1)**2) * np.trace(np.matmul(KH, LH))

def CKA(X, Y, kernel=None):
    """
    Calculate Centered Kernel Alingment for X and Y. If no kernel
    is specified, the linear kernel will be used.
    """
    kernel = linear_kernel if kernel is None else kernel
    
    K = kernel(X, X)
    L = kernel(Y, Y)
        
    hsic = HSIC(K, L)
    varK = np.sqrt(HSIC(K, K))
    varL = np.sqrt(HSIC(L, L))
    return hsic / (varK * varL)

We now simulate centered matrices $X$ and $Y$.

In [102]:
n = 100   # Samples
p1 = 64   # Representation dim model 1
p2 = 32   # Representation dim model 1

# Generate X
X = np.random.normal(size=(n, p1))
Y = np.random.normal(size=(n, p2))

# Center columns
X = X - np.mean(X, 0)
Y = Y - np.mean(Y, 0)

Now we calculate the CKA on these both with the RBF and linear kernel. To verify the implementation we also calculate the CKA of each matrix with itself, expecting a value of $1$.
Note, the bandwidth $\sigma$ in the RBF kernel is chosen as the the squareroot of the median distance between samples following [Kornblith et al. (2019)](http://proceedings.mlr.press/v97/kornblith19a/kornblith19a.pdf).

In [103]:
print(f'Linear CKA, between X and Y: {CKA(X, Y):1.5f}')
print(f'Linear CKA, between X and X: {CKA(X, X):1.5f}')

print(f'RBF Kernel CKA, between X and Y: {CKA(X, Y, rbf):1.5f}')
print(f'RBF Kernel CKA, between X and X: {CKA(X, X, rbf):1.5f}')

Linear CKA, between X and Y: 0.28181
Linear CKA, between X and X: 1.00000
RBF Kernel CKA, between X and Y: 0.41206
RBF Kernel CKA, between X and X: 1.00000


In [104]:
%%timeit
CKA(X, Y)

1.24 ms ± 383 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


The aim is to use the CKA metric to compare representations from different models.
The procedure requires us to pass a dataset (e.g. CIFAR-10 through a model and register the representation (i.e. activation at each node) of some layer. The stacked representation is the $n \times p$ matrix of $n$ samples and $p$-dimensional representations. One can both look at the similarities between samples and between features, but for the linear kernel these amount to the same result according to [Kornblith et al. (2019)](http://proceedings.mlr.press/v97/kornblith19a/kornblith19a.pdf).