# Import

In [None]:
import numpy as np
import json
from scipy.sparse import load_npz,save_npz,diags,csr_matrix
import scipy.sparse as sp
import pandas as pd
import os
import requests
from io import BytesIO
from tqdm import tqdm
from scipy.sparse.linalg import eigsh
from scipy.spatial.distance import pdist, squareform
import matplotlib.pyplot as plt
from pathlib import Path
from matplotlib.backends.backend_pdf import PdfPages
from pypdf import PdfReader, PdfWriter
from tempfile import NamedTemporaryFile
import networkx as nx
import pickle
import gseapy as gp
import mygene
from IPython.display import display, HTML
import re
from collections import deque
from goatools.obo_parser import GODag
import math
from itertools import combinations
from collections import Counter
from gseapy.parser import read_gmt
import time
import random

In [None]:
pd.set_option('display.width', None)      # No line-wrapping
pd.set_option('display.max_columns', None)  # Show all columns

# Prep

## Loading variables

In [None]:
DISEASE = "BIPOLAR"
DISEASE_FOLDER = f"../output/{DISEASE}/"
RESULT_FOLDER = DISEASE_FOLDER + "leiden_results"
DGIDB_DIRECTORY = f"../../Gen_Hypergraph/output/DGIDB_{DISEASE}/"
MSIGDB_DIRECTORY = "../../Gen_Hypergraph/output/MSigDB_Full/"
RESULT_COMMUNITIES_SELECTED = "result_communities_agg"
RESULT_GRAPH = "result_graph_agg"

with open(DISEASE_FOLDER + "gene_to_index_distinct.json", "r") as file:
    gene_to_index_distinct = json.load(file)
    
try:
    with open(DGIDB_DIRECTORY + f"gene_to_index_{DISEASE}.json", "r") as file:
        DGIDB_gene_to_index = json.load(file)
except FileNotFoundError:
    DGIDB_gene_to_index = {}
    print("File not found. Setting DGIDB_gene_to_index to be {}.")
    
    
sim_mat = load_npz(f"{DISEASE_FOLDER}/agg_sim_mat.npz")

In [None]:
index_to_gene_distinct = {v: k for k, v in gene_to_index_distinct.items()}

In [None]:
# Loading result graph and communities
with open(f"{RESULT_FOLDER}/{RESULT_COMMUNITIES}.pkl", "rb") as f:
    communities_selected = pickle.load(f)
with open(f"{RESULT_FOLDER}/{RESULT_GRAPH}.pkl", "rb") as f:
    graph = pickle.load(f)

## Helpful functions (big object, drop NAN)

In [None]:
# Helpful functions
def drop_nan_from_communities(communities):
    cleaned_communities = []
    total_dropped = 0

    for i, community in enumerate(communities):
        cleaned = []
        dropped = 0
        for g in community:
            if g is None or (isinstance(g, float) and math.isnan(g)):
                dropped += 1
            else:
                cleaned.append(g)
        cleaned_communities.append(cleaned)
        total_dropped += dropped
        print(f"Community {i}: dropped {dropped} NaN entries")

    print(f"\nTotal dropped across all communities: {total_dropped}")
    return cleaned_communities

def big_objects(n=10, min_mb=1):
    """
    Show the largest objects currently in memory.
    
    Parameters
    ----------
    n : int
        Number of top objects to show.
    min_mb : float
        Minimum size (in MB) to include.
    """
    import sys
    import numpy as np
    import pandas as pd
    import scipy.sparse as sp
    from IPython import get_ipython

    def get_size(obj):
        try:
            if isinstance(obj, np.ndarray):
                return obj.nbytes
            elif isinstance(obj, pd.DataFrame) or isinstance(obj, pd.Series):
                return obj.memory_usage(deep=True).sum()
            elif sp.issparse(obj):
                return (obj.data.nbytes +
                        obj.indptr.nbytes +
                        obj.indices.nbytes)
            else:
                return sys.getsizeof(obj)
        except Exception:
            return 0

    ip = get_ipython()
    if ip is None:
        ns = globals()
    else:
        ns = ip.user_ns

    items = []
    for name, val in ns.items():
        if name.startswith('_'):
            continue  # skip internals
        size = get_size(val)
        if size > min_mb * 1024 ** 2:
            items.append((name, type(val).__name__, size))

    items.sort(key=lambda x: x[2], reverse=True)

    print(f"{'Variable':30s} {'Type':25s} {'Size (MB)':>10s}")
    print("-" * 70)
    for name, t, size in items[:n]:
        print(f"{name:30s} {t:25s} {size / 1024 ** 2:10.2f}")

## Index to HGNC

In [None]:
# Convert index to ncbi
def index_to_ncbi(comms,index_to_ncbi = index_to_gene_distinct):
    comms_ncbi = [list(map(index_to_ncbi.get, c)) for c in comms]
    return comms_ncbi

In [None]:
communities_ncbi = index_to_ncbi(communities_selected,index_to_gene_distinct)
print(communities_ncbi)
print(len(communities_ncbi))

