In [1]:
import polars as pl 
import numpy as np
from scipy.special import psi
from sklearn.datasets import make_classification
from typing import Any
from time import perf_counter

In [2]:
orig_x, orig_y = make_classification(n_samples = 500_000, n_features = 50, n_informative = 25, n_redundant = 25)
df = pl.from_numpy(orig_x).insert_at_idx(0, pl.Series("target", orig_y))
target = "target"
features = df.columns
features.remove(target)

In [3]:
df.head() 

target,column_0,column_1,column_2,column_3,column_4,column_5,column_6,column_7,column_8,column_9,column_10,column_11,column_12,column_13,column_14,column_15,column_16,column_17,column_18,column_19,column_20,column_21,column_22,column_23,column_24,column_25,column_26,column_27,column_28,column_29,column_30,column_31,column_32,column_33,column_34,column_35,column_36,column_37,column_38,column_39,column_40,column_41,column_42,column_43,column_44,column_45,column_46,column_47,column_48,column_49
i32,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64
1,-11.968307,-5.590815,-1.746167,-1.012551,-5.140348,2.990342,0.266907,-9.858675,-4.652851,4.427344,2.894079,-2.131139,0.904102,-3.314921,-8.481504,2.610291,-8.301959,-5.405958,2.51331,-0.216056,-3.729986,7.927808,11.630288,-2.852285,8.109333,2.95304,-14.814188,4.362106,1.533787,-2.188023,5.383688,16.247816,16.74379,1.442346,-2.15984,6.128398,-9.206509,6.773962,1.074484,0.464216,-3.444563,8.649643,-35.455566,-12.341535,-4.165749,-3.844133,0.911746,8.286309,3.158895,-0.71184
0,-13.450774,-8.133396,-5.444318,0.140959,-4.843051,3.518212,3.553526,5.748577,-4.574156,-1.704198,-10.37389,-3.877364,1.900696,-5.072693,-3.559298,3.205881,1.755133,23.380757,1.540192,3.676254,-4.330686,3.831776,-4.226898,1.709885,1.642697,-4.734284,-4.821917,0.321153,0.944158,4.861723,5.822973,-4.011873,-4.792765,22.615943,-6.096436,1.561845,15.879205,-7.638073,4.007871,-10.37521,1.068146,8.171845,-8.655302,0.625733,0.469663,-2.537352,-5.181116,0.072743,3.33454,-1.984165
1,6.840603,15.042843,6.707545,-4.691243,2.414062,-1.603739,-2.351174,10.198456,5.336584,0.001217,0.343879,-0.083142,15.840742,-2.787194,-10.74107,6.020713,2.563522,-4.430095,2.114337,-1.506611,-1.512303,-6.431443,-7.22339,-1.541967,-2.777161,-11.276563,11.316407,-3.708968,-1.925514,1.081416,3.122988,-18.788071,0.918066,-4.407322,0.269026,2.440958,-0.507801,-11.192894,3.73611,8.995161,1.538787,11.68157,8.942071,-11.359194,2.312416,-0.542254,-0.127693,-1.41095,-3.234347,0.219932
0,3.820183,6.052279,-1.659343,-1.789607,-2.135955,4.518278,3.303078,15.753852,-10.791777,-0.816614,7.556424,21.469352,-1.881295,2.579593,-14.077286,0.850466,-6.78561,-5.837398,-4.664091,0.196875,-0.679874,-1.613173,-1.037819,-6.589579,-4.371202,0.892782,5.719641,-0.761811,-6.072487,3.587019,-4.956998,18.627286,1.026229,3.294992,-5.049486,4.227393,-11.677407,-12.580899,0.197161,-4.009421,-0.391283,4.428396,-1.321422,24.754975,-0.848659,1.660904,1.295858,4.784324,4.216042,3.772742
0,-10.805693,-7.630871,-1.788277,-4.364055,-1.565975,-1.139612,-3.093969,-3.042105,2.926529,1.584595,5.686045,-7.661655,-9.165909,1.357721,-1.576992,-3.467461,4.138005,-2.485949,-7.824919,1.392517,1.071709,3.392844,5.63585,4.365423,-1.829961,-0.324683,-13.948674,2.605558,-0.424199,0.522443,11.558255,11.903976,-0.520687,2.167889,-3.635129,-12.448553,2.557372,7.113387,0.700795,-14.884599,-3.540622,-26.062519,-9.760506,14.184144,5.198959,6.480604,0.774962,-2.943167,1.140789,-0.383678


