# Jores et al 2021 Plotting 
**Authorship:**
Adam Klie (last updated: *06/08/2023*)
***
**Description:**
Notebook to generate plots for the Jores et al (2021) dataset that are not included in the other notebooks.
 - Summary table of benchmarking results for for each model type
 - Cleaner seq track plots for top sequences
 - TomTom filter annotation analysis
 - Loss and metric plots
***

In [None]:
# General imports
import os
import sys
import numpy as np
import pandas as pd
from copy import deepcopy
from itertools import groupby
from operator import itemgetter
from scipy.stats import mannwhitneyu
from statsmodels.stats.multitest import multipletests

# EUGENe imports and settings
import eugene as eu
from eugene import plot as pl
from eugene import settings
settings.dataset_dir = "/cellar/users/aklie/data/eugene/revision/jores21"
settings.output_dir = "/cellar/users/aklie/projects/ML4GLand/EUGENe_paper/output/revision/jores21"
settings.logging_dir = "/cellar/users/aklie/projects/ML4GLand/EUGENe_paper/logs/revision/jores21"
settings.figure_dir = "/cellar/users/aklie/projects/ML4GLand/EUGENe_paper/figures/revision/jores21"

# EUGENe packages
import seqdata as sd
import motifdata as md

# For illustrator editing
import matplotlib
import matplotlib.pyplot as plt
import seaborn as sns
matplotlib.rcParams['pdf.fonttype'] = 42
matplotlib.rcParams['ps.fonttype'] = 42

# Print versions
print(f"Python version: {sys.version}")
print(f"NumPy version: {np.__version__}")
print(f"Eugene version: {eu.__version__}")
print(f"SeqData version: {sd.__version__}")

# Generate performance figures

## Leaf models

In [None]:
# Read in the predictions 
leaf_predictions = pd.read_csv(os.path.join(settings.output_dir, "leaf", "leaf_test_predictions.tsv"), sep="\t", index_col=0)
sdata_leaf = sd.open_zarr(os.path.join(settings.output_dir, "leaf", "leaf_test_predictions.zarr")).load()

In [None]:
# Get only variables with "predictions" in the name
preds_vars = [k for k in sdata_leaf.keys() if "predictions" in k]

# Order the pred_vars from in this order ["cnn", "hyrbrid", "jores21_cnn", "deepstarr"]
order = ["cnn", "hybrid", "deepstarr", "jores21_cnn"]
pred_models = [k.split("_")[0] if "jores21" not in k else "jores21_cnn" for k in preds_vars]
pred_models = [order.index(m) for m in pred_models]
preds_vars = [k for _, k in sorted(zip(pred_models, preds_vars))]

# Get groups based io
model_groups = {"cnn": "cnn", "hybrid": "hybrid", "deepstarr": "deepstarr", "jores21_cnn": "jores21_cnn"}
groups = [model_groups[k.split("_")[0]] if "jores21" not in k else "jores21_cnn" for k in preds_vars]

In [None]:
# Summarize performance across models for r2
leaf_model_scores = pl.performance_summary(
    sdata_leaf,
    target_var="enrichment",
    prediction_vars=preds_vars,
    prediction_groups=groups,
    metrics=["r2"],
    add_swarm=False,
    figsize=(6, 6),
    save=os.path.join(settings.figure_dir, "leaf", "leaf_performance_boxplot.pdf")
)

In [None]:
# Calculate significance of differences between groups
pairwise_tests = []
compare_df = leaf_model_scores.pivot(columns="prediction_groups", values="r2")
pairwise_tests.append(mannwhitneyu(compare_df["hybrid"].dropna(), compare_df["jores21_cnn"].dropna()).pvalue)
pairwise_tests.append(mannwhitneyu(compare_df["hybrid"].dropna(), compare_df["cnn"].dropna()).pvalue)
pairwise_tests.append(mannwhitneyu(compare_df["cnn"].dropna(), compare_df["jores21_cnn"].dropna()).pvalue)
pairwise_tests.append(mannwhitneyu(compare_df["deepstarr"].dropna(), compare_df["jores21_cnn"].dropna()).pvalue)
pairwise_tests.append(mannwhitneyu(compare_df["deepstarr"].dropna(), compare_df["cnn"].dropna()).pvalue)
pairwise_tests.append(mannwhitneyu(compare_df["deepstarr"].dropna(), compare_df["hybrid"].dropna()).pvalue)
pairwise_tests = np.array(pairwise_tests)
multipletests(pairwise_tests, alpha=0.05, method="fdr_bh")

