# Training of a super simple model for celltype classification and explanation of its predictions

# Part I - Training

In [None]:
import datetime
import os

import ipywidgets as widgets
import numpy as np
import pandas as pd
import seaborn as sns
import tensorflow as tf
from ipywidgets import interact, interact_manual
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MinMaxScaler
from tensorflow.keras import layers
from tensorflow.keras.utils import to_categorical

!which python
!python --version
print(tf.VERSION)
print(tf.keras.__version__)
!pwd #  start jupyter under notebooks/ for correct relative paths


## a look at the data
labels are categories 1-20, here's the associated celltype:

In [None]:
meta_series = pd.read_csv('../data/single-cell/metadata.csv', index_col=0, dtype=str).rename_axis(None)
meta_series

There are 13 unbalanced classes, and over 80k samples

In [None]:
data_df = pd.read_csv('../data/single-cell/data.csv')
sns.countplot(data_df.category)

In [None]:
data_df.sample(n=10)


In [None]:
# class wise probabilities
def one_hot_encoding(classes):
    return to_categorical(classes)[:, 1:]  # remove category 0

def one_hot_decoding(labels):
    return labels.argmax(axis=1) + 1

In [None]:
classes = data_df['category'].values
labels = one_hot_encoding(classes)

In [None]:
one_hot_decoding(labels)

In [None]:
labels # model output (softmax) for classification shows probabilities per class

In [None]:
#scale the data from 0 to 1
min_max_scaler = MinMaxScaler(feature_range=(0, 1), copy=True)
data = min_max_scaler.fit_transform(data_df.drop('category', axis=1).values)
data.shape

In [None]:
data_train, data_test, labels_train, labels_test = train_test_split(
    data, labels,
    # TODO make choices on keyword arguments
)

In [None]:
batchsize = 32

In [None]:
dataset = tf.data.Dataset.from_tensor_slices((data_train, labels_train))
dataset = dataset.shuffle(2 * batchsize).batch(batchsize)
dataset = dataset.repeat()

testset = tf.data.Dataset.from_tensor_slices((data_test, labels_test))
testset = testset.batch(batchsize)

## Implement a simple model
start with a dense/linear layer, train a few epochs and grow/adapt your network architecture as you like
for classification end with a softmax for output probability per celltype.

In [None]:
model = tf.keras.Sequential()

model.add(
    # TODO provide layers
)

In [None]:
model.summary()

In [None]:
model.compile(optimizer=tf.keras.optimizers.Adam(0.001),
              loss='categorical_crossentropy',
              metrics=[tf.keras.metrics.categorical_accuracy])

In [None]:
# evaluation on testset on every epoch
# log_dir="logs/fit/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
# tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1)
model.fit(
    dataset,
    epochs=20, steps_per_epoch=np.ceil(data_train.shape[0]/batchsize),
    validation_data=testset, #  callbacks=[tensorboard_callback]
)

## Now what did the model learn?
Let's save to model to use an explainer from `depiction` to investigate

In [None]:
!pwd
!ls

In [None]:
# Save entire model to a HDF5 file
model.save('./celltype_model.h5')

In [None]:
# tensorboard --logdir logs/fit

In [None]:
# To recreate the exact same model, including weights and optimizer.
# model = tf.keras.models.load_model('./celltype_model')

# Part II - Explaining

### Here we are going to apply _lime_ and _anchor_ to the model to explain the prediction for given samples
keep in mind those methods are trained on data themselves

In [None]:
# Use DataFrame to track feature names and choose samples of given class
test_df = pd.DataFrame(
    np.append(data_test, one_hot_decoding(labels_test)[:, None], axis=1), columns=data_df.columns
)
# valid_df, test_df = train_test_split(test_df, test_size=0.67, stratify=test_df.category)
data_valid, data_test, labels_valid, labels_test = train_test_split(
    data_test, labels_test, test_size=0.67, stratify=test_df.category
)

# train_df.head()

In [None]:
pd.Series()

## Using the pretrained model  with `depiction`
by implementing a class inheriting from `depiction.models.Model`  


In [None]:
# lets find the saved model
!pwd
!ls
!ls ~/.keras/models

In [None]:
import depiction

In [None]:
from tensorflow import keras
from depiction.models.base.base_model import BaseModel
from depiction.core import Task, DataType

