# Kopp et al 2021 Plotting 
**Authorship:**
Adam Klie, *08/12/2022*
***
**Description:**
Notebook to generate plots for the Kopp et al (2021) dataset that are not included in the other notebooks.
- Performance figures
- Nicer seq track and filter visualizations
- Inspect and merge TomTom annotations
***

In [None]:
if 'autoreload' not in get_ipython().extension_manager.loaded:
    %load_ext autoreload
%autoreload 2

import os
import glob
import logging
import torch
import numpy as np
import pandas as pd
import eugene as eu
import matplotlib.pyplot as plt
import matplotlib
from scipy.stats import mannwhitneyu
from statsmodels.stats.multitest import multipletests

matplotlib.rcParams['pdf.fonttype'] = 42
matplotlib.rcParams['ps.fonttype'] = 42

In [None]:
eu.settings.dataset_dir = "/cellar/users/aklie/data/eugene/kopp21"
eu.settings.output_dir = "/cellar/users/aklie/projects/EUGENe/EUGENe_paper/output/kopp21"
eu.settings.logging_dir = "/cellar/users/aklie/projects/EUGENe/EUGENe_paper/logs/kopp21"
eu.settings.config_dir = "/cellar/users/aklie/projects/EUGENe/EUGENe_paper/configs/kopp21"
eu.settings.figure_dir = "/cellar/users/aklie/projects/EUGENe/EUGENe_paper/figures/kopp21"
eu.settings.verbosity = logging.ERROR

# Load in the test `SeqData`(s)

In [None]:
# Load in the training data that's been predicted on
sdata_test = eu.dl.read_h5sd(filename=os.path.join(eu.settings.output_dir, "jund_test_predictions_all.h5sd"))
sdata_test

# Generate performance figures

In [None]:
model_scores = eu.pl.performance_summary(
    sdata_test,
    target_key="target",
    prediction_groups=["Kopp21CNN"]*5 + ["dsCNN"]*5 + ["dsFCN"]*5 + ["dsHybrid"]*5,
    order=["dsFCN", "Kopp21CNN", "dsHybrid", "dsCNN"],
    metrics=["average_precision"],
    figsize=(6, 6),
    save=os.path.join(eu.settings.figure_dir, "jund_auprc_boxplot.pdf")
)

In [None]:
# Test distributions for significant differences
pairwise_tests = []
compare_df = model_scores.pivot(columns="prediction_groups", values="average_precision")
pairwise_tests.append(mannwhitneyu(compare_df["dsFCN"].dropna(), compare_df["Kopp21CNN"].dropna()).pvalue)
pairwise_tests.append(mannwhitneyu(compare_df["dsFCN"].dropna(), compare_df["dsHybrid"].dropna()).pvalue)
pairwise_tests.append(mannwhitneyu(compare_df["dsFCN"].dropna(), compare_df["dsCNN"].dropna()).pvalue)
pairwise_tests.append(mannwhitneyu(compare_df["Kopp21CNN"].dropna(), compare_df["dsHybrid"].dropna()).pvalue)
pairwise_tests.append(mannwhitneyu(compare_df["Kopp21CNN"].dropna(), compare_df["dsCNN"].dropna()).pvalue)
pairwise_tests.append(mannwhitneyu(compare_df["dsHybrid"].dropna(), compare_df["dsCNN"].dropna()).pvalue)
multipletests(pairwise_tests, alpha=0.05, method="fdr_bh")

In [None]:
# Summarize performance across models for a metric
model_scores = eu.pl.performance_summary(
    sdata_test,
    target_key="target",
    prediction_groups=["Kopp21CNN"]*5 + ["dsCNN"]*5 + ["dsFCN"]*5 + ["dsHybrid"]*5, 
    order=["dsFCN", "Kopp21CNN", "dsHybrid", "dsCNN"],
    metrics=["accuracy", "precision", "recall", "f1", "average_precision", "roc_auc"],
    figsize=(6, 6),
    save=os.path.join(eu.settings.figure_dir, "jund_performance_summary.pdf")
)

In [None]:
# Save model performance as table
model_scores.to_csv(os.path.join(eu.settings.output_dir, "jund_performance_summary.tsv"), sep="\t")

In [None]:
# Identify the best model from returned model scores
auprc_sorted = model_scores["average_precision"].sort_values(ascending=False)
auprc_sorted.plot(kind="bar", ylabel="auPRC")

In [None]:
# Plot the performances across species for the best model
model_scores["model_type"] = [model.split("_")[0] for model in model_scores.index]
best_preds = model_scores.sort_values(by="average_precision", ascending=False).groupby("model_type").head(1).index
eu.pl.auprc(
    sdata_test,
    target_keys=["target"]*4, 
    prediction_keys=best_preds,
    labels=best_preds,
    save=os.path.join(eu.settings.figure_dir, "jund_best_model_auprc.pdf")
)
plt.show()

In [None]:
# Plot the performances across species for the best model
eu.pl.auprc(
    sdata_test,
    target_keys="target",
    prediction_keys=best_preds[0],
    labels=best_preds[0],
    save=os.path.join(eu.settings.figure_dir, "jund_best_single_model_auprc.pdf")
)
plt.show()

# Seq track visualizations

In [None]:
# Choose the model and the test data
model_type = "Kopp21CNN"
trial = 4
sdata_test = eu.dl.read_h5sd(os.path.join(eu.settings.output_dir, f"jund_test_predictions_and_interpretations_{model_type}.h5sd"))

In [None]:
# From the top 10 predicted sequences, identify the positions where we see significant attribution signal
top10 = sdata_test[f"{model_type}_trial_{trial}_target_predictions"].sort_values(ascending=False).iloc[:10].index
top10_idx = np.argsort(sdata_test[f"{model_type}_trial_{trial}_target_predictions"].values)[::-1][:10]
np.where(np.sum(sdata_test.uns["GradientSHAP_forward_imps"][top10_idx], axis=1) > 0.1)

In [None]:
# Use the ouptut from above to highlight specific seqlets
eu.pl.seq_track(
    sdata_test,
    seq_id=top10[0],
    uns_key="GradientSHAP_forward_imps",
    ylabel="GradientSHAP Forward",
    figsize=(18, 3),
    highlights=[(387, 400)],
    highlight_colors = ["lightcyan"],
    save=os.path.join(eu.settings.figure_dir, f"jund_best_{model_type}_model_GradientSHAP_forward_imps_top1_with_color.pdf")
)

In [None]:
# Repeat for the reverse strand
np.where(np.sum(sdata_test.uns["GradientSHAP_reverse_imps"][top10_idx], axis=1) > 0.1)

In [None]:
eu.pl.seq_track(
    sdata_test,
    seq_id=top2[0],
    uns_key="GradientSHAP_reverse_imps",
    ylabel="GradientSHAP Reverse",
    figsize=(18, 3),
    highlights=[(105,114), (178, 190)],
    highlight_colors = ["lightcyan", "honeydew"],
    save=os.path.join(eu.settings.figure_dir, f"jund_best_{model_type}_model_GradientSHAP_reverse_imps_top1_with_color.pdf")
)

In [None]:
# Plot all top 10 and save
for i, seq in enumerate(top10):
    eu.pl.seq_track(
        sdata_test,
        seq_id=top10[i],
        uns_key="GradientSHAP_forward_imps",
        ylabel="GradientSHAP Forward",
        figsize=(18, 3),
        save=os.path.join(eu.settings.figure_dir, f"jund_best_{model_type}_model_GradientSHAP_forward_imps_top{i+1}.pdf")
    )
    eu.pl.seq_track(
        sdata_test,
        seq_id=top10[i],
        uns_key="GradientSHAP_reverse_imps",
        ylabel="GradientSHAP Reverse",
        figsize=(18, 3),
        save=os.path.join(eu.settings.figure_dir, f"jund_best_{model_type}_model_GradientSHAP_reverse_imps_top{i+1}.pdf")
    )

# Filter viz

In [None]:
# Plot all the filters for the current model
for i in range(1):
    start_filter = i*10
    end_filter = (i*10) + 10
    print(f"Plotting and saving filters {start_filter+1}-{end_filter}")
    eu.pl.multifilter_viz(
        sdata_test,
        filter_ids=list(sdata_test.uns["pfms"].keys())[start_filter:end_filter],
        num_rows=2,
        num_cols=5,
        titles=[f"filter {i}" for i in range(start_filter, end_filter)],
        save=os.path.join(eu.settings.figure_dir, f"jund_best_{model_type}_model_filters{start_filter+1}-{end_filter}_viz.pdf")
    )

# TomTom results

In [None]:
# Get the significant hits to the HOCOMOCO database
model_type = "CNN"
res = pd.read_csv(os.path.join(eu.settings.output_dir, f"jund_best_{model_type}_model_filters_tomtom.tsv"), sep="\t", comment="#")
res_sig = res[res["q-value"] < 0.05].sort_values(by="q-value")

In [None]:
# Check the most significant for each filter
res_sig.groupby("Query_ID").head(1)

In [None]:
# Save as a dataframe
merged_df = pd.DataFrame()
for model in ["CNN", "Hybrid", "Kopp21CNN"]:
    x = res = pd.read_csv(os.path.join(eu.settings.output_dir, f"jund_best_{model_type}_model_filters_tomtom.tsv"), sep="\t", comment="#")
    x["model_type"] = model
    merged_df = pd.concat([merged_df, x])
merged_df = merged_df[~merged_df["Query_ID"].isna()]
merged_df.to_csv(os.path.join(eu.settings.output_dir, "all_models_filters_tomtom.tsv"), sep="\t", index=False)

---