In [1]:
import pandas as pd
import numpy as np
import torch

# Autoreload extension
if 'autoreload' not in get_ipython().extension_manager.loaded:
    %load_ext autoreload
%autoreload 2

# Basic import
import eugene as eu
eu.__version__

# ISM attributions methods

In [3]:
import pandas as pd
import numpy as np
import torch
from yuzu.utils import perturbations
import eugene as eu

Global seed set to 13


GPU is available: True
Number of GPUs: 1
Current GPU: 0
GPUs: Quadro RTX 5000


In [4]:
sdata = eu.datasets.random1000()
X_np = eu.pp.ohe_seqs(sdata[:10].seqs)
model = eu.models.DeepBind(input_len=100, output_dim=2)

One-hot encoding sequences:   0%|          | 0/10 [00:00<?, ?it/s]

## Perturb seq

In [5]:
def perturb_seq(seq):
    """Numpy version of perturbations"""
    n_choices, seq_len = seq.shape
    idxs = seq.argmax(axis=0)
    n = seq_len * (n_choices - 1)
    X = np.tile(seq, (n, 1))
    X = X.reshape(n, n_choices, seq_len)
    for k in range(1, n_choices):
        i = np.arange(seq_len) * (n_choices - 1) + (k - 1)
        X[i, idxs, np.arange(seq_len)] = 0
        X[i, (idxs + k) % n_choices, np.arange(seq_len)] = 1
    return X

def perturb_seq_torch(seq):
    """Torch version of perturbations"""
    n_choices, seq_len = seq.shape
    idxs = seq.argmax(axis=0)
    n = seq_len * (n_choices - 1)
    X = torch.tile(seq, (n, 1))
    X = X.reshape(n, n_choices, seq_len)
    for k in range(1, n_choices):
        i = torch.arange(seq_len) * (n_choices - 1) + (k - 1)
        X[i, idxs, torch.arange(seq_len)] = 0
        X[i, (idxs + k) % n_choices, torch.arange(seq_len)] = 1
    return X


In [6]:
#perturb_seq(X_np[0]).shape, eu.pp.decode_seqs(perturb_seq_torch(torch.from_numpy(X_np[0])).numpy())[:5]

## Peturb seqs

In [7]:
def perturb_seqs(seqs):
    n_seqs, n_choices, seq_len = seqs.shape
    idxs = seqs.argmax(axis=1)
    n = seq_len * (n_choices - 1)
    X = np.tile(seqs, (n, 1, 1))
    X = X.reshape(n, n_seqs, n_choices, seq_len).transpose(1, 0, 2, 3)
    for i in range(n_seqs):
        for k in range(1, n_choices):
            idx = np.arange(seq_len) * (n_choices - 1) + (k - 1)

            X[i, idx, idxs[i], np.arange(seq_len)] = 0
            X[i, idx, (idxs[i] + k) % n_choices, np.arange(seq_len)] = 1
    return X

def perturb_seqs_torch(seqs):
    n_seqs, n_choices, seq_len = seqs.shape
    idxs = seqs.argmax(axis=1)
    n = seq_len * (n_choices - 1)
    X = torch.tile(seqs, (n, 1, 1))
    X = X.reshape(n, n_seqs, n_choices, seq_len).permute(1, 0, 2, 3)
    for i in range(n_seqs):
        for k in range(1, n_choices):
            idx = torch.arange(seq_len) * (n_choices - 1) + (k - 1)

            X[i, idx, idxs[i], torch.arange(seq_len)] = 0
            X[i, idx, (idxs[i] + k) % n_choices, torch.arange(seq_len)] = 1
    return X

In [8]:
#perturb_seqs(X_np).shape, perturb_seqs_torch(torch.from_numpy(X_np)).shape

In [9]:
#eu.pp.decode_seqs(perturb_seqs_torch(torch.from_numpy(X_np)).numpy()[0])[:5]

## Naive ISM

In [10]:
def delta(y, reference):
    return (y - reference).sum(axis=-1)
def l1(y, reference):
    return (y - reference).abs().sum(axis=-1)
def l2(y, reference):
    return torch.sqrt(torch.square(y - reference).sum(axis=-1))

