Kfold logo

In [1]:
import xgboost as xgb
from sklearn.model_selection import LeaveOneGroupOut
from sklearn.metrics import roc_auc_score, f1_score
from sklearn.metrics import roc_curve
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

auc_scores = []
f1_scores = []

params = {
    'max_depth': 6,
    'eta': 0.3,
    'objective': 'binary:logistic',
    'eval_metric': 'auc',
    "gamma": 0,
    "subsample": 1,
    "colsample_bytree": 1,
    "colsample_bylevel": 1,
    "scale_pos_weight": 1,
    "reg_alpha": 0,
    "reg_lambda": 1,
}

data_folder = './l_data/'

for event in range(6):

    # load data
    dtrain = xgb.DMatrix(data_folder + '{}_train.data'.format(event))
    dtest = xgb.DMatrix(data_folder + '{}_val.data'.format(event))
    deval = xgb.DMatrix(data_folder + '{}_val.data'.format(event))
    
    # train
    evallist = [(dtrain, 'train'), (deval, 'eval')]
    bst = xgb.train(params, dtrain, num_boost_round=1000, evals=[(dtest, 'eval')], early_stopping_rounds=200)
    # predict probabilities
    y_pred_proba = bst.predict(dtest)
    # Convert predicted probabilities to class labels
    y_pred_labels = (y_pred_proba > 0.5).astype(int)
    
    y_real = dtest.get_label()
    
    # Calculate AUC
    auc = roc_auc_score(y_real, y_pred_proba)
    auc_scores.append(auc)
    
    # Calculate F1 score
    f1 = f1_score(y_real, y_pred_labels, average='macro')
    f1_scores.append(f1)
    
# Average AUC and F1 scores across all groups
avg_auc = np.mean(auc_scores)
avg_f1 = np.mean(f1_scores)

print(f"Avg AUC: {avg_auc}")
print(f"Avg F1 Score: {avg_f1}")

[0]	eval-auc:0.49082
[1]	eval-auc:0.46601
[2]	eval-auc:0.44634
[3]	eval-auc:0.45322
[4]	eval-auc:0.47475
[5]	eval-auc:0.49877
[6]	eval-auc:0.48513
[7]	eval-auc:0.48167
[8]	eval-auc:0.49085
[9]	eval-auc:0.48513
[10]	eval-auc:0.48171
[11]	eval-auc:0.47338
[12]	eval-auc:0.48066
[13]	eval-auc:0.47966
[14]	eval-auc:0.47624
[15]	eval-auc:0.48003
[16]	eval-auc:0.48301
[17]	eval-auc:0.47743
[18]	eval-auc:0.48375
[19]	eval-auc:0.47698
[20]	eval-auc:0.46211
[21]	eval-auc:0.45438
[22]	eval-auc:0.46152
[23]	eval-auc:0.46404
[24]	eval-auc:0.46159
[25]	eval-auc:0.46546
[26]	eval-auc:0.47007
[27]	eval-auc:0.47089
[28]	eval-auc:0.47483
[29]	eval-auc:0.47103
[30]	eval-auc:0.47252
[31]	eval-auc:0.47959
[32]	eval-auc:0.47565
[33]	eval-auc:0.46962
[34]	eval-auc:0.47089
[35]	eval-auc:0.46739
[36]	eval-auc:0.45014
[37]	eval-auc:0.44798
[38]	eval-auc:0.43348
[39]	eval-auc:0.43207
[40]	eval-auc:0.43601
[41]	eval-auc:0.43675
[42]	eval-auc:0.44017
[43]	eval-auc:0.43779
[44]	eval-auc:0.43809
[45]	eval-auc:0.4392