In [None]:
# NCBI to HGNC symbol
def ncbi_to_HGNC(comms_ncbi):
    comms_HGNC = []
    for community in comms_ncbi:
        mg = mygene.MyGeneInfo()
        entrez_ids = [str(e) for e in community]

        results = mg.querymany(
            entrez_ids,
            scopes="entrezgene",
            fields="symbol",
            species="human"
        )

        # Build a mapping: input ID -> symbol (or None)
        id_to_symbol = {}
        for r in results:
            q = str(r.get("query"))
            id_to_symbol[q] = r.get("symbol") if not r.get("notfound") else None

        # Preserve original order
        symbols = [id_to_symbol.get(str(e), None) for e in entrez_ids]
        comms_HGNC.append(symbols)
    return comms_HGNC


In [None]:
COMMUNITIES_HGNC = ncbi_to_HGNC(communities_ncbi)

In [None]:
print(COMMUNITIES_HGNC)

In [None]:
COMMUNITIES_HGNC = drop_nan_from_communities(COMMUNITIES_HGNC)

In [None]:
num_selected_comm = len(COMMUNITIES_HGNC)

In [None]:
print(num_selected_comm)

# Categoization Prep

### GO-slim

In [None]:
DATA_DIRECTORY = "../../data"
GO_OBO = f"{DATA_DIRECTORY}/GO/go-basic.obo"            # put the file in your working dir (or give full path)
GOSLIM_OBO = f"{DATA_DIRECTORY}/GO/goslim_generic.obo"  # swap to another slim if you prefer
GOSLIM_PIR_OBO = f"{DATA_DIRECTORY}/GO/goslim_pir.obo"  # swap to another slim if you prefer
GOSLIM_YEAST_OBO = f"{DATA_DIRECTORY}/GO/goslim_yeast.obo"
GOSLIM_AGR_OBO = f"{DATA_DIRECTORY}/GO/goslim_agr.obo"

In [None]:
# GO library
go = GODag(GO_OBO)

# SLIM libraries
slim = GODag(GOSLIM_OBO)
slim_pir = GODag(GOSLIM_PIR_OBO)
slim_yeast = GODag(GOSLIM_YEAST_OBO)
slim_agr = GODag(GOSLIM_AGR_OBO)

slim_ids = set(slim.keys())
slim_pir_ids = set(slim_pir.keys())
slim_yeast_ids = set(slim_yeast.keys())
slim_agr_ids = set(slim_agr.keys())

In [None]:
GO_RE = re.compile(r"(GO:\d{7})")

def get_goid(term: str):
    if isinstance(term, str):
        m = GO_RE.search(term)
        if m:
            return m.group(1)
    raise RuntimeError("Term not found!!")

def get_go_ancestors(go_id):
    """Return a list of ancestor GO term IDs for the given GO ID using QuickGO."""
    url = f"https://www.ebi.ac.uk/QuickGO/services/ontology/go/terms/{go_id}/ancestors"
    headers = {"Accept": "application/json"}

    r = requests.get(url, headers=headers)
    r.raise_for_status()

    data = r.json()
    results = data.get("results", [])
    if not results:
        return []

    # Ancestors come back as a simple list of GO IDs (strings)
    ancestors = results[0].get("ancestors", [])
    return set(ancestors)


def get_go_ancestors_in_slim(go_id):
    ancestors = get_go_ancestors(go_id)
    return slim_ids & ancestors

In [None]:
def get_go_ancestors_at_depth(go_id, depth, include_relations=("is_a", "part_of")):
    """
    Return the set of GO term IDs that are ancestors of `go_id` and have
    absolute depth == `depth` in the GO DAG.

    Parameters
    ----------
    go_id : str
        Starting GO term (e.g., "GO:0051310").
    depth : int
        Absolute depth in the GO DAG (e.g., 3 means all ancestors at depth=3).
    include_relations : tuple[str]
        Relation types to traverse upward, e.g. ("is_a", "part_of", "regulates", ...).

    Returns
    -------
    set[str]
        Ancestor GO IDs whose term.depth == `depth`. Empty set if none.
    """
    if depth < 0:
        return set()
    if go_id not in go:
        return set()

    # One-hop function honoring relation filter
    def parent_ids(term):
        ids = set()
        if "is_a" in include_relations:
            # GOATOOLS usually puts is_a parents here (and sometimes part_of merged)
            ids.update(p.id for p in term.parents)

        rel = getattr(term, "relationship", {}) or {}
        for r in include_relations:
            # relationship entries are already GO IDs
            ids.update(rel.get(r, []))

        # ensure IDs exist in DAG
        return {pid for pid in ids if pid in go}

    result = set()
    frontier = {go_id}
    visited = {go_id}

    # BFS upwards, but pruning branches that are already above the target depth
    while frontier:
        next_frontier = set()
        for node in frontier:
            for pid in parent_ids(go[node]):
                if pid in visited:
                    continue
                visited.add(pid)
                d = go[pid].depth  # absolute depth in DAG

                if d == depth:
                    # ancestor at the exact target depth
                    result.add(pid)
                elif d > depth:
                    # still "below" target depth (further from root),
                    # its parents might reach the target depth
                    next_frontier.add(pid)
                # if d < depth: this branch has gone above the target,
                # and all further ancestors will have depth <= d, so we can skip
        frontier = next_frontier

    return result


