# Jores et al 2021 Intepretation
**Authorship:**
Adam Klie, *08/12/2022*
***
**Description:**
Notebook to interpret the best trained models on the Jores et al (2021) dataset.
***

In [None]:
if 'autoreload' not in get_ipython().extension_manager.loaded:
    %load_ext autoreload
%autoreload 2

import os
import glob
import logging
import torch
import numpy as np
import pandas as pd
import eugene as eu
import matplotlib.pyplot as plt
import matplotlib
import seaborn as sns

# For illustrator purposes
matplotlib.rcParams['pdf.fonttype'] = 42
matplotlib.rcParams['ps.fonttype'] = 42

In [None]:
eu.settings.dataset_dir = "/cellar/users/aklie/data/eugene/jores21"
eu.settings.output_dir = "/cellar/users/aklie/projects/EUGENe/EUGENe_paper/output/jores21"
eu.settings.logging_dir = "/cellar/users/aklie/projects/EUGENe/EUGENe_paper/logs/jores21"
eu.settings.config_dir = "/cellar/users/aklie/projects/EUGENe/EUGENe_paper/configs/jores21"
eu.settings.figure_dir = "/cellar/users/aklie/projects/EUGENe/EUGENe_paper/figures/jores21"

# Load in the `leaf`, `proto` and `combined` test `SeqData`s 

In [None]:
# Load in the preprocessed training data
sdata_leaf = eu.dl.read(os.path.join(eu.settings.output_dir, "leaf", "leaf_test_predictions.h5sd"))
sdata_proto = eu.dl.read(os.path.join(eu.settings.output_dir, "proto", "proto_test_predictions.h5sd"))
sdata_combined = eu.dl.read(os.path.join(eu.settings.output_dir, "combined", "combined_test_predictions.h5sd"))
sdata_leaf, sdata_proto, sdata_combined

# Load the best model

In [None]:
leaf_model_file = glob.glob(os.path.join(eu.settings.logging_dir, "ssHybrid", f"leaf_trial_5", "checkpoints", "*"))[0]
leaf_model = eu.models.Hybrid.load_from_checkpoint(leaf_model_file)
proto_model_file = glob.glob(os.path.join(eu.settings.logging_dir, "Jores21CNN", f"proto_trial_2", "checkpoints", "*"))[0]
proto_model = eu.models.Jores21CNN.load_from_checkpoint(proto_model_file)
combined_model_file = glob.glob(os.path.join(eu.settings.logging_dir, "Jores21CNN", f"combined_trial_3", "checkpoints", "*"))[0]
combined_model = eu.models.Jores21CNN.load_from_checkpoint(combined_model_file)

In [None]:
# Choose which model you want to intepret
name = "leaf"
model = leaf_model
sdata = sdata_leaf
model, sdata

# Feature attribution

In [None]:
# Run three saliency feature attribution methods across all sequences
saliency_methods = ["InputXGradient", "DeepLift", "GradientSHAP"]
for method in saliency_methods:
    eu.interpret.feature_attribution_sdata(
        model=model,
        sdata=sdata,
        method=method
    )

In [None]:
# Grab the top10 in terms of predictions to plot tracks for
top5 = sdata["ssHybrid_trial_5_enrichment_predictions"].sort_values(ascending=False).iloc[:5].index

eu.pl.multiseq_track(
    sdata,
    seq_ids=top5,
    uns_keys = "DeepLift_imps",
    ylabels="DeepLift",
    height=3,
    width=70,
    save=os.path.join(eu.settings.figure_dir, f"{name}_best_model_feature_attr.pdf")
)

# Filter viz

In [None]:
# Generate pfms from filters using method described in Minnoye et al. 2020
eu.interpret.generate_pfms_sdata(
    model=model, 
    sdata=sdata,
    method="Minnoye20"
)

In [None]:
# Visualize a filter of choice
eu.pl.filter_viz(
    sdata,
    filter_id=1
)

In [None]:
# Visualize multiple filters at once and save
for i in range(8):
    start_filter = i*32
    end_filter = (i*32) + 32
    print(f"Plotting and saving filters {start_filter+1}-{end_filter}")
    eu.pl.multifilter_viz(
        sdata,
        filter_ids=list(sdata.uns["pfms"].keys())[start_filter:end_filter],
        num_rows=8,
        num_cols=4,
        titles=[f"filter {i}" for i in range(start_filter, end_filter)],
        save=os.path.join(eu.settings.figure_dir, f"{name}_best_model_filters{start_filter+1}-{end_filter}_viz.pdf")
    )

