In [2]:
import numpy as np
import scipy.sparse as sp

In [4]:
def bound(S, a, b):
    return sp.csr_matrix(sp.csr_matrix.maximum(S + b, 0) - sp.csr_matrix.maximum(S + a, 0))

def maxprod(ai, aj, av, m, x):
    y = -np.inf * np.ones(m)
    for i in range(len(ai)):
        y[ai[i]] = max(y[ai[i]], av[i] * x[aj[i]])
    return y

def implicit_maxprod(n, ai, x):
    N = len(ai)
    y = -np.inf * np.ones(N)
    max1 = -np.inf * np.ones(n)
    max2 = -np.inf * np.ones(n)
    max1ind = np.zeros(n, dtype=int)
    for i in range(N):
        if x[i] > max2[ai[i]]:
            if x[i] > max1[ai[i]]:
                max2[ai[i]] = max1[ai[i]]
                max1[ai[i]] = x[i]
                max1ind[ai[i]] = i
            else:
                max2[ai[i]] = x[i]
    for i in range(N):
        if i == max1ind[ai[i]]:
            y[i] = max2[ai[i]]
        else:
            y[i] = max1[ai[i]]
    return y

def round_messages(messages, S, w, alpha, beta, rp, ci, tripi, n, m, perm):
    ai = np.zeros(len(tripi))
    ai[tripi > 0] = messages[perm]
    val, ma, mb, mi = bipartite_matching_primal_dual(rp, ci, ai, tripi, n, m)
    matchweight = np.sum(w[mi])
    cardinality = np.sum(mi)
    overlap = (mi @ (S @ mi.astype(np.float64))) / 2
    f = alpha * matchweight + beta * overlap
    return np.array([f, matchweight, cardinality, overlap])

def bipartite_matching_setup(w, li, lj, m, n):
    nedges = len(w)
    rp = np.ones(n + 1, dtype=int)  # csr matrix with extra edges
    ci = np.zeros(nedges + n, dtype=int)
    ai = np.zeros(nedges + n)
    tripi = np.zeros(nedges + n, dtype=int)

    # 1. build csr representation with a set of extra edges from vertex i to vertex m+i
    rp[0] = 0
    for i in range(nedges):
        rp[li[i] + 1] += 1

    rp = np.cumsum(rp)
    rp_copy = np.copy(rp)
    for i in range(nedges):
        tripi[rp_copy[li[i]]] = i + 1  # 1-based index for triplet index
        ai[rp_copy[li[i]]] = w[i]
        ci[rp_copy[li[i]]] = lj[i]
        rp_copy[li[i]] += 1

    for i in range(n):  # add the extra edges
        tripi[rp_copy[i]] = -1
        ai[rp_copy[i]] = 0
        ci[rp_copy[i]] = m + i
        rp_copy[i] += 1

    # restore the row pointer array
    for i in range(n, 0, -1):
        rp[i] = rp[i - 1]
    rp[0] = 0
    rp = rp + 1

    # 1a. check for duplicates in the data
    colind = np.zeros(m + n, dtype=bool)
    for i in range(n):
        for rpi in range(rp[i] - 1, rp[i + 1] - 1):
            if colind[ci[rpi]]:
                raise ValueError(f"Duplicate edge detected ({i}, {ci[rpi]})")
            colind[ci[rpi]] = True
        for rpi in range(rp[i] - 1, rp[i + 1] - 1):
            colind[ci[rpi]] = False  # reset indicator

    return rp, ci, ai, tripi, n, m

def bipartite_matching_primal_dual(rp, ci, ai, tripi, n, m):
    alpha = np.zeros(n)
    beta = np.zeros(n + m)
    queue = np.zeros(n + m, dtype=int)
    t = np.zeros(n + m, dtype=int)
    match1 = np.zeros(n, dtype=int)
    match2 = np.zeros(n + m, dtype=int)
    tmod = np.zeros(n + m, dtype=int)
    ntmod = 0

    for i in range(n):
        for rpi in range(rp[i] - 1, rp[i] - 1 + rp[i + 1] - rp[i]):
            if ai[rpi] > alpha[i]:
                alpha[i] = ai[rpi]

    i = 0
    while i < n:
        for j in range(ntmod):
            t[tmod[j]] = 0
        ntmod = 0

        head = 0
        tail = 0
        queue[head] = i + 1
        head += 1

        while head > tail and match1[i] == 0:
            k = queue[tail] - 1
            tail += 1
            for rpi in range(rp[k] - 1, rp[k] - 1 + rp[k + 1] - rp[k]):
                j = ci[rpi]
                if ai[rpi] < alpha[k] + beta[j] - 1e-8:
                    continue
                if t[j] == 0:
                    queue[head] = match2[j]
                    head += 1
                    t[j] = k + 1
                    tmod[ntmod] = j
                    ntmod += 1
                    if match2[j] == 0:
                        while j > 0:
                            match2[j] = t[j]
                            k = t[j] - 1
                            temp = match1[k]
                            match1[k] = j
                            j = temp
                        break

        if match1[i] == 0:
            theta = np.inf
            for j in range(tail):
                t1 = queue[j] - 1
                for rpi in range(rp[t1] - 1, rp[t1] - 1 + rp[t1 + 1] - rp[t1]):
                    t2 = ci[rpi]
                    if t[t2] == 0 and alpha[t1] + beta[t2] - ai[rpi] < theta:
                        theta = alpha[t1] + beta[t2] - ai[rpi]

            for j in range(tail):
                alpha[queue[j] - 1] -= theta

            for j in range(ntmod):
                beta[tmod[j]] += theta

            continue

        i += 1

    val = 0
    for i in range(n):
        for rpi in range(rp[i] - 1, rp[i] - 1 + rp[i + 1] - rp[i]):
            if ci[rpi] == match1[i]:
                val += ai[rpi]

    noute = np.sum(match1[:n] <= m)
    m1 = np.zeros(noute, dtype=int)
    m2 = np.zeros(noute, dtype=int)
    noute = 0
    for i in range(n):
        if match1[i] <= m:
            m1[noute] = i + 1
            m2[noute] = match1[i]
            noute += 1

    if tripi is not None:
        mi = np.zeros(len(tripi) - n, dtype=bool)
        for i in range(n):
            for rpi in range(rp[i] - 1, rp[i] - 1 + rp[i + 1] - rp[i]):
                if match1[i] <= m and ci[rpi] == match1[i]:
                    mi[tripi[rpi] - 1] = True

        return val, m1, m2, mi

    return val, m1, m2

