In [4]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Retrain an MLP using a SHAP feature list and Optuna best_params.
"""

import json
import numpy as np
import pandas as pd
from pathlib import Path
from joblib import dump
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import RobustScaler, FunctionTransformer, LabelEncoder
from sklearn.compose import ColumnTransformer
from sklearn.pipeline import Pipeline
from sklearn.impute import SimpleImputer
from sklearn.neural_network import MLPClassifier
from sklearn.metrics import accuracy_score, f1_score, classification_report

# ================== PATHS / CONFIG ==================
DATA_PATH = Path("dataset_enriched_filtered.xlsx")   # <-- adjust
TARGET_COL = "phase_disambiguated"

RANDOM_STATE = 42
TEST_SIZE = 0.2

# ================== List of features selected by SHAP ==================
SELECTED_FEATURES = [
    "critical_temperature_max",
    "electrical_resist_sum",
    "critical_temperature_mean",
    "critical_temperature_sum",
    "critical_temperature_stdv",
    "electrical_resist_stdv",
    "electrical_resist_mean",
    "vickers_hardness_min",
    "electrical_resist_max",
    "atomic_en_sanderson_min_to_max_ratio",
    "electrical_resist_range",
    "atomic_hfu_std_over_range",
    "gibbs-oxides_min_to_max_ratio",
    "atomic_ea_min_to_max_ratio",
    "rigidity_mod_min_to_max_ratio",
    "deltacp-oxides_min_to_max_ratio",
    "deltacp-oxides_cov",
    "reflectivity_min",
    "covalent_rad_emp_sum",
    "poissons_ratio_min_to_max_ratio",
    "enthalpy-oxides_max",
    "thermal_conduct_min",
    "covalent_rad_range",
    "deltacp-oxides_min",
    "atomic_ea_std_over_range",
    "bulk_mod_min_to_max_ratio",
    "gibbs-oxides_max",
    "entropy-oxides_mean",
    "atomic_enc_sum",
    "covalent_rad_min_to_max_ratio",
    "coeff_of_lte_std_over_range",
    "atomic_orbital_radii_cov",
    "valence_d_electrons_cov",
    "has_zero_paul_ionic_radii_min",
    "vec_stdv",
    "atomic_en_sanderson_min",
    "enthalpy-oxides_std_over_range",
    "refract_index_stdv",
    "covalent_rad_emp_stdv",
    "oxidation_std_over_range",
    "rigidity_mod_std_over_range",
    "coeff_of_lte_min",
    "vec_sum",
    "covalent_rad_stdv",
    "deltacp-oxides_range",
    "z_std_over_range",
    "atomic_enc_cov",
    "atomic_spacegroupnum_std_over_range",
    "rigidity_mod_range",
    "deltacp-oxides_std_over_range",
    "boiling_point_stdv",
    "atomic_en_allen__max",
    "atomic_hatm_mean",
    "atomic_en_sanderson_max",
    "atomic_en_allredroch_max",
    "atomic_ebe__range"
]

# ================== BEST PARAMS (from your Optuna run) ==================
best_params = {
    "n_layers": 1,
    "n_units_l1": 24,
    "activation": "tanh",
    "solver": "adam",
    "alpha": 0.0023404113539471617,
    "tol": 1.2324561293087634e-05,
    "max_iter": 30000,
    "shuffle": True,
    "warm_start": False,
    "batch_choice": "auto",
    "learning_rate_init": 0.001143876930032853,
    "beta_1": 0.8891631637090696,
    "beta_2": 0.999513887975822,
    "epsilon": 1.6097066118907405e-10
}


def main():
    # 1) Load dataset
    if not DATA_PATH.exists():
        raise FileNotFoundError(f"Dataset not found: {DATA_PATH}")

    df = pd.read_excel(DATA_PATH)
    if TARGET_COL not in df.columns:
        raise KeyError(f"Target column '{TARGET_COL}' not found in dataset. Available: {list(df.columns)[:10]} ...")

    df = df.dropna(subset=[TARGET_COL]).reset_index(drop=True)
    y_raw = df[TARGET_COL]
    X_full = df.drop(columns=[TARGET_COL]).replace([np.inf, -np.inf], np.nan)

    # 2) Ensure all SELECTED_FEATURES are present
    missing = [c for c in SELECTED_FEATURES if c not in X_full.columns]
    if missing:
        raise KeyError(
            "O dataset não contém TODAS as features fixas requeridas.\n"
            f"Faltando ({len(missing)}): {missing[:10]}{' ...' if len(missing) > 10 else ''}\n"
            "Garanta que estas colunas existam antes de treinar."
        )

    # Keep only selected features in fixed order
    X = X_full[SELECTED_FEATURES].apply(pd.to_numeric, errors="coerce")

    # 3) Split + Label encode
    le = LabelEncoder()
    y = le.fit_transform(y_raw)
    X_train, X_test, y_train, y_test = train_test_split(
        X, y, test_size=TEST_SIZE, stratify=y, random_state=RANDOM_STATE
    )

    # 4) Preprocessing (NO LAMBDA): impute -> scale
    pre = ColumnTransformer(
        transformers=[
            ("num", Pipeline([
                ("imputer", SimpleImputer(strategy="median")),
                ("scaler", RobustScaler()),
            ]), SELECTED_FEATURES),
        ],
        remainder="drop",
        verbose_feature_names_out=False
    )

    # 5) MLP from best_params
    n_layers = int(best_params.get("n_layers", 1))
    hidden_layer_sizes = tuple(int(best_params.get(f"n_units_l{i+1}", 100)) for i in range(n_layers))
    batch_choice = best_params.get("batch_choice", "auto")
    batch_size = "auto" if (isinstance(batch_choice, str) and batch_choice.lower() == "auto") else int(batch_choice)

    clf = MLPClassifier(
        hidden_layer_sizes=hidden_layer_sizes,
        activation=best_params.get("activation", "relu"),
        solver=best_params.get("solver", "adam"),
        alpha=float(best_params.get("alpha", 1e-4)),
        batch_size=batch_size,
        learning_rate="constant",
        learning_rate_init=float(best_params.get("learning_rate_init", 1e-3)),
        max_iter=int(best_params.get("max_iter", 30000)),
        shuffle=bool(best_params.get("shuffle", True)),
        random_state=RANDOM_STATE,
        tol=float(best_params.get("tol", 1e-4)),
        warm_start=bool(best_params.get("warm_start", False)),
        early_stopping=False,
        validation_fraction=0.1,
        n_iter_no_change=10,
        beta_1=float(best_params.get("beta_1", 0.9)),
        beta_2=float(best_params.get("beta_2", 0.999)),
        epsilon=float(best_params.get("epsilon", 1e-8)),
    )

    pipe = Pipeline([("pre", pre), ("clf", clf)])

    # Fit
    pipe.fit(X_train, y_train)

    y_hat = pipe.predict(X_test)
    acc = accuracy_score(y_test, y_hat)
    f1m = f1_score(y_test, y_hat, average="macro")
    print(f"Holdout Accuracy: {acc:.4f} | F1-macro: {f1m:.4f}")
    print("Classification report (label indices):\n", classification_report(y_test, y_hat))

if __name__ == "__main__":
    main()


Holdout Accuracy: 0.9790 | F1-macro: 0.9766
Classification report (label indices):
               precision    recall  f1-score   support

           0       1.00      0.97      0.98        96
           1       0.94      1.00      0.97        47

    accuracy                           0.98       143
   macro avg       0.97      0.98      0.98       143
weighted avg       0.98      0.98      0.98       143

