In [13]:
import json
from scipy.sparse import load_npz,save_npz,diags,csr_matrix,issparse
from matplotlib.patches import Wedge, Patch
import numpy as np
import networkx as nx
import pickle

In [5]:
DISEASE = "BIPOLAR"
DISEASE_FOLDER = f"../output/{DISEASE}/"
RESULT_FOLDER = DISEASE_FOLDER + "leiden_results"
DGIDB_DIRECTORY = f"../../Gen_Hypergraph/output/DGIDB_{DISEASE}/"
MSIGDB_DIRECTORY = "../../Gen_Hypergraph/output/MSigDB_Full/"
RESULT_COMMUNITIES = "result_communities_agg"
RESULT_GRAPH = "result_graph_agg"

with open(DISEASE_FOLDER + "gene_to_index_distinct.json", "r") as file:
    gene_to_index_distinct = json.load(file)
    
try:
    with open(DGIDB_DIRECTORY + f"gene_to_index_{DISEASE}.json", "r") as file:
        DGIDB_gene_to_index = json.load(file)
except FileNotFoundError:
    DGIDB_gene_to_index = {}
    print("File not found. Setting DGIDB_gene_to_index to be {}.")
    
    
sim_mat = load_npz(f"{DISEASE_FOLDER}/agg_sim_mat.npz")

In [18]:
# Loading result graph and communities
with open(f"{RESULT_FOLDER}/{RESULT_COMMUNITIES}.pkl", "rb") as f:
    communities = pickle.load(f)
with open(f"{RESULT_FOLDER}/{RESULT_GRAPH}.pkl", "rb") as f:
    graph = pickle.load(f)

In [17]:
len(communities[0])

3205

In [11]:
def community_similarity_from_node_similarity(S, communities):
    """
    Compute a community–community similarity matrix from a node–node
    similarity matrix S.

    S : (n, n) sparse or dense similarity matrix (preferably CSR, symmetric)
    communities : list of lists/arrays of node indices

    Returns
    -------
    sim : (K, K) numpy array
        sim[i, j] = average similarity between nodes in community i and j
                    = sum_{u in Ci} sum_{v in Cj} S[u, v] / (|Ci| * |Cj|)
    """
    # Ensure sparse CSR
    if not issparse(S):
        S = csr_matrix(S)
    else:
        S = S.tocsr()

    K = len(communities)
    n = S.shape[0]

    # Build membership matrix H (n x K) with H[u, i] = 1 if u ∈ community i
    rows = []
    cols = []
    for c_idx, nodes in enumerate(communities):
        rows.extend(nodes)
        cols.extend([c_idx] * len(nodes))

    data = np.ones(len(rows), dtype=float)
    H = csr_matrix((data, (rows, cols)), shape=(n, K))

    # Block sums: B = H^T * S * H  (K x K)
    B = (H.T @ S @ H).toarray()

    # Normalize by |Ci| * |Cj|
    sizes = np.array([len(c) for c in communities], dtype=float)
    norm = np.outer(sizes, sizes)
    sim = np.divide(B, norm, out=np.zeros_like(B, dtype=float), where=norm > 0)

    return sim


def build_community_similarity_graph(
    S,
    communities,
    threshold=1e-4,
    add_self_loops=False
):
    """
    Build a NetworkX graph where each node is a community and edges are
    weighted by similarity, computed from a node–node similarity matrix.

    Parameters
    ----------
    S : (n, n) sparse or dense similarity matrix (symmetric)
    communities : list of lists/arrays of node indices
    threshold : float
        Only edges with similarity > threshold are added.
    add_self_loops : bool
        If True, add self-loop edges (i, i) with weight sim[i, i].

    Returns
    -------
    G_comm : networkx.Graph
        Graph whose nodes are community indices (0..K-1) with attributes:
            - 'size' : community size
        and edges (i, j) with:
            - 'weight' : similarity between community i and j
    sim_comm : (K, K) numpy array
        Community–community similarity matrix.
    """
    sim_comm = community_similarity_from_node_similarity(S, communities)
    K = sim_comm.shape[0]

    G_comm = nx.Graph()

    # Add nodes with size attribute
    sizes = [len(c) for c in communities]
    for i in range(K):
        G_comm.add_node(i, size=sizes[i])

    # Add edges
    for i in range(K):
        # optional self-loop
        if add_self_loops and sim_comm[i, i] > threshold:
            G_comm.add_edge(i, i, weight=sim_comm[i, i])

        for j in range(i + 1, K):
            w = sim_comm[i, j]
            if w > threshold:
                G_comm.add_edge(i, j, weight=w)

    return G_comm, sim_comm


