In [7]:
import torch
import torch.optim as optim

# Define the CCA loss function
def cca_loss(x, y):
    # Normalize the input data
    x_normalized = torch.nn.functional.normalize(x, dim=0)
    y_normalized = torch.nn.functional.normalize(y, dim=0)

    # Compute the covariance matrix of the normalized input data
    cov = torch.matmul(x_normalized.T, y_normalized)

    # Compute the singular value decomposition of the covariance matrix
    u, s, v = torch.svd(cov)

    # Compute the canonical correlation coefficients
    cca_coef = s[:min(x.shape[1], y.shape[1])]

    # Normalize the CCA coefficients
    cca_coef_norm = cca_coef / torch.max(cca_coef)

    # Compute the loss function
    loss = 1 - cca_coef_norm.sum()

    return loss

# Define the model and optimizer
x = torch.randn(100, 50)  # Fixed input X
y = torch.randn(100, 50, requires_grad=True)  # Model parameters Y

optimizer = optim.SGD([y], lr=0.1)

# Train the model to maximize the normalized CCA value
for i in range(1000):
    loss = cca_loss(x, y)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    # Print the normalized CCA value every 100 iterations
    if i % 100 == 0:
        cca_coef = torch.svd(torch.matmul(x.T, y))[1][:min(x.shape[1], y.shape[1])]
        cca_coef_norm = cca_coef / torch.max(cca_coef)
        print(f"Iteration {i}: Normalized CCA value = {cca_coef_norm.sum().item():.4f}")


Iteration 0: Normalized CCA value = 16.7968
Iteration 100: Normalized CCA value = 24.9161
Iteration 200: Normalized CCA value = 29.2382
Iteration 300: Normalized CCA value = 31.9819
Iteration 400: Normalized CCA value = 34.1554
Iteration 500: Normalized CCA value = 35.8794
Iteration 600: Normalized CCA value = 37.1309
Iteration 700: Normalized CCA value = 38.3227
Iteration 800: Normalized CCA value = 39.2704
Iteration 900: Normalized CCA value = 39.3329


In [1]:
import torch
import torch.nn as nn
import torch.optim as optim