### KEGG

In [None]:
def build_kegg_name_to_id(species="hsa"):
    """Map KEGG pathway name -> 'hsaXXXXX' (species-specific)."""
    lines = requests.get(f"https://rest.kegg.jp/list/pathway/{species}").text.strip().splitlines()
    name_to_id = {}
    for ln in lines:
        pid, raw = ln.split("\t")
        pid = pid.replace("path:", "")  # e.g. hsa03010
        # strip " - Homo sapiens (human)" suffix
        name = re.sub(r"\s*-\s*Homo sapiens.*$", "", raw).strip()
        name_to_id[name.lower()] = pid
    return name_to_id

name_to_id = build_kegg_name_to_id("hsa")

In [None]:
def get_kegg_level2(hsa_id: str) -> str | None:
    """
    Return the KEGG Level 2 category for a pathway like 'hsa03040'.
    Example: get_kegg_level2("hsa03040") -> 'Transcription'
    """
    url = f"http://rest.kegg.jp/get/{hsa_id}"
    try:
        text = requests.get(url, timeout=10).text
    except Exception:
        return None

    for line in text.splitlines():
        if line.startswith("CLASS"):
            # CLASS line looks like: CLASS       Genetic Information Processing; Transcription
            parts = [p.strip() for p in line.split(";", maxsplit=2)]
            if len(parts) >= 2:
                return [parts[1]]
            elif len(parts) == 1:
                return [parts[0].replace("CLASS", "").strip()]
    return []

### Reactome

In [None]:
def build_reactome_level_map(level=1, species="9606"):
    """
    Returns { 'R-HSA-xxxxx': ['CategoryNameAtLevel', ...], ... } for the given species.

    Parameters
    ----------
    level : int, default=1
        1-based depth in the Reactome pathway hierarchy:
          - level=1 â†’ top-level Reactome categories (original behavior)
          - level=2 â†’ second-level ancestors, etc.
        If a node is shallower than `level`, the deepest available ancestor
        is used as a fallback.
    species : str, default="9606"
        Taxonomy ID ("9606") or species name ("Homo sapiens").
    """
    if level < 1:
        raise ValueError("level must be >= 1 (1-based depth)")

    # ensure spaces are encoded if a name is used
    species_path = species.replace(" ", "+")
    url = f"https://reactome.org/ContentService/data/eventsHierarchy/{species_path}"
    r = requests.get(url, headers={"Accept": "application/json"}, timeout=60)
    r.raise_for_status()
    trees = r.json()  # list of trees, one per TopLevelPathway

    mapping = {}

    def walk(node, ancestors):
        """
        node: current node dict
        ancestors: list of ancestor nodes from root to parent of `node`
        """
        # ancestors_chain includes current node at the end
        ancestors_chain = ancestors + [node]

        st_id = node.get("stId")
        if st_id:
            # We want the ancestor at depth `level` (1-based).
            # If the path is shorter than `level`, fall back to the deepest one.
            if len(ancestors_chain) >= level:
                cat_node = ancestors_chain[level - 1]
            else:
                cat_node = ancestors_chain[-1]

            cat_name = cat_node.get("name")
            if cat_name:
                mapping.setdefault(st_id, set()).add(cat_name)

        # Recurse into children
        for child in node.get("children", []):
            walk(child, ancestors_chain)

    # Each tree is a top-level pathway
    for top in trees:
        walk(top, [])

    # sets -> sorted lists
    return {k: sorted(v) for k, v in mapping.items()}

# Example:
reactome_level1 = build_reactome_level_map(level = 2)
  # -> ['Signal Transduction']

In [None]:
print(reactome_level1["R-HSA-9007101"])

# Run Enrichment Analysis

In [None]:
TERM_SCORE_CAP = 0.001
PERCENTAGE = 0.1

### GO

