In [1]:
import numpy as np

X = []
vals= [0, 1]
for i in vals:
    for j in vals:
        for k in vals:
            X.append([i, j, k])
X.append([vals[0], vals[0], vals[0]])
X.append([vals[1], vals[1], vals[1]])

X = np.array(X)
Y = X[:, 0]*X[:, 1] + X[:, 2]

## Fit intercept.

In [2]:
f0 = np.mean(Y)
residual = Y - f0
f0

0.8

## Fit mains.

In [3]:
def fit_main(residual, idx):
    f = np.array([np.mean(residual[X[:, idx] == vals[0]]), np.mean(residual[X[:, idx] == vals[1]])])
    f_preds = [f[vals.index(x)] for x in X[:, idx]]
    residual = residual - f_preds
    return f, f_preds, residual

# Iterative algorithm, set the tolerance here.
# Smaller tolerance values result in smaller errors.
tol = 1e-10

updates = np.inf
residual = Y - f0
residual_prev = residual
f1, f1_preds, residual = fit_main(residual, 0)
f2, f2_preds, residual = fit_main(residual, 1)
f3, f3_preds, residual = fit_main(residual, 2)
while updates > tol:
    f1_new, f1_preds_new, residual = fit_main(residual, 0)
    f2_new, f2_preds_new, residual = fit_main(residual, 1)
    f3_new, f3_preds_new, residual = fit_main(residual, 2)
    updates = np.sum(np.abs(f1_preds_new)) + np.sum(np.abs(f2_preds_new)) + np.sum(np.abs(f3_preds_new))
    f1_preds += f1_preds_new
    f2_preds += f2_preds_new
    f3_preds += f3_preds_new
    f1 += f1_new
    f2 += f2_new
    f3 += f3_new

## Fit pairs.

In [4]:
def fit_pair(residual, idx1, idx2):
    f = np.array([
        np.array([np.mean(residual[np.logical_and(X[:, idx1] == vals[0], X[:, idx2] == vals[0])]), np.mean(residual[np.logical_and(X[:, idx1] == vals[0], X[:, idx2] == vals[1])])]),
        np.array([np.mean(residual[np.logical_and(X[:, idx1] == vals[1], X[:, idx2] == vals[0])]), np.mean(residual[np.logical_and(X[:, idx1] == vals[1], X[:, idx2] == vals[1])])])
    ])
    f_preds = [f[vals.index(xi), vals.index(xj)] for (xi, xj) in X[:, [idx1, idx2]]]
    residual = residual - f_preds
    return f, f_preds, residual

updates = np.inf
residual_prev = residual
f12, f12_preds, residual = fit_pair(residual, 0, 1)
f13, f13_preds, residual = fit_pair(residual, 0, 2)
f23, f23_preds, residual = fit_pair(residual, 1, 2)
while updates > tol:
    f12_new, f12_preds_new, residual = fit_pair(residual, 0, 1)
    f13_new, f13_preds_new, residual = fit_pair(residual, 0, 2)
    f23_new, f23_preds_new, residual = fit_pair(residual, 1, 2)
    updates = np.sum(np.abs(f12_preds_new)) + np.sum(np.abs(f13_preds_new)) + np.sum(np.abs(f23_preds_new))
    f12_preds += f12_preds_new
    f13_preds += f13_preds_new
    f23_preds += f23_preds_new
    f12 += f12_new
    f13 += f13_new
    f23 += f23_new

## Evaluate the residual.

In [5]:
errors = residual
print(errors)

assert np.max(np.abs(errors)) < tol
print("Assertion Passed: All errors are smaller than tolerance, which was set to: {}".format(tol))

[ 5.28269701e-19 -2.68132567e-17 -2.86336455e-17 -9.70874044e-19
 -1.05653940e-18  2.68132567e-17  2.86336455e-17  4.85437022e-19
  5.28269701e-19  4.85437022e-19]
Assertion Passed: All errors are smaller than tolerance, which was set to: 1e-10


In [6]:
print("Learned functions:")
print("f0", f0)
print("f1", f1)
print("f2", f2)
print("f3", f3)
print("f12", f12)
print("f13", f13)
print("f23", f23)

Learned functions:
f0 0.8
f1 [-0.25  0.25]
f2 [-0.25  0.25]
f3 [-0.5  0.5]
f12 [[ 0.2 -0.3]
 [-0.3  0.2]]
f13 [[ 7.14771178e-14 -1.07215677e-13]
 [ 1.07198843e-13 -7.14658956e-14]]
f23 [[-1.18859112e-14 -1.78791310e-14]
 [ 1.78288668e-14  1.19194206e-14]]


## Post-hoc Purification

In [7]:
# https://github.com/blengerich/gam_purification
from gam_purification.purify import purify
def make_density(idx1, idx2):
    densities = np.zeros((2, 2))
    for xi, xj in X[:, [idx1, idx2]]:
        densities[vals.index(xi), vals.index(xj)] += 1
    return densities


# Purify f12 into f1 and f2
intercept, m1, m2, mat, i = purify(f12.copy(), densities=make_density(0, 1))

f0 += intercept
f1 += m1
f2 += m2
f12_pure = mat

# Purify f13 into f1 and f3
intercept, m1, m2, mat, i = purify(f13.copy(), densities=make_density(0, 2))

f0 += intercept
f1 += m1
f3 += m2
f13_pure = mat

# Purify f23 into f2 and f3
intercept, m1, m2, mat, i = purify(f23.copy(), densities=make_density(1, 2))

f0 += intercept
f2 += m1
f3 += m2
f23_pure = mat

In [8]:
print("Learned pure functions:")
print("f0", f0)
print("f1", f1)
print("f2", f2)
print("f3", f3)
print("f12", f12_pure)
print("f13", f13_pure)
print("f23", f23_pure)

Learned pure functions:
f0 0.8
f1 [-0.25  0.25]
f2 [-0.25  0.25]
f3 [-0.5  0.5]
f12 [[ 0.2 -0.3]
 [-0.3  0.2]]
f13 [[-1.42886903e-14 -2.14498686e-14]
 [ 2.14330354e-14  1.42999124e-14]]
f23 [[-4.59351921e-16 -7.39292030e-16]
 [ 6.89027881e-16  4.92861353e-16]]
