In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from time import time
import numpy as np

# https://github.com/coderlemon17/LemonScripts/blob/master/OptimalTransport/OT.ipynb

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# numpy写的
def compute_optimal_transport(M, r, c, lam, epsilon=1e-5):
    """
    Computes the optimal transport matrix and Slinkhorn distance using the
    Sinkhorn-Knopp algorithm

    Inputs:
        - M : cost matrix (n x m)
        - r : vector of marginals (n, )
        - c : vector of marginals (m, )
        - lam : strength of the entropic regularization  
        - epsilon : convergence parameter

    Output:
        - P : optimal transport matrix (n x m)
        - dist : Sinkhorn distance
    """
    n, m = M.shape
    P = np.exp(- lam * M)
    # Avoiding poor math condition
    P /= P.sum()
    u = np.zeros(n)
    # Normalize this matrix so that P.sum(1) == r, P.sum(0) == c
    cnt = 0
    while np.max(np.abs(u - P.sum(1))) > epsilon:
        if cnt%100 == 0:
            print(cnt)
            print(np.max(np.abs(u - P.sum(1))))
        # Shape (n, )
        u = P.sum(1)
        P *= (r / u).reshape((-1, 1))
        P *= (c / P.sum(0)).reshape((1, -1))
    return P, np.sum(P * M)

layer_similarity = torch.tensor([[0.92, 0.92, 0.90],  [0.92, 0.91, 0.90], [0.81, 0.80, 0.79], [0.52, 0.48, 0.43], [0.61, 0.61, 0.58], [0.69, 0.69, 0.62]])
Cost = 1.0 - layer_similarity 
NA = 6 # number of layers in the Ancestry model
ND = 3 # number of layers in the Descendant model
r = torch.ones((NA), device=Cost.device, dtype=Cost.dtype)/6.0 * 3.0 
c = torch.ones((ND), device=Cost.device, dtype=Cost.dtype)

lam = 10
P, d = compute_optimal_transport(Cost.numpy(), r.numpy(), c.numpy(), lam=lam)
print("OT_matrix")
print(torch.from_numpy(P))
print("final")
print(d)

loss_matrix = [i*j for i in range(5) for j in range(5)]
print(loss_matrix)

tensor([0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000])
0
0.41085106134414673
0
0.4849617
0
0.014769554
0
0.00019866228
OT_matrix
tensor([[0.1495, 0.1659, 0.1846],
        [0.1543, 0.1550, 0.1906],
        [0.1543, 0.1550, 0.1906],
        [0.2121, 0.1578, 0.1301],
        [0.1549, 0.1720, 0.1731],
        [0.1749, 0.1941, 0.1310]])
final
0.81180024
[0, 0, 0, 0, 0, 0, 1, 2, 3, 4, 0, 2, 4, 6, 8, 0, 3, 6, 9, 12, 0, 4, 8, 12, 16]
