In [9]:
import numpy as np
import os, yaml, sys
ENV = os.getenv("MY_ENV", "dev")
with open("../../config.yaml", "r") as f:
    config = yaml.safe_load(f)
paths = config[ENV]["paths"]
sys.path.append(paths["src_path"])
from sklearn.metrics.pairwise import pairwise_kernels
import torch
import torch.nn.functional as F

In [177]:
def center_gram(K):
    """Center a Gram matrix."""
    n = K.shape[0]
    H = np.eye(n) - np.ones((n, n)) / n
    return H @ K @ H

def hsic(K, L):
    """Hilbert-Schmidt Independence Criterion (biased)."""
    n = K.shape[0]
    Kc, Lc = center_gram(K), center_gram(L)
    trace_test = np.trace(Kc @ Lc) / (n - 1) ** 2
    dot_test = np.dot(Kc.flatten(), Lc.flatten())
    #assert(np.allclose(trace_test, dot_test))
    return np.trace(Kc @ Lc) / (n - 1) ** 2
    #return dot_test
    
def cka(X, Y, kernel="linear", **kwargs):
    """
    Compute CKA between two representations X, Y.
    
    Parameters
    ----------
    X : ndarray of shape (n_samples, d1)
    Y : ndarray of shape (n_samples, d2)
    kernel : str or callable
        - "linear": use linear kernel
        - any kernel name supported by sklearn (e.g. "rbf", "poly")
    kwargs : extra arguments for sklearn pairwise_kernels

    Returns
    -------
    cka_value : float
        Centered Kernel Alignment value in [0, 1].
    """
    # Build Gram matrices
    if kernel == "linear":
        K = X @ X.T
        L = Y @ Y.T
    else:
        K = pairwise_kernels(X, metric=kernel, **kwargs)
        L = pairwise_kernels(Y, metric=kernel, **kwargs)
    
    # Compute normalized HSIC
    hsic_xy = hsic(K, L)
    hsic_xx = hsic(K, K)
    hsic_yy = hsic(L, L)
    
    return hsic_xy / np.sqrt(hsic_xx * hsic_yy + 1e-12)

In [185]:
# Random activations from two layers
X = np.random.randn(3000, 1000)  # 100 samples, 128 neurons
Y = 5*X +2*np.random.randn(3000, 1000)  # 100 samples, 256 neurons

# Linear CKA
print("Linear CKA:", cka(X, Y, kernel="linear"))

# RBF CKA
print("RBF CKA:", cka(X, Y, kernel="rbf", gamma=1.0))

Linear CKA: 0.8965122658886783
RBF CKA: 0.9999955030298343


In [201]:
# TODO check if it's correct or not ... is it correct to sum stuff? or shall I first compute everything and then do the hsic like in the paper
# TODO check also how the trend goes for CKA with respect to the others
def center_gram_torch(K):
    """Center Gram matrix (Torch version)."""
    n = K.size(0)
    H = torch.eye(n, device=K.device) - torch.ones((n, n), device=K.device) / n
    return H @ K @ H

def hsic_torch(K, L):
    """HSIC estimate (biased)."""
    n = K.size(0)
    Kc, Lc = center_gram_torch(K), center_gram_torch(L)
    return torch.trace(Kc @ Lc) / ((n - 1) ** 2)

def cka_minibatch(X, Y, kernel="linear", **kwargs):
    """
    Minibatch version of CKA in PyTorch.
    
    Parameters
    ----------
    X : tensor, shape (batch_size, d1)
    Y : tensor, shape (batch_size, d2)
    kernel : str
        "linear" or "rbf"
    kwargs : parameters for RBF kernel (e.g., gamma)
    """
    if kernel == "linear":
        K = X @ X.T
        L = Y @ Y.T
    else:
        K = pairwise_kernels(X_batch, metric=kernel, **kwargs)
        L = pairwise_kernels(Y_batch, metric=kernel, **kwargs)
    
    hsic_xy = hsic_unbiased(K, L)
    hsic_xx = hsic_unbiased(K, K)
    hsic_yy = hsic_unbiased(L, L)
    #return hsic_xy / np.sqrt(hsic_xx * hsic_yy + 1e-12)
    return hsic_xy, hsic_xx, hsic_yy 

In [205]:
def hsic_unbiased(K, L):
    """Unbiased HSIC estimator (Song et al. 2007) in NumPy."""
    n = K.shape[0]
    if n < 4:
        raise ValueError("Need at least 4 samples for unbiased HSIC")

    # make a copy and zero diagonals
    K = K.copy()
    L = L.copy()
    np.fill_diagonal(K, 0)
    np.fill_diagonal(L, 0)

    ones = np.ones((n, 1))

    term1 = np.trace(K @ L).item()  
    term2 = ((ones.T @ K @ ones) * (ones.T @ L @ ones) / ((n - 1) * (n - 2))).item()
    term3 = (2 * (ones.T @ (K @ L) @ ones) / (n - 2)).item()
    return float((term1 + term2 - term3) / (n * (n - 3)))

