In [46]:
import torch
import torch.fft
import geomloss
import time

In [5]:
import torch
import torch.fft

def _sinkhorn_fft_stable_single_eps(
    a: torch.Tensor,
    b: torch.Tensor,
    cost_matrix: torch.Tensor,
    epsilon: float,
    max_iters: int,
    tolerance: float,
    f_init: torch.Tensor = None,
    g_init: torch.Tensor = None,
):
    """
    Internal helper that runs Sinkhorn until convergence for a single epsilon.
    """
    device = a.device
    dtype = a.dtype

    kernel = torch.exp(-cost_matrix / epsilon).to(device, dtype=dtype)
    kernel_hat = torch.fft.fft2(kernel)

    def K_op(potential):
        v = torch.exp(potential / epsilon)
        v_hat = torch.fft.fft2(v)
        Kv_hat = kernel_hat * v_hat
        K_v = torch.fft.ifft2(Kv_hat).real
        K_v_clamped = torch.clamp(K_v, min=1e-10)
        return -epsilon * torch.log(K_v_clamped)

    f = torch.zeros_like(a) if f_init is None else f_init.to(device, dtype=dtype)
    g = torch.zeros_like(b) if g_init is None else g_init.to(device, dtype=dtype)

    for i in range(max_iters):
        g_prev = g
        
        # Update f and g
        f = K_op(g) + epsilon * torch.log(torch.clamp(a, min=1e-10))
        g = K_op(f) + epsilon * torch.log(torch.clamp(b, min=1e-10))
        g = g - g.max()

        # --- CONVERGENCE CHECK ---
        # Check the change in the potential 'g'. When it stabilizes, we've converged.
        if i % 20 == 0: # Check every 20 iterations to reduce overhead
            err = torch.mean(torch.abs(g - g_prev))
            if err < tolerance:
                # print(f"  > Converged after {i+1} iterations.") # Optional: for debugging
                break

    cost = torch.sum(f * a) + torch.sum(g * b)
    return cost, f, g

def sinkhorn_with_scheduling(
    a: torch.Tensor,
    b: torch.Tensor,
    cost_matrix: torch.Tensor,
    target_epsilon: float,
    max_final_iters: int,
    tolerance: float = 1e-6, # Convergence tolerance
    schedule_stages: int = 5,
    verbose: bool = True,
):
    """
    Performs Sinkhorn's algorithm with an epsilon schedule and a convergence criterion.
    """
    epsilon_schedule = torch.logspace(-1, torch.log10(torch.tensor(target_epsilon)), steps=schedule_stages).tolist()
    
    # Use max_final_iters for the last, hardest stage
    iters_schedule = torch.linspace(max_final_iters // (2**(schedule_stages-1)), max_final_iters, steps=schedule_stages, dtype=torch.int).tolist()
    
    f, g = None, None

    if verbose:
        print("--- Starting Sinkhorn with Epsilon Scheduling and Convergence Check ---")

    for i, (eps, iters) in enumerate(zip(epsilon_schedule, iters_schedule)):
        if verbose:
            print(f"Stage {i+1}/{schedule_stages}: ε = {eps:.6f}, max_iters = {iters}, tol = {tolerance:.1e}")
        
        cost, f, g = _sinkhorn_fft_stable_single_eps(
            a, b, cost_matrix,
            epsilon=eps,
            max_iters=iters,
            tolerance=tolerance,
            f_init=f,
            g_init=g
        )
        if not torch.isfinite(cost):
            print(f"Error: Cost became NaN at ε = {eps}. Stopping.")
            break
        if verbose:
            print(f"  > Intermediate Cost: {cost.item():.6f}")

    return cost, f, g

In [7]:
import torch
import geomloss

In [8]:
# Check if a GPU is available and use it, otherwise use the CPU
if torch.cuda.is_available():
    device = torch.device("cuda")
    torch.cuda.set_device(0)
else:
    device = torch.device("cpu")

print(f"Using device: {device}")


# 1. Create two sample 1024x1024 images as torch tensors
# In a real-world scenario, you would load your own images here.
# For this example, we'll create two random images.
print("Creating two 1024x1024 images...")
image_a = torch.rand(1, 1, 1024, 1024, device=device)
image_b = torch.rand(1, 1, 1024, 1024, device=device)

# 2. Normalize the images to be probability distributions (sum to 1)
image_a = image_a / image_a.sum()
image_b = image_b / image_b.sum()

# 3. Define the SamplesLoss for Wasserstein-2 distance
#    - p=2 specifies the Wasserstein-2 distance.
#    - blur sets the regularization strength. A small value is closer to the true Wasserstein distance.
#    - potentials=True makes the loss function return the dual potentials.
#    - backend="auto" will choose the most efficient backend, which will be KeOps-based for large inputs.
print("Defining the Wasserstein-2 loss function...")
loss_fn = geomloss.SamplesLoss(
    loss="sinkhorn", p=2, blur=0.05, potentials=True, backend="auto"
)

# 4. Convert the images to weighted point clouds
#    This is the standard way to use SamplesLoss with gridded data like images.
def image_to_points(image):
    """Converts a 2D image to a set of weighted 2D points."""
    _, _, H, W = image.shape
    # Create a grid of pixel coordinates
    grid_y, grid_x = torch.meshgrid(
        torch.arange(H, dtype=torch.float32, device=device),
        torch.arange(W, dtype=torch.float32, device=device),
        indexing="ij",
    )
    # Reshape the grid and the image to get points and weights
    points = torch.stack([grid_x.ravel(), grid_y.ravel()], dim=1)
    weights = image.ravel()
    return weights, points

print("Converting images to weighted point clouds...")
weights_a, points_a = image_to_points(image_a)
weights_b, points_b = image_to_points(image_b)


# 5. Compute the Wasserstein-2 distance and dual potentials
print("Computing the Wasserstein distance and dual potentials...")
# The loss function returns the dual potentials F and G, from which the distance can be computed.
F, G = loss_fn(weights_a, points_a, weights_b, points_b)

# The Wasserstein distance is the mean of the dual potentials
wasserstein_distance = (F.mean() + G.mean()).item()


# 6. Display the results
print("\n--- Results ---")
print(f"Wasserstein-2 Distance: {wasserstein_distance}")
print(f"Shape of the first dual potential (F): {F.shape}")
print(f"Shape of the second dual potential (G): {G.shape}")

Using device: cuda
Creating two 1024x1024 images...
Defining the Wasserstein-2 loss function...
Converting images to weighted point clouds...
Computing the Wasserstein distance and dual potentials...


NameError: name 'generic_logsumexp' is not defined

In [4]:
import geomloss
print(geomloss.__path__)

['c:\\Users\\Spud\\AppData\\Local\\Programs\\Python\\Python311\\Lib\\site-packages\\geomloss']


ModuleNotFoundError: No module named 'fcntl'