# Kopp et al 2021 Plotting 
**Authorship:**
Adam Klie (last updated: *06/10/2023*)
***
**Description:**
Notebook to generate plots for the Kopp et al (2021) dataset that are not included in the other notebooks.
- Performance figures
- Nicer seq track and filter visualizations
- Inspect and merge TomTom annotations
***

In [None]:
# General imports
import os
import sys
import numpy as np
import pandas as pd
from copy import deepcopy
from itertools import groupby
from operator import itemgetter
from scipy.stats import mannwhitneyu
from statsmodels.stats.multitest import multipletests

# EUGENe imports and settings
import eugene as eu
from eugene import preprocess as pp
from eugene import plot as pl
from eugene import settings
settings.dataset_dir = "/cellar/users/aklie/data/eugene/revision/kopp21"
settings.output_dir = "/cellar/users/aklie/projects/ML4GLand/EUGENe_paper/output/revision/kopp21"
settings.logging_dir = "/cellar/users/dlaub/projects/ML4GLand/EUGENe_paper/logs/kopp21"
settings.figure_dir = "/cellar/users/aklie/projects/ML4GLand/EUGENe_paper/figures/revision/kopp21"

# EUGENe packages
import seqdata as sd
import motifdata as md

# For illustrator editing
import matplotlib
import matplotlib.pyplot as plt
matplotlib.rcParams['pdf.fonttype'] = 42
matplotlib.rcParams['ps.fonttype'] = 42

# Print versions
print(f"Python version: {sys.version}")
print(f"NumPy version: {np.__version__}")
print(f"Eugene version: {eu.__version__}")
print(f"SeqData version: {sd.__version__}")

# Load in the test `SeqData`(s)

In [None]:
# Read in the predictions 
predictions = pd.read_csv(os.path.join(settings.output_dir, "test_predictions_all.tsv"), sep="\t", index_col=0)
sdata_test = sd.open_zarr(os.path.join(settings.output_dir, "test_predictions_all.zarr")).load()

In [None]:
# Give the sequences a unique ID
sdata_test["id"] = sdata_test["chrom"] + ":" + sdata_test["chromStart"].astype(str) + "-" + sdata_test["chromEnd"].astype(str)

# Generate performance figures

In [None]:
# Get only variables with "predictions" in the name
preds_vars = [k for k in sdata_test.keys() if "predictions" in k]
preds_vars

In [None]:
# Order the pred_vars from in this order
order = ["dsfcn", "kopp21_cnn", "dshybrid", "dscnn"]
pred_models = [k.split("_")[0] if "kopp21" not in k else "kopp21_cnn" for k in preds_vars]
pred_models = [order.index(m) for m in pred_models]
preds_vars = [k for _, k in sorted(zip(pred_models, preds_vars))]

# Get groups based io
model_groups = {"dsfcn": "dsfcn", "kopp21_cnn": "kopp21_cnn", "dshybrid": "dshybrid", "dscnn": "dscnn"}
groups = [model_groups[k.split("_")[0]] if "kopp21" not in k else "kopp21_cnn" for k in preds_vars]
preds_vars, groups

In [None]:
model_scores = pl.performance_summary(
    sdata_test,
    target_var="target",
    prediction_vars=preds_vars,
    prediction_groups=groups,
    order=order,
    metrics=["average_precision"],
    figsize=(6, 6),
    save=os.path.join(settings.figure_dir, "all_models_auprc_boxplot.pdf")
)


In [None]:
# Test distributions for significant differences
pairwise_tests = []
compare_df = model_scores.pivot(columns="prediction_groups", values="average_precision")
pairwise_tests.append(mannwhitneyu(compare_df["dsfcn"].dropna(), compare_df["kopp21_cnn"].dropna()).pvalue)
pairwise_tests.append(mannwhitneyu(compare_df["dsfcn"].dropna(), compare_df["dshybrid"].dropna()).pvalue)
pairwise_tests.append(mannwhitneyu(compare_df["dsfcn"].dropna(), compare_df["dscnn"].dropna()).pvalue)
pairwise_tests.append(mannwhitneyu(compare_df["kopp21_cnn"].dropna(), compare_df["dshybrid"].dropna()).pvalue)
pairwise_tests.append(mannwhitneyu(compare_df["kopp21_cnn"].dropna(), compare_df["dscnn"].dropna()).pvalue)
pairwise_tests.append(mannwhitneyu(compare_df["dshybrid"].dropna(), compare_df["dscnn"].dropna()).pvalue)
multipletests(pairwise_tests, alpha=0.05, method="fdr_bh")

