# Kopp et al 2021 Intepretation
**Authorship:**
Adam Klie (last updated: *06/10/2023*)
***
**Description:**
Notebook to interpret the best trained models on the Kopp et al (2021) dataset.
***

In [None]:
# General imports
import os
import sys
import glob
import torch
import numpy as np
import xarray as xr
import torch.nn.functional as F

# EUGENe imports and settings
import eugene as eu
from eugene import preprocess as pp
from eugene import models
from eugene import interpret
from eugene import plot as pl
from eugene import settings
settings.dataset_dir = "/cellar/users/aklie/data/eugene/revision/kopp21"
settings.config_dir = "/cellar/users/aklie/projects/ML4GLand/EUGENe_paper/configs/kopp21"
settings.output_dir = "/cellar/users/aklie/projects/ML4GLand/EUGENe_paper/output/revision/kopp21/"
settings.logging_dir = "/cellar/users/aklie/projects/ML4GLand/EUGENe_paper/logs/revision/kopp21"
settings.figure_dir = "/cellar/users/aklie/projects/ML4GLand/EUGENe_paper/figures/revision/kopp21"

# EUGENe packages
import seqdata as sd
import motifdata as md
import seqpro as sp

# For illustrator editing
import matplotlib
import matplotlib.pyplot as plt
matplotlib.rcParams['pdf.fonttype'] = 42
matplotlib.rcParams['ps.fonttype'] = 42

# Print versions
print(f"Python version: {sys.version}")
print(f"NumPy version: {np.__version__}")
print(f"Xarray version: {xr.__version__}")
print(f"Eugene version: {eu.__version__}")
print(f"SeqData version: {sd.__version__}")
print(f"MotifData version: {md.__version__}")
print(f"SeqPro version: {sp.__version__}")
print(f"PyTorch version: {torch.__version__}")

# Load in the test `SeqData`

In [None]:
# Read in the predictions 
sdata_test = sd.open_zarr(os.path.join(settings.output_dir, "test_predictions_all.zarr")).load()

In [None]:
# Give the sequences a unique ID
sdata_test["id"] = sdata_test["chrom"] + ":" + sdata_test["chromStart"].astype(str) + "-" + sdata_test["chromEnd"].astype(str)

# Load the best model

In [None]:
# Define the model type and trial
model_type = "dsfcn"
trial = 1

In [None]:
# Load up the model form the checkpoint
model_file = glob.glob(os.path.join(settings.logging_dir, f"{model_type}", f"trial_{trial}", "checkpoints", "*"))[0]
model_arch = models.load_config(config_path=f"{model_type}.yaml")
model = models.SequenceModule.load_from_checkpoint(model_file, arch=model_arch.arch)

## Feature attribution

In [None]:
# Run GradientShap with a simple all 0s reference
method = "GradientShap"
interpret.attribute_sdata(
    model,
    sdata_test,
    method=method,
    batch_size=128,
    reference_type="zero",
    transforms={"ohe_seq": lambda x: torch.tensor(x, dtype=torch.float32).permute(0, 2, 1)}
)

In [None]:
# Get the reverse complement of the one-hot encoded sequence
sdata_test["rc_ohe_seq"] = sp.reverse_complement(sdata_test["ohe_seq"], sp.alphabets.DNA, length_axis=1, ohe_axis=2)
interpret.attribute_sdata(
    model,
    sdata_test,
    method=method,
    batch_size=128,
    reference_type="zero",
    seq_var="rc_ohe_seq",
    transforms={"rc_ohe_seq": lambda x: torch.tensor(x, dtype=torch.float32).permute(0, 2, 1)},
    suffix="_rc"
)

In [None]:
# From the top 10 predicted sequences, identify the positions where we see significant attribution signal
top10 = sdata_test[f"{model_type}_trial_{trial}_target_predictions"].to_series().sort_values(ascending=False).iloc[:10].index
top10_idx = np.argsort(sdata_test[f"{model_type}_trial_{trial}_target_predictions"].values)[::-1][:10]
ids = sdata_test["id"].values

In [None]:
# Plot a test sequence
pl.seq_track(
    sdata_test,
    seq_id=ids[top10[0]],
    attrs_var="GradientShap_attrs",
    ylab="GradientShap Forward",
    figsize=(18, 3),
)

In [None]:
# Check the rev comp
pl.seq_track(
    sdata_test,
    seq_id=ids[top10[0]],
    attrs_var="GradientShap_attrs_rc",
    ylab="GradientShap Reverse",
    figsize=(18, 3),
)

## Filter viz (only for convnet)

In [None]:
#TODO
if model_type == "kopp21_cnn":
    model.to("cuda")
    layer_name = "arch.conv"
    layer = models.get_layer(model, layer_name)
    seqs = sdata_test["ohe_seq"].transpose("_sequence", "_ohe", "length").to_numpy()
    seqs_torch = torch.tensor(seqs, dtype=torch.float32).to(model.device)
    activations = F.relu(layer(seqs_torch)).detach().cpu().numpy()
    transforms = None
elif "ds" in model_type:
    model.to("cuda")
    layer_name = "arch.conv1d_tower.layers.0"
    layer = models.get_layer(model, layer_name)
    seqs = sdata_test["ohe_seq"].transpose("_sequence", "_ohe", "length").to_numpy()
    seqs_torch = torch.tensor(seqs, dtype=torch.float32).to(model.device)
    activations = F.relu(layer(seqs_torch)).detach().cpu().numpy()
    transforms = None