In [None]:
# GO Analysis; save terms with small size and high p-value
def go_enrichment(communities,
                  term_score_cap,
                  percentage, 
                  slim_ids = slim_yeast_ids,
                  depth = 1):
    important_terms = pd.DataFrame(columns=["Community Index","Community Size","Term", "Overlap", "Adjusted P-value","Category"])
    category_counts_and_overlap_score_list = {}
    i = 0
    num_nonzero_communities = 0
    
    for community in communities:
        # Gene Ontology enrichment
        enr_go = gp.enrichr(
            gene_list=community,
            gene_sets=['GO_Biological_Process_2023',
                    'GO_Molecular_Function_2023',
                    'GO_Cellular_Component_2023'],
            organism='Human',
            outdir=None # don't write to disk
        )
        go_df = enr_go.results
        

        # Filter by overlap percentage and adjusted p-value
        mask =  (go_df["Adjusted P-value"] < term_score_cap) & (go_df["Overlap"].apply(lambda x: int(x.split("/")[0])/int(x.split("/")[1]) > percentage))
        filtered = go_df[mask].copy()
        
        # Categorization from GO-Slim
        filtered["GO_ID"] = filtered["Term"].apply(get_goid)
        # filtered["Slim_IDs"] = filtered["GO_ID"].apply(get_go_ancestors_in_slim)
        filtered["Slim_IDs"] = filtered["GO_ID"].apply(lambda id: get_go_ancestors_at_depth(id, depth=depth, include_relations=("is_a", "part_of")))
        
        # Get empty count
        empty_count = (filtered["Slim_IDs"].apply(len) == 0).sum()
        
        # Get slim names    
        filtered["Category"] = filtered["Slim_IDs"].apply(lambda ids: [go[i].name for i in ids])
        
        # Sort
        filtered['Overlap (value)'] = filtered['Overlap'].apply(lambda x: int(x.split("/")[0])/int(x.split("/")[1]))
        filtered = filtered.sort_values(['Overlap (value)'], ascending=False)
        
        # Compute overlap score for every category:
        filtered_exploded = filtered.explode('Category').reset_index(drop=True)
        category_counts_and_overlap_score = {}
        for val, group in filtered_exploded.groupby('Category'):
            overlap_list = group["Overlap"].tolist()
            numerators = [(lambda x: int(x.split("/")[0]))(e) for e in overlap_list]
            denominators = [(lambda x: int(x.split("/")[1]))(e) for e in overlap_list]
            overlap_score = sum(numerators)/sum(denominators)
            
            category_counts_and_overlap_score[val] = (len(group),overlap_score,)
        
        category_counts_and_overlap_score_list[i] = category_counts_and_overlap_score
        
        # Add results to important terms
        if not filtered.empty:
            # print size of community
            print(f"Size of community: {len(community)}")
            
            # print number of filtered terms
            print(f"Number of filtered terms: {len(filtered)}")
            print(f"Number of unmapped terms: {empty_count}")      
            print(category_counts_and_overlap_score)
            filtered.loc[:, "Community Index"] = i
            filtered.loc[:, "Community Size"] = len(community)
            important_terms = pd.concat([important_terms, filtered], ignore_index=True)
            display(HTML(filtered[["Community Index",'Term','Overlap','Adjusted P-value',"Slim_IDs","Category"]].head(10).to_html(max_cols=None)))
            num_nonzero_communities += 1

        i += 1
    print(f"{num_nonzero_communities} out of {len(communities)} communities had significant GO terms.")
    return important_terms,category_counts_and_overlap_score_list

In [None]:
go_important_terms,go_category_counts_and_overlap_score = go_enrichment(COMMUNITIES_HGNC,TERM_SCORE_CAP,PERCENTAGE,slim_ids,depth = 2)

In [None]:
go_important_terms

### KEGG

In [None]:
# KEGG
def kegg_enrichment(communities,
                    term_score_cap,
                    percentage):
    important_terms = pd.DataFrame(columns=["Community Index","Community Size","Term", "Overlap", "Adjusted P-value","Category"])
    category_counts_and_overlap_score_list = {}
    i = 0
    num_nonzero_communities = 0
    for community in communities:
        enr_path = gp.enrichr(
            gene_list=community,
            gene_sets=['KEGG_2021_Human'],
            organism='Human',
            outdir=None
        )
        KEGG_df = enr_path.results

        # Filter by overlap percentage and adjusted p-value
        mask =  (KEGG_df["Adjusted P-value"] < term_score_cap) & (KEGG_df["Overlap"].apply(lambda x: int(x.split("/")[0])/int(x.split("/")[1]) > percentage))
        filtered = KEGG_df[mask].copy()
        
        # Categorization from KEGG Level 2
        filtered["KEGG_ID"] = filtered["Term"].str.replace(r"\s*-\s*Homo sapiens.*$", "", regex=True).str.lower().map(name_to_id)
        filtered["Category"] = filtered["KEGG_ID"].map(get_kegg_level2)
        
        # Sort
        filtered['Overlap (value)'] = filtered['Overlap'].apply(lambda x: int(x.split("/")[0])/int(x.split("/")[1]))
        filtered = filtered.sort_values(['Overlap (value)'], ascending=False)
        
        # Compute overlap score for every category:
        filtered_exploded = filtered.explode('Category').reset_index(drop=True)
        category_counts_and_overlap_score = {}
        for val, group in filtered_exploded.groupby('Category'):
            overlap_list = group["Overlap"].tolist()
            numerators = [(lambda x: int(x.split("/")[0]))(e) for e in overlap_list]
            denominators = [(lambda x: int(x.split("/")[1]))(e) for e in overlap_list]
            overlap_score = sum(numerators)/sum(denominators)
            
            category_counts_and_overlap_score[val] = (len(group),overlap_score,)
        
        category_counts_and_overlap_score_list[i] = category_counts_and_overlap_score
        
        # Add results to important terms
        if not filtered.empty:
            # print size of community
            print(f"Size of community: {len(community)}")   
            
            # print number of filtered terms
            print(f"Number of filtered terms: {len(filtered)}")
            filtered.loc[:, "Community Index"] = i
            filtered.loc[:, "Community Size"] = len(community)
            important_terms = pd.concat([important_terms, filtered], ignore_index=True)
            
            # show results
            display(HTML(filtered[["Community Index",'Term','Overlap','Adjusted P-value',"KEGG_ID","Category"]].head(10).to_html(max_cols=None)))
            print(category_counts_and_overlap_score)
            num_nonzero_communities += 1

        i += 1
    print(f"{num_nonzero_communities} out of {len(communities)} communities had significant GO terms.")
    return important_terms,category_counts_and_overlap_score_list

