In [None]:
!git clone https://github.com/IBM/dl-interpretability-compbio.git
!pip3 install dl-interpretability-compbio/ 

In [None]:
!pip3 install -I scikit-image>=0.14.2

In [None]:
import skimage
print(skimage.__version__)

TODO NOW: Restart the runtime 

In [None]:
!wget https://repo.anaconda.com/archive/Anaconda3-5.2.0-Linux-x86_64.sh
!bash Anaconda3-5.2.0-Linux-x86_64.sh -b -f -p /usr/local
!conda install -y rdkit=2019.03.1=py36hc20afe1_1 -c https://conda.anaconda.org/rdkit

In [None]:
import sys
sys.path.insert(0, '/usr/local/lib/python3.6/site-packages')
%cd /content/dl-interpretability-compbio/notebooks

# Training of a super simple model for celltype classification

In [None]:
import tensorflow as tf
!which python
!python --version
print(tf.VERSION)
print(tf.keras.__version__)
!pwd #  start jupyter under notebooks/ for correct relative paths

In [None]:
import datetime
import inspect
import pandas as pd
import numpy as np
import seaborn as sns
from tensorflow.keras import layers
from tensorflow.keras.utils import to_categorical
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MinMaxScaler
from depiction.models.examples.celltype.celltype import one_hot_encoding, one_hot_decoding

### a look at the data
labels are categories 1-20, here's the associated celltype:

In [None]:
meta_series = pd.read_csv('../data/single-cell/metadata.csv', index_col=0)
meta_series

There are 13 unbalanced classes, and over 80k samples

In [None]:
data_df = pd.read_csv('../data/single-cell/data.csv')
data_df.groupby('category').count()['CD45']

In [None]:
data_df.sample(n=10)

In [None]:
print(inspect.getsource(one_hot_encoding)) # from keras, but taking care of 1 indexed classes
print(inspect.getsource(one_hot_decoding))

In [None]:
classes = data_df['category'].values
labels = one_hot_encoding(classes)

#scale the data from 0 to 1
min_max_scaler = MinMaxScaler(feature_range=(0, 1), copy=True)
data = min_max_scaler.fit_transform(data_df.drop('category', axis=1).values)
data.shape

In [None]:
one_hot_decoding(labels)

In [None]:
data_train, data_test, labels_train, labels_test = train_test_split(
    data, labels, test_size=0.33, random_state=42, stratify=data_df.category)

In [None]:
labels

In [None]:
batchsize = 32

In [None]:
dataset = tf.data.Dataset.from_tensor_slices((data_train, labels_train))
dataset = dataset.shuffle(2 * batchsize).batch(batchsize)
dataset = dataset.repeat()

testset = tf.data.Dataset.from_tensor_slices((data_test, labels_test))
testset = testset.batch(batchsize)

### I don't know how a simpler network would look like

In [None]:
model = tf.keras.Sequential()
# Add a softmax layer with output units per celltype:
model.add(layers.Dense(
    len(meta_series), activation='softmax',
    batch_input_shape=tf.data.get_output_shapes(dataset)[0]
))

In [None]:
model.summary()

In [None]:
model.compile(optimizer=tf.keras.optimizers.Adam(0.001),
              loss='categorical_crossentropy',
              metrics=[tf.keras.metrics.categorical_accuracy])

In [None]:
# evaluation on testset on every epoch
# log_dir="logs/fit/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
# tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1)
model.fit(
    dataset,
    epochs=20, steps_per_epoch=np.ceil(data_train.shape[0]/batchsize),
    validation_data=testset, #  callbacks=[tensorboard_callback]
)

### Is such a simple model interpretable?

In [None]:
# Save entire model to a HDF5 file
model.save('./celltype_model.h5')

In [None]:
# tensorboard --logdir logs/fit

In [None]:
# To recreate the exact same model, including weights and optimizer.
# model = tf.keras.models.load_model('../data/models/celltype_dnn_model.h5')

## What is the effect of increasing model complexity? 
Play around by adding some layers, train and save the model under some name to use with the other notebook.