In [None]:
# Summarize performance across models for multiple metrics
leaf_model_scores = pl.performance_summary(
    sdata_leaf,
    target_var="enrichment",
    prediction_vars=preds_vars,
    prediction_groups=groups,
    metrics=["r2", "mse", "pearson", "spearman", "kendall"],
    add_swarm=False,
    figsize=(6, 6),
)

# Identify the best model from returned model scores
r2_sorted = leaf_model_scores["r2"].sort_values(ascending=False)
r2_sorted.plot(kind="bar", ylabel="R2")
plt.savefig(os.path.join(settings.figure_dir, "leaf", "leaf_performance_summary.pdf"))

In [None]:
# Save the performance summary as a table
leaf_model_scores.to_csv(os.path.join(settings.output_dir, "leaf", "leaf_performance_summary.tsv"), sep="\t")

In [None]:
# Plot the performances across species for the best model
best_preds = r2_sorted.index[4]
ax = pl.performance_scatter(
    sdata_leaf, 
    target_vars="enrichment", 
    prediction_vars=best_preds,
    alpha=0.5,
    groupby="sp",
    figsize=(8, 8),
    rasterized=True,
    save=os.path.join(settings.figure_dir, "leaf", "leaf_best_model_performance_scatter_by_sp.pdf")
)

In [None]:
# Plot the performance for all species to use in Figure 2 
ax = pl.performance_scatter(
    sdata_leaf, 
    target_vars="enrichment", 
    prediction_vars=best_preds,
    alpha=0.5,
    figsize=(4, 4),
    rasterized=True,
    save=os.path.join(settings.figure_dir, "leaf", "leaf_best_model_performance_scatter.pdf")
)

## Proto models

In [None]:
# Read in the predictions 
proto_predictions = pd.read_csv(os.path.join(settings.output_dir, "proto", "proto_test_predictions.tsv"), sep="\t", index_col=0)
sdata_proto = sd.open_zarr(os.path.join(settings.output_dir, "proto", "proto_test_predictions.zarr")).load()

In [None]:
# Get only variables with "predictions" in the name
preds_vars = [k for k in sdata_proto.keys() if "predictions" in k]

# Order the pred_vars from in this order ["cnn", "hyrbrid", "jores21_cnn", "deepstarr"]
order = ["cnn", "hybrid", "deepstarr", "jores21_cnn"]
pred_models = [k.split("_")[0] if "jores21" not in k else "jores21_cnn" for k in preds_vars]
pred_models = [order.index(m) for m in pred_models]
preds_vars = [k for _, k in sorted(zip(pred_models, preds_vars))]

# Get groups based io
model_groups = {"cnn": "cnn", "hybrid": "hybrid", "deepstarr": "deepstarr", "jores21_cnn": "jores21_cnn"}
groups = [model_groups[k.split("_")[0]] if "jores21" not in k else "jores21_cnn" for k in preds_vars]

In [None]:
# Summarize performance across models for r2
proto_model_scores = pl.performance_summary(
    sdata_proto,
    target_var="enrichment",
    prediction_vars=preds_vars,
    prediction_groups=groups,
    metrics=["r2"],
    add_swarm=False,
    figsize=(6, 6),
    save=os.path.join(settings.figure_dir, "proto", "proto_performance_boxplot.pdf")
)

