# Entropic maps 

In [11]:
import numpy as np
import ot
import matplotlib.pyplot as plt

from torchcfm.utils import * #Import the sampling functions

In [31]:
def compute_sinkhorn_coupling(P, Q, reg=0.1, nb_iter=100000):
    """
    Compute the entropic optimal transport plan between datasets P and Q using the Sinkhorn algorithm.

    Parameters:
    - P (np.array): Source dataset where each row is a 2D sample.
    - Q (np.array): Target dataset where each row is a 2D sample.
    - reg (float): Regularization parameter for entropic regularization.

    Returns:
    - np.array: Optimal transport plan matrix.
    """

    # Step 1: Compute the cost matrix using Euclidean distance
    cost_matrix = ot.dist(P, Q, metric='sqeuclidean')

    # Step 2: Normalize histograms (assuming uniform distribution over samples)
    P_hist = np.ones((P.shape[0],)) / P.shape[0]
    Q_hist = np.ones((Q.shape[0],)) / Q.shape[0]

    # Step 3: Compute the Sinkhorn transport plan
    transport_plan = ot.sinkhorn(P_hist, Q_hist, cost_matrix, reg, numItermax=nb_iter)

    return transport_plan

def coupling_to_map(coupling, target_support):
    '''
    coupling_to_map - converts a diffuse coupling into a map
        by computing the conditional expectation of each slice

    :param coupling: 2D numpy array with coupling[i,j] being the mass from Xi to Yj
    :param target_support: locations of the points Yj
    :return: 2D numpy array where the i'th row is the image of the i'th sample under the map
    '''

    if np.abs(np.sum(coupling) - 1.0) > 0.00001:
        raise Exception("coupling does not sum to 1")
        
    if np.min(coupling) < 0.0:
        raise Exception("coupling cannot have negative entries")
        
    if coupling.shape[1] != target_support.shape[0]:
        raise Exception("coupling.shape[1] must equal target_support.shape[0]")
    
    unnormalized_map = coupling @ target_support
    normalized_map = unnormalized_map / coupling.sum(1)[:,np.newaxis]
    
    return normalized_map

def entropic_map(P, Q, reg=0.1, nb_iter=10000):
    pi = compute_sinkhorn_coupling(P, Q, reg=reg, nb_iter=nb_iter)
    return coupling_to_map(pi, Q)

x0 = sample_8gaussians(256).numpy()
x1 = sample_moons(256).numpy()

map = entropic_map(x0, x1, reg=0.2)