In [None]:
kegg_important_terms,kegg_category_counts_and_overlap_score = kegg_enrichment(COMMUNITIES_HGNC,TERM_SCORE_CAP,PERCENTAGE)

In [None]:
kegg_important_terms

### Reactome

In [None]:
# Reactome enrichment
def reactome_enrichment(communities,
                        term_score_cap,
                        percentage):
    important_terms = pd.DataFrame(columns=["Community Index","Community Size","Term", "Overlap", "Adjusted P-value","Category"])
    category_counts_and_overlap_score_list= {}
    i = 0
    num_nonzero_communities = 0
    for community in communities:
        enr_path = gp.enrichr(
            gene_list=community,
            gene_sets=['Reactome_2022'],
            organism='Human',
            outdir=None
        )
        Reactome_df = enr_path.results

        # Filter by overlap percentage and adjusted p-value
        mask =  (Reactome_df["Adjusted P-value"] < term_score_cap) & (Reactome_df["Overlap"].apply(lambda x: int(x.split("/")[0])/int(x.split("/")[1]) > percentage))
        filtered = Reactome_df[mask].copy()
        
        # Categorization from Reactome Level 1
        filtered["Category"] = filtered["Term"].str.extract(r"(R-[A-Z]+-\d+)", expand=False).map(reactome_level1)
        
        # Sort
        filtered['Overlap (value)'] = filtered['Overlap'].apply(lambda x: int(x.split("/")[0])/int(x.split("/")[1]))
        filtered = filtered.sort_values(['Overlap (value)'], ascending=False)
        
        # Compute overlap score for every category:
        filtered_exploded = filtered.explode('Category').reset_index(drop=True)
        category_counts_and_overlap_score = {}
        for val, group in filtered_exploded.groupby('Category'):
            overlap_list = group["Overlap"].tolist()
            numerators = [(lambda x: int(x.split("/")[0]))(e) for e in overlap_list]
            denominators = [(lambda x: int(x.split("/")[1]))(e) for e in overlap_list]
            overlap_score = sum(numerators)/sum(denominators)
            
            category_counts_and_overlap_score[val] = (len(group),overlap_score,)
        
        category_counts_and_overlap_score_list[i] = category_counts_and_overlap_score
        
        # Add results to important terms
        if not filtered.empty:
            print(f"Size of community: {len(community)}")
            print(f"Number of filtered terms: {len(filtered)}")
            filtered.loc[:, "Community Index"] = i
            filtered.loc[:, "Community Size"] = len(community)
            important_terms = pd.concat([important_terms, filtered], ignore_index=True)
            display(HTML(filtered[["Community Index",'Term','Overlap','Adjusted P-value',"Category"]].head(30).to_html(max_cols=None)))
            print(category_counts_and_overlap_score)
            num_nonzero_communities += 1
        i += 1
    print(f"{num_nonzero_communities} out of {len(communities)} communities had significant GO terms.")
    return important_terms,category_counts_and_overlap_score_list

In [None]:
reactome_important_terms,reactome_category_counts_and_overlap_score = reactome_enrichment(COMMUNITIES_HGNC,TERM_SCORE_CAP,PERCENTAGE)

In [None]:
reactome_important_terms

### Disease Data Sets

In [None]:
# disease_term_score_cap = 0.001
# disease_percentage = 0.1
# important_diseases = pd.DataFrame(columns=["Community Index","Community Size","Term", "Overlap", "Adjusted P-value"])

In [None]:
# # Disease-gene enrichment libraries
# disease_sets = [
#     'DisGeNET_2020', # curated geneâ€“disease associations
#     'GWAS_Catalog_2023', # genome-wide association hits
#     'OMIM_Disease', # Mendelian disorders
#     'Jensen_DISEASES' # text-mined associations
# ]

# # # Disease-gene enrichment Analysis; save terms with small size and high p-value
# i = 0
# for community in communities_HGNC:
#     # Gene Ontology enrichment
#     enr_disease = gp.enrichr(
#         gene_list=community,
#         gene_sets=disease_sets,
#         organism='Human',
#         outdir=None # don't write to disk
#     )
#     enr_disease_df = enr_disease.results.sort_values('Adjusted P-value')
#     print(f"Size of community: {len(community)}")

#     mask =  (enr_disease_df["Adjusted P-value"] < disease_term_score_cap) & (enr_disease_df["Overlap"].apply(lambda x: int(x.split("/")[0])/int(x.split("/")[1]) > disease_percentage))
        
#     filtered = enr_disease_df[mask].copy()
#     if not filtered.empty:
#         filtered.loc[:, "Community Index"] = i
#         filtered.loc[:, "Community Size"] = len(community)
#         important_diseases = pd.concat([important_diseases, filtered], ignore_index=True)