In [None]:
# Calculate significance of differences between groups
pairwise_tests = []
compare_df = proto_model_scores.pivot(columns="prediction_groups", values="r2")
pairwise_tests.append(mannwhitneyu(compare_df["hybrid"].dropna(), compare_df["jores21_cnn"].dropna()).pvalue)
pairwise_tests.append(mannwhitneyu(compare_df["hybrid"].dropna(), compare_df["cnn"].dropna()).pvalue)
pairwise_tests.append(mannwhitneyu(compare_df["cnn"].dropna(), compare_df["jores21_cnn"].dropna()).pvalue)
pairwise_tests.append(mannwhitneyu(compare_df["deepstarr"].dropna(), compare_df["jores21_cnn"].dropna()).pvalue)
pairwise_tests.append(mannwhitneyu(compare_df["deepstarr"].dropna(), compare_df["cnn"].dropna()).pvalue)
pairwise_tests.append(mannwhitneyu(compare_df["deepstarr"].dropna(), compare_df["hybrid"].dropna()).pvalue)
pairwise_tests = np.array(pairwise_tests)
multipletests(pairwise_tests, alpha=0.05, method="fdr_bh")

In [None]:
# Summarize performance across models for multiple metrics
proto_model_scores = pl.performance_summary(
    sdata_proto,
    target_var="enrichment",
    prediction_vars=preds_vars,
    prediction_groups=groups,
    metrics=["r2", "mse", "pearson", "spearman", "kendall"],
    add_swarm=False,
    figsize=(6, 6),
)

# Identify the best model from returned model scores
r2_sorted = proto_model_scores["r2"].sort_values(ascending=False)
r2_sorted.plot(kind="bar", ylabel="R2")
plt.savefig(os.path.join(settings.figure_dir, "proto", "proto_performance_summary.pdf"))

In [None]:
# Save the performance summary as a table
proto_model_scores.to_csv(os.path.join(settings.output_dir, "proto", "proto_performance_summary.tsv"), sep="\t")

In [None]:
# Plot the performances across species for the best model
best_preds = r2_sorted.index[0]
ax = pl.performance_scatter(
    sdata_proto, 
    target_vars="enrichment", 
    prediction_vars=best_preds,
    alpha=0.5,
    groupby="sp",
    figsize=(8, 8),
    rasterized=True,
    save=os.path.join(settings.figure_dir, "proto", "proto_best_model_performance_scatter_by_sp.pdf")
)

In [None]:
# Plot the performance for all species to use in Figure 2 
ax = pl.performance_scatter(
    sdata_proto, 
    target_vars="enrichment", 
    prediction_vars=best_preds,
    alpha=0.5,
    figsize=(4, 4),
    rasterized=True,
    save=os.path.join(settings.figure_dir, "proto", "proto_best_model_performance_scatter.pdf")
)

## Combined models

In [None]:
# Read in the predictions 
combined_predictions = pd.read_csv(os.path.join(settings.output_dir, "combined", "combined_test_predictions.tsv"), sep="\t", index_col=0)
sdata_combined = sd.open_zarr(os.path.join(settings.output_dir, "combined", "combined_test_predictions.zarr")).load()

In [None]:
# Get only variables with "predictions" in the name
preds_vars = [k for k in sdata_combined.keys() if "predictions" in k]

# Order the pred_vars from in this order ["cnn", "hyrbrid", "jores21_cnn", "deepstarr"]
order = ["cnn", "hybrid", "jores21_cnn", "deepstarr"]
pred_models = [k.split("_")[0] if "jores21" not in k else "jores21_cnn" for k in preds_vars]
pred_models = [order.index(m) for m in pred_models]
preds_vars = [k for _, k in sorted(zip(pred_models, preds_vars))]

# Get groups based io
model_groups = {"cnn": "cnn", "hybrid": "hybrid", "jores21_cnn": "jores21_cnn", "deepstarr": "deepstarr"}
groups = [model_groups[k.split("_")[0]] if "jores21" not in k else "jores21_cnn" for k in preds_vars]

In [None]:
# Summarize performance across models for r2
combined_model_scores = pl.performance_summary(
    sdata_combined,
    target_var="enrichment",
    prediction_vars=preds_vars,
    prediction_groups=groups,
    metrics=["r2"],
    add_swarm=False,
    figsize=(6, 6),
    save=os.path.join(settings.figure_dir, "combined", "combined_performance_boxplot.pdf")
)

