In [None]:
import pandas as pd
import numpy as np
import re
import matplotlib.pyplot as plt
from sklearn.model_selection import StratifiedKFold, train_test_split, cross_val_score, GridSearchCV
from sklearn.ensemble import RandomForestClassifier
from sklearn.preprocessing import LabelEncoder, MultiLabelBinarizer
from sklearn.metrics import classification_report, confusion_matrix, ConfusionMatrixDisplay

# === 1. Load CSV ===
df = pd.read_csv("Materials_With_AtLeast_OneMagneticElement.csv")

# === 2. Parse elements column: ['Element Fe', 'Element O'] → ['Fe', 'O'] ===
def parse_elements(val):
    if pd.isna(val):
        return []
    return [e.strip().split()[-1] for e in str(val).strip("[]").split(",")]
df["elements"] = df["elements"].apply(parse_elements)

# === 3. Extract crystal system from symmetry column ===
def extract_crystal_system(symmetry_str):
    if pd.isna(symmetry_str):
        return "Unknown"
    match = re.search(r"crystal_system=<CrystalSystem\.\w+: '(\w+)'", str(symmetry_str))
    return match.group(1) if match else "Unknown"
df["crystal_system"] = df["symmetry"].apply(extract_crystal_system)

# === 4. Drop rows with missing essential features ===
numerical_features = [
    "cbm", "vbm", "energy_above_hull", "band_gap", "density_atomic", "numberofelements",
    "volume", "nsites", "density", "efermi", "formation_energy_per_atom"
]
df = df.dropna(subset=numerical_features + ["ordering"])

# === 5. Merge FM, FiM, and AFM into 'Magnetic'; NM stays as is ===
def merge_classes(label):
    if label in ['FM', 'FiM', 'AFM']:
        return 'Magnetic'
    else:
        return label  # NM
df["ordering"] = df["ordering"].apply(merge_classes)

# === Optional: Check new class distribution ===
print("Class distribution after merging:")
print(df["ordering"].value_counts())

# === 6. One-hot encode elements ===
mlb = MultiLabelBinarizer()
element_df = pd.DataFrame(mlb.fit_transform(df["elements"]), columns=mlb.classes_)

# === 7. One-hot encode crystal system ===
crystal_df = pd.get_dummies(df["crystal_system"], prefix="crysys")

# === 8. Combine all features ===
X = pd.concat([
    df[numerical_features].reset_index(drop=True),
    element_df.reset_index(drop=True),
    crystal_df.reset_index(drop=True)
], axis=1)

# === 9. Encode target labels ===
le = LabelEncoder()
y = le.fit_transform(df["ordering"])  # 'Magnetic' = 1, 'NM' = 0

# === 10. Train/test split ===
X_train, X_holdout, y_train, y_holdout = train_test_split(X, y, test_size=0.20, stratify=y, random_state=42)

# === 11. Cross-validation setup and training ===
if X_train.shape[0] > 0 and len(np.unique(y_train)) > 1:
    skf = StratifiedKFold(n_splits=4, shuffle=True, random_state=42)

    param_grid = {
        'n_estimators': [250],
        'max_depth': [65],
        'min_samples_split': [5],
        'class_weight': ['balanced']
    }

    grid_search = GridSearchCV(
        RandomForestClassifier(random_state=42),
        param_grid,
        scoring='f1_weighted',
        cv=skf,
        n_jobs=-1,
        verbose=1
    )

    grid_search.fit(X_train, y_train)
    best_clf = grid_search.best_estimator_

    print(f"\n✅ Best Parameters: {grid_search.best_params_}\n")

    # === 12. Evaluate cross-validation performance ===
    cv_scores = cross_val_score(best_clf, X_train, y_train, cv=skf, scoring='f1_weighted', n_jobs=-1)
    print(f"Cross-validation scores: {cv_scores}")
    print(f"Mean CV score: {cv_scores.mean()}")

    # === 13. Evaluate on test set ===
    y_pred_holdout = best_clf.predict(X_holdout)
    print("Classification Report (Hold-Out Set):")
    print(classification_report(y_holdout, y_pred_holdout, target_names=le.classes_))

    cm_holdout = confusion_matrix(y_holdout, y_pred_holdout)
    disp = ConfusionMatrixDisplay(confusion_matrix=cm_holdout, display_labels=le.classes_)
    disp.plot(cmap="Blues", values_format="d")
    plt.title("Confusion Matrix (Hold-Out Set)")
    plt.tight_layout()
    plt.show()

else:
    print("❗ Not enough data or diversity in target classes to perform training.")
