Baseline

In [1]:
import xgboost as xgb
from sklearn.model_selection import LeaveOneGroupOut
from sklearn.metrics import roc_auc_score, f1_score
from sklearn.metrics import roc_curve
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

auc_scores = []
f1_scores = []

params = {
    'max_depth': 6,
    'eta': 0.3,
    'objective': 'binary:logistic',
    'eval_metric': 'auc',
    "gamma": 0,
    "subsample": 1,
    "colsample_bytree": 1,
    "colsample_bylevel": 1,
    "scale_pos_weight": 1,
    "reg_alpha": 0,
    "reg_lambda": 1,
}

data_folder = './data_baseline_2/'

for event in range(13):

    # load data
    dtrain = xgb.DMatrix(data_folder + '{}_train.data'.format(event))
    dtest = xgb.DMatrix(data_folder + '{}_val.data'.format(event))
    deval = xgb.DMatrix(data_folder + '{}_val.data'.format(event))
    
    # train
    evallist = [(dtrain, 'train'), (deval, 'eval')]
    bst = xgb.train(params, dtrain, num_boost_round=1000, evals=[(dtest, 'eval')], early_stopping_rounds=200)
    # predict probabilities
    y_pred_proba = bst.predict(dtest)
    # Convert predicted probabilities to class labels
    y_pred_labels = (y_pred_proba > 0.5).astype(int)
    
    y_real = dtest.get_label()
    
    # Calculate AUC
    auc = roc_auc_score(y_real, y_pred_proba)
    auc_scores.append(auc)
    
    # Calculate F1 score
    f1 = f1_score(y_real, y_pred_labels, average='macro')
    f1_scores.append(f1)
    
# Average AUC and F1 scores across all groups
avg_auc_baseline = np.mean(auc_scores)
avg_f1_baseline = np.mean(f1_scores)

print(f"Avg AUC: {avg_auc_baseline}")
print(f"Avg F1 Score: {avg_f1_baseline}")

[0]	eval-auc:0.42433
[1]	eval-auc:0.42601
[2]	eval-auc:0.45545
[3]	eval-auc:0.42588
[4]	eval-auc:0.43815
[5]	eval-auc:0.42885
[6]	eval-auc:0.43169
[7]	eval-auc:0.49019
[8]	eval-auc:0.47004
[9]	eval-auc:0.46358
[10]	eval-auc:0.47934
[11]	eval-auc:0.46617
[12]	eval-auc:0.46643
[13]	eval-auc:0.45403
[14]	eval-auc:0.44112
[15]	eval-auc:0.45945
[16]	eval-auc:0.45480
[17]	eval-auc:0.46178
[18]	eval-auc:0.46281
[19]	eval-auc:0.47856
[20]	eval-auc:0.47624
[21]	eval-auc:0.48089
[22]	eval-auc:0.47753
[23]	eval-auc:0.47469
[24]	eval-auc:0.47624
[25]	eval-auc:0.46798
[26]	eval-auc:0.47314
[27]	eval-auc:0.48140
[28]	eval-auc:0.48295
[29]	eval-auc:0.48580
[30]	eval-auc:0.48657
[31]	eval-auc:0.50284
[32]	eval-auc:0.50026
[33]	eval-auc:0.48760
[34]	eval-auc:0.49613
[35]	eval-auc:0.49380
[36]	eval-auc:0.48812
[37]	eval-auc:0.48580
[38]	eval-auc:0.48605
[39]	eval-auc:0.48476
[40]	eval-auc:0.49354
[41]	eval-auc:0.49380
[42]	eval-auc:0.49587
[43]	eval-auc:0.49664
[44]	eval-auc:0.50852
[45]	eval-auc:0.5118

[93]	eval-auc:0.54876
[94]	eval-auc:0.54533
[95]	eval-auc:0.54716
[96]	eval-auc:0.54670
[97]	eval-auc:0.54762
[98]	eval-auc:0.54762
[99]	eval-auc:0.54602
[100]	eval-auc:0.54625
[101]	eval-auc:0.54853
[102]	eval-auc:0.54510
[103]	eval-auc:0.54418
[104]	eval-auc:0.54808
[105]	eval-auc:0.54670
[106]	eval-auc:0.53869
[107]	eval-auc:0.54212
[108]	eval-auc:0.53961
[109]	eval-auc:0.54281
[110]	eval-auc:0.53915
[111]	eval-auc:0.54006
[112]	eval-auc:0.53984
[113]	eval-auc:0.53686
[114]	eval-auc:0.53732
[115]	eval-auc:0.54052
[116]	eval-auc:0.53686
[117]	eval-auc:0.53663
[118]	eval-auc:0.53365
[119]	eval-auc:0.53274
[120]	eval-auc:0.53228
[121]	eval-auc:0.53068
[122]	eval-auc:0.53022
[123]	eval-auc:0.52587
[124]	eval-auc:0.52633
[125]	eval-auc:0.52495
[126]	eval-auc:0.52129
[127]	eval-auc:0.51992
[128]	eval-auc:0.51946
[129]	eval-auc:0.51946
[130]	eval-auc:0.51809
[131]	eval-auc:0.51763
[132]	eval-auc:0.51877
[133]	eval-auc:0.51374
[134]	eval-auc:0.51397
[135]	eval-auc:0.51099
[136]	eval-auc:0.5

