In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import pandas as pd
import seaborn as sns

sns.set_theme()

In [None]:
X_test = pd.read_csv("../data/test_values.csv")
X_train = pd.read_csv("../data/train_values.csv")
y_train = pd.read_csv("../data/train_labels.csv")

In [None]:
cat_cols = X_train.select_dtypes(include="object").columns
numeric_cols = X_train.select_dtypes(include="int64").columns
binary_cols = [col for col in X_train.columns if col.startswith("has")]
numeric_cols = [col for col in numeric_cols if col not in binary_cols]

In [None]:
X = pd.concat([X_train, y_train], axis=1)

In [None]:
X_cat = X_train[cat_cols]
X_cat.head()

In [None]:
for col in cat_cols:
    sns.catplot(data=X, x=col, kind="count", hue="damage_grade")

In [None]:
col = "plan_configuration"

counts = X[col].value_counts(normalize=True).sort_values(ascending=False)
counts

In [None]:
threshold = 0.01
filter_out = counts[counts < 0.01].index
filter_out

In [None]:
X[col] = X[col].replace(filter_out, "other")

In [None]:
X[col].value_counts(normalize=True)

In [None]:
from typing import Literal


def handle_rare_categoricals(
    X: pd.DataFrame,
    threshold: float = 0.01,
    method: Literal["replace", "remove"] = "replace",
):
    """Handle rare categories in a categorical column."""
    X = X.copy()
    cols = X.select_dtypes(include="object").columns
    for col in cols:
        counts = X[col].value_counts(normalize=True).sort_values(ascending=False)
        filter_out = counts[counts < threshold].index

        if method == "replace":
            X[col] = X[col].replace(filter_out, "other")
        elif method == "remove":
            X = X.loc[~X[col].isin(filter_out)]
        else:
            raise ValueError("method must be either 'replace' or 'remove'")
    return X

In [None]:
from sklearn.preprocessing import FunctionTransformer

rare_cat_transformer = FunctionTransformer(
    handle_rare_categoricals, kw_args={"method": "replace", "threshold": 0.05}
)

In [None]:
df_ = rare_cat_transformer.fit_transform(X)

In [None]:
df_["plan_configuration"].value_counts(normalize=True)

In [None]:
for col in cat_cols:
    sns.catplot(data=df_, x=col, kind="count", hue="damage_grade")