In [None]:
import pandas as pd
import numpy as np
import seaborn as sns
import ipywidgets as widgets
from ipywidgets import interact
from sklearn.model_selection import train_test_split
from depiction.models.celltype import CellTyper
from depiction.interpreters.uw_model import UWModel
from depiction.core import Task, DataType

## Data

In [None]:
# Load data 
datapath = '../data/single-cell/data.csv'
data_df = pd.read_csv(datapath)
sns.countplot(data_df.category)

#  split as in traing of the model
train_df, test_df = train_test_split(data_df, test_size=0.33, random_state=42, stratify=data_df.category)
test_df, valid_df = train_test_split(test_df, test_size=0.67, stratify=test_df.category)


train_df.head()

In [None]:
classifier.celltype_names

## Loading a pretrained model
is actually done under the hood by a child implementation of `depiction.models.Model`  
Change `filename` (there's also `cache_dir`) to load a different model.

In [None]:
# Import trained classifier
classifier = CellTyper(filename='celltype_dnn_model.h5')

## Layer weights

In [None]:
weights = classifier.model.layers[-1].get_weights()[0] # shape is (13, 20)
# weights

In [None]:
sns.heatmap(pd.DataFrame(weights, index=data_df.columns[:-1], columns=meta_series))

In [None]:
Compare qualitatively to __B__ and **C**
![manual_gated](https://science.sciencemag.org/content/sci/332/6030/687/F2.large.jpg?width=800&height=600&carousel=1)
from https://science.sciencemag.org/content/332/6030/687/tab-figures-data

## Interpretability methods
helper functions and a widget to sample from a class

In [None]:
def interpret_with_lime(id_sample_to_explain):
# Create a LIME tabular interpreter
    explanation_configs = {
        "top_labels": 1
    }
    interpreter_params = {
        "training_data": train_df.values[:, :-1],
        "training_labels": train_df.values[:, -1],
        "feature_names": train_df.columns[:-1],
        "verbose": True,
        "class_names": classifier.celltype_names.values(),
        "discretize_continuous": False,
        "sample_around_instance": True
    }

    explainer = UWModel("lime", Task.CLASSIFICATION, DataType.TABULAR, explanation_configs, **interpreter_params)

    # explain the chosen instance wrt the chosen labels
    explainer.interpret(classifier.predict, test_df.values[id_sample_to_explain, :-1])


def interpret_with_anchor(id_sample_to_explain):
    explanation_configs = {}
    interpreter_params = {
        "feature_names": train_df.columns[:-1],
        "class_names": classifier.celltype_names.values(),
        "categorical_names": {}
    }

    explainer = UWModel("anchor", Task.CLASSIFICATION, DataType.TABULAR, explanation_configs, **interpreter_params)
    X_train = train_df.values[:, :-1]
    y_train = train_df.values[:, -1].astype(np.int)
    X_valid = valid_df.values[:, :-1]
    y_valid = valid_df.values[:, -1].astype(np.int)
    explainer.explainer.fit(
        X_train, y_train, X_valid, y_valid
    )

    # explain the chosen instance wrt the chosen labels
    def new_predict(sample, **kwargs):
        return np.argmax(classifier.predict(sample,**kwargs), axis=1)
    explainer.interpret(new_predict, test_df.values[id_sample_to_explain, :-1])


def interpret_random_from_class(label, interpreter="lime"):
    id_sample_to_explain = test_df.reset_index().query('category==@label').sample(n=1).index[0]
#     sample = test_df.iloc[id_sample_to_explain]
#     print(f"Interpreting sample with index {id_sample_to_explain}:\n{sample}")
    if interpreter=="lime":
        interpret_with_lime(id_sample_to_explain)
    else:
        interpret_with_anchor(id_sample_to_explain)

In [None]:
interact(interpret_random_from_class, label=[(v, k) for k, v in classifier.celltype_names.items()],
         interpreter=["lime", "anchor"]);

In [None]:
interpret_with_anchor(5371)