[244]	eval-auc:0.63958
[245]	eval-auc:0.63958
[246]	eval-auc:0.63944
[247]	eval-auc:0.63944
[248]	eval-auc:0.63931
[249]	eval-auc:0.63889
[250]	eval-auc:0.64097
[251]	eval-auc:0.64167
[252]	eval-auc:0.64056
[253]	eval-auc:0.63931
[254]	eval-auc:0.63944
[255]	eval-auc:0.63944
[256]	eval-auc:0.63986
[257]	eval-auc:0.63986
[258]	eval-auc:0.64069
[259]	eval-auc:0.64028
[260]	eval-auc:0.64194
[261]	eval-auc:0.64111
[262]	eval-auc:0.64125
[263]	eval-auc:0.64222
[264]	eval-auc:0.64250
[265]	eval-auc:0.64125
[266]	eval-auc:0.64153
[267]	eval-auc:0.64111
[268]	eval-auc:0.64125
[269]	eval-auc:0.64028
[270]	eval-auc:0.64000
[271]	eval-auc:0.64097
[272]	eval-auc:0.63986
[273]	eval-auc:0.64056
[274]	eval-auc:0.64194
[275]	eval-auc:0.64194
[276]	eval-auc:0.64097
[277]	eval-auc:0.64167
[278]	eval-auc:0.64181
[279]	eval-auc:0.64111
[280]	eval-auc:0.64167
[281]	eval-auc:0.64264
[282]	eval-auc:0.64222
[283]	eval-auc:0.64208
[284]	eval-auc:0.64264
[285]	eval-auc:0.64208
[286]	eval-auc:0.64153
[287]	eval-

[83]	eval-auc:0.44784
[84]	eval-auc:0.44518
[85]	eval-auc:0.44668
[86]	eval-auc:0.44701
[87]	eval-auc:0.45332
[88]	eval-auc:0.44983
[89]	eval-auc:0.45183
[90]	eval-auc:0.45332
[91]	eval-auc:0.45515
[92]	eval-auc:0.45432
[93]	eval-auc:0.45332
[94]	eval-auc:0.45116
[95]	eval-auc:0.45365
[96]	eval-auc:0.45781
[97]	eval-auc:0.45880
[98]	eval-auc:0.45548
[99]	eval-auc:0.45515
[100]	eval-auc:0.45565
[101]	eval-auc:0.45664
[102]	eval-auc:0.45930
[103]	eval-auc:0.45664
[104]	eval-auc:0.45631
[105]	eval-auc:0.45764
[106]	eval-auc:0.45565
[107]	eval-auc:0.45831
[108]	eval-auc:0.45631
[109]	eval-auc:0.45316
[110]	eval-auc:0.45399
[111]	eval-auc:0.45482
[112]	eval-auc:0.45432
[113]	eval-auc:0.45581
[114]	eval-auc:0.45498
[115]	eval-auc:0.45831
[116]	eval-auc:0.45864
[117]	eval-auc:0.46213
[118]	eval-auc:0.46213
[119]	eval-auc:0.46329
[120]	eval-auc:0.46495
[121]	eval-auc:0.46561
[122]	eval-auc:0.46478
[123]	eval-auc:0.46595
[124]	eval-auc:0.46628
[125]	eval-auc:0.46528
[126]	eval-auc:0.46429
[127]

