<h1>Tutorial: Hierarchical Refinement for Large-Scale Optimal Transport

This tutorial presents the implementation of the Hierarchical Refinement (HiRef) algorithm from the paper "Hierarchical Refinement: Optimal Transport to Infinity and Beyond" (Halmos et al., 2025). This algorithm solves the optimal transport (OT) problem for two large sets of points with linear memory complexity. The key idea is to exploit the fact that low-rank optimal transport solutions co-cluster each point with its image under the (bijective) Monge map. HiRef therefore recursively constructs a multi-scale partition of the data by solving a hierarchy of low-rank OT sub-problems, ultimately leading to a complete bijective coupling. Below we detail the implementation steps, highlighting methodological choices.

<h2> Setup and import

In [22]:
from typing import List, Callable, Union, Dict, Any
import random
import operator
import functools
from functools import reduce

import numpy as np
import torch
import matplotlib.pyplot as plt
from sklearn.cluster import KMeans

import jax
import jax.numpy as jnp
from ott.geometry import costs, pointcloud
from ott.tools import sinkhorn_divergence, progot
from ott.problems.linear import linear_problem
from ott.solvers.linear import sinkhorn
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
import torch
import torchvision.models as models
import os
from PIL import Image
from tqdm import tqdm
import pickle

from ott.geometry import geometry
from ott.problems.linear import linear_problem
from ott.solvers.linear import sinkhorn

# 1. Load images & model

The tutorial begins by loading a set of images (here a subset of 5000 images from COCO/ImageNet). Each image is transformed (resized to 224×224 and converted to a tensor) then passed through a pre-trained neural network (ResNet-50). The final classification layer (model.fc = Identity) is removed to obtain only the feature vectors (embeddings) of dimension 2048.

To save time during future executions, we save the embeddings with pickle. In the end, we obtain an embeddings tensor of shape (5000, 2048) containing the vector representations of our images. So you can also directly go to part 3 if you want to start directly from the embeddings.

In [18]:
class CustomImageDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        # Liste tous les fichiers .jpg
        self.image_paths = [os.path.join(root_dir, fname)
                            for fname in os.listdir(root_dir)
                            if fname.endswith('.jpg')]
        # Si le nombre d'images est impair, enlever la dernière
        if len(self.image_paths) % 2 != 0:
            self.image_paths = self.image_paths[:-1]

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        image = Image.open(img_path).convert('RGB')  # Toujours convertir en RGB
        if self.transform:
            image = self.transform(image)
        return image  # Pas d'étiquette ici, juste l'image

In [19]:
# Define image transformations
transform = transforms.Compose([
    transforms.Resize((224, 224)),  # Resize images for CNN input
    transforms.ToTensor(),
])

# Load COCO dataset from extracted path
imagenet_dataset = CustomImageDataset(root_dir='images', transform=transform)


# Create DataLoader for batching
imagenet_loader = DataLoader(imagenet_dataset, batch_size=32, shuffle=True)

print(f"Loaded {len(imagenet_dataset)} images from ImageNet!")

Loaded 5000 images from ImageNet!


In [20]:
model_path = os.path.expanduser("resnet50-0676ba61.pth")

# Load pretrained ResNet model
model = models.resnet50()
model.load_state_dict(torch.load(model_path, map_location="cpu"))
model.fc = torch.nn.Identity()  # Remove classification layer to extract features
model.eval()  # Set to evaluation mode

# Move model to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 

# 2. Create embeddings

In [21]:
def extract_features(dataloader, model):
    """
    Compute embeddings
    """
    embeddings = []
    with torch.no_grad():
        for idx, images in tqdm(enumerate(dataloader), desc="Extracting features", total=len(dataloader)):
            images = images.to(device)
            features = model(images)
            embeddings.append(jnp.array(features.detach().cpu().numpy()))  
    return jnp.vstack(embeddings)  # Stack all embeddings

print('extracting embeddings!')
embeddings = extract_features(imagenet_loader, model)

extracting embeddings!


Extracting features: 100%|██████████| 157/157 [04:00<00:00,  1.53s/it]


In [22]:
with open('embeddings/embeddings.pkl', "wb") as f:
    pickle.dump(embeddings, f)

print("Embeddings saved successfully to embeddings/embeddings.pkl")

Embeddings saved successfully to embeddings/embeddings.pkl


# 3.

In [23]:
# Load embeddings from the pickle file
with open('embeddings/embeddings.pkl', "rb") as f:
    embeddings = pickle.load(f)

print(f"Embeddings loaded successfully! Shape: {embeddings.shape}")

Embeddings loaded successfully! Shape: (5000, 2048)


In the paper, OT is formulated between two uniform distributions of the same size, where the problem is actually an assignment problem (Monge bijection). To reproduce this, we start with the embeddings extracted from 5000 images and do:  
Random permutation: We create a random index vector of size 5000 (fixed with a seed for reproducibility).
<b>Separation into two equal subsets:We take the first 2500 entries of this permutation to form set X and the next 2500 for Y. Thus, |X| = |Y| = 2500, with the same uniform measure a priori.
These two sets X and Y of 2048-dim vectors will be the input points for the HiRef algorithm. 

In [24]:
num_samples = embeddings.shape[0]

# Shuffle indices
key = jax.random.PRNGKey(42)
indices = jax.random.permutation(key, num_samples)