In [None]:
# Calculate significance of differences between groups
pairwise_tests = []
compare_df = combined_model_scores.pivot(columns="prediction_groups", values="r2")
pairwise_tests.append(mannwhitneyu(compare_df["hybrid"].dropna(), compare_df["jores21_cnn"].dropna()).pvalue)
pairwise_tests.append(mannwhitneyu(compare_df["hybrid"].dropna(), compare_df["cnn"].dropna()).pvalue)
pairwise_tests.append(mannwhitneyu(compare_df["cnn"].dropna(), compare_df["jores21_cnn"].dropna()).pvalue)
pairwise_tests.append(mannwhitneyu(compare_df["deepstarr"].dropna(), compare_df["jores21_cnn"].dropna()).pvalue)
pairwise_tests.append(mannwhitneyu(compare_df["deepstarr"].dropna(), compare_df["cnn"].dropna()).pvalue)
pairwise_tests.append(mannwhitneyu(compare_df["deepstarr"].dropna(), compare_df["hybrid"].dropna()).pvalue)
pairwise_tests = np.array(pairwise_tests)
multipletests(pairwise_tests, alpha=0.05, method="fdr_bh")

In [None]:
# Summarize performance across models for multiple metrics
combined_model_scores = pl.performance_summary(
    sdata_combined,
    target_var="enrichment",
    prediction_vars=preds_vars,
    prediction_groups=groups,
    metrics=["r2", "mse", "pearson", "spearman", "kendall"],
    add_swarm=False,
    figsize=(6, 6),
)

# Identify the best model from returned model scores
r2_sorted = combined_model_scores["r2"].sort_values(ascending=False)
r2_sorted.plot(kind="bar", ylabel="R2")
plt.savefig(os.path.join(settings.figure_dir, "combined", "combined_performance_summary.pdf"))

In [None]:
# Save the performance summary as a table
combined_model_scores.to_csv(os.path.join(settings.output_dir, "combined", "combined_performance_summary.tsv"), sep="\t")

In [None]:
# Plot the performances across species for the best model
best_preds = r2_sorted.index[1]
ax = pl.performance_scatter(
    sdata_combined, 
    target_vars="enrichment", 
    prediction_vars=best_preds,
    alpha=0.5,
    groupby="sp",
    figsize=(8, 8),
    rasterized=True,
    save=os.path.join(settings.figure_dir, "combined", "combined_best_model_performance_scatter_by_sp.pdf")
)

In [None]:
# Plot the performance for all species to use in Figure 2 
ax = pl.performance_scatter(
    sdata_combined, 
    target_vars="enrichment", 
    prediction_vars=best_preds,
    alpha=0.5,
    figsize=(4, 4),
    rasterized=True,
    save=os.path.join(settings.figure_dir, "combined", "combined_best_model_performance_scatter.pdf")
)

# Performance summary table

In [None]:
# Combined everything into one dataframe
leaf_model_scores["model"] = "leaf"
proto_model_scores["model"] = "proto"
combined_model_scores["model"] = "combined"
merged_model_scores = pd.concat([leaf_model_scores, proto_model_scores, combined_model_scores])
merged_model_scores.to_csv(os.path.join(settings.output_dir, "merged_performance_summary.tsv"), sep="\t")

# Cleaner seq track logos

In [None]:
# Set-up model
model = "leaf"
trial = 5
model_type = "hybrid"

In [None]:
# Load in importances
sdata_interpretations = sd.open_zarr(os.path.join(settings.output_dir, model, f"{model}_test_predictions_and_interpretations.zarr"))

In [None]:
# Grab the highest predicted seqs for the best model
top5 = sdata_interpretations[f"{model_type}_trial_{trial}_enrichment_predictions"].to_series().sort_values(ascending=False).iloc[:5].index
top5_idx = np.argsort(sdata_interpretations[f"{model_type}_trial_{trial}_enrichment_predictions"].values)[::-1][:5]

In [None]:
# Find the ranges in each seq where the model gives high interpretations
seq_num, seq_pos = np.where(np.sum(sdata_interpretations["DeepLift_attrs"].values[top5_idx], axis=1) > 0.01)
ranges = []