In [None]:
# Save all the filter pfms from above as meme format for submission to TomTom
eu.dl.motif.filters_to_meme_sdata(
    sdata,
    filter_ids=list(sdata.uns["pfms"].keys()),
    output_dir=os.path.join(eu.settings.output_dir),
    file_name=f"{name}_best_model_filters.meme"
)

In [None]:
sdata.write_h5sd(os.path.join(eu.settings.output_dir, f"{name}_test_predictions_and_interpretations.h5sd"))

## In silico evolution 

In [None]:
# Load in sequences that were evolved in the published paper
sdata_evolve = eu.dl.read_csv(os.path.join(eu.settings.dataset_dir, "promoters_for_evolution.tsv"), seq_col="sequence", name_col="name")
eu.pp.ohe_seqs_sdata(sdata_evolve)
sdata_evolve

In [None]:
# Evolve them using the best model across 10 rounds
eu.interpret.evolve_seqs_sdata(
    model,
    sdata_evolve,
    rounds=10
)

In [None]:
# Plot the distribution of scores at different rounds of evolution
eu.pl.violinplot(
    sdata_evolve,
    groupby=["original_score", "evolved_3_score", "evolved_5_score", "evolved_10_score"],
    xlabel="Evolution Round",
    ylabel="Score",
    color = "lightblue",
    save=os.path.join(eu.settings.figure_dir, f"{name}_best_model_evolution_summary.pdf")
)

In [None]:
sdata_evolve.write_h5sd(os.path.join(eu.settings.output_dir, f"{name}_evolved_sequences.h5sd")) 

# In silico feature implant

In [None]:
sdata_evolve = eu.dl.read_h5sd(os.path.join(eu.settings.output_dir, name, f"{name}_evolved_sequences.h5sd"))

In [None]:
# Prep feature from meme file
meme = eu.dl.motif.MinimalMEME(path=os.path.join(eu.settings.dataset_dir, "CPEs.meme"))
motif = meme.motifs["TATA"]
feat_name = motif.name
pfm = motif.pfm
zero_pfm = np.zeros(pfm.shape)
rand_pfm = eu.pp.ohe_seq(eu.utils.random_seq(pfm.shape[0])).transpose()
consensus = motif.consensus
shuffled_pfm = eu.pp.ohe_seq(eu.pp.dinuc_shuffle_seq(consensus)).transpose()
feat_name, pfm, consensus

In [None]:
# Slide the TATA motif across the sequences
eu.interpret.feature_implant_seqs_sdata(
    model=model, 
    sdata=sdata_evolve, 
    feature=pfm, 
    seqsm_key=f"slide_{feat_name}",
    encoding="onehot", 
    onehot=True
)

# Slide a random seq across the sequences
eu.interpret.feature_implant_seqs_sdata(
    model=model, 
    sdata=sdata_evolve, 
    feature=rand_pfm,
    seqsm_key=f"slide_random",
    encoding="onehot", 
    onehot=True
)

# Slide a zero ohe seq across the sequences 
eu.interpret.feature_implant_seqs_sdata(
    model=model, 
    sdata=sdata_evolve, 
    feature=zero_pfm,
    seqsm_key=f"slide_zero",
    encoding="onehot", 
    onehot=True
)

# Slide a TATA shuffled ohe seq across the sequences 
eu.interpret.feature_implant_seqs_sdata(
    model=model, 
    sdata=sdata_evolve, 
    feature=shuffled_pfm,
    seqsm_key=f"slide_shuffled",
    encoding="onehot", 
    onehot=True
)

In [None]:
# Check the average percentage increase across the sequences and positions
mean_original = sdata_evolve["original_score"].mean()
avg_increase = np.mean(np.subtract(sdata_evolve.seqsm["slide_TATA"], np.expand_dims(sdata_evolve["original_score"], axis=1)), axis=1)
(avg_increase.mean()/mean_original).mean()*100

In [None]:
# Generate a line plot
eu.pl.feature_implant_plot(
    sdata_evolve,
    seqsm_keys=[f"slide_{feat_name}", "slide_shuffled", "slide_zero", "slide_random"],
    save=os.path.join(eu.settings.figure_dir, f"{name}_best_model_feature_implant_TATA.pdf")
)

In [None]:
sdata_evolve.write_h5sd(os.path.join(eu.settings.output_dir, f"{name}_evolved_sequences_with_TATA_implant.h5sd")) 

---