In [8]:
from __future__ import annotations

In [None]:
import numpy as np
import polars as pl
from typing import overload, Literal
import pandas as pd
import mrmr


# Monkey-patching for older version of Polars referenced in `mrmr`
pl.pearson_corr = pl.corr  # type: ignore

df = pl.scan_csv("fish.csv").collect()

df = df.with_columns(pl.col("Species").cast(pl.Enum(df.get_column("Species").unique())))

df.get_column("Species").value_counts().plot.bar(x="Species", y="count")

df.get_column("Weight").plot.hist()

df.select(pl.selectors.numeric()).corr()

df.describe()

df.get_column("Weight").plot.hist()


@overload
def mrmr_regression(
    df: pl.DataFrame, target_column: str, k: int, return_scores: Literal[True]
) -> tuple[list[str], pd.Series, pd.DataFrame]: ...


@overload
def mrmr_regression(
    df: pl.DataFrame, target_column: str, k: int, return_scores: Literal[False]
) -> list[str]: ...


type MRMR = list[str] | tuple[list[str], pd.Series, pd.DataFrame]


def mrmr_regression(
    df: pl.DataFrame, target_column: str, k: int, return_scores: bool
) -> MRMR:
    return mrmr.polars.mrmr_regression(
        df=df, target_column=target_column, K=k, return_scores=return_scores
    )


selected_features = mrmr_regression(
    df=df, target_column="Width", k=3, return_scores=False
)

first, second, third, *_ = selected_features

print(first, second, third)


def train_test_split_df(df: pl.DataFrame, seed: int = 0, test_size: float = 0.2):
    return df.with_columns(
        pl.int_range(pl.len(), dtype=pl.UInt32)
        .shuffle(seed=seed)
        .gt(pl.len() * test_size)
        .alias("split")
    ).partition_by("split", include_key=False)


def train_test_split(
    X: pl.DataFrame, y: pl.DataFrame, seed: int = 0, test_size: float = 0.2
):
    X_train, X_test = train_test_split_df(X, seed=seed, test_size=test_size)
    y_train, y_test = train_test_split_df(y, seed=seed, test_size=test_size)
    return X_train, X_test, y_train, y_test


a, b = train_test_split_df(df=df, seed=243)

print(a)
print(b)

100%|██████████| 3/3 [00:00<00:00, 268.40it/s]

Weight Length3 Length2
shape: (7,)
Series: 'Species' [u32]
[
	31
	11
	5
	10
	45
	15
	10
]
shape: (32, 7)
┌─────────┬────────┬─────────┬─────────┬─────────┬─────────┬────────┐
│ Species ┆ Weight ┆ Length1 ┆ Length2 ┆ Length3 ┆ Height  ┆ Width  │
│ ---     ┆ ---    ┆ ---     ┆ ---     ┆ ---     ┆ ---     ┆ ---    │
│ enum    ┆ f64    ┆ f64     ┆ f64     ┆ f64     ┆ f64     ┆ f64    │
╞═════════╪════════╪═════════╪═════════╪═════════╪═════════╪════════╡
│ Bream   ┆ 450.0  ┆ 27.6    ┆ 30.0    ┆ 35.1    ┆ 14.0049 ┆ 4.8438 │
│ Bream   ┆ 600.0  ┆ 29.4    ┆ 32.0    ┆ 37.2    ┆ 15.438  ┆ 5.58   │
│ Bream   ┆ 575.0  ┆ 31.3    ┆ 34.0    ┆ 39.5    ┆ 15.1285 ┆ 5.5695 │
│ Bream   ┆ 950.0  ┆ 38.0    ┆ 41.0    ┆ 46.5    ┆ 17.6235 ┆ 6.3705 │
│ Roach   ┆ 40.0   ┆ 12.9    ┆ 14.1    ┆ 16.2    ┆ 4.1472  ┆ 2.268  │
│ …       ┆ …      ┆ …       ┆ …       ┆ …       ┆ …       ┆ …      │
│ Pike    ┆ 567.0  ┆ 43.2    ┆ 46.0    ┆ 48.7    ┆ 7.792   ┆ 4.87   │
│ Smelt   ┆ 7.5    ┆ 10.0    ┆ 10.5    ┆ 11.6    ┆ 1.97


