In [6]:
import numpy as np
import pandas as pd
from sklearn.neural_network import MLPClassifier
from sklearn.metrics import accuracy_score, f1_score, roc_auc_score, precision_recall_curve, auc, classification_report
from sklearn.preprocessing import LabelBinarizer
import xgboost as xgb
from sklearn.ensemble import VotingClassifier

In [2]:
train_df = pd.read_csv('../../../dataset/open_world/openworld_train.csv')
test_df = pd.read_csv('../../../dataset/open_world/openworld_test.csv')

# Feature / Target Split
X_train = train_df.drop(columns=["label"]).values
y_train = train_df["label"].values

X_test = test_df.drop(columns=["label"]).values
y_test = test_df["label"].values

In [3]:
# Train MLP (Best param.)
best_mlp = MLPClassifier(
    hidden_layer_sizes=(384, 512),
    activation='tanh',
    solver='adam',
    learning_rate_init=0.0009118650899147087,
    alpha=1.9155906182703047e-06,
    batch_size=64,
    max_iter=300,
    random_state=42
)

best_mlp.fit(X_train, y_train)

In [7]:
# Train XGB (Best param.)
xgb_model = xgb.XGBClassifier(
    objective='multi:softprob',
    num_class=96,  # label range: 0~94

    # === Optuna Best Params ===
    learning_rate=0.16862524911487684,
    max_depth=9,
    min_child_weight=4,
    subsample=0.8754925874208443,
    colsample_bytree=0.9905531435470902,
    gamma=0.004617282882739576,
    reg_lambda=1.3105248980670992,
    reg_alpha=0.0030352556227954,

    # === Recommended add-ons ===
    n_estimators=300,          # boosting rounds
    eval_metric="mlogloss",
    tree_method="hist",
    random_state=42
)

xgb_model.fit(X_train, y_train)

In [8]:
# Weighted soft voting ensemble
voting_clf = VotingClassifier(
    estimators=[
        ('mlp', best_mlp),
        ('xgb', xgb_model)
    ],
    voting='soft',
    weights=[1, 2]   
)

voting_clf.fit(X_train, y_train)

In [9]:
# Evaluation
# Pred label
test_pred = voting_clf.predict(X_test)

# Pred probability
test_proba = voting_clf.predict_proba(X_test)

acc = accuracy_score(y_test, test_pred)
f1_macro = f1_score(y_test, test_pred, average='macro')
f1_micro = f1_score(y_test, test_pred, average='micro')
f1_weighted = f1_score(y_test, test_pred, average='weighted')

lb = LabelBinarizer()
y_test_bin = lb.fit_transform(y_test)

# binary safety check
if y_test_bin.shape[1] == 1:
    y_test_bin = np.hstack([1 - y_test_bin, y_test_bin])

roc_auc_macro = roc_auc_score(
    y_test_bin,
    test_proba,
    average='macro',
    multi_class='ovr'
)

pr_aucs = []
for k in range(y_test_bin.shape[1]):
    y_true_k = y_test_bin[:, k]
    y_score_k = test_proba[:, k]

    if y_true_k.sum() == 0:
        continue

    prec, rec, _ = precision_recall_curve(y_true_k, y_score_k)
    pr_aucs.append(auc(rec, prec))

pr_auc_macro = np.mean(pr_aucs)

print("\n========== SOFT VOTING TEST RESULTS ==========")
print(f"Accuracy        : {acc:.4f}")
print(f"F1 (macro)      : {f1_macro:.4f}")
print(f"F1 (micro)      : {f1_micro:.4f}")
print(f"F1 (weighted)   : {f1_weighted:.4f}")
print(f"ROC-AUC (macro) : {roc_auc_macro:.4f}")
print(f"PR-AUC (macro)  : {pr_auc_macro:.4f}")

print("\nClassification Report:")
print(classification_report(y_test, test_pred, digits=4))


Accuracy        : 0.8101
F1 (macro)      : 0.7671
F1 (micro)      : 0.8101
F1 (weighted)   : 0.8056
ROC-AUC (macro) : 0.9922
PR-AUC (macro)  : 0.8320

Classification Report:
              precision    recall  f1-score   support

           0     0.8696    0.6667    0.7547        60
           1     0.9231    0.8000    0.8571        60
           2     0.9583    0.7667    0.8519        60
           3     0.8621    0.8333    0.8475        60
           4     0.9130    0.7000    0.7925        60
           5     0.7500    0.6000    0.6667        60
           6     0.7681    0.8833    0.8217        60
           7     0.7869    0.8000    0.7934        60
           8     0.7778    0.7000    0.7368        60
           9     0.7551    0.6167    0.6789        60
          10     0.8250    0.5500    0.6600        60
          11     0.8800    0.7333    0.8000        60
          12     0.8750    0.9333    0.9032        60
          13     0.6042    0.4833    0.5370        60
          14  