## OmniXAI: A Library for Explainable AI

**Inroduction:**

**OmniXAI (short for Omni eXplainable AI)** is a Python machine-learning library for explainable AI (XAI), offering omni-way explainable AI and interpretable machine learning capabilities to address many pain points in explaining decisions made by machine learning models in practice. OmniXAI aims to be a one-stop comprehensive library that makes explainable AI easy for data scientists, ML researchers and practitioners who need explanation for various types of data, models and explanation methods at different stages of ML process

In [9]:
# Download Dataset
import pandas as pd
#importing Tablular form omnixai.data.table
from omnixai.data.tabular import Tabular
from omnixai.preprocessing.tabular import TabularTransform
from sklearn.model_selection import train_test_split
from xgboost import XGBClassifier
from omnixai.explainers.tabular import TabularExplainer
from sklearn.metrics import classification_report,accuracy_score
from omnixai.explainers.prediction import PredictionAnalyzer
from omnixai.visualization.dashboard import Dashboard

In [10]:
training_df = pd.read_csv('training_df.csv')
training_df = training_df.drop('Unnamed: 0', axis=1)
training_df.head()

Unnamed: 0,user_id,item_id,cat_id,merchant_id,brand_id,time_stamp,action_type,age_range,gender,label
0,190023,424,662,3432,5093.0,1111,0,4.0,0.0,0
1,190023,424,662,3432,5093.0,1111,0,4.0,0.0,0
2,190023,424,662,3432,5093.0,1111,0,4.0,0.0,0
3,190023,424,662,3432,5093.0,1111,3,4.0,0.0,0
4,190023,424,662,3432,5093.0,1111,2,4.0,0.0,0


In [11]:
feature_names=training_df.columns
training_df.shape

(388, 10)

<b>Tabular Explainer<b>

The package **omnixai.preprocessing** provides several useful preprocessing functions for a Tabular instance. TabularTransform is a special transform designed for processing tabular data. By default, it converts categorical features into one-hot encoding, and keeps continuous-valued features. The method transform of TabularTransform transforms a Tabular instance to a numpy array. If the Tabular instance has a target/label column, the last column of the numpy array will be the target/label. After data preprocessing, we train a **XGBoost classifier** for this task.

In [12]:
tabular_data = Tabular(
   training_df,
   categorical_columns=[feature_names[i] for i in [2, 3, 6, 7, 8]],
   target_column='label'
)

In [13]:
# Data preprocessing
transformer = TabularTransform().fit(tabular_data)
class_names = transformer.class_names
x = transformer.transform(tabular_data)
# Split into training and test datasets
train, test, train_labels, test_labels = train_test_split(x[:, :-1], x[:, -1], train_size=0.80)
# Train an XGBoost model (the last column of `x` is the label column after transformation)
model = XGBClassifier(n_estimators=300, max_depth=5)
model.fit(train, train_labels)
# Convert the transformed data back to Tabular instances
train_data = transformer.invert(train)
test_data = transformer.invert(test)

In [14]:
pred=model.predict(test)
accuracy_score(test_labels,pred)*100

89.74358974358975

Here, LIME, SHAP and MACE generate **local explanations** while PDP (partial dependence plot) generates **global explanations**. explainer.explain returns the local explanations generated by the three methods given the test instances, and explainer.explain_global returns the global explanations generated by PDP. TabularExplainer hides all the details behind the explainers, so we can simply call these two methods to generate explanations.

In [15]:
explainer = TabularExplainer(
    explainers = ["lime", "shap", "mace", "pdp", "ale"],
    mode = "classification",
    data = train_data,
    model = model,
    preprocess = lambda z: transformer.transform(z)
    )

In [16]:
test_instances = test_data[:5]
local_explanations = explainer.explain(X = test_instances)
global_explanations = explainer.explain_global(
    params = {"pdp": {"features": ["item_id", "cat_id", "merchant_id", "action_type", 
                                   "age_range", "gender", "time_stamp"]}}
)

  0%|          | 0/5 [00:00<?, ?it/s]

we create a **PredictionAnalyzer** for computing performance metrics for this classification task

In [17]:
analyzer = PredictionAnalyzer(
    mode="classification",
    test_data=test_data,                           # The test dataset (a `Tabular` instance)
    test_targets=test_labels,                      # The test labels (a numpy array)
    model=model,                                   # The ML model
    preprocess=lambda z: transformer.transform(z)  # Converts raw features into the model inputs
)
prediction_explanations = analyzer.explain()

**OmniXAI Dashboard**

In [None]:
# Launch a dashboard for visualization
dashboard = Dashboard(
   instances=test_instances,                        # The instances to explain
   local_explanations=local_explanations,           # Set the local explanations
   global_explanations=global_explanations,         # Set the global explanations
   prediction_explanations=prediction_explanations, # Set the prediction metrics
   class_names=class_names,                         # Set class names
   explainer=explainer                              # The created TabularExplainer for what if analysis
)
dashboard.show()  

Dash is running on http://127.0.0.1:8050/

 * Serving Flask app "omnixai.visualization.dashboard" (lazy loading)
 * Environment: production
[2m   Use a production WSGI server instead.[0m
 * Debug mode: off


 * Running on http://127.0.0.1:8050/ (Press CTRL+C to quit)
127.0.0.1 - - [30/Apr/2023 19:29:09] "GET / HTTP/1.1" 200 -
127.0.0.1 - - [30/Apr/2023 19:29:09] "GET /_dash-component-suites/dash/dcc/async-graph.js HTTP/1.1" 304 -
127.0.0.1 - - [30/Apr/2023 19:29:09] "GET /_dash-component-suites/dash/dcc/async-dropdown.js HTTP/1.1" 304 -
127.0.0.1 - - [30/Apr/2023 19:29:09] "GET /_dash-component-suites/dash/dcc/async-plotlyjs.js HTTP/1.1" 304 -
127.0.0.1 - - [30/Apr/2023 19:29:10] "GET /_dash-layout HTTP/1.1" 200 -
127.0.0.1 - - [30/Apr/2023 19:29:10] "GET /_dash-dependencies HTTP/1.1" 200 -
127.0.0.1 - - [30/Apr/2023 19:29:10] "POST /_dash-update-component HTTP/1.1" 200 -
127.0.0.1 - - [30/Apr/2023 19:29:10] "POST /_dash-update-component HTTP/1.1" 200 -
127.0.0.1 - - [30/Apr/2023 19:29:10] "POST /_dash-update-component HTTP/1.1" 200 -
127.0.0.1 - - [30/Apr/2023 19:29:10] "POST /_dash-update-component HTTP/1.1" 200 -


A value is trying to be set on a copy of a slice from a DataFrame.
Try u