In [1]:
import polars as pl
import numpy as np
from scipy.special import psi
from sklearn.datasets import make_classification
from typing import Any
from time import perf_counter

In [2]:
orig_x, orig_y = make_classification(n_samples = 300_000, n_features = 50, n_informative = 25, n_redundant = 25)
df = pl.from_numpy(orig_x).insert_at_idx(0, pl.Series("target", orig_y))
target = "target"
features = df.columns
features.remove(target)

In [3]:
df.head() 

target,column_0,column_1,column_2,column_3,column_4,column_5,column_6,column_7,column_8,column_9,column_10,column_11,column_12,column_13,column_14,column_15,column_16,column_17,column_18,column_19,column_20,column_21,column_22,column_23,column_24,column_25,column_26,column_27,column_28,column_29,column_30,column_31,column_32,column_33,column_34,column_35,column_36,column_37,column_38,column_39,column_40,column_41,column_42,column_43,column_44,column_45,column_46,column_47,column_48,column_49
i32,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64
0,11.50356,0.429806,1.474147,-1.356332,-1.848187,-3.697524,-3.918588,-1.526723,12.501827,-2.248826,-12.666978,1.177657,0.724407,0.244796,1.946044,11.968625,-0.509509,0.422647,-5.059723,1.196536,3.75947,-1.091892,0.911092,-4.681064,-12.38997,1.215335,4.350664,-6.29382,2.347619,-2.081105,16.346095,-3.482772,12.401306,-5.909624,5.591609,-0.032605,6.804842,-2.77217,1.36524,1.455955,-0.312545,-5.70997,-5.777117,-2.990212,-3.775008,1.420761,1.175697,-6.862175,1.97027,-5.992461
0,-3.105294,-0.341285,-10.083761,4.818264,-2.149998,-1.83556,-2.719747,-3.38598,-5.466626,-2.270568,-17.551761,-4.684044,-6.126515,0.565376,2.941526,11.604482,5.174033,2.292629,-0.757084,-2.819696,1.579771,-3.148015,-5.968023,-9.117488,-1.946583,1.191429,-1.527178,-0.650567,-5.890211,1.160781,2.471245,-17.056095,-13.597424,-4.658944,1.722993,1.900577,-12.831776,1.190638,1.403039,-2.198171,-10.029456,-7.30485,2.632855,5.923424,1.519941,1.437473,-6.648569,-8.862297,-1.66265,19.476192
1,-3.009683,-0.173725,-5.305688,-3.39294,-1.795725,-3.259409,0.678384,-4.396279,-2.811006,-3.422661,3.121645,-0.677429,-1.869744,-0.968596,-4.689661,2.426552,0.338202,2.304795,3.681587,0.875571,-3.472632,-0.761902,6.772799,3.431409,-4.7374,0.065603,-1.793211,0.064225,6.614304,-9.888175,1.146906,-9.402501,-4.477763,-1.21895,-7.361408,0.224716,-2.121848,-11.945049,-0.046274,2.472539,-4.987473,6.968973,-5.760278,1.917678,1.474294,0.002177,5.822436,1.99796,-0.938701,5.247786
0,20.6105,5.974194,14.807029,5.768045,-0.341809,0.548961,2.556123,3.651809,21.635004,2.399569,-2.064209,-0.887618,-2.401863,0.509854,-0.071829,17.227576,-2.89244,-5.433666,-13.93419,4.081027,0.320225,1.197226,13.054601,22.727735,-3.386939,-3.993787,2.821235,4.727874,-2.834722,12.113593,-5.170227,0.694064,16.517866,4.36337,11.002745,-3.866952,0.935343,-0.628372,1.24749,-5.866572,4.820419,-0.890116,-1.971506,0.485948,-1.348603,-1.62773,-11.201212,15.09782,9.227846,1.606904
1,-3.151563,-3.118631,-8.501699,-0.102278,0.916365,0.660685,-1.048504,3.52123,-1.614859,2.445598,1.103768,-5.421418,0.531163,-2.502882,2.615444,-5.664028,-3.346394,2.015558,-5.676821,-4.999961,1.98995,3.001614,3.92288,-8.450276,17.000902,-3.969472,-4.935979,-1.077858,-11.900195,2.367959,1.022803,9.414992,-8.369752,-17.271874,1.510305,1.935489,-7.967335,-2.527683,0.658442,1.020268,4.077512,-17.679637,0.664338,0.413705,2.144528,5.518291,-6.258479,-8.626793,-0.24963,0.672597