#     display(HTML(filtered[['Term','Overlap','Adjusted P-value']].head(10).to_html(max_cols=None)))
#     i += 1

# Important Terms Analysis

### Constructing Important Terms df

In [None]:
def comm_similarity_with_term(x,y):
    return 1-(abs(x-y)/max(x,y))

In [None]:
important_terms = pd.DataFrame(columns=["Community Index","Community Size","Term", "Overlap", "Adjusted P-value","Category"])

In [None]:
c = [go_important_terms,kegg_important_terms,reactome_important_terms]
important_terms = pd.concat(c, ignore_index=True)

In [None]:
# important_terms = important_terms.sort_values(by="Overlap (value)",ascending=False)
important_terms = important_terms.sort_values(by="Community Index")
important_terms

In [None]:
# Community id to size dict
com_id_to_size = {i : len(COMMUNITIES_HGNC[i]) for i in range(len(COMMUNITIES_HGNC))}

In [None]:
unique_com_id_to_size = important_terms.drop_duplicates(subset="Community Index", keep="first")

In [None]:
comm_size_dict = dict(zip(unique_com_id_to_size["Community Index"], unique_com_id_to_size["Community Size"]))

In [None]:
important_terms.to_csv(f"../output/{DISEASE}/important_terms_{DISEASE}.csv", index=False)

### Graph Building

In [None]:
# df must have: "Community Index", "Term", "Overlap (value)"
work = important_terms.loc[:, ["Community Index", "Term", "Overlap (value)","Category"]].copy()
work["Overlap (value)"] = work["Overlap (value)"].astype(float)

# Ensure (community, term) uniqueness
dupes = work.duplicated(subset=["Community Index", "Term"], keep=False)
if dupes.any():
    raise ValueError("Duplicated (Community Index, Term) rows found; ensure uniqueness first.")

# --- Build edge weights AND collect contributing terms per pair ---
edge_weights = {}              # (u, v) -> float
edge_counts  = {}              # (u, v) -> int
edge_terms   = {}              # (u, v) -> list[(term, contrib)]

for term, sub in work.groupby("Term", sort=False):
    comms  = sub["Community Index"].to_numpy()
    scores = sub["Overlap (value)"].to_numpy()
    if len(comms) < 2:
        continue
    for i, j in combinations(range(len(comms)), 2):
        u, v = comms[i], comms[j]
        if u > v: u, v = v, u  # canonical ordering
        contrib = comm_similarity_with_term(scores[i], scores[j])

        edge_weights[(u, v)] = edge_weights.get((u, v), 0.0) + contrib
        edge_counts[(u, v)]  = edge_counts.get((u, v), 0)    + 1
        edge_terms.setdefault((u, v), []).append((term, contrib))

# Sort contributing terms by contribution desc for each edge
for key in edge_terms:
    edge_terms[key].sort(key=lambda t: t[1], reverse=True)

# --- Build edge list DataFrame (optional, useful to inspect) ---
edge_df = pd.DataFrame(
    [(u, v, edge_weights[(u, v)], edge_counts[(u, v)], edge_terms.get((u, v), []))
     for (u, v) in edge_weights.keys()],
    columns=["u", "v", "weight", "shared_terms", "terms_contrib"]
).sort_values(["weight", "shared_terms"], ascending=[False, False]).reset_index(drop=True)

# --- Build NetworkX graph with attributes ---
G = nx.Graph()
G.add_nodes_from(pd.unique(work["Community Index"]))
for _, r in edge_df.iterrows():
    G.add_edge(
        int(r.u), int(r.v),
        weight=float(r.weight),
        shared_terms=int(r.shared_terms),
        terms_contrib=r.terms_contrib  # list of (term, contrib) sorted desc
    )

In [None]:
work

### Table

In [None]:
#--------------Table------------------
term_contribs = []

for term, sub in work.groupby("Term", sort=False):
    comms  = sub["Community Index"].to_numpy()
    scores = sub["Overlap (value)"].to_numpy()
    if len(comms) < 2:
        continue
    for i, j in combinations(range(len(comms)), 2):
        u, v = comms[i], comms[j]
        if u > v:
            u, v = v, u
        contrib = comm_similarity_with_term(scores[i], scores[j])
        cat = sub["Category"].iloc[0] if "Category" in sub.columns else None
        term_contribs.append((u, v, term, contrib, cat))

# 2) Build DataFrame
term_df = pd.DataFrame(term_contribs, columns=["u", "v", "Term", "Contribution","Category"])
# 3) Sort and aggregate terms per edge (keep per-term order)
agg_blocks = []
for (u, v), sub in term_df.groupby(["u", "v"]):
    sub_sorted = sub.sort_values("Contribution", ascending=False)
    
    # Create category count dictionary
    category_counts = Counter(
        c
        for cats in sub_sorted["Category"].dropna()
        for c in cats
    )
    category_counts_dict = dict(category_counts)

    # sub_sorted = sub.sort_values(sub_sorted["Category"].apply(tuple), ascending=False)
    block = "\n".join(
        [f"  - {t} {cat} ({c:.3f})"
        for t, c, cat in zip(sub_sorted["Term"], sub_sorted["Contribution"], sub_sorted["Category"])]
    )
    total = sub_sorted["Contribution"].sum()
    agg_blocks.append({
        "u": u,
        "v": v,
        "Total Weight": total,
        "Terms (by contribution)": block,
        "Category Count": category_counts_dict
    })

