# Jores et al 2021 Intepretation
**Authorship:**
Adam Klie (last updated: *06/08/2023*)
***
**Description:**
Notebook to interpret the best trained models on the Jores et al (2021) dataset.
***

In [None]:
if 'autoreload' not in get_ipython().extension_manager.loaded:
    %load_ext autoreload
%autoreload 2

In [None]:
# General imports
import os
import sys
import glob
import torch
import numpy as np
import torch.nn.functional as F

# EUGENe imports and settings
import eugene as eu
from eugene import preprocess as pp
from eugene import models
from eugene import interpret
from eugene import plot as pl
from eugene import settings
settings.dataset_dir = "/cellar/users/aklie/data/eugene/revision/jores21"
settings.config_dir = "/cellar/users/aklie/projects/ML4GLand/EUGENe_paper/configs/jores21"
settings.output_dir = "/cellar/users/aklie/projects/ML4GLand/EUGENe_paper/output/revision/jores21/"
settings.logging_dir = "/cellar/users/aklie/projects/ML4GLand/EUGENe_paper/logs/revision/jores21"
settings.figure_dir = "/cellar/users/aklie/projects/ML4GLand/EUGENe_paper/figures/revision/jores21"

# EUGENe packages
import seqdata as sd
import motifdata as md
import seqpro as sp

# For illustrator editing
import matplotlib
import matplotlib.pyplot as plt
matplotlib.rcParams['pdf.fonttype'] = 42
matplotlib.rcParams['ps.fonttype'] = 42

# Print versions
print(f"Python version: {sys.version}")
print(f"NumPy version: {np.__version__}")
print(f"Eugene version: {eu.__version__}")
print(f"SeqData version: {sd.__version__}")
print(f"PyTorch version: {torch.__version__}")

# Load in the `leaf`, `proto` and `combined` test `SeqData`s 

In [None]:
# Load in the preprcoessed test data with predictions
sdata_leaf = sd.open_zarr(os.path.join(settings.output_dir, "leaf", "leaf_test_predictions.zarr"))
sdata_proto = sd.open_zarr(os.path.join(settings.output_dir, "proto", "proto_test_predictions.zarr"))
sdata_combined = sd.open_zarr(os.path.join(settings.output_dir, "combined", "combined_test_predictions.zarr"))

# Load the best model

In [None]:
# Load them all but will choose one
leaf_model_file = glob.glob(os.path.join(settings.logging_dir, "hybrid", "leaf_trial_3", "checkpoints", "*"))[0]
leaf_model_arch = models.load_config(config_path="hybrid.yaml")
leaf_model = models.SequenceModule.load_from_checkpoint(leaf_model_file, arch=leaf_model_arch.arch)
proto_model_file = glob.glob(os.path.join(settings.logging_dir, "jores21_cnn", "proto_trial_3", "checkpoints", "*"))[0]
proto_model_arch = models.load_config(config_path="jores21_cnn.yaml")
proto_model = models.SequenceModule.load_from_checkpoint(proto_model_file, arch=proto_model_arch.arch)
combined_model_file = glob.glob(os.path.join(settings.logging_dir, "deepstarr", "combined_trial_5", "checkpoints", "*"))[0]
combined_model_arch = models.load_config(config_path="deepstarr.yaml")
combined_model = models.SequenceModule.load_from_checkpoint(combined_model_file, arch=combined_model_arch.arch)

In [None]:
# Choose which model you want to intepret
name = "leaf"
arch = "hybrid"
trial = 3
model = leaf_model
sdata = sdata_leaf

# Feature attribution

In [None]:
# Run DeepLift with a simple all 0s reference
method = "DeepLift"
interpret.attribute_sdata(
    model,
    sdata,
    method=method,
    batch_size=128,
    reference_type="zero",
    transforms={"ohe_seq": lambda x: torch.tensor(x, dtype=torch.float32).permute(0, 2, 1)}
)