# Find the continuous ranges of high interpretation that are longer than 3 and allow for multiple ranges per seq
for i in np.unique(seq_num):
    ranges_i = []
    for k, g in groupby(enumerate(seq_pos[seq_num == i]), lambda x: x[0] - x[1]):
        group = list(map(itemgetter(1), g))
        if len(group) > 3:
            ranges_i.append((group[0], group[-1]))
    ranges.append(ranges_i)
ranges

In [None]:
# Plot the top 5 with the ranges
ids = sdata_interpretations["id"].values[top5_idx]
for i in range(5):
    pl.seq_track(
        sdata_interpretations,
        seq_id=ids[i],
        attrs_var="DeepLift_attrs",
        ylab="DeepLift",
        highlights=ranges[i],
        figsize=(8, 1),
        save=os.path.join(settings.figure_dir, model, f"{model}_best_model_feature_attr_{i+1}.pdf"),
    )

# TomTom annotation analysis

In [None]:
# Which model?
model = "proto"
trial = 3
model_type = "jores21_cnn"

In [None]:
# Grab and combine the results from annotating CPEs and TF clusters
tomtom_cpe = pd.read_csv(os.path.join(settings.output_dir, model, f"{model}_best_model_filters_tomtom_CPE.tsv"), sep="\t")
tomtom_tf = pd.read_csv(os.path.join(settings.output_dir, model, f"{model}_best_model_filters_tomtom_TF.tsv"), sep="\t")
tomtom_df = pd.concat([tomtom_cpe, tomtom_tf], axis=0)

In [None]:
# Save as a dataframe
tomtom_df.to_csv(os.path.join(settings.output_dir, model, f"{model}_best_model_filters_tomtom.tsv"), sep="\t")

In [None]:
# Subset to significant hits
tomtom_sig = tomtom_df[tomtom_df["q-value"] <= 0.05]

In [None]:
# Add the filter number as a column
tomtom_sig["filter_num"] = tomtom_sig["Query_ID"].str.split("filter_").str[-1].astype(int)

In [None]:
# Separate into filters that were instantiated and those that were purely learned
tomtom_sig_init = tomtom_sig[tomtom_sig["filter_num"] <= 77]
tomtom_sig_learned = tomtom_sig[tomtom_sig["filter_num"] > 77]

In [None]:
# How many of the original filters returned significant hits?
len(tomtom_sig_init["Target_ID"].unique())

In [None]:
# Separate the instantiated filters into CPE and TF hits
tomtom_sig_init_tf = tomtom_sig_init[tomtom_sig_init["Target_ID"].str.contains("TF")]
tomtom_sig_init_cpe = tomtom_sig_init[(tomtom_sig_init["Target_ID"].str.contains("TF") == False)]

In [None]:
# Get the cluster numbers for the TF hits
tomtom_sig_init_tf["TF_cluster_number"] = tomtom_sig_init_tf["Target_ID"].str.split("_").str[-1]
tomtom_sig_init_tf["TF_cluster_number"] = tomtom_sig_init_tf["TF_cluster_number"].astype(int)

In [None]:
# See how many of the initialized TF clusters remained significant to their initialization
(tomtom_sig_init_tf["TF_cluster_number"] + 5 == tomtom_sig_init_tf["filter_num"]).sum()

In [None]:
# Generate a naming map for plotting
core_promoter_elements = md.read_meme(os.path.join(settings.dataset_dir, 'CPEs.meme'))
tf_clusters = md.read_meme(os.path.join(settings.dataset_dir, 'TF-clusters.meme'))

# Smush them together, make function in the future
all_motifs = deepcopy(core_promoter_elements)
for motif in tf_clusters:
    all_motifs.add_motif(motif)
all_motifs
id_map = {}
for motif in all_motifs:
    id_map[motif.identifier] = motif.name

In [None]:
# Plot the frequencies of hits to motifs in the learned filters
plt.figure(figsize=(6, 3), dpi=300)
tomtom_sig_learned_counts = tomtom_sig_learned["Target_ID"].map(id_map).value_counts()
tomtom_sig_learned_counts.plot(kind="bar", ylabel="Number of filters")
plt.savefig(os.path.join(settings.figure_dir, model, f"{model}_best_model_filters_tomtom_barplot.pdf"))