DIFF_REGISTRY = {
    "delta": delta,
    "l1": l1,
    "l2": l2,
}

In [11]:
X_0 = torch.tensor(X_np, dtype=torch.float32)
model = eu.models.DeepBind(input_len=100, output_dim=2)
target = 0
batch_size = 10
diff_type = "delta"

In [50]:
def _naive_ism(
    model, 
    inputs, 
    target=None, 
    batch_size=128, 
    diff_type="delta", 
    device="cpu"
):
    
    # Get the number of sequences, choices, and sequence length
    n_seqs, n_choices, seq_len = inputs.shape
    n = seq_len * (n_choices - 1)
    X_idxs = inputs.argmax(axis=1)

    # If target not provided aggregate over all outputs
    target = np.arange(model.output_dim) if target is None else target
    
    # Move the model to eval mode
    model = model.eval()

    # Get the reference output
    reference = model(inputs)[:, target].unsqueeze(1)
    batch_starts = np.arange(0, n, batch_size)
    device="cpu"

    # Get the change in output for each perturbation
    isms = []
    for i in range(n_seqs):
        X = perturb_seq_torch(inputs[i])
        y = []
        for start in batch_starts:
            X_ = X[start : start + batch_size]
            y_ = model(X_)[:, target].unsqueeze(1)
            y.append(y_)
            del X_
        y = torch.cat(y)
        ism = DIFF_REGISTRY[diff_type](y, reference[i])
        isms.append(ism)

    # Clean up the output to be (N, A, L)
    isms = torch.stack(isms)
    isms = isms.reshape(n_seqs, seq_len, n_choices - 1)
    j_idxs = torch.arange(n_seqs * seq_len)
    X_ism = torch.zeros(n_seqs * seq_len, n_choices, device=device)
    for i in range(1, n_choices):
        i_idxs = (X_idxs.flatten() + i) % n_choices
        X_ism[j_idxs, i_idxs] = isms[:, :, i - 1].flatten()

    X_ism = X_ism.reshape(n_seqs, seq_len, n_choices).permute(0, 2, 1)
    return X_ism


In [51]:
out = _naive_ism(model, X_0, target=None, batch_size=10)

In [41]:
X_0[0][:, :5].permute(1, 0)

tensor([[0., 0., 0., 1.],
        [0., 0., 1., 0.],
        [0., 1., 0., 0.],
        [0., 0., 1., 0.],
        [1., 0., 0., 0.]])

In [42]:
orig = model(X_0[0].unsqueeze(0))

In [46]:
new = model(perturb_seq_torch(X_0[0])[0].unsqueeze(0))

In [47]:
(new - orig).sum(axis=-1)

tensor([0.0003], grad_fn=<SumBackward1>)

In [48]:
out[0][:, :5].permute(1, 0)

tensor([[ 3.0854e-04,  2.1504e-04,  1.2280e-04,  0.0000e+00],
        [ 5.5820e-05,  5.4076e-05,  0.0000e+00,  8.4758e-05],
        [-2.6880e-04,  0.0000e+00,  8.8975e-05,  2.8175e-04],
        [-1.8227e-04, -4.0574e-04,  0.0000e+00, -1.4119e-05],
        [ 0.0000e+00, -1.6875e-04,  7.9952e-03,  2.1660e-03]],
       grad_fn=<PermuteBackward0>)

In [23]:
eu.pp.decode_seq(X_0[0])

'TGCGAGGCCATGGCTCATGAGTTCTAAGGATGCGAATAACACAAAAAGCCGCGATCTTAAACGTTCTACACTTCTAAGGTCTGCATGAGCGAACCGAAAC'

In [24]:
eu.pp.decode_seq(perturb_seq_torch(X_0[0]).detach().numpy()[6])

'TGGGAGGCCATGGCTCATGAGTTCTAAGGATGCGAATAACACAAAAAGCCGCGATCTTAAACGTTCTACACTTCTAAGGTCTGCATGAGCGAACCGAAAC'

In [52]:
from typing import Union, Callable

ISM_REGISTRY = {
    "NaiveISM": _naive_ism,
}