[160]	eval-auc:0.61161
[161]	eval-auc:0.61060
[162]	eval-auc:0.61001
[163]	eval-auc:0.60996
[164]	eval-auc:0.60862
[165]	eval-auc:0.60905
[166]	eval-auc:0.60926
[167]	eval-auc:0.60894
[168]	eval-auc:0.60883
[169]	eval-auc:0.60990
[170]	eval-auc:0.60878
[171]	eval-auc:0.60803
[172]	eval-auc:0.60883
[173]	eval-auc:0.60835
[174]	eval-auc:0.60862
[175]	eval-auc:0.60766
[176]	eval-auc:0.60878
[177]	eval-auc:0.60841
[178]	eval-auc:0.60867
[179]	eval-auc:0.60915
[180]	eval-auc:0.60851
[181]	eval-auc:0.60942
[182]	eval-auc:0.60926
[183]	eval-auc:0.60873
[184]	eval-auc:0.60894
[185]	eval-auc:0.60846
[186]	eval-auc:0.60830
[187]	eval-auc:0.60835
[188]	eval-auc:0.60761
[189]	eval-auc:0.60798
[190]	eval-auc:0.60814
[191]	eval-auc:0.60777
[192]	eval-auc:0.60782
[193]	eval-auc:0.60729
[194]	eval-auc:0.60643
[195]	eval-auc:0.60707
[196]	eval-auc:0.60750
[197]	eval-auc:0.60809
[198]	eval-auc:0.60809
[199]	eval-auc:0.60809
[200]	eval-auc:0.60739
[201]	eval-auc:0.60707
[202]	eval-auc:0.60617
[203]	eval-

[107]	eval-auc:0.56560
[108]	eval-auc:0.56669
[109]	eval-auc:0.56470
[110]	eval-auc:0.56718
[111]	eval-auc:0.56758
[112]	eval-auc:0.56421
[113]	eval-auc:0.56579
[114]	eval-auc:0.56570
[115]	eval-auc:0.56639
[116]	eval-auc:0.56411
[117]	eval-auc:0.56520
[118]	eval-auc:0.56490
[119]	eval-auc:0.56233
[120]	eval-auc:0.56173
[121]	eval-auc:0.55965
[122]	eval-auc:0.55916
[123]	eval-auc:0.55975
[124]	eval-auc:0.56094
[125]	eval-auc:0.55985
[126]	eval-auc:0.55846
[127]	eval-auc:0.55836
[128]	eval-auc:0.55906
[129]	eval-auc:0.56015
[130]	eval-auc:0.55896
[131]	eval-auc:0.55688
[132]	eval-auc:0.55767
[133]	eval-auc:0.55678
[134]	eval-auc:0.55420
[135]	eval-auc:0.55499
[136]	eval-auc:0.55450
[137]	eval-auc:0.55351
[138]	eval-auc:0.55371
[139]	eval-auc:0.55232
[140]	eval-auc:0.55291
[141]	eval-auc:0.55252
[142]	eval-auc:0.55499
[143]	eval-auc:0.55628
[144]	eval-auc:0.55598
[145]	eval-auc:0.55499
[146]	eval-auc:0.55420
[147]	eval-auc:0.55232
[148]	eval-auc:0.55272
[149]	eval-auc:0.55242
[150]	eval-

[255]	eval-auc:0.60450
[256]	eval-auc:0.60249
[257]	eval-auc:0.60297
[258]	eval-auc:0.60369
[259]	eval-auc:0.60378
[260]	eval-auc:0.60225
[261]	eval-auc:0.60193
[262]	eval-auc:0.60161
[263]	eval-auc:0.60185
[264]	eval-auc:0.60233
[265]	eval-auc:0.60233
[266]	eval-auc:0.60249
[267]	eval-auc:0.60321
[268]	eval-auc:0.60329
[269]	eval-auc:0.60289
[270]	eval-auc:0.60080
[271]	eval-auc:0.60032
[272]	eval-auc:0.60225
[273]	eval-auc:0.60185
[274]	eval-auc:0.60209
[275]	eval-auc:0.60225
[276]	eval-auc:0.60257
[277]	eval-auc:0.60257
[278]	eval-auc:0.60313
[279]	eval-auc:0.60233
[280]	eval-auc:0.60273
[281]	eval-auc:0.60209
[282]	eval-auc:0.60257
[283]	eval-auc:0.60233
[284]	eval-auc:0.60225
[285]	eval-auc:0.60265
[286]	eval-auc:0.60273
[287]	eval-auc:0.60233
[288]	eval-auc:0.60265
[289]	eval-auc:0.60305
[290]	eval-auc:0.60233
[291]	eval-auc:0.60177
[292]	eval-auc:0.60193
[293]	eval-auc:0.60257
[294]	eval-auc:0.60289
[295]	eval-auc:0.60177
[296]	eval-auc:0.60313
[297]	eval-auc:0.60273
[298]	eval-

