Skip to content


Repository files navigation

🤖 CKA PyTorch 🤖

CKA (Centered Kernel Alignment) in PyTorch.

Python Pytorch

✒️ About


Centered Kernel Alignment (CKA) [1] is a similarity index between representations of features in neural networks, based on the Hilbert-Schmidt Independence Criterion (HSIC) [2]. Given a set of examples, CKA compares the representations of examples passed through the layers that we want to compare.

Given two matrices $\boldsymbol{X} \in \mathbb{R}^{n\times s_1}$ and $\boldsymbol{Y} \in \mathbb{R}^{n\times s_2}$ representing the output of two layers, we can define two auxiliary $n \times n$ Gram matrices $\boldsymbol{K}=\boldsymbol{XX^T}$ and $\boldsymbol{L}=\boldsymbol{YY^T}$ and compute the dot-product similarity between them

$$\langle vec(\boldsymbol{XX^T}), vec(\boldsymbol{YY^T})\rangle = tr(\boldsymbol{XX^T YY^T}) = \lVert \boldsymbol{Y^T X} \rVert_F^2.$$

Then, the $HSIC$ on $K$ and $L$ is defined as

$$HSIC_0(\boldsymbol{K}, \boldsymbol{L}) = \frac{tr(\boldsymbol{KHLH})}{(n - 1)^2},$$

where $\boldsymbol{H} = \boldsymbol{I_n} - \frac{1}{n}\boldsymbol{J_n}$ is the centering matrix and $\boldsymbol{J_n}$ is an $n \times n$ matrix filled with ones. Finally, to obtain the CKA value we only need to normalize $HSIC_0$

$$CKA(\boldsymbol{K}, \boldsymbol{L}) = \frac{HSIC(\boldsymbol{K}, \boldsymbol{L})}{\sqrt{HSIC(\boldsymbol{K}, \boldsymbol{K}) HSIC(\boldsymbol{L}, \boldsymbol{L})}}.$$


However, naive computation of linear CKA (i.e.: the previous equation) requires maintaining the activations across the entire dataset in memory, which is challenging for wide and deep networks [3].

Therefore, we need to define the unbiased estimator of HSIC so that the value of CKA is independent of the batch size

$$HSIC_1(\boldsymbol{K}, \boldsymbol{L})=\frac{1}{n(n-3)}\left( tr(\boldsymbol{\tilde{K}}, \boldsymbol{\tilde{L}}) + \frac{\boldsymbol{1^T\tilde{K}11^T\tilde{L}1}}{(n-1)(n-2)} - \frac{2}{n-2}\boldsymbol{1^T\tilde{K}\tilde{L}1}\right),$$

where $\boldsymbol{\tilde{K}}$ and $\boldsymbol{\tilde{L}}$ are obtained by setting the diagonal entries of $\boldsymbol{K}$ and $\boldsymbol{L}$ to zero. Finally, we can compute the minibatch version of CKA by averaging $HSIC_1$ scores over $k$ minibatches

$$CKA_{minibatch}=\frac{\frac{1}{k} \displaystyle\sum_{i=1}^{k} HSIC_1(\boldsymbol{K_i}, \boldsymbol{L_i})}{\sqrt{\frac{1}{k} \displaystyle\sum_{i=1}^{k} HSIC_1(\boldsymbol{K_i}, \boldsymbol{K_i})}\sqrt{\frac{1}{k} \displaystyle\sum_{i=1}^{k} HSIC_1(\boldsymbol{L_i}, \boldsymbol{L_i})}},$$

with $\boldsymbol{K_i}=\boldsymbol{X_iX_i^T}$ and $\boldsymbol{L_i}=\boldsymbol{Y_iY_i^T}$, where $\boldsymbol{X_i} \in \mathbb{R}^{m \times p_1}$ and $\boldsymbol{Y_i} \in \mathbb{R}^{m \times p_2}$ are now matrices containing activations of the $i^{th}$ minibatch of $m$ examples sampled without replacement [3].

📦 Installation

This project requires python >= 3.10. As first step, clone the repository

git clone

Then, you can install the necessary packages with


This will install PyTorch not compiled with CUDA if you are on Windows. If you want to use your GPU during the computation, you should follow the official site

pip install -r requirements.txt

Take a look at the examples directory to understand how to compute CKA in different scenarios.

🖼️ Plots

Model compared with itself Different models compared
Model compared with itself Model comparison

📚 Bibliography

[1] Kornblith, Simon, et al. "Similarity of neural network representations revisited." International Conference on Machine Learning. PMLR, 2019.

[2] Wang, Tinghua, Xiaolu Dai, and Yuze Liu. "Learning with Hilbert–Schmidt independence criterion: A review and new perspectives." Knowledge-based systems 234 (2021): 107567.

[3] Nguyen, Thao, Maithra Raghu, and Simon Kornblith. "Do wide and deep networks learn the same things? uncovering how neural network representations vary with width and depth." arXiv preprint arXiv:2010.15327 (2020).

This project is also based on the following repositories:

📝 License

This project is MIT licensed.


CKA (Centered Kernel Alignment) implemented in PyTorch








No releases published


No packages published