def _ism_attributions(
    model: torch.nn.Module, 
    inputs: Union[tuple, torch.Tensor],
    method: Union[str, Callable],
    target: int = None,
    **kwargs
):
    attrs = ISM_REGISTRY[method](model=model, inputs=inputs, **kwargs)
    return attrs

In [53]:
model = eu.models.DeepBind(input_len=100, output_dim=2)

In [54]:
explains = _ism_attributions(model=model, inputs=X_0, method="NaiveISM")

# Captum registry

In [None]:
from captum.attr import InputXGradient, DeepLift, GradientShap, DeepLiftShap
CAPTUM_REGISTRY = {
    "InputXGradient": InputXGradient,
    "DeepLift": DeepLift,
    "DeepLiftShap": DeepLiftShap,
    "GradientShap": GradientShap,

}

def _captum_attributions(
    model: torch.nn.Module,
    inputs: Union[tuple, torch.Tensor],
    method: str,
    target: int = 0,
    **kwargs
):
    """
    """
    if isinstance(inputs, np.ndarray):
        inputs = torch.tensor(inputs)
    attributor = CAPTUM_REGISTRY[method](model)
    attrs = attributor.attribute(inputs=inputs, target=target, **kwargs)
    return attrs

In [None]:
model = eu.models.DeepBind(input_len=100, output_dim=2, strand="ds")

In [None]:
sdataset = sdata.to_dataset(target_keys=None, transform_kwargs={})
sdataloader = sdataset.to_dataloader(batch_size=32, shuffle=False)
batch = next(iter(sdataloader))
forward, rev = batch[1], batch[2]

No transforms given, assuming just need to tensorize.


In [None]:
forward_ref = _get_reference(forward, "gc")
reverse_ref = _get_reference(rev, "gc")



In [None]:
explains = _captum_attributions(model, (forward, rev), "GradientShap", target=1, baselines=(forward_ref, reverse_ref))

# Master attribution function

In [None]:
from eugene.models._base_models import BaseModel
from eugene import settings
from typing import Union, Callable

ATTRIBUTIONS_REGISTRY = {
    "NaiveISM": _ism_attributions,
    "InputXGradient": _captum_attributions,
    "DeepLift": _captum_attributions,
    "GradientShap": _captum_attributions,
    "DeepLiftShap": _captum_attributions,
}

def _model_to_device(model, device="cpu"):
    """
    """
    # Get the model to the correct device
    device = "cuda" if settings.gpus > 0 else "cpu" if device is None else device
    model.eval()
    model.to(device)
    return model

def attribute(
    model: BaseModel,
    inputs: torch.Tensor,
    method: Union[str, Callable],
    target: int = 0,
    device: str = "cpu",
    **kwargs
):

    # Put model on device
    model = _model_to_device(model, device)

    # Check kwargs for reference
    if "reference_type" in kwargs:
        ref_type = kwargs.pop("reference_type")
        kwargs["baselines"] = _get_reference(inputs, ref_type)

    # Get attributions
    attr = ATTRIBUTIONS_REGISTRY[method](
        model=model,
        inputs=inputs,
        method=method,
        target=target,
        **kwargs
    )

    # Return attributions
    return attr
     

In [None]:
model = eu.models.DeepBind(input_len=100, output_dim=2, strand="ss")

In [None]:
explains = attribute(model=model, inputs=forward, method="DeepLiftShap", target=0, baselines="gc")

               activations. The hooks and attributes will be removed
            after the attribution is finished
  after the attribution is finished"""
                be used as the gradients of the module's input tensor.
                See MaxPool1d as an example.
  module


In [None]:
explains.shape

torch.Size([32, 4, 100])

# Comparison to old implementations

In [None]:
model = eu.models.DeepBind(input_len=100, output_dim=2, strand="ss")

In [None]:
sdataset = sdata.to_dataset(target_keys=None, transform_kwargs={})
sdataloader = sdataset.to_dataloader(batch_size=32, shuffle=False)
batch = next(iter(sdataloader))
forward, rev = batch[1], batch[2]

No transforms given, assuming just need to tensorize.


In [None]:
forward_ref = _get_reference(forward, "zero")
reverse_ref = _get_reference(rev, "zero")



In [None]:
# Old
explains = eu.interpret.nn_explain(
    model=model, 
    inputs=(forward, rev), 
    saliency_type="DeepLift", 
    target=0, 
    baselines=(forward_ref, reverse_ref)
    )

In [None]:
explains_new = attribute(model=model, inputs=forward, method="DeepLift", target=0, baselines="zero")

               activations. The hooks and attributes will be removed
            after the attribution is finished
  after the attribution is finished"""
                be used as the gradients of the module's input tensor.
                See MaxPool1d as an example.
  module