In [None]:
top5 = sdata[f"{arch}_trial_{trial}_enrichment_predictions"].to_series().sort_values(ascending=False).iloc[:5].index
top5_idx = np.argsort(sdata[f"{arch}_trial_{trial}_enrichment_predictions"].values)[::-1][:5]
ids = sdata["id"].values[top5_idx]
pl.multiseq_track(
    sdata,
    seq_ids=ids,
    attrs_keys = f"{method}_attrs",
    ylabs=method,
    height=3,
    width=70,
    save=os.path.join(settings.figure_dir, name, f"{name}_best_model_feature_attr.pdf")
)

# Filter viz

In [None]:
if arch == "jores21_cnn":
    model.to("cuda")
    layer_name = "arch.biconv"
    seqs = sdata["ohe_seq"].transpose("_sequence", "_ohe", "length").to_numpy()
    seqs_torch = torch.tensor(seqs, dtype=torch.float32).to(model.device)
    kernel = models.get_layer(model, f"{layer_name}.kernels")[0].to(model.device)
    bias = models.get_layer(model, f"{layer_name}.biases")[0].to(model.device)
    activations = F.conv1d(seqs_torch, kernel, stride=1, padding="same")
    activations = torch.add(activations.transpose(1, 2), bias).transpose(1, 2)
    activations = activations.detach().cpu().numpy()
    padding = 6
    transforms=None
    kernel_size = 13
    num_filters = 256
else:
    if arch == "deepstarr":
        kernel_size = 7
        padding = 3
        layer_name = "arch.conv1d_tower.layers.2"
        num_filters = 246
    else:
        padding = 0
        kernel_size = 13
        layer_name = "arch.conv1d_tower.layers.1"
        num_filters = 256
    activations = None
    seqs = None
    transforms={"ohe_seq": lambda x: torch.tensor(x, dtype=torch.float32).permute(0, 2, 1)}
kernel_size, num_filters, padding, transforms, layer_name

In [None]:
# Generate pfms from filters
interpret.generate_pfms_sdata(
    model,
    sdata,
    seq_key="ohe_seq",
    layer_name=layer_name,
    kernel_size=kernel_size,
    activations=activations,
    seqs=seqs,
    num_filters=num_filters,
    padding=padding,
    num_seqlets=100,
    transforms=transforms,
)

In [None]:
# Visualize a filter of choice
pl.filter_viz(
    sdata,
    filter_num=179,
    pfms_key=f"{layer_name}_pfms",
)

In [None]:
# Visualize multiple filters at once and save
for i in range(8):
    start_filter = i*32
    end_filter = (i*32) + 32
    print(f"Plotting and saving filters {start_filter+1}-{end_filter}")
    pl.multifilter_viz(
        sdata,
        filter_nums=range(start_filter, end_filter),
        pfms_key=f"{layer_name}_pfms",
        num_rows=8,
        num_cols=4,
        titles=[f"filter {i}" for i in range(start_filter, end_filter)],
        save=os.path.join(settings.figure_dir, name, f"{name}_best_model_filters{start_filter+1}-{end_filter}_viz.pdf")
    )

In [None]:
# Save all the filter pfms from above as meme format for submission to TomTom
interpret.filters_to_meme_sdata(
    sdata,
    filters_key=f"{layer_name}_pfms",
    output_dir=os.path.join(settings.output_dir, name),
    filename=f"{name}_best_model_filters.meme"
)

In [None]:
# Save the predictions and interpretations to zarr
sd.to_zarr(sdata, os.path.join(settings.output_dir, name, f"{name}_test_predictions_and_interpretations.zarr"), load_first=True, mode="w")

# *in silico* evolution 

In [None]:
# Load in sequences that were evolved in the published paper
sdata_evolve = sd.read_table(
    name="seq",
    tables=os.path.join(settings.dataset_dir, "promoters_for_evolution.tsv"),
    out=os.path.join(settings.dataset_dir, "promoters_for_evolution.zarr"),
    seq_col="sequence",
    fixed_length=False,
    batch_size=310,
    overwrite=True
)
pp.ohe_seqs_sdata(sdata_evolve)

In [None]:
# Evolve them using the best model across 10 rounds
interpret.evolve_seqs_sdata(
    model,
    sdata_evolve,
    rounds=10
)

