In [19]:
import pandas as pd
import numpy as np
import seaborn as sns
import ipywidgets as widgets
from ipywidgets import interact, interact_manual
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MinMaxScaler
from depiction.models.base.base_model import BaseModel
from depiction.models.examples.celltype.celltype import CellTyper
from depiction.interpreters.u_wash.u_washer import UWasher
from depiction.interpreters.aix360.rule_based_model import RuleAIX360
from depiction.models.base import BinarizedClassifier
from depiction.core import Task, DataType
from tensorflow import keras

## Data

In [4]:
# Load data 
datapath = 'data/single-cell/data.csv'  # '../data/single-cell/data.csv'
data_df = pd.read_csv(datapath)
sns.countplot(data_df.category)

#scale the data from 0 to 1
min_max_scaler = MinMaxScaler(feature_range=(0, 1), copy=True)
data = min_max_scaler.fit_transform(data_df.drop('category', axis=1).values)
data_df = pd.DataFrame(
    np.append(data, data_df['category'].values[:, None], axis=1), index=data_df.index, columns=data_df.columns
)

#  split as in traing of the model
train_df, test_df = train_test_split(data_df, test_size=0.33, random_state=42, stratify=data_df.category)
test_df, valid_df = train_test_split(test_df, test_size=0.67, stratify=test_df.category)


train_df.head()

In [5]:
CellTyper.celltype_names

## Loading a pretrained model
is actually done under the hood by a child implementation of `depiction.models.Model`  
Change `filename` (there's also `cache_dir`) to load a different model.

In [7]:
# Import trained classifier
classifier = CellTyper(filename='celltype_model.h5')

In [8]:
classifier.model.summary()

## Layer weights

In [9]:
weights = classifier.model.layers[0].get_weights()[0]
sns.heatmap(pd.DataFrame(
    weights,
    index=data_df.columns[:-1],
#     columns=CellTyper.celltype_names.values()
).T)

Compare qualitatively to __B__ and **C** (thought the image is not depicting this exact dataset)
![manual_gated](https://science.sciencemag.org/content/sci/332/6030/687/F2.large.jpg?width=800&height=600&carousel=1)
from https://science.sciencemag.org/content/332/6030/687/tab-figures-data

## Interpretability methods
helper functions and a widget to sample from a class

In [10]:
def random_from_class(label):
    id_sample_to_explain = test_df.reset_index().query('category==@label').sample(n=1).index[0]
    print('Interpreting sample with index {} in test_df'.format(id_sample_to_explain))
    return id_sample_to_explain


def interpret_with_lime(id_sample_to_explain):
# Create a LIME tabular interpreter
    explanation_configs = {
        'top_labels': 1
    }
    interpreter_params = {
        'training_data': train_df.values[:, :-1],
        'training_labels': train_df.values[:, -1],
        'feature_names': train_df.columns[:-1],
        'verbose': True,
        'class_names': classifier.celltype_names.values(),
        'discretize_continuous': False,
        'sample_around_instance': True
    }

    explainer = UWasher("lime", classifier, **interpreter_params)

    # explain the chosen instance wrt the chosen labels
    explainer.interpret(test_df.values[id_sample_to_explain, :-1], explanation_configs=explanation_configs)


def interpret_with_anchor(id_sample_to_explain):
    explanation_configs = {}
    interpreter_params = {
        'feature_names': train_df.columns[:-1],
        'class_names': classifier.celltype_names.values(),
        'categorical_names': {}
    }

    explainer = UWasher('anchors', classifier, **interpreter_params)
    X_train = train_df.values[:, :-1]
    y_train = train_df.values[:, -1].astype(np.int)
    X_valid = valid_df.values[:, :-1]
    y_valid = valid_df.values[:, -1].astype(np.int)
    explainer.explainer.fit(
        X_train, y_train, X_valid, y_valid
    )

    # explain the chosen instance wrt the chosen labels
    def new_predict(sample, **kwargs):
        return np.argmax(classifier.predict(sample,**kwargs), axis=1)
    explainer.interpret(test_df.values[id_sample_to_explain, :-1], explanation_configs=explanation_configs,callback=new_predict)


def visualize_logits(id_sample_to_explain):
    sample = test_df.iloc[id_sample_to_explain,:-1]
    logits = pd.DataFrame(classifier.predict([[sample]]), columns=CellTyper.celltype_names.values()).T
    sns.heatmap(logits)


def visualize(id_sample_to_explain, layer):
    sample = test_df.iloc[id_sample_to_explain,:-1]
    if layer is None:
        visualize_logits(id_sample_to_explain)
        return
    elif layer==0:
        # output of last "layer" is the sample
        layer_output = sample.values.transpose()
    else:
        # for vizualization of output of a layer we access the model
        activation_model = keras.models.Model(
            inputs=classifier.model.input,
            outputs=classifier.model.layers[layer-1].output
        )
        layer_output = activation_model.predict([[sample]])[0]
    
    weights = classifier.model.layers[layer].get_weights()[0]
    weighted_output = (weights.transpose() * layer_output)
    sns.heatmap(weighted_output)


def visualize_random_from_class(label, layer):
    visualize(random_from_class(label), layer)

    
def interpret_random_from_class(label, interpreter):
    id_sample_to_explain = random_from_class(label)
    if interpreter == 'lime':
        interpret_with_lime(id_sample_to_explain)
    elif interpreter == 'anchor':
        interpret_with_anchor(id_sample_to_explain)

In [11]:
interact_manual(
    visualize_random_from_class,
    label=[(v, k) for k, v in classifier.celltype_names.items()],
    layer=dict(
        **{layer.name: i for i, layer in enumerate(classifier.model.layers)}, logits=None
    )
);

In [12]:
visualize_logits(4368)

In [13]:
interact_manual(interpret_random_from_class, label=[(v, k) for k, v in classifier.celltype_names.items()],
         interpreter=['lime', 'anchor']
);

In [14]:
interpret_with_anchor(4368)

# Global interpretation with rule-based models 

In [15]:
LABEL2ID = {CellTyper.celltype_names[i]: i for i in range(1, len(CellTyper.celltype_names)+1)}

In [16]:
LABEL_TO_EXPLAIN = 'Mature CD4+ T'
LABEL_ID = LABEL2ID[LABEL_TO_EXPLAIN]

## Data preparation and auxiliary functions

In [18]:
markers = train_df.columns[:-1]

X_train = train_df[markers]
X_test = test_df[markers]

y_train = np.array(train_df['category'] == np.float(LABEL_ID), np.int)
y_test = np.array(test_df['category'] == np.float(LABEL_ID), np.int)

## Post-Hoc explanation 

In [20]:
# Binarize the task to use this method
model = BinarizedClassifier(classifier, data_type=DataType.TABULAR, label_index=LABEL_ID)

### BRCG

In [21]:
interpreter = RuleAIX360('brcg', X=X_train, model=model)
interpreter.interpret()

### GLRM - Linear

In [22]:
interpreter = RuleAIX360('glrm_linear', X=X_train, model=model)
interpreter.interpret()

### GLRM - Logistic

In [23]:
interpreter = RuleAIX360('glrm_logistic', X=X_train, model=model)
interpreter.interpret()

## Ante-Hoc explanation

### BRCG

In [24]:
interpreter = RuleAIX360('brcg', X=X_train, y=y_train)
interpreter.interpret()

### GLRM - Linear

In [25]:
interpreter = RuleAIX360('glrm_linear', X=X_train, y=y_train)
interpreter.interpret()

### GLRM - Logistic

In [26]:
interpreter = RuleAIX360('glrm_logistic', X=X_train, y=y_train)
interpreter.interpret()