Baseline

In [1]:
import xgboost as xgb
from sklearn.model_selection import LeaveOneGroupOut
from sklearn.metrics import roc_auc_score, f1_score
from sklearn.metrics import roc_curve
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

auc_scores = []
f1_scores = []

params = {
    'max_depth': 6,
    'eta': 0.3,
    'objective': 'binary:logistic',
    'eval_metric': 'auc',
    "gamma": 0,
    "subsample": 1,
    "colsample_bytree": 1,
    "colsample_bylevel": 1,
    "scale_pos_weight": 1,
    "reg_alpha": 0,
    "reg_lambda": 1,
}

data_folder = './data_baseline_2/'

for event in range(4):

    # load data
    dtrain = xgb.DMatrix(data_folder + '{}_train.data'.format(event))
    dtest = xgb.DMatrix(data_folder + '{}_val.data'.format(event))
    deval = xgb.DMatrix(data_folder + '{}_val.data'.format(event))
    
    # train
    evallist = [(dtrain, 'train'), (deval, 'eval')]
    bst = xgb.train(params, dtrain, num_boost_round=1000, evals=[(dtest, 'eval')], early_stopping_rounds=200)
    # predict probabilities
    y_pred_proba = bst.predict(dtest)
    # Convert predicted probabilities to class labels
    y_pred_labels = (y_pred_proba > 0.5).astype(int)
    
    y_real = dtest.get_label()
    
    # Calculate AUC
    auc = roc_auc_score(y_real, y_pred_proba)
    auc_scores.append(auc)
    
    # Calculate F1 score
    f1 = f1_score(y_real, y_pred_labels, average='macro')
    f1_scores.append(f1)
    
# Average AUC and F1 scores across all groups
baseline_avg_auc = np.mean(auc_scores)
baseline_avg_f1 = np.mean(f1_scores)

print(f"Baseline Avg AUC: {baseline_avg_auc}")
print(f"Basleline Avg F1 Score: {baseline_avg_f1}")

[0]	eval-auc:0.49584
[1]	eval-auc:0.58052
[2]	eval-auc:0.56837
[3]	eval-auc:0.55367
[4]	eval-auc:0.55710
[5]	eval-auc:0.54555
[6]	eval-auc:0.55120
[7]	eval-auc:0.55694
[8]	eval-auc:0.55613
[9]	eval-auc:0.53663
[10]	eval-auc:0.53210
[11]	eval-auc:0.52298
[12]	eval-auc:0.51361
[13]	eval-auc:0.51991
[14]	eval-auc:0.52310
[15]	eval-auc:0.52706
[16]	eval-auc:0.53130
[17]	eval-auc:0.53687
[18]	eval-auc:0.54252
[19]	eval-auc:0.53840
[20]	eval-auc:0.53550
[21]	eval-auc:0.54107
[22]	eval-auc:0.54276
[23]	eval-auc:0.54115
[24]	eval-auc:0.54163
[25]	eval-auc:0.54692
[26]	eval-auc:0.53618
[27]	eval-auc:0.52697
[28]	eval-auc:0.52625
[29]	eval-auc:0.53303
[30]	eval-auc:0.52924
[31]	eval-auc:0.53424
[32]	eval-auc:0.52019
[33]	eval-auc:0.52382
[34]	eval-auc:0.51672
[35]	eval-auc:0.51882
[36]	eval-auc:0.51539
[37]	eval-auc:0.51853
[38]	eval-auc:0.51829
[39]	eval-auc:0.51789
[40]	eval-auc:0.51458
[41]	eval-auc:0.51619
[42]	eval-auc:0.51595
[43]	eval-auc:0.51191
[44]	eval-auc:0.51490
[45]	eval-auc:0.5154

[164]	eval-auc:0.61642
[165]	eval-auc:0.61407
[166]	eval-auc:0.61254
[167]	eval-auc:0.61274
[168]	eval-auc:0.61274
[169]	eval-auc:0.61274
[170]	eval-auc:0.61284
[171]	eval-auc:0.61305
[172]	eval-auc:0.61356
[173]	eval-auc:0.61366
[174]	eval-auc:0.61315
[175]	eval-auc:0.61284
[176]	eval-auc:0.61315
[177]	eval-auc:0.61325
[178]	eval-auc:0.61336
[179]	eval-auc:0.61397
[180]	eval-auc:0.61417
[181]	eval-auc:0.61509
[182]	eval-auc:0.61591
[183]	eval-auc:0.61622
[184]	eval-auc:0.61591
[185]	eval-auc:0.61612
[186]	eval-auc:0.61581
[187]	eval-auc:0.61673
[188]	eval-auc:0.61601
[189]	eval-auc:0.61806
[190]	eval-auc:0.61816
[191]	eval-auc:0.61806
[192]	eval-auc:0.61867
[193]	eval-auc:0.61929
[194]	eval-auc:0.61929
[195]	eval-auc:0.61949
[196]	eval-auc:0.61755
[197]	eval-auc:0.61970
[198]	eval-auc:0.61959
[199]	eval-auc:0.61980
[200]	eval-auc:0.61908
[201]	eval-auc:0.61755
[202]	eval-auc:0.61775
[203]	eval-auc:0.61806
[204]	eval-auc:0.61693
[205]	eval-auc:0.61734
[206]	eval-auc:0.61560
[207]	eval-

[70]	eval-auc:0.50598
[71]	eval-auc:0.51277
[72]	eval-auc:0.51220
[73]	eval-auc:0.50621
[74]	eval-auc:0.50230
[75]	eval-auc:0.51001
[76]	eval-auc:0.50978
[77]	eval-auc:0.51738
[78]	eval-auc:0.51784
[79]	eval-auc:0.51450
[80]	eval-auc:0.51220
[81]	eval-auc:0.51059
[82]	eval-auc:0.50426
[83]	eval-auc:0.50426
[84]	eval-auc:0.50449
[85]	eval-auc:0.50713
[86]	eval-auc:0.50736
[87]	eval-auc:0.50345
[88]	eval-auc:0.50276
[89]	eval-auc:0.50219
[90]	eval-auc:0.50299
[91]	eval-auc:0.49919
[92]	eval-auc:0.49988
[93]	eval-auc:0.49827
[94]	eval-auc:0.49804
[95]	eval-auc:0.49747
[96]	eval-auc:0.49517
[97]	eval-auc:0.49862
[98]	eval-auc:0.50138
[99]	eval-auc:0.50012
[100]	eval-auc:0.50150
[101]	eval-auc:0.50334
[102]	eval-auc:0.50449
[103]	eval-auc:0.50587
[104]	eval-auc:0.50483
[105]	eval-auc:0.50518
[106]	eval-auc:0.50391
[107]	eval-auc:0.50391
[108]	eval-auc:0.50345
[109]	eval-auc:0.50748
[110]	eval-auc:0.50598
[111]	eval-auc:0.50621
[112]	eval-auc:0.50794
[113]	eval-auc:0.50840
[114]	eval-auc:0.5

In [2]:
print(f"Baseline Avg AUC: {baseline_avg_auc}")
print(f"Basleline Avg F1 Score: {baseline_avg_f1}")

Baseline Avg AUC: 0.5207592902450531
Basleline Avg F1 Score: 0.4837822142249584