def cka_batch_collection(xy, xx, yy):
    return xy / (np.sqrt(xx) *np.sqrt(yy)+ 1e-12)

In [206]:
cka_scores = []
prev_idx = 0
xy, xx, yy = 0, 0, 0
n_pts = X.shape[0]
batch_size = n_pts
batch_num = n_pts // batch_size
idx = np.random.choice(np.arange(n_pts), size=n_pts, replace=False)
prev_idx = 0
counter = 0
for i in range(batch_size, n_pts+batch_size, batch_size):  # iterate over minibatches
    i = min(i, n_pts)
    #X_batch, Y_batch = X_batch.cuda(), Y_batch.cuda()
    #idx = np.random.choice(np.arange(n_pts), size=batch_size, replace=False)
    #X_batch, Y_batch = X[idx, :], Y[idx, :]
    #print("batch start", prev_idx, "batch end", i-1)
    counter +=1
    #print(counter)
    rand_idx = idx[prev_idx:i]
    X_batch, Y_batch = X[rand_idx, :], Y[rand_idx, :]
    prev_idx = i
    xy_n, xx_n, yy_n = cka_minibatch(X_batch, Y_batch, kernel="linear")
    xy += xy_n
    xx += xx_n
    yy += yy_n
#cka_value = np.mean(np.stack(cka_scores))
print("Estimated minibatch CKA:", cka_batch_collection(xy, xx, yy))

Estimated minibatch CKA: 0.8620060373731601


In [124]:
for i in range(batch_size, n_pts, batch_size):  # iterate over minibatches
    print(i)

100
200


In [94]:
n_pts = X.shape[0]
batch_size = 10
num_batches = n_pts // batch_size

xy_sum, xx_sum, yy_sum = 0.0, 0.0, 0.0

for i in range(num_batches):
    start = i * batch_size
    end = (i + 1) * batch_size
    X_batch, Y_batch = X[start:end, :], Y[start:end, :]
    xy_n, xx_n, yy_n = cka_minibatch(X_batch, Y_batch, kernel="linear")
    xy_sum += xy_n
    xx_sum += xx_n
    yy_sum += yy_n

xy_avg = xy_sum / num_batches
xx_avg = xx_sum / num_batches
yy_avg = yy_sum / num_batches

cka_est = xy_avg / np.sqrt(xx_avg * yy_avg)
print(cka_est)

-0.00870454970006897


In [39]:
prev_idx = 0
for i in range(10, 100, 10):  # iterate over minibatches
    #X_batch, Y_batch = X_batch.cuda(), Y_batch.cuda()
    print(prev_idx, i)
    prev_idx = i
    

0 10
10 20
20 30
30 40
40 50
50 60
60 70
70 80
80 90


In [13]:
for i in range(1, 40, 7):
    print(i)

1
8
15
22
29
36


In [93]:
import numpy as np
from sklearn.metrics.pairwise import pairwise_kernels

def center_gram(K):
    """Center a Gram matrix."""
    n = K.shape[0]
    H = np.eye(n) - np.ones((n, n)) / n
    return H @ K @ H

def hsic_biased(K, L):
    """Hilbert-Schmidt Independence Criterion (biased estimator)."""
    n = K.shape[0]
    Kc, Lc = center_gram(K), center_gram(L)
    return np.trace(Kc @ Lc) / (n - 1) ** 2

def hsic_unbiased(K, L):
    """Unbiased HSIC estimator (Song et al. 2007)."""
    n = K.shape[0]
    if n < 4:
        raise ValueError("Need at least 4 samples for unbiased HSIC")

    # Zero out diagonals
    K_tilde = K.copy()
    L_tilde = L.copy()
    np.fill_diagonal(K_tilde, 0)
    np.fill_diagonal(L_tilde, 0)

    ones = np.ones(n)
    
    term1 = np.trace(K_tilde @ L_tilde)
    term2 = (ones.T @ K_tilde @ ones) * (ones.T @ L_tilde @ ones) / ((n - 1) * (n - 2))
    term3 = 2 * (ones.T @ K_tilde @ L_tilde @ ones) / (n - 2)
    
    return (term1 + term2 - term3) / (n * (n - 3))

def cka_exact(X, Y, kernel="linear", use_unbiased=False, **kwargs):
    """
    Compute exact CKA between two representations.
    """
    if kernel == "linear":
        K = X @ X.T
        L = Y @ Y.T
    else:
        K = pairwise_kernels(X, metric=kernel, **kwargs)
        L = pairwise_kernels(Y, metric=kernel, **kwargs)
    
    # Choose HSIC estimator
    hsic_func = hsic_unbiased if use_unbiased else hsic_biased
    
    hsic_xy = hsic_func(K, L)
    hsic_xx = hsic_func(K, K)
    hsic_yy = hsic_func(L, L)
    
    return hsic_xy / np.sqrt(hsic_xx * hsic_yy + 1e-12)