[179]	eval-auc:0.56652
[180]	eval-auc:0.56675
[181]	eval-auc:0.56523
[182]	eval-auc:0.56523
[183]	eval-auc:0.56605
[184]	eval-auc:0.56605
[185]	eval-auc:0.56523
[186]	eval-auc:0.56535
[187]	eval-auc:0.56640
[188]	eval-auc:0.56640
[189]	eval-auc:0.56699
[190]	eval-auc:0.56769
[191]	eval-auc:0.56465
[192]	eval-auc:0.56535
[193]	eval-auc:0.56722
[194]	eval-auc:0.56628
[195]	eval-auc:0.56535
[196]	eval-auc:0.56547
[197]	eval-auc:0.56547
[198]	eval-auc:0.56477
[199]	eval-auc:0.56593
[200]	eval-auc:0.56628
[201]	eval-auc:0.56722
[202]	eval-auc:0.56769
[203]	eval-auc:0.57026
[204]	eval-auc:0.56862
[205]	eval-auc:0.56932
[206]	eval-auc:0.56979
[207]	eval-auc:0.56921
[208]	eval-auc:0.57084
[209]	eval-auc:0.57201
[210]	eval-auc:0.57201
[211]	eval-auc:0.56897
[212]	eval-auc:0.57073
[213]	eval-auc:0.56979
[214]	eval-auc:0.56944
[215]	eval-auc:0.57108
[216]	eval-auc:0.57084
[217]	eval-auc:0.57178
[218]	eval-auc:0.57119
[219]	eval-auc:0.56792
[220]	eval-auc:0.56699
[221]	eval-auc:0.56839
[222]	eval-

[536]	eval-auc:0.57704
[537]	eval-auc:0.57669
[538]	eval-auc:0.57564
[539]	eval-auc:0.57564
[540]	eval-auc:0.57692
[541]	eval-auc:0.57622
[542]	eval-auc:0.57552
[543]	eval-auc:0.57575
[544]	eval-auc:0.57599
[545]	eval-auc:0.57622
[546]	eval-auc:0.57564
[547]	eval-auc:0.57435
[548]	eval-auc:0.57435
[549]	eval-auc:0.57564
[550]	eval-auc:0.57552
[551]	eval-auc:0.57540
[552]	eval-auc:0.57587
[553]	eval-auc:0.57482
[554]	eval-auc:0.57494
[555]	eval-auc:0.57412
[556]	eval-auc:0.57236
[557]	eval-auc:0.57271
[558]	eval-auc:0.57435
[559]	eval-auc:0.57388
[560]	eval-auc:0.57307
[561]	eval-auc:0.57307
[562]	eval-auc:0.57435
[563]	eval-auc:0.57482
[564]	eval-auc:0.57470
[565]	eval-auc:0.57482
[566]	eval-auc:0.57575
[567]	eval-auc:0.57423
[568]	eval-auc:0.57505
[569]	eval-auc:0.57365
[570]	eval-auc:0.57400
[571]	eval-auc:0.57423
[572]	eval-auc:0.57458
[573]	eval-auc:0.57330
[574]	eval-auc:0.57412
[575]	eval-auc:0.57365
[576]	eval-auc:0.57435
[577]	eval-auc:0.57505
[578]	eval-auc:0.57494
[579]	eval-

In [2]:
print(f"Avg AUC: {avg_auc}")
print(f"Avg F1 Score: {avg_f1}")

Avg AUC: 0.542855765906472
Avg F1 Score: 0.5359202595109523
