## Canonical Correlation Analysis (CCA)

Canonical Correlation Analysis (CCA) is a method used to compare two sets of random variables. The objective of this method is to find two sets of weights that maximize the correlation between the random variables \(X\) and \(Y\). By training a linear regression, we aim to find \(w_{R} \in \mathbb{R}^{D}\) and \(w_{R'} \in \mathbb{R}^{D}\) that satisfy the following equation:

$$\rho := \rho(R, R') := \max_{w_R, w_{R'}} \frac{\langle R w_R, R' w_{R'} \rangle}{\lVert R w_R \rVert \cdot \lVert R' w_{R'} \rVert}$$

The value of \(\rho\) indicates the correlation between the linear combinations of the variables in \(R\) and \(R'\). A \(\rho\) value close to 1 implies that the representations \(R\) and \(R'\) are highly similar.

We can define two metrics to measure the similarity between the random variables:

1. **Standard Representation:**
   
   $$m_{CCA}(R, R') = \frac{1}{D} \sum_{i=1}^D \rho_i$$

2. **Yanai's Generalized Coefficient of Determination:**

   $$m_{CCA}^2(R, R') = \frac{1}{D} \sum_{i=1}^D \rho_i^2$$

These metrics provide a quantitative measure of the similarity between the sets of variables based on their canonical correlations.

In [140]:
# implementation in torch
import torch
import torch.nn.functional as F

def mean_center(X):
    """Mean center the data, due to this method Assume mean-centered representations"""
    return X - X.mean(dim=0)

def compute_covariance_matrices(X, Y):
    n_samples = X.size(0)
    sigma_X = X.T @ X / (n_samples - 1)
    sigma_Y = Y.T @ Y / (n_samples - 1)
    sigma_XY = X.T @ Y / (n_samples - 1)
    return sigma_X, sigma_Y, sigma_XY

def compute_inverse_sqrt_matrix(matrix):
    eigvals, eigvecs = torch.linalg.eigh(matrix)
    eigvals = torch.clamp(eigvals, min=1e-10)  # Avoid division by zero
    inv_sqrt_matrix = eigvecs @ torch.diag(1.0 / torch.sqrt(eigvals)) @ eigvecs.T
    return inv_sqrt_matrix

def compute_cca(X, Y, output_dim):
    # Center the data
    X_centered = mean_center(X)
    Y_centered = mean_center(Y)
    
    # Compute covariance matrices
    sigma_X, sigma_Y, sigma_XY = compute_covariance_matrices(X_centered, Y_centered)
    
    # Compute inverse square roots of covariance matrices
    sqrt_inv_sigma_X = compute_inverse_sqrt_matrix(sigma_X)
    sqrt_inv_sigma_Y = compute_inverse_sqrt_matrix(sigma_Y)
    
    # Compute the transformation matrix
    T = sqrt_inv_sigma_X @ sigma_XY @ sqrt_inv_sigma_Y
    
    # Perform SVD on the transformation matrix
    U, S, V = torch.svd(T)
    
    # Select the top output_dim components
    X_c = X_centered @ sqrt_inv_sigma_X @ U
    Y_c = Y_centered @ sqrt_inv_sigma_Y @ V
    
    return X_c, Y_c, S

def compute_standard_cca(X, Y):
    _, _, S = compute_cca(X, Y, min(X.size(1), Y.size(1)))
    return float(1/X.size(1) * torch.sum(S))

def compute_yanai_cca(X, Y):
    _, _, S = compute_cca(X, Y, min(X.size(1), Y.size(1)))
    return float(1/X.size(1) * torch.sum(S**2))

In [143]:
# Example usage
X = torch.rand(100, 10)  # 100 samples, 10 features
Y = torch.rand(100, 10)  # 100 samples, 10 features

# Set the number of components you want to keep
output_dim = 1

# Perform CCA
X_c, Y_c, canonical_correlations = compute_cca(X, Y, output_dim)

# The canonical correlations (rho) are the singular values S from the SVD
rho = canonical_correlations
print("Canonical Correlations (rho):", rho)


Canonical Correlations (rho): tensor([0.5262, 0.4921, 0.3742, 0.3318, 0.3189, 0.2431, 0.2374, 0.1553, 0.0787,
        0.0172])


In [144]:
print(compute_standard_cca(X, Y))
print(compute_yanai_cca(X,Y))

0.2774849832057953
0.10169340670108795


## Singular Value CCA

Intuitively, an eigenvector is a vector whose direction remains unchanged when a linear transformation is applied to it.

In [69]:
def pca(X, variance_retained=0.99):
    """Perform PCA on the data to retain the specified amount of variance."""
    X_centered = X - X.mean(dim=0)
    
    # Compute covariance matrix
    covariance_matrix = (X_centered.T @ X_centered) / (X_centered.size(0) - 1)
    
    # Compute eigenvalues and eigenvectors
    eigvals, eigvecs = torch.linalg.eigh(covariance_matrix)
    
    # Sort eigenvalues and eigenvectors in descending order
    sorted_indices = torch.argsort(eigvals, descending=True)
    eigvals = eigvals[sorted_indices]
    eigvecs = eigvecs[:, sorted_indices]
    
    # Compute the cumulative explained variance
    explained_variance = eigvals / eigvals.sum()
    cumulative_variance = torch.cumsum(explained_variance, dim=0)
    
    # Determine the number of components to retain
    num_components = torch.searchsorted(cumulative_variance, variance_retained).item() + 1
    
    # Project the data onto the principal components
    principal_components = eigvecs[:, :num_components]
    X_pca = X_centered @ principal_components
    
    return X_pca, principal_components

def compute_standard_svcca(X, Y, variance_retained=0.99):
    X_pca, _ = pca(X, variance_retained)
    Y_pca, _ = pca(Y, variance_retained)
    
    _, _, S = compute_cca(X_pca, Y_pca, min(X.size(1), Y.size(1)))
    return float(1/X.size(1) * torch.sum(S))

def compute_yanai_svcca(X, Y, variance_retained=0.99):
    X_pca, _ = pca(X, variance_retained)
    Y_pca, _ = pca(Y, variance_retained)

    _, _, S = compute_cca(X_pca, Y_pca, min(X.size(1), Y.size(1)))
    return float(1/X.size(1) * torch.sum(S**2))

In [192]:
compute_standard_svcca(X, Y)

0.27748483419418335

## Projection Weighted CCA

In [193]:
def compute_wvcca(X, Y, output_dim=1):

    X_c, _, S = compute_cca(X, Y, output_dim)
    
    output_dim = X_c.size(1)
    X = mean_center(X)
    weights = torch.zeros(output_dim)
    for i in range(output_dim):
        Xw = X.T @ X_c[:, i]
        weights[i] = torch.abs(torch.sum(Xw @ X.T))

    # Normalize the weights
    weights = weights / torch.sum(weights)
    # Compute the final PWCCA measure
    return float(torch.sum(weights * S))

In [194]:
compute_wvcca(X, X)

0.9999998211860657