In [2]:
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
import scipy.spatial
import ot
import torch
import ipynbname


from torch.autograd.functional import jacobian

In [3]:
# ----------------------------
# Device
# ----------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


# ---------------------------------------------------
# Helpers for shapes
# ---------------------------------------------------
def load_binary_shape(path, thresh=0.95):
    """Load a grayscale/binary PNG and return the dark-pixel coordinates."""
    img = plt.imread(path).astype(np.float64)
    img = img[..., 2]    # take one channel (your original behavior)

    pts = []
    H, W = img.shape
    for i in range(H):
        for j in range(W):
            # keep non-white pixels (threshold 0.95 like your code)
            if img[i, j] < thresh:
                pts.append([j, H - i])
    return np.array(pts)


def compute_cost_matrix(X):
    """Pairwise Euclidean distances, normalized."""
    C = scipy.spatial.distance.cdist(X, X)
    return C / C.max()

In [4]:
# ---------------------------------------------------
# Load all shapes
# ---------------------------------------------------
this_file = ipynbname.path()
data_path = this_file.parent / "simple_shapes"
shape_files = ["square.png", "cross.png", "triangle.png", "star.png"]
Xs = [load_binary_shape(data_path / fname) for fname in shape_files]
ns = [len(X) for X in Xs]

# ---------------------------------------------------
# Compute cost matrices and uniform measures
# ---------------------------------------------------
Cs = [compute_cost_matrix(X) for X in Xs]
ps = [ot.unif(n) for n in ns]

In [5]:
# ------------------------------------------------------
# Select one shape as target and the others as templates
# ------------------------------------------------------
i = 3
Cs_except_i = [C for j, C in enumerate(Cs) if j != i]
ps_except_i = [p for j, p in enumerate(ps) if j != i]
C_target_np = Cs[i]      # target cost matrix
p_target_np = ps[i]      # (not strictly needed here)

In [6]:
# ---------------------------------------------------
# Convert data to torch
# ---------------------------------------------------
Cs_torch = [torch.from_numpy(C).float().to(device) for C in Cs_except_i]
ps_torch = [torch.from_numpy(p).float().to(device) for p in ps_except_i]
C_target = torch.from_numpy(C_target_np).float().to(device) 
p_target = torch.from_numpy(p_target_np).float().to(device)



# lambda as a torch parameter (on simplex or not, up to you)
lambda_init = torch.ones(len(Cs_torch), device=device) / len(Cs_torch)
lambda_vec = lambda_init.clone().requires_grad_(True)

# Section


In [7]:
# ---------------------------------------------------
# Select barycenter support size and background measure 
# ---------------------------------------------------
M = 30
p_b_np = ot.unif(M)
p_b_torch = torch.from_numpy(p_b_np).float().to(device)

# ---------------------------------------------------
# Compute Gromov-Wasserstein barycenter and its Jacobian w.r.t. lambda
# ---------------------------------------------------

def GWbarycenter(lambda_vec):
    GWbary = ot.gromov.gromov_barycenters(M, Cs = Cs_torch, ps = ps_torch, p=None, lambdas=lambda_vec, loss_fun='square_loss')
    return GWbary

# Jacobian
GWBary_Jacobian = jacobian(GWbarycenter, lambda_vec)
print("Jacobian shape:", GWBary_Jacobian.shape)

# Brycenter
lambda_clean = lambda_vec.detach().clone()
Z_star = GWbarycenter(lambda_clean)

Jacobian shape: torch.Size([30, 30, 3])


In [8]:
#---------------------------------------------------
# Compute derivative of GW loss resp to barycenter cost matrix
#---------------------------------------------------

def GW_loss(C_bary):
    loss = ot.gromov.gromov_wasserstein2(C_bary, C_target, p_b_torch, p_target, 'square_loss')
    return loss

GWloss_Jacobian = jacobian(GW_loss, GWbarycenter(lambda_vec))
print("GW loss Jacobian shape:", GWloss_Jacobian.shape)

GW loss Jacobian shape: torch.Size([30, 30])


In [9]:
#---------------------------------------------------
# Compute derivative of GW loss resp to barycenter cost matrix via envelope theorem
#---------------------------------------------------
X = C_target
p_X_torch = p_target

# Solve GW between Z_star and X (POT)
T_np = ot.gromov.gromov_wasserstein(
    Z_star.detach(),  # don't backprop through the solver
    X.detach(),
    p_b_torch.detach(),
    p_X_torch.detach(),
    loss_fun='square_loss',
)

# Make sure T is a torch tensor on the right device
if not isinstance(T_np, torch.Tensor):
    T = torch.from_numpy(T_np).to(Z_star.device, dtype=Z_star.dtype)
else:
    T = T_np.to(Z_star.device, dtype=Z_star.dtype)

T_fixed = T.detach()  # VERY IMPORTANT: fixed plan

In [10]:
#---------------------------------------------------
# Define GW cost with fixed transport plan to use envelope theorem
#---------------------------------------------------
def gw_cost_fixed_plan(Z, X, T_fixed):
    """
    GW(Z, X; T_fixed) with squared loss, using T as a constant.
    Z: (n,n), X: (m,m), T_fixed: (n,m) (detached)
    Returns scalar tensor.
    """
    T_const = T_fixed  # already detached outside
    # shapes: Z[i,j], X[k,l], T[i,k], T[j,l]

    # Expand to (n, n, m, m)
    Z_exp = Z[:, :, None, None]
    X_exp = X[None, None, :, :]
    diff2 = (Z_exp - X_exp) ** 2  # (n,n,m,m)

    # Expand T weights: (n, m) -> (n,1,m,1) and (1,n,1,m)
    Ti = T_const[:, None, :, None]   # (n,1,m,1)
    Tj = T_const[None, :, None, :]   # (1,n,1,m)

    weight = Ti * Tj                 # (n,n,m,m)

    return (diff2 * weight).sum()


In [11]:
Z_star = Z_star.detach()  # detach to avoid double grad
Z_star = Z_star.requires_grad_(True)           # ensure grad if not already

loss_Z = gw_cost_fixed_plan(Z_star, X, T_fixed)
loss_Z.backward()                     # ∂/∂Z_star using fixed T

grad_Z = Z_star.grad 
print("Gradient via envelope theorem shape:", grad_Z.shape)    

Gradient via envelope theorem shape: torch.Size([30, 30])


In [12]:
grad_lambda = (grad_Z.unsqueeze(-1) * GWBary_Jacobian).sum(dim=(0,1))
print("Gradient w.r.t. lambda via envelope theorem:", grad_lambda)
print("Gradient w.r.t. lambda via envelope theorem shape:", grad_lambda.shape)

Gradient w.r.t. lambda via envelope theorem: tensor([0.0100, 0.0105, 0.0111])
Gradient w.r.t. lambda via envelope theorem shape: torch.Size([3])


input \lambda, {X_k}, 
X Compute Z^*(\lambda) with GW bary in PythonOT 
Compute \partial Z^* resp lambdas with AD 
Compute GW(Z^*,X) with GW in PythonOT 
Compute partial GW resp Z using formula in Step 1 of 45.2 
Compute grad resp lambda of objective

In [13]:
# ------------------------------------------------------
# Compute the Wasserstein distances between target and barys in the simplex
# ------------------------------------------------------




# ------------------------------------------------------
# Plot level curves in the simplex
# ------------------------------------------------------

