In [4]:
import lightgbm as lgb
import numpy as np
import numpy.typing as npt
import pandas as pd
from sklearn.model_selection import StratifiedKFold
from tqdm import tqdm
from sklearn.metrics import roc_auc_score

from cate.model.dataset import Dataset, split, to_rank
from cate.utils import get_logger, path_linker

In [13]:
dataset_name = "test"
pathlinker = path_linker(dataset_name)
logger = get_logger("causalml")
logger.info("load dataset")

ds = Dataset.load(pathlinker.base)
train_ds, test_ds = split(ds, 1 / 3, random_state=42)

# Add Bias To Train Dataset Using LightGBM
_pred_dfs = []
skf = StratifiedKFold(5, shuffle=True, random_state=42)
for i, (train_idx, valid_idx) in enumerate(
    skf.split(np.zeros(len(train_ds)), train_ds.y)
):
    train_X = train_ds.X.iloc[train_idx]
    train_y = train_ds.y.iloc[train_idx].to_numpy().reshape(-1)
    valid_X = train_ds.X.iloc[valid_idx]
    valid_y = train_ds.y.iloc[valid_idx].to_numpy().reshape(-1)

    base_classifier = lgb.LGBMClassifier(
        importance_type="gain",
        random_state=42,
        force_col_wise=True,
        n_jobs=-1,
        verbosity=0,
    )
    base_classifier.fit(
        train_X, train_y, eval_set=[(valid_X, valid_y)], eval_metric="auc"
    )
    pred: npt.NDArray[np.float_] = base_classifier.predict_proba(valid_X)[:, 1]  # type: ignore

    _pred_dfs.append(
        pd.DataFrame(
            {"index": train_ds.y.index[valid_idx], "pred": pred.reshape(-1)}
        ).set_index("index")
    )
pred_df = pd.concat(_pred_dfs)
rank = to_rank(pred_df.index.to_series(), pred_df["pred"]).to_frame()

INFO  2024-11-19 05:44:51 [causalml] load dataset
INFO  2024-11-19 05:44:51 [causalml] load dataset


In [None]:
train_df = pd.merge(train_ds.to_pandas(), rank, left_index=True, right_index=True)

In [None]:
train_ds_list: list[Dataset] = []
for rank in range(1, 101):
    rank_flg = train_df["rank"] <= rank
    localized_train_df = train_df.loc[rank_flg]
    localized_train_ds = Dataset(
        localized_train_df, train_ds.x_columns, train_ds.y_columns, train_ds.w_columns
    )
    train_ds_list.append(localized_train_ds)