In [1]:
import pandas as pd
import numpy as np
import torch

# Autoreload extension
if 'autoreload' not in get_ipython().extension_manager.loaded:
    %load_ext autoreload
%autoreload 2

# Basic import
import eugene as eu
eu.__version__

Global seed set to 13


'0.0.6'

# Test data

In [2]:
sdata = eu.datasets.random1000()

In [3]:
eu.pp.ohe_seqs_sdata(sdata)

One-hot encoding sequences:   0%|          | 0/1000 [00:00<?, ?it/s]

SeqData object modified:
	ohe_seqs: None -> 1000 ohe_seqs added


In [4]:
eu.pp.reverse_complement_seqs_sdata(sdata)

SeqData object modified:
	ohe_rev_seqs: None -> 1000 ohe_rev_seqs added


# Clean up dinuc shuffle

In [5]:
# dinuc_shuffle
from eugene.preprocess._utils import _string_to_char_array, _char_array_to_string, _tokens_to_one_hot, _one_hot_to_tokens

# concise versions
from eugene.preprocess._utils import _one_hot2token, _token2one_hot

In [6]:
# Seqs to use
seqs = sdata.seqs
seq = seqs[0]
ohe_seqs = sdata.ohe_seqs
ohe_seq = ohe_seqs[0]
ohe_seq_T = ohe_seq.T

In [7]:
# Keep these from dincu_shuffle.py
char_array = _string_to_char_array(seq)
re_seq = _char_array_to_string(char_array)
seq, char_array, re_seq

('TGCGAGGCCATGGCTCATGAGTTCTAAGGATGCGAATAACACAAAAAGCCGCGATCTTAAACGTTCTACACTTCTAAGGTCTGCATGAGCGAACCGAAAC',
 array([84, 71, 67, 71, 65, 71, 71, 67, 67, 65, 84, 71, 71, 67, 84, 67, 65,
        84, 71, 65, 71, 84, 84, 67, 84, 65, 65, 71, 71, 65, 84, 71, 67, 71,
        65, 65, 84, 65, 65, 67, 65, 67, 65, 65, 65, 65, 65, 71, 67, 67, 71,
        67, 71, 65, 84, 67, 84, 84, 65, 65, 65, 67, 71, 84, 84, 67, 84, 65,
        67, 65, 67, 84, 84, 67, 84, 65, 65, 71, 71, 84, 67, 84, 71, 67, 65,
        84, 71, 65, 71, 67, 71, 65, 65, 67, 67, 71, 65, 65, 65, 67],
       dtype=int8),
 'TGCGAGGCCATGGCTCATGAGTTCTAAGGATGCGAATAACACAAAAAGCCGCGATCTTAAACGTTCTACACTTCTAAGGTCTGCATGAGCGAACCGAAAC')

In [8]:
# Match these with concise versions
tokens = _one_hot_to_tokens(ohe_seq_T)
re_ohe_seq = _tokens_to_one_hot(tokens, one_hot_dim=4)
ohe_seq_T[:5], tokens, re_ohe_seq[:5]

(array([[0, 0, 0, 1],
        [0, 0, 1, 0],
        [0, 1, 0, 0],
        [0, 0, 1, 0],
        [1, 0, 0, 0]], dtype=int8),
 array([3, 2, 1, 2, 0, 2, 2, 1, 1, 0, 3, 2, 2, 1, 3, 1, 0, 3, 2, 0, 2, 3,
        3, 1, 3, 0, 0, 2, 2, 0, 3, 2, 1, 2, 0, 0, 3, 0, 0, 1, 0, 1, 0, 0,
        0, 0, 0, 2, 1, 1, 2, 1, 2, 0, 3, 1, 3, 3, 0, 0, 0, 1, 2, 3, 3, 1,
        3, 0, 1, 0, 1, 3, 3, 1, 3, 0, 0, 2, 2, 3, 1, 3, 2, 1, 0, 3, 2, 0,
        2, 1, 2, 0, 0, 1, 1, 2, 0, 0, 0, 1]),
 array([[0., 0., 0., 1.],
        [0., 0., 1., 0.],
        [0., 1., 0., 0.],
        [0., 0., 1., 0.],
        [1., 0., 0., 0.]]))

