In [1]:
"""
Minimal, runnable PyTorch implementation of the COPT distance between two graphs
(based on the closed-form objective over P with Laplacian pseudoinverses).

It optimizes the transport matrix P under nonnegativity and (row/col) sum constraints
using a multiplicative Sinkhorn-style projection each step.

Usage (example at bottom):
- Provide adjacency matrices A_x (N x N) and A_y (M x M) as torch tensors.
- Call copt_distance(A_x, A_y, ...). Returns (distance_value, P_opt).

Notes:
- We compute L^\dagger by eigen-decomposition on the subspace orthogonal to constants.
- tr(sqrt(S)) is computed as sum of sqrt of eigenvalues of symmetric PSD S.
- For numerical stability we clamp tiny eigenvalues and add small diagonal jitter where needed.
"""

from __future__ import annotations
import torch
import torch.nn as nn
import torch.nn.functional as F

def laplacian_from_adj(A: torch.Tensor) -> torch.Tensor:
    """Unnormalized graph Laplacian L = D - A. Assumes undirected A with nonnegative entries."""
    assert A.dim() == 2 and A.size(0) == A.size(1), "A must be square"
    d = A.sum(dim=1)
    L = torch.diag(d) - A
    # ensure symmetry
    return 0.5 * (L + L.T)


def pinv_laplacian(L: torch.Tensor, eps: float = 1e-8, tol: float = 1e-10) -> torch.Tensor:
    """Moore-Penrose pseudoinverse of a graph Laplacian via eigendecomposition.
    Drops the eigenvalue associated with the constant vector; in practice, clamps small lambdas.
    """
    # Symmetrize
    Ls = 0.5 * (L + L.T)
    # Eigen-decomposition (symmetric/Hermitian)
    evals, evecs = torch.linalg.eigh(Ls)
    # Clamp tiny negatives due to numerical noise
    evals = torch.clamp(evals, min=0.0)
    # Build pseudoinverse spectrum: 0 for ~0 eigenvalues, 1/lambda otherwise
    mask = evals > tol
    inv_evals = torch.zeros_like(evals)
    inv_evals[mask] = 1.0 / torch.clamp(evals[mask], min=eps)
    L_pinv = (evecs * inv_evals) @ evecs.T
    # Make perfectly symmetric
    return 0.5 * (L_pinv + L_pinv.T)


def sqrtm_psd(M: torch.Tensor, eps: float = 1e-10) -> torch.Tensor:
    """Matrix square root for symmetric PSD matrices via eigen-decomposition."""
    Ms = 0.5 * (M + M.T)
    evals, evecs = torch.linalg.eigh(Ms)
    evals = torch.clamp(evals, min=eps)
    sqrt_evals = torch.sqrt(evals)
    S = (evecs * sqrt_evals) @ evecs.T
    return 0.5 * (S + S.T)


def trace_sqrt_psd(M: torch.Tensor, eps: float = 1e-10) -> torch.Tensor:
    """Compute tr(sqrt(M)) for symmetric PSD M via eigenvalues.
    Returns a scalar tensor that participates in autograd.
    """
    Ms = 0.5 * (M + M.T)
    evals = torch.linalg.eigvalsh(Ms)
    evals = torch.clamp(evals, min=eps)
    return torch.sqrt(evals).sum()


def sinkhorn_project(P: torch.Tensor, row_sum: float, col_sum: float, iters: int = 10, eps: float = 1e-16) -> torch.Tensor:
    """Project P to the set {P >= 0, P 1 = row_sum, P^T 1 = col_sum} using multiplicative scaling.
    Keeps autograd path (no .detach()).
    """
    P = torch.clamp(P, min=0.0) + eps
    for _ in range(iters):
        # Scale columns to sum to row_sum (since P 1_col = row_sum for each column sum over rows)
        col_sums = P.sum(dim=0, keepdim=True) + eps
        P = P * (row_sum / col_sums)
        # Scale rows to sum to col_sum (since P^T 1_row = col_sum for each row sum over cols)
        row_sums = P.sum(dim=1, keepdim=True) + eps
        P = P * (col_sum / row_sums)
    return P


def copt_objective_from_Ldag(Lx_dag: torch.Tensor, Ly_dag: torch.Tensor, P: torch.Tensor) -> torch.Tensor:
    """Compute the COPT closed-form objective (to be minimized) given Lx^dag, Ly^dag and current P.
    Obj = M*tr(Lx^dag) + N*tr(Ly^dag) - 2 * tr( sqrt( (Ly^dag)^{1/2} P^T Lx^dag P (Ly^dag)^{1/2} ) )
    """
    N = P.size(0)
    M = P.size(1)
    Ly_dag_sqrt = sqrtm_psd(Ly_dag)
    # Z must be symmetric PSD
    Z = Ly_dag_sqrt @ P.T @ Lx_dag @ P @ Ly_dag_sqrt
    Z = 0.5 * (Z + Z.T)
    term = trace_sqrt_psd(Z)
    obj = M * torch.trace(Lx_dag) + N * torch.trace(Ly_dag) - 2.0 * term
    return obj