In [None]:
# Plot the distribution of scores at different rounds of evolution
ax = pl.violinplot(
    sdata_evolve,
    groupby=["original_score", "evolved_3_score", "evolved_5_score", "evolved_10_score"],
    xlabel="Evolution Round",
    ylabel="Score",
    color = "lightblue",
    return_axes=True,
)
ax.set_ylim(-3.5, 13)
plt.savefig(os.path.join(settings.figure_dir, name, f"{name}_best_model_evolution_summary.pdf"), dpi=300, bbox_inches="tight")

In [None]:
# Write out the evolved sequences and their scores (along with the original sequences)
sd.to_zarr(sdata_evolve, os.path.join(settings.output_dir, name, f"jores21_{name}_evolved_sequences.zarr"), load_first=True, mode="w")

# Positional GIA

In [None]:
# Reread in the evolved sequences and their scores
sdata_evolve = sd.open_zarr(os.path.join(settings.output_dir, name, f"{name}_evolved_sequences.zarr"))

In [None]:
# Read in the motif
motif_set = md.read_meme(os.path.join(settings.dataset_dir, "CPEs.meme"))
motif = motif_set["TATA"]
feat_name = motif.name
pfm = motif.pfm
consensus = motif.consensus
consensus_ohe = sp.ohe(consensus, alphabet=sp.alphabets.DNA)

# Generate some baseline sequences
zero_pfm = np.zeros(pfm.shape)
rand_pfm = sp.ohe(sp.random_seq(pfm.shape[0]), alphabet=sp.alphabets.DNA)
shuffled_pfm = sp.ohe(sp.k_shuffle(consensus, k=1).tobytes().decode(), alphabet=sp.alphabets.DNA)
zero_pfm.shape, rand_pfm.shape, shuffled_pfm.shape

In [None]:
# Slide the TATA motif across the sequences
interpret.positional_gia_sdata(
    model,
    sdata_evolve,
    seq_key="ohe_seq",
    id_key="name",
    feature=consensus_ohe,
    feature_name=feat_name,
    encoding="onehot",
    store_key=f"slide_{feat_name}",
)

# Slide a random seq across the sequences
interpret.positional_gia_sdata(
    model,
    sdata_evolve,
    seq_key="ohe_seq",
    id_key="name",
    feature=rand_pfm,
    feature_name="random",
    encoding="onehot",
    store_key=f"slide_random",
)

# Slide a zero ohe seq across the sequences 
interpret.positional_gia_sdata(
    model,
    sdata_evolve,
    seq_key="ohe_seq",
    id_key="name",
    feature=zero_pfm,
    feature_name="zero",
    encoding="onehot",
    store_key=f"slide_zero",
)

# Slide a TATA shuffled ohe seq across the sequences
interpret.positional_gia_sdata(
    model,
    sdata_evolve,
    seq_key="ohe_seq",
    id_key="name",
    feature=shuffled_pfm,
    feature_name="shuffled",
    encoding="onehot",
    store_key=f"slide_shuffled",
)   

In [None]:
# Check the average percentage increase across the sequences and positions
mean_original = sdata_evolve["original_score"].values.mean()
avg_increase = np.mean(np.subtract(sdata_evolve["slide_TATA"].values, np.expand_dims(sdata_evolve["original_score"].values, axis=1)), axis=1)
(avg_increase.mean()/mean_original).mean()*100

In [None]:
# Generate a line plot
pl.positional_gia_plot(
    sdata_evolve,
    keys=[f"slide_{feat_name}", "slide_shuffled", "slide_zero", "slide_random"],
    id_key="name",
    save=os.path.join(settings.figure_dir, name, f"{name}_best_model_feature_implant_TATA.pdf")
)

In [None]:
# Save the evolved sequences  with the TATA implanted scores as well
sd.to_zarr(sdata_evolve, os.path.join(settings.output_dir, name, f"{name}_evolved_sequences_with_TATA_implant.zarr"), load_first=True, mode="w")

# DONE!

---

# Scratch

In [None]:
sd.open_zarr(os.path.join(settings.output_dir, name, f"jores21_{name}_test_predictions_and_interpretations.zarr"))

In [None]:
sd.open_zarr(os.path.join(settings.output_dir, name, f"jores21_{name}_evolved_sequences.zarr"))