In [2]:
pip install ott-jax

Collecting ott-jax
  Downloading ott_jax-0.5.0-py3-none-any.whl.metadata (20 kB)
Collecting jaxopt>=0.8 (from ott-jax)
  Downloading jaxopt-0.8.5-py3-none-any.whl.metadata (3.3 kB)
Collecting lineax>=0.0.7 (from ott-jax)
  Downloading lineax-0.0.8-py3-none-any.whl.metadata (18 kB)
Collecting equinox>=0.11.10 (from lineax>=0.0.7->ott-jax)
  Downloading equinox-0.12.1-py3-none-any.whl.metadata (18 kB)
Collecting jaxtyping>=0.2.24 (from lineax>=0.0.7->ott-jax)
  Downloading jaxtyping-0.3.2-py3-none-any.whl.metadata (7.0 kB)
Collecting wadler-lindig>=0.1.0 (from equinox>=0.11.10->lineax>=0.0.7->ott-jax)
  Downloading wadler_lindig-0.1.5-py3-none-any.whl.metadata (17 kB)
Downloading ott_jax-0.5.0-py3-none-any.whl (283 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m283.7/283.7 kB[0m [31m7.3 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading jaxopt-0.8.5-py3-none-any.whl (172 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m172.4/172.4 kB[0m [31m9.4 MB/s[

In [None]:
import jax
import jax.numpy as jnp
import functools
import operator
from typing import List, Tuple, Dict, Optional, Callable
import ott
from ott.geometry import pointcloud, costs
from ott.solvers.linear import sinkhorn_lr, sinkhorn
from ott.problems.linear import linear_problem
import numpy as np
import time
from scipy.sparse import coo_matrix

In [None]:

def rank_annealing__factors(n):
    """Return list of all factors of an integer."""
    return set(
        sum(([i, n//i] for i in range(1, int(n**0.5) + 1) if n % i == 0), [])
    )

def rank_annealing__max_factor_lX(n, max_X):
    """Find max factor of n, such that max_factor ≤ max_X."""
    factor_lst = rank_annealing__factors(n)
    max_factor = 0
    for factor in factor_lst:
        if factor > max_factor and factor <= max_X:
            max_factor = factor
    return max_factor

def rank_annealing__min_sum_partial_products_with_factors(n, k, C):
    """
    Dynamic program to compute the rank-schedule, subject to a constraint of intermediates being ≤ C.

    Parameters
    ----------
    n: int
        The dataset size to be factored into a rank-scheduler. Assumed to be non-prime.
    k: int
        The depth of the hierarchy.
    C: int
        A constraint on the maximal intermediate rank across the hierarchy.
    """
    INF = float('inf')

    dp = [[INF]*(k+1) for _ in range(n+1)]
    choice = [[-1]*(k+1) for _ in range(n+1)]

    for d in range(1, n+1):
        if d <= C:
            dp[d][1] = d
            choice[d][1] = d

    for t in range(2, k+1):
        for d in range(1, n+1):
            if dp[d][t-1] == INF and t > 1:
                pass

            for r in range(1, min(C,d)+1):
                if d % r == 0:
                    candidate = r + r * dp[d // r][t-1]
                    if candidate < dp[d][t]:
                        dp[d][t] = candidate
                        choice[d][t] = r

    if dp[n][k] == INF:
        return None, []

    factors = []
    d_cur, t_cur = n, k

    while t_cur > 0:
        r_cur = choice[d_cur][t_cur]
        factors.append(r_cur)
        d_cur //= r_cur
        t_cur -= 1

    return dp[n][k], factors

def rank_annealing__optimal_rank_schedule(n, hierarchy_depth=6, max_Q=int(2**10), max_rank=16):
    """
    A function to compute the optimal rank-scheduler of refinement.

    Parameters
    ----------
    n: int
        Size of the input dataset -- cannot be a prime number
    hierarchy_depth: int
        Maximal permissible depth of the multi-scale hierarchy
    max_Q: int
        Maximal rank at terminal base case (before reducing the ≤ max_Q rank coupling to a 1-1 alignment)
    max_rank: int
        Maximal rank at the intermediate steps of the rank-schedule
    """
    # Factoring out the max factor
    Q = rank_annealing__max_factor_lX(n, max_Q)
    ndivQ = int(n / Q)

    # Compute partial rank schedule up to Q
    min_value, rank_schedule = rank_annealing__min_sum_partial_products_with_factors(ndivQ, hierarchy_depth, max_rank)
    rank_schedule.sort()
    rank_schedule.append(Q)
    rank_schedule = [x for x in rank_schedule if x != 1]

    print(f'Optimized rank-annealing schedule: {rank_schedule}')

    assert functools.reduce(operator.mul, rank_schedule) == n, "Error! Rank-schedule does not factorize n!"

    return rank_schedule

def hierarchical_refinement(
    X: jnp.ndarray,
    Y: jnp.ndarray,
    hierarchy_depth: int = 6,
    max_base_rank: int = 1024,   # Correspond à max_Q
    max_rank: int = 16,
    epsilon: float = 1e-2,
    cost_fn = None,
    rank_schedule: List[int] = None,
    threshold: float = 1e-3,
    inner_iterations: int = 20,
    verbose: bool = True
) -> Dict:
    """
    Implémentation de Hierarchical Refinement en utilisant OTT-JAX.

    Args:
        X: Points source (n x d)
        Y: Points cible (n x d)
        hierarchy_depth: Profondeur maximale de la hiérarchie
        max_base_rank: Rang maximal pour les sous-problèmes terminaux (max_Q)
        max_rank: Rang maximal à chaque niveau intermédiaire (C)
        epsilon: Régularisation entropique
        cost_fn: Fonction de coût (par défaut: SqEuclidean)
        rank_schedule: Liste des rangs fournie manuellement (ignore hierarchy_depth si fourni)
        threshold: Seuil de convergence pour les solveurs OT
        inner_iterations: Nombre maximal d'itérations pour chaque solveur
        verbose: Afficher les informations détaillées

    Returns:
        Dictionnaire contenant l'assignation bijective entre X et Y et des informations supplémentaires
    """
    n = X.shape[0]
    assert X.shape[0] == Y.shape[0], "X et Y doivent avoir le même nombre de points"

    # Si le rank_schedule n'est pas fourni, le calculer
    if rank_schedule is None:
        rank_schedule = rank_annealing__optimal_rank_schedule(
            n=n,
            hierarchy_depth=hierarchy_depth,
            max_Q=max_base_rank,
            max_rank=max_rank
        )

    if verbose:
        print(f"Utilisation du rank schedule: {rank_schedule}")
        print(f"Produit des rangs: {functools.reduce(operator.mul, rank_schedule)}")

    if cost_fn is None:
        cost_fn = costs.SqEuclidean()

    # Initialiser la co-partition avec l'ensemble des données
    t = 0
    Gamma_t = [(X, Y)]

    # Suivre les coûts à chaque niveau
    level_costs = []

    while any(min(x.shape[0], y.shape[0]) > max_base_rank for x, y in Gamma_t):
        Gamma_t_plus_1 = []
        level_cost = 0.0

        if verbose:
            print(f"Niveau {t}: {len(Gamma_t)} co-clusters, rang courant = {rank_schedule[t] if t < len(rank_schedule) else 'terminal'}")

        for i, (X_q, Y_q) in enumerate(Gamma_t):
            if min(X_q.shape[0], Y_q.shape[0]) <= max_base_rank:
                # Si la taille du cluster est inférieure au rang de base, on le conserve
                Gamma_t_plus_1.append((X_q, Y_q))
            else:
                # Calculer les poids uniformes
                a = jnp.ones(X_q.shape[0]) / X_q.shape[0]
                b = jnp.ones(Y_q.shape[0]) / Y_q.shape[0]

                # Créer la géométrie de point cloud
                geom = pointcloud.PointCloud(
                    X_q, Y_q,
                    epsilon=epsilon,
                    cost_fn=cost_fn
                )

                # Obtenir le rang pour ce niveau
                rank_t = rank_schedule[t] if t < len(rank_schedule) else rank_schedule[-1]

                try:
                    # CORRECTION: Passer le rang directement au constructeur LRSinkhorn
                    lr_sink = sinkhorn_lr.LRSinkhorn(
                        rank=rank_t,  # Paramètre obligatoire
                        threshold=threshold,
                        inner_iterations=inner_iterations
                    )

                    # Utiliser directement la géométrie de point cloud
                    ot_prob = linear_problem.LinearProblem(geom, a, b)
                    output = lr_sink(ot_prob)

                    # Extraire les facteurs Q et R
                    Q, R = output.q, output.r

                    # Ajouter au coût du niveau
                    level_cost += output.reg_ot_cost

                    # Partitionner les points selon leur cluster
                    for z in range(rank_t):
                        # Sélectionner les points dont l'assignation maximale est z
                        X_z_mask = jnp.argmax(Q, axis=1) == z
                        Y_z_mask = jnp.argmax(R, axis=1) == z

                        X_z = X_q[X_z_mask]
                        Y_z = Y_q[Y_z_mask]

                        # Ajouter le nouveau co-cluster s'il n'est pas vide
                        if X_z.shape[0] > 0 and Y_z.shape[0] > 0:
                            Gamma_t_plus_1.append((X_z, Y_z))

                except Exception as e:
                    if verbose:
                        print(f"Erreur pour le co-cluster {i} au niveau {t}: {e}")
                    # En cas d'erreur, conserver le co-cluster d'origine
                    Gamma_t_plus_1.append((X_q, Y_q))

        # Mettre à jour pour la prochaine itération
        Gamma_t = Gamma_t_plus_1
        level_costs.append(level_cost)
        t += 1

        if verbose:
            print(f"  → {len(Gamma_t)} co-clusters générés, coût OT: {level_cost:.6f}")

    # Créer un mapping bijective final
    mapping = []
    final_cost = 0.0

    if verbose:
        print(f"Résolution des {len(Gamma_t)} co-clusters terminaux...")

    for i, (X_q, Y_q) in enumerate(Gamma_t):
        if X_q.shape[0] == 1 and Y_q.shape[0] == 1:
            # Correspondance 1-à-1
            mapping.append((X_q[0], Y_q[0]))
        else:
            # Résoudre un problème OT standard pour le cas base
            a = jnp.ones(X_q.shape[0]) / X_q.shape[0]
            b = jnp.ones(Y_q.shape[0]) / Y_q.shape[0]

            geom = pointcloud.PointCloud(X_q, Y_q, epsilon=epsilon, cost_fn=cost_fn)
            ot_prob = linear_problem.LinearProblem(geom, a, b)

            # Utiliser Sinkhorn standard pour la bijection finale
            sink = sinkhorn.Sinkhorn(threshold=threshold, inner_iterations=inner_iterations)

            try:
                output = sink(ot_prob)

                # Ajouter au coût final
                final_cost += output.reg_ot_cost

                # Extraire les correspondances de la matrice de transport
                P = output.matrix
                for i in range(X_q.shape[0]):
                    j = jnp.argmax(P[i])
                    mapping.append((X_q[i], Y_q[j]))

            except Exception as e:
                if verbose:
                    print(f"Erreur pour le co-cluster terminal {i}: {e}")
                # Utiliser une correspondance arbitraire (non optimale) en cas d'erreur
                for i in range(min(X_q.shape[0], Y_q.shape[0])):
                    mapping.append((X_q[i], Y_q[i if i < Y_q.shape[0] else 0]))

    # Vérifier que le nombre de paires correspond à n
    assert len(mapping) == n, f"Erreur: {len(mapping)} paires mappées pour {n} points"

    level_costs.append(final_cost)

    # Construire une matrice de transport sparse (format COO)
    indices_x = []
    indices_y = []

    # Trouver les indices originaux dans X et Y
    for x, y in mapping:
        # Recherche des indices
        x_idx = jnp.where((X == x).all(axis=1))[0][0]
        y_idx = jnp.where((Y == y).all(axis=1))[0][0]
        indices_x.append(x_idx)
        indices_y.append(y_idx)

    transport_matrix = coo_matrix(
        (jnp.ones(n), (indices_x, indices_y)),
        shape=(n, n)
    )

    if verbose:
        print(f"Hierarchical Refinement terminé: {n} paires mappées")
        print(f"Coût final: {sum(level_costs):.6f}")

    return {
        "mapping": mapping,               # Liste des paires (x, y)
        "rank_schedule": rank_schedule,   # Schedule de rangs utilisé
        "transport_matrix": transport_matrix,  # Matrice de transport sparse
        "level_costs": level_costs,       # Coûts OT à chaque niveau
        "total_cost": sum(level_costs)    # Coût OT total
    }

In [9]:
def generate_synthetic_data(n=512, seed=42):
    """
    Génère deux ensembles de points à apparier.
    Le second ensemble est une rotation + translation du premier.
    """
    # Fixer la graine pour la reproductibilité
    np.random.seed(seed)

    # Générer des points sur un cercle avec du bruit
    theta = np.linspace(0, 2*np.pi, n)
    radius = 1.0 + 0.1 * np.random.randn(n)

    X = np.column_stack([
        radius * np.cos(theta),
        radius * np.sin(theta)
    ])

    # Créer Y comme une transformation de X
    rotation_angle = np.pi/4  # 45 degrés
    rotation_matrix = np.array([
        [np.cos(rotation_angle), -np.sin(rotation_angle)],
        [np.sin(rotation_angle), np.cos(rotation_angle)]
    ])

    Y = np.dot(X, rotation_matrix) + np.array([0.5, -0.3])

    # Ajouter un peu de bruit à Y
    Y += 0.05 * np.random.randn(*Y.shape)

    # Convertir en jnp.array
    return jnp.array(X), jnp.array(Y)

# Évaluer le temps d'exécution et la qualité de l'alignement
def evaluate_alignment(X, Y, result):
    """
    Évalue la qualité de l'alignement et affiche les résultats.
    """
    # Calculer la distance moyenne entre les paires mappées
    distances = []
    for (x, y) in result["mapping"]:
        dist = jnp.sqrt(jnp.sum((x - y)**2))
        distances.append(dist)

    mean_dist = jnp.mean(jnp.array(distances))
    max_dist = jnp.max(jnp.array(distances))

    print(f"Distance moyenne entre les paires mappées: {mean_dist:.6f}")
    print(f"Distance maximale entre les paires mappées: {max_dist:.6f}")

    return {
        "mean_distance": mean_dist,
        "max_distance": max_dist
    }

# Visualiser l'alignement
def visualize_alignment(X, Y, result, max_lines=100):
    """
    Visualise l'alignement entre X et Y.
    """
    plt.figure(figsize=(12, 10))

    # Tracer les points
    plt.scatter(X[:, 0], X[:, 1], c='blue', s=10, alpha=0.6, label='Source')
    plt.scatter(Y[:, 0], Y[:, 1], c='red', s=10, alpha=0.6, label='Target')

    # Tracer un sous-ensemble de lignes d'alignement pour éviter l'encombrement
    indices = np.random.choice(len(result["mapping"]), min(max_lines, len(result["mapping"])), replace=False)

    for idx in indices:
        x, y = result["mapping"][idx]
        plt.plot([x[0], y[0]], [x[1], y[1]], 'k-', alpha=0.1)

    plt.title(f'Hierarchical Refinement Alignment\nTotal Cost: {result["total_cost"]:.4f}')
    plt.legend()
    plt.grid(alpha=0.3)
    plt.axis('equal')
    plt.show()

# Exécuter l'exemple
def run_example():
    # Taille du problème (doit être factorisable)
    n = 512  # Essayez avec différentes puissances de 2: 128, 256, 512, 1024

    print(f"Générer {n} points synthétiques...")
    X, Y = generate_synthetic_data(n=n)

    print(f"Dimensions de X: {X.shape}")
    print(f"Dimensions de Y: {Y.shape}")

    # Paramètres pour Hierarchical Refinement
    params = {
        "hierarchy_depth": 3,
        "max_base_rank": 8,  # Rang maximal pour les problèmes terminaux
        "max_rank": 8,       # Rang maximal pour les niveaux intermédiaires
        "epsilon": 0.01,     # Régularisation entropique
        "threshold": 1e-3,   # Seuil de convergence
        "inner_iterations": 30  # Nombre d'itérations internes
    }

    print("\nDébut de Hierarchical Refinement...")
    start_time = time.time()

    # Exécuter l'algorithme
    result = hierarchical_refinement(X, Y, **params)

    end_time = time.time()
    print(f"Temps d'exécution: {end_time - start_time:.3f} secondes")

    # Évaluer les résultats
    print("\nÉvaluation de l'alignement:")
    eval_results = evaluate_alignment(X, Y, result)

    # Visualiser les résultats
    print("\nVisualisation de l'alignement...")
    visualize_alignment(X, Y, result)

    return result, eval_results

# Exécuter l'exemple
if __name__ == "__main__":
    result, eval_results = run_example()

Générer 512 points synthétiques...
Dimensions de X: (512, 2)
Dimensions de Y: (512, 2)

Début de Hierarchical Refinement...
Optimized rank-annealing schedule: [8, 8, 8]
Utilisation du rank schedule: [8, 8, 8]
Produit des rangs: 512
Niveau 0: 1 co-clusters, rang courant = 8


TypeError: LRSinkhorn.__init__() missing 1 required positional argument: 'rank'