In [1]:
import matplotlib.pyplot as plt
import copy
import utils.gol as gol
import numpy as np
import torch
gol._init()
gol.set_value('device', 'cuda:1')
from environments.mujoco.ant_goal_cluster import AntGoalClusterEnv
from environments.mujoco.ant_dir_cluster import AntDirClusterEnv

In [3]:
# the original optimal transport problem solver

def compute_optimal_transport(M, r, c, lam, epsilon=1e-5):     
    n, m = M.shape     
    # P = np.exp(- lam * M)
    # Avoiding poor math condition     
    P = M
    P /= P.sum()     
    u = np.zeros(n)
    # Normalize this matrix so that P.sum(1) == r, P.sum(0) == c     
    for i in range(10):         
        # Shape (n, )         
        u = P.sum(1)
        P *= (r / u).reshape((-1, 1))         
        P *= (c / P.sum(0)).reshape((1, -1))     
    
    return P, np.sum(P * M)


M = np.random.randn(4, 4)
print(M)
print("the first method:", compute_optimal_transport(copy.deepcopy(M), np.ones(4), np.ones(4), lam=0.2, epsilon=0.01))

# the sinkhorn-knoll in the current code
@torch.no_grad()
def sinkhorn(Q):
    # Q = torch.exp(scores / self.args.epsilon).t() # Q is K-by-B for consistency with notations from our paper
    B = Q.shape[1]  # number of samples to assign
    K = Q.shape[0]  # how many prototypes

    # make the matrix sums to 1
    sum_Q = torch.sum(Q)
    # dist.all_reduce(sum_Q)
    Q /= sum_Q
    Q *= B # the colomns must sum to 1 so that Q is an assignment
    
    for it in range(10):
        # normalize each row: total weight per prototype must be 1/K
        sum_of_rows = torch.sum(Q, dim=1, keepdim=True)
        # dist.all_reduce(sum_of_rows)
        Q /= sum_of_rows
        # Q /= K

        # normalize each column: total weight per sample must be 1/B
        Q /= torch.sum(Q, dim=0, keepdim=True)
        # Q /= B

    
    return Q

print(M)
print("the second method:", sinkhorn(torch.from_numpy(M)))

[[ 0.04327524  0.69078721 -0.57207887  1.29217531]
 [ 0.26322138 -1.87145274 -1.48656541 -2.00062924]
 [ 0.28909475  1.46297013 -0.10675882  0.97841841]
 [ 0.47832591 -0.77301679 -1.66078182  0.97493419]]
the first method: (array([[ 0.13204075,  4.10995605,  0.36984113,  1.42719716],
       [ 0.3411623 , -4.72980429,  0.40823923, -0.93864325],
       [ 0.25193344,  2.48602508,  0.01971243,  0.3086486 ],
       [ 0.27486352, -0.86617684,  0.20220721,  0.2027975 ]]), 49.865269079484165)
[[ 0.04327524  0.69078721 -0.57207887  1.29217531]
 [ 0.26322138 -1.87145274 -1.48656541 -2.00062924]
 [ 0.28909475  1.46297013 -0.10675882  0.97841841]
 [ 0.47832591 -0.77301679 -1.66078182  0.97493419]]
the second method: tensor([[ 0.1320,  4.1100,  0.3698,  1.4272],
        [ 0.3412, -4.7298,  0.4082, -0.9386],
        [ 0.2519,  2.4860,  0.0197,  0.3086],
        [ 0.2749, -0.8662,  0.2022,  0.2028]], dtype=torch.float64)


In [6]:
A = np.random.randn(4, 4)
B = np.random.randn(4, 4)
print(np.sum(A * B))
print(np.dot(A.T, B).trace())

-1.810956853079488
-1.8109568530794875


In [26]:
M = np.random.randn(10, 4)
print(M)
@torch.no_grad()
def sinkhorn(scores):
    Q = torch.exp(scores / 0.1).t() # Q is K-by-B for consistency with notations from our paper
    B = Q.shape[1]  # number of samples to assign
    K = Q.shape[0]  # how many prototypes

    # make the matrix sums to 1
    sum_Q = torch.sum(Q)
    # dist.all_reduce(sum_Q)
    Q /= sum_Q
    Q *= B # the colomns must sum to 1 so that Q is an assignment
    
    for it in range(200):
        # normalize each row: total weight per prototype must be 1/K
        sum_of_rows = torch.sum(Q, dim=1, keepdim=True)
        # dist.all_reduce(sum_of_rows)
        Q /= sum_of_rows
        # Q /= K

        # normalize each column: total weight per sample must be 1/B
        Q /= torch.sum(Q, dim=0, keepdim=True)
        # Q /= B

    
    return Q.t()

# print(M)
test_results = sinkhorn(torch.from_numpy(M))
print("the second method:", sinkhorn(torch.from_numpy(M)))

[[-0.66626409 -3.28637241 -0.11529508  0.96383359]
 [-1.33434205  2.06294486 -1.54455869 -0.83923558]
 [-1.00172206  2.20889372 -0.59780513  0.14056552]
 [ 0.75693611 -1.07924075  1.01225128  0.29206287]
 [ 0.4814222   0.28424167 -1.57707098 -0.46884366]
 [ 1.25558147 -1.01464578  1.55881364  0.50988612]
 [-0.12671393  0.20937435  1.52011385  0.48144126]
 [-1.43945091 -1.2030839  -0.89878742 -1.65948718]
 [-0.88681909  2.20409042  1.28839136  1.73192686]
 [-0.12319705  0.64734347 -0.15474894  0.39944466]]
the second method: tensor([[2.7652e-07, 2.2670e-21, 1.9696e-06, 1.0000e+00],
        [8.9755e-13, 1.0000e+00, 3.1617e-15, 3.8205e-11],
        [5.8045e-12, 1.0000e+00, 9.5011e-12, 1.5977e-07],
        [7.2819e-01, 1.5155e-11, 2.6971e-01, 2.0998e-03],
        [9.9970e-01, 2.7305e-04, 3.3096e-11, 2.2478e-05],
        [6.2569e-01, 1.6967e-13, 3.7420e-01, 1.0881e-04],
        [2.4432e-06, 1.3813e-07, 9.9968e-01, 3.2209e-04],
        [1.3370e-01, 2.7885e-03, 8.5905e-01, 4.4603e-03],
      

In [6]:
test_results.sum(0)

tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000], dtype=torch.float64)