Baseline

In [1]:
import xgboost as xgb
from sklearn.model_selection import LeaveOneGroupOut
from sklearn.metrics import roc_auc_score, f1_score
from sklearn.metrics import roc_curve
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

auc_scores = []
f1_scores = []

params = {
    'max_depth': 6,
    'eta': 0.3,
    'objective': 'binary:logistic',
    'eval_metric': 'auc',
    "gamma": 0,
    "subsample": 1,
    "colsample_bytree": 1,
    "colsample_bylevel": 1,
    "scale_pos_weight": 1,
    "reg_alpha": 0,
    "reg_lambda": 1,
}

data_folder = './data_baseline_2/'

for event in range(5):

    # load data
    dtrain = xgb.DMatrix(data_folder + '{}_train.data'.format(event))
    dtest = xgb.DMatrix(data_folder + '{}_val.data'.format(event))
    deval = xgb.DMatrix(data_folder + '{}_val.data'.format(event))
    
    # train
    evallist = [(dtrain, 'train'), (deval, 'eval')]
    bst = xgb.train(params, dtrain, num_boost_round=1000, evals=[(dtest, 'eval')], early_stopping_rounds=200)
    # predict probabilities
    y_pred_proba = bst.predict(dtest)
    # Convert predicted probabilities to class labels
    y_pred_labels = (y_pred_proba > 0.5).astype(int)
    
    y_real = dtest.get_label()
    
    # Calculate AUC
    auc = roc_auc_score(y_real, y_pred_proba)
    auc_scores.append(auc)
    
    # Calculate F1 score
    f1 = f1_score(y_real, y_pred_labels, average='macro')
    f1_scores.append(f1)
    
# Average AUC and F1 scores across all groups
baseline_avg_auc = np.mean(auc_scores)
baseline_avg_f1 = np.mean(f1_scores)

print(f"Baseline Avg AUC: {baseline_avg_auc}")
print(f"Basleline Avg F1 Score: {baseline_avg_f1}")

[0]	eval-auc:0.40676
[1]	eval-auc:0.41678
[2]	eval-auc:0.43716
[3]	eval-auc:0.45678
[4]	eval-auc:0.46667
[5]	eval-auc:0.47800
[6]	eval-auc:0.51347
[7]	eval-auc:0.51879
[8]	eval-auc:0.50625
[9]	eval-auc:0.51473
[10]	eval-auc:0.52135
[11]	eval-auc:0.52844
[12]	eval-auc:0.52210
[13]	eval-auc:0.53035
[14]	eval-auc:0.54396
[15]	eval-auc:0.54014
[16]	eval-auc:0.54042
[17]	eval-auc:0.54737
[18]	eval-auc:0.54998
[19]	eval-auc:0.54312
[20]	eval-auc:0.53949
[21]	eval-auc:0.54788
[22]	eval-auc:0.53809
[23]	eval-auc:0.52783
[24]	eval-auc:0.53026
[25]	eval-auc:0.52075
[26]	eval-auc:0.51030
[27]	eval-auc:0.50881
[28]	eval-auc:0.51497
[29]	eval-auc:0.51720
[30]	eval-auc:0.51590
[31]	eval-auc:0.51823
[32]	eval-auc:0.51497
[33]	eval-auc:0.51786
[34]	eval-auc:0.51972
[35]	eval-auc:0.52065
[36]	eval-auc:0.52107
[37]	eval-auc:0.51967
[38]	eval-auc:0.51944
[39]	eval-auc:0.52084
[40]	eval-auc:0.51776
[41]	eval-auc:0.52308
[42]	eval-auc:0.52084
[43]	eval-auc:0.51515
[44]	eval-auc:0.51636
[45]	eval-auc:0.5162

[147]	eval-auc:0.50959
[148]	eval-auc:0.50838
[149]	eval-auc:0.50783
[150]	eval-auc:0.50820
[151]	eval-auc:0.50866
[152]	eval-auc:0.50755
[153]	eval-auc:0.50774
[154]	eval-auc:0.50866
[155]	eval-auc:0.50885
[156]	eval-auc:0.50829
[157]	eval-auc:0.50653
[158]	eval-auc:0.50866
[159]	eval-auc:0.50736
[160]	eval-auc:0.50959
[161]	eval-auc:0.51172
[162]	eval-auc:0.51172
[163]	eval-auc:0.51033
[164]	eval-auc:0.50987
[165]	eval-auc:0.50931
[166]	eval-auc:0.50894
[167]	eval-auc:0.50959
[168]	eval-auc:0.51024
[169]	eval-auc:0.51014
[170]	eval-auc:0.51088
[171]	eval-auc:0.51116
[172]	eval-auc:0.51227
[173]	eval-auc:0.51237
[174]	eval-auc:0.51394
[175]	eval-auc:0.51515
[176]	eval-auc:0.51459
[177]	eval-auc:0.51579
[178]	eval-auc:0.51561
[179]	eval-auc:0.51579
[180]	eval-auc:0.51570
[181]	eval-auc:0.51542
[182]	eval-auc:0.51505
[183]	eval-auc:0.51450
[184]	eval-auc:0.51413
[185]	eval-auc:0.51440
[186]	eval-auc:0.51440
[187]	eval-auc:0.51431
[188]	eval-auc:0.51366
[189]	eval-auc:0.51496
[190]	eval-

[55]	eval-auc:0.46672
[56]	eval-auc:0.46561
[57]	eval-auc:0.46758
[58]	eval-auc:0.46770
[59]	eval-auc:0.46709
[60]	eval-auc:0.46733
[61]	eval-auc:0.46401
[62]	eval-auc:0.46352
[63]	eval-auc:0.46044
[64]	eval-auc:0.46561
[65]	eval-auc:0.45503
[66]	eval-auc:0.45687
[67]	eval-auc:0.45281
[68]	eval-auc:0.45466
[69]	eval-auc:0.45330
[70]	eval-auc:0.45293
[71]	eval-auc:0.45417
[72]	eval-auc:0.45564
[73]	eval-auc:0.45490
[74]	eval-auc:0.45293
[75]	eval-auc:0.45010
[76]	eval-auc:0.44850
[77]	eval-auc:0.44986
[78]	eval-auc:0.45010
[79]	eval-auc:0.45220
[80]	eval-auc:0.45244
[81]	eval-auc:0.45047
[82]	eval-auc:0.44998
[83]	eval-auc:0.44617
[84]	eval-auc:0.44838
[85]	eval-auc:0.45010
[86]	eval-auc:0.44924
[87]	eval-auc:0.44481
[88]	eval-auc:0.44617
[89]	eval-auc:0.44961
[90]	eval-auc:0.45097
[91]	eval-auc:0.45220
[92]	eval-auc:0.45576
[93]	eval-auc:0.45306
[94]	eval-auc:0.45675
[95]	eval-auc:0.45527
[96]	eval-auc:0.45453
[97]	eval-auc:0.45207
[98]	eval-auc:0.45269
[99]	eval-auc:0.45146
[100]	eval

[211]	eval-auc:0.53699
[212]	eval-auc:0.53783
[213]	eval-auc:0.53854
[214]	eval-auc:0.53776
[215]	eval-auc:0.53833
Baseline Avg AUC: 0.5040172863886122
Basleline Avg F1 Score: 0.48345938039997893


In [2]:
print(f"Baseline Avg AUC: {baseline_avg_auc}")
print(f"Basleline Avg F1 Score: {baseline_avg_f1}")

Baseline Avg AUC: 0.5040172863886122
Basleline Avg F1 Score: 0.48345938039997893