def cka_minibatch_approximation(X, Y, batch_size=50, n_batches=20, kernel="linear", 
                              use_unbiased=True, **kwargs):
    """
    Minibatch approximation of CKA using proper sampling without replacement.
    """
    n_samples = X.shape[0]
    
    # Choose HSIC estimator
    hsic_func = hsic_unbiased if use_unbiased else hsic_biased
    
    hsic_xy_sum = 0
    hsic_xx_sum = 0  
    hsic_yy_sum = 0
    
    for i in range(n_batches):
        # Sample without replacement for this minibatch
        indices = np.random.choice(n_samples, size=batch_size, replace=False)
        X_batch = X[indices]
        Y_batch = Y[indices]
        
        # Compute gram matrices for this batch
        if kernel == "linear":
            K_batch = X_batch @ X_batch.T
            L_batch = Y_batch @ Y_batch.T
        else:
            K_batch = pairwise_kernels(X_batch, metric=kernel, **kwargs)
            L_batch = pairwise_kernels(Y_batch, metric=kernel, **kwargs)
        
        # Accumulate HSIC estimates
        hsic_xy_sum += hsic_func(K_batch, L_batch)
        hsic_xx_sum += hsic_func(K_batch, K_batch)
        hsic_yy_sum += hsic_func(L_batch, L_batch)
    
    # Average and compute CKA
    hsic_xy_avg = hsic_xy_sum / n_batches
    hsic_xx_avg = hsic_xx_sum / n_batches
    hsic_yy_avg = hsic_yy_sum / n_batches
    
    return hsic_xy_avg / np.sqrt(hsic_xx_avg * hsic_yy_avg + 1e-12)

# Test the implementation
np.random.seed(42)  # For reproducibility
X = np.random.randn(1000, 128)  # Larger dataset
Y = np.random.randn(1000, 256)

print("=== Comparison with BIASED HSIC ===")
# Exact CKA with biased HSIC
exact_cka_biased = cka_exact(X, Y, kernel="linear", use_unbiased=False)
print(f"Exact CKA (biased HSIC): {exact_cka_biased:.6f}")

# Minibatch approximation with biased HSIC  
minibatch_cka_biased = cka_minibatch_approximation(X, Y, batch_size=100, n_batches=50, 
                                                 kernel="linear", use_unbiased=False)
print(f"Minibatch CKA (biased HSIC): {minibatch_cka_biased:.6f}")
print(f"Difference: {abs(exact_cka_biased - minibatch_cka_biased):.6f}")

print("\n=== Comparison with UNBIASED HSIC ===")
# Exact CKA with unbiased HSIC
exact_cka_unbiased = cka_exact(X, Y, kernel="linear", use_unbiased=True)
print(f"Exact CKA (unbiased HSIC): {exact_cka_unbiased:.6f}")

# Minibatch approximation with unbiased HSIC
minibatch_cka_unbiased = cka_minibatch_approximation(X, Y, batch_size=100, n_batches=50,
                                                   kernel="linear", use_unbiased=True)
print(f"Minibatch CKA (unbiased HSIC): {minibatch_cka_unbiased:.6f}")
print(f"Difference: {abs(exact_cka_unbiased - minibatch_cka_unbiased):.6f}")

print("\n=== Convergence Test ===")
# Test convergence with increasing number of batches
batch_sizes = [10, 25, 50, 100]
n_batches_list = [10, 20, 50, 100]

print("Batch size | N batches | Minibatch CKA | Difference from exact")
print("-" * 65)

for batch_size in batch_sizes:
    for n_batches in n_batches_list:
        if batch_size * n_batches <= 10000:  # Reasonable computational limit
            approx_cka = cka_minibatch_approximation(X, Y, batch_size=batch_size, 
                                                   n_batches=n_batches, kernel="linear", 
                                                   use_unbiased=True)
            diff = abs(exact_cka_unbiased - approx_cka)
            print(f"{batch_size:10d} | {n_batches:9d} | {approx_cka:13.6f} | {diff:.6f}")

=== Comparison with BIASED HSIC ===
Exact CKA (biased HSIC): 0.151661
Minibatch CKA (biased HSIC): 0.636501
Difference: 0.484841

=== Comparison with UNBIASED HSIC ===
Exact CKA (unbiased HSIC): -0.000498
Minibatch CKA (unbiased HSIC): 0.000162
Difference: 0.000660

=== Convergence Test ===
Batch size | N batches | Minibatch CKA | Difference from exact
-----------------------------------------------------------------
        10 |        10 |      0.026911 | 0.027409
        10 |        20 |     -0.048918 | 0.048420
        10 |        50 |      0.024392 | 0.024890
        10 |       100 |     -0.001390 | 0.000891
        25 |        10 |     -0.019966 | 0.019467
        25 |        20 |     -0.015000 | 0.014501
        25 |        50 |      0.018446 | 0.018945
        25 |       100 |      0.002189 | 0.002687
        50 |        10 |      0.007660 | 0.008159
        50 |        20 |     -0.006197 | 0.005699
        50 |        50 |     -0.001846 | 0.001348
        50 |       100 |     