In [None]:
import warnings; warnings.filterwarnings('ignore', category=FutureWarning)
import tensorflow as tf; tf.logging.set_verbosity(tf.logging.ERROR)  # suppress deprecation messages
import pandas as pd
import numpy as np
import seaborn as sns
import ipywidgets as widgets
from ipywidgets import interact, interact_manual
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MinMaxScaler
from depiction.models.base.base_model import BaseModel
from depiction.models.examples.celltype.celltype import CellTyper
from depiction.interpreters.u_wash.u_washer import UWasher
from depiction.interpreters.alibi import Counterfactual
from depiction.interpreters.aix360.rule_based_model import RuleAIX360
from depiction.models.base import BinarizedClassifier
from depiction.core import Task, DataType
from tensorflow import keras

## Data

In [None]:
# Load data 
datapath = '../data/single-cell/data.csv'
data_df = pd.read_csv(datapath)

#scale the data from 0 to 1
min_max_scaler = MinMaxScaler(feature_range=(0, 1), copy=True)
data = min_max_scaler.fit_transform(data_df.drop('category', axis=1).values)
data_df = pd.DataFrame(
    np.append(data, data_df['category'].values[:, None], axis=1), index=data_df.index, columns=data_df.columns
)

#  split as in traing of the model
train_df, test_df = train_test_split(data_df, test_size=0.33, random_state=42, stratify=data_df.category)
test_df, valid_df = train_test_split(test_df, test_size=0.67, random_state=42, stratify=test_df.category)


train_df.head()

In [None]:
markers = train_df.columns[:-1]

X_train = train_df[markers].values
X_test = test_df[markers].values
X_valid = valid_df[markers].values

y_train = train_df['category'].values.astype(np.int)
y_test = test_df['category'].values.astype(np.int)
y_valid = valid_df['category'].values.astype(np.int)


In [None]:
sns.countplot(data_df.category)
CellTyper.celltype_names

## Loading a pretrained model
is actually done under the hood by a child implementation of `depiction.models.uri.HTTPModel`. 
Change `filename`, `cache_dir` (with fixed subdir `models/`) and/or `origin` to load/download a different model.
Or have a look at other uri models, e.g `FileSystemModel` or `RESTAPIModel`.

In [None]:
# Import trained classifier
classifier = CellTyper(filename='celltype_model.h5')
# classifier.model_path

In [None]:
classifier.model.summary()

## Layer weights

In [None]:
weights = classifier.model.layers[0].get_weights()[0]
sns.heatmap(pd.DataFrame(
    weights,
    index=markers,
#     columns=CellTyper.celltype_names.values()
).T)

