Baseline

In [1]:
import xgboost as xgb
from sklearn.model_selection import LeaveOneGroupOut
from sklearn.metrics import roc_auc_score, f1_score
from sklearn.metrics import roc_curve
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

auc_scores = []
f1_scores = []

params = {
    'max_depth': 6,
    'eta': 0.3,
    'objective': 'binary:logistic',
    'eval_metric': 'auc',
    "gamma": 0,
    "subsample": 1,
    "colsample_bytree": 1,
    "colsample_bylevel": 1,
    "scale_pos_weight": 1,
    "reg_alpha": 0,
    "reg_lambda": 1,
}

data_folder = './data_baseline_2/'

for event in range(4):

    # load data
    dtrain = xgb.DMatrix(data_folder + '{}_train.data'.format(event))
    dtest = xgb.DMatrix(data_folder + '{}_val.data'.format(event))
    deval = xgb.DMatrix(data_folder + '{}_val.data'.format(event))
    
    # train
    evallist = [(dtrain, 'train'), (deval, 'eval')]
    bst = xgb.train(params, dtrain, num_boost_round=1000, evals=[(dtest, 'eval')], early_stopping_rounds=200)
    # predict probabilities
    y_pred_proba = bst.predict(dtest)
    # Convert predicted probabilities to class labels
    y_pred_labels = (y_pred_proba > 0.5).astype(int)
    
    y_real = dtest.get_label()
    
    # Calculate AUC
    auc = roc_auc_score(y_real, y_pred_proba)
    auc_scores.append(auc)
    
    # Calculate F1 score
    f1 = f1_score(y_real, y_pred_labels, average='macro')
    f1_scores.append(f1)
    
# Average AUC and F1 scores across all groups
baseline_avg_auc = np.mean(auc_scores)
baseline_avg_f1 = np.mean(f1_scores)

print(f"Avg AUC: {baseline_avg_auc}")
print(f"Avg F1 Score: {baseline_avg_f1}")

[0]	eval-auc:0.49254
[1]	eval-auc:0.49492
[2]	eval-auc:0.52627
[3]	eval-auc:0.55227
[4]	eval-auc:0.56930
[5]	eval-auc:0.57949
[6]	eval-auc:0.58911
[7]	eval-auc:0.59431
[8]	eval-auc:0.59404
[9]	eval-auc:0.60142
[10]	eval-auc:0.59114
[11]	eval-auc:0.57047
[12]	eval-auc:0.57654
[13]	eval-auc:0.57588
[14]	eval-auc:0.56774
[15]	eval-auc:0.56369
[16]	eval-auc:0.55965
[17]	eval-auc:0.55656
[18]	eval-auc:0.56924
[19]	eval-auc:0.56957
[20]	eval-auc:0.55869
[21]	eval-auc:0.56178
[22]	eval-auc:0.56517
[23]	eval-auc:0.55855
[24]	eval-auc:0.55618
[25]	eval-auc:0.55470
[26]	eval-auc:0.55104
[27]	eval-auc:0.54918
[28]	eval-auc:0.54677
[29]	eval-auc:0.54524
[30]	eval-auc:0.54961
[31]	eval-auc:0.55284
[32]	eval-auc:0.55432
[33]	eval-auc:0.55426
[34]	eval-auc:0.55803
[35]	eval-auc:0.56131
[36]	eval-auc:0.56733
[37]	eval-auc:0.56738
[38]	eval-auc:0.56689
[39]	eval-auc:0.56771
[40]	eval-auc:0.56878
[41]	eval-auc:0.56725
[42]	eval-auc:0.57752
[43]	eval-auc:0.57506
[44]	eval-auc:0.57211
[45]	eval-auc:0.5711