[26]	eval-auc:0.60678
[27]	eval-auc:0.61642
[28]	eval-auc:0.62324
[29]	eval-auc:0.62666
[30]	eval-auc:0.61441
[31]	eval-auc:0.61140
[32]	eval-auc:0.62726
[33]	eval-auc:0.62425
[34]	eval-auc:0.63007
[35]	eval-auc:0.62244
[36]	eval-auc:0.62023
[37]	eval-auc:0.62826
[38]	eval-auc:0.63890
[39]	eval-auc:0.64492
[40]	eval-auc:0.63870
[41]	eval-auc:0.64412
[42]	eval-auc:0.64613
[43]	eval-auc:0.64071
[44]	eval-auc:0.63729
[45]	eval-auc:0.64051
[46]	eval-auc:0.64713
[47]	eval-auc:0.64372
[48]	eval-auc:0.64171
[49]	eval-auc:0.63127
[50]	eval-auc:0.63649
[51]	eval-auc:0.63107
[52]	eval-auc:0.62726
[53]	eval-auc:0.62284
[54]	eval-auc:0.63770
[55]	eval-auc:0.63208
[56]	eval-auc:0.63930
[57]	eval-auc:0.63950
[58]	eval-auc:0.64552
[59]	eval-auc:0.64492
[60]	eval-auc:0.64853
[61]	eval-auc:0.65616
[62]	eval-auc:0.65998
[63]	eval-auc:0.65616
[64]	eval-auc:0.65375
[65]	eval-auc:0.64833
[66]	eval-auc:0.64552
[67]	eval-auc:0.64492
[68]	eval-auc:0.64653
[69]	eval-auc:0.64472
[70]	eval-auc:0.64894
[71]	eval-

[129]	eval-auc:0.50877
[130]	eval-auc:0.50692
[131]	eval-auc:0.50778
[132]	eval-auc:0.50840
[133]	eval-auc:0.50519
[134]	eval-auc:0.50469
[135]	eval-auc:0.50408
[136]	eval-auc:0.50494
[137]	eval-auc:0.50544
[138]	eval-auc:0.50618
[139]	eval-auc:0.50667
[140]	eval-auc:0.50371
[141]	eval-auc:0.50642
[142]	eval-auc:0.50457
[143]	eval-auc:0.50383
[144]	eval-auc:0.50124
[145]	eval-auc:0.49951
[146]	eval-auc:0.49951
[147]	eval-auc:0.49741
[148]	eval-auc:0.49666
[149]	eval-auc:0.49592
[150]	eval-auc:0.49518
[151]	eval-auc:0.49629
[152]	eval-auc:0.49728
[153]	eval-auc:0.49815
[154]	eval-auc:0.49914
[155]	eval-auc:0.49938
[156]	eval-auc:0.49988
[157]	eval-auc:0.49963
[158]	eval-auc:0.49778
[159]	eval-auc:0.49666
[160]	eval-auc:0.49419
[161]	eval-auc:0.49493
[162]	eval-auc:0.49444
[163]	eval-auc:0.49456
[164]	eval-auc:0.49222
[165]	eval-auc:0.49024
[166]	eval-auc:0.48950
[167]	eval-auc:0.48802
[168]	eval-auc:0.49012
[169]	eval-auc:0.48765
[170]	eval-auc:0.48987
[171]	eval-auc:0.49197
[172]	eval-

[57]	eval-auc:0.67861
[58]	eval-auc:0.67832
[59]	eval-auc:0.67852
[60]	eval-auc:0.67685
[61]	eval-auc:0.67773
[62]	eval-auc:0.68018
[63]	eval-auc:0.68204
[64]	eval-auc:0.68312
[65]	eval-auc:0.68077
[66]	eval-auc:0.67920
[67]	eval-auc:0.67460
[68]	eval-auc:0.67362
[69]	eval-auc:0.67822
[70]	eval-auc:0.67822
[71]	eval-auc:0.67705
[72]	eval-auc:0.67636
[73]	eval-auc:0.67685
[74]	eval-auc:0.67558
[75]	eval-auc:0.67460
[76]	eval-auc:0.67430
[77]	eval-auc:0.67372
[78]	eval-auc:0.67274
[79]	eval-auc:0.67009
[80]	eval-auc:0.66804
[81]	eval-auc:0.66716
[82]	eval-auc:0.66569
[83]	eval-auc:0.66588
[84]	eval-auc:0.66481
[85]	eval-auc:0.66206
[86]	eval-auc:0.66187
[87]	eval-auc:0.66089
[88]	eval-auc:0.65727
[89]	eval-auc:0.66157
[90]	eval-auc:0.66040
[91]	eval-auc:0.66255
[92]	eval-auc:0.66490
[93]	eval-auc:0.66363
[94]	eval-auc:0.66275
[95]	eval-auc:0.66304
[96]	eval-auc:0.66520
[97]	eval-auc:0.66265
[98]	eval-auc:0.65962
[99]	eval-auc:0.66011
[100]	eval-auc:0.66157
[101]	eval-auc:0.65903
[102]	ev