# 4) Create final block table
block_df = pd.DataFrame(agg_blocks).sort_values("Total Weight", ascending=False).reset_index(drop=True)
# 5) Display nicely
for _, row in block_df.iterrows():
    print(f"Community pair ({row.u}, {row.v}) â€” Total Weight = {row['Total Weight']:.3f}")
    print(row["Terms (by contribution)"])
    
    print()
    print("Category Count:")
    for key, value in sorted(row["Category Count"].items(), key=lambda x: x[1], reverse=True):
        print(f"{key}: {value}")

    print("-" * 60)

In [None]:
# block_df.to_excel("output.xlsx", index=False)

### Category Counts

In [None]:
def print_category_count_by_comm(category_count_by_comm):
    for comm_id, cat_dict in category_count_by_comm.items():
        print(f"\nðŸ§© Community {comm_id}")
        print("-" * (14 + len(str(comm_id))))

        if not cat_dict:
            print("  (no categories)")
            continue

        # Sort categories by descending count
        for cat, count in sorted(cat_dict.items(), key=lambda x: x[1], reverse=True):
            print(f"  â€¢ {cat:<50} {count}")

In [None]:
category_count_by_comm = {}
for i in range(num_selected_comm):
    comm_cates = go_category_counts_and_overlap_score[i] | kegg_category_counts_and_overlap_score[i] | reactome_category_counts_and_overlap_score[i]
    category_count_by_comm[i] = dict(sorted(comm_cates.items(), key=lambda x: x[1],reverse=True))

In [None]:
print_category_count_by_comm(category_count_by_comm)

### Visualization

# Robustness Analysis

In [None]:
def run_enrichment_func(community,term_score_cap,percentage):
    # GO df
    enr_go = gp.enrichr(
        gene_list=community,
        gene_sets=['GO_Biological_Process_2023',
                'GO_Molecular_Function_2023',
                'GO_Cellular_Component_2023'],
        organism='Human',
        outdir=None # don't write to disk
    )
    GO_df = enr_go.results
    mask =  (GO_df["Adjusted P-value"] < term_score_cap) & (GO_df["Overlap"].apply(lambda x: int(x.split("/")[0])/int(x.split("/")[1]) > percentage))
    GO_df = GO_df[mask].copy()   
    
    # KEGG df
    enr_kegg = gp.enrichr(
        gene_list=community,
        gene_sets=['KEGG_2021_Human'],
        organism='Human',
        outdir=None
    )
    KEGG_df = enr_kegg.results
    mask =  (KEGG_df["Adjusted P-value"] < term_score_cap) & (KEGG_df["Overlap"].apply(lambda x: int(x.split("/")[0])/int(x.split("/")[1]) > percentage))
    KEGG_df = KEGG_df[mask].copy() 
       
    # Reactome df
    enr_reactome = gp.enrichr(
        gene_list=community,
        gene_sets=['Reactome_2022'],
        organism='Human',
        outdir=None
    )
    Reactome_df = enr_reactome.results  
    mask =  (Reactome_df["Adjusted P-value"] < term_score_cap) & (Reactome_df["Overlap"].apply(lambda x: int(x.split("/")[0])/int(x.split("/")[1]) > percentage))
    Reactome_df = Reactome_df[mask].copy()
    
    
    all_df = [GO_df,KEGG_df,Reactome_df]
    # build result df by concatenating
    result = pd.concat(all_df, ignore_index=True)
    return result

In [None]:
from json import JSONDecodeError

# ---------------- 1) Safe wrapper that calls YOUR enrichr function ----------------
_ENR_CACHE = {}  # key: tuple(sorted(genes)) -> DataFrame (copy)

def run_enrichment_safe(run_enrichment_func, community, retries=5, base_sleep=0.8):
    """
    Calls user's run_enrichment_func(community) with retries + memoization.
    Returns a DataFrame (possibly empty). Never raises JSONDecodeError outward.
    """
    # Ensure we always pass a list of gene symbols (never a bare string)
    genes = np.atleast_1d(np.array(community, dtype=object)).tolist()
    if len(genes) == 0:
        return pd.DataFrame()

    key = tuple(sorted(genes))
    if key in _ENR_CACHE:
        return _ENR_CACHE[key].copy()

    for a in range(retries):
        try:
            df = run_enrichment_func(genes,TERM_SCORE_CAP,PERCENTAGE)
            if df is None:
                # treat as transient failure to trigger retry
                raise RuntimeError("run_enrichment_func returned None")
            _ENR_CACHE[key] = df.copy()
            return df
        except (JSONDecodeError, OSError, RuntimeError, ValueError) as e:
            # Transient errors from HTTP/JSON/file handling inside gseapy
            if a == retries - 1:
                # Give up: return empty so pipeline continues
                return pd.DataFrame()
            time.sleep(base_sleep * (2 ** a) + np.random.rand() * 0.3)

    return pd.DataFrame()

