In [8]:
import numpy as np
import time
def logsumexp(M):
    """
    Compute logsumexp along the second axis (axis=1) of the matrix M.

    :parameters M : Input matrix 
    :returns: Resulting vector after applying logsumexp
    """
    rmax = np.max(M, axis=1, keepdims=True)  # Keep rmax as a column vector
    rs = rmax + np.log(np.sum(np.exp(M - rmax), axis=1, keepdims=True))  # Keepdims to maintain shape
    return rs.squeeze()  # Convert to 1D array



def round_matrix(X_hat_k, r, c):
    """
    Rounds the transport plan matrix to satisfy marginal constraints.
    :param X_hat_k: Initial transport matrix (n x n).
    :param r: Supply vector (length n).
    :param c: Demand vector (length n).
    :return: Adjusted transport matrix `X_hat` that satisfies the marginal constraints.
    """

    one = np.ones((X_hat_k.shape[0], 1))

    # Scale rows to match supply constraints
    x = r / (X_hat_k @ one).flatten()
    x = np.minimum(x, 1)
    F_1 = (X_hat_k.T * x).T  

    # Scale columns to match demand constraints
    y = c / (F_1.T @ one).flatten()
    y = np.minimum(y, 1)
    F_2 = F_1 * y  

    # Compute row and column errors after scaling
    err_r = r - F_2 @ one.flatten()
    err_c = c - F_2.T @ one.flatten()

    # Adjust the transport matrix to correct errors
    X_hat = F_2 + np.outer(err_r, err_c) / np.linalg.norm(err_r, 1)

    return X_hat


def stable_sinkhorn(a, b, M, reg, num_iters=100000):
    """
    :param a: Supply distribution 
    :param b: Demand distribution 
    :param M: Cost matrix 
    :param reg: Regularization parameter for entropy regularization.
    :param num_iters: Maximum number of Sinkhorn iterations 
    :return: Transport plan matrix P , computation time list, total cost list.
    """

    n = len(a)
    
    # ** Normalize M (similar to Cinf processing) **
    M_max = np.max(M)  
    gamma = reg / (4 * np.log(n))  
    A = -M / (M_max * gamma)  # Normalize cost matrix
        
    u = np.zeros(n)  
    v = np.zeros(n)  
    log_a = np.log(a)  
    log_b = np.log(b)  

    t = 0  # Timer
    start_time = time.time()  
    time_list = []  # Store time t
    cost_list = []  # Store total cost
    tolerance = 1e-6  

    for i in range(num_iters):
        v = log_a - logsumexp(A.T + np.ones((n, 1)) * u.T)
        u = log_b - logsumexp(A + np.ones((n, 1)) * v.T)
    
        # Record total cost and time every 100 iterations
        if i % 100 == 0:
            t = time.time() - start_time
            time_list.append(t)
        
            # Compute transport plan matrix
            log_P = A + u[:, None] + v[None, :]
            P = np.exp(log_P)
            X_k = round_matrix(P, a, b)  # Normalize X_k
          
            # Compute total cost
            total_cost = np.sum(X_k * M)
            cost_list.append(total_cost)
            print(total_cost)
      
            # Stopping criterion
            if len(cost_list) > 1 and abs(cost_list[-1] - cost_list[-2]) < tolerance:
                print(f"Converged at iteration {i} with total cost change {abs(cost_list[-1] - cost_list[-2]):.16f}")
                break
            if t > 2000:  # Stop if runtime exceeds 2000 seconds
                break
    
    return P, np.array(time_list), np.array(cost_list)