[156]	eval-auc:0.61717
[157]	eval-auc:0.61856
[158]	eval-auc:0.61684
[159]	eval-auc:0.61591
[160]	eval-auc:0.61684
[161]	eval-auc:0.61717
[162]	eval-auc:0.61711
[163]	eval-auc:0.61704
[164]	eval-auc:0.61684
[165]	eval-auc:0.61651
[166]	eval-auc:0.61644
[167]	eval-auc:0.61651
[168]	eval-auc:0.61625
[169]	eval-auc:0.61724
[170]	eval-auc:0.61823
[171]	eval-auc:0.61770
[172]	eval-auc:0.61810
[173]	eval-auc:0.61889
[174]	eval-auc:0.61856
[175]	eval-auc:0.62042
[176]	eval-auc:0.62088
[177]	eval-auc:0.61903
[178]	eval-auc:0.61982
[179]	eval-auc:0.61790
[180]	eval-auc:0.61797
[181]	eval-auc:0.61816
[182]	eval-auc:0.61816
[183]	eval-auc:0.61803
[184]	eval-auc:0.61750
[185]	eval-auc:0.61823
[186]	eval-auc:0.61658
[187]	eval-auc:0.61644
[188]	eval-auc:0.61505
[189]	eval-auc:0.61525
[190]	eval-auc:0.61585
[191]	eval-auc:0.61757
[192]	eval-auc:0.61850
[193]	eval-auc:0.61949
[194]	eval-auc:0.61883
[195]	eval-auc:0.61856
[196]	eval-auc:0.61737
[197]	eval-auc:0.61744
[198]	eval-auc:0.61625
[199]	eval-

[265]	eval-auc:0.57531
[266]	eval-auc:0.57553
[267]	eval-auc:0.57566
[268]	eval-auc:0.57575
[269]	eval-auc:0.57588
[270]	eval-auc:0.57690
[271]	eval-auc:0.57751
[272]	eval-auc:0.57729
[273]	eval-auc:0.57632
[274]	eval-auc:0.57654
[275]	eval-auc:0.57637
[276]	eval-auc:0.57676
[277]	eval-auc:0.57637
[278]	eval-auc:0.57663
[279]	eval-auc:0.57606
[280]	eval-auc:0.57610
[281]	eval-auc:0.57632
[282]	eval-auc:0.57535
[283]	eval-auc:0.57615
[284]	eval-auc:0.57646
[285]	eval-auc:0.57575
[286]	eval-auc:0.57654
[287]	eval-auc:0.57756
[288]	eval-auc:0.57778
[289]	eval-auc:0.57778
[290]	eval-auc:0.57720
[291]	eval-auc:0.57637
[292]	eval-auc:0.57690
[293]	eval-auc:0.57751
[294]	eval-auc:0.57813
[295]	eval-auc:0.57795
[296]	eval-auc:0.57734
[297]	eval-auc:0.57720
[298]	eval-auc:0.57584
[299]	eval-auc:0.57681
[300]	eval-auc:0.57756
[301]	eval-auc:0.57773
[302]	eval-auc:0.57769
[303]	eval-auc:0.57804
[304]	eval-auc:0.57756
[305]	eval-auc:0.57760
[306]	eval-auc:0.57756
[307]	eval-auc:0.57698
[308]	eval-

[179]	eval-auc:0.52942
[180]	eval-auc:0.52986
[181]	eval-auc:0.53026
[182]	eval-auc:0.52932
[183]	eval-auc:0.52883
[184]	eval-auc:0.52873
[185]	eval-auc:0.52770
[186]	eval-auc:0.52602
[187]	eval-auc:0.52489
[188]	eval-auc:0.52558
[189]	eval-auc:0.52740
[190]	eval-auc:0.52784
[191]	eval-auc:0.52750
[192]	eval-auc:0.52661
[193]	eval-auc:0.52661
[194]	eval-auc:0.52740
[195]	eval-auc:0.52696
[196]	eval-auc:0.52661
[197]	eval-auc:0.52602
[198]	eval-auc:0.52499
[199]	eval-auc:0.52494
[200]	eval-auc:0.52430
[201]	eval-auc:0.52277
[202]	eval-auc:0.52257
[203]	eval-auc:0.52380
[204]	eval-auc:0.52252
[205]	eval-auc:0.52183
[206]	eval-auc:0.52016
[207]	eval-auc:0.52094
[208]	eval-auc:0.52188
[209]	eval-auc:0.52055
[210]	eval-auc:0.52099
[211]	eval-auc:0.52025
[212]	eval-auc:0.52040
[213]	eval-auc:0.51956
[214]	eval-auc:0.51981
[215]	eval-auc:0.52114
[216]	eval-auc:0.52075
[217]	eval-auc:0.52045
[218]	eval-auc:0.52075
[219]	eval-auc:0.52134
[220]	eval-auc:0.52282
[221]	eval-auc:0.52380
[222]	eval-

In [2]:
print(f"Avg AUC: {baseline_avg_auc}")
print(f"Avg F1 Score: {baseline_avg_f1}")

Avg AUC: 0.5706641538139068
Avg F1 Score: 0.49340357046278144