![title](https://i.kym-cdn.com/photos/images/newsfeed/000/531/557/a88.jpg)

In [None]:
model = tf.keras.Sequential()
# Adds a densely-connected layers with 64 units to the model:
model.add(layers.Dense(64, activation='relu', batch_input_shape=tf.data.get_output_shapes(dataset)[0])) # 
# ...
# do whatever you want
# model.add(layers.Dense(64, activation='relu'))
# model.add(layers.Dropout(0.5))
# ...
# Add a softmax layer with output units per celltype:
model.add(layers.Dense(len(meta_series), activation='softmax'))

In [None]:
%reset

# Interpreting Cell Typer

In [None]:
import pandas as pd
import numpy as np
import seaborn as sns
import ipywidgets as widgets
from ipywidgets import interact, interact_manual
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MinMaxScaler
from depiction.models.base.base_model import BaseModel
from depiction.models.examples.celltype.celltype import CellTyper
from depiction.interpreters.u_wash.u_washer import UWasher
from depiction.interpreters.aix360.rule_based_model import RuleAIX360
from depiction.models.base import BinarizedClassifier
from depiction.core import Task, DataType
from tensorflow import keras

### Data

In [None]:
# Load data 
datapath = '../data/single-cell/data.csv'
data_df = pd.read_csv(datapath)
sns.countplot(data_df.category)

#scale the data from 0 to 1
min_max_scaler = MinMaxScaler(feature_range=(0, 1), copy=True)
data = min_max_scaler.fit_transform(data_df.drop('category', axis=1).values)
data_df = pd.DataFrame(
    np.append(data, data_df['category'].values[:, None], axis=1), index=data_df.index, columns=data_df.columns
)

#  split as in traing of the model
train_df, test_df = train_test_split(data_df, test_size=0.33, random_state=42, stratify=data_df.category)
test_df, valid_df = train_test_split(test_df, test_size=0.67, stratify=test_df.category)


train_df.head()

In [None]:
CellTyper.celltype_names

### Loading a pretrained model
is actually done under the hood by a child implementation of `depiction.models.uri.HTTPModel`  
Change `filename` (there's also `cache_dir`) to load a different model.

In [None]:
# Import trained classifier
classifier = CellTyper(filename='celltype_model.h5')

In [None]:
classifier.model.summary()

### Layer weights

In [None]:
weights = classifier.model.layers[0].get_weights()[0]
sns.heatmap(pd.DataFrame(
    weights,
    index=data_df.columns[:-1],
#     columns=CellTyper.celltype_names.values()
).T)

Compare qualitatively to __B__ and **C** (thought the image is not depicting this exact dataset)
![manual_gated](https://science.sciencemag.org/content/sci/332/6030/687/F2.large.jpg?width=800&height=600&carousel=1)
from https://science.sciencemag.org/content/332/6030/687/tab-figures-data

### Interpretability methods
helper functions and a widget to sample from a class

In [None]:
def random_from_class(label):
    id_sample_to_explain = test_df.reset_index().query('category==@label').sample(n=1).index[0]
    print('Interpreting sample with index {} in test_df'.format(id_sample_to_explain))
    return id_sample_to_explain


def interpret_with_lime(id_sample_to_explain):
# Create a LIME tabular interpreter
    explanation_configs = {
        'top_labels': 1
    }
    interpreter_params = {
        'training_data': train_df.values[:, :-1],
        'training_labels': train_df.values[:, -1],
        'feature_names': train_df.columns[:-1],
        'verbose': True,
        'class_names': classifier.celltype_names.values(),
        'discretize_continuous': False,
        'sample_around_instance': True
    }

    explainer = UWasher("lime", classifier, **interpreter_params)

    # explain the chosen instance wrt the chosen labels
    explainer.interpret(test_df.values[id_sample_to_explain, :-1], explanation_configs=explanation_configs)


def interpret_with_anchor(id_sample_to_explain):
    explanation_configs = {}
    interpreter_params = {
        'feature_names': train_df.columns[:-1],
        'class_names': classifier.celltype_names.values(),
        'categorical_names': {}
    }

    explainer = UWasher('anchors', classifier, **interpreter_params)
    X_train = train_df.values[:, :-1]
    y_train = train_df.values[:, -1].astype(np.int)
    X_valid = valid_df.values[:, :-1]
    y_valid = valid_df.values[:, -1].astype(np.int)
    explainer.explainer.fit(
        X_train, y_train, X_valid, y_valid
    )

    # explain the chosen instance wrt the chosen labels
    def new_predict(sample, **kwargs):
        return np.argmax(classifier.predict(sample,**kwargs), axis=1)
    explainer.interpret(test_df.values[id_sample_to_explain, :-1], explanation_configs=explanation_configs,callback=new_predict)


def visualize_logits(id_sample_to_explain):
    sample = test_df.iloc[id_sample_to_explain,:-1]
    logits = pd.DataFrame(classifier.predict([[sample]]), columns=CellTyper.celltype_names.values()).T
    sns.heatmap(logits)


def visualize(id_sample_to_explain, layer):
    sample = test_df.iloc[id_sample_to_explain,:-1]
    if layer is None:
        visualize_logits(id_sample_to_explain)
        return
    elif layer==0:
        # output of last "layer" is the sample
        layer_output = sample.values.transpose()
    else:
        # for vizualization of output of a layer we access the model
        activation_model = keras.models.Model(
            inputs=classifier.model.input,
            outputs=classifier.model.layers[layer-1].output
        )
        layer_output = activation_model.predict([[sample]])[0]
    
    weights = classifier.model.layers[layer].get_weights()[0]
    weighted_output = (weights.transpose() * layer_output)
    sns.heatmap(weighted_output)


def visualize_random_from_class(label, layer):
    visualize(random_from_class(label), layer)

In [None]:
interact_manual(
    visualize_random_from_class,
    label=[(v, k) for k, v in classifier.celltype_names.items()],
    layer=dict(
        **{layer.name: i for i, layer in enumerate(classifier.model.layers)}, logits=None
    )
);

In [None]:
visualize_logits(4773)

In [None]:
interpret_with_lime(4773)

In [None]:
interpret_with_anchor(4773)

## Global interpretation with rule-based models 

In [None]:
LABEL2ID = {CellTyper.celltype_names[i]: i for i in range(1, len(CellTyper.celltype_names)+1)}

In [None]:
LABEL_TO_EXPLAIN = 'Mature CD4+ T'
LABEL_ID = LABEL2ID[LABEL_TO_EXPLAIN]

### Data preparation and auxiliary functions

In [None]:
markers = train_df.columns[:-1]

X_train = train_df[markers]
X_test = test_df[markers]

y_train = np.array(train_df['category'] == np.float(LABEL_ID), np.int)
y_test = np.array(test_df['category'] == np.float(LABEL_ID), np.int)

### Post-Hoc explanation 

In [None]:
# Binarize the task to use this method
model = BinarizedClassifier(classifier, data_type=DataType.TABULAR, label_index=LABEL_ID)

#### BRCG

In [None]:
interpreter = RuleAIX360('brcg', X=X_train, model=model)
interpreter.interpret()

#### GLRM - Linear

In [None]:
interpreter = RuleAIX360('glrm_linear', X=X_train, model=model)
interpreter.interpret()

#### GLRM - Logistic

In [None]:
interpreter = RuleAIX360('glrm_logistic', X=X_train, model=model)
interpreter.interpret()

### Ante-Hoc explanation

#### BRCG

In [None]:
interpreter = RuleAIX360('brcg', X=X_train, y=y_train)
interpreter.interpret()

#### GLRM - Linear

In [None]:
interpreter = RuleAIX360('glrm_linear', X=X_train, y=y_train)
interpreter.interpret()

#### GLRM - Logistic

In [None]:
interpreter = RuleAIX360('glrm_logistic', X=X_train, y=y_train)
interpreter.interpret()

In [None]:
%reset

# Having fun with DeepBind

In [None]:
from depiction.models.examples.deepbind.deepbind import DeepBind
from depiction.core import Task, DataType
from depiction.interpreters.u_wash.u_washer import UWasher
from depiction.models.examples.deepbind.deepbind import create_DNA_language
from ipywidgets import interact

### Setup task

In [None]:
class_names = ['NOT BINDING', 'BINDING']
tf_factor_id = 'D00328.003' # CTCF
classifier = DeepBind(tf_factor_id = tf_factor_id)

## Instantiate the interpreters

In [None]:
# Create a LIME text interpreter
interpreter = 'lime'
lime_explanation_configs = {
    'labels': (1,),
}
interpreter_params = {
    'class_names': class_names,
    'split_expression': list,
    'bow': False,
    'char_level': True
}

lime_explainer = UWasher(interpreter, classifier, **interpreter_params)

# Create Anchor text intepreter
interpreter = 'anchors'
anchors_explanation_configs = {
    'use_proba': False,
    'batch_size': 100
}
interpreter_params = {
    'class_names': class_names,
    'nlp': create_DNA_language(),
    'unk_token': 'N',
    'sep_token': '',
    'use_unk_distribution': True
}

anchor_explainer = UWasher(interpreter, classifier, **interpreter_params)

#### Let's use LIME and Anchors

In [None]:
# sequences = [
#        'AGGCTAGCTAGGGGCGCCC', 'AGGCTAGCTAGGGGCGCTT', 'AGGGTAGCTAGGGGCGCTT',
#        'AGGGTAGCTGGGGGCGCTT', 'AGGCTAGGTGGGGGCGCTT', 'AGGCTCGGTGGGGGCGCTT',
#        'AGGCTCGGTAGGGGGCGATT'
#    ]
sequence = 'AGGCTCGGTAGGGGGCGATT'

classifier.use_labels = False
lime_explainer.interpret(sequence, explanation_configs=lime_explanation_configs)

classifier.use_labels = True
anchor_explainer.interpret(sequence, explanation_configs=anchors_explanation_configs)

## Let's interpret

CTCF binding motif
![CTCF binding motif](https://media.springernature.com/full/springer-static/image/art%3A10.1186%2Fgb-2009-10-11-r131/MediaObjects/13059_2009_Article_2281_Fig2_HTML.jpg?as=webp)
from Essien, Kobby, et al. "CTCF binding site classes exhibit distinct evolutionary, genomic, epigenomic and transcriptomic features." Genome biology 10.11 (2009): R131.

In [None]:
tf_factor_id = 'D00761.001' # FOXA1
classifier.tf_factor_id = tf_factor_id

In [None]:
sequence = 'TGTGTGTGTG'

classifier.use_labels = False
lime_explainer.interpret(sequence, explanation_configs=lime_explanation_configs)

classifier.use_labels = True
anchor_explainer.interpret(sequence, explanation_configs=anchors_explanation_configs)

FOXA1 binding motif
![FOXA1 binding motif](https://ismara.unibas.ch/supp/dataset1_IBM_v2/ismara_report/logos/FOXA1.png)
from https://ismara.unibas.ch/supp/dataset1_IBM_v2/ismara_report/pages/FOXA1.html

In [None]:
%reset

# Understanding PaccMann

In [None]:
%%capture
# import all the needed libraries
import numpy as np
import pandas as pd
import tempfile
from rdkit import Chem
from sklearn.model_selection import train_test_split
import ipywidgets as widgets
from ipywidgets import interact, interact_manual
from IPython.display import SVG, display
from depiction.models.examples.paccmann import PaccMannSmiles, PaccMannCellLine
from depiction.models.examples.paccmann.smiles import (
    get_smiles_language, smiles_attention_to_svg,
    process_smiles, get_atoms
)
from depiction.core import Task, DataType
from depiction.interpreters.u_wash.u_washer import UWasher

cache_dir = tempfile.mkdtemp()

### Data

In [None]:
# Parse data from GDSC
# drugs
drugs = pd.read_csv(
    '../data/paccmann/gdsc.smi', sep='\t',
    index_col=1, header=None,
    names=['smiles']
)
# cell lines
cell_lines = pd.read_csv('../data/paccmann/gdsc.csv.gz', index_col=1)
genes = cell_lines.columns[3:].tolist()
# sensitivity data
drug_sensitivity = pd.read_csv('../data/paccmann/gdsc_sensitivity.csv.gz', index_col=0)
# labels
class_names = ['Not Effective', 'Effective']

### Interpretability on the drug level for a cell line of interest

#### LIME and Anchor

In [None]:
# pick a cell line
selected_cell_line = 'NCI-H1648'
# filter and prepare data
selected_drug_sensitivity = drug_sensitivity[
    drug_sensitivity['cell_line'] == selected_cell_line
]
selected_drugs = drugs.reindex(selected_drug_sensitivity['drug']).dropna()
selected_drug_sensitivity = selected_drug_sensitivity.set_index('drug').reindex(
    selected_drugs.index
).dropna()
# setup a classifier for the specific cell line
classifier = PaccMannSmiles(cell_lines.loc[selected_cell_line][genes].values, cache_dir=cache_dir)
# interpretablity methods
def interpret_smiles_with_lime(example):
    explanation_configs = {
        'labels': (1,),
    }
    interpreter_params = {
        'class_names': class_names,
        'split_expression': list,
        'bow': False,
        'char_level': True
    }
    explainer = UWasher('lime', classifier, **interpreter_params)
    explainer.interpret(example, explanation_configs=explanation_configs)


def interpret_smiles_with_anchor(example):
    explanation_configs = {
        'use_proba': False,
        'batch_size': 32,
    }
    interpreter_params = {
        'class_names': class_names,
        'nlp': get_smiles_language(),
        'unk_token': '*',
        'sep_token': '',
        'use_unk_distribution': True
    }
    explainer = UWasher('anchors', classifier, **interpreter_params)
    def predict_wrapper(samples):
        return np.argmax(classifier.predict(samples), axis=1)
    explainer.interpret(example, explanation_configs=explanation_configs, callback=predict_wrapper)

In [None]:
# Dummy just to visualize the drugs neatly in Colab
interact_manual(
    lambda drug: print(drug), 
    drug=drugs.index
);

In [None]:
interpret_smiles_with_lime(drugs.loc['PHA-793887'].item())

In [None]:
interpret_smiles_with_anchor(drugs.loc['PHA-793887'].item())

#### What about PaccMann's attention?

In [None]:
# pick a cell line
selected_cell_line = 'NCI-H1648'
# setup a classifier for the specific cell line
classifier = PaccMannSmiles(cell_lines.loc[selected_cell_line][genes].values, cache_dir=cache_dir)

In [None]:
def attention_smiles(drug):
    try:
        smiles = drugs.loc[drug].item()
        molecule = Chem.MolFromSmiles(smiles)
        atoms = get_atoms(smiles)
        _ = classifier.predict([smiles])
        smiles_attention = next(classifier.predictor.predictions)['smiles_attention'][0]
        display(SVG(smiles_attention_to_svg(smiles_attention, atoms, molecule)))
    except:
        print('Structure visualization not supported')

In [None]:
interact(
    attention_smiles,
    drug=drugs.index
);

### Interpretability on the cell line level for a drug of interest

#### LIME and Anchor

In [None]:
# pick a drug
selected_drug = 'Imatinib'
# filter and prepare data
selected_drug_sensitivity = drug_sensitivity[
    drug_sensitivity['drug'] == selected_drug
]
selected_cell_lines = cell_lines.reindex(selected_drug_sensitivity['cell_line']).dropna()
selected_drug_sensitivity = selected_drug_sensitivity.set_index('cell_line').reindex(
    selected_cell_lines.index
).dropna()
X_train, X_test, y_train, y_test = train_test_split(
    selected_cell_lines[genes].values, selected_drug_sensitivity['effective'].values
)
X_test, X_valid, y_test, y_valid = train_test_split(
    X_test, y_test
)
# setup a classifier for the specific drug
classifier = PaccMannCellLine(drugs.loc[selected_drug].item(), cache_dir=cache_dir)
# interpretablity methods
def interpret_cell_line_with_lime(example):
    explanation_configs = {
        'labels': (1,),
    }
    interpreter_params = {
        'training_data': X_train,
        'training_labels': y_train,
        'feature_names': genes,
        'class_names': class_names,
        'discretize_continuous': False,
        'sample_around_instance': True
    }
    explainer = UWasher('lime', classifier, **interpreter_params)
    explainer.interpret(example, explanation_configs=explanation_configs)


def interpret_cell_line_with_anchor(example):
    explanation_configs = {}
    interpreter_params = {
        'feature_names': genes,
        'class_names': class_names,
        'categorical_names': {}
    }

    explainer = UWasher('anchors', classifier, **interpreter_params)
    explainer.explainer.fit(
        X_train, y_train, X_valid, y_valid
    )
    def predict_wrapper(samples):
        return np.argmax(classifier.predict(samples), axis=1)
    explainer.interpret(example, explanation_configs=explanation_configs, callback=predict_wrapper)


In [None]:
# Dummy just to visualize the cell lines neatly
interact_manual(
    lambda cell_line: print(cell_line), 
    cell_line=cell_lines.index
);

In [None]:
interpret_cell_line_with_lime(
    cell_lines.loc['JiyoyeP-2003'][genes].values
)

In [None]:
interpret_cell_line_with_anchor(
    cell_lines.loc['JiyoyeP-2003'][genes].values
)

#### What about PaccMann's attention?

In [None]:
# pick a drug
selected_drug = 'Imatinib'
classifier = PaccMannCellLine(drugs.loc[selected_drug].item(), cache_dir=cache_dir)

In [None]:
def attention_cell_line(cell_line, top_k=10):
    try:
        _ = classifier.predict([cell_lines.loc[cell_line][genes].values])
        gene_attention = next(classifier.predictor.predictions)['gene_attention'][0]
        pd.Series(dict(zip(genes, gene_attention))).sort_values(ascending=False)[:top_k].plot.bar()
    except:
        print('Cell line visualization not supported')

In [None]:
interact(
    attention_cell_line, cell_line=cell_lines.index,
    top_k=(1, 30, 1)
);