In [1]:
import numpy as np
from scipy.spatial.distance import cdist
import math
import matplotlib.pyplot as plt
import copy
from numpy import linalg
import time
import warnings

In [2]:
def phi_(gamma, lamu, C):
    A = (-C/gamma + np.outer(lamu[:n], one) + np.outer(one, lamu[n:]))
    a = A.max()
    A-=a
    s = a+np.log(np.exp(A).sum())
    return gamma*(-lamu[:n].dot(p) - lamu[n:].dot(q) + s)

def f_(gamma, x):
    y = (x.reshape(-1)).copy()
    y[x.reshape(-1) == 0.] = 1.
    y = y.reshape(n, -1)
    return (C * x).sum() + gamma * (x * np.log(y)).sum()

In [3]:
n = 2
one = np.ones(n, np.float64)
I = np.ones([n,n], np.float64)
def f_true(x, C):
    return (C * x).sum()

def B_round(x):
    r = p_ref / x.dot(one)
    r[r>1] = 1.
    F = np.diag(r).dot(x)
    c = q_ref / (x.T).dot(one)
    c[c>1] = 1.
    F = F.dot(np.diag(c))
    err_r = p_ref - F.dot(one)
    err_c = q_ref - (F.T).dot(one)
    return F + np.outer(err_r, err_c) / abs(err_r).sum()

In [21]:
def aam(x0, C, gamma, eps, min_iter=50, min_time=0):
    L=1
    step = 2
    xi = np.zeros_like(x0)
    eta = xi.copy()
    zeta = xi.copy()
    eta_new = xi.copy()
    zeta_new = xi.copy()
    grad2 = alpha_new = alpha = 0
    ustep = np.zeros_like(x0[:n])
    vstep = np.zeros_like(ustep)
    
    f = lambda x: phi_(gamma, x, T)
    f_primal = lambda x: f_(gamma, x)
    
    
    print('C: ', C)
    K = -C/gamma
    primal_var = np.zeros_like(K)
    
    start_time = time.perf_counter()
    
    stage_i = 0
    
    while True and stage_i < min_iter:
        L_new = L/step
        while True and stage_i < min_iter:
            alpha_new = 1/2/L_new + np.sqrt(1/4/L_new/L_new + alpha*alpha*L/L_new)
            tau = 1/alpha_new/L_new
            xi = tau*zeta + (1-tau)*eta
            
            ##############
            logB = (K + np.outer(xi[:n], one) + np.outer(one, xi[n:]))
            max_logB =logB.max()
            logB_stable = logB - max_logB

            B_stable = np.exp(logB_stable)
            u_hat_stable, v_hat_stable = B_stable.dot(one), B_stable.T.dot(one)
            
            Bs_stable = u_hat_stable.sum()

            f_xi = gamma*(-xi[:n].dot(p) - xi[n:].dot(q) + np.log(Bs_stable) + max_logB)
            grad_f_xi = gamma*np.concatenate((-p + 
                                              u_hat_stable/Bs_stable, 
                                              -q + v_hat_stable/Bs_stable),0)
            
            gu, gv = (grad_f_xi[:n]**2).sum(), (grad_f_xi[n:]**2).sum()
            norm2_grad_f_xi = (gu+gv)

            if gu > gv:
                with warnings.catch_warnings():
                    warnings.filterwarnings('error')
                    try:
                        ustep = p/u_hat_stable
                    except Warning as e:
                        u_hat_stable/=u_hat_stable.max()
                        u_hat_stable[u_hat_stable<1e-150] = 1e-150
                        ustep = p/u_hat_stable
                        #print('catchu')
                    
                
                ustep/=ustep.max()
                xi[:n]+=np.log(ustep)
                Z=ustep[:,None]*B_stable
            else:
                with warnings.catch_warnings():
                    warnings.filterwarnings('error')
                    try:
                        vstep = q/v_hat_stable
                    except Warning as e:
                        v_hat_stable/=v_hat_stable.max()
                        v_hat_stable[v_hat_stable<1e-150] = 1e-150
                        vstep = q/v_hat_stable
                        #print('catchv')

                vstep/=vstep.max()
                xi[n:]+=np.log(vstep)
                Z=B_stable*vstep[None,:]
            ##############
            f_eta_new=gamma*(np.log(Z.sum())+max_logB-xi[:n].dot(p)-xi[n:].dot(q))
            
            #print(L_new)
            if f_eta_new <= f_xi - (norm2_grad_f_xi)/2/L_new: # can be optimized 2 itmes
                primal_var = (alpha_new * B_stable/Bs_stable 
                              + L * alpha**2 * primal_var) /(L_new*alpha_new**2)
                
                zeta -= alpha_new * grad_f_xi
                #eta = eta_new.copy()
                eta = xi.copy()
                alpha = alpha_new
                L = L_new
                
                break
            L_new*=step
        
        #print((C * (B_round(primal_var) - primal_var)).sum(), 
        #f_primal(primal_var) + f_eta_new, eps/6)
        
        stage_i += 1
        
        if ((C * (B_round(primal_var) - primal_var)).sum() <= eps/6 and 
        abs(f_primal(primal_var) + f_eta_new) <= eps/6):
            return xi #time.perf_counter() - start_time
        
#     return np.array(history_f), np.array(history_time)

In [22]:
n = 2
T = np.array([[-1000. , -1999.5], [ -999.5, -2000. ]])
C = T
corr_matrix = np.array([[0.25, 0.25], [0.25, 0.25]])

    
p = np.array([0.5, 0.5]) #np.nansum(corr_matrix, axis=1),
q = np.array([0.5, 0.5]) #np.nansum(corr_matrix, axis=0)

people_num = np.nansum(p)

# p = p / np.nansum(p)
# q = q / np.nansum(q)

print(p, q)

p = p.reshape(n)
q = q.reshape(n)

p_ref, q_ref = p, q

[0.5 0.5] [0.5 0.5]


In [23]:
p_ref, q_ref

(array([0.5, 0.5]), array([0.5, 0.5]))

In [24]:
x0 = [0.25, 0.25, 0.25, 0.25]
gamma = n / np.nansum(T)
xi = aam(x0, T, gamma, eps = 10**(-8), min_iter=1000, min_time=0)

C:  [[-1000.  -1999.5]
 [ -999.5 -2000. ]]
































































KeyboardInterrupt: 

In [None]:
xi