In [None]:
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

from sklearn.model_selection import train_test_split
from sklearn.preprocessing import OneHotEncoder
from sklearn.compose import ColumnTransformer
from sklearn.pipeline import Pipeline
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import RandomForestClassifier

from magx import MagXExplainer

In [None]:
# Load seaborn titanic dataset
titanic = sns.load_dataset("titanic").dropna(subset=["survived"])

# Choose some useful features
features = ["pclass", "sex", "age", "sibsp", "parch", "fare", "embarked"]
target = "survived"

data = titanic[features + [target]].dropna()

X = data[features]
y = data[target].values

numeric_features = ["age", "sibsp", "parch", "fare"]
categorical_features = ["pclass", "sex", "embarked"]

numeric_transformer = "passthrough"
categorical_transformer = OneHotEncoder(handle_unknown="ignore", drop=None)

preprocessor = ColumnTransformer(
    transformers=[
        ("num", numeric_transformer, numeric_features),
        ("cat", categorical_transformer, categorical_features),
    ]
)

X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.2, random_state=42, stratify=y
)

print(X_train.head())


In [None]:
log_reg = Pipeline(
    steps=[
        ("preprocess", preprocessor),
        ("model", LogisticRegression(max_iter=1000)),
    ]
)

rf_clf = Pipeline(
    steps=[
        ("preprocess", preprocessor),
        ("model", RandomForestClassifier(
            n_estimators=200,
            random_state=42,
        )),
    ]
)

log_reg.fit(X_train, y_train)
rf_clf.fit(X_train, y_train)

print("LogReg train acc:", log_reg.score(X_train, y_train))
print("LogReg test acc:", log_reg.score(X_test, y_test))
print("RF train acc:", rf_clf.score(X_train, y_train))
print("RF test acc:", rf_clf.score(X_test, y_test))


In [None]:
# ColumnTransformer + OneHotEncoder -> get actual feature names
ohe: OneHotEncoder = preprocessor.named_transformers_["cat"]
cat_feature_names = ohe.get_feature_names_out(categorical_features)
all_feature_names = list(numeric_features) + list(cat_feature_names)

print("Total transformed features:", len(all_feature_names))
print(all_feature_names)


In [None]:
# Transform X_train into the model's feature space for MagX
X_train_transformed = preprocessor.fit_transform(X_train)
X_train_df = pd.DataFrame(X_train_transformed, columns=all_feature_names)

magx_logreg = MagXExplainer(
    model=log_reg,
    X_train=X_train_df,
    y_train=y_train,
    feature_names=all_feature_names,
    task_type="classification",
    class_names=["died", "survived"],
)

magx_rf = MagXExplainer(
    model=rf_clf,
    X_train=X_train_df,
    y_train=y_train,
    feature_names=all_feature_names,
    task_type="classification",
    class_names=["died", "survived"],
)


In [None]:
x0_raw = X_test.iloc[0]
x0_pre = preprocessor.transform(x0_raw.to_frame().T)
x0_df = pd.DataFrame(x0_pre, columns=all_feature_names)

print("Passenger features (raw):")
display(x0_raw)

print("Passenger prediction (LogReg):", log_reg.predict(x0_raw.to_frame().T)[0])
print("Passenger prediction (RF):", rf_clf.predict(x0_raw.to_frame().T)[0])


In [None]:
from matplotlib import pyplot as plt

# Logistic Regression global
global_log = magx_logreg.explain_global()
magx_logreg.plot_global(top_k=10, theme="light")
plt.show()

print(magx_logreg.explain_global_text(top_k=5))

# Random Forest global
global_rf = magx_rf.explain_global()
magx_rf.plot_global(top_k=10, theme="dark")
plt.show()

print(magx_rf.explain_global_text(top_k=5))


In [None]:
x0_vector = x0_df.iloc[0]

# LogReg
print("=== Logistic Regression: Local Explanation ===")
magx_logreg.plot_local(x0_vector, top_k=10, theme="light")
plt.show()

print(magx_logreg.explain_local_text(x0_vector, top_k=5, instance_id="Passenger 0"))

metrics_log = magx_logreg.evaluate_local(x0_vector, top_k=5)
print("Local explanation metrics (LogReg):", metrics_log)

# RF
print("\n=== Random Forest: Local Explanation ===")
magx_rf.plot_local(x0_vector, top_k=10, theme="dark")
plt.show()

print(magx_rf.explain_local_text(x0_vector, top_k=5, instance_id="Passenger 0"))

metrics_rf = magx_rf.evaluate_local(x0_vector, top_k=5)
print("Local explanation metrics (RF):", metrics_rf)
