In [1]:
import numpy as np
from utils import DiscreteDistrib, discrete_wasserstein_distance
from D2_clustering import D2Cluster
from BADMM import badmm_centroid_update

In [2]:
x1 = np.array([[0, 1], [2, 3]])
n1 = x1.shape[0]
x2 =  np.ones((3, 2))
n2 = x2.shape[0]

w1 = np.ones(n1) / n1
w2 = np.ones(n2) / n2

In [3]:
P1 = DiscreteDistrib(w1, x1)
P2 = DiscreteDistrib(w2, x2)

dist, pi = discrete_wasserstein_distance(P1, P2, return_coupling=True)

In [4]:
dist, pi

(1.272019649514065,
 array([[0.16666667, 0.16666667, 0.16666667],
        [0.16666667, 0.16666667, 0.16666667]]))

In [5]:
c0 = DiscreteDistrib(w=np.array([1/2, 1/2]), x=np.array([[0, 0], [1, 1]]))

In [6]:
supp = np.vstack([x1, x2]).T
stride = np.array([n1, n2])
w = np.hstack([w1, w2])

In [7]:
import numpy as np
from utils import DiscreteDistrib
from scipy import sparse
from scipy.spatial import distance_matrix


def badmm_centroid_update(stride, supp, w, c0: DiscreteDistrib,
                          rho=1e-2, nIter=10000, eps=1e-10,
                          tau=10, badmm_tol=1e-3):
    d = supp.shape[0]
    n = len(stride)
    m = len(w)

    c = c0

    support_size = len(c.w)
    posvec = np.concatenate(([0], np.cumsum(stride)))
    
    X = np.zeros((support_size, m))
    Y = np.zeros_like(X)
    Z = np.zeros((support_size, np.sum(stride)))

    spIDX_rows = np.zeros(support_size * m, dtype=int)
    spIDX_cols = np.zeros_like(spIDX_rows, dtype=int)

    for i in range(n):
        xx, yy = np.meshgrid(i * support_size + np.arange(0, support_size),
                             np.arange(posvec[i], posvec[i + 1]))
        ii = support_size * posvec[i] + np.arange(support_size * stride[i])
        spIDX_rows[ii] = xx.flatten()
        spIDX_cols[ii] = yy.flatten()

    spIDX = np.kron(np.eye(support_size), np.ones((1, n)))

    for i in range(n):
        Z[:, posvec[i]: posvec[i + 1]] = 1 / (support_size * stride[i])

    C = distance_matrix(c.x.T, supp.T) ** 2
    
    for iteration in range(nIter):
        # update X
        X = Z * np.exp((C + Y) / (-rho)) + eps
        X = X * (w / np.sum(X, axis=0)).T

        # update Z
        Z0 = Z
        Z = X * np.exp(Y / rho) + eps
#         print(spIDX_cols, spIDX_rows)
        spZ = sparse.csr_matrix((Z.T.ravel(order='F'), (spIDX_rows, spIDX_cols)),
                               shape=(support_size * n, m))
        
#         print(spZ.shape, support_size, Z.shape, spIDX_rows.shape, spIDX_cols.shape)
        tmp = np.sum(spZ, axis=1)
        tmp = np.reshape(tmp, (support_size, n))
        dg = c.w / tmp
        dg = sparse.csr_matrix((np.array(dg).flatten(),
                                (np.arange(n * support_size), np.arange(n * support_size))))
#         print(dg.shape, spZ.shape, spIDX.shape, X.shape)
#         print(m)
        Z = spIDX @ dg @spZ
    
        # update Y
        Y = Y + rho * (X - Z)

        # update c.w
        tmp = tmp / np.sum(tmp, axis=0)
#         print(np.sum(np.sqrt(tmp), axis=1))
        sumW = np.array(np.sum(np.sqrt(tmp), axis=1)) ** 2  # (R2)
        # sumW = np.sum(tmp, axis=1)  # (R1)
        c.w = sumW / np.sum(sumW)
        if iteration % tau == 0:
            c.x = supp @ X.T / np.tile(np.sum(X, axis=1), (d, 1)).T
            C = distance_matrix(c.x.T, supp.T) ** 2
        if iteration % 100 == 0:
            primres = np.linalg.norm(X - Z, 'fro') / np.linalg.norm(Z, 'fro')
            dualres = np.linalg.norm(Z - Z0, 'fro') / np.linalg.norm(Z, 'fro')
            cost = round(np.sum(C * X) / n, 3)
            print(f'Iter: {iteration}, Avg cost {cost}, Primal: {round(primres, 4)}, Dual: {round(dualres, 4)}')
            if np.sqrt(dualres * primres) < badmm_tol:
                print("Early stop activated!")
                break
    return c

In [8]:
c = badmm_centroid_update(stride, supp, w, c0=c0)

Iter: 0, Avg cost 1.25, Primal: 0.7167, Dual: 0.7167
Iter: 100, Avg cost 1.25, Primal: 0.9341, Dual: 0.0001
Iter: 200, Avg cost 2.141, Primal: 0.7714, Dual: 0.0
Iter: 300, Avg cost 2.141, Primal: 0.7717, Dual: 0.0
Iter: 400, Avg cost 2.141, Primal: 0.7719, Dual: 0.0
Iter: 500, Avg cost 2.141, Primal: 0.7721, Dual: 0.0
Iter: 600, Avg cost 2.141, Primal: 0.7721, Dual: 0.0
Iter: 700, Avg cost 2.141, Primal: 0.7722, Dual: 0.0
Iter: 800, Avg cost 2.141, Primal: 0.7722, Dual: 0.0
Iter: 900, Avg cost 2.141, Primal: 0.7723, Dual: 0.0
Iter: 1000, Avg cost 2.141, Primal: 0.7723, Dual: 0.0
Early stop activated!


In [9]:
c

[(array([1., 2.]), array([0.00121132])), (array([0.875, 1.375]), array([0.99878868]))]