else:
    layer_name = "arch.conv1d_tower.layers.1"
    transforms = {"ohe_seq": lambda x: torch.tensor(x, dtype=torch.float32).permute(0, 2, 1)}
    activations = None
    seqs = None
kernel_size = 11
num_filters = 10
num_seqlets = 100
layer_name, transforms

In [None]:
# Generate pfms from filters
interpret.generate_pfms_sdata(
    model,
    sdata_test,
    seq_var="ohe_seq",
    layer_name=layer_name,
    activations=activations,
    seqs=seqs,
    kernel_size=11,
    num_filters=10,
    num_seqlets=100,
    transforms=transforms,
)

In [None]:
# Plot all the filters for the current model
pl.multifilter_viz(
    sdata_test,
    filter_nums=range(0, 10),
    pfms_var=f"{layer_name}_pfms",
    num_rows=2,
    num_cols=5,
    figsize=(10, 3),
    titles=[f"filter {i}" for i in range(0, 10)],
)

In [None]:
# Save all the filter pfms from above as meme format for submission to TomTom
interpret.filters_to_meme_sdata(
    sdata_test,
    filters_var=f"{layer_name}_pfms", 
    axis_order=(f"_{layer_name}_10_filters", "_ohe", f"_{layer_name}_11_kernel_size"),
    output_dir=os.path.join(settings.output_dir, model_type),
    filename=f"best_model_{model_type}_filters.meme"
)

In [None]:
# Save the predictions and interpretations to zarr
sd.to_zarr(sdata_test, os.path.join(settings.output_dir, model_type, f"test_predictions_and_interpretations.zarr"), mode="w")

## In silico interpretation

In [None]:
# Prep feature from meme file
meme = md.read_meme(os.path.join(settings.dataset_dir, "MA0491.1.meme"))
motif = meme.motifs["MA0491.1"]
feat_name = motif.name
pfm = motif.pfm
consensus = motif.consensus
consensus_ohe = sp.ohe(consensus, sp.alphabets.DNA)
pfm.shape, consensus_ohe.shape

In [None]:
# Get some background sequences
zero_pfm = np.zeros(pfm.shape)
rand_pfm = sp.ohe(sp.random_seq(pfm.shape[0]), alphabet=sp.alphabets.DNA)
shuffled_pfm = sp.ohe(sp.k_shuffle(consensus, k=1).tobytes().decode(), alphabet=sp.alphabets.DNA)
zero_pfm.shape, rand_pfm.shape, shuffled_pfm.shape

In [None]:
# Random seqs to implant into
random_seqs = sp.ohe(sp.random_seqs(10, 500), alphabet=sp.alphabets.DNA).transpose(0, 2, 1)
sdata_implant = xr.Dataset({"ohe_seq": xr.DataArray(random_seqs, dims=("_sequence", "_ohe", "length"))})
pp.make_unique_ids_sdata(sdata_implant, id_var="name")

In [None]:
# Slide the JUND motif across the sequences
interpret.positional_gia_sdata(
    model,
    sdata_implant,
    seq_var="ohe_seq",
    id_var="name",
    feature=consensus_ohe,
    feature_name=feat_name,
    encoding="onehot",
    store_var=f"slide_{feat_name}",
)

# Slide a random seq across the sequences
interpret.positional_gia_sdata(
    model,
    sdata_implant,
    seq_var="ohe_seq",
    id_var="name",
    feature=rand_pfm,
    feature_name="random",
    encoding="onehot",
    store_var=f"slide_random",
)

# Slide a zero ohe seq across the sequences 
interpret.positional_gia_sdata(
    model,
    sdata_implant,
    seq_var="ohe_seq",
    id_var="name",
    feature=zero_pfm,
    feature_name="zero",
    encoding="onehot",
    store_var=f"slide_zero",
)

# Slide a TATA shuffled ohe seq across the sequences
interpret.positional_gia_sdata(
    model,
    sdata_implant,
    seq_var="ohe_seq",
    id_var="name",
    feature=shuffled_pfm,
    feature_name="shuffled",
    encoding="onehot",
    store_var=f"slide_shuffled",
)   

In [None]:
# Get the original predictions values
orig_preds = model.predict(torch.tensor(sdata_implant["ohe_seq"].values, dtype=torch.float32).to(model.device)).detach().cpu().numpy()

In [None]:
#TODO
np.min(orig_preds), np.max(orig_preds)

In [None]:
# Plot the implanted scores across positions
ax = pl.positional_gia_plot(
    sdata_implant,
    vars=[f"slide_{feat_name}", "slide_shuffled", "slide_zero", "slide_random"],
    id_var="name",
    save=os.path.join(eu.settings.figure_dir, model_type, f"best_{model_type}_model_feature_implant_jund.pdf"),
    ylim=(-4.5, 4.2),
    return_axes=True
)

In [None]:
# Plot the original scores for the sequences
plt.boxplot(orig_preds)
plt.ylim(-4.5, 4.2)
plt.savefig(os.path.join(eu.settings.figure_dir, model_type, f"best_{model_type}_model_random_seq_scores.pdf"))

In [None]:
# Save the evolved sequences  with the TATA implanted scores as well
sd.to_zarr(sdata_implant, os.path.join(settings.output_dir, model_type, f"implant_JUND_{model_type}.zarr"), mode="w")

# DONE!

---

# Scratch