[189]	eval-auc:0.54901
[190]	eval-auc:0.54705
[191]	eval-auc:0.54350
[192]	eval-auc:0.54634
[193]	eval-auc:0.54670
[194]	eval-auc:0.54439
[195]	eval-auc:0.54350
[196]	eval-auc:0.54457
[197]	eval-auc:0.54439
[198]	eval-auc:0.54616
[199]	eval-auc:0.54759
[200]	eval-auc:0.54492
[201]	eval-auc:0.54474
[202]	eval-auc:0.54403
[203]	eval-auc:0.54368
[204]	eval-auc:0.54634
[205]	eval-auc:0.54705
[206]	eval-auc:0.54812
[207]	eval-auc:0.54865
[208]	eval-auc:0.54901
[209]	eval-auc:0.54830
[210]	eval-auc:0.54616
[211]	eval-auc:0.54723
[212]	eval-auc:0.54688
[213]	eval-auc:0.54634
[214]	eval-auc:0.54599
[215]	eval-auc:0.54616
[216]	eval-auc:0.54599
[217]	eval-auc:0.54155
[218]	eval-auc:0.54048
[0]	eval-auc:0.52888
[1]	eval-auc:0.57667
[2]	eval-auc:0.62837
[3]	eval-auc:0.58348
[4]	eval-auc:0.57503
[5]	eval-auc:0.58373
[6]	eval-auc:0.59786
[7]	eval-auc:0.58108
[8]	eval-auc:0.58512
[9]	eval-auc:0.58235
[10]	eval-auc:0.55914
[11]	eval-auc:0.54124
[12]	eval-auc:0.55839
[13]	eval-auc:0.56242
[14]	eval-au

[134]	eval-auc:0.51368
[135]	eval-auc:0.51442
[136]	eval-auc:0.51356
[137]	eval-auc:0.51319
[138]	eval-auc:0.51049
[139]	eval-auc:0.51172
[140]	eval-auc:0.51343
[141]	eval-auc:0.51319
[142]	eval-auc:0.51442
[143]	eval-auc:0.51184
[144]	eval-auc:0.51221
[145]	eval-auc:0.51221
[146]	eval-auc:0.51110
[147]	eval-auc:0.51466
[148]	eval-auc:0.51208
[149]	eval-auc:0.51221
[150]	eval-auc:0.50988
[151]	eval-auc:0.50656
[152]	eval-auc:0.50791
[153]	eval-auc:0.50791
[154]	eval-auc:0.50730
[155]	eval-auc:0.50902
[156]	eval-auc:0.50877
[157]	eval-auc:0.51000
[158]	eval-auc:0.51147
[159]	eval-auc:0.51061
[160]	eval-auc:0.51037
[161]	eval-auc:0.50975
[162]	eval-auc:0.50853
[163]	eval-auc:0.50877
[164]	eval-auc:0.50853
[165]	eval-auc:0.50853
[166]	eval-auc:0.50767
[167]	eval-auc:0.50804
[168]	eval-auc:0.50816
[169]	eval-auc:0.50693
[170]	eval-auc:0.50742
[171]	eval-auc:0.50497
[172]	eval-auc:0.50620
[173]	eval-auc:0.50632
[174]	eval-auc:0.50252
[175]	eval-auc:0.50141
[176]	eval-auc:0.50178
[177]	eval-

[291]	eval-auc:0.59291
[292]	eval-auc:0.59372
[293]	eval-auc:0.59439
[294]	eval-auc:0.59452
[295]	eval-auc:0.59599
[296]	eval-auc:0.59599
[297]	eval-auc:0.59465
[298]	eval-auc:0.59532
[299]	eval-auc:0.59519
[300]	eval-auc:0.59532
[301]	eval-auc:0.59639
[302]	eval-auc:0.59679
[303]	eval-auc:0.59786
[304]	eval-auc:0.59853
[305]	eval-auc:0.59799
[306]	eval-auc:0.59773
[307]	eval-auc:0.59706
[308]	eval-auc:0.59786
Avg AUC: 0.5418403523690837
Avg F1 Score: 0.48823523498773863


In [2]:
print(f"Baseline Avg AUC: {avg_auc_baseline}")
print(f"Baseline Avg F1 Score: {avg_f1_baseline}")

Baseline Avg AUC: 0.5418403523690837
Baseline Avg F1 Score: 0.48823523498773863