In [4]:
from scipy.spatial import KDTree

def estimate_mi_optimized(df:pl.DataFrame, cols:list[str], target:str, k=3, random_state:int=42):

    n = len(df)
    rng = np.random.default_rng(random_state)

    target_col = df.get_column(target).to_numpy().ravel()
    unique_targets = np.unique(target_col)
    # parts = {t:df.filter((pl.col(target) == t)) for t in df[target].unique()}
    # To do: If any part size < k, abort. This gets rid of "points with unique labels" issue too.
    all_masks = {}
    for t in unique_targets:
        all_masks[t] = target_col == t
        if len(df.filter(pl.col(target) == t)) <= k:
            raise ValueError(f"The target class {t} must have more than {k} values in the dataset.")        

    estimates = []
    psi_n_and_k = psi(n) + psi(k)
    for col in cols:
        c = df.get_column(col).to_numpy().reshape(-1,1)
        c = c + (1e-10 * np.mean(c) * rng.standard_normal(size=c.shape))
        radius = np.empty(n)
        label_counts = np.empty(n)
        for t in unique_targets:
            mask = all_masks[t]
            c_masked = c[mask]
            kd1 = KDTree(data=c_masked)
            # dd = distances from the points the the k nearest points. +1 because this starts from 0. It is 1 off from sklearn's kdtree.
            dd, _ = kd1.query(c_masked, k = k + 1, workers=-1)
            radius[mask] = np.nextafter(dd[:, -1], 0)
            label_counts[mask] = np.sum(mask)

        kd2 = KDTree(data=c) 
        m_all = kd2.query_ball_point(c, r = radius, return_length=True, workers=-1)
        estimates.append(
            max(0, psi_n_and_k - np.mean(psi(label_counts) + psi(m_all)))
        )

    output = pl.from_records([cols, estimates], schema=["feature", "estimated_mi"])
    return output
        

In [5]:
start = perf_counter()
output = estimate_mi_optimized(df, target=target, cols=features).sort("estimated_mi",  descending=True)
end = perf_counter()
print(f"Spent {end-start:.2f}s.")
output


Spent 12.68s.


feature,estimated_mi
str,f64
"""column_16""",0.061014
"""column_17""",0.060518
"""column_10""",0.039148
"""column_0""",0.036392
"""column_24""",0.033618
"""column_21""",0.029324
"""column_41""",0.027905
"""column_7""",0.023051
"""column_32""",0.022909
"""column_25""",0.022426


In [6]:
from sklearn.feature_selection import mutual_info_classif

def estimate_mi_sklearn(df:pl.DataFrame, cols:list[str], target:str, k=3, random_state:int=42):
    mi_estimates = mutual_info_classif(df[cols], df[target]
                        , n_neighbors=k, random_state=random_state, discrete_features=False)

    return pl.from_records([cols, mi_estimates], schema=["feature", "estimated_mi"]).sort("estimated_mi", descending=True)

In [7]:
start = perf_counter()
output = estimate_mi_sklearn(df, target=target, cols=features).sort("estimated_mi",  descending=True)
end = perf_counter()
print(f"Spent {end-start:.2f}s.")
output

Spent 76.31s.


feature,estimated_mi
str,f64
"""column_16""",0.061012
"""column_17""",0.060519
"""column_10""",0.039149
"""column_0""",0.036394
"""column_24""",0.03362
"""column_21""",0.029326
"""column_41""",0.027905
"""column_7""",0.02305
"""column_32""",0.022908
"""column_25""",0.022426


In [8]:
import sys
sys.path.append('../src')
from dsds.fs import mutual_info

In [9]:
mutual_info(df, conti_cols=features, target=target)

Mutual Info: 100%|██████████| 50/50 [00:11<00:00,  4.29it/s]


feature,estimated_mi
str,f64
"""column_0""",0.036392
"""column_1""",0.013259
"""column_2""",0.008562
"""column_3""",0.016895
"""column_4""",0.004402
"""column_5""",0.007083
"""column_6""",0.013809
"""column_7""",0.023051
"""column_8""",0.016186
"""column_9""",0.001504
