In [None]:
import polars as pl
import numpy as np
from scipy.special import psi
from sklearn.datasets import make_classification
from typing import Any
from time import perf_counter

In [None]:
orig_x, orig_y = make_classification(n_samples = 300_000, n_features = 50, n_informative = 25, n_redundant = 25)
df = pl.from_numpy(orig_x).insert_at_idx(0, pl.Series("target", orig_y))
target = "target"
features = df.columns
features.remove(target)

In [None]:
df.head() 

In [None]:
from scipy.spatial import KDTree

def estimate_mi_optimized(df:pl.DataFrame, cols:list[str], target:str, k=3, random_state:int=42):

    n = len(df)
    rng = np.random.default_rng(random_state)

    target_col = df.get_column(target).to_numpy().ravel()
    unique_targets = np.unique(target_col)
    # parts = {t:df.filter((pl.col(target) == t)) for t in df[target].unique()}
    # To do: If any part size < k, abort. This gets rid of "points with unique labels" issue too.
    all_masks = {}
    for t in unique_targets:
        all_masks[t] = target_col == t
        if len(df.filter(pl.col(target) == t)) <= k:
            raise ValueError(f"The target class {t} must have more than {k} values in the dataset.")        

    estimates = []
    psi_n_and_k = psi(n) + psi(k)
    for col in cols:
        c = df.get_column(col).to_numpy().reshape(-1,1)
        c = c + (1e-10 * np.mean(c) * rng.standard_normal(size=c.shape))
        radius = np.empty(n)
        label_counts = np.empty(n)
        for t in unique_targets:
            mask = all_masks[t]
            c_masked = c[mask]
            kd1 = KDTree(data=c_masked)
            # dd = distances from the points the the k nearest points. +1 because this starts from 0. It is 1 off from sklearn's kdtree.
            dd, _ = kd1.query(c_masked, k = k + 1, workers=-1)
            radius[mask] = np.nextafter(dd[:, -1], 0)
            label_counts[mask] = np.sum(mask)

        kd2 = KDTree(data=c) 
        m_all = kd2.query_ball_point(c, r = radius, return_length=True, workers=-1)
        estimates.append(
            max(0, psi_n_and_k - np.mean(psi(label_counts) + psi(m_all)))
        )

    output = pl.from_records([cols, estimates], schema=["feature", "estimated_mi"])
    return output
        

In [None]:
start = perf_counter()
output = estimate_mi_optimized(df, target=target, cols=features).sort("estimated_mi",  descending=True)
end = perf_counter()
print(f"Spent {end-start:.2f}s.")
output


In [None]:
from sklearn.feature_selection import mutual_info_classif

def estimate_mi_sklearn(df:pl.DataFrame, cols:list[str], target:str, k=3, random_state:int=42):
    mi_estimates = mutual_info_classif(df[cols], df[target]
                        , n_neighbors=k, random_state=random_state, discrete_features=False)

    return pl.from_records([cols, mi_estimates], schema=["feature", "estimated_mi"]).sort("estimated_mi", descending=True)

In [None]:
start = perf_counter()
output = estimate_mi_sklearn(df, target=target, cols=features).sort("estimated_mi",  descending=True)
end = perf_counter()
print(f"Spent {end-start:.2f}s.")
output

In [None]:
import sys
sys.path.append('../src')
from dsds.fs import mutual_info

In [None]:
mutual_info(df, conti_cols=features, target=target)