In [272]:
import numpy as np
import polars as pl
import matplotlib.pyplot as plt
from sklearn.model_selection import StratifiedKFold
from sklearn.ensemble import RandomForestClassifier

In [273]:
def filter_case(df, min_count, max_count):
    min_count = 10
    max_count = 30
    selected_df = (
            df \
            .with_columns(
                pl.len().over("case")
                .alias("count")
            ) 
            .filter(
                pl.col("count") >= min_count
            ) 
            .with_columns(
                pl.min_horizontal(max_count, pl.col("count"))
                .alias("cap_max")
            ) 
            .with_columns(
                pl.arange(1, pl.len() + 1).over("case")
                .alias("case_idx")
            ) 
            .filter(
                pl.col("case_idx") <= pl.col("cap_max")
            ) 
        )
    return selected_df


In [274]:
MIN_COUNT = 10
MAX_COUNT = 30

df = pl.read_csv("/home/surayuth/her2/extracted_features/orig_feat_level_16_white_balance_False_scale_0.25.csv")
df = filter_case(
        df, min_count=MIN_COUNT, 
        max_count=MAX_COUNT
    ) \
    .with_columns(
        pl.when(pl.col("ihc_score") == "0").then(pl.lit(0))
        .when(pl.col("ihc_score") == "1+").then(pl.lit(1))
        .when(pl.col("ihc_score") == "2-").then(pl.lit(2))
        .when(pl.col("ihc_score") == "2+").then(pl.lit(3))
        .when(pl.col("ihc_score") == "3+").then(pl.lit(4))
        .otherwise(None)
        .alias("ihc_score")
    )

case_df = df.group_by("case") \
    .agg(
        pl.col("label").min(), 
        pl.col("ihc_score").first())

In [281]:
ALPHAS = [0.01, 0.05, 0.1]

for r in range(5):
    calib_skf = StratifiedKFold(n_splits=5, random_state=r, shuffle=True)
    calib_splits = calib_skf.split(case_df.select("case"), case_df.select("ihc_score"))

    # 1) split 1/5 folds for calib
    for train_idx, calib_idx in calib_splits:
        train_case = case_df[train_idx].select("case", "label", "ihc_score")
        calib_case = case_df[calib_idx].select("case", "label", "ihc_score")

        calib_df = df \
            .filter(
                pl.col("case")
                .is_in(calib_case.select("case"))
            ) \
            .drop("count", "cap_max", "case_idx")
        X_calib = calib_df.drop("case", "path", "ihc_score", "label").to_numpy()
        y_calib = calib_df.select("label").to_numpy().reshape(-1)
        ihc_calib = calib_df.select("ihc_score") 
        break
    # 2) used the rest of folds 4/5 for train(2)/val(1)/test(1)
    train_skf = StratifiedKFold(n_splits=4, random_state=r, shuffle=True)
    train_splits = train_skf.split(train_case.select("case"), train_case.select("ihc_score"))    
    for i, (inner_idx, outer_idx) in enumerate(train_splits):
        inner_case = train_case[inner_idx].select("case", "label", "ihc_score")
        outer_case = train_case[outer_idx].select("case", "label", "ihc_score")

        # hyperparameters tuning
        # .....
        best_params = {
            "n_estimators": 100
        }

        train_df = df \
            .filter(
                pl.col("case") 
                .is_in(train_case.select("case"))
            ) \
            .drop("count", "cap_max", "case_idx")
        X_train = train_df.drop("case", "path", "ihc_score", "label").to_numpy()
        y_train = train_df.select("label").to_numpy().reshape(-1)

        test_df = df \
            .filter(
                pl.col("case") 
                .is_in(outer_case.select("case"))
            ) \
            .drop("count", "cap_max", "case_idx")
        X_test = test_df.drop("case", "path", "ihc_score", "label").to_numpy()
        y_test = test_df.select("label").to_numpy().reshape(-1)

        # print(i, len(X_calib), len(X_train), len(X_test), len(X_calib) + len(X_train) + len(X_test))
        model = RandomForestClassifier(**best_params, random_state=0)
        model.fit(X_train, y_train)

        # calib (conformal prediction)
        for alpha in ALPHAS:
            prob_calib = model.predict_proba(X_calib)

            scores0 = 1 - prob_calib[:,0][y_calib == 0]
            scores1 = 1 - prob_calib[:,1][y_calib == 1]
            n_calib = len(prob_calib)
            q_level = np.ceil((n_calib+1) * (1 - alpha)) / n_calib
            qhat0 = np.quantile(scores0, q_level, method='higher')
            qhat1 = np.quantile(scores1, q_level, method='higher')
            # test
            prob_test = model.predict_proba(X_test)
            preds0 = (prob_test[:, 0] >= 1 - qhat0) * 1
            preds1 = (prob_test[:, 1] >= 1 - qhat1) * 1
            
            pred_df = pl.DataFrame({
                    "path": test_df.select("path"),
                    "case": test_df.select("case"),
                    "random_state": r,
                    "alpha": alpha,
                    "fold": i,
                    "pred0": preds0,
                    "pred1": preds1,
                    "label": y_test,
                    "ihc_score": test_df.select("ihc_score"),
                }) \
                .with_columns(
                    (pl.col("pred0") + pl.col("pred1"))
                    .alias("pred_size")
                ) \
                .with_columns(
                    pl.when(
                        (pl.col("pred0") == 1) & (pl.col("pred1") == 0)
                    ).then(pl.lit(0))
                    .when(
                        (pl.col("pred0") == 0) & (pl.col("pred1") == 1)
                    ).then(pl.lit(1))
                    .otherwise(pl.lit(-1))
                    .alias("final_pred")
                ) 
        

        # metric:
        # the result (experiments) should be done for each alpha
        # patch-level
        # - empirical coverage (misconverage rate could be obtained from 1 - empirical coverage)
        #   - total
        #   - only for ihc_score 0,1,2-,2+,3+
        #   - only for label 0, 1
        # - average set size
        #   - total
        #   - only for ihc_score 0,1,2-,2+,3+
        #   - only for label 0, 1
        # - accuracy/precision/recall/f1 for confident predictions only
        # - ambiguity rate (size = 2)
        #   - total
        #   - only for ihc_score 0,1,2-,2+,3+
        #   - only for label 0, 1

        # case-level
        # precision/recall/f1
        # ambiguity rate (for case)


        break