# Define the deep neural network model
class Model(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(Model, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.fc2 = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# Define the CCA loss function
def cca_loss(x, y):
    # Normalize the input data
    x_normalized = torch.nn.functional.normalize(x, dim=0)
    y_normalized = torch.nn.functional.normalize(y, dim=0)

    # Compute the covariance matrix of the normalized input data
    cov = torch.matmul(x_normalized.T, y_normalized)

    # Compute the singular value decomposition of the covariance matrix
    u, s, v = torch.svd(cov)

    # Compute the canonical correlation coefficients
    cca_coef = s[:min(x.shape[1], y.shape[1])]

    # Normalize the CCA coefficients
    cca_coef_norm = cca_coef / torch.max(cca_coef)

    # Compute the loss function
    loss = 1 - cca_coef_norm.sum()

    return loss

# Define the data and model parameters
x = torch.randn(100, 50)
y = torch.randn(100, 50, requires_grad=True)

model = Model(input_size=50, hidden_size=100, output_size=50)
optimizer = optim.SGD(model.parameters(), lr=0.1)

# Train the model to maximize the CCA value
for i in range(1000):
    y_pred = model(x)
    loss = cca_loss(y_pred, y)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    # Print the CCA value every 100 iterations
    if i % 100 == 0:
        cca_coef = torch.svd(torch.matmul(x.T, y_pred))[1][:min(x.shape[1], y_pred.shape[1])]
        cca_coef_norm = cca_coef / torch.max(cca_coef)
        print(f"Iteration {i}: CCA value = {cca_coef_norm.sum().item():.4f}")


  from .autonotebook import tqdm as notebook_tqdm


Iteration 0: CCA value = 8.0445
Iteration 100: CCA value = 15.9398
Iteration 200: CCA value = 16.8100
Iteration 300: CCA value = 17.0379
Iteration 400: CCA value = 17.4920
Iteration 500: CCA value = 17.1443
Iteration 600: CCA value = 16.9245
Iteration 700: CCA value = 17.2129
Iteration 800: CCA value = 17.3528
Iteration 900: CCA value = 16.8542


In [16]:
import torch
import numpy as np
from scipy.spatial.distance import pdist
import math
def rbf(X, sigma=None):
    GX = np.dot(X, X.T)
    KX = np.diag(GX) - GX + (np.diag(GX) - GX).T
    if sigma is None:
        mdist = np.median(KX[KX != 0])
        sigma = math.sqrt(mdist)
    KX *= - 0.5 / (sigma * sigma)
    KX = np.exp(KX)
    return KX

def centering(K):
    n = K.shape[0]
    unit = np.ones([n, n])
    I = np.eye(n)
    H = I - unit / n
    return np.dot(np.dot(H, K), H) 

def kernel_HSIC(X, Y, sigma):
    return np.sum(centering(rbf(X, sigma)) * centering(rbf(Y, sigma)))


def cka(X, Y, sigma=1e-2):
        hsic = kernel_HSIC(X, Y, sigma)
        var1 = np.sqrt(kernel_HSIC(X, X, sigma))
        var2 = np.sqrt(kernel_HSIC(Y, Y, sigma))

        return hsic / (var1 * var2)
    
X = torch.randn(100, 10)
Y = torch.randn(100, 10)

cka_value = cka(X, Y)

def objective(Y, X=X, gamma=1e-2):
    K_xx = kernel_HSIC(X, X, gamma)
    K_yy = kernel_HSIC(Y, Y, gamma)
    K_xy = kernel_HSIC(X, Y, gamma)
    cka = torch.mean(K_xy) / torch.sqrt(torch.mean(K_xx) * torch.mean(K_yy))
    return -cka

from torch.optim import Adam

Y = torch.randn(100, 10, requires_grad=True)
optimizer = Adam([Y], lr=1e-2)

for i in range(10):
    loss = objective(Y)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    print(loss)

Y = Y.detach().numpy()




RuntimeError: Can't call numpy() on Tensor that requires grad. Use tensor.detach().numpy() instead.

In [13]:
import sys
sys.path.append("..")
from CKA_utils.CKA import CKA, CudaCKA
np_cka = CKA()
print('RBF Kernel CKA, between diff subset: {}'.format(np_cka.kernel_CKA(X.detach().numpy(), Y)))

RBF Kernel CKA, between diff subset: 0.15928442998902606


In [7]:
cka_value = cka(X, torch.Tensor(Y))

In [8]:
cka_value

0.9999964833259583

In [9]:
X = torch.randn(100, 10)
Y = torch.randn(100, 10)

In [11]:
cka_value = cka(X, Y)
cka_value

0.996455729007721

In [23]:
####Step 1: Compute the Gram matrices of X and Y using a kernel function. 
#### The kernel function can be any function that measures the similarity between two inputs. In this case, we will use the Gaussian kernel:
import torch
import torch.nn as nn
import torch.optim as optim
def gram_matrix(x, sigma=1):
    n = x.size(0)
    x = x.view(n, -1)
    gram = torch.mm(x, x.t())
    gram = torch.exp(-torch.pow(gram, 2) / (2 * sigma**2))
    return gram
#### Step 2: Compute the centered Gram matrices of X and Y:
def centered_gram_matrix(x, sigma=1):
    n = x.size(0)
    gram = gram_matrix(x, sigma)
    ones = torch.ones(n, n) / n
    centered_gram = gram - torch.mm(ones, gram) - torch.mm(gram, ones) + torch.mm(ones, torch.mm(gram, ones))
    return centered_gram

### Step 3: Compute the Hilbert-Schmidt Independence Criterion (HSIC) between X and Y:
def hsic(x, y, sigma=1):
    x_centered = centered_gram_matrix(x, sigma)
    y_centered = centered_gram_matrix(y, sigma)
    hsic = torch.trace(torch.mm(x_centered, y_centered))
    return hsic
#### Step 4: Compute the maximum mean discrepancy (MMD) between X and Y using the same kernel function:
def mmd(x, y, sigma=1):
    x_gram = gram_matrix(x, sigma)
    y_gram = gram_matrix(y, sigma)
    mmd = torch.mean(x_gram) - 2 * torch.mean(torch.mm(x_gram, y_gram)) + torch.mean(y_gram)
    return mmd
### Step5 Step 5: Compute the squared centered kernel alignment (CKA) between X and Y using HSIC and MMD:

def cka(x, y, sigma=1):
    hsic_xy = hsic(x, y, sigma)
    hsic_xx = hsic(x, x, sigma)
    hsic_yy = hsic(y, y, sigma)
    mmd_xy = mmd(x, y, sigma)
    mmd_xx = mmd(x, x, sigma)
    mmd_yy = mmd(y, y, sigma)
    cka = torch.pow(hsic_xy, 2) / (torch.pow(hsic_xx, 2) * torch.pow(hsic_yy, 2))
    cka /= torch.pow(mmd_xy, 2) / (torch.pow(mmd_xx, 2) * torch.pow(mmd_yy, 2))
    return cka
x = torch.randn(100, 10)
y = torch.randn(100, 10)
cka_value = cka(x, y)


class MyModel(nn.Module):
    def __init__(self, input_size, output_size):
        super(MyModel, self).__init__()
        self.fc = nn.Linear(input_size, output_size)

    def forward(self, x):
        y = self.fc(x)
        return y

class CKALoss(nn.Module):
    def __init__(self, sigma=1):
        super(CKALoss, self).__init__()
        self.sigma = sigma

    def forward(self, x, y):
        x_centered = centered_gram_matrix(x, self.sigma)
        y_centered = centered_gram_matrix(y, self.sigma)
        hsic_xy = torch.trace(torch.mm(x_centered, y_centered))
        hsic_xx = torch.trace(torch.mm(x_centered, x_centered))
        hsic_yy = torch.trace(torch.mm(y_centered, y_centered))
        mmd_xy = mmd(x, y, self.sigma)
        mmd_xx = mmd(x, x, self.sigma)
        mmd_yy = mmd(y, y, self.sigma)
        cka = torch.pow(hsic_xy, 2) / (torch.pow(hsic_xx, 2) * torch.pow(hsic_yy, 2))
        cka /= torch.pow(mmd_xy, 2) / (torch.pow(mmd_xx, 2) * torch.pow(mmd_yy, 2))
        loss = 1 - cka
        return loss
# x = torch.randn(1000, 50)
# y = torch.randn(1000, 50)
# model = MyModel(10, 10)
cka_loss = CKALoss(sigma=1)
# optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

# for epoch in range(100):
#     # Forward pass
#     y_pred = model(x)
#     # Compute CKA loss
#     loss = cka_loss(x, y_pred)
#     # Backward pass
#     optimizer.zero_grad()
#     loss.backward()
#     optimizer.step()
#     # Print loss
#     print(f"Epoch [{epoch+1}/{1000}], Loss: {loss.item():.4f}")
X = torch.randn(100, 10)
Y = torch.randn(100, 10, requires_grad=True)
optimizer = Adam([Y], lr=1e-2)

for i in range(10):
    loss = cka_loss(X, Y)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    print(loss)

Y = Y.detach().numpy()



tensor(1.0000, grad_fn=<RsubBackward1>)
tensor(1.0000, grad_fn=<RsubBackward1>)
tensor(1.0000, grad_fn=<RsubBackward1>)
tensor(1.0000, grad_fn=<RsubBackward1>)
tensor(1.0000, grad_fn=<RsubBackward1>)
tensor(1.0000, grad_fn=<RsubBackward1>)
tensor(1.0000, grad_fn=<RsubBackward1>)
tensor(1.0000, grad_fn=<RsubBackward1>)
tensor(1.0000, grad_fn=<RsubBackward1>)
tensor(1.0000, grad_fn=<RsubBackward1>)


In [18]:
import sys
sys.path.append("..")
from CKA_utils.CKA import CKA, CudaCKA
np_cka = CKA()
print('RBF Kernel CKA, between same subset: {}'.format(np_cka.kernel_CKA(X.detach().numpy(), Y.detach().numpy())))

RBF Kernel CKA, between same subset: 0.16438511925374072


In [19]:
cka(X,Y)

tensor(2.5119e-07, grad_fn=<DivBackward0>)

In [22]:
cka(X,y_pred)

tensor(4.9260e-05, grad_fn=<DivBackward0>)

In [27]:
import torch
import numpy as np
from scipy.spatial.distance import pdist

def rbf_kernel(X, Y, gamma):
    X_norms = (X ** 2).sum(dim=1, keepdim=True)
    Y_norms = (Y ** 2).sum(dim=1, keepdim=True)
    K = torch.exp(-gamma * (X_norms + Y_norms.T - 2 * torch.mm(X, Y.T)))
    return K


def cka(X, Y, gamma=1e-2):
    K_xx = rbf_kernel(X, X, gamma)
    K_yy = rbf_kernel(Y, Y, gamma)
    K_xy = rbf_kernel(X, Y, gamma)
    cka = torch.mean(K_xy) / torch.sqrt(torch.mean(K_xx) * torch.mean(K_yy))
    return cka.item()
    
X = torch.randn(100, 10)
Y = torch.randn(100, 10)

cka_value = cka(X, Y)

def objective(Y, X=X, gamma=1e-2):
    K_xx = rbf_kernel(X, X, gamma)
    K_yy = rbf_kernel(Y, Y, gamma)
    K_xy = rbf_kernel(X, Y, gamma)
    cka = torch.mean(K_xy) / torch.sqrt(torch.mean(K_xx) * torch.mean(K_yy))
    return -cka

from torch.optim import Adam

Y = torch.randn(100, 10, requires_grad=True)
optimizer = Adam([Y], lr=1e-2)

for i in range(1000):
    loss = objective(Y)
    optimizer.zero_grad()
    loss.backward()
    print('loss',loss)
    optimizer.step()

Y = Y.detach().numpy()
