# Understanding PaccMann

In [None]:
%%capture
# import all the needed libraries
import numpy as np
import pandas as pd
import tempfile
from rdkit import Chem
from sklearn.model_selection import train_test_split
import ipywidgets as widgets
from ipywidgets import interact, interact_manual
from IPython.display import SVG, display
from depiction.models.examples.paccmann import PaccMannSmiles, PaccMannCellLine
from depiction.models.examples.paccmann.smiles import (
    get_smiles_language, smiles_attention_to_svg,
    process_smiles, get_atoms
)
from depiction.core import Task, DataType
from depiction.interpreters.u_wash.u_washer import UWasher

cache_dir = tempfile.mkdtemp()

## Data

In [None]:
# Parse data from GDSC
# drugs
drugs = pd.read_csv(
    '../data/paccmann/gdsc.smi', sep='\t',
    index_col=1, header=None,
    names=['smiles']
)
# cell lines
cell_lines = pd.read_csv('../data/paccmann/gdsc.csv.gz', index_col=1)
genes = cell_lines.columns[3:].tolist()
# sensitivity data
drug_sensitivity = pd.read_csv('../data/paccmann/gdsc_sensitivity.csv.gz', index_col=0)
# labels
class_names = ['Not Effective', 'Effective']

## Interpretability on the drug level for a cell line of interest

### LIME and Anchor

In [None]:
# pick a cell line
selected_cell_line = 'NCI-H1648'
# filter and prepare data
selected_drug_sensitivity = drug_sensitivity[
    drug_sensitivity['cell_line'] == selected_cell_line
]
selected_drugs = drugs.reindex(selected_drug_sensitivity['drug']).dropna()
selected_drug_sensitivity = selected_drug_sensitivity.set_index('drug').reindex(
    selected_drugs.index
).dropna()
# setup a classifier for the specific cell line
classifier = PaccMannSmiles(cell_lines.loc[selected_cell_line][genes].values, cache_dir=cache_dir)
# interpretablity methods
def interpret_smiles_with_lime(example):
    explanation_configs = {
        'labels': (1,),
    }
    interpreter_params = {
        'class_names': class_names,
        'split_expression': list,
        'bow': False,
        'char_level': True
    }
    explainer = UWasher('lime', classifier, **interpreter_params)
    explainer.interpret(example, explanation_configs=explanation_configs)


def interpret_smiles_with_anchor(example):
    explanation_configs = {
        'use_proba': False,
        'batch_size': 32,
    }
    interpreter_params = {
        'class_names': class_names,
        'nlp': get_smiles_language(),
        'unk_token': '*',
        'sep_token': '',
        'use_unk_distribution': True
    }
    explainer = UWasher('anchors', classifier, **interpreter_params)
    def predict_wrapper(samples):
        return np.argmax(classifier.predict(samples), axis=1)
    explainer.interpret(example, explanation_configs=explanation_configs, callback=predict_wrapper)

    
def interpret_smiles(interpreter, drug):
    if interpreter == 'lime':
        interpret_smiles_with_lime(drugs.loc[drug].item())
    else:
        interpret_smiles_with_anchor(drugs.loc[drug].item())

In [None]:
interact_manual(
    interpret_smiles, interpreter=['lime', 'anchor'],
    drug=drugs.index
);

### What about PaccMann's attention?

In [None]:
# pick a cell line
selected_cell_line = 'NCI-H1648'
# setup a classifier for the specific cell line
classifier = PaccMannSmiles(cell_lines.loc[selected_cell_line][genes].values, cache_dir=cache_dir)

In [None]:
def attention_smiles(drug):
    try:
        smiles = drugs.loc[drug].item()
        molecule = Chem.MolFromSmiles(smiles)
        atoms = get_atoms(smiles)
        _ = classifier.predict([smiles])
        smiles_attention = next(classifier.predictor.predictions)['smiles_attention'][0]
        display(SVG(smiles_attention_to_svg(smiles_attention, atoms, molecule)))
    except:
        print('Structure visualization not supported')

In [None]:
interact(
    attention_smiles,
    drug=drugs.index
);

## Interpretability on the cell line level for a drug of interest

### LIME and Anchor

In [None]:
# pick a drug
selected_drug = 'Imatinib'
# filter and prepare data
selected_drug_sensitivity = drug_sensitivity[
    drug_sensitivity['drug'] == selected_drug
]
selected_cell_lines = cell_lines.reindex(selected_drug_sensitivity['cell_line']).dropna()
selected_drug_sensitivity = selected_drug_sensitivity.set_index('cell_line').reindex(
    selected_cell_lines.index
).dropna()
X_train, X_test, y_train, y_test = train_test_split(
    selected_cell_lines[genes].values, selected_drug_sensitivity['effective'].values
)
X_test, X_valid, y_test, y_valid = train_test_split(
    X_test, y_test
)
# setup a classifier for the specific drug
classifier = PaccMannCellLine(drugs.loc[selected_drug].item(), cache_dir=cache_dir)
# interpretablity methods
def interpret_cell_line_with_lime(example):
    explanation_configs = {
        'labels': (1,),
    }
    interpreter_params = {
        'training_data': X_train,
        'training_labels': y_train,
        'feature_names': genes,
        'class_names': class_names,
        'discretize_continuous': False,
        'sample_around_instance': True
    }
    explainer = UWasher('lime', classifier, **interpreter_params)
    explainer.interpret(example, explanation_configs=explanation_configs)


def interpret_cell_line_with_anchor(example):
    explanation_configs = {}
    interpreter_params = {
        'feature_names': genes,
        'class_names': class_names,
        'categorical_names': {}
    }

    explainer = UWasher('anchors', classifier, **interpreter_params)
    explainer.explainer.fit(
        X_train, y_train, X_valid, y_valid
    )
    def predict_wrapper(samples):
        return np.argmax(classifier.predict(samples), axis=1)
    explainer.interpret(example, explanation_configs=explanation_configs, callback=predict_wrapper)

    
def interpret_cell_line(interpreter, cell_line):
    if interpreter == 'lime':
        interpret_cell_line_with_lime(
            cell_lines.loc[cell_line][genes].values
        )
    else:
        interpret_cell_line_with_anchor(
            cell_lines.loc[cell_line][genes].values
        )

In [None]:
interact_manual(
    interpret_cell_line, interpreter=['lime', 'anchor'],
    cell_line=cell_lines.index
);

### What about PaccMann's attention?

In [None]:
# pick a drug
selected_drug = 'Imatinib'
classifier = PaccMannCellLine(drugs.loc[selected_drug].item(), cache_dir=cache_dir)

In [None]:
def attention_cell_line(cell_line, top_k=10):
    try:
        _ = classifier.predict([cell_lines.loc[cell_line][genes].values])
        gene_attention = next(classifier.predictor.predictions)['gene_attention'][0]
        pd.Series(dict(zip(genes, gene_attention))).sort_values(ascending=False)[:top_k].plot.bar()
    except:
        print('Cell line visualization not supported')

In [None]:
interact(
    attention_cell_line, cell_line=cell_lines.index,
    top_k=(1, 30, 1)
);