# Interpreting Cell Type classification model

In [None]:
import warnings; warnings.filterwarnings('ignore', category=FutureWarning)
import tensorflow as tf; tf.logging.set_verbosity(tf.logging.ERROR)  # suppress deprecation messages
import pandas as pd
import numpy as np
import seaborn as sns
import ipywidgets as widgets
from ipywidgets import interact, interact_manual
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MinMaxScaler
from depiction.models.base.base_model import BaseModel
from depiction.models.examples.celltype.celltype import CellTyper
from depiction.core import Task, DataType
from depiction.interpreters.backprop.backpropeter import BackPropeter
from depiction.models.keras.core import KerasModel
from depiction.explanations.feature_attribution import aggregate_attributions

## Load data

In [None]:
# Load data 
datapath = '../data/single-cell/data.csv'
data_df = pd.read_csv(datapath)

#scale the data from 0 to 1
min_max_scaler = MinMaxScaler(feature_range=(0, 1), copy=True)
data = min_max_scaler.fit_transform(data_df.drop('category', axis=1).values)
data_df = pd.DataFrame(
    np.append(data, data_df['category'].values[:, None], axis=1), index=data_df.index, columns=data_df.columns
)

#  split as in traing of the model
train_df, test_df = train_test_split(data_df, test_size=0.33, random_state=42, stratify=data_df.category)
test_df, valid_df = train_test_split(test_df, test_size=0.67, random_state=42, stratify=test_df.category)

train_df.head()

markers = train_df.columns[:-1]

X_train = train_df[markers].values
X_test = test_df[markers].values
X_valid = valid_df[markers].values

y_train = train_df['category'].values.astype(np.int)
y_test = test_df['category'].values.astype(np.int)
y_valid = valid_df['category'].values.astype(np.int)

sample_id = 4
sample = X_test[sample_id:sample_id+1]

## Load pretrained model

In [None]:
# Import trained classifier
classifier = CellTyper(filename='celltype_model.h5')
print(classifier.celltype_names)
classifier = KerasModel(classifier.model, Task.CLASSIFICATION, DataType.TABULAR)

## Create backpropagation-based explainer

In [None]:
methods = ['saliency','shapley_sampling', 'occlusion', 'elrp']

explanations = []
for m in methods:
    interpreter = BackPropeter(classifier, m)
    exp = interpreter.interpret(sample)[0].make_positive().normalize()
    explanations.append(exp)
    
explanations = aggregate_attributions(explanations, mode='none')
fig, ax = explanations.visualize(feature_names=markers, y_labels=methods, cmap='Reds')

# Interpreting protein binding

In [None]:
import warnings; warnings.filterwarnings('ignore', category=FutureWarning)
import tensorflow as tf; tf.logging.set_verbosity(tf.logging.ERROR)  # suppress deprecation messages
from depiction.models.examples.deepbind.deepbind import DeepBind, create_DNA_language
from depiction.interpreters.u_wash.u_washer import UWasher, to_feature_attribution
from ipywidgets import interact

In [None]:
classifier_foxa1 = DeepBind('DeepBind/Homo_sapiens/TF/D00761.001_ChIP-seq_FOXA1', min_length=40)

In [None]:
class_names = ['NOT BINDING', 'BINDING']
lime_explanation_configs = {
    'labels': (1,),
}
lime_params = {
    'class_names': class_names,
    'split_expression': list,
    'bow': False,
    'char_level': True
}

In [None]:
lime_explainer = UWasher("lime", classifier_foxa1, **lime_params)
classifier_foxa1.use_labels = False
explanation = lime_explainer.interpret("TGTTTACTTT", explanation_configs=lime_explanation_configs)
explanation = to_feature_attribution(explanation, classifier_foxa1.data_type, labels=[1])
fig, ax = explanation.visualize(tokens=list("TGTTTACTTT"), show=True, as_logo=True)