In [11]:
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
import scipy.spatial
import ot
import torch
import ipynbname


from torch.autograd.functional import jacobian

In [2]:
# ----------------------------
# Device
# ----------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


# ---------------------------------------------------
# Helpers for shapes
# ---------------------------------------------------
def load_binary_shape(path, thresh=0.95):
    """Load a grayscale/binary PNG and return the dark-pixel coordinates."""
    img = plt.imread(path).astype(np.float64)
    img = img[..., 2]    # take one channel (your original behavior)

    pts = []
    H, W = img.shape
    for i in range(H):
        for j in range(W):
            # keep non-white pixels (threshold 0.95 like your code)
            if img[i, j] < thresh:
                pts.append([j, H - i])
    return np.array(pts)


def compute_cost_matrix(X):
    """Pairwise Euclidean distances, normalized."""
    C = scipy.spatial.distance.cdist(X, X)
    return C / C.max()

In [None]:
# ---------------------------------------------------
# Load all shapes
# ---------------------------------------------------
this_file = ipynbname.path()
data_path = this_file.parent / "simple_shapes"
shape_files = ["square.png", "cross.png", "triangle.png", "star.png"]
Xs = [load_binary_shape(data_path / fname) for fname in shape_files]
ns = [len(X) for X in Xs]

# ---------------------------------------------------
# Compute cost matrices and uniform measures
# ---------------------------------------------------
Cs = [compute_cost_matrix(X) for X in Xs]
ps = [ot.unif(n) for n in ns]

In [7]:
# ------------------------------------------------------
# Select one shape as target and the others as templates
# ------------------------------------------------------
i = 3
Cs_except_i = [C for j, C in enumerate(Cs) if j != i]
ps_except_i = [p for j, p in enumerate(ps) if j != i]
C_target_np = Cs[i]      # target cost matrix
p_target_np = ps[i]      # (not strictly needed here)

In [8]:
# ---------------------------------------------------
# Convert data to torch
# ---------------------------------------------------
Cs_torch = [torch.from_numpy(C).float().to(device) for C in Cs_except_i]
ps_torch = [torch.from_numpy(p).float().to(device) for p in ps_except_i]
C_target = torch.from_numpy(C_target_np).float().to(device) 



# lambda as a torch parameter (on simplex or not, up to you)
lambda_init = torch.ones(len(Cs_torch), device=device) / len(Cs_torch)
lambda_vec = lambda_init.clone().requires_grad_(True)

In [None]:
# ---------------------------------------------------
# Select barycenter support size and background measure 
# ---------------------------------------------------
M = 30
p_b_np = ot.unif(M)
p_b_torch = torch.from_numpy(p_b_np).float().to(device)


def GWbarycenter(lambda_vec):
    GWbary = ot.gromov.gromov_barycenters(M, Cs = Cs_torch, ps = ps_torch, p=None, lambdas=lambda_vec, loss_fun='square_loss')
    return GWbary


GWBary_Jacobian = jacobian(GWbarycenter, lambda_vec)
print("Jacobian shape:", GWBary_Jacobian.shape)



Jacobian shape: torch.Size([30, 30, 3])


input \lambda, {X_k}, 
X Compute Z^*(\lambda) with GW bary in PythonOT 
Compute \partial Z^* resp lambdas with AD 
Compute GW(Z^*,X) with GW in PythonOT 
Compute partial GW resp Z using formula in Step 1 of 45.2 
Compute grad resp lambda of objective