In [None]:
explains_new[0].detach().cpu().numpy()[0]

array([ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        , -0.00145747,
        0.        ,  0.        ,  0.        , -0.        , -0.        ,
        0.        ,  0.00177175,  0.        ,  0.        , -0.00129909,
       -0.        ,  0.        , -0.        , -0.        ,  0.        ,
        0.00035432,  0.00486775, -0.        ,  0.        ,  0.00358446,
        0.        , -0.        ,  0.        ,  0.        , -0.00655105,
        0.00031941,  0.        ,  0.00604295, -0.00020719,  0.        ,
       -0.00688955,  0.        , -0.00539981, -0.00740858,  0.00150003,
       -0.00288175,  0.00425521, -0.        ,  0.        ,  0.        ,
        0.        ,  0.        , -0.        ,  0.00194549, -0.        ,
       -0.        , -0.        ,  0.        , -0.0007722 , -0.00143565,
       -0.00253362,  0.        , -0.        , -0.        ,  0.        ,
        0.        , -0.        , -0.00087424,  0.        ,  0.00

In [None]:
explains[0][0]

array([ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        , -0.00145747,
        0.        ,  0.        ,  0.        , -0.        , -0.        ,
        0.        ,  0.00177175,  0.        ,  0.        , -0.00129909,
       -0.        ,  0.        , -0.        , -0.        ,  0.        ,
        0.00035432,  0.00486775, -0.        ,  0.        ,  0.00358446,
        0.        , -0.        ,  0.        ,  0.        , -0.00655105,
        0.00031941,  0.        ,  0.00604295, -0.00020719,  0.        ,
       -0.00688955,  0.        , -0.00539981, -0.00740858,  0.00150003,
       -0.00288175,  0.00425521, -0.        ,  0.        ,  0.        ,
        0.        ,  0.        , -0.        ,  0.00194549, -0.        ,
       -0.        , -0.        ,  0.        , -0.0007722 , -0.00143565,
       -0.00253362,  0.        , -0.        , -0.        ,  0.        ,
        0.        , -0.        , -0.00087424,  0.        ,  0.00

---

# Scratch

## Implement DeepLiftShap

In [None]:
def _deepliftshap_explain(
    model: torch.nn.Module, 
    inputs: tuple,
    ref_type: str = "zero", 
    target: int = None, 
    device: str = "cpu"
):
    """
    Compute DeepLIFT feature attribution scores using a model on a set of inputs.

    Parameters
    ----------
    model : torch.nn.Module
        PyTorch model to use for computing DeepLIFT scores.
        Can be a EUGENe trained model or one you trained with PyTorch or PL.
    inputs : tuple
        Tuple of forward and reverse complement inputs to compute DeepLIFT scores on.
        If the model is a ss model, then the scores will only be computed on the forward inputs.
    ref_type: str
        Type of reference to use for computing DeepLIFT scores. By default this is an all zeros reference,
        but we also support a dinucleotide shuffled reference and one based on GC content
    target: int
        Index of the target class to compute scores for if there are multiple outputs. If there
        is a single output, this should be None
    device: str
        Device to use for computing DeepLIFT scores.
        EUGENe will always use a gpu if available
    
    Returns
    -------
    nd.array
        Array of DeepLIFT scores
    """
    # Run checks
    if model.strand == "ds":
        raise ValueError("DeepLift currently only works for ss and ts strand models")
    
    # Get the model to the correct device
    device = "cuda" if settings.gpus > 0 else "cpu" if device is None else device
    model.eval()
    model.to(device)

    # Set up the explainer
    deepliftshap_explainer = DeepLiftShap(model)

    # Prep the inputs
    forward_inputs = inputs[0]
    reverse_inputs = inputs[1]
    if isinstance(forward_inputs, torch.Tensor):
        forward_inputs = forward_inputs.detach().cpu().numpy()
        reverse_inputs = reverse_inputs.detach().cpu().numpy() 

    # Prep the reference
    if ref_type == "zero":
        forward_ref = torch.zeros(forward_inputs.shape)
        reverse_ref = torch.zeros(reverse_inputs.shape)
    elif ref_type == "shuffle":
        forward_ref = torch.tensor(dinuc_shuffle_seqs(forward_inputs))
        if model.strand != "ss":
            reverse_ref = torch.tensor(dinuc_shuffle_seqs(reverse_inputs))
    elif ref_type == "gc":
        forward_ref = torch.tensor([0.3, 0.2, 0.2, 0.3]).expand(forward.shape[0], forward.shape[2], 4).transpose(2, 1)
        reverse_ref = forward_ref.clone()
    elif callable(ref_type):
        forward_ref = torch.tensor(ref_type(forward_inputs))
        if model.strand != "ss":
            reverse_ref = torch.tensor(ref_type(reverse_inputs))
    forward_inputs = torch.tensor(forward_inputs).to(device)
    forward_ref.to(device)

    print(forward_inputs.shape, forward_ref.shape)
    
    # Compute the attribution scores
    if model.strand == "ss":
        attrs = deepliftshap_explainer.attribute(
            forward_inputs,
            baselines=(forward_ref),
            target=target,
        )
        return attrs.to("cpu").detach().numpy()
    else:
        reverse_inputs = torch.tensor(reverse_inputs).to(device)
        reverse_ref.to(device)
        attrs = deepliftshap_explainer.attribute(
            (forward_inputs, reverse_inputs),
            baselines=(forward_ref, reverse_ref),
            target=target,
        )
        return (
            attrs[0].to("cpu").detach().numpy(),
            attrs[1].to("cpu").detach().numpy(),
        )

In [None]:
def ablate_first_base(seqs):
    """
    Change the first base of each sequence in `seqs` to A, C, G, or T.
    This is used for computing the DeepLIFT scores for the first base
    """
    seqs[:, :, 0] = [0, 0, 0, 0]
    return seqs

In [None]:
ablate_first_base(ohe_seqs)[0,:,1]

array([0, 0, 1, 0], dtype=int8)

In [None]:
sdataset = sdata.to_dataset(target_keys=None, transform_kwargs={})
sdataloader = sdataset.to_dataloader(batch_size=32, shuffle=False)
batch = next(iter(sdataloader))
forward, rev = batch[1], batch[2]

No transforms given, assuming just need to tensorize.


In [None]:
model = eu.models.DeepBind(input_len=100, output_dim=2, strand="ts", aggr="avg")

In [None]:
_deepliftshap_explain(model, (forward, rev), target=0, ref_type=ablate_first_base)

torch.Size([32, 4, 100]) torch.Size([32, 4, 100])


               activations. The hooks and attributes will be removed
            after the attribution is finished
  after the attribution is finished"""
                be used as the gradients of the module's input tensor.
                See MaxPool1d as an example.
  module


RuntimeError: A Module MaxPool1d(kernel_size=85, stride=85, padding=0, dilation=1, ceil_mode=False) was detected that does not contain some of the input/output attributes that are required for DeepLift computations. This can occur, for example, if your module is being used more than once in the network.Please, ensure that module is being used only once in the network.

## Direct Captum usage

In [None]:
deepliftshap_explainer = DeepLiftShap(model)

In [None]:
deepliftshap_explainer.attribute(
    forward,
    baselines=(forward_ref),
    target=0,
)

               activations. The hooks and attributes will be removed
            after the attribution is finished
  after the attribution is finished"""
                be used as the gradients of the module's input tensor.
                See MaxPool1d as an example.
  module


tensor([[[ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0015,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0066,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0033,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]],

        [[ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]],

        [[ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000, -0.0003,  0.0048,  ...,  0.0005,  0.0000,  0.0000]],

        ...,

        [[ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.00