In [None]:
import pandas as pd
from wassnmf.validation import generate_data
from  wassnmf.wassnmf import WassersteinNMF

In [None]:
%reload_ext autoreload
%autoreload 2


In [None]:
scenario =  {
    "name": "gaussian_mixture",
    "n_samples": 20,
    "n_features": 20
}

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from geomloss import SamplesLoss  # Sinkhorn-Wasserstein loss

# Convert NumPy arrays to PyTorch tensors
def to_tensor(x, requires_grad=False, dtype=torch.float32):
    return torch.tensor(x, dtype=dtype, requires_grad=requires_grad)
# -------------------------------------------------------------------------
# Wasserstein Dictionary Learning
# -------------------------------------------------------------------------
def train_wdil(X, cost_matrix, n_components=5, lr=0.01, epochs=100):
    """
    Train Wasserstein Dictionary Learning using Sinkhorn gradient descent with a precomputed cost matrix.

    X: (n_features, n_samples) histogram data
    cost_matrix: (n_features, n_features) cost for 1D transport
    """
    # Convert data to PyTorch tensors
    X_torch = to_tensor(X, requires_grad=False)          # shape (n_features, n_samples)
    cost_matrix_torch = to_tensor(cost_matrix, requires_grad=False)  # shape (n_features, n_features)

    # -- Define a custom cost function that just returns cost_matrix for each sample in the batch --
    def cost_fn(x, y):
        """
        x, y: shape (batch_size, n_features, 1) if in measure mode.
        We must return shape (batch_size, n_features, n_features) with the cost of each pair of bins.
        """
        if len(x.shape) == 3:  # Handle GeomLoss passing (B, N, 1)
            x = x.squeeze(-1)  # Remove last dim if it's (B, N, 1)
            y = y.squeeze(-1)

        B = x.shape[0]  # Get batch size
        return cost_matrix_torch.unsqueeze(0).expand(B, -1, -1)  # Expand cost matrix for each batch

    # Initialize Dictionary (D) and Coeffs (R) randomly
    n_features, n_samples = X.shape
    D = to_tensor(np.abs(np.random.randn(n_features, n_components)), requires_grad=True)
    R = to_tensor(np.abs(np.random.randn(n_components, n_samples)), requires_grad=True)

    # Define Sinkhorn loss using our custom cost function
    sinkhorn_loss = SamplesLoss(
        loss="sinkhorn",
        cost=cost_fn,   # <--- pass the function, NOT the tensor
        blur=0.025,     
        debias=False
    )

    # Optimizer
    optimizer = optim.Adam([D, R], lr=lr)

    for epoch in range(epochs):
        optimizer.zero_grad()

        # Reconstruction
        X_hat = D @ R  # shape (n_features, n_samples)

        # Sinkhorn expects (batch_size, n_features) so we do .T
        # shape = (n_samples, n_features)
        loss = sinkhorn_loss(X_torch.T, X_hat.T)

        # Optionally add a sparsity penalty on R
        loss += 0.05 * torch.sum(R * torch.log(R + 1e-9))

        # Backprop
        loss.backward()
        optimizer.step()

        if epoch % 10 == 0:
            print(f"Epoch {epoch}, Loss = {loss.item():.4f}")

    # Return learned dictionary + coefficients in NumPy
    return D.detach().numpy(), R.detach().numpy()


In [None]:
X, K, coord, cost_matrix = generate_data(scenario=scenario)

In [None]:
wnmf = WassersteinNMF(n_components=5, verbose=True)
D_wass, Lambda_wass = wnmf.fit_transform(X, K)

In [None]:
D_learned, R_learned = train_wdil(X.T, cost_matrix.T, n_components=5, lr=0.01, epochs=100)
print("D_learned shape:", D_learned.shape)
print("R_learned shape:", R_learned.shape)

In [None]:
import seaborn as sns

In [None]:
X_pred = D_learned @ R_learned 

In [None]:
X_wassnmf = D_wass @ Lambda_wass

In [None]:
sns.heatmap(X)

In [None]:
sns.heatmap(X_pred)

In [None]:
sns.heatmap(X_wassnmf)

In [None]:
sns.heatmap(D_wass)

In [None]:
sns.heatmap(Lambda_wass)