In [1]:
import numpy as np
from scipy import linalg
from scipy.stats import mannwhitneyu

In [139]:
def shuffle(x):
    x1 = x.copy()
    np.random.shuffle(x1)
    return x1

def _residual_linear_fit(x, cov):
    coeffX = linalg.lstsq(cov.T, x)[0]
    return x - coeffX.dot(cov)

def r2(trg, srcLst):
    return 1 - np.var(_residual_linear_fit(trg, np.array(srcLst))) / np.var(trg)

def _add_noise(n, x, y, z, sigX=1, sigY=1, sigZ=1):
    xNew = x + np.random.normal(0, sigX, n)
    yNew = y + np.random.normal(0, sigY, n)
    zNew = z + np.random.normal(0, sigZ, n)
    return xNew, yNew, zNew

def gen_data_xor_noisy(n=1000, sigX=1, sigY=1, sigZ=1):
    x0 = np.random.normal(0, 1, n)
    y0 = np.random.normal(0, 1, n)
    return _add_noise(n, x0, y0, x0 * y0, sigX=sigX, sigY=sigY, sigZ=sigZ)

def quadratic_triplet_decomp_1D(src1, src2, trg):
    # Compute first two centered moments
    c1 = src1 - np.mean(src1)
    c2 = src2 - np.mean(src2)
    cTrg = trg - np.mean(trg)
    c12 = c1 * c2

    # # Fit, compute ExpVariduals and related variances
    rev1 = r2(cTrg, [c1])
    rev2 = r2(cTrg, [c2])
    # rev12 = _relative_explained_variance(cTrg, [c1, c2])
    # rev12sq = _relative_explained_variance(cTrg, [c1, c2, c12])

    revFull = r2(cTrg, [c1, c2, c12])
    revM1 = r2(cTrg, [c2, c12])
    revM2 = r2(cTrg, [c1, c12])
    revM12 = r2(cTrg, [c1, c2])



    # print(revFull, revM1, revM2, revM12)

    # # Compute individual effects
    # red12 = rev1 + rev2 - rev12
    # unq1 = rev12 - rev2
    # unq2 = rev12 - rev1
    # syn12 = rev12sq - rev12

    unq1 = revFull - revM1
    unq2 = revFull - revM2
    syn12 = revFull - revM12
#     red12 = revFull - unq1 - unq2 - syn12
    red12 = rev1 + rev2 - revM12

    return [unq1, unq2, red12, syn12]

In [140]:
trueLst = []
randLst = []
for nTest in range(400):
    x,y,z = gen_data_xor_noisy(n=1000, sigX=0, sigY=0, sigZ=0)
    
    trueLst += [quadratic_triplet_decomp_1D(x,y,z)[2]]
    randLst += [quadratic_triplet_decomp_1D(x,y,shuffle(z))[2]]

In [141]:
mannwhitneyu(trueLst, randLst, alternative='greater')[1]

0.5639951937514551

In [142]:
print(np.mean(trueLst))
print(np.mean(randLst))
print(np.std(trueLst))
print(np.std(randLst))

q = np.quantile(randLst, 0.99)
print(q, np.mean(trueLst >= q))

-1.0719760388436251e-05
-2.6757387264572662e-06
0.0001750191322554168
5.703592505083294e-05
0.00014030129967013456 0.1
