In [1]:
import polars as pl
import numpy as np

In [2]:
df = pl.read_csv("/home/surayuth/her2/old_data/extracted_features/combined_feat_scale_0.5.csv")

In [4]:
from sklearn.model_selection import StratifiedKFold
from utils.prep import filter_case

cv = 4
min_img = 10
max_img = 10
selected_features = ["color_feat"]

df = filter_case(df, min_img, max_img) \
    .select("path", "case", "ihc_score", "label", *selected_features)
case_df = df.group_by("case").agg(pl.col("label").min(), pl.col("ihc_score").first())
for r in range(1, 10, 1):
    skf = StratifiedKFold(n_splits=cv, random_state=r, shuffle=True)
    splits = skf.split(case_df.select("case"), case_df.select("ihc_score"))
    for i, (inner_idx, outer_idx) in enumerate(splits):
        inner_case = case_df[inner_idx].select("case", "label", "ihc_score")
        outer_case = case_df[outer_idx].select("case", "label", "ihc_score")

        inner_skf = StratifiedKFold(n_splits=cv-1, random_state=r, shuffle=True)
        inner_splits = inner_skf.split(inner_case.select("case"), inner_case.select("ihc_score"))
        for j, (train_idx, val_idx) in enumerate(inner_splits):
            train_case = inner_case[train_idx].select("case", "label", "ihc_score") 
            val_case = inner_case[val_idx].select("case", "label", "ihc_score")

In [5]:
stat_train = train_case \
    .group_by("label", "ihc_score") \
    .agg(
        pl.len().alias("count")
    ) \
    .with_columns(
        pl.lit("train").alias("type")
    ) \
    .sort("label", "ihc_score")

stat_val = val_case \
    .group_by("label", "ihc_score") \
    .agg(
        pl.len().alias("count")
    ) \
    .with_columns(
        pl.lit("val").alias("type")
    ) \
    .sort("label", "ihc_score")

stat_test = outer_case \
    .group_by("label", "ihc_score") \
    .agg(
        pl.len().alias("count")
    ) \
    .with_columns(
        pl.lit("test").alias("type")
    ) \
    .sort("label", "ihc_score")

tot_stat = pl.concat([
    stat_train,
    stat_val,
    stat_test
])

In [6]:
tot_stat \
    .group_by("type", "label") \
    .agg(
        pl.col("count").sum()
    ) \
    .sort("type") \
    .with_columns(
        pl.col("count").sum().over("label")
        .alias("tot_count")
    ) \
    .with_columns(
        (pl.col("count") / pl.col("tot_count") * 100)
        .round(2)
        .alias("ratio")
    ) \
    .with_columns(
        pl.when(pl.col("type") == "train").then(1)
        .when(pl.col("type") == "val").then(2)
        .otherwise(3).alias("idx")
    ) \
    .sort("label", "idx")

type,label,count,tot_count,ratio,idx
str,i64,u32,u32,f64,i32
"""train""",0,23,45,51.11,1
"""val""",0,11,45,24.44,2
"""test""",0,11,45,24.44,3
"""train""",1,25,49,51.02,1
"""val""",1,12,49,24.49,2
"""test""",1,12,49,24.49,3


In [7]:
tot_stat \
    .group_by("type", "ihc_score") \
    .agg(
        pl.col("count").sum()
    ) \
    .sort("type") \
    .with_columns(
        pl.col("count").sum().over("ihc_score")
        .alias("tot_count")
    ) \
    .with_columns(
        (pl.col("count") / pl.col("tot_count") * 100)
        .round(2)
        .alias("ratio")
    ) \
    .with_columns(
        pl.when(pl.col("type") == "train").then(1)
        .when(pl.col("type") == "val").then(2)
        .otherwise(3).alias("idx")
    ) \
    .sort("ihc_score", "idx") \
    .to_pandas() \
    .head(n=20)

Unnamed: 0,type,ihc_score,count,tot_count,ratio,idx
0,train,0,6,11,54.55,1
1,val,0,2,11,18.18,2
2,test,0,3,11,27.27,3
3,train,1+,6,11,54.55,1
4,val,1+,3,11,27.27,2
5,test,1+,2,11,18.18,3
6,train,2+,15,30,50.0,1
7,val,2+,8,30,26.67,2
8,test,2+,7,30,23.33,3
9,train,2-,11,23,47.83,1
