In [1]:
import warnings
warnings.filterwarnings("ignore")
import numpy as np
from scipy.optimize import minimize
from scipy.stats import chi2
from scipy.linalg import sqrtm
from numpy.linalg import det
import numpy.linalg as LA
import matplotlib.pyplot as plt
import math
from tqdm import tqdm
import numpy as np
from scipy.stats import invwishart as iw        
import matplotlib.pyplot as plt

In [2]:
def inv(A):
    return LA.inv(A)

def relu(v):
    threshold = 1E-5
    if v < 100 and v > threshold:
        return np.log1p(1 + np.exp(v))* threshold /np.log1p(1+np.exp(threshold))
    else:
        return v



def pinv(A):
    RELU = np.vectorize(relu)
    tmp_eig, tmp_egv = LA.eig(A)
    M_inv = tmp_egv @ np.diag(1/RELU(tmp_eig)) @ tmp_egv.T
    M = tmp_egv @ np.diag(RELU(tmp_eig)) @ tmp_egv.T
    return M


def generate_covariance(true_mu, dims, df):
    S = (np.tril(iw.rvs(df, 1, size=dims**2).reshape(dims, dims)))*df
    cov = np.dot(S, S.T)
    while(abs(np.linalg.det(cov)) < 1.5):
        cov = cov + 0.5*np.diag(np.diag(cov))
    mu = np.random.multivariate_normal(true_mu, cov, 1)[0]

    return mu, cov

def mutual_covariance(cov_a, cov_b):
    D_a, S_a = np.linalg.eigh(cov_a)
    D_a_sqrt = sqrtm(np.diag(D_a))
    D_a_sqrt_inv = inv(D_a_sqrt)
    M = np.dot(np.dot(np.dot(np.dot(D_a_sqrt_inv, inv(S_a)), cov_b), S_a), D_a_sqrt_inv)    # eqn. 10 in Sijs et al.
    D_b, S_b = np.linalg.eigh(M)
    D_gamma = np.diag(np.clip(D_b, a_min=1.0, a_max=None))   # eqn. 11b in Sijs et al.
    return np.dot(np.dot(np.dot(np.dot(np.dot(np.dot(S_a, D_a_sqrt), S_b), D_gamma), inv(S_b)), D_a_sqrt), inv(S_a))  # eqn. 11a in Sijs et al

def get(dims, df):
    true_mu = np.zeros((dims, ))

    x_ac, C_ac = generate_covariance(true_mu, dims, df)
    x_c, C_c = generate_covariance(true_mu, dims, df)
    x_bc, C_bc = generate_covariance(true_mu, dims, df)

    C_a = LA.inv(LA.inv(C_ac) + LA.inv(C_c))
    C_b = LA.inv(LA.inv(C_bc) + LA.inv(C_c))

    x_a = C_a @ (LA.inv(C_ac) @ x_ac + LA.inv(C_c) @ x_c)
    x_b = C_b @ (LA.inv(C_bc) @ x_bc + LA.inv(C_c) @ x_c)

    C_fus = LA.inv(LA.inv(C_a) + LA.inv(C_b) - LA.inv(C_c))

    x_fus = C_fus @ (LA.inv(C_ac) @ x_ac + LA.inv(C_bc) @ x_bc + LA.inv(C_c) @ x_c)

    return x_a.reshape(1, dims), x_b.reshape(1, dims), C_a, C_b, C_fus, x_fus

def get_critical_value(dimensions, alpha):
    return chi2.ppf((1 - alpha), df=dimensions)

eta = get_critical_value(2, 0.05)

In [3]:
df = 100
x_a, x_b, C_a, C_b, C_fus, t_x_fus = get(2, df)
x_a = x_a.reshape(1, 2)
x_b = x_b.reshape(1, 2)
S_0 = np.array([0, 0])

In [4]:
def objective(S):
    S = S.reshape(1, 2)
    return np.trace(S.T @ S)
    