In [None]:
# Boxplots
model_scores = pl.performance_summary(
    sdata_test,
    target_var="target",
    prediction_vars=preds_vars,
    prediction_groups=groups,
    order=order,
    metrics=["accuracy", "precision", "recall", "f1", "average_precision", "roc_auc"],
    figsize=(6, 6),
    save=os.path.join(settings.figure_dir, "performance_summary_boxplots.pdf")
)

In [None]:
# Save model performance as table
model_scores.to_csv(os.path.join(settings.output_dir, "performance_summary.tsv"), sep="\t")

In [None]:
# Identify the best model from returned model scores
auprc_sorted = model_scores["average_precision"].sort_values(ascending=False)
auprc_sorted.plot(kind="bar", ylabel="auPRC")

In [None]:
# Plot the performances across species for the best model
model_scores["model_type"] = [model.split("_")[0] for model in model_scores.index]
best_preds = model_scores.sort_values(by="average_precision", ascending=False).groupby("model_type").head(1).index
pl.auprc(
    sdata_test,
    target_vars=["target"]*4, 
    prediction_vars=best_preds,
    labels=best_preds,
    save=os.path.join(settings.figure_dir, "best_models_auprc.pdf")
)
plt.show()

In [None]:
# Plot the performances across species for the best model
pl.auprc(
    sdata_test,
    target_vars="target",
    prediction_vars=best_preds[0],
    labels=best_preds[0],
    save=os.path.join(eu.settings.figure_dir, "best_single_model_auprc.pdf")
)
plt.show()

# Seq track visualizations

In [None]:
# Choose the model
model_type = "dsfcn"
trial = 1

In [None]:
# And the test data
sdata_test = sd.open_zarr(os.path.join(eu.settings.output_dir, model_type, f"test_predictions_and_interpretations.zarr"))
ids = sdata_test["id"].values

In [None]:
# From the top 10 predicted sequences, identify the positions where we see significant attribution signal
top10 = sdata_test[f"{model_type}_trial_{trial}_target_predictions"].to_series().sort_values(ascending=False).iloc[:10].index
top10_idx = np.argsort(sdata_test[f"{model_type}_trial_{trial}_target_predictions"].values)[::-1][:10]
np.where(np.sum(sdata_test["GradientShap_attrs"][top10_idx], axis=1) > 0.1)

In [None]:
# Plot all top 10 and save
for i, seq in enumerate(top10):
    pl.seq_track(
        sdata_test,
        seq_id=ids[top10[i]],
        attrs_var="GradientShap_attrs",
        ylab="GradientShap Forward",
        figsize=(18, 3),
        save=os.path.join(settings.figure_dir, model_type, f"best_{model_type}_model_GradientSHAP_forward_imps_top{i+1}.pdf")
    )
    pl.seq_track(
        sdata_test,
        seq_id=ids[top10[i]],
        attrs_var="GradientShap_attrs_rc",
        ylab="GradientShap Reverse",
        figsize=(18, 3),
        save=os.path.join(settings.figure_dir, model_type, f"best_{model_type}_model_GradientSHAP_reverse_imps_top{i+1}.pdf")
    )

# Filter viz

In [None]:
# Choose the right layer name for each model
if model_type == "kopp21_cnn":
    layer_name = "arch.conv"
elif "ds" in model_type:
    layer_name = "arch.conv1d_tower.layers.0"
else:
    layer_name = "arch.conv1d_tower.layers.1"
layer_name

In [None]:
# Plot all the filters for the current model
pl.multifilter_viz(
    sdata_test,
    filter_nums=range(0, 10),
    pfms_var=f"{layer_name}_pfms",
    num_rows=2,
    num_cols=5,
    figsize=(10, 3),
    titles=[f"filter {i}" for i in range(0, 10)],
    save=os.path.join(settings.figure_dir, model_type, f"best_{model_type}_model_filters_viz.pdf")
)

# TomTom results

In [None]:
# Get the significant hits to the HOCOMOCO database
model_type = "kopp21_cnn"
res = pd.read_csv(os.path.join(settings.output_dir, model_type, f"best_model_{model_type}_filters_tomtom.tsv"), sep="\t", comment="#")
res_sig = res[res["q-value"] < 0.05].sort_values(by="q-value")

In [None]:
# Check the most significant for each filter
res_sig.groupby("Query_ID").head(1)

In [None]:
# Save as a dataframe
merged_df = pd.DataFrame()
for model in ["dshybrid", "kopp21_cnn", "dscnn"]:
    x = res = pd.read_csv(os.path.join(settings.output_dir, model, f"best_model_{model}_filters_tomtom.tsv"), sep="\t", comment="#")
    x["model_type"] = model
    merged_df = pd.concat([merged_df, x])
merged_df = merged_df[~merged_df["Query_ID"].isna()]
merged_df.to_csv(os.path.join(settings.output_dir, "all_models_filters_tomtom.tsv"), sep="\t", index=False)

# DONE!

---

# Scratch