def copt_distance_from_laplacians(
    Lx: torch.Tensor,
    Ly: torch.Tensor,
    steps: int = 300,
    lr: float = 0.3,
    sinkhorn_iters: int = 10,
    seed: int | None = 0,
    verbose: bool = True,
):
    """Optimize P and return (distance_value, P_opt) for the COPT distance between two graphs.
    Lx, Ly: Laplacians (N x N) and (M x M)
    """
    if seed is not None:
        torch.manual_seed(seed)
    device = Lx.device
    N = Lx.size(0)
    M = Ly.size(0)

    # Precompute pseudoinverses (detached constants for speed/stability)
    with torch.no_grad():
        Lx_dag = pinv_laplacian(Lx).to(device)
        Ly_dag = pinv_laplacian(Ly).to(device)

    # Make them parameters for autograd? We treat them as constants.
    Lx_dag = Lx_dag.requires_grad_(False)
    Ly_dag = Ly_dag.requires_grad_(False)

    # Initialize P with positive entries
    P = torch.rand(N, M, device=device) + 1.0
    P = sinkhorn_project(P, row_sum=N, col_sum=M, iters=30)
    P = nn.Parameter(P)

    opt = torch.optim.Adam([P], lr=lr)

    best_val = None
    best_P = None

    for t in range(1, steps + 1):
        opt.zero_grad()
        # Re-project to feasible set before each step
        with torch.no_grad():
            P.data = sinkhorn_project(P.data, row_sum=N, col_sum=M, iters=sinkhorn_iters)
        obj = copt_objective_from_Ldag(Lx_dag, Ly_dag, P)
        obj.backward()
        opt.step()

        # Track best
        val = obj.detach().item()
        if best_val is None or val < best_val:
            best_val = val
            best_P = P.detach().clone()

        if verbose and (t % max(1, steps // 10) == 0 or t == 1):
            print(f"[Step {t:4d}] COPT obj = {val:.6f}")

    # Final projection for cleanliness
    with torch.no_grad():
        best_P = sinkhorn_project(best_P, row_sum=N, col_sum=M, iters=50)
    return best_val, best_P


def copt_distance(
    A_x: torch.Tensor,
    A_y: torch.Tensor,
    **kwargs,
):
    """Convenience wrapper from adjacency to distance.
    A_x: (N,N), A_y: (M,M) nonnegative symmetric (float) tensors.
    Returns (distance_value, P_optimal)
    """
    Lx = laplacian_from_adj(A_x)
    Ly = laplacian_from_adj(A_y)
    return copt_distance_from_laplacians(Lx, Ly, **kwargs)


# -----------------------------
# Example Usage
# -----------------------------
if __name__ == "__main__":
    torch.set_default_dtype(torch.float64)  # better numerical precision

    # Build two small random undirected graphs (Erdos-Renyi) for demo
    def er_graph(n, p, seed=0):
        g = torch.rand(n, n)
        g = torch.triu((g < p).double(), diagonal=1)
        g = g + g.T
        return g

    N, M = 30, 25
    A1 = er_graph(N, p=0.15, seed=0)
    A2 = er_graph(M, p=0.20, seed=1)

    dist, Popt = copt_distance(A1, A2, steps=200, lr=0.5, sinkhorn_iters=10, verbose=True)

    print("\nCOPT distance:", dist)
    print("P shape:", Popt.shape)
    print("Row sums (should be ~M):", Popt.sum(dim=1)[:5])
    print("Col sums (should be ~N):", Popt.sum(dim=0)[:5])


  """


[Step    1] COPT obj = 514.562386
[Step   20] COPT obj = 423.730343
[Step   40] COPT obj = 417.419936
[Step   60] COPT obj = 410.964449
[Step   80] COPT obj = 402.721435
[Step  100] COPT obj = 389.977649
[Step  120] COPT obj = 380.059954
[Step  140] COPT obj = 368.885516
[Step  160] COPT obj = 359.842705
[Step  180] COPT obj = 352.344428
[Step  200] COPT obj = 348.508909

COPT distance: 348.5089088111093
P shape: torch.Size([30, 25])
Row sums (should be ~M): tensor([25.0000, 25.0000, 25.0000, 25.0000, 25.0000])
Col sums (should be ~N): tensor([30.0000, 30.0000, 30.0000, 30.0000, 30.0000])