In [5]:
def netalignmbp(S, w, a=1, b=1, li=None, lj=None, gamma=0.99, dtype=2, maxiter=100, verbose=1):
    nedges = len(li)
    nsquares = S.nnz // 2
    m = max(li) + 1
    n = max(lj) + 1

    # Initialize the messages
    y = np.zeros(nedges)
    z = np.zeros(nedges)
    Sk = sp.csr_matrix(S.shape)
    if dtype > 1:
        d = np.zeros(nedges)

    # Initialize a few parameters
    damping = gamma
    curdamp = 1
    iter = 1

    # Initialize history
    hista = np.zeros((maxiter, 4))
    histb = np.zeros((maxiter, 4))
    fbest = 0
    fbestiter = 0

    if verbose:
        print(f'{"best":<4} {"iter":<4} {"obj_ma":<7} {"wght_ma":<7} {"card_ma":<7} {"over_ma":<7} {"obj_mb":<7} {"wght_mb":<7} {"card_mb":<7} {"over_mb":<7}')

    # setup the matching problem once
    rp, ci, ai, tripi, matn, matm = bipartite_matching_setup(w, li, lj, m, n)
    mperm = tripi[tripi > 0]

    while iter <= maxiter:
        curdamp = damping * curdamp
        Sknew = bound(Sk.transpose() + b * S, 0, b)
        if dtype > 1:
            dold = d
        d = np.sum(Sknew, axis=1).A1

        ynew = a * w - np.maximum(0, implicit_maxprod(n, lj, z)) + d
        znew = a * w - np.maximum(0, implicit_maxprod(m, li, y)) + d

        Skt = sp.diags(ynew + znew - a * w - d) @ S - Sknew

        if dtype == 1:
            Sk = curdamp * Skt + (1 - curdamp) * Sk
            y = curdamp * ynew + (1 - curdamp) * y
            z = curdamp * znew + (1 - curdamp) * z
        elif dtype == 2:
            prev = y + z - a * w + dold
            y = ynew + (1 - curdamp) * prev
            z = znew + (1 - curdamp) * prev
            Sk = Skt + (1 - curdamp) * (Sk + Sk.transpose() - b * S)
        elif dtype == 3:
            prev = y + z - a * w + dold
            y = curdamp * ynew + (1 - curdamp) * prev
            z = curdamp * znew + (1 - curdamp) * prev
            Sk = curdamp * Skt + (1 - curdamp) * (Sk + Sk.transpose() - b * S)

        hista[iter - 1] = round_messages(y, S, w, a, b, rp, ci, tripi, matn, matm, mperm)
        histb[iter - 1] = round_messages(z, S, w, a, b, rp, ci, tripi, matn, matm, mperm)

        if hista[iter - 1, 0] > fbest:
            fbestiter = iter
            mbest = y
            fbest = hista[iter - 1, 0]

        if histb[iter - 1, 0] > fbest:
            fbestiter = -iter
            mbest = z
            fbest = histb[iter - 1, 0]

        if verbose:
            bestchar = '*a' if fbestiter == iter else '*b' if fbestiter == -iter else ''
            print(f'{bestchar:<4} {iter:<4} {hista[iter - 1, 0]:<7g} {hista[iter - 1, 1]:<7g} {hista[iter - 1, 2]:<7g} {hista[iter - 1, 3]:<7g} {histb[iter - 1, 0]:<7g} {histb[iter - 1, 1]:<7g} {histb[iter - 1, 2]:<7g} {histb[iter - 1, 3]:<7g}')

        iter += 1

    return mbest, hista, histb

