In [3]:
import pandas as pd
import numpy as np
import shap
from sklearn.datasets import fetch_openml
import pytest
from torsha_ml.models import *

import plotly as plotly
from sklearn.datasets import make_classification, make_regression
from torsha_ml.evaluation.classification_plots import (
    create_classifier_explainer,
    plot_confusion_matrix,
    plot_precision_recall_curve,
    plot_roc_auc_curve,
    classification_plot,
    precision_plot,
    plot_lift_curve,
    plot_density_chart,
    plot_class_prediction_error)

from sklearn.linear_model import LinearRegression, Ridge
import warnings
warnings.filterwarnings('ignore')

In [5]:
external_models = [
    XGBClassifier(verbose=0), 
    CatBoostClassifier(),
    EarthClassifier(),
    ExplainableBoostingClassifier(), 
    SymbolicClassifier(),
    LinearTreeClassifier(RidgeClassifier()), 
    LinearForestClassifier(LinearRegression()), 
    LinearBoostClassifier(RidgeClassifier())]

In [8]:
X, y = make_classification(
            n_samples=500,
            n_features=10,
            n_informative=4,
            n_classes=4,
            random_state=41)

data = pd.DataFrame(X, columns=["X_" + str(i) for i in range(X.shape[1])])

In [9]:
model = SymbolicClassifier()
model.fit(data,y)

ValueError: y contains 4 class after sample_weight trimmed classes with zero weights, while 2 classes are required.

In [6]:
X, y = make_classification(
            n_samples=500,
            n_features=10,
            random_state=41)

data = pd.DataFrame(X, columns=["X_" + str(i) for i in range(X.shape[1])])
for model in external_models:
    model.fit(data, y)
    explainer = create_classifier_explainer(data, y, model)
    fig = plot_confusion_matrix(explainer, cutoff=None, percentage=False,
    normalize=None, binary=False, pos_label=None)
    isinstance(fig, plotly.graph_objs._figure.Figure)

Parameters: { "verbose" } might not be used.

  This could be a false alarm, with some parameters getting used by language bindings but
  then being mistakenly passed down to XGBoost core, or some parameter actually being used
  but getting flagged wrongly here. Please open an issue if you find any such cases.


Detected XGBClassifier model: Changing class type to XGBClassifierExplainer...
Note: model_output=='probability'. For XGBClassifier shap values normally get calculated against X_background, but paramater X_background=None, so using X instead
Generating self.shap_explainer = shap.TreeExplainer(model, X, model_output='probability', feature_perturbation='interventional')...
Note: Shap interaction values will not be available. If shap values in probability space are not necessary you can pass model_output='logodds' to get shap values in logodds without the need for a background dataset and also working shap interaction values...
<class 'explainerdashboard.explainers.XGBClassifierEx

In [4]:
for model in external_models:
    model.fit(data, y)
    explainer = create_classifier_explainer(data, y, model)
    fig = plot_precision_recall_curve(explainer, cutoff=None, pos_label=None)
    fig.show()
                

Parameters: { "verbose" } might not be used.

  This could be a false alarm, with some parameters getting used by language bindings but
  then being mistakenly passed down to XGBoost core, or some parameter actually being used
  but getting flagged wrongly here. Please open an issue if you find any such cases.


Detected XGBClassifier model: Changing class type to XGBClassifierExplainer...
model_output=='probability' does not work with multiclass XGBClassifier models, so settings model_output='logodds'...
Generating self.shap_explainer = shap.TreeExplainer(model)
<class 'explainerdashboard.explainers.XGBClassifierExplainer'>
Calculating pr auc curves...
Calculating prediction probabilities...
<class 'plotly.graph_objs._figure.Figure'>


Learning rate set to 0.067916
0:	learn: 1.3681689	total: 48.1ms	remaining: 48.1s
1:	learn: 1.3454924	total: 49.1ms	remaining: 24.5s
2:	learn: 1.3291813	total: 49.8ms	remaining: 16.5s
3:	learn: 1.3140799	total: 50.4ms	remaining: 12.5s
4:	learn: 1.2928688	total: 50.8ms	remaining: 10.1s
5:	learn: 1.2761510	total: 51.4ms	remaining: 8.51s
6:	learn: 1.2579977	total: 51.8ms	remaining: 7.35s
7:	learn: 1.2417881	total: 52.3ms	remaining: 6.49s
8:	learn: 1.2242856	total: 52.8ms	remaining: 5.82s
9:	learn: 1.2018635	total: 53.3ms	remaining: 5.28s
10:	learn: 1.1860800	total: 53.8ms	remaining: 4.84s
11:	learn: 1.1612073	total: 54.3ms	remaining: 4.47s
12:	learn: 1.1471120	total: 54.8ms	remaining: 4.16s
13:	learn: 1.1345823	total: 55.3ms	remaining: 3.9s
14:	learn: 1.1203126	total: 55.8ms	remaining: 3.67s
15:	learn: 1.1048415	total: 56.4ms	remaining: 3.47s
16:	learn: 1.0899932	total: 56.9ms	remaining: 3.29s
17:	learn: 1.0750455	total: 57.4ms	remaining: 3.13s
18:	learn: 1.0612760	total: 58.2ms	remaining: