# Jores et al 2021 Plotting 
**Authorship:**
Adam Klie, *09/12/2022*
***
**Description:**
Notebook to generate plots for the Jores et al (2021) dataset that are not included in the other notebooks.
 - Summary table of benchmarking results for for each model type
 - Cleaner seq track plots for top sequences
 - TomTom filter annotation analysis
 - Loss and metric plots
***

In [None]:
if 'autoreload' not in get_ipython().extension_manager.loaded:
    %load_ext autoreload
%autoreload 2

import os
import glob
import logging
import torch
import numpy as np
import pandas as pd
import eugene as eu
import matplotlib.pyplot as plt
import matplotlib
from scipy.stats import mannwhitneyu
from statsmodels.stats.multitest import multipletests

# For illustrator editing
matplotlib.rcParams['pdf.fonttype'] = 42
matplotlib.rcParams['ps.fonttype'] = 42

In [None]:
# Configure directories
eu.settings.dataset_dir = "/cellar/users/aklie/data/eugene/jores21"
eu.settings.output_dir = "/cellar/users/aklie/projects/EUGENe/EUGENe_paper/output/jores21"
eu.settings.logging_dir = "/cellar/users/aklie/projects/EUGENe/EUGENe_paper/logs/jores21"
eu.settings.config_dir = "/cellar/users/aklie/projects/EUGENe/EUGENe_paper/configs/jores21"
eu.settings.figure_dir = "/cellar/users/aklie/projects/EUGENe/EUGENe_paper/figures/jores21"

# Generate performance figures

## Leaf models

In [None]:
# Read in the predictions 
sdata_leaf = eu.dl.read_h5sd(os.path.join(eu.settings.output_dir, "leaf", "leaf_test_predictions.h5sd"))

In [None]:
# Summarize performance across models for r2
leaf_model_scores = eu.pl.performance_summary(
    sdata_leaf,
    target_key="enrichment",
    prediction_groups=["Jores21CNN"]*5 + ["ssCNN"]*5 + ["ssHybrid"]*5,
    metrics=["r2"],
    add_swarm=False,
    figsize=(6, 6),
    save=os.path.join(eu.settings.figure_dir, "leaf_performance_boxplot.pdf")
)

In [None]:
# Calculate significance of differences between groups
pairwise_tests = []
compare_df = leaf_model_scores.pivot(columns="prediction_groups", values="r2")
pairwise_tests.append(mannwhitneyu(compare_df["ssHybrid"].dropna(), compare_df["Jores21CNN"].dropna()).pvalue)
pairwise_tests.append(mannwhitneyu(compare_df["ssHybrid"].dropna(), compare_df["ssCNN"].dropna()).pvalue)
pairwise_tests.append(mannwhitneyu(compare_df["ssCNN"].dropna(), compare_df["Jores21CNN"].dropna()).pvalue)
multipletests(pairwise_tests, alpha=0.05, method="fdr_bh")

In [None]:
# Summarize performance across models for multiple metrics
leaf_model_scores = eu.pl.performance_summary(
    sdata_leaf,
    target_key="enrichment",
    prediction_groups=["Jores21CNN"]*5 + ["ssCNN"]*5 + ["ssHybrid"]*5,
    metrics=["r2", "mse", "pearson", "spearman", "kendall"],
    add_swarm=False,
    figsize=(6, 6),
)

# Identify the best model from returned model scores
r2_sorted = leaf_model_scores["r2"].sort_values(ascending=False)
r2_sorted.plot(kind="bar", ylabel="R2")
plt.savefig(os.path.join(eu.settings.figure_dir, "leaf_performance_summary.pdf"))

In [None]:
# Save the performance summary as a table
leaf_model_scores.to_csv(os.path.join(eu.settings.output_dir, "leaf_performance_summary.tsv"), sep="\t")

In [None]:
# Plot the performances across species for the best model
best_preds = r2_sorted.index[0]
ax = eu.pl.performance_scatter(
    sdata_leaf, 
    target_keys="enrichment", 
    prediction_keys=best_preds,
    alpha=0.5,
    groupby="sp",
    figsize=(8, 8),
    rasterized=True,
    save=os.path.join(eu.settings.figure_dir, "leaf_best_model_performance_scatter_by_sp.pdf")
)

In [None]:
# Plot the performance for all species to use in Figure 2 
ax = eu.pl.performance_scatter(
    sdata_leaf, 
    target_keys="enrichment", 
    prediction_keys=best_preds,
    alpha=0.5,
    figsize=(4, 4),
    rasterized=True,
    save=os.path.join(eu.settings.figure_dir, "leaf_best_model_performance_scatter.pdf")
)

## Proto models

In [None]:
# Read in if already made predictions and generated file above
sdata_proto = eu.dl.read_h5sd(os.path.join(eu.settings.output_dir, "proto", "proto_test_predictions.h5sd"))

In [None]:
# Summarize performance across models for r2
proto_model_scores = eu.pl.performance_summary(
    sdata_proto,
    target_key="enrichment",
    prediction_groups=["Jores21CNN"]*5 + ["ssCNN"]*5 + ["ssHybrid"]*5,
    metrics=["r2"],
    add_swarm=False,
    figsize=(6, 6),
    save=os.path.join(eu.settings.figure_dir, "proto_performance_boxplot.pdf")
)

In [None]:
# Calculate significance of differences between groups
pairwise_tests = []
compare_df = proto_model_scores.pivot(columns="prediction_groups", values="r2")
pairwise_tests.append(mannwhitneyu(compare_df["ssHybrid"].dropna(), compare_df["Jores21CNN"].dropna()).pvalue)
pairwise_tests.append(mannwhitneyu(compare_df["ssHybrid"].dropna(), compare_df["ssCNN"].dropna()).pvalue)
pairwise_tests.append(mannwhitneyu(compare_df["ssCNN"].dropna(), compare_df["Jores21CNN"].dropna()).pvalue)
multipletests(pairwise_tests, alpha=0.05, method="fdr_bh")

In [None]:
# Summarize performance across models for a metric
proto_model_scores = eu.pl.performance_summary(
    sdata_proto,
    target_key="enrichment",
    prediction_groups=["Jores21CNN"]*5 + ["ssCNN"]*5 + ["ssHybrid"]*5,
    metrics=["r2", "mse", "pearson", "spearman", "kendall"],
    add_swarm=False,
    figsize=(6, 6),
)

# Identify the best model from returned model scores
r2_sorted = proto_model_scores["r2"].sort_values(ascending=False)
r2_sorted.plot(kind="bar", ylabel="R2")
plt.savefig(os.path.join(eu.settings.figure_dir, "proto_performance_summary.pdf"))

In [None]:
# Plot the performances across species for the best model
best_preds = r2_sorted.index[0]
ax = eu.pl.performance_scatter(
    sdata_proto, 
    target_keys="enrichment", 
    prediction_keys=best_preds,
    alpha=0.5,
    groupby="sp",
    figsize=(8, 8),
    rasterized=True,
    save=os.path.join(eu.settings.figure_dir, "proto_best_model_performance_scatter_by_sp.pdf")
)

In [None]:
# Save the performance summary as a table
proto_model_scores.to_csv(os.path.join(eu.settings.output_dir, "proto_performance_summary.tsv"), sep="\t")

## Combined models

In [None]:
# Read in if already made predictions and generated file above
sdata_combined = eu.dl.read_h5sd(os.path.join(eu.settings.output_dir, "combined", "combined_test_predictions.h5sd"))

In [None]:
# Summarize performance across models for a metric
combined_model_scores = eu.pl.performance_summary(
    sdata_combined,
    target_key="enrichment",
    prediction_groups=["Jores21CNN"]*5 + ["ssCNN"]*5 + ["ssHybrid"]*5,
    metrics=["r2"],
    add_swarm=False,
    figsize=(6, 6),
    save= os.path.join(eu.settings.figure_dir, "combined_performance_boxplot.pdf")
)

In [None]:
# Calculate significance of differences between groups
pairwise_tests = []
compare_df = combined_model_scores.pivot(columns="prediction_groups", values="r2")
pairwise_tests.append(mannwhitneyu(compare_df["ssHybrid"].dropna(), compare_df["Jores21CNN"].dropna()).pvalue)
pairwise_tests.append(mannwhitneyu(compare_df["ssHybrid"].dropna(), compare_df["ssCNN"].dropna()).pvalue)
pairwise_tests.append(mannwhitneyu(compare_df["ssCNN"].dropna(), compare_df["Jores21CNN"].dropna()).pvalue)
multipletests(pairwise_tests, alpha=0.05, method="fdr_bh")

In [None]:
# Summarize performance across models for a metric
combined_model_scores = eu.pl.performance_summary(
    sdata_combined,
    target_key="enrichment",
    prediction_groups=["Jores21CNN"]*5 + ["ssCNN"]*5 + ["ssHybrid"]*5,
    metrics=["r2", "mse", "pearson", "spearman", "kendall"],
    add_swarm=False,
    figsize=(6, 6),
)

# Identify the best model from returned model scores
r2_sorted = combined_model_scores["r2"].sort_values(ascending=False)
r2_sorted.plot(kind="bar", ylabel="R2")
plt.savefig(os.path.join(eu.settings.figure_dir, "combined_performance_summary.pdf"))

In [None]:
# Plot the performances across species for the best model
best_preds = r2_sorted.index[0]
ax = eu.pl.performance_scatter(
    sdata_combined, 
    target_keys="enrichment", 
    prediction_keys=best_preds,
    alpha=0.5,
    groupby="sp",
    figsize=(8, 8),
    rasterized=True,
    save=os.path.join(eu.settings.figure_dir, "combined_best_model_performance_scatter_by_sp.pdf")
)

In [None]:
# Save the performance summary as a table
combined_model_scores.to_csv(os.path.join(eu.settings.output_dir, "combined_performance_summary.tsv"), sep="\t")

# Performance summary table

In [None]:
# Combined everything into one dataframe
leaf_model_scores["model"] = "leaf"
proto_model_scores["model"] = "proto"
combined_model_scores["model"] = "combined"
merged_model_scores = pd.concat([leaf_model_scores, proto_model_scores, combined_model_scores])
merged_model_scores.to_csv(os.path.join(eu.settings.output_dir, "merged_performance_summary.tsv"), sep="\t")

# Cleaner seq track logos

In [None]:
# Set-up model
model = "combined"
trial = 3
model_type = "Jores21CNN"

In [None]:
# Load in importances
sdata_interpretations = eu.dl.read_h5sd(os.path.join(eu.settings.output_dir, f"{model}_test_predictions_and_interpretations.h5sd"))

In [None]:
# Grab the highest predicted seqs for the best model
top5 = sdata_interpretations[f"{model_type}_trial_{trial}_enrichment_predictions"].sort_values(ascending=False).iloc[:5].index
top5_idx = np.argsort(sdata_interpretations[f"{model_type}_trial_{trial}_enrichment_predictions"].values)[::-1][:5]

In [None]:
# Find the ranges in each seq where the model gives high interpretations
seq_num, seq_pos = np.where(np.sum(sdata_interpretations.uns["DeepLift_imps"][top5_idx], axis=1) > 0.2)
ranges = []
for j in range(5):
    curr_pos = seq_pos[np.where(seq_num == j)]
    start = curr_pos[0]
    motifs = []
    for i in range(1, len(curr_pos)):
        if curr_pos[i] - curr_pos[i-1] > 3:
            start = curr_pos[i]
            if curr_pos[i] - start > 4:
                    motifs.append((start, curr_pos[i]))
    if curr_pos[-1] - start > 4:
        motifs.append((start, curr_pos[-1]))
    ranges.append(motifs)

In [None]:
for i in range(5):
    eu.pl.seq_track(
        sdata_interpretations,
        seq_id=top5[i],
        uns_key="DeepLift_imps",
        ylabel="DeepLift",
        highlights=ranges[i],
        figsize=(8, 1),
        save=os.path.join(eu.settings.figure_dir, f"{model}_best_model_feature_attr_{i+1}.pdf"),
    )

# TomTom annotation analysis

In [None]:
# Which model?
model = "leaf"
trial = 5
model_type = "ssHybrid"

In [None]:
# Grab and combine the results from annotating CPEs and TF clusters
tomtom_cpe = pd.read_csv(os.path.join(eu.settings.output_dir, model, f"{model}_best_model_filters_tomtom_CPE.tsv"), sep="\t")
tomtom_tf = pd.read_csv(os.path.join(eu.settings.output_dir, model, f"{model}_best_model_filters_tomtom_TF.tsv"), sep="\t")
tomtom_df = pd.concat([tomtom_cpe, tomtom_tf], axis=0)

In [None]:
# Save as a dataframe
tomtom_df.to_csv(os.path.join(eu.settings.output_dir, f"{model}_best_model_filters_tomtom.tsv"), sep="\t")

In [None]:
# Subset to significant hits
tomtom_sig = tomtom_df[tomtom_df["q-value"] <= 0.05]

In [None]:
# Add the filter number as a column
tomtom_sig["filter_num"] = tomtom_sig["Query_ID"].str.split("filter").str[-1].astype(int)

In [None]:
# Separate into filters that were instantiated and those that were purely learned
tomtom_sig_init = tomtom_sig[tomtom_sig["filter_num"] <= 77]
tomtom_sig_learned = tomtom_sig[tomtom_sig["filter_num"] > 77]

In [None]:
# How many of the original filters returned significant hits?
len(tomtom_sig_init["Target_ID"].unique())

In [None]:
# Separate the instantiated filters into CPE and TF hits
tomtom_sig_init_tf = tomtom_sig_init[tomtom_sig_init["Target_ID"].str.contains("TF")]
tomtom_sig_init_cpe = tomtom_sig_init[(tomtom_sig_init["Target_ID"].str.contains("TF") == False)]

In [None]:
# Get the cluster numbers for the TF hits
tomtom_sig_init_tf["TF_cluster_number"] = tomtom_sig_init_tf["Target_ID"].str.split("_").str[-1]
tomtom_sig_init_tf["TF_cluster_number"] = tomtom_sig_init_tf["TF_cluster_number"].astype(int)

In [None]:
# See how many of the initialized TF clusters remained significant to their initialization
(tomtom_sig_init_tf["TF_cluster_number"] + 5 == tomtom_sig_init_tf["filter_num"]).sum()

In [None]:
# Generate a naming map for plotting
core_promoter_elements = eu.dl.motif.MinimalMEME(os.path.join(eu.settings.dataset_dir, 'CPEs.meme'))
tf_groups = eu.dl.motif.MinimalMEME(os.path.join(eu.settings.dataset_dir, 'TF-clusters.meme'))
all_motifs = {**core_promoter_elements.motifs, **tf_groups.motifs}
id_map = {}
for id_name, motif in all_motifs.items():
    id_map[id_name] = motif.name

In [None]:
# Plot the frequencies of hits to motifs in the learned filters
plt.figure(figsize=(6, 3), dpi=300)
tomtom_sig_learned["Target_ID"].map(id_map).value_counts().plot(kind="bar", ylabel="Number of filters")
plt.savefig(os.path.join(eu.settings.figure_dir, f"{model}_best_model_filters_tomtom_barplot.pdf"))

In [None]:
# Get the most significant hits to each motif
top_tomtom_sig_learned = tomtom_sig_learned.sort_values("q-value").groupby("Target_ID").head(1)
top_tomtom_sig_learned.to_csv(os.path.join(eu.settings.output_dir, f"{model}_best_model_filters_tomtom_top_hits.tsv"), sep="\t")

In [None]:
# Save as a dataframe
merged_df = pd.DataFrame()
for model in ["leaf", "proto", "combined"]:
    x = pd.read_csv(os.path.join(eu.settings.output_dir, model, f"{model}_best_model_filters_tomtom.tsv"), sep="\t", comment="#", index_col=0)
    x["system"] = model
    merged_df = pd.concat([merged_df, x])

In [None]:
merged_df = merged_df[~merged_df["Query_ID"].isna()]

In [None]:
merged_df.to_csv(os.path.join(eu.settings.output_dir, "best_models_filters_tomtom.tsv"), sep="\t", index=False)

# Loss curve

In [None]:
# Which model?
model = "combined"
trial = 3
model_type = "Jores21CNN"

In [None]:
# Plot and save the training and metric curves
eu.pl.training_summary(
    os.path.join(eu.settings.logging_dir, model_type, f"{model}_trial_{trial}"),
    metric="r2",
    save=os.path.join(eu.settings.figure_dir, f"{model}_best_model_training_summary.pdf")
)

---