In [9]:
# Concise versions
tokens = _one_hot2token(ohe_seq)
re_ohe_seq = _token2one_hot(tokens, vocab="DNA")
ohe_seq[:5], tokens, re_ohe_seq[:5]

(array([[0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0,
         0, 0, 0, 1, 1, 0, 0, 1, 0, 0, 0, 0, 1, 1, 0, 1, 1, 0, 1, 0, 1, 1,
         1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0,
         0, 1, 0, 1, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1,
         0, 0, 0, 1, 1, 0, 0, 0, 1, 1, 1, 0],
        [0, 0, 1, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0,
         0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0,
         0, 0, 0, 0, 1, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1,
         0, 0, 1, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0,
         0, 1, 0, 0, 0, 1, 1, 0, 0, 0, 0, 1],
        [0, 1, 0, 1, 0, 1, 1, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 1, 0, 1, 0,
         0, 0, 0, 0, 0, 1, 1, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 1, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0,
        

In [10]:
from eugene import settings
# Added concise
def dinuc_shuffle_seq(
    seq, 
    num_shufs=None, 
    rng=None
):
    """
    Creates shuffles of the given sequence, in which dinucleotide frequencies
    are preserved.

    If `seq` is a string, returns a list of N strings of length L, each one
    being a shuffled version of `seq`. If `seq` is a 2D np array, then the
    result is an N x L x D np array of shuffled versions of `seq`, also
    one-hot encoded. If `num_shufs` is not specified, then the first dimension
    of N will not be present (i.e. a single string will be returned, or an L x D
    array).

    Parameters
    ----------
    seq : str
        The sequence to shuffle.
    num_shufs : int, optional
        The number of shuffles to create. If None, only one shuffle is created.
    rng : np.random.RandomState, optional
        The random number generator to use. If None, a new one is created.

    Returns
    -------
    list of str or np.array
        The shuffled sequences.

    Note
    ----
    This function comes from DeepLIFT's dinuc_shuffle.py.
    """
    if type(seq) is str or type(seq) is np.str_:
        arr = _string_to_char_array(seq)
    elif type(seq) is np.ndarray and len(seq.shape) == 2:
        seq_len, one_hot_dim = seq.shape
        arr = _one_hot2token(seq)
    else:
        raise ValueError("Expected string or one-hot encoded array")
    if not rng:
        rng = np.random.RandomState(seed=settings.seed)

    # Get the set of all characters, and a mapping of which positions have which
    # characters; use `tokens`, which are integer representations of the
    # original characters
    chars, tokens = np.unique(arr, return_inverse=True)

    # For each token, get a list of indices of all the tokens that come after it
    shuf_next_inds = []
    for t in range(len(chars)):
        mask = tokens[:-1] == t  # Excluding last char
        inds = np.where(mask)[0]
        shuf_next_inds.append(inds + 1)  # Add 1 for next token

    if type(seq) is str or type(seq) is np.str_:
        all_results = []
    else:
        all_results = np.empty(
            (num_shufs if num_shufs else 1, seq_len, one_hot_dim), dtype=seq.dtype
        )

    for i in range(num_shufs if num_shufs else 1):
        # Shuffle the next indices
        for t in range(len(chars)):
            inds = np.arange(len(shuf_next_inds[t]))
            inds[:-1] = rng.permutation(len(inds) - 1)  # Keep last index same
            shuf_next_inds[t] = shuf_next_inds[t][inds]

        counters = [0] * len(chars)

        # Build the resulting array
        ind = 0
        result = np.empty_like(tokens)
        result[0] = tokens[ind]
        for j in range(1, len(tokens)):
            t = tokens[ind]
            ind = shuf_next_inds[t][counters[t]]
            counters[t] += 1
            result[j] = tokens[ind]

        if type(seq) is str or type(seq) is np.str_:
            all_results.append(_char_array_to_string(chars[result]))
        else:
            all_results[i] = _token2one_hot(chars[result])
    return all_results if num_shufs else all_results[0]


def dinuc_shuffle_seqs(seqs, num_shufs=None, rng=None):
    """
    Shuffle the sequences in `seqs` in the same way as `dinuc_shuffle_seq`.
    If `num_shufs` is not specified, then the first dimension of N will not be
    present (i.e. a single string will be returned, or an L x D array).

    Parameters
    ----------
    seqs : np.ndarray
        Array of sequences to shuffle
    num_shufs : int, optional
        Number of shuffles to create, by default None
    rng : np.random.RandomState, optional
        Random state to use for shuffling, by default None

    Returns
    -------
    np.ndarray
        Array of shuffled sequences

    Note
    -------
    This is taken from DeepLIFT
    """
    if not rng:
        rng = np.random.RandomState(seed=settings.seed)

    if type(seqs) is str or type(seqs) is np.str_:
        seqs = [seqs]

    all_results = []
    for i in range(len(seqs)):
        all_results.append(dinuc_shuffle_seq(seqs[i], num_shufs=num_shufs, rng=rng))
    return np.array(all_results)

In [11]:
eu.pp.decode_seq(ohe_seqs[0])

'TGCGAGGCCATGGCTCATGAGTTCTAAGGATGCGAATAACACAAAAAGCCGCGATCTTAAACGTTCTACACTTCTAAGGTCTGCATGAGCGAACCGAAAC'

In [12]:
eu.pp.decode_seqs(dinuc_shuffle_seqs(ohe_seqs))

Decoding sequences:   0%|          | 0/1000 [00:00<?, ?it/s]

array(['TTATAAATGAAAGATGGCGCGTAAACAGACGAACTCAACCTCCCAAGCTTGCGCTGGCGCTTGTCTAATCAGGCAGGAACTTCATCGTATGAACGAGAAC',
       'CCAGTTCTGCCGGCCTACAGAACTCGCCTGTGCGTGGAGCGGAAGCTGAAGGTTATTGGCGTGTTCTGCTATCGCAGCTAGCGTCTGTTCAAAATAGACA',
       'TGTCTTCCGCCTGCAACGACCTCTACGCGTGACCTAGGTCCCGCGGACTCTGTCTCACCCTTTTATGTGGAGGGAGTTGTTTGTTTGCTGATGCCTTGGC',
       'ATAGAGTCTGAACGCATAGATTACAATCCGTCGTAGCCCGTTACAGTGATTGCTTCGCAGGTTGTCCTATTCTTCACTCCGACACGTTCATCCGTGAGGG',
       'CCCTGACCGCTCCTCCATTCGTTTTCGAAACAAGTGTTGCGCTCCGCAGACGAGGTTACAGTTTACATGTCGCGATTCACATGCACACAGTAAGTGCCAG',
       'GGGCTGTGCACGTATTTAACTACTTTGTTGTGCCACTTTACTGGTGACGTAGAAGCACGCGTATGATGAGGTAGCCCGTAACCCCCGTCCGTCCAATCCC',
       'GTCCACGCTCAATAAAACACTGTCCACGAGCCCACTCCCAGCAATATCGCAGCATCGAAACACCTCTTGTTCTCTAGGGTTCATTCGGACCCCCTAGGTT',
       'CAGGGAAAATTCACCAGAATAAGAGATCCGAAGAGGTTAGGCCATGAGTCACACTGATGCCGAAAATCTATAATCTGGGGCCTTCAATTGTAATAGCCAT',
       'TCGGGAATGAAAAGGCCGGATTGCGCGATAATTGTAGCGCGTTAGAGCTGTTCCCGCAGTTCTACTTAGACAATCGGGTTATTACCCAGTGGTTCGCTTA',
 

# Reference registry

In [53]:
from typing import Union, Callable

def ablate_first_base(seqs):
    """
    Change the first base of each sequence in `seqs` to A, C, G, or T.
    This is used for computing the DeepLIFT scores for the first base
    """
    seqs[:, :, 0] = [0, 0, 0, 0]
    return seqs

def zero_ref_seqs(seqs):
    seqs = torch.zeros(seqs.shape)
    return seqs

def gc_ref_seqs(seqs, dists=None):
    if dists is None:
        dists = torch.Tensor([0.3, 0.2, 0.2, 0.3])
        seqs = dists.expand(seqs.shape[0], seqs.shape[2], 4).transpose(2, 1)
    else:
        seqs = dists.expand(seqs.shape[0], seqs.shape[2], 4).transpose(2, 1)
    return seqs

REFERENCE_REGISTRY = {
    "zero": zero_ref_seqs,
    "gc": gc_ref_seqs,
    "shuffle": dinuc_shuffle_seqs,
}

def _get_reference(
    seqs: np.ndarray,
    method: Union[str, Callable],
):
    """
    Returns torch.Tensor reference
    """
    if isinstance(method, str):
        if method not in REFERENCE_REGISTRY:
            raise ValueError(f"Reference method {method} not in {list(REFERENCE_REGISTRY.keys())}")
        if isinstance(seqs, tuple):
            return tuple([torch.Tensor(REFERENCE_REGISTRY[method](seqs[i])) for i in range(len(seqs))])
        else:
            return torch.tensor(REFERENCE_REGISTRY[method](seqs))
    elif callable(method):
        if isinstance(seqs, tuple):
            return tuple([torch.Tensor(method(seqs[i])) for i in range(len(seqs))])
        else:
            return torch.Tensor(method(seqs))
    else:
        raise ValueError(f"Reference method {method} not in {list(REFERENCE_REGISTRY.keys())}")

In [55]:
refs = _get_reference(ohe_seqs, "shuffle")

In [57]:
refs = _get_reference((ohe_seqs, ohe_seqs), "shuffle")

# ISM attributions methods

In [60]:
from typing import Union, Callable
from eugene.interpret._utils import _naive_ism

ISM_REGISTRY = {
    "NaiveISM": _naive_ism,
}

def _ism_attributions(
    model: torch.nn.Module, 
    inputs: tuple, 
    method: Union[str, Callable],
    target: int = 0,
    device: str = "cpu", 
    **kwargs
):
    if isinstance(inputs, torch.Tensor):
        inputs = inputs.detach().cpu().numpy()
    if model.strand != "ss":
        raise ValueError("ISM currrently only works for single strand models, but we are working on this!")
    attrs = ISM_REGISTRY[method](
        model=model,
        X_0=inputs,  # Rename to inputs eventually
        device=device,
        **kwargs
    )
    return attrs

In [61]:
model = eu.models.DeepBind(input_len=100, output_dim=2)

In [62]:
sdataset = sdata.to_dataset(target_keys=None, transform_kwargs={})
sdataloader = sdataset.to_dataloader(batch_size=32, shuffle=False)
batch = next(iter(sdataloader))
forward, rev = batch[1], batch[2]

No transforms given, assuming just need to tensorize.


In [64]:
explains = _ism_attributions(model=model, inputs=forward, method="NaiveISM")

# Captum registry

In [65]:
from captum.attr import InputXGradient, DeepLift, GradientShap, DeepLiftShap
CAPTUM_REGISTRY = {
    "InputXGradient": InputXGradient,
    "DeepLift": DeepLift,
    "DeepLiftShap": DeepLiftShap,
    "GradientShap": GradientShap,

}

def _captum_attributions(
    model: torch.nn.Module,
    inputs: tuple,
    method: str,
    target: int = 0,
    **kwargs
):
    """
    """
    if isinstance(inputs, np.ndarray):
        inputs = torch.tensor(inputs)
    attributor = CAPTUM_REGISTRY[method](model)
    attrs = attributor.attribute(inputs=inputs, target=target, **kwargs)
    return attrs

In [66]:
model = eu.models.DeepBind(input_len=100, output_dim=2, strand="ds")

In [67]:
sdataset = sdata.to_dataset(target_keys=None, transform_kwargs={})
sdataloader = sdataset.to_dataloader(batch_size=32, shuffle=False)
batch = next(iter(sdataloader))
forward, rev = batch[1], batch[2]

No transforms given, assuming just need to tensorize.


In [68]:
forward_ref = _get_reference(forward, "gc")
reverse_ref = _get_reference(rev, "gc")



In [69]:
explains = _captum_attributions(model, (forward, rev), "GradientShap", target=1, baselines=(forward_ref, reverse_ref))

# Master attribution function

In [70]:
from eugene.models._base_models import BaseModel
from eugene import settings
from typing import Union, Callable

ATTRIBUTIONS_REGISTRY = {
    "NaiveISM": _ism_attributions,
    "InputXGradient": _captum_attributions,
    "DeepLift": _captum_attributions,
    "GradientShap": _captum_attributions,
    "DeepLiftShap": _captum_attributions,
}

def _model_to_device(model, device="cpu"):
    """
    """
    # Get the model to the correct device
    device = "cuda" if settings.gpus > 0 else "cpu" if device is None else device
    model.eval()
    model.to(device)
    return model

def attribute(
    model: BaseModel,
    inputs: torch.Tensor,
    method: Union[str, Callable],
    target: int = 0,
    device: str = "cpu",
    **kwargs
):

    # Put model on device
    model = _model_to_device(model, device)

    # Check kwargs for reference
    if "reference_type" in kwargs:
        ref_type = kwargs.pop("reference_type")
        kwargs["baselines"] = _get_reference(inputs, ref_type)

    # Get attributions
    attr = ATTRIBUTIONS_REGISTRY[method](
        model=model,
        inputs=inputs,
        method=method,
        target=target,
        **kwargs
    )

    # Return attributions
    return attr
     

In [71]:
model = eu.models.DeepBind(input_len=100, output_dim=2, strand="ss")

In [72]:
explains = attribute(model=model, inputs=forward, method="DeepLiftShap", target=0, baselines="gc")

               activations. The hooks and attributes will be removed
            after the attribution is finished
  after the attribution is finished"""
                be used as the gradients of the module's input tensor.
                See MaxPool1d as an example.
  module


In [27]:
explains.shape

torch.Size([32, 4, 100])

# Comparison to old implementations

In [73]:
model = eu.models.DeepBind(input_len=100, output_dim=2, strand="ss")

In [74]:
sdataset = sdata.to_dataset(target_keys=None, transform_kwargs={})
sdataloader = sdataset.to_dataloader(batch_size=32, shuffle=False)
batch = next(iter(sdataloader))
forward, rev = batch[1], batch[2]

No transforms given, assuming just need to tensorize.


In [75]:
forward_ref = _get_reference(forward, "zero")
reverse_ref = _get_reference(rev, "zero")



In [76]:
# Old
explains = eu.interpret.nn_explain(
    model=model, 
    inputs=(forward, rev), 
    saliency_type="DeepLift", 
    target=0, 
    baselines=(forward_ref, reverse_ref)
    )

In [78]:
explains_new = attribute(model=model, inputs=forward, method="DeepLift", target=0, baselines="zero")

               activations. The hooks and attributes will be removed
            after the attribution is finished
  after the attribution is finished"""
                be used as the gradients of the module's input tensor.
                See MaxPool1d as an example.
  module


In [83]:
explains_new[0].detach().cpu().numpy()[0]

array([ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        , -0.00145747,
        0.        ,  0.        ,  0.        , -0.        , -0.        ,
        0.        ,  0.00177175,  0.        ,  0.        , -0.00129909,
       -0.        ,  0.        , -0.        , -0.        ,  0.        ,
        0.00035432,  0.00486775, -0.        ,  0.        ,  0.00358446,
        0.        , -0.        ,  0.        ,  0.        , -0.00655105,
        0.00031941,  0.        ,  0.00604295, -0.00020719,  0.        ,
       -0.00688955,  0.        , -0.00539981, -0.00740858,  0.00150003,
       -0.00288175,  0.00425521, -0.        ,  0.        ,  0.        ,
        0.        ,  0.        , -0.        ,  0.00194549, -0.        ,
       -0.        , -0.        ,  0.        , -0.0007722 , -0.00143565,
       -0.00253362,  0.        , -0.        , -0.        ,  0.        ,
        0.        , -0.        , -0.00087424,  0.        ,  0.00

In [84]:
explains[0][0]

array([ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        , -0.00145747,
        0.        ,  0.        ,  0.        , -0.        , -0.        ,
        0.        ,  0.00177175,  0.        ,  0.        , -0.00129909,
       -0.        ,  0.        , -0.        , -0.        ,  0.        ,
        0.00035432,  0.00486775, -0.        ,  0.        ,  0.00358446,
        0.        , -0.        ,  0.        ,  0.        , -0.00655105,
        0.00031941,  0.        ,  0.00604295, -0.00020719,  0.        ,
       -0.00688955,  0.        , -0.00539981, -0.00740858,  0.00150003,
       -0.00288175,  0.00425521, -0.        ,  0.        ,  0.        ,
        0.        ,  0.        , -0.        ,  0.00194549, -0.        ,
       -0.        , -0.        ,  0.        , -0.0007722 , -0.00143565,
       -0.00253362,  0.        , -0.        , -0.        ,  0.        ,
        0.        , -0.        , -0.00087424,  0.        ,  0.00

# Implement DeepLiftShap

In [143]:
def _deepliftshap_explain(
    model: torch.nn.Module, 
    inputs: tuple,
    ref_type: str = "zero", 
    target: int = None, 
    device: str = "cpu"
):
    """
    Compute DeepLIFT feature attribution scores using a model on a set of inputs.

    Parameters
    ----------
    model : torch.nn.Module
        PyTorch model to use for computing DeepLIFT scores.
        Can be a EUGENe trained model or one you trained with PyTorch or PL.
    inputs : tuple
        Tuple of forward and reverse complement inputs to compute DeepLIFT scores on.
        If the model is a ss model, then the scores will only be computed on the forward inputs.
    ref_type: str
        Type of reference to use for computing DeepLIFT scores. By default this is an all zeros reference,
        but we also support a dinucleotide shuffled reference and one based on GC content
    target: int
        Index of the target class to compute scores for if there are multiple outputs. If there
        is a single output, this should be None
    device: str
        Device to use for computing DeepLIFT scores.
        EUGENe will always use a gpu if available
    
    Returns
    -------
    nd.array
        Array of DeepLIFT scores
    """
    # Run checks
    if model.strand == "ds":
        raise ValueError("DeepLift currently only works for ss and ts strand models")
    
    # Get the model to the correct device
    device = "cuda" if settings.gpus > 0 else "cpu" if device is None else device
    model.eval()
    model.to(device)

    # Set up the explainer
    deepliftshap_explainer = DeepLiftShap(model)

    # Prep the inputs
    forward_inputs = inputs[0]
    reverse_inputs = inputs[1]
    if isinstance(forward_inputs, torch.Tensor):
        forward_inputs = forward_inputs.detach().cpu().numpy()
        reverse_inputs = reverse_inputs.detach().cpu().numpy() 

    # Prep the reference
    if ref_type == "zero":
        forward_ref = torch.zeros(forward_inputs.shape)
        reverse_ref = torch.zeros(reverse_inputs.shape)
    elif ref_type == "shuffle":
        forward_ref = torch.tensor(dinuc_shuffle_seqs(forward_inputs))
        if model.strand != "ss":
            reverse_ref = torch.tensor(dinuc_shuffle_seqs(reverse_inputs))
    elif ref_type == "gc":
        forward_ref = torch.tensor([0.3, 0.2, 0.2, 0.3]).expand(forward.shape[0], forward.shape[2], 4).transpose(2, 1)
        reverse_ref = forward_ref.clone()
    elif callable(ref_type):
        forward_ref = torch.tensor(ref_type(forward_inputs))
        if model.strand != "ss":
            reverse_ref = torch.tensor(ref_type(reverse_inputs))
    forward_inputs = torch.tensor(forward_inputs).to(device)
    forward_ref.to(device)

    print(forward_inputs.shape, forward_ref.shape)
    
    # Compute the attribution scores
    if model.strand == "ss":
        attrs = deepliftshap_explainer.attribute(
            forward_inputs,
            baselines=(forward_ref),
            target=target,
        )
        return attrs.to("cpu").detach().numpy()
    else:
        reverse_inputs = torch.tensor(reverse_inputs).to(device)
        reverse_ref.to(device)
        attrs = deepliftshap_explainer.attribute(
            (forward_inputs, reverse_inputs),
            baselines=(forward_ref, reverse_ref),
            target=target,
        )
        return (
            attrs[0].to("cpu").detach().numpy(),
            attrs[1].to("cpu").detach().numpy(),
        )

In [131]:
def ablate_first_base(seqs):
    """
    Change the first base of each sequence in `seqs` to A, C, G, or T.
    This is used for computing the DeepLIFT scores for the first base
    """
    seqs[:, :, 0] = [0, 0, 0, 0]
    return seqs

In [132]:
ablate_first_base(ohe_seqs)[0,:,1]

array([0, 0, 1, 0], dtype=int8)

In [133]:
sdataset = sdata.to_dataset(target_keys=None, transform_kwargs={})
sdataloader = sdataset.to_dataloader(batch_size=32, shuffle=False)
batch = next(iter(sdataloader))
forward, rev = batch[1], batch[2]

No transforms given, assuming just need to tensorize.


In [138]:
model = eu.models.DeepBind(input_len=100, output_dim=2, strand="ts", aggr="avg")

In [139]:
_deepliftshap_explain(model, (forward, rev), target=0, ref_type=ablate_first_base)

torch.Size([32, 4, 100]) torch.Size([32, 4, 100])


               activations. The hooks and attributes will be removed
            after the attribution is finished
  after the attribution is finished"""
                be used as the gradients of the module's input tensor.
                See MaxPool1d as an example.
  module


RuntimeError: A Module MaxPool1d(kernel_size=85, stride=85, padding=0, dilation=1, ceil_mode=False) was detected that does not contain some of the input/output attributes that are required for DeepLift computations. This can occur, for example, if your module is being used more than once in the network.Please, ensure that module is being used only once in the network.

## Direct Captum usage

In [101]:
deepliftshap_explainer = DeepLiftShap(model)

In [102]:
deepliftshap_explainer.attribute(
    forward,
    baselines=(forward_ref),
    target=0,
)

               activations. The hooks and attributes will be removed
            after the attribution is finished
  after the attribution is finished"""
                be used as the gradients of the module's input tensor.
                See MaxPool1d as an example.
  module


tensor([[[ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0015,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0066,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0033,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]],

        [[ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]],

        [[ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000, -0.0003,  0.0048,  ...,  0.0005,  0.0000,  0.0000]],

        ...,

        [[ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.00

---