In [4]:
from scipy.spatial import KDTree

def estimate_mi_optimized(df:pl.DataFrame, cols:list[str], target:str, k=3, random_state:int=42):

    n = len(df)
    rng = np.random.default_rng(random_state)

    target_col = df.get_column(target).to_numpy().ravel()
    unique_targets = np.unique(target_col)
    # parts = {t:df.filter((pl.col(target) == t)) for t in df[target].unique()}
    # To do: If any part size < k, abort. This gets rid of "points with unique labels" issue too.
    all_masks = {}
    for t in unique_targets:
        all_masks[t] = target_col == t
        if len(df.filter(pl.col(target) == t)) <= k:
            raise ValueError(f"The target class {t} must have more than {k} values in the dataset.")        

    estimates = []
    psi_n_and_k = psi(n) + psi(k)
    for col in cols:
        c = df.get_column(col).to_numpy().reshape(-1,1)
        c = c + (1e-10 * np.mean(c) * rng.standard_normal(size=c.shape))
        radius = np.empty(n)
        label_counts = np.empty(n)
        for t in unique_targets:
            mask = all_masks[t]
            c_masked = c[mask]
            kd1 = KDTree(data=c_masked)
            # dd = distances from the points the the k nearest points. +1 because this starts from 0. It is 1 off from sklearn's kdtree.
            dd, _ = kd1.query(c_masked, k = k + 1, workers=-1)
            radius[mask] = np.nextafter(dd[:, -1], 0)
            label_counts[mask] = np.sum(mask)

        kd2 = KDTree(data=c) 
        m_all = kd2.query_ball_point(c, r = radius, return_length=True, workers=-1)
        estimates.append(
            max(0, psi_n_and_k - np.mean(psi(label_counts) + psi(m_all)))
        )

    output = pl.from_records([cols, estimates], schema=["feature", "estimated_mi"])
    return output
        

In [5]:
start = perf_counter()
print(estimate_mi_optimized(df, target=target, cols=features).sort("estimated_mi",  descending=True))
end = perf_counter()
print(f"Spent {end-start:.2f}s.")


shape: (50, 2)
┌───────────┬──────────────┐
│ feature   ┆ estimated_mi │
│ ---       ┆ ---          │
│ str       ┆ f64          │
╞═══════════╪══════════════╡
│ column_6  ┆ 0.064877     │
│ column_28 ┆ 0.05626      │
│ column_19 ┆ 0.053086     │
│ column_33 ┆ 0.038917     │
│ …         ┆ …            │
│ column_4  ┆ 0.000093     │
│ column_31 ┆ 0.000032     │
│ column_38 ┆ 0.0          │
│ column_44 ┆ 0.0          │
└───────────┴──────────────┘
Spent 22.86s.


In [6]:
from sklearn.feature_selection import mutual_info_classif

def estimate_mi_sklearn(df:pl.DataFrame, cols:list[str], target:str, k=3, random_state:int=42):
    mi_estimates = mutual_info_classif(df[cols], df[target]
                        , n_neighbors=k, random_state=random_state, discrete_features=False)

    return pl.from_records([cols, mi_estimates], schema=["feature", "estimated_mi"]).sort("estimated_mi", descending=True)

In [7]:
start = perf_counter()
estimate_mi_sklearn(df, target=target, cols=features).sort("estimated_mi",  descending=True)
end = perf_counter()
print(f"Spent {end-start:.2f}s.")

Spent 146.33s.


In [None]:
# Slim down before running timeit.

In [8]:
%%timeit 
estimate_mi_optimized(df.limit(100_000), target=target, cols=features).sort("estimated_mi",  descending=True)

4.1 s ± 38 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [9]:
%%timeit
estimate_mi_sklearn(df.limit(100_000), target=target, cols=features).sort("estimated_mi",  descending=True)

22.5 s ± 36.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [10]:
import sys
sys.path.append('../src')
from eda.eda_selection import mutual_info

In [11]:
mutual_info(df, conti_cols=features, target=target)

100%|██████████| 50/50 [00:22<00:00,  2.26it/s]


feature,estimated_mi
str,f64
"""column_6""",0.064877
"""column_28""",0.05626
"""column_19""",0.053086
"""column_33""",0.038917
"""column_11""",0.037282
"""column_37""",0.031609
"""column_40""",0.027376
"""column_24""",0.023433
"""column_48""",0.023096
"""column_26""",0.021332