# Split into two tensors
X = embeddings[indices[:num_samples // 2]]
Y = embeddings[indices[num_samples // 2:]]

print(f"X shape: {X.shape}, Y shape: {Y.shape}")

X shape: (2500, 2048), Y shape: (2500, 2048)


# 4. Sinkhorn with epsilon-schedule

Entropic optimal transport with the Sinkhorn algorithm is the foundation on which we'll build. Here's an implementation of the log-stabilized Sinkhorn with the utility functions for optimal transport.

In [25]:
def ott_log_sinkhorn(grad,
                     a,
                     b,
                     gamma_k,
                     max_iter = 50,
                     balanced = True,
                     unbalanced = False,
                     tau = None,
                     tau2 = None):
    """
    grad: cost matrix (n, m)
    a: source histogram (n,)
    b: target histogram (m,)
    gamma_k: régularisation inverse (1/epsilon)
    """
    epsilon = 1.0 / gamma_k

    # Choix des tau pour marges
    if balanced and not unbalanced:
        tau_a, tau_b = 1.0, 1.0
    elif not balanced and unbalanced:
        tau_a = tau / (tau + epsilon)
        tau_b = tau2 / (tau2 + epsilon) if tau2 is not None else tau_a
    else:  # semi-relaxed
        tau_a, tau_b = 1.0, tau / (tau + epsilon)

    # Géométrie entropique sur la matrice de coût
    geom = geometry.Geometry(cost_matrix=grad, epsilon=epsilon)

    # Construction du problème linéaire
    prob = linear_problem.LinearProblem(
        geom,
        a=a,
        b=b,
        tau_a=tau_a,
        tau_b=tau_b
    )

    # Solveur Sinkhorn
    solver = sinkhorn.Sinkhorn(max_iterations=max_iter)
    out = solver(prob)

    return out.matrix

def utils__Delta(vark, varkm1, gamma_k):
    return (gamma_k**-2) * (jnp.linalg.norm(vark[0] - varkm1[0]) + jnp.linalg.norm(vark[1] - varkm1[1]) + jnp.linalg.norm(vark[2] - varkm1[2]))

def utils__random_simplex_sample(key, N, dtype = jnp.float64):
    """
    Draws a random point from the (N-1)-simplex using normalized exponentiated Gaussian variates.

    Args:
        key: PRNGKey for random number generation.
        N: Dimensionality of the simplex (vector length).
        dtype: Desired floating-point type of the output.

    Returns:
        A 1D array of shape (N,) with non-negative entries summing to 1.
    """
    # Sample N independent standard normals
    z = jax.random.normal(key, shape=(N,), dtype=dtype)
    # Exponentiate
    e = jnp.exp(z)
    # Normalize to sum to 1
    return e / jnp.sum(e)

def utils__initialize_couplings(a, b, gQ, gR, gamma, full_rank = True, key = jax.random.PRNGKey(0), dtype = float, rank2_random = False, max_iter = 50):
    """
    Initialize coupling factors in JAX.
    """
    N1 = a.shape[0]
    N2 = b.shape[0]
    r = gQ.shape[0]
    r2 = gR.shape[0]

    one_N1 = jnp.ones((N1,), dtype=dtype)
    one_N2 = jnp.ones((N2,), dtype=dtype)

    if full_rank:
        # Full-rank initialization via log-Sinkhorn
        key, subkey = jax.random.split(key)
        C_random = jax.random.uniform(subkey, (N1, r), dtype=dtype)
        Q = ott_log_sinkhorn(C_random, a, gQ, gamma,
                                max_iter=max_iter,
                                balanced=True)

        key, subkey = jax.random.split(key)
        C_random = jax.random.uniform(subkey, (N2, r2), dtype=dtype)
        R = ott_log_sinkhorn(C_random, b, gR, gamma,
                                max_iter=max_iter,
                                balanced=True)

        # Compute updated inner marginals
        gR_new = R.T @ one_N2
        gQ_new = Q.T @ one_N1

        key, subkey = jax.random.split(key)
        C_random = jax.random.uniform(subkey, (r, r2), dtype=dtype)
        T = ott_log_sinkhorn(C_random, gQ_new, gR_new, gamma,
                                max_iter=max_iter,
                                balanced=True)

        # Inner inverse coupling
        if r == r2:
            Lambda = jnp.linalg.inv(T)
        else:
            Lambda = jnp.diag(1.0 / gQ_new) @ T @ jnp.diag(1.0 / gR_new)

    else:
        # Rank-2 initialization (Scetbon et al. 2021)
        if r != r2:
            raise ValueError("Rank-2 init requires equal inner ranks.")
        g = gQ
        lambd = jnp.minimum(jnp.min(a), jnp.min(b))
        lambd = jnp.minimum(lambd, jnp.min(g)) / 2.0

        # Sample or deterministic
        if rank2_random:
            key, *splits = random.split(key, 4)
            a1 = utils__random_simplex_sample(N1, splits[0], dtype)
            b1 = utils__random_simplex_sample(N2, splits[1], dtype)
            g1 = utils__random_simplex_sample(r, splits[2], dtype)
        else:
            g1 = jnp.arange(1, r + 1, dtype=dtype)
            g1 = g1 / jnp.sum(g1)
            a1 = jnp.arange(1, N1 + 1, dtype=dtype)
            a1 = a1 / jnp.sum(a1)
            b1 = jnp.arange(1, N2 + 1, dtype=dtype)
            b1 = b1 / jnp.sum(b1)

        a2 = (a - lambd * a1) / (1 - lambd)
        b2 = (b - lambd * b1) / (1 - lambd)
        g2 = (g - lambd * g1) / (1 - lambd)

        Q = lambd * jnp.outer(a1, g1) + (1 - lambd) * jnp.outer(a2, g2)
        R = lambd * jnp.outer(b1, g1) + (1 - lambd) * jnp.outer(b2, g2)

        gR_new = R.T @ one_N2
        gQ_new = Q.T @ one_N1

        T = (1 - lambd) * jnp.diag(g) + lambd * jnp.outer(gR_new, gQ_new)
        Lambda = jnp.linalg.inv(T)

    return Q, R, T, Lambda

In [None]:
def gd__Wasserstein_Grad(C_or_Cfactors, Q, R, Lambda, full_grad=True, low_rank=False):
    if low_rank:
        C1, C2 = C_or_Cfactors
        gradQ = C1 @ ((C2 @ R) @ Lambda.T)
    else:
        C = C_or_Cfactors
        gradQ = (C @ R) @ Lambda.T

    if full_grad:
        N1 = Q.shape[0]
        one_N1 = jnp.ones((N1,), dtype=Q.dtype)
        gQ = Q.T @ one_N1
        w1 = jnp.diag((gradQ.T @ Q) @ jnp.diag(1.0 / gQ))
        gradQ = gradQ - jnp.outer(one_N1, w1)

    if low_rank:
        gradR = C2.T @ ((C1.T @ Q) @ Lambda)
    else:
        gradR = (C.T @ Q) @ Lambda

    if full_grad:
        N2 = R.shape[0]
        one_N2 = jnp.ones((N2,), dtype=R.dtype)
        gR = R.T @ one_N2
        w2 = jnp.diag(jnp.diag(1.0 / gR) @ (R.T @ gradR))
        gradR = gradR - jnp.outer(one_N2, w2)

    return gradQ, gradR

def gd__compute_grad_A(C, Q, R, Lambda, gamma,
                       semiRelaxedLeft, semiRelaxedRight,
                       Wasserstein=True, FGW=False,
                       A=None, B=None,
                       alpha=0.0,
                       unbalanced=False,
                       full_grad=True,
                       low_rank=False,
                       C_factors=None, A_factors=None, B_factors=None):
    """
    JAX version of gradient computation for Wasserstein, GW and FGW.
    If `low_rank` is True, it uses low-rank factorized cost matrices.
    """

    r = Lambda.shape[0]
    one_r = jnp.ones((r,))
    One_rr = jnp.outer(one_r, one_r)

    if low_rank:
        # Assume: A_factors = (A1, A2), B_factors = (B1, B2), C_factors = (C1, C2)
        A1, A2 = A_factors
        B1, B2 = B_factors

        gradQ = -4 * (A1 @ (A2 @ (Q @ Lambda @ ((R.T @ B1) @ (B2 @ R)) @ Lambda.T)))
        gradR = -4 * (B1 @ (B2 @ (R @ (Lambda.T @ ((Q.T @ A1) @ (A2 @ Q)) @ Lambda))))

        if full_grad:
            N1, N2 = Q.shape[0], R.shape[0]
            one_N1 = jnp.ones((N1,))
            one_N2 = jnp.ones((N2,))
            gQ = Q.T @ one_N1
            gR = R.T @ one_N2

            MR = Lambda.T @ ((Q.T @ A1) @ (A2 @ Q)) @ Lambda @ ((R.T @ B1) @ (B2 @ R)) @ jnp.diag(1. / gR)
            MQ = Lambda @ ((R.T @ B1) @ (B2 @ R)) @ Lambda.T @ ((Q.T @ A1) @ (A2 @ Q)) @ jnp.diag(1. / gQ)
            gradQ += 4 * jnp.outer(one_N1, jnp.diag(MQ))
            gradR += 4 * jnp.outer(one_N2, jnp.diag(MR))

        # Wasserstein gradients in low-rank form
        gradQW, gradRW = gd__Wasserstein_Grad(C_factors, Q, R, Lambda, full_grad=full_grad, low_rank=True)

        gradQ = (1 - alpha) * gradQW + (alpha / 2) * gradQ
        gradR = (1 - alpha) * gradRW + (alpha / 2) * gradR

    else:
        if Wasserstein:
            gradQ, gradR = gd__Wasserstein_Grad(C, Q, R, Lambda, full_grad=full_grad)
        elif A is not None and B is not None:
            if not semiRelaxedLeft and not semiRelaxedRight and not unbalanced:
                gradQ = -4 * (A @ Q) @ Lambda @ (R.T @ B @ R) @ Lambda.T
                gradR = -4 * (B @ R @ Lambda.T) @ (Q.T @ A @ Q) @ Lambda
            elif semiRelaxedRight:
                gradQ = -4 * (A @ Q) @ Lambda @ (R.T @ B @ R) @ Lambda.T
                gradR = 2 * (B @ B) @ R @ One_rr - 4 * (B @ R @ Lambda.T) @ (Q.T @ A @ Q) @ Lambda
            elif semiRelaxedLeft:
                gradQ = 2 * (A @ A) @ Q @ One_rr - 4 * (A @ Q) @ Lambda @ (R.T @ B @ R) @ Lambda.T
                gradR = -4 * (B @ R @ Lambda.T) @ (Q.T @ A @ Q) @ Lambda
            elif unbalanced:
                gradQ = 2 * (A @ A) @ Q @ One_rr - 4 * (A @ Q) @ Lambda @ (R.T @ B @ R) @ Lambda.T
                gradR = 2 * (B @ B) @ R @ One_rr - 4 * (B @ R @ Lambda.T) @ (Q.T @ A @ Q) @ Lambda

            if full_grad:
                N1, N2 = Q.shape[0], R.shape[0]
                one_N1 = jnp.ones((N1,))
                one_N2 = jnp.ones((N2,))
                gQ = Q.T @ one_N1
                gR = R.T @ one_N2
                F = Q @ Lambda @ R.T
                MR = Lambda.T @ Q.T @ A @ F @ B @ R @ jnp.diag(1. / gR)
                MQ = Lambda @ R.T @ B @ F.T @ A @ Q @ jnp.diag(1. / gQ)
                gradQ += 4 * jnp.outer(one_N1, jnp.diag(MQ))
                gradR += 4 * jnp.outer(one_N2, jnp.diag(MR))

            if FGW:
                gradQW, gradRW = gd__Wasserstein_Grad(C, Q, R, Lambda, full_grad=full_grad)
                gradQ = (1 - alpha) * gradQW + alpha * gradQ
                gradR = (1 - alpha) * gradRW + alpha * gradR
        else:
            raise ValueError("Provide either Wasserstein=True or distance matrices A and B for GW problem.")

    normalizer = jnp.max(jnp.array([jnp.max(jnp.abs(gradQ)), jnp.max(jnp.abs(gradR))]))
    gamma_k = gamma / normalizer

    return gradQ, gradR, gamma_k



def gd__compute_grad_B(C=None, Q=None, R=None, Lambda=None, gQ=None, gR=None, gamma=1.0,
                       Wasserstein=True, FGW=False, A=None, B=None, alpha=0.0,
                       low_rank=False, C_factors=None, A_factors=None, B_factors=None):
    '''
    JAX version of the Wasserstein / GW / FGW gradient w.r.t. the transport plan T.
    Supports both full and low-rank computation.
    '''
    if low_rank:
        # Gradient using low-rank approximation
        C1, C2 = C_factors
        gradLambda = (1 - alpha) * ((Q.T @ C1) @ (C2 @ R))

        if A_factors is not None and B_factors is not None:
            A1, A2 = A_factors
            B1, B2 = B_factors
            grad_GW = -4 * ((Q.T @ A1) @ (A2 @ Q)) @ Lambda @ ((R.T @ B1) @ (B2 @ R))
            gradLambda += (alpha / 2.0) * grad_GW
    else:
        if Wasserstein:
            gradLambda = Q.T @ C @ R
        else:
            gradLambda = -4 * Q.T @ A @ Q @ Lambda @ R.T @ B @ R
            if FGW:
                gradLambda = (1 - alpha) * (Q.T @ C @ R) + alpha * gradLambda

    # Final gradient
    gradT = jnp.diag(1.0 / gQ) @ gradLambda @ jnp.diag(1.0 / gR)
    gamma_T = gamma / jnp.max(jnp.abs(gradT))
    return gradT, gamma_T

<h2> Implementation of the FRLC Solver (Low-Rank Optimal Transport)

The paper relies on a low-rank optimal transport solver called FRLC (Factor Relaxation with Latent Coupling). Here's its implementation:

In [None]:
def FRLC_opt(C, A=None, B=None, C_factors=None, A_factors=None, B_factors=None,
             a=None, b=None, tau_in=50, tau_out=50, gamma=90, r=10, r2=None,
             max_iter=200, Wasserstein=True, returnFull=False, FGW=False, alpha=0.0,
             initialization='Full', init_args=None, full_grad=True,
             convergence_criterion=True, tol=1e-5, min_iter=25,
             max_inneriters_balanced=300, max_inneriters_relaxed=50,
             diagonalize_return=False, low_rank=False):
    """
    FRLC Optimal Transport solver. Supports dense and low-rank cost formulations.
    """
    if r2 is None:
        r2 = r

    n, m = (C.shape[0], C.shape[1]) if C is not None else (C_factors[0].shape[0], C_factors[1].shape[0])

    if a is None:
        a = jnp.ones(n) / n
    if b is None:
        b = jnp.ones(m) / m

    # Initialize coupling decomposition Q, R, Lambda
    if initialization == 'Full':
        T0 = ott_log_sinkhorn(C, a, b, max_inneriters_balanced)
    elif initialization == 'Identity':
        T0 = jnp.outer(a, b)
    elif initialization == 'Rank-1':
        T0 = jnp.outer(a, b)
    elif initialization == 'Random':
        T0 = jax.random.uniform(jax.random.PRNGKey(0), shape=(n, m))
        T0 = T0 / T0.sum()
    elif initialization == 'Given' and init_args is not None:
        T0 = init_args
    else:
        raise ValueError("Unsupported initialization method")

    U, s, Vt = jnp.linalg.svd(T0, full_matrices=False)
    Q = U[:, :r]
    R = Vt[:r, :].T
    Lambda = jnp.diag(s[:r])

    err = jnp.inf
    it = 0
    converged = False

    while not converged and it < max_iter:
        # Gradient via factorized formulation or dense formulation
        gradQ, gradR, gamma_k = gd__compute_grad_A(
            C_factors, A_factors, B_factors, Q, R, Lambda, gamma, alpha=alpha, full_grad=full_grad, low_rank=low_rank
        )

        # Update Q
        Q -= gamma_k * gradQ
        Q = Q / jnp.linalg.norm(Q, axis=0, keepdims=True)

        T_mid = Q @ Lambda @ R.T

        # Solve subproblem for R
        T_R = ott_log_sinkhorn(C if not low_rank else None, a, b, max_inneriters_relaxed, init=T_mid)
        U, s, Vt = jnp.linalg.svd(T_R, full_matrices=False)
        R = Vt[:r2, :].T
        Lambda = jnp.diag(s[:r2])

        gradQ, gradR, gamma_k = gd__compute_grad_B(
            C, A, B, Q, R, Lambda, gamma, alpha=alpha, full_grad=full_grad, low_rank==low_rank, C_factors=C_factors, A_factors=A_factors, B_factors=B_factors
        )

        # Update R
        R -= gamma_k * gradR
        R = R / jnp.linalg.norm(R, axis=0, keepdims=True)

        T_mid = Q @ Lambda @ R.T

        # Solve subproblem for Q
        T_Q = ott_log_sinkhorn(C if not low_rank else None, a, b, max_inneriters_relaxed, init=T_mid)
        U, s, Vt = jnp.linalg.svd(T_Q, full_matrices=False)
        Q = U[:, :r]
        Lambda = jnp.diag(s[:r])

        # Convergence check
        if convergence_criterion:
            err = jnp.linalg.norm(T_Q - T_mid, ord='fro') / jnp.linalg.norm(T_Q, ord='fro')
            if it >= min_iter and err < tol:
                converged = True

        it += 1

    if returnFull:
        return Q @ Lambda @ R.T
    elif diagonalize_return:
        return Q, Lambda, R.T
    else:
        return Q, R, Lambda

    
def FRLC_compute_OT_cost(X, Y, C = None, Monge_clusters = None, sq_Euclidean = True):
    """
    Compute the optimal transport cost in linear space and time (without coupling), in JAX.
    Supports squared Euclidean cost via OTT cost object.
    """
    if Monge_clusters is None or len(Monge_clusters) == 0:
        return 0.0

    def compute_pair_cost(pair):
        idx1, idx2 = pair
        if C is not None:
            return C[idx1, idx2]
        else:
            diff = X[idx1] - Y[idx2]
            if sq_Euclidean:
                return jnp.sum(diff**2)
            else:
                return jnp.linalg.norm(diff)

    pair_costs = jax.vmap(compute_pair_cost)(jnp.array(Monge_clusters))
    total_cost = jnp.sum(pair_costs)
    return total_cost / len(Monge_clusters)

In [None]:
# 1. Calcul de la matrice des coûts (distance euclidienne)
def cdist_jax(X, Y):
    # ||x - y||^2 = ||x||^2 + ||y||^2 - 2<x, y>
    X_norm = jnp.sum(X ** 2, axis=1)[:, None]
    Y_norm = jnp.sum(Y ** 2, axis=1)[None, :]
    C = jnp.sqrt(jnp.maximum(X_norm + Y_norm - 2 * jnp.dot(X, Y.T), 0.0))
    return C

C = cdist_jax(X, Y)

try:
    # 2. Appel à FRLC_opt (on passe dtype=jnp.float32)
    Q, R, T, errs = FRLC_opt(
        C=C,
        gamma=30,
        r=40,
        max_iter=100,
        tau_in=100000,
        low_rank=False
    )

    # 3. Calcul de la matrice de couplage P complète
    inv_sum_Q = 1.0 / jnp.sum(Q, axis=0)  # shape (r,)
    inv_sum_R = 1.0 / jnp.sum(R, axis=0)  # shape (r,)
    P = (Q
         @ jnp.diag(inv_sum_Q)
         @ T
         @ jnp.diag(inv_sum_R)
         @ R.T)  # shape (n_X, n_Y)

    # 4. Extraction des paires (i,j) où P[i,j] > 0
    ij = jnp.argwhere(P > 0)  # shape (num_pairs, 2)
    Monge_clusters = [(int(i), int(j)) for i, j in ij]

    # 5. Calcul du coût OT via la fonction JAXisée
    cost_frlc = FRLC_compute_OT_cost(
        X, Y,
        C=C,
        Monge_clusters=Monge_clusters,
        sq_Euclidean=True
    )

    print(f'FRLC cost: {cost_frlc}')

    # 6. Approximation du couplage pour l'extraction de correspondances
    P_approx = Q @ T @ R.T  # shape (n_X, n_Y)
    matches_Y = jnp.argmax(P_approx, axis=1)  # shape (n_X,)

    # 7. Construction de la liste F de paires indices
    F = [
        (jnp.array([i], dtype=jnp.int32),
         jnp.array([int(matches_Y[i])], dtype=jnp.int32))
        for i in range(X.shape[0])
    ]

except Exception as e:
    print(f'FRLC failed for sample size {X.shape[0]}: {e}')
    raise e

Iteration: 0
Iteration: 25
Iteration: 50
Iteration: 75
FRLC cost: 23.961475372314453


# 5.

In [None]:
'''
--------------
Code for gradients assuming low-rank distance matrices C, A, B
--------------
'''

def gd__compute_grad_A_LR(C_factors, A_factors, B_factors, Q, R, Lambda, gamma, alpha=0.0, full_grad=False):
    
    N1, N2 = C_factors[0].shape[0], C_factors[1].shape[1]

    if A_factors is not None and B_factors is not None:
        A1, A2 = A_factors
        B1, B2 = B_factors

        # GW gradients
        gradQ = -4 * (A1 @ (A2 @ (Q @ Lambda @ ((R.T @ B1) @ (B2 @ R)) @ Lambda.T)))
        gradR = -4 * (B1 @ (B2 @ (R @ (Lambda.T @ ((Q.T @ A1) @ (A2 @ Q)) @ Lambda))))

        one_N1 = jnp.ones((N1,), dtype=Q.dtype)
        one_N2 = jnp.ones((N2,), dtype=R.dtype)

        if full_grad:
            gQ = Q.T @ one_N1
            gR = R.T @ one_N2

            MR = (Lambda.T @ ((Q.T @ A1) @ (A2 @ Q)) @ Lambda
                  @ ((R.T @ B1) @ (B2 @ R)) @ jnp.diag(1.0 / gR))
            MQ = (Lambda @ ((R.T @ B1) @ (B2 @ R)) @ Lambda.T
                  @ ((Q.T @ A1) @ (A2 @ Q)) @ jnp.diag(1.0 / gQ))

            gradQ += 4 * jnp.outer(one_N1, jnp.diag(MQ))
            gradR += 4 * jnp.outer(one_N2, jnp.diag(MR))
    else:
        gradQ = jnp.zeros_like(Q)
        gradR = jnp.zeros_like(R)

    # Appel à une version jaxifiée de gd__Wasserstein_Grad_LR
    gradQW, gradRW = gd__Wasserstein_Grad_LR(C_factors, Q, R, Lambda, full_grad=full_grad)

    gradQ = (1 - alpha) * gradQW + (alpha / 2.0) * gradQ
    gradR = (1 - alpha) * gradRW + (alpha / 2.0) * gradR

    normalizer = jnp.maximum(jnp.max(jnp.abs(gradQ)), jnp.max(jnp.abs(gradR)))
    gamma_k = gamma / normalizer

    return gradQ, gradR, gamma_k

def gd__compute_grad_B_LR(C_factors, A_factors, B_factors, Q, R, Lambda, gQ, gR, gamma, alpha=0.0):
    """
    Low-rank gradient computation in JAX for Wasserstein / Gromov-Wasserstein.
    """
    C1, C2 = C_factors  # (N1, rC), (rC, N2)
    gradLambda = 0.0

    if A_factors is not None and B_factors is not None:
        A1, A2 = A_factors  # (N1, rA), (rA, N1)
        B1, B2 = B_factors  # (N2, rB), (rB, N2)
        term_A = (Q.T @ A1) @ (A2 @ Q)         # shape: (r, r)
        term_B = (R.T @ B1) @ (B2 @ R)         # shape: (r, r)
        gradLambda = -4.0 * term_A @ Lambda @ term_B

    term_C = (Q.T @ C1) @ (C2 @ R)             # shape: (r, r)
    gradLambda = (1 - alpha) * term_C + (alpha / 2.0) * gradLambda

    gradT = jnp.diag(1.0 / gQ) @ gradLambda @ jnp.diag(1.0 / gR)
    gamma_T = gamma / jnp.max(jnp.abs(gradT))
    return gradT, gamma_T

def gd__Wasserstein_Grad_LR(C_factors, Q, R, Lambda, full_grad=True):
    """
    JAX version of Wasserstein gradient with low-rank cost approximation:
    C ≈ C1 @ C2.T
    """
    C1, C2 = C_factors

    gradQ = C1 @ ((C2 @ R) @ Lambda.T)
    if full_grad:
        N1 = Q.shape[0]
        one_N1 = jnp.ones((N1,), dtype=Q.dtype)
        gQ = Q.T @ one_N1
        w1 = jnp.diag((gradQ.T @ Q) @ jnp.diag(1.0 / gQ))
        gradQ = gradQ - jnp.outer(one_N1, w1)

    gradR = C2.T @ ((C1.T @ Q) @ Lambda)
    if full_grad:
        N2 = R.shape[0]
        one_N2 = jnp.ones((N2,), dtype=R.dtype)
        gR = R.T @ one_N2
        w2 = jnp.diag(jnp.diag(1.0 / gR) @ (R.T @ gradR))
        gradR = gradR - jnp.outer(one_N2, w2)

    return gradQ, gradR

In [14]:
def rank_annealing__factors(n):
    """
    Return list of all factors of an integer
    """
    n = int(n)  # Conversion pour compatibilité avec jnp.arange
    candidates = jnp.arange(1, jnp.floor(jnp.sqrt(n)) + 1).astype(int)
    divisible = (n % candidates) == 0
    factors1 = candidates[divisible]
    factors2 = n // factors1
    all_factors = jnp.concatenate([factors1, factors2])
    unique_factors = jnp.unique(all_factors)
    return unique_factors

def rank_annealing__max_factor_lX(n, max_X):
    """
    Find max factor of n , such that max_factor \leq max_X
    """
    factor_lst = rank_annealing__factors(n)
    factors_leq_max = factor_lst[factor_lst <= max_X]
    return jnp.max(factors_leq_max)

def rank_annealing__min_sum_partial_products_with_factors(n, k, C):
    """
    Dynamic program to compute the rank-schedule, subject to a constraint of intermediates being \leq C

    Parameters
    ----------
    n: int
        The dataset size to be factored into a rank-scheduler. Assumed to be non-prime.
    k: int
        The depth of the hierarchy.
    C: int
        A constraint on the maximal intermediate rank across the hierarchy.
    
    """
    INF = 1e10  # Large constant instead of float('inf') for JAX compatibility

    dp = jnp.full((n+1, k+1), INF)
    choice = jnp.full((n+1, k+1), -1)

    def init_base_case(dp, choice):
        d = jnp.arange(1, n+1)
        mask = d <= C
        dp = dp.at[d[mask], 1].set(d[mask])
        choice = choice.at[d[mask], 1].set(d[mask])
        return dp, choice

    dp, choice = init_base_case(dp, choice)

    for t in range(2, k+1):
        for d in range(1, n+1):
            if dp[d, t-1] >= INF:
                continue
            for r in range(1, min(C, d)+1):
                if d % r == 0:
                    candidate = r + r * dp[d // r, t-1]
                    if candidate < dp[d, t]:
                        dp = dp.at[d, t].set(candidate)
                        choice = choice.at[d, t].set(r)

    if dp[n, k] >= INF:
        return None, []

    # Backtracking
    factors = []
    d_cur, t_cur = n, k
    while t_cur > 0:
        r_cur = int(choice[d_cur, t_cur])
        factors.append(r_cur)
        d_cur //= r_cur
        t_cur -= 1

    return dp[n, k], factors

def rank_annealing__optimal_rank_schedule(n, hierarchy_depth=6, max_Q=int(2**10), max_rank=16):
    """
    A function to compute the optimal rank-scheduler of refinement.
    
    Parameters
    ----------
    n: int
        Size of the input dataset -- cannot be a prime number
    hierarchy_depth: int
        Maximal permissible depth of the multi-scale hierarchy
    max_Q: int
        Maximal rank at terminal base case (before reducing the \leq max_Q rank coupling to a 1-1 alignment)
    max_rank: int
        Maximal rank at the intermediate steps of the rank-schedule
        
    """
    Q = int(rank_annealing__max_factor_lX(n, max_Q))
    ndivQ = int(n // Q)

    _, rank_schedule = rank_annealing__min_sum_partial_products_with_factors(ndivQ, hierarchy_depth, max_rank)
    rank_schedule = sorted(rank_schedule)
    rank_schedule.append(Q)
    rank_schedule = [x for x in rank_schedule if x != 1]

    print(f'Optimized rank-annealing schedule: {rank_schedule}')

    assert functools.reduce(operator.mul, rank_schedule, 1) == n, "Error! Rank-schedule does not factorize n!"

    return rank_schedule

This function allows representing the squared Euclidean cost matrix in factorized form to avoid storing the full matrix, which is crucial for scaling the algorithm.

In [None]:
def compute_lr_sqeuclidean_matrix(X_s,
                                  X_t,
                                  rescale_cost: bool = False):
    """
    Adapted to JAX from the low-rank squared Euclidean cost decomposition,
    as in Scetbon, Cuturi & Peyré (2021), Section 3.5, Proposition 1.
    """

    ns, dim = X_s.shape
    nt, _ = X_t.shape

    # First low-rank term (source-side)
    sum_Xs_sq = jnp.sum(X_s ** 2, axis=1, keepdims=True)         # (ns, 1)
    ones_ns = jnp.ones((ns, 1), dtype=X_s.dtype)                 # (ns, 1)
    neg_two_Xs = -2.0 * X_s                                      # (ns, d)
    M1 = jnp.concatenate([sum_Xs_sq, ones_ns, neg_two_Xs], axis=1)  # (ns, d+2)

    # Second low-rank term (target-side)
    ones_nt = jnp.ones((nt, 1), dtype=X_t.dtype)                 # (nt, 1)
    sum_Xt_sq = jnp.sum(X_t ** 2, axis=1, keepdims=True)         # (nt, 1)
    M2 = jnp.concatenate([ones_nt, sum_Xt_sq, X_t], axis=1)      # (nt, d+2)

    if rescale_cost:
        # Use jnp.max over entire arrays
        max_M1 = jnp.max(jnp.abs(M1))
        max_M2 = jnp.max(jnp.abs(M2))

        # Avoid division by zero
        if max_M1 > 0:
            M1 = M1 / jnp.sqrt(max_M1)
        if max_M2 > 0:
            M2 = M2 / jnp.sqrt(max_M2)

    return M1, M2.T

<h2> Implementation of Hierarchical Refinement

The core of the paper is the "Hierarchical Refinement" algorithm that uses low-rank decompositions recursively to build a full transport plan between two datasets. \
The key insight is that optimal factors of low-rank optimal transport co-cluster points with their image under the Monge map

In [None]:
class HierarchicalRefinementOT:
    """
    A class to perform Hierarchical OT refinement with optional (CPU) parallelization.
    
    Attributes
    ----------
    C : torch.tensor
        The cost matrix of shape (N, N), currently assumed square for hierarchical OT.
        Can represent general user-defined costs.
    rank_schedule : list
        The list of ranks for each hierarchical level -- i.e. the rank-annealing schedule.
    solver : callable
        A low-rank OT solver that takes a cost submatrix and returns Q, R, diagG, errs.
    solver_params: Dict[str, Any], optional
        Additional parameters for the low-rank solver. If None, default values are used.
    device : str
        The device ('cpu' or 'cuda') to be used for computations.
    base_rank : int
        Base-case rank at which to stop subdividing clusters.
    clustering_type : str
        'soft' or 'hard'. Determines how cluster assignments are computed after each OT solve.
    plot_clusterings : bool
        Whether to plot the Q and R matrices at each step for debugging.
    parallel : bool
        Whether to execute each subproblem at a level in parallel.
    num_processes : int or None
        Number of worker processes to spawn (if `parallel=True`). Defaults to `None` which uses `mp.cpu_count()`.
    X, Y : torch.tensor
        The point-clouds for the first dataset (X) and the second dataset (Y)
    N : int
        The size of the dataset.
    Monge_clusters : list (tuples of type torch.float)
        A list containing the Monge-map pairings
    """
    
    def __init__(self,
                C: torch.Tensor,
                 rank_schedule: List[int],
                 solver: Callable = FRLC_opt,
                 solver_params: Union[Dict[str, Any] , None] = None,
                 device: str = 'cpu',
                 base_rank: int = 1,
                 clustering_type: str = 'soft',
                 plot_clusterings: bool = False,
                 parallel: bool = False,
                 num_processes: Union[int, None] = None
                ):
    
        self.C = C.to(device)
        self.rank_schedule = rank_schedule
        self.solver = solver
        self.device = device
        self.base_rank = base_rank
        self.clustering_type = clustering_type
        self.plot_clusterings =  plot_clusterings
        self.parallel = parallel
        self.num_processes = num_processes
        
        # Point clouds optional attributes
        self.X, self.Y = None, None
        self.N = C.shape[0]
        self.Monge_clusters = None
        # This is a dummy line -- this init doesn't compute C or its factorization
        self.sq_Euclidean = False

        # Setting parameters to use with the FRLC solver
        default_solver_params = {
            'gamma' : 30,
            'max_iter' : 60,
            'min_iter' : 25,
            'max_inneriters_balanced' : 100,
            'max_inneriters_relaxed' : 40,
            'printCost' : False,
            'tau_in' : 100000
        }
        if solver_params is not None:
            default_solver_params.update(solver_params)
        self.solver_params = default_solver_params
        
        assert C.shape[0] == C.shape[1], "Currently assume square costs so that |X| = |Y| = N"
    
    @classmethod
    def init_from_point_clouds(cls,
                            X: torch.Tensor,
                            Y: torch.Tensor,
                            rank_schedule: List[int],
                            distance_rank_schedule: Union[List[int], None] = None,
                            solver: Callable = lambda *args: FRLC_opt(*args, low_rank=True),
                            solver_params: Union[Dict[str, Any] , None] = None,
                            device: str = 'cpu',
                            base_rank: int = 1,
                            clustering_type: str = 'soft',
                            plot_clusterings: bool = False,
                            parallel: bool = False,
                            num_processes: Union[int, None] = None,
                              sq_Euclidean = False):
        r"""
        Constructor for initializing from point clouds.
        
        Attributes
        ----------
        X : torch.tensor
            The point-cloud of shape N for measure \mu
        Y: torch.tensor
            Point cloud of shape N for measure \nu
        distance_rank_schedule: List[int]
            A separate rank-schedule for the low-rank distance matrix being factorized.
        sq_Euclidean : bool
            If True, assumes squared Euclidean cost. Otherwise, defaults to Euclidean.
            Needed for the point-cloud variant, in order to define a distance metric
            to use for the low-rank approximation of C.
        """
        
        obj = cls.__new__(cls)
        
        obj.X = X
        obj.Y = Y
        obj.rank_schedule = rank_schedule
        
        if distance_rank_schedule is None:
            # Default: assume distance rank schedule is identical to rank schedule for coupling.
            obj.distance_rank_schedule = rank_schedule
        else:
            obj.distance_rank_schedule = distance_rank_schedule
        
        obj.solver = solver
        obj.device = device
        obj.base_rank = base_rank
        obj.clustering_type = clustering_type
        obj.plot_clusterings =  plot_clusterings
        obj.parallel = parallel
        obj.num_processes = num_processes
        obj.N = X.shape[0]
        
        # Cost-mat an optional attribute
        obj.C = None
        obj.Monge_clusters = None
        obj.sq_Euclidean = sq_Euclidean

        # Setting parameters to use with the FRLC solver
        default_solver_params = {
            'gamma' : 30,
            'max_iter' : 60,
            'min_iter' : 25,
            'max_inneriters_balanced' : 100,
            'max_inneriters_relaxed' : 40,
            'printCost' : False,
            'tau_in' : 100000
        }
        if solver_params is not None:
            default_solver_params.update(solver_params)
        obj.solver_params = default_solver_params
        
        assert X.shape[0] == Y.shape[0], "Currently assume square costs so that |X| = |Y| = N"
        
        return obj

    def run(self, return_as_coupling: bool = False):
        """
        Routine to run hierarchical refinement.
        
        Parameters
        ----------
        return_as_coupling : bool
            Whether to return a full coupling matrix (size NxN) 
            or a list of (idxX, idxY) co-clusters / assignments.
        
        Returns
        -------
        list of (idxX, idxY) pairs OR torch.tensor
            If return_as_coupling=False: returns a list of tuples (idxX, idxY) for each co-cluster.
            If return_as_coupling=True: returns a dense coupling matrix of shape (N, N).
        """
        if self.parallel:
            """ WARNING: not currently implemented in this class! See HR_OT_parallelized instead! """
            return self._hierarchical_refinement_parallelized(return_as_coupling = return_as_coupling)
        else:
            return self._hierarchical_refinement(return_as_coupling = return_as_coupling)

    def _hierarchical_refinement(self, return_as_coupling: bool = False):
        """
        Single-process (serial) Hierarchical Refinement
        """

        # Define partitions
        F_t = [(torch.arange( self.N , device=self.device), 
                torch.arange( self.N , device=self.device))]
        
        for i, rank_level in enumerate(self.rank_schedule):
            # Iterate over ranks in the scheduler
            F_tp1 = []

            if i == len(self.rank_schedule)-1:
                fin_iters = int(self.N / rank_level)
                print(f'Last level, rank chunk-size {rank_level} with {fin_iters} iterations to completion.')
                j = 0
            
            for (idxX, idxY) in F_t:

                if i == len(self.rank_schedule)-1:
                    print(f'{j}/{fin_iters} of final-level iterations to completion')
                    j += 1
                
                if len(idxX) <=self.base_rank or len(idxY) <= self.base_rank:
                    # Return tuple of base-rank sized index sets (e.g. (x,T(x)) for base_rank=1)
                    F_tp1.append( ( idxX, idxY ) )
                    continue
                
                if self.C is not None:
                    Q,R = self._solve_prob( idxX, idxY, rank_level)
                else:
                    rank_D = self.distance_rank_schedule[i]
                    Q,R = self._solve_LR_prob( idxX, idxY, rank_level, rank_D )
                
                if self.plot_clusterings:
                    # If visualizing the Q - R clustering matrices.
                    
                    plt.figure(figsize=(12, 5))
                    plt.subplot(1, 2, 1)
                    plt.imshow(Q.detach().cpu().numpy(), aspect='auto', cmap='viridis')
                    plt.title(f"Q Clustering Level {i+1}")
                    plt.colorbar()
    
                    plt.subplot(1, 2, 2)
                    plt.imshow(R.detach().cpu().numpy(), aspect='auto', cmap='viridis')
                    plt.title(f"R Clustering Level {i+1}")
                    plt.colorbar()
                    plt.show()
                
                # Next level cluster capacity
                capacity = int(self.N) / int(torch.prod(torch.Tensor(self.rank_schedule[0:i+1])))
                capacity = int(capacity)
                
                idx_seenX, idx_seenY = torch.arange(Q.shape[0], device=self.device), \
                                                    torch.arange(R.shape[0], device=self.device)
                
                # Split by hard or soft-clustering
                if self.clustering_type == 'soft':
                    # If using a solver which returns "soft" clusterings, must strictly fill partitions to capacities.
                    
                    for z in range(rank_level):
                        
                        topk_values, topk_indices_X = torch.topk( Q[idx_seenX][:,z], k=capacity )
                        idxX_z = idxX[idx_seenX[topk_indices_X]]
                        topk_values, topk_indices_Y = torch.topk( R[idx_seenY][:,z], k=capacity )
                        idxY_z = idxY[idx_seenY[topk_indices_Y]]
                        
                        F_tp1.append(( idxX_z, idxY_z ))
                        
                        idx_seenX = idx_seenX[~torch.isin(idx_seenX, idx_seenX[topk_indices_X])]
                        idx_seenY = idx_seenY[~torch.isin(idx_seenY, idx_seenY[topk_indices_Y])]
                
                elif self.clustering_type == 'hard':
                    # If using a solver which returns "hard" clusterings, can exactly take argmax.
                    
                    zX = torch.argmax(Q, axis=1) # X-assignments
                    zY = torch.argmax(R, axis=1) # Y-assignments
                    
                    for z in range(rank_level):
                        
                        idxX_z = idxX[zX == z]
                        idxY_z = idxY[zY == z]
    
                        assert len(idxX_z) == len(idxY_z) == capacity, \
                                            "Assertion failed! Not a hard-clustering function, or point sets of unequal size!"
                        
                        F_tp1.append((idxX_z, idxY_z))
                        
            F_t = F_tp1
        
        self.Monge_clusters = F_t

        
        if return_as_coupling is False:
            return self.Monge_clusters
        else:
            return self._compute_coupling_from_Ft()

    def _solve_LR_prob(self, idxX, idxY, rank_level, rankD, eps=0.04):
        
        """
        Solve problem for low-rank coupling under a low-rank factorization of distance matrix.
        """
        
        _x0, _x1 = torch.index_select(self.X, 0, idxX), torch.index_select(self.Y, 0, idxY)
        
        if rankD < _x0.shape[0]:
            
            C_factors, A_factors, B_factors = self.get_dist_mats(_x0, _x1, 
                                                                 rankD, eps, 
                                                                 self.sq_Euclidean )
            
            # Solve a low-rank OT sub-problem with black-box solver
            Q, R, diagG, errs = self.solver(C_factors, A_factors, B_factors,
                                       gamma = self.solver_params['gamma'],
                                       r = rank_level,
                                       max_iter = self.solver_params['max_iter'],
                                       device=self.device,
                                       min_iter = self.solver_params['min_iter'],
                                       max_inneriters_balanced = self.solver_params['max_inneriters_balanced'],
                                       max_inneriters_relaxed = self.solver_params['max_inneriters_relaxed'],
                                       diagonalize_return = True,
                                       printCost = False, tau_in = self.solver_params['tau_in'],
                                        dtype = _x0.dtype)
        
        else:
            
            # Final base instance -- cost within-cluster costs explicitly
            if self.sq_Euclidean:
                C_XY = torch.cdist(_x0, _x1)**2
                
            else:
                # normal Euclidean distance otherwise
                C_XY = torch.cdist(_x0, _x1)
            
            Q, R, diagG, errs = FRLC_opt(C_XY,
                                   gamma = self.solver_params['gamma'],
                                   r = rank_level,
                                   max_iter = self.solver_params['max_iter'],
                                   device = self.device,
                                   min_iter = self.solver_params['min_iter'],
                                   max_inneriters_balanced = self.solver_params['max_inneriters_balanced'],
                                   max_inneriters_relaxed = self.solver_params['max_inneriters_relaxed'],
                                   diagonalize_return=True,
                                   printCost=False, tau_in = self.solver_params['tau_in'],
                                       dtype = C_XY.dtype)
            
        return Q, R
        
    def _solve_prob(self, idxX, idxY, rank_level):
        """
        Solve problem for low-rank coupling assuming cost sub-matrix.
        """
        
        # Index into sub-cost
        submat = torch.index_select(self.C, 0, idxX)
        C_XY = torch.index_select(submat, 1, idxY)
        
        # Solve a low-rank OT sub-problem with black-box solver
        Q, R, diagG, errs = self.solver(C_XY,
                                   gamma = self.solver_params['gamma'],
                                   r = rank_level,
                                   max_iter = self.solver_params['max_iter'],
                                   device=self.device,
                                   min_iter = self.solver_params['min_iter'],
                                   max_inneriters_balanced = self.solver_params['max_inneriters_balanced'],
                                   max_inneriters_relaxed = self.solver_params['max_inneriters_relaxed'],
                                   diagonalize_return=True,
                                   printCost=False, tau_in = self.solver_params['tau_in'],
                                       dtype = C_XY.dtype)
        return Q, R
    
    def _compute_coupling_from_Ft(self):
        """
        Returns coupling as a full-rank matrix rather than as a set of (x, T(x)) pairs.
        """
        size = (self.N, self.N)
        P = torch.zeros(size)
        # Fill sparse coupling with entries
        for pair in self.Monge_clusters:
            idx1, idx2 = pair
            P[idx1, idx2] = 1
        # Return, trivially normalized to satisfy standard OT constraints
        return P / self.N
    
    def compute_OT_cost(self):
        """
        Compute the optimal transport in linear space and time (w/o coupling).
        """
        
        cost = 0
        for clus in self.Monge_clusters:
            idx1, idx2 = clus
            if self.C is not None:
                # If C saved, index into general cost directly
                cost += self.C[idx1, idx2]
            else:
                # In case point-cloud init used, must directly compute distances between point pairs in X, Y.
                if self.sq_Euclidean:
                    # squared Euclidean case
                    cost += torch.norm(self.X[idx1,:] - self.Y[idx2,:])**2
                else:
                    # normal Euclidean cost
                    cost += torch.norm(self.X[idx1,:] - self.Y[idx2,:])
        # Appropriately normalize the cost
        cost = cost / self.N
        return cost

    
    def get_dist_mats(self, _x0, _x1, rankD, eps , sq_Euclidean ):
        
        # Wasserstein-only, setting A and B factors to be NoneType
        A_factors = None
        B_factors = None
        
        if sq_Euclidean:
            # Sq Euclidean
            C_factors = compute_lr_sqeuclidean_matrix(_x0, _x1, True)
        else:
            # Standard Euclidean dist
            C_factors = self.ret_normalized_cost(_x0, _x1, rankD, eps)
        
        return C_factors, A_factors, B_factors
    
    def ret_normalized_cost(self, X, Y, rankD, eps):
        
        C1, C2 = utils__low_rank_distance_factorization(X,
                                                      Y,
                                                      r=rankD,
                                                      eps=eps,
                                                      device=self.device)
        # Normalize appropriately
        c = ( C1.max()**1/2 ) * ( C2.max()**1/2 )
        C1, C2 = C1/c, C2/c
        C_factors = (C1.to(X.dtype), C2.to(X.dtype))
        
        return C_factors

In [None]:
rank_schedule = rank_annealing__optimal_rank_schedule( X.shape[0] , hierarchy_depth = 6, max_Q = int(2**11), max_rank = 64 )

try:
    hrot_lr = HierarchicalRefinementOT.init_from_point_clouds(X, Y, rank_schedule, base_rank=1, device=device)
    del X, Y
    F = hrot_lr.run(return_as_coupling=False)
    cost_hrot_lr = hrot_lr.compute_OT_cost()
    print(f'HR-OT cost: {cost_hrot_lr}')
except Exception as e:
    print(f'HROT-LR failed for sample size {n}: {e}')

Optimized rank-annealing schedule: [2, 1250]


NameError: name 'n' is not defined

<h2> Conclusion

In this tutorial, we explored the Hierarchical Refinement (HiRef) algorithm for large-scale optimal transport. This innovative algorithm allows computing bijective correspondences between large datasets with linear space complexity, thus overcoming one of the main limitations of traditional optimal transport methods.
We implemented the key components of the algorithm, including:

The low-rank optimal transport solver (FRLC)
The calculation of the optimal rank annealing schedule
The main hierarchical refinement algorithm