In [21]:
import numpy as np
import polars as pl

In [22]:
# 1 select case
# 2 cross validation (train only no hyper opt)
# 3 calculate mean prediction of case level 
# - accuracy 
# - precision
# - recall
# - F1

In [23]:
df = pl.read_csv("/home/surayuth/her2/extracted_features/baseline_feat|level_128.csv")

In [24]:
min_count = 10
max_count = 30

selected_df = (
        df \
        .with_columns(
            pl.len().over("case")
            .alias("count")
        ) 
        .filter(
            pl.col("count") >= min_count
        ) 
        .with_columns(
            pl.min_horizontal(max_count, pl.col("count"))
            .alias("cap_max")
        ) 
        .with_columns(
            pl.arange(1, pl.len() + 1).over("case")
            .alias("case_idx")
        ) 
        .filter(
            pl.col("case_idx") <= pl.col("cap_max")
        )
    )


case_df = selected_df.group_by("case").agg(pl.col("label").min())

In [25]:
from sklearn.model_selection import StratifiedKFold
from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier
from sklearn.metrics import accuracy_score, f1_score, roc_auc_score

color_feat = ["color_feat"]
lbp_feat = [f"lbp{i}" for i in range(10)]
hara_feat = [
        "contrast", "dissim", "homo", "asm",
        "energy", "corrs", "entropy"
    ] 

selected_feat = ["color", "lbp", "hara"]
features = []
if "hole" in selected_feat:
    features += color_feat
if "lbp" in selected_feat:
    features += lbp_feat
if "hara" in selected_feat:
    features += hara_feat

for k in range(10):
    print(f"state: {k}")
    print("=" * 30)
    skf = StratifiedKFold(n_splits=4, random_state=k, shuffle=True)
    acc = []
    f1 = []
    auc = []
    for i, (train_index, test_index) in enumerate(skf.split(case_df.select("case"), case_df.select("label"))):
        train_case = case_df[train_index].select("case")
        test_case = case_df[test_index].select("case")
        train_df = selected_df.filter(pl.col("case").is_in(train_case)).select(*features, "label")
        test_df = selected_df.filter(pl.col("case").is_in(test_case)).select(*features, "label")

        X_train = train_df.drop("label").to_numpy()
        y_train = train_df.select("label").to_numpy().reshape(-1)

        X_test = test_df.drop("label").to_numpy()
        y_test = test_df.select("label").to_numpy().reshape(-1)

        #model = RandomForestClassifier(n_estimators=10)
        model = GradientBoostingClassifier()
        model.fit(X_train, y_train)
        y_pred = model.predict(X_test)

        acc.append(accuracy_score(y_test, y_pred))
        f1.append(f1_score(y_test, y_pred))
        auc.append(roc_auc_score(y_test, y_pred))
    print(np.mean(acc), np.mean(f1), np.mean(auc))

state: 0
0.7743948497485548 0.7898169768312241 0.7739769016914
state: 1
0.7916083774678198 0.8063456649722783 0.789568805717238
state: 2
0.8002651216995336 0.8151629884972572 0.7973362731238585
state: 3
0.7700036070272476 0.7875490519933762 0.7679488397824098
state: 4
0.7587744314768331 0.7675707576948236 0.7604354320105952
state: 5
0.769260320771664 0.7812382036680365 0.7665420149864155
state: 6
0.7810473580589805 0.7977965918786458 0.7824209966365214
state: 7
0.7694774637853596 0.7790290304079176 0.7694520281782004
state: 8
0.7723845111177474 0.7880551600755497 0.7713490809486534
state: 9
0.781549530257718 0.7933489642565978 0.7831004115723349


In [26]:
# level 8
# state: 0
# ==============================
# 0.7488508800271507 0.7657164627598658 0.748268890359457
# state: 1
# ==============================
# 0.7452149689873473 0.7550751379405882 0.744987709852827
# state: 2
# ==============================
# 0.7406790628636386 0.7518998877335172 0.7415026800441883
# state: 3
# ==============================
# 0.7577374467325011 0.7680383704427369 0.7532089102576551
# state: 4
# ==============================
# 0.7541877950464176 0.7702171994302747 0.7527684816973449
# state: 5
# ==============================
# 0.7596592945697114 0.7677087653041283 0.7577694850145011
# state: 6
# ==============================
# 0.7512687356175005 0.7617426539832652 0.7524431476852472
# state: 7
# ==============================
# 0.7318469817753285 0.7471749744704697 0.7330034260341411
# state: 8
# ==============================
# 0.7432958611668694 0.7596883618993681 0.7424572611843463
# state: 9
# ==============================
# 0.7455583288678863 0.7565676873230716 0.746171108818008