class CellTyper(BaseModel):
    """Classifier of single cells to be explained."""

    def __init__(self, filename, directory): 
        """Initialize the Model."""
        super(CellTyper, self).__init__(Task.CLASSIFICATION, DataType.TABULAR)
        self.model_path = os.path.join(os.path.expanduser(directory), filename)
        self.model = None # TODO load your model from disk

    def predict(self, sample, *args, **kwargs):
        """
        Run the model for inference on a given sample and with the provided
        parameters.

        Args:
            sample (object): an input sample for the model.
            args (list): list of arguments.
            kwargs (dict): list of key-value arguments.

        Returns:
            a prediction for the model on the given sample.
        """
        return self.model.predict(
            sample) # TODO return prediction of sample

### We actually just need the results from prediction, not access to the full model.

In [None]:
# Import trained classifier, or use a pretrained one with directory '~/.keras/models'
classifier = CellTyper(filename='celltype_model.h5', directory='.')

In [None]:
classifier.model.summary()

## Layer weights
are only somewhat interpretable if the model has a single layer

In [None]:
meta_series.values

In [None]:
single_layer_model = tf.keras.models.load_model(os.path.expanduser('~/.keras/models/celltype_model.h5'))
weights = None  # TODO access weights from single_layer_model
sns.heatmap(pd.DataFrame(
    weights,
    index=data_df.columns[:-1],
    columns=meta_series.values
).T)

## Interpretability methods
helper functions and a widget to sample from a class

In [None]:
from depiction.interpreters.u_wash.u_washer import UWasher
from depiction.core import DataType

In [None]:
class_names = meta_series.to_dict()['cell type name']

In [None]:
explanation_configs = {
        'top_labels': 1
    }
interpreter_params = {
        # TODO provide interpreter params
}

explainer = UWasher("lime", classifier, **interpreter_params)

id_sample_to_explain = 0
# explain the chosen instance wrt the chosen labels
explainer.interpret(test_df.values[id_sample_to_explain, :-1], explanation_configs=explanation_configs)

## Widgets to look at random sample from class

In [None]:
def random_from_class(label):
    id_sample_to_explain = test_df.reset_index().query('category==@label').sample(n=1).index[0]
    print('Interpreting sample with index {} in test_df'.format(id_sample_to_explain))
    return id_sample_to_explain


def interpret_with_lime(id_sample_to_explain):
# Create a LIME tabular interpreter
    explanation_configs = {
        'top_labels': 1
    }
    interpreter_params = {
        # TODO provide interpreter params
    }

    explainer = UWasher("lime", classifier, **interpreter_params)

    # explain the chosen instance wrt the chosen labels
    explainer.interpret(test_df.values[id_sample_to_explain, :-1], explanation_configs=explanation_configs)


def interpret_with_anchor(id_sample_to_explain):
    explanation_configs = {}
    interpreter_params = {
        'feature_names': data_df.columns[:-1],
        'class_names': class_names.values(),
        'categorical_names': {}
    }

    explainer = UWasher('anchors', classifier, **interpreter_params)  
    
    explainer.explainer.fit(
        data_train, labels_train, data_valid, labels_valid #.astype(np.int)
    )


    # explain the chosen instance wrt the chosen labels
    def new_predict(sample, **kwargs):
        return np.argmax(classifier.predict(sample,**kwargs), axis=1)
    explainer.interpret(test_df.values[id_sample_to_explain, :-1], explanation_configs=explanation_configs,callback=new_predict)

    
def interpret_random_from_class(label, interpreter):
    id_sample_to_explain = random_from_class(label)
    if interpreter == 'lime':
        interpret_with_lime(id_sample_to_explain)
    elif interpreter == 'anchor':
        interpret_with_anchor(id_sample_to_explain)

In [None]:
interact_manual(interpret_random_from_class, label=[(v, k) for k, v in class_names.items()],
         interpreter=['lime', 'anchor']
);

In [None]:
interpret_with_anchor(8373)

Compare qualitatively to __B__ and **C** (thought the image is not depicting this exact dataset)
![manual_gated](https://science.sciencemag.org/content/sci/332/6030/687/F2.large.jpg?width=800&height=600&carousel=1)
from https://science.sciencemag.org/content/332/6030/687/tab-figures-data