In [48]:
import polars as pl

import cate.dataset as cds

In [49]:
def get_biased_ds(
    ds: cds.Dataset,
    rank_flg: pl.Series,
) -> cds.Dataset:
    tg_flg = pl.Series("group", ds.w) == 1
    tg_ds = cds.filter(ds, [tg_flg, rank_flg])
    cg_ds = cds.filter(ds, [~tg_flg, ~rank_flg])
    biased_ds = cds.concat([tg_ds, cg_ds])
    return cds.sample(biased_ds, frac=1)


In [50]:
def tg_cg_split(
    ds: cds.Dataset,
    rank_flg: pl.Series,
    random_ratio: float,
    random_state: int,
) -> cds.Dataset:
    if random_ratio == 0:
        return get_biased_ds(ds, rank_flg)

    sample_bias_ds = get_biased_ds(ds, rank_flg)
    biased_ds_ratio = len(sample_bias_ds) / len(ds)
    random_ds_ratio = random_ratio * biased_ds_ratio
    if random_ratio == 1:
        return cds.sample(ds, frac=random_ds_ratio, random_state=random_state)

    ds = cds.Dataset(
        ds.to_frame().with_row_index(), ds.x_columns, ds.y_columns, ds.w_columns
    )
    _ds, random_ds = cds.split(ds, test_frac=random_ds_ratio, random_state=random_state)
    rank_flg = (
        rank_flg.to_frame()
        .with_row_index()
        .filter(pl.col("index").is_in(_ds.to_frame()["index"]))
        .drop("index")
        .to_series()
    ).cast(pl.Boolean)
    biased_ds = get_biased_ds(_ds, rank_flg)
    return cds.sample(
        cds.concat([biased_ds, random_ds]), frac=1, random_state=random_state
    )


In [51]:
ds = cds.Dataset(
    pl.DataFrame({"a": [1, 2, 3, 4, 5], "b": [1, 1, 0, 0, 0], "c": [0, 0, 0, 1, 1]}),
    ["a"],
    ["b"],
    ["c"],
)

In [57]:
_ds = get_biased_ds(ds, pl.Series("c", [0, 0, 1, 1, 1]).cast(pl.Boolean))

In [58]:
_ds.to_frame()

a,b,c
i64,i64,i64
4,0,1
5,0,1
1,1,0
2,1,0


In [59]:
ds.to_frame()

a,b,c
i64,i64,i64
1,1,0
2,1,0
3,0,0
4,0,1
5,0,1
