In [2]:
%pip install POT

Collecting POT
  Downloading POT-0.9.5-cp311-cp311-win_amd64.whl.metadata (35 kB)
Downloading POT-0.9.5-cp311-cp311-win_amd64.whl (348 kB)
   ---------------------------------------- 0.0/348.6 kB ? eta -:--:--
   --------------------- ------------------ 184.3/348.6 kB 5.6 MB/s eta 0:00:01
   ---------------------------------------- 348.6/348.6 kB 7.4 MB/s eta 0:00:00
Installing collected packages: POT
Successfully installed POT-0.9.5
Note: you may need to restart the kernel to use updated packages.



[notice] A new release of pip is available: 23.3.2 -> 25.0.1
[notice] To update, run: python.exe -m pip install --upgrade pip


In [3]:
import numpy as np
import matplotlib.pyplot as plt
import ot  # pip install POT




In [None]:
class SinkhornLowRankSolver:
    """
    A more serious Low-Rank OT solver:
    - Selects random landmarks (clusters)
    - Projects data on the landmark space
    - Solves OT between the projected distributions
    """
    def __init__(self, rank=8, epsilon=0.05):
        self.rank = rank
        self.epsilon = epsilon

    def solve(self, X, Y):
        n, m = len(X), len(Y)
        a = np.ones((n,)) / n
        b = np.ones((m,)) / m

        # Step 1: Sample landmarks
        idx_X_landmarks = np.random.choice(n, self.rank, replace=False)
        idx_Y_landmarks = np.random.choice(m, self.rank, replace=False)
        X_landmarks = X[idx_X_landmarks]
        Y_landmarks = Y[idx_Y_landmarks]

        # Step 2: Project X and Y onto landmarks
        M_X = ot.dist(X, X_landmarks)
        M_Y = ot.dist(Y, Y_landmarks)

        # Soft assignment based on proximity
        K_X = np.exp(-M_X / self.epsilon)
        K_Y = np.exp(-M_Y / self.epsilon)

        K_X /= K_X.sum(axis=1, keepdims=True)
        K_Y /= K_Y.sum(axis=1, keepdims=True)

        # Step 3: Solve OT between landmark spaces
        a_landmark = np.ones((self.rank,)) / self.rank
        b_landmark = np.ones((self.rank,)) / self.rank
        M_landmark = ot.dist(X_landmarks, Y_landmarks)

        P_landmark = ot.sinkhorn(a_landmark, b_landmark, M_landmark, reg=self.epsilon)

        # Step 4: Infer matching between points
        transport_matrix = K_X @ P_landmark @ K_Y.T

        idx_X = np.argmax(transport_matrix, axis=1)
        idx_Y = np.argmax(transport_matrix, axis=0)

        return idx_X, idx_Y


class HierarchicalRefinement:
    def __init__(self, solver, min_size=1, verbose=True):
        self.solver = solver
        self.min_size = min_size
        self.verbose = verbose
        self.matches = []

    def refine(self, X, Y):
        n = len(X)
        assert len(Y) == n, "X and Y must have the same number of points."

        self.matches = []
        clusters = [(np.arange(n), np.arange(n))]

        level = 0
        while clusters:
            new_clusters = []
            if self.verbose:
                print(f"Niveau {level}, {len(clusters)} clusters")

            for idx_X, idx_Y in clusters:
                if len(idx_X) <= self.min_size:
                    self.matches.extend(list(zip(idx_X, idx_Y)))
                    continue

                idxs_X, idxs_Y = self.solver.solve(X[idx_X], Y[idx_Y])

                unique_labels = np.unique(idxs_X)
                for label in unique_labels:
                    sub_idx_X = idx_X[idxs_X == label]
                    sub_idx_Y = idx_Y[idxs_Y == label]
                    if len(sub_idx_X) > 0 and len(sub_idx_Y) > 0:
                        new_clusters.append((sub_idx_X, sub_idx_Y))

            clusters = new_clusters
            level += 1

        return self.matches

    def plot_matches(self, X, Y):
        plt.figure(figsize=(8, 8))
        plt.scatter(X[:, 0], X[:, 1], c='blue', label='X', marker='o')
        plt.scatter(Y[:, 0], Y[:, 1], c='red', label='Y', marker='x')

        for i, j in self.matches:
            plt.plot([X[i, 0], Y[j, 0]], [X[i, 1], Y[j, 1]], 'gray', linewidth=0.5)

        plt.legend()
        plt.title("Correspondances après HiRef")
        plt.show()

In [None]:
np.random.seed(42)
n_points = 64
X = np.random.rand(n_points, 2)
Y = np.random.rand(n_points, 2)

solver = SinkhornLowRankSolver(rank=4, epsilon=0.1)
hiref = HierarchicalRefinement(solver, min_size=1)
matches = hiref.refine(X, Y)

hiref.plot_matches(X, Y)

Niveau 0, 1 clusters
Niveau 1, 37 clusters
Niveau 2, 11 clusters
Niveau 3, 8 clusters
Niveau 4, 8 clusters
Niveau 5, 8 clusters
Niveau 6, 8 clusters
Niveau 7, 8 clusters
Niveau 8, 8 clusters
Niveau 9, 8 clusters
Niveau 10, 8 clusters
Niveau 11, 8 clusters
Niveau 12, 8 clusters
Niveau 13, 8 clusters
Niveau 14, 8 clusters
Niveau 15, 8 clusters
Niveau 16, 8 clusters
Niveau 17, 8 clusters
Niveau 18, 8 clusters
Niveau 19, 8 clusters
Niveau 20, 8 clusters
Niveau 21, 8 clusters
Niveau 22, 8 clusters
Niveau 23, 8 clusters
Niveau 24, 8 clusters
Niveau 25, 8 clusters
Niveau 26, 8 clusters
Niveau 27, 8 clusters
Niveau 28, 8 clusters
Niveau 29, 8 clusters
Niveau 30, 8 clusters
Niveau 31, 8 clusters
Niveau 32, 8 clusters
Niveau 33, 8 clusters
Niveau 34, 8 clusters
Niveau 35, 8 clusters
Niveau 36, 8 clusters
Niveau 37, 8 clusters
Niveau 38, 8 clusters
Niveau 39, 8 clusters
Niveau 40, 8 clusters
Niveau 41, 8 clusters
Niveau 42, 8 clusters
Niveau 43, 8 clusters
Niveau 44, 8 clusters
Niveau 45, 8 clust

KeyboardInterrupt: 