In [27]:
# level 16

# state: 0
# ==============================
# 0.712347315776149 0.722176499265892 0.7120017163568473
# state: 1
# ==============================
# 0.750006836625848 0.7712254729676793 0.7474593154797387
# state: 2
# ==============================
# 0.7547223996943145 0.7762329710633001 0.7536800372831292
# state: 3
# ==============================
# 0.7168957848801166 0.7315350982027391 0.7156097413172517
# state: 4
# ==============================
# 0.7574583311475238 0.7699086724003281 0.7537912659642565
# state: 5
# ==============================
# 0.7540571422227562 0.7713124224665199 0.7560585993449082
# state: 6
# ==============================
# 0.7452156749864991 0.7627805623771379 0.740505346968268
# state: 7
# ==============================
# 0.7279507538572076 0.752905376115133 0.725064396030968
# state: 8
# ==============================
# 0.7350519821712105 0.7540273429813007 0.7296540106702686
# state: 9
# ==============================
# 0.756926946280178 0.7793185361431626 0.7507225667518946

In [28]:
# 32

# state: 0
# ==============================
# 0.7403645905471878 0.7600958741740333 0.7359940346454854
# state: 1
# ==============================
# 0.7373904901993397 0.7514391561175626 0.7373315640235911
# state: 2
# ==============================
# 0.7531149812443451 0.7666335552764941 0.7498851145730304
# state: 3
# ==============================
# 0.7461171005593022 0.7700583339849298 0.742894257460136
# state: 4
# ==============================
# 0.7415703004118728 0.7597713245980255 0.7373041761178799
# state: 5
# ==============================
# 0.7612585172334123 0.7720026098126518 0.7544454542794321
# state: 6
# ==============================
# 0.7456623203774135 0.7579291500640545 0.7456401379088382
# state: 7
# ==============================
# 0.7476064845445634 0.7665594262120083 0.745038854014388
# state: 8
# ==============================
# 0.7480109408184205 0.7676980893234412 0.745535538014001
# state: 9
# ==============================
# 0.7496031081858276 0.7718811351600121 0.7455329613760586

In [29]:
# state: 0
# ==============================
# 0.7883921457822445 0.8047174139631045 0.7900871893244139
# state: 1
# ==============================
# 0.7857594847697456 0.8025315944345089 0.7841348292009703
# state: 2
# ==============================
# 0.7696655842140747 0.7864885019412412 0.766413361061999
# state: 3
# ==============================
# 0.8055594368740427 0.8203561978399165 0.8056026843960198
# state: 4
# ==============================
# 0.7901038018430977 0.8033051277719937 0.7870713283160762
# state: 5
# ==============================
# 0.7939206017447641 0.8086835406327242 0.7934083172331674
# state: 6
# ==============================
# 0.75441504758219 0.7681425099488016 0.7537593283465153
# state: 7
# ==============================
# 0.791840703603413 0.8094211534350535 0.7896041569607614
# state: 8
# ==============================
# 0.7587073098480523 0.7726079809118502 0.757784861509656
# state: 9
# ==============================
# 0.752689518578384 0.7617194702645947 0.7520012147100443

In [None]:
# # 128

# state: 0
# ==============================
# 0.7743948497485548 0.7898169768312241 0.7739769016914
# state: 1
# ==============================
# 0.7916083774678198 0.8063456649722783 0.789568805717238
# state: 2
# ==============================
# 0.8002651216995336 0.8151629884972572 0.7973362731238585
# state: 3
# ==============================
# 0.7700036070272476 0.7875490519933762 0.7679488397824098
# state: 4
# ==============================
# 0.7587744314768331 0.7675707576948236 0.7604354320105952
# state: 5
# ==============================
# 0.769260320771664 0.7812382036680365 0.7665420149864155
# state: 6
# ==============================
# 0.7810473580589805 0.7977965918786458 0.7824209966365214
# state: 7
# ==============================
# 0.7694774637853596 0.7790290304079176 0.7694520281782004
# state: 8
# ==============================
# 0.7723845111177474 0.7880551600755497 0.7713490809486534
# state: 9
# ==============================
# 0.781549530257718 0.7933489642565978 0.7831004115723349