In [8]:
def plot_community_similarity_pies(
    sim_matrix,
    community_category_counts,
    community_sizes=None,
    min_frac=0.05,
    edge_threshold=0.0,
    figsize=(10, 8),
    title="Community similarity (pie = category composition)",
    edge_label_fmt="{:.2f}",
    layout="spring",
    seed=42,
):
    """
    Plot a community–community similarity graph where each node is drawn
    as a pie chart showing that community's category composition.

    Parameters
    ----------
    sim_matrix : array-like or (n,n) sparse matrix
        Community–community similarity matrix. Can be dense (NumPy array)
        or sparse (CSR). Assumed symmetric, non-negative.
        Entry (i,j) gives similarity between community i and j.

    community_category_counts : list of dict
        Length = n_communities.
        community_category_counts[i] is a dict:
            {category_name: count, ...}
        This is where you can already union Reactome/KEGG/GO, e.g.
            merged_dict[i] = (
                reactome_category_counts_and_overlap_score[i]
                | kegg_category_counts_and_overlap_score[i]
                | go_category_counts_and_overlap_score[i]
            )
        Only the *counts* are used here.

    community_sizes : 1D array-like or None
        Size of each community, used for node radius.
        If None, size is inferred as sum of counts in
        community_category_counts[i].

    min_frac : float, default 0.05
        Per-community threshold: categories whose count / total_count
        is < min_frac are lumped into an "Other" slice.

    edge_threshold : float, default 0.0
        Only draw edges with similarity > edge_threshold.

    figsize : tuple
        Matplotlib figure size.

    title : str
        Suptitle for the plot.

    edge_label_fmt : str
        Format string for edge label text.

    layout : {"spring", "kamada_kawai"}
        Which NetworkX layout to use.

    seed : int
        Random seed for reproducible layouts.

    Returns
    -------
    fig, ax : matplotlib Figure and Axes
    """

    # ------------------------------------------------------------------
    # 0. Basic checks & normalization
    # ------------------------------------------------------------------
    if issparse(sim_matrix):
        A = sim_matrix.tocsr()
        n = A.shape[0]
        sim_dense = A.toarray()
    else:
        sim_dense = np.asarray(sim_matrix)
        n = sim_dense.shape[0]

    if len(community_category_counts) != n:
        raise ValueError(
            f"community_category_counts has length {len(community_category_counts)}, "
            f"but sim_matrix is {n}x{n}."
        )

    if community_sizes is None:
        community_sizes = np.array(
            [sum(cat_dict.values()) for cat_dict in community_category_counts],
            dtype=float,
        )
    else:
        community_sizes = np.asarray(community_sizes, dtype=float)
        if community_sizes.shape[0] != n:
            raise ValueError("community_sizes must have length n_communities")

    # Avoid zeros in size (for radius scaling)
    community_sizes = np.maximum(community_sizes, 1.0)

    # ------------------------------------------------------------------
    # 1. Build the similarity graph
    # ------------------------------------------------------------------
    G = nx.Graph()
    for i in range(n):
        G.add_node(i)

    for i in range(n):
        for j in range(i + 1, n):
            w = sim_dense[i, j]
            if w > edge_threshold:
                G.add_edge(i, j, weight=w)

    # ------------------------------------------------------------------
    # 2. Layout
    # ------------------------------------------------------------------
    if layout == "spring":
        pos = nx.spring_layout(G, seed=seed, weight="weight")
    elif layout == "kamada_kawai":
        pos = nx.kamada_kawai_layout(G, weight="weight")
    else:
        raise ValueError("layout must be 'spring' or 'kamada_kawai'")

    # ------------------------------------------------------------------
    # 3. Prepare category universe and hatching patterns
    # ------------------------------------------------------------------
    # Collect all category names (for consistent legend & hatches)
    all_categories = set()
    for d in community_category_counts:
        all_categories.update(d.keys())
    all_categories = sorted(all_categories)

    # Hatching patterns to distinguish categories (repeats if needed)
    hatch_patterns = [
        "/", "\\", "|", "-", "+", "x", "o", "O", ".", "*",
        "///", "\\\\", "||", "--", "++", "xx", "oo", "OO"
    ]
    num_patterns = len(hatch_patterns)

    category_to_hatch = {
        cat: hatch_patterns[i % num_patterns] for i, cat in enumerate(all_categories)
    }

    # ------------------------------------------------------------------
    # 4. Start plotting
    # ------------------------------------------------------------------
    fig, ax = plt.subplots(figsize=figsize)
    ax.set_aspect("equal")
    ax.set_xticks([])
    ax.set_yticks([])

    # Draw edges first (behind nodes)
    if G.number_of_edges() > 0:
        weights = np.array([G[u][v]["weight"] for u, v in G.edges()])
        # Normalize line widths nicely
        if weights.max() > 0:
            widths = 1.0 + 4.0 * (weights - weights.min()) / (weights.max() - weights.min() + 1e-12)
        else:
            widths = np.ones_like(weights)
        nx.draw_networkx_edges(G, pos, ax=ax, width=widths, alpha=0.4)

        # Edge labels
        edge_labels = {
            (u, v): edge_label_fmt.format(d["weight"])
            for u, v, d in G.edges(data=True)
        }
        nx.draw_networkx_edge_labels(
            G, pos, edge_labels=edge_labels, font_size=8, ax=ax
        )

    # ------------------------------------------------------------------
    # 5. Node pies
    # ------------------------------------------------------------------
    # radius scaling: make radii between r_min and r_max
    r_min, r_max = 0.15, 0.35
    size_norm = (community_sizes - community_sizes.min()) / (
        community_sizes.max() - community_sizes.min() + 1e-12
    )
    radii = r_min + (r_max - r_min) * size_norm

    # Legend accumulation: which categories actually appear after thresholding?
    legend_categories = set()

    for i in range(n):
        x, y = pos[i]
        r = radii[i]
        cat_counts = community_category_counts[i]

        if not cat_counts:
            # no categories; draw a simple grey circle
            circ = plt.Circle((x, y), r, facecolor="lightgrey", edgecolor="black", alpha=0.7)
            ax.add_patch(circ)
            continue

        total = sum(cat_counts.values())
        if total <= 0:
            circ = plt.Circle((x, y), r, facecolor="lightgrey", edgecolor="black", alpha=0.7)
            ax.add_patch(circ)
            continue

        # Filter categories by min_frac and collect "Other"
        cat_items = sorted(cat_counts.items(), key=lambda kv: kv[1], reverse=True)
        kept = []
        other_count = 0
        for cat, cnt in cat_items:
            frac = cnt / total
            if frac >= min_frac:
                kept.append((cat, cnt))
            else:
                other_count += cnt

        if other_count > 0:
            kept.append(("Other", other_count))

        kept_total = sum(cnt for _, cnt in kept)

        # Draw wedges
        start_angle = 0.0
        for cat, cnt in kept:
            frac = cnt / kept_total
            theta = 360.0 * frac
            wedge = Wedge(
                center=(x, y),
                r=r,
                theta1=start_angle,
                theta2=start_angle + theta,
                facecolor="lightgrey",   # uniform color
                edgecolor="black",
                linewidth=0.7,
                hatch=category_to_hatch.get(cat, "///") if cat != "Other" else "",
                alpha=0.9,
            )
            ax.add_patch(wedge)
            start_angle += theta

            if cat != "Other":
                legend_categories.add(cat)

    # ------------------------------------------------------------------
    # 6. Legend on the right
    # ------------------------------------------------------------------
    legend_handles = []
    legend_labels = []

    for cat in sorted(legend_categories):
        hatch = category_to_hatch[cat]
        patch = Patch(
            facecolor="lightgrey",
            edgecolor="black",
            hatch=hatch,
            label=cat,
        )
        legend_handles.append(patch)
        legend_labels.append(cat)

    if legend_handles:
        ax.legend(
            handles=legend_handles,
            labels=legend_labels,
            title="Categories",
            bbox_to_anchor=(1.05, 0.5),
            loc="center left",
            borderaxespad=0.0,
        )

    ax.set_title(title)
    plt.tight_layout()
    return fig, ax
