In [32]:
#!/usr/bin/env python
"""
GO enrichment (BP & MF) for top-95 TF-target sets, per cluster — inclusive display version with custom background
Author: Zach Myers / Greenham Lab   Updated 20 May 2025

Requires:
    pip install pandas gprofiler-official matplotlib seaborn
"""

# ----------------------------
# Imports
# ----------------------------
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from gprofiler import GProfiler

# ----------------------------
# Parameters – edit as needed
# ----------------------------
INPUT_TSV   = "/home/greenhamlab/snTC/20240411.GRNs.ParallelShuffled/all_adj/unshuffled_top95_all_clusters.tsv"
ORGANISM    = "athaliana"
TOP_N       = 5
ALPHA       = 0.01
OUT_PREFIX  = "top5inclusive.95percentile_GO_py.bgCorrected"

# ----------------------------
# 1. Read and organize target sets
# ----------------------------
df = pd.read_csv(INPUT_TSV, sep="\t")

cluster_targets = (
    df.groupby("cluster")["target"]
      .unique()
      .apply(list)
      .to_dict()
)

# Custom background: all observed target genes
all_targets = sorted(set(df["target"]))

# ----------------------------
# 2. Run g:Profiler enrichment with custom background
# ----------------------------
gp = GProfiler(return_dataframe=True)

gost = gp.profile(
    organism        = ORGANISM,
    query           = cluster_targets,
    sources         = ["GO:BP", "GO:MF"],
    significance_threshold_method = "fdr",
    user_threshold  = ALPHA,
    no_evidences    = False,
    background      = all_targets
)

# ----------------------------
# 3. Harmonize and check output
# ----------------------------
rename_rules = {
    "name": "term_name", "term.name": "term_name", "native": "term_id", "term.id": "term_id",
    "p.value": "p_value", "overlap.size": "intersection_size", "query.number": "query"
}
present = {old: new for old, new in rename_rules.items() if old in gost.columns}
if present:
    gost.rename(columns=present, inplace=True)

required = {"term_name", "term_id", "p_value", "intersection_size", "query"}
missing = required.difference(gost.columns)
if missing:
    raise RuntimeError(
        f"g:Profiler result missing expected columns: {', '.join(missing)}\n"
        "→ Please upgrade gprofiler-official"
    )

gost.to_csv(f"{OUT_PREFIX}_full_results.tsv", sep="\t", index=False)

# ----------------------------
# 4. Inclusive top-N terms across all clusters
# ----------------------------
def top_terms_inclusive(df: pd.DataFrame, namespace: str) -> pd.DataFrame:
    df_ns = df[df["source"] == namespace].copy()

    # Top N per cluster
    top_per_cluster = (
        df_ns.sort_values(["query", "p_value"])
             .groupby("query")
             .head(TOP_N)
    )
    top_terms_set = set(top_per_cluster["term_id"])

    # Get all occurrences of those terms across all clusters
    inclusive = df_ns[df_ns["term_id"].isin(top_terms_set)].copy()
    inclusive["log10FDR"] = -np.log10(inclusive["p_value"])
    inclusive["gene_ratio"] = inclusive["intersection_size"] / inclusive["term_size"]
    inclusive["term_name"] = inclusive["term_name"].str.slice(0, 60)
    return inclusive

bp_top = top_terms_inclusive(gost, "GO:BP")
mf_top = top_terms_inclusive(gost, "GO:MF")

bp_top.to_csv(f"{OUT_PREFIX}_BP_top.tsv", sep="\t", index=False)
mf_top.to_csv(f"{OUT_PREFIX}_MF_top.tsv", sep="\t", index=False)

# ----------------------------
# 5. Append abbreviated cluster labels
# ----------------------------
cluster_labels = {
    "c0": "Photo", "c1": "Photo", "c3": "Meso", "c4": "Meso",
    "c11": "Meso", "c2": "Meso", "c12_0": "Meso", "c13": "Trichome/Epi",
    "c14": "GC", "c8_0": "Div", "c9": "Div",
    "c12_1": "Div", "c5": "Vasc", "c6": "PCC", "c7": "Vasc",
    "c8_1": "Vasc", "c10": "Vasc", "c12_2": "PP", "c15": "Xylem"
}

def annotate_clusters(df):
    df["query"] = df["query"].apply(lambda x: f"{x}; {cluster_labels.get(x, 'Unknown')}")
    return df

bp_top = annotate_clusters(bp_top)
mf_top = annotate_clusters(mf_top)

# ----------------------------
# 6. Define final cluster order
# ----------------------------
ordered_clusters = [
    "c0", "c1", "c3", "c4", "c11", "c2", "c12_0", "c13", "c14",
    "c8_0", "c9", "c12_1", "c5", "c6", "c7", "c8_1", "c10", "c12_2", "c15"
]
ordered_labels = [f"{c}; {cluster_labels.get(c, 'Unknown')}" for c in ordered_clusters]

# ----------------------------
# 7. Dot plot function
# ----------------------------
def dotplot(df: pd.DataFrame, title: str, outfile: str):
    term_order = (
        df.groupby(["term_id", "term_name"])["p_value"]
          .min().sort_values().index
          .get_level_values("term_name")
    )

    plt.figure(figsize=(12, 8))
    sns.scatterplot(
        data=df,
        x=pd.Categorical(df["query"], categories=ordered_labels, ordered=True),
        y=pd.Categorical(df["term_name"], categories=term_order, ordered=True),
        size="gene_ratio", sizes=(20, 300),
        hue="log10FDR", palette="viridis_r",
        edgecolor="black", linewidth=0.3,
    )
    plt.title(title, fontsize=14)
    plt.xlabel("Cluster", fontsize=14)
    plt.ylabel("GO term", fontsize=14)
    plt.xticks(rotation=45, ha="right", fontsize=12)
    plt.yticks(rotation=30, ha="right", fontsize=12)
    plt.legend(title="• Dot size = GeneRatio\n• Dot color = –log10(FDR)",
               bbox_to_anchor=(1.05, 1), loc="upper left", borderaxespad=0.)
    plt.tight_layout()
    plt.savefig(outfile, dpi=300)
    plt.close()

# ----------------------------
# 8. Generate plots
# ----------------------------
dotplot(bp_top, "GO Biological Process enrichment (top 5 in any cluster, inclusive; custom background)",
        f"FIG4D{OUT_PREFIX}_BP_dotplot.png")
dotplot(mf_top, "GO Molecular Function enrichment (top 5 in any cluster, inclusive; custom background)",
        f"FIG4D{OUT_PREFIX}_MF_dotplot.png")

print("✓ Finished.")
print("  •", f"{OUT_PREFIX}_BP_dotplot.png")
print("  •", f"{OUT_PREFIX}_MF_dotplot.png")
print("  •", f"{OUT_PREFIX}_BP_top.tsv")
print("  •", f"{OUT_PREFIX}_MF_top.tsv")
print("  •", f"{OUT_PREFIX}_full_results.tsv")


✓ Finished.
  • top5inclusive.95percentile_GO_py.bgCorrected_BP_dotplot.png
  • top5inclusive.95percentile_GO_py.bgCorrected_MF_dotplot.png
  • top5inclusive.95percentile_GO_py.bgCorrected_BP_top.tsv
  • top5inclusive.95percentile_GO_py.bgCorrected_MF_top.tsv
  • top5inclusive.95percentile_GO_py.bgCorrected_full_results.tsv
