Baseline

In [1]:
import xgboost as xgb
from sklearn.model_selection import LeaveOneGroupOut
from sklearn.metrics import roc_auc_score, f1_score
from sklearn.metrics import roc_curve
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

auc_scores = []
f1_scores = []

params = {
    'max_depth': 6,
    'eta': 0.3,
    'objective': 'binary:logistic',
    'eval_metric': 'auc',
    "gamma": 0,
    "subsample": 1,
    "colsample_bytree": 1,
    "colsample_bylevel": 1,
    "scale_pos_weight": 1,
    "reg_alpha": 0,
    "reg_lambda": 1,
}

data_folder = './data_baseline_2/'

for event in range(5):

    # load data
    dtrain = xgb.DMatrix(data_folder + '{}_train.data'.format(event))
    dtest = xgb.DMatrix(data_folder + '{}_val.data'.format(event))
    deval = xgb.DMatrix(data_folder + '{}_val.data'.format(event))
    
    # train
    evallist = [(dtrain, 'train'), (deval, 'eval')]
    bst = xgb.train(params, dtrain, num_boost_round=1000, evals=[(dtest, 'eval')], early_stopping_rounds=200)
    # predict probabilities
    y_pred_proba = bst.predict(dtest)
    # Convert predicted probabilities to class labels
    y_pred_labels = (y_pred_proba > 0.5).astype(int)
    
    y_real = dtest.get_label()
    
    # Calculate AUC
    auc = roc_auc_score(y_real, y_pred_proba)
    auc_scores.append(auc)
    
    # Calculate F1 score
    f1 = f1_score(y_real, y_pred_labels, average='macro')
    f1_scores.append(f1)
    
# Average AUC and F1 scores across all groups
baseline_avg_auc = np.mean(auc_scores)
baseline_avg_f1 = np.mean(f1_scores)

print(f"Avg AUC: {baseline_avg_auc}")
print(f"Avg F1 Score: {baseline_avg_f1}")

[0]	eval-auc:0.57363
[1]	eval-auc:0.55444
[2]	eval-auc:0.51426
[3]	eval-auc:0.50977
[4]	eval-auc:0.49553
[5]	eval-auc:0.49694
[6]	eval-auc:0.48658
[7]	eval-auc:0.48906
[8]	eval-auc:0.49920
[9]	eval-auc:0.49484
[10]	eval-auc:0.48515
[11]	eval-auc:0.48265
[12]	eval-auc:0.49464
[13]	eval-auc:0.49381
[14]	eval-auc:0.49174
[15]	eval-auc:0.51191
[16]	eval-auc:0.51014
[17]	eval-auc:0.50020
[18]	eval-auc:0.49777
[19]	eval-auc:0.49671
[20]	eval-auc:0.48821
[21]	eval-auc:0.49208
[22]	eval-auc:0.49287
[23]	eval-auc:0.49166
[24]	eval-auc:0.49169
[25]	eval-auc:0.48673
[26]	eval-auc:0.49382
[27]	eval-auc:0.49993
[28]	eval-auc:0.49225
[29]	eval-auc:0.49708
[30]	eval-auc:0.49152
[31]	eval-auc:0.50194
[32]	eval-auc:0.50769
[33]	eval-auc:0.50733
[34]	eval-auc:0.51216
[35]	eval-auc:0.51593
[36]	eval-auc:0.51262
[37]	eval-auc:0.51206
[38]	eval-auc:0.51814
[39]	eval-auc:0.51577
[40]	eval-auc:0.51873
[41]	eval-auc:0.51998
[42]	eval-auc:0.52024
[43]	eval-auc:0.51915
[44]	eval-auc:0.52024
[45]	eval-auc:0.5191

[165]	eval-auc:0.53842
[166]	eval-auc:0.53948
[167]	eval-auc:0.53899
[168]	eval-auc:0.54001
[169]	eval-auc:0.54078
[170]	eval-auc:0.54176
[171]	eval-auc:0.54278
[172]	eval-auc:0.54208
[173]	eval-auc:0.54042
[174]	eval-auc:0.54021
[175]	eval-auc:0.54062
[176]	eval-auc:0.54062
[177]	eval-auc:0.54001
[178]	eval-auc:0.54042
[179]	eval-auc:0.54086
[180]	eval-auc:0.54103
[181]	eval-auc:0.54326
[182]	eval-auc:0.54375
[183]	eval-auc:0.54412
[184]	eval-auc:0.54249
[185]	eval-auc:0.54127
[186]	eval-auc:0.54001
[187]	eval-auc:0.54066
[188]	eval-auc:0.53976
[189]	eval-auc:0.53968
[190]	eval-auc:0.53948
[191]	eval-auc:0.54098
[192]	eval-auc:0.54086
[193]	eval-auc:0.54058
[194]	eval-auc:0.54033
[195]	eval-auc:0.53972
[196]	eval-auc:0.54066
[197]	eval-auc:0.54074
[198]	eval-auc:0.53915
[199]	eval-auc:0.53907
[200]	eval-auc:0.54050
[201]	eval-auc:0.53972
[202]	eval-auc:0.54042
[203]	eval-auc:0.54070
[204]	eval-auc:0.53972
[0]	eval-auc:0.47569
[1]	eval-auc:0.46472
[2]	eval-auc:0.46425
[3]	eval-auc:0.47

[88]	eval-auc:0.51590
[89]	eval-auc:0.51675
[90]	eval-auc:0.51628
[91]	eval-auc:0.51557
[92]	eval-auc:0.51530
[93]	eval-auc:0.51563
[94]	eval-auc:0.51513
[95]	eval-auc:0.51631
[96]	eval-auc:0.51685
[97]	eval-auc:0.51736
[98]	eval-auc:0.51824
[99]	eval-auc:0.51634
[100]	eval-auc:0.51651
[101]	eval-auc:0.51624
[102]	eval-auc:0.51560
[103]	eval-auc:0.51580
[104]	eval-auc:0.51543
[105]	eval-auc:0.51699
[106]	eval-auc:0.51607
[107]	eval-auc:0.51705
[108]	eval-auc:0.51922
[109]	eval-auc:0.51932
[110]	eval-auc:0.52168
[111]	eval-auc:0.52206
[112]	eval-auc:0.52118
[113]	eval-auc:0.52185
[114]	eval-auc:0.52158
[115]	eval-auc:0.52233
[116]	eval-auc:0.52179
[117]	eval-auc:0.52084
[118]	eval-auc:0.52209
[119]	eval-auc:0.52182
[120]	eval-auc:0.52196
[121]	eval-auc:0.52263
[122]	eval-auc:0.52294
[123]	eval-auc:0.52419
[124]	eval-auc:0.52348
[125]	eval-auc:0.52476
[126]	eval-auc:0.52358
[127]	eval-auc:0.52239
[128]	eval-auc:0.52290
[129]	eval-auc:0.52250
[130]	eval-auc:0.52297
[131]	eval-auc:0.52324


In [2]:
print(f"Avg AUC: {baseline_avg_auc}")
print(f"Avg F1 Score: {baseline_avg_f1}")

Avg AUC: 0.5451685523118439
Avg F1 Score: 0.4936942754296624
