# Mathematical Underpinnings - Lab 6

In [2]:
from sklearn.metrics import mutual_info_score
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
from tqdm import tqdm

## Useful functions

In [3]:
def discetize_2bins(X):
    X_discrete = 1*(X >= 0)
    return X_discrete

In [4]:
def conditional_permutation(X, Z):

    z_values = np.unique(Z)
    n_z_values = len(z_values)
    n = len(Z)

    X_b = np.zeros(n)

    for i in range(n_z_values):

        z_value_tmp = z_values[i]

        X_b[Z == z_value_tmp] = np.random.permutation(X[Z == z_value_tmp])

    return X_b

In [5]:
def conditional_mutual_information(X, Y, Z):

    z_values = np.unique(Z)
    n_z_values = len(z_values)
    n = len(Z)

    cmi = 0

    for i in range(n_z_values):

        z_value_tmp = z_values[i]
        z_condition = (Z == z_value_tmp)

        X_z = X[z_condition]
        Y_z = Y[z_condition]

        mi_XY_z = mutual_info_score(X_z, Y_z)
        p_z = np.sum(z_condition)/n

        cmi += p_z*mi_XY_z

    return cmi

In [6]:
# II(X;Y;Z)
def interaction_information(X, Y, Z):
    return conditional_mutual_information(X, Y, Z) - mutual_info_score(X, Y)

In [7]:
# II(X;Y;Z1;Z2)
def interaction_information2(X, Y, Z1, Z2):
    Z_1_and_2 = 2*Z2 + Z1
    return interaction_information(X, Y, Z_1_and_2) - interaction_information(X, Y, Z1) - interaction_information(X, Y, Z2)

## Task 1

In [8]:
def secmi2(X, Y, Z):
    interaction_info = 0
    for i in range(Z.shape[1]):
        interaction_info += interaction_information(X, Y, Z[:,i])
    return mutual_info_score(X, Y) + interaction_info
    
def secmi3(X, Y, Z):
    interaction_info = 0
    for i in range(Z.shape[1]):
        for j in range(i):
            interaction_info += interaction_information2(X, Y, Z[:, i], Z[:, j])
    return secmi2(X, Y, Z) + interaction_info

### a)

In [9]:
def cond_indep_test_permutation(X, Y, Z, B, stat):

    n_col_Z = Z.shape[1]
    Z_1dim = np.dot(Z, 2**np.linspace(0, n_col_Z-1, n_col_Z))

    if stat == "cmi":
        stat_value = conditional_mutual_information(X, Y, Z_1dim)
    if stat == "secmi2":
        stat_value = secmi2(X, Y, Z)
    if stat == "secmi3":
        stat_value = secmi3(X, Y, Z)

    condition_p_value = 0
    for b in range(B):
        X_b = conditional_permutation(X, Z_1dim)

        if stat == "cmi":
            stat_value_b = conditional_mutual_information(X_b, Y, Z_1dim)
        if stat == "secmi2":
            stat_value_b = secmi2(X_b, Y, Z)
        if stat == "secmi3":
            stat_value_b = secmi3(X_b, Y, Z)

        if stat_value <= stat_value_b:
            condition_p_value += 1

    p_value = (1 + condition_p_value)/(1 + B)

    return 2*len(X)*stat_value, p_value

### b)

In [10]:
n = 100
rng = np.random.default_rng(123)
Y = rng.normal(0, 1, n)
Y_tilde = 2*(Y > 0) -1
Z1 = discetize_2bins(rng.normal(Y_tilde, 1, n))
Z2 = discetize_2bins(rng.normal(Y_tilde, 1, n))
Z3 = discetize_2bins(rng.normal(Y_tilde, 1, n))
Z1_2 = np.vstack([Z1, Z2]).T
Z2_3 = np.vstack([Z2, Z3]).T
X = discetize_2bins(rng.normal(Z1, 1, n))

In [11]:
cond_indep_test_permutation(X, Y_tilde, Z2_3, B=100, stat='secmi3')

(8.993979775732782, 0.13861386138613863)

In [12]:
N = 100
n = 100
p_crit = 0.05
results = []
for i in range(N):
    rng = np.random.default_rng(i)
    Y = rng.normal(0, 1, n)
    Y_tilde = 2*(Y > 0) -1
    Z1 = discetize_2bins(rng.normal(Y_tilde, 1, n))
    Z2 = discetize_2bins(rng.normal(Y_tilde, 1, n))
    Z3 = discetize_2bins(rng.normal(Y_tilde, 1, n))
    Z1_2 = np.vstack([Z1, Z2]).T
    Z2_3 = np.vstack([Z2, Z3]).T
    X = discetize_2bins(rng.normal(Z1, 1, n))
    results += [{
        'secmi2_z12': cond_indep_test_permutation(X, Y_tilde, Z1_2, B=50, stat='secmi2')[1],
        'secmi3_z12': cond_indep_test_permutation(X, Y_tilde, Z1_2, B=50, stat='secmi3')[1],
        'cmi_z12': cond_indep_test_permutation(X, Y_tilde, Z1_2, B=50, stat='cmi')[1],
        'secmi2_z23': cond_indep_test_permutation(X, Y_tilde, Z2_3, B=50, stat='secmi2')[1],
        'secmi3_z23': cond_indep_test_permutation(X, Y_tilde, Z2_3, B=50, stat='secmi3')[1],
        'cmi_z23': cond_indep_test_permutation(X, Y_tilde, Z2_3, B=50, stat='cmi')[1]
    }]
results = pd.DataFrame(results)

In [13]:
(results < 0.05).mean(axis=0)

secmi2_z12    0.03
secmi3_z12    0.04
cmi_z12       0.04
secmi2_z23    0.09
secmi3_z23    0.07
cmi_z23       0.13
dtype: float64

### c)

In [18]:
rng = np.random.default_rng(123)
n = 100
X = rng.binomial(n=1, p=0.5, size=n)
Z1 = rng.binomial(n=1, p=0.5, size=n)
Z2 = rng.binomial(n=1, p=0.5, size=n)
Z3 = rng.binomial(n=1, p=0.5, size=n)
prob = 0.6 * ((X + Z1 + Z2) % 2) + 0.2
Y = rng.binomial(n=1, p=prob, size=n)

In [27]:
cond_indep_test_permutation(X, Y, np.vstack([Z1, Z2]).T, B=50, stat="secmi3")

(22.01623492018674, 0.0196078431372549)

In [28]:
cond_indep_test_permutation(X, Y, np.vstack([Z2, Z3]).T, B=50, stat="secmi3")

(3.042810044576404, 0.5490196078431373)

In [29]:
N = 100
n = 100
results = []
for i in tqdm(range(N)):
    rng = np.random.default_rng(i)
    X = rng.binomial(n=1, p=0.5, size=n)
    Z1 = rng.binomial(n=1, p=0.5, size=n)
    Z2 = rng.binomial(n=1, p=0.5, size=n)
    Z3 = rng.binomial(n=1, p=0.5, size=n)
    prob = 0.6 * ((X + Z1 + Z2) % 2) + 0.2
    Y = rng.binomial(n=1, p=prob, size=n)
    Z1_2 = np.vstack([Z1, Z2]).T
    Z2_3 = np.vstack([Z2, Z3]).T
    results += [{
        'secmi2_z12': cond_indep_test_permutation(X, Y, Z1_2, B=50, stat='secmi2')[1],
        'secmi3_z12': cond_indep_test_permutation(X, Y, Z1_2, B=50, stat='secmi3')[1],
        'cmi_z12': cond_indep_test_permutation(X, Y, Z1_2, B=50, stat='cmi')[1],
        'secmi2_z23': cond_indep_test_permutation(X, Y, Z2_3, B=50, stat='secmi2')[1],
        'secmi3_z23': cond_indep_test_permutation(X, Y, Z2_3, B=50, stat='secmi3')[1],
        'cmi_z23': cond_indep_test_permutation(X, Y, Z2_3, B=50, stat='cmi')[1]
    }]
results = pd.DataFrame(results)

100%|██████████| 100/100 [05:51<00:00,  3.51s/it]


In [30]:
(results < 0.05).mean(axis=0)

secmi2_z12    0.02
secmi3_z12    1.00
cmi_z12       1.00
secmi2_z23    0.06
secmi3_z23    0.04
cmi_z23       0.05
dtype: float64

## Task 2
 
in R