Compare qualitatively to __B__ and **C** (thought the image is not depicting this exact dataset)
![manual_gated](https://science.sciencemag.org/content/sci/332/6030/687/F2.large.jpg?width=800&height=600&carousel=1)
from https://science.sciencemag.org/content/332/6030/687/tab-figures-data

helper/widget functions

In [None]:
def random_from_class(label):
    id_sample_to_explain = test_df.reset_index().query('category==@label').sample(n=1).index[0]
    print('Interpreting sample with index {} in test_df'.format(id_sample_to_explain))
    return id_sample_to_explain



In [None]:
def visualize_logits(id_sample_to_explain):
    sample = X_test[id_sample_to_explain]
    logits = pd.DataFrame(classifier.predict([[sample]]), columns=CellTyper.celltype_names.values()).T
    sns.heatmap(logits)


def visualize(id_sample_to_explain, layer):
    sample = X_test[id_sample_to_explain]
    if layer is None:
        visualize_logits(id_sample_to_explain)
        return
    elif layer==0:
        # output of last "layer" is the sample
        layer_output = sample.transpose()
    else:
        # for vizualization of output of a layer we access the model
        activation_model = keras.models.Model(
            inputs=classifier.model.input,
            outputs=classifier.model.layers[layer-1].output
        )
        layer_output = activation_model.predict([[sample]])[0]
    
    weights = classifier.model.layers[layer].get_weights()[0]
    weighted_output = (weights.transpose() * layer_output)
    sns.heatmap(weighted_output)


def visualize_random_from_class(label, layer):
    visualize(random_from_class(label), layer)



In [None]:
interact_manual(
    visualize_random_from_class,
    label=[(v, k) for k, v in classifier.celltype_names.items()],
    layer=dict(
        **{layer.name: i for i, layer in enumerate(classifier.model.layers)}, logits=None
    )
);

In [None]:
visualize_logits(4368)

# Interpretability methods
starting with "local" methods, explaining a given sample.

## Lime

In [None]:
# Create a LIME tabular interpreter
lime_params = {
    'training_data': X_train,
    'training_labels': y_train,
    'feature_names': markers,
    'verbose': True,
    'class_names': classifier.celltype_names.values(),
    'discretize_continuous': False,
    'sample_around_instance': True
}

lime = UWasher("lime", classifier, **lime_params)

## Anchor

In [None]:
anchors_params = {
    'feature_names': markers,
    'class_names': classifier.celltype_names.values(),
    'categorical_names': {}
}
fit_params = {  # depiction fits the anchor (tabular) on contruction.
    'train_data': X_train,
    'train_labels': y_train,
    'validation_data': X_valid,
    'validation_labels': y_valid
}

anchors = UWasher('anchors', classifier, **fit_params, **anchors_params)

## Counterfactual

In [None]:
counterfactual_params = {
    # setting some parameters
    'shape': (1, 13), # with batch size
    'target_proba': 1.0,
    'tol': 0.1, # tolerance for counterfactuals
    'max_iter': 10,
    'lam_init': 1e-1,
    'max_lam_steps': 10,
    'learning_rate_init': 0.1,
    'feature_range': (X_train.min(),X_train.max())
}

counterfactual = Counterfactual(
    classifier,
    target_class='other',  # any other class
    **counterfactual_params,
)


helper/widget functions

In [None]:
def interpret_with_lime(id_sample_to_explain):
    """Explain the chosen instance wrt the chosen label."""
    lime.interpret(X_test[id_sample_to_explain], explanation_configs={'top_labels': 1})


def anchor_callback(sample, **kwargs):
    """Explain the chosen instance wrt the chosen labels."""
    return np.argmax(classifier.predict(sample,**kwargs), axis=1)


def interpret_with_anchor(id_sample_to_explain):
    anchors.interpret(X_test[id_sample_to_explain], explanation_configs={},callback=anchor_callback)


def interpret_with_counterfactual(id_sample_to_explain):
    """Explain the chosen instance wrt the chosen label."""
    explanation = counterfactual.interpret(np.expand_dims(X_train[0], axis=0))  # with batch size
    predicted_class = explanation['cf']['class']
    probability = explanation['cf']['proba'][0][predicted_class]
    print(f'Counterfactual prediction: {predicted_class} with probability {probability}')
    print(explanation['cf']['X'])


def interpret_random_from_class(label, interpreter):
    id_sample_to_explain = random_from_class(label)
    if interpreter == 'lime':
        interpret_with_lime(id_sample_to_explain)
    elif interpreter == 'anchor':
        interpret_with_anchor(id_sample_to_explain)
    elif interpreter == 'counterfactual':
        interpret_with_counterfactual(id_sample_to_explain)



In [None]:
interact_manual(interpret_random_from_class, label=[(v, k) for k, v in classifier.celltype_names.items()],
         interpreter=['lime', 'anchor', 'counterfactual']
);

In [None]:
interpret_with_anchor(4368)

# Global interpretation with rule-based models 

In [None]:
LABEL2ID = {CellTyper.celltype_names[i]: i for i in CellTyper.celltype_names.keys()}

In [None]:
LABEL_TO_EXPLAIN = 'Mature CD4+ T'
LABEL_ID = LABEL2ID[LABEL_TO_EXPLAIN]

## Data preparation and auxiliary functions

In [None]:
# Binarize the task to use this method
model = BinarizedClassifier(classifier, data_type=DataType.TABULAR, label_index=LABEL_ID)

## Post-Hoc explanation 

### BRCG

In [None]:
interpreter = RuleAIX360('brcg', X=X_train, model=model)
interpreter.interpret()

### GLRM - Linear

In [None]:
interpreter = RuleAIX360('glrm_linear', X=X_train, model=model)
interpreter.interpret()

### GLRM - Logistic

In [None]:
interpreter = RuleAIX360('glrm_logistic', X=X_train, model=model)
interpreter.interpret()

## Ante-Hoc explanation

In [None]:
y_train_binary = y_train == LABEL_ID

### BRCG

In [None]:
interpreter = RuleAIX360('brcg', X=X_train, y=y_train_binary)
interpreter.interpret()

### GLRM - Linear

In [None]:
interpreter = RuleAIX360('glrm_linear', X=X_train, y=y_train_binary)
interpreter.interpret()

### GLRM - Logistic

In [None]:
interpreter = RuleAIX360('glrm_logistic', X=X_train, y=y_train_binary)
interpreter.interpret()