In [None]:
# Save the counts as a dataframe
tomtom_sig_learned_counts_df = tomtom_sig_learned_counts.to_frame()
tomtom_sig_learned_counts_df["system"] = model
tomtom_sig_learned_counts_df.to_csv(os.path.join(settings.output_dir, model, f"{model}_best_model_filters_tomtom_learned_motif_counts.tsv"), sep="\t")

In [None]:
# Get the most significant hits to each motif
top_tomtom_sig_learned = tomtom_sig_learned.sort_values("q-value").groupby("Target_ID").head(1)
top_tomtom_sig_learned.to_csv(os.path.join(settings.output_dir, model, f"{model}_best_model_filters_tomtom_top_hits.tsv"), sep="\t")

In [None]:
# Hits
idxs = top_tomtom_sig_learned["filter_num"].values
hit_names = top_tomtom_sig_learned["Target_ID"].map(id_map).values

In [None]:
# Load in importances
sdata_interpretations = sd.open_zarr(os.path.join(settings.output_dir, model, f"{model}_test_predictions_and_interpretations.zarr"))

In [None]:
# Grab the key for the pfms
keys = pd.Index(sdata_interpretations.data_vars.keys())
pfm_var = keys[keys.str.contains("pfms")].values[0]
pfm_var

In [None]:
# Visualize a filter of choice
for i, idx in enumerate(idxs):
    pl.filter_viz(
        sdata_interpretations,
        pfms_var=pfm_var,
        filter_num=idx,
        save=os.path.join(settings.figure_dir, model, f"{model}_best_model_filter{idx}_rank{i}_viz.pdf"),
    )

In [None]:
# Visualize a filter of choice
pl.filter_viz(
    sdata_interpretations,
    pfms_var=pfm_var,
    filter_num=179,
    save=os.path.join(settings.figure_dir, model, f"{model}_best_model_filter179_viz.pdf"),
    title=f"Filter 179",
)

# Save all the TomTom results

In [None]:
# Save as a dataframe
merged_df = pd.DataFrame()
for model in ["leaf", "proto", "combined"]:
    x = pd.read_csv(os.path.join(eu.settings.output_dir, model, f"{model}_best_model_filters_tomtom.tsv"), sep="\t", comment="#", index_col=0)
    x["system"] = model
    merged_df = pd.concat([merged_df, x])

In [None]:
# Remove na Query_IDs
merged_df = merged_df[~merged_df["Query_ID"].isna()]

In [None]:
# Save the merged dataframe
merged_df.to_csv(os.path.join(eu.settings.output_dir, "best_models_filters_tomtom.tsv"), sep="\t", index=False)

In [None]:
# Save as a dataframe
merged_counts_df = pd.DataFrame()
for model in ["leaf", "proto", "combined"]:
    x = pd.read_csv(os.path.join(settings.output_dir, model, f"{model}_best_model_filters_tomtom_learned_motif_counts.tsv"), sep="\t", comment="#")
    merged_counts_df = pd.concat([merged_counts_df, x])
merged_counts_df = merged_counts_df[merged_counts_df["Target_ID"] > 1]

In [None]:
# Plot the frequencies of hits to motifs in the learned filters, colored by system with 3 non-default colors
plt.figure(figsize=(4, 3), dpi=300)

# Set the color palette
sns.set_palette(sns.color_palette("Set2"))
ax = sns.barplot(data=merged_counts_df, x="Unnamed: 0", y="Target_ID", hue="system")
ax.set_ylabel("Number of filters")
ax.set_xlabel("")
for item in ax.get_xticklabels():
    item.set_rotation(90)
plt.tight_layout()
plt.savefig(os.path.join(settings.figure_dir, f"best_model_filters_tomtom_barplot.pdf"))

# Loss curve

In [None]:
# Which model?
model = "combined"
trial = 5
model_type = "deepstarr"

In [None]:
# Plot and save the training and metric curves
pl.training_summary(
    os.path.join(settings.logging_dir, model_type, f"{model}_trial_{trial}"),
    metric="r2",
    save=os.path.join(settings.figure_dir, model, f"{model}_best_model_training_summary.pdf")
)

# DONE!

---

# Scratch