# Interpret DeepMEL model using EUGENe on `pbmc-granulocyte-sorted-3k_10x-Multiome`
Adam Klie (last updated: *09/20/2023*)
***
This notebook shows how to interpret a DeepMEL model using EUGENe on the `pbmc-granulocyte-sorted-3k_10x-Multiome` dataset.

# Set-up

In [None]:
# Load necessary packages
import os
import sys
import torch
import numpy as np
import pandas as pd
import tfomics
import matplotlib.pyplot as plt
import seqdata as sd
import seqexplainer as se
from eugene import models
from eugene.models.zoo import DeepMEL
from eugene import plot as pl
sys.path.append("/Users/adamklie/Desktop/research/projects/ML4GLand/use_cases/DeepMEL/scripts")

%matplotlib inline

In [None]:
# Set-up the paths to data (TODO: change to your own paths)
dataset_name = "pbmc-granulocyte-sorted-3k_10x-Multiome"
input_dir = '/cellar/users/aklie/projects/ML4GLand/use_cases/scBasset/pbmc-granulocyte-sorted-3k_10x-Multiome/processed'

# Load some data

In [None]:
# Load in the test data
test_sdata = sd.open_zarr(os.path.join(input_dir, f"{dataset_name}.test.zarr"))

In [None]:
# Load in the model
arch = DeepMEL(
    input_len=500, 
    output_dim=37,
    conv_kwargs={
        "conv_channels": [1024],  
    },
)
model = models.SequenceModule.load_from_checkpoint(
    os.path.join(input_dir, dataset_name, "multiome_cells_all_peaks.DeepMEL.revision/v0/checkpoints/epoch=15-step=16080.ckpt"),
    arch=arch,
)

In [None]:
# Choose which topic to interpret
topic_num = 16

In [None]:
# Run attributions with GradientShap
explains = se.attribute(
    model,
    inputs=torch.tensor(test_sdata["ohe_seqs"].values.transpose(0, 2, 1), dtype=torch.float32),
    method="GradientShap",
    target=topic_num-1,
    reference_type="shuffle",
    device="cuda",
    batch_size=128
)

In [None]:
# Get the top5 predictions
test_preds = model.predict(test_sdata["ohe_seqs"].transpose("_sequence", "_ohe", "length").values, batch_size=512)
test_preds = test_preds.cpu().numpy()
top5_ind = np.argsort(test_preds[:, topic_num-1])[::-1][:5]
test_preds[top5_ind]

In [None]:
# Ge the attribution scores for the top5 predictions
top5_explains = explains[list(top5_ind)]
top5_ind.shape, top5_explains.shape

In [None]:
def plot_saliency_map(explains, sort, width=13, height_per_explain=1):
    """
    Plot the saliency maps for each sequence
    """
    num_plot = len(explains)
    fig = plt.figure(figsize=(width, num_plot*height_per_explain))
    for i in range(num_plot):
        ax = plt.subplot(num_plot, 1, i+1)
        saliency_df = pd.DataFrame(explains[i].transpose([1,0]), columns=["A","C","G","T"])
        saliency_df.index.name = "pos"
        tfomics.impress.plot_attribution_map(saliency_df, ax, figsize=(num_plot,1))
        plt.ylabel(sort[i])

In [None]:
plot_saliency_map(explains[:5], top5_ind, width=30, height_per_explain=1.5)
plt.show()

In [None]:
# Run modisco
pos_patterns, neg_patterns = se.modisco(
    one_hot=test_sdata["ohe_seq"].values,
    hypothetical_contribs=explains.detach().cpu().numpy(),
    input_dir=input_dir,
    output_name=f"DeepSTRESS_30v2_modisco_topic{topic_num}.h5",
)

In [None]:
# Get modisco logos
se.modisco_logos(
    modisco_h5_file=os.path.join(input_dir, f"DeepSTRESS_30v2_modisco_topic{topic_num}.h5"),
    input_dir=os.path.join(input_dir, "topic1_logos"),
)

In [None]:
# Create modisco report
se.modisco_report(
    modisco_h5_file=os.path.join(input_dir, f"DeepSTRESS_30v2_modisco_topic{topic_num}.h5"),
    meme_db_file="/cellar/users/aklie/data/shared/meme/motif_databases/HUMAN/HOCOMOCOv11_core_HUMAN_mono_meme_format.meme",
    input_dir=os.path.join(input_dir, "report"),
    top_n_matches=2,
    trim_threshold=0.3,
    trim_min_length=3,
)

# DONE!

---