# Having fun with DeepBind

In [None]:
from depiction.models.deepbind import DeepBind
from depiction.core import Task, DataType
from depiction.interpreters.uw_model import UWModel
from depiction.models.deepbind import create_DNA_language
from ipywidgets import interact

## Setup task

In [None]:
task = Task.CLASSIFICATION
data_type = DataType.TEXT
class_names = ["NOT BINDING", "BINDING"]

# Instantiate the interpreters

In [None]:
# Create a LIME text interpreter
interpreter = "lime"
explanation_configs = {
    "labels": (1,),
}
interpreter_params = {
    "class_names": class_names,
    "split_expression": list,
    "bow": False,
    "char_level": True
}

lime_explainer = UWModel(interpreter, task, data_type, explanation_configs, **interpreter_params)

# Create Anchor text intepreter
interpreter = "anchor"
explanation_configs = {
    "use_proba": False,
    "batch_size": 100
}
interpreter_params = {
    "class_names": class_names,
    "nlp": create_DNA_language(),
    "unk_token": 'N',
    "sep_token": '',
    "use_unk_distribution": True
}

anchor_explainer = UWModel(interpreter, task, data_type, explanation_configs, **interpreter_params)

### Wrapper for the interactive widget

In [None]:
class InteractiveWrapper:
    def __init__(self, classifier):
        self.classifier = classifier
        
    def callback(self, sequence):
        # LIME
        self.classifier.use_labels = False
        lime_explainer.interpret(self.classifier.predict, sequence)
        # Anchors 
        #self.classifier.use_labels = True
        #anchor_explainer.interpret(self.classifier.predict, sequence)

# Let's interpret

In [None]:
tf_factor_id = "D00328.003"
classifier = DeepBind(tf_factor_id = tf_factor_id)
wrapper = InteractiveWrapper(classifier)

interact(wrapper.callback, sequence=["AGGCTAGCTAGGGGCGCCC", "AGGCTAGCTAGGGGCGCTT", "AGGGTAGCTAGGGGCGCTT", "AGGGTAGCTGGGGGCGCTT", "AGGCTAGGTGGGGGCGCTT", "AGGCTCGGTGGGGGCGCTT", "AGGCTCGGTAGGGGGCGATT"])

In [None]:
tf_factor_id = "D00794.047"
classifier = DeepBind(tf_factor_id = tf_factor_id)
wrapper = InteractiveWrapper(classifier)

interact(wrapper.callback, sequence="TGGCCAACCAGGGGGCGCTT")