In [282]:
pred_df.group_by("ihc_score", "pred_size",).agg(pl.len().alias("count")).sort("ihc_score", "pred_size")

ihc_score,pred_size,count
i32,i64,u32
0,1,58
1,1,60
2,1,119
2,2,2
3,1,125
3,2,55
4,1,81
4,2,9


In [283]:
pred_df \
    .group_by("label", "ihc_score", "final_pred", "pred_size") \
    .agg(
        pl.len().alias("count")
    ) \
    .sort("label", "ihc_score", "final_pred")

label,ihc_score,final_pred,pred_size,count
i64,i32,i32,i64,u32
0,0,0,1,58
0,1,0,1,60
0,2,-1,2,2
0,2,0,1,119
1,3,-1,2,55
1,3,1,1,125
1,4,-1,2,9
1,4,1,1,81


In [284]:
pred_df

path,case,random_state,alpha,fold,pred0,pred1,label,ihc_score,pred_size,final_pred
str,str,i32,f64,i32,i64,i64,i64,i32,i64,i32
"""./Data_Chula/24 Jan HER2+ DISH…","""24 Jan HER2+ DISH+""",4,0.1,0,0,1,1,3,1,1
"""./Data_Chula/24 Jan HER2+ DISH…","""24 Jan HER2+ DISH+""",4,0.1,0,0,1,1,3,1,1
"""./Data_Chula/24 Jan HER2+ DISH…","""24 Jan HER2+ DISH+""",4,0.1,0,1,1,1,3,2,-1
"""./Data_Chula/24 Jan HER2+ DISH…","""24 Jan HER2+ DISH+""",4,0.1,0,1,1,1,3,2,-1
"""./Data_Chula/24 Jan HER2+ DISH…","""24 Jan HER2+ DISH+""",4,0.1,0,0,1,1,3,1,1
…,…,…,…,…,…,…,…,…,…,…
"""./Data_Chula/21 Sep HER2 score…","""21 Sep HER2 score 0 case 3""",4,0.1,0,1,0,0,0,1,0
"""./Data_Chula/21 Sep HER2 score…","""21 Sep HER2 score 0 case 3""",4,0.1,0,1,0,0,0,1,0
"""./Data_Chula/21 Sep HER2 score…","""21 Sep HER2 score 0 case 3""",4,0.1,0,1,0,0,0,1,0
"""./Data_Chula/21 Sep HER2 score…","""21 Sep HER2 score 0 case 3""",4,0.1,0,1,0,0,0,1,0


In [278]:
# fig, axs = plt.subplots(1, 2, figsize=(14, 5))
# bins = 20
# alpha = 0.05

# n_calib = len(prob_calib)
# q_level = np.ceil((n_calib+1) * (1-ALPHA)) / n_calib

# hist1 = np.histogram(1 - prob_calib[:,0][y_calib == 0], bins=bins)[0]
# hist1 = hist1 / hist1.sum()
# axs[0].bar(np.arange(bins), hist1, width=1)
# axs[0].set_xticks(ticks=np.arange(0, 20, 1), labels=np.arange(0, 20, 1))
# axs[0].plot([q_level * bins, q_level * bins], [0, 1])

# hist2 = np.histogram(1 - prob_calib[:,1][y_calib == 1], bins=bins)[0]
# hist2 = hist2 / hist2.sum()
# axs[1].bar(np.arange(bins), hist2, color="r", width=1)
# axs[1].set_xticks(ticks=np.arange(0, 20, 1), labels=np.arange(0, 20, 1));