# ---------------- 2) Minimal bootstrap to record robust terms ----------------
def get_robust_terms(communities_HGNC, run_enrichment_func,
                     R=50, leaveout=0.10, recurrence_cutoff=0.70, seed=42):
    """
    Uses YOUR run_enrichment_func(community)->DataFrame (already filtered to significant terms).
    Returns DataFrame with columns: community_id, term, recurrence (and Gene_set if available).
    """
    rng = np.random.default_rng(seed)
    rows = []

    for cid, community in enumerate(communities_HGNC):
        n = len(community)
        if n == 0:
            continue
        drop_k = max(1, int(np.floor(leaveout * n)))
        counts = Counter()

        for _ in range(R):
            # Jackknife subset (ensure not empty)
            keep = np.ones(n, dtype=bool)
            keep[rng.choice(n, size=min(drop_k, n), replace=False)] = False
            sub = np.atleast_1d(np.array(community, dtype=object)[keep]).tolist()
            if len(sub) == 0:
                continue

            df = run_enrichment_safe(run_enrichment_func, sub)
            if df is None or df.empty:
                continue

            # Your function already returns significant terms; just count them.
            # If it includes multiple libraries, preserve Gene_set to disambiguate names.
            if 'Term' not in df.columns:
                continue  # be defensive

            if 'Gene_set' in df.columns:
                terms = (df[['Term', 'Gene_set']]
                         .dropna()
                         .drop_duplicates()
                         .apply(lambda r: f"{r['Term']}|{r['Gene_set']}", axis=1)
                         .tolist())
            else:
                terms = df['Term'].dropna().drop_duplicates().tolist()

            counts.update(terms)

            # tiny pause helps with API rate limits if your func calls Enrichr internally
            time.sleep(0.03)

        # Keep only robust terms
        for t, c in counts.items():
            freq = c / max(R, 1)
            if freq >= recurrence_cutoff:
                if '|' in t:
                    term, gene_set = t.split('|', 1)
                    rows.append({'Community Index': cid, 'Term': term, 'recurrence': freq, 'Gene_set': gene_set})
                else:
                    rows.append({'Community Index': cid, 'Term': t, 'recurrence': freq})

    return (pd.DataFrame(rows)
              .sort_values(['Community Index', 'recurrence'], ascending=[True, False])
              .reset_index(drop=True))

In [None]:
twr3 = get_robust_terms([COMMUNITIES_HGNC[1]], run_enrichment_func,
                                R=25, leaveout=0.1, recurrence_cutoff=0)

In [None]:
twr3

In [None]:
terms_with_recurrence = get_robust_terms(COMMUNITIES_HGNC, run_enrichment_func,
                                R=10, leaveout=0.1, recurrence_cutoff=0)

In [None]:
terms_with_recurrence

In [None]:
# rename important terms to match terms_with_recurrence
important_terms = important_terms.rename(columns={'index': 'community_id'})
important_terms = important_terms.rename(columns={'Term': 'term'})

In [None]:
terms_with_rec_merged = important_terms.merge(
    terms_with_recurrence[['community_id', 'term', 'Gene_set', 'recurrence']],
    on=['community_id', 'term', 'Gene_set'],
    how='left'
)

terms_with_rec_merged['recurrence'] = terms_with_rec_merged['recurrence'].fillna(0.0)

terms_with_rec_merged = terms_with_rec_merged.sort_values(
    ['community_id', 'recurrence'],
    ascending=[True, False]
).reset_index(drop=True)

In [None]:
terms_with_rec_merged

In [None]:
community_summary = (
    terms_with_rec_merged
    .groupby("community_id")["recurrence"]
    .agg(mean_recurrence="mean", term_count="count")
    .reset_index()
)

print(community_summary)

In [None]:
display(HTML(terms_with_recurrence.to_html(max_cols=None)))

# Checks!

In [None]:
for c in communities:
    print(len(c))

In [None]:
DGIDB_genes_ncbi = list(DGIDB_gene_to_index.keys())

In [None]:
all_comms_ncbi = index_to_ncbi(communities,index_to_gene_distinct)

In [None]:
print(all_comms_ncbi)

In [None]:
def DGIDB_count(c):
    return len(set(c) & set(DGIDB_genes_ncbi))

In [None]:
for c in all_comms_ncbi:
    print(len(c),DGIDB_count(c))

In [None]:
def tbd(id):
    print(len(communities[id]))
    c8_ncbi = index_to_ncbi([communities[id]])[0]
    print(len(c8_ncbi))
    print(DGIDB_count(c8_ncbi))

In [None]:
def tbd_selected(id):
    print(len(communities_selected[id]))
    c8_ncbi = index_to_ncbi([communities_selected[id]])[0]
    print(len(c8_ncbi))
    print(DGIDB_count(c8_ncbi))

In [None]:
tbd_selected(1)