<a href="https://colab.research.google.com/github/AyobamiMichael/KCICA/blob/main/Kernel_based_Contrastive_Independent_Component_Analysis_(KCICA).ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim

In [2]:
def rbf_kernel(x, sigma=1.0):
    """ Compute the RBF Kernel Gram Matrix. """
    pairwise_sq_dists = torch.cdist(x, x, p=2) ** 2
    return torch.exp(-pairwise_sq_dists / (2 * sigma ** 2))

In [3]:
def hsic_loss(X, Y):
    """ Compute HSIC between two feature maps X and Y. """
    n = X.size(0)
    K = rbf_kernel(X)
    L = rbf_kernel(Y)
    H = torch.eye(n) - (1/n) * torch.ones((n, n))  # Centering matrix
    H = H.to(X.device)
    HSIC = torch.trace(K @ H @ L @ H) / (n - 1) ** 2
    return HSIC

In [5]:
class KCICA(nn.Module):
    def __init__(self, input_dim, output_dim, hidden_dim=128):
        super(KCICA, self).__init__()
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, output_dim)
        )

    def forward(self, x):
        return self.encoder(x)

In [6]:
# Simulated Example
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
X = torch.randn(100, 5).to(device)  # Simulated mixed signals
neg_samples = torch.randn(100, 5).to(device)  # Negative samples

# Model & Optimizer
model = KCICA(input_dim=5, output_dim=5).to(device)
optimizer = optim.Adam(model.parameters(), lr=0.01)


In [7]:
# Training
for epoch in range(1000):
    optimizer.zero_grad()
    S = model(X)  # Recovered sources
    loss = hsic_loss(S, S) - hsic_loss(S, neg_samples)  # Contrastive HSIC Loss
    loss.backward()
    optimizer.step()

    if epoch % 100 == 0:
        print(f"Epoch {epoch}, Loss: {loss.item():.4f}")

print("Training Complete! 🚀")


Epoch 0, Loss: 0.0067
Epoch 100, Loss: -0.0012
Epoch 200, Loss: -0.0026
Epoch 300, Loss: -0.0033
Epoch 400, Loss: -0.0034
Epoch 500, Loss: -0.0035
Epoch 600, Loss: -0.0035
Epoch 700, Loss: -0.0035
Epoch 800, Loss: -0.0035
Epoch 900, Loss: -0.0035
Training Complete! 🚀
