In [None]:
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from depiction.models.celltype import CellTyper
from depiction.interpreters.uw_model import UWModel
from depiction.core import Task, DataType

In [None]:
# Import trained classifier
classifier = CellTyper()

In [None]:
# Load data 
datapath = '../data/single-cell/data.csv'
data_df = pd.read_csv(datapath)

train_df, test_df = train_test_split(data_df, test_size=0.1)
train_df, valid_df = train_test_split(train_df, test_size=0.2)

train_df.head()

In [None]:
task = Task.CLASSIFICATION
data_type = DataType.TABULAR
id_sample_to_explain = 3

In [None]:
# Create a LIME tabular interpreter
interpreter = "lime"
explanation_configs = {
    "top_labels": 1
}
interpreter_params = {
    "training_data": train_df.values[:, :-1],
    "training_labels": train_df.values[:, -1],
    "feature_names": train_df.columns[:-1],
    "verbose": True,
    "class_names": [classifier.celltype_names[i+1] for i in range(len(classifier.celltype_names))],
    "discretize_continuous": False,
    "sample_around_instance": True
}

explainer = UWModel(interpreter, task, data_type, explanation_configs, **interpreter_params)

# explain the chosen instance wrt the chosen labels
explainer.interpret(classifier.predict, test_df.values[id_sample_to_explain, :-1])

In [None]:
# Create a anchor tabular interpreter
interpreter = "anchor"
explanation_configs = {}
interpreter_params = {
    "feature_names": train_df.columns[:-1],
    "class_names": [classifier.celltype_names[i+1] for i in range(len(classifier.celltype_names))],
    "categorical_names": {}
}

explainer = UWModel(interpreter, task, data_type, explanation_configs, **interpreter_params)
X_train = train_df.values[:, :-1]
y_train = train_df.values[:, -1].astype(np.int)
X_valid = valid_df.values[:, :-1]
y_valid = valid_df.values[:, -1].astype(np.int)
explainer.explainer.fit(
    X_train, y_train, X_valid, y_valid
)

# explain the chosen instance wrt the chosen labels
def new_predict(sample, **kwargs):
    return np.argmax(classifier.predict(sample,**kwargs), axis=1)
explainer.interpret(new_predict, test_df.values[id_sample_to_explain, :-1])