Interaction information

In [1]:
%pip install feature-engine

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [2]:
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits import mplot3d
import numpy as np
import matplotlib.pyplot as plt
from sklearn.decomposition import TruncatedSVD
from sklearn.utils.extmath import randomized_svd
import numpy as np
from scipy.integrate import odeint
import matplotlib.pyplot as plt
import math
import pandas as pd
import pandas as pd
from sklearn.metrics import mutual_info_score
from sklearn.feature_selection import mutual_info_classif, mutual_info_regression
from feature_engine.encoding import RareLabelEncoder, OrdinalEncoder

In [3]:
import numpy as np
import pandas as pd
from scipy.special import digamma


Functions

In [4]:
def getPairwiseDistArray(data, coords = [], discrete_dist = 1):
    '''
    Input: 
    data: pandas data frame
    coords: list of indices for variables to be used
    discrete_dist: distance to be used for non-numeric differences
    Output:
    p x n x n array with pairwise distances for each variable
    '''
    n, p = data.shape
    if coords == []:
        coords = range(p)
    col_names = list(data)
    distArray = np.empty([p,n,n])
    distArray[:] = np.nan
    for coord in coords:
        thisdtype=data[col_names[coord]].dtype
        if pd.api.types.is_numeric_dtype(thisdtype):
            distArray[coord,:,:] = abs(data[col_names[coord]].to_numpy() -
                                       data[col_names[coord]].to_numpy()[:,None])
        else:
            distArray[coord,:,:] = (1 - (data[col_names[coord]].to_numpy() ==
                                    data[col_names[coord]].to_numpy()[:,None])) * discrete_dist
    return distArray

def getPointCoordDists(distArray, ind_i, coords = list()):
    '''
    Input: 
    ind_i: current observation row index
    distArray: output from getPariwiseDistArray
    coords: list of variable (column) indices
    output: n x p matrix of all distancs for row ind_i
    '''
    if not coords:
        coords = range(distArray.shape[0])
    obsDists = np.transpose(distArray[coords, :, ind_i])
    return obsDists

def countNeighbors(coord_dists, rho, coords = list()):
    '''
    input: list of coordinate distances (output of coordDistList), 
    coordinates we want (coords), distance (rho)
    output: scalar integer of number of points within ell infinity radius
    '''
    
    if not coords:
        coords = range(coord_dists.shape[1])
    dists = np.max(coord_dists[:,coords], axis = 1)
    count = np.count_nonzero(dists <= rho) - 1
    return count

def getKnnDist(distArray, k):
    '''
    input:
    distArray: numpy 2D array of pairwise, coordinate wise distances,
    output from getPairwiseDistArray
    k: nearest neighbor value
    
    output: (k, distance to knn)
    '''
    dists = np.max(distArray, axis = 1)
    ordered_dists = np.sort(dists)
    # using k, not k-1, here because this includes dist to self
    k_tilde = np.count_nonzero(dists <= ordered_dists[k]) - 1
    return k_tilde, ordered_dists[k]

def cmiPoint(point_i, x, y, z, k, distArray):
    '''
    input:
    point_i: current observation row index
    x, y, z: list of indices
    k: positive integer scalar for k in knn
    distArray: output of getPairwiseDistArray
    output:
    cmi point estimate
    '''
    n = distArray.shape[1]
    coord_dists = getPointCoordDists(distArray, point_i, x + y + z)
    k_tilde, rho = getKnnDist(coord_dists, k)
    x_coords = list(range(len(x)))
    y_coords = list(range(len(x), len(x+y)))
    z_coords = list(range(len(x+y), len(x+y+z)))
    nxz = countNeighbors(coord_dists, rho, x_coords + z_coords)
    nyz = countNeighbors(coord_dists, rho, y_coords + z_coords)
    nz = countNeighbors(coord_dists, rho, z_coords)
    xi = digamma(k_tilde) - digamma(nxz) - digamma(nyz) + digamma(nz)
    return xi

def miPoint(point_i, x, y, k, distArray):
    '''
    input:
    point_i: current observation row index
    x, y: list of indices
    k: positive integer scalar for k in knn
    distArray: output of getPairwiseDistArray
    output:
    mi point estimate
    '''
    n = distArray.shape[1]
    coord_dists = getPointCoordDists(distArray, point_i, x + y)
    k_tilde, rho = getKnnDist(coord_dists, k)
    x_coords = list(range(len(x)))
    y_coords = list(range(len(x), len(x+y)))
    nx = countNeighbors(coord_dists, rho, x_coords)
    ny = countNeighbors(coord_dists, rho, y_coords)
    xi = digamma(k_tilde) + digamma(n) - digamma(nx) - digamma(ny)
    return xi
    
def cmi(x, y, z, k, data, discrete_dist = 1, minzero = 1):
    '''
    computes conditional mutual information, I(x,y|z)
    input:
    x: list of indices for x
    y: list of indices for y
    z: list of indices for z
    k: hyper parameter for kNN
    data: pandas dataframe
    output:
    scalar value of I(x,y|z)
    '''
    # compute CMI for I(x,y|z) using k-NN
    n, p = data.shape

    # convert variable to index if not already
    vrbls = [x,y,z]
    for i, lst in enumerate(vrbls):
        if all(type(elem) == str for elem in lst) and len(lst) > 0:
            vrbls[i] = list(data.columns.get_indexer(lst))
    x,y,z = vrbls
            
    distArray = getPairwiseDistArray(data, x + y + z, discrete_dist)
    if len(z) > 0:
        ptEsts = map(lambda obs: cmiPoint(obs, x, y, z, k, distArray), range(n))
    else:
        ptEsts = map(lambda obs: miPoint(obs, x, y, k, distArray), range(n))
    if minzero == 1:
        return(max(sum(ptEsts)/n,0))
    elif minzero == 0:
        return(sum(ptEsts)/n)

Loading data 

In [5]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [18]:
dt1 = pd.read_csv(r'drive/MyDrive/mutualinformationdataset/Lorenz.csv').iloc[:10000]
dt2 = pd.read_csv(r'drive/MyDrive/mutualinformationdataset/Rossler.csv').iloc[:10000]
dt3 = pd.read_csv(r'drive/MyDrive/mutualinformationdataset/ECG.csv')
dt4 = pd.read_csv(r'drive/MyDrive/mutualinformationdataset/random.csv')
dt5 = pd.read_csv(r'drive/MyDrive/mutualinformationdataset/sinx.csv')

Lorenz

In [22]:
dt1

Unnamed: 0.1,Unnamed: 0,X,Y,Z,T
0,0,0.500000,1.000000,0.050000,0.00
1,1,0.550000,1.129750,0.053666,0.01
2,2,0.607975,1.272157,0.058449,0.02
3,3,0.674393,1.429313,0.064624,0.03
4,4,0.749885,1.603415,0.072540,0.04
...,...,...,...,...,...
9995,9995,-5.744281,-4.953745,25.005358,99.95
9996,9996,-5.665227,-5.076228,24.623022,99.96
9997,9997,-5.606327,-5.216780,24.253906,99.97
9998,9998,-5.567372,-5.374630,23.899524,99.98


In [15]:
imlorenz = cmi(['X'],['Y'],['Z'], 4, dt1) - mutual_info_regression(dt1['X'].to_frame(), dt1['Y'], discrete_features=[False])

In [16]:
imlorenz

array([1.99762159])

Rossler

In [19]:
imrossler = cmi(['X'],['Y'],['Z'], 4, dt2) - mutual_info_regression(dt2['X'].to_frame(), dt2['Y'], discrete_features=[False])

In [20]:
imrossler


array([0.45073509])

ECG

In [23]:
ecgresults = {}
for i in range(1,13):
  for j in range(i+1,13):
    for k in range(j+1,13):
      ecgresults[str(i) + "Vs" + str(j) + "Vs" + str(k)] =  cmi([str(i)],[str(j)],[str(k)], 4, dt3) -mutual_info_regression(dt3[str(i)].to_frame(), dt3[str(j)], discrete_features=[False])[0]


In [24]:
ecgresults

{'1Vs2Vs3': 0.7152484661222913,
 '1Vs2Vs4': 0.13507854697326538,
 '1Vs2Vs5': 0.08692394537487447,
 '1Vs2Vs6': -0.8044119701678978,
 '1Vs2Vs7': -0.5619354982768365,
 '1Vs2Vs8': -0.8150895967034815,
 '1Vs2Vs9': -0.8283152859602442,
 '1Vs2Vs10': -0.8726633089071721,
 '1Vs2Vs11': -0.9911976011207138,
 '1Vs2Vs12': -1.0551672448811469,
 '1Vs3Vs4': 0.09919714787089862,
 '1Vs3Vs5': 0.5279006218138171,
 '1Vs3Vs6': -0.014997182155791267,
 '1Vs3Vs7': -0.5024985018397914,
 '1Vs3Vs8': -0.6125335035835053,
 '1Vs3Vs9': -0.6352796583137179,
 '1Vs3Vs10': -0.6355819650944927,
 '1Vs3Vs11': -0.6189088967007077,
 '1Vs3Vs12': -0.6256126816320321,
 '1Vs4Vs5': 0.1447113685660648,
 '1Vs4Vs6': -0.29594793390695084,
 '1Vs4Vs7': -0.7221914758622264,
 '1Vs4Vs8': -1.0216203864767694,
 '1Vs4Vs9': -0.9749941078879165,
 '1Vs4Vs10': -1.001245632434657,
 '1Vs4Vs11': -1.0067386879174345,
 '1Vs4Vs12': -0.9680390821460367,
 '1Vs5Vs6': -0.33628526608670484,
 '1Vs5Vs7': -0.4380510780611251,
 '1Vs5Vs8': -0.554372038654623,
 '