# Import

In [None]:
import numpy as np
import json
from scipy.sparse import load_npz,save_npz,diags,csr_matrix
import scipy.sparse as sp
import pandas as pd
import os
import requests
from io import BytesIO
from tqdm import tqdm
from scipy.sparse.linalg import eigsh
from scipy.spatial.distance import pdist, squareform
import matplotlib.pyplot as plt
from pathlib import Path
from matplotlib.backends.backend_pdf import PdfPages
from pypdf import PdfReader, PdfWriter
from tempfile import NamedTemporaryFile
import networkx as nx
import pickle
import gseapy as gp
import mygene
from IPython.display import display, HTML
import re
from collections import deque
from goatools.obo_parser import GODag
import math
from itertools import combinations
from collections import Counter
from gseapy.parser import read_gmt
import random
import igraph as ig

# Initialization

In [None]:
LIST_OF_DISEASES = ["BIPOLAR","SCHIZOPHRENIA","LEUKEMIA","NONE"]
LIST_OF_DISEASES_CLEAN = [x for x in LIST_OF_DISEASES if x != "NONE"]
SIZE_CAP = 30 # community size cap

In [None]:
def jaccard(a, b):
    a, b = set(a), set(b)
    inter = len(a & b)
    union = len(a | b)
    return inter / union if union else 0.0

In [None]:
disease_to_df = {}
for disease in LIST_OF_DISEASES:
   disease_to_df[disease] = pd.read_csv(f"../output/{disease}/important_terms_{disease}.csv")

In [None]:
def load_info(disease):
    DISEASE_FOLDER = f"../output/{disease}/"
    RESULT_FOLDER = DISEASE_FOLDER + "leiden_results/"
    DGIDB_DIRECTORY = f"../../Gen_Hypergraph/output/DGIDB_{disease}/"
    RESULT_COMMUNITIES = "result_communities_new"
    RESULT_GRAPH = "result_graph_new"

    with open(DISEASE_FOLDER + "gene_to_index_distinct.json", "r") as file:
        gene_to_index_distinct = json.load(file)
    with open(DGIDB_DIRECTORY + f"gene_to_index_{disease}.json", "r") as file:
        DGIDB_gene_to_index = json.load(file)
    with open(DGIDB_DIRECTORY + f"drug_to_index_{disease}.json", "r") as file:
        DGIDB_drug_to_index = json.load(file)
    # Loading result graph and communities
    with open(f"{RESULT_FOLDER}/{RESULT_COMMUNITIES}.pkl", "rb") as f:
        communities = pickle.load(f)
    # with open(f"{RESULT_FOLDER}/result_communities_selected.pkl", "rb") as f:
    #     communities_selected = pickle.load(f)
    with open(f"{RESULT_FOLDER}/{RESULT_GRAPH}.pkl", "rb") as f:
        graph = pickle.load(f)
    
    DGIDB_genes = set(DGIDB_gene_to_index.keys())
    DGIDB_drugs = set(DGIDB_drug_to_index.keys())
    index_to_gene_distinct = {v: k for k, v in gene_to_index_distinct.items()}
    
    return {"communities": communities,
            # "communities_selected": communities_selected,
            "graph": graph,
            "gene_to_index_distinct": gene_to_index_distinct,
            "index_to_gene_distinct": index_to_gene_distinct,
            "DGIDB_gene_to_index": DGIDB_gene_to_index,
            "DGIDB_drug_to_index": DGIDB_drug_to_index,
            "DGIDB_genes": DGIDB_genes,
            "DGIDB_drugs": DGIDB_drugs
            }

In [None]:
all_info = {}
for disease in LIST_OF_DISEASES:
    all_info[disease] = load_info(disease)

In [None]:
# Add all DGIDB genes from other diseases to NONE
for disease in LIST_OF_DISEASES_CLEAN:
    all_info['NONE']['DGIDB_genes'] = all_info['NONE']['DGIDB_genes'] | all_info[disease]['DGIDB_genes']

In [None]:
print(len(all_info['BIPOLAR']['DGIDB_gene_to_index']))

In [None]:
len(all_info['NONE']['DGIDB_genes'])

In [None]:
# Filter communities by size_cap
for disease in LIST_OF_DISEASES:
    result_list = []
    for community in all_info[disease]['communities']:
        if (len(community) >= SIZE_CAP):
            result_list.append(community)
    all_info[disease]['communities'] = result_list

# Create communities_selected

In [None]:
# Community preprocessing
def nx_to_igraph(G: nx.Graph, weight: str | None = "weight") -> ig.Graph:
    """
    Convert a NetworkX graph to an iGraph graph.
    
    Parameters
    ----------
    G : nx.Graph or nx.DiGraph
        Your NetworkX graph.
    weight : str or None, optional
        Name of the edge attribute to treat as weight.
        If None, graph is treated as unweighted.

    Returns
    -------
    ig.Graph
        An iGraph object with:
        - g.vs['name'] = node labels
        - g.es['weight'] = weights (if provided)
    """
    # 1) Keep original node labels
    nodes = list(G.nodes())
    idx_map = {node: i for i, node in enumerate(nodes)}

    # 2) Convert edges
    edges = [(idx_map[u], idx_map[v]) for u, v in G.edges()]

    # 3) Initialize iGraph
    g_ig = ig.Graph(edges=edges, directed=G.is_directed())
    g_ig.vs["name"] = nodes

    # 4) Add weights if available
    if weight is not None:
        # Extract weights or default to 1.0
        weights = [G[u][v].get(weight, 1.0) for u, v in G.edges()]
        g_ig.es["weight"] = weights

    return g_ig

def communities_cutoff(communities, cutoff = 100):
    result = []
    for community in communities:
        if len(community) >= cutoff:
            result.append(community)

    return result, len(result)

def zscore(values):
    arr = np.asarray(values, dtype=float)
    if arr.size == 0:
        return arr

    mean = arr.mean()
    std = arr.std(ddof=0)

    if std == 0 or np.isnan(std):
        # no variation: all z-scores = 0
        return np.zeros_like(arr)

    return (arr - mean) / std

def intramodule_closeness(G, community, weight="weight"):
    """
    Intramodule closeness centrality using the induced subgraph of a community.
    
    Parameters
    ----------
    G : nx.Graph
        Original graph.
    community : iterable
        Nodes in the community (subset of G).
    weight : str or None
        Edge attribute to use as weights.
    
    Returns
    -------
    dict
        Mapping node -> intramodule closeness centrality.
    """
    # Induced subgraph
    H = G.subgraph(community).copy()
    H = nx_to_igraph(H, weight="weight")
    print(f"diameter: {H.diameter()}")
    # Regular closeness, but only inside H
    cl = H.closeness(weights="weight", normalized=True)
    closeness = {H.vs[i]["name"]: float(cl[i]) for i in range(H.vcount())}
    return closeness

def community_central_genes(G, community_nodes, weight="weight", score_cap = 0.1):
    closeness_scores = intramodule_closeness(G, community_nodes)
    return [u for u in community_nodes if closeness_scores[u] >= score_cap]

def pagerank_for_community(g, community_nodes, weight="weight"):
    """
    Compute PageRank for a given community inside an iGraph graph.
    
    Parameters
    ----------
    g : igraph.Graph
        The iGraph graph (converted from NetworkX).
    community_nodes : list
        A list of node names (the same labels used in g.vs["name"]).
    weight : str or None
        Name of the weight attribute. Use None for unweighted Pagerank.
        
    Returns
    -------
    dict {node_name: pagerank_score}
    """

    # 1. Build subgraph by vertex names
    sub = g.subgraph(community_nodes)

    # 2. Compute PageRank
    pr = sub.pagerank(weights=weight)

    # 3. Map back: {original_node_name : score}
    return {sub.vs[i]["name"]: pr[i] for i in range(sub.vcount())}

def community_central_genes_by_num(G, community_nodes, weight="weight", top_n=20):
    C = set(community_nodes)
    H = G.subgraph(C).copy()                       # induced subgraph
    # within-community (weighted) degree
    k = {u: H.degree(u, weight=weight) for u in H}
    ks = np.array(list(k.values()), dtype=float)
    zscore_list = zscore(ks)
    Z = dict(zip(H,zscore_list))        # within-module degree z-score

    # rank by z
    ranked = sorted(H.nodes(), key=lambda u: (Z[u]), reverse=True)
    return [u for u in ranked[:top_n]]

def community_central_genes_by_score(G, community_nodes, weight="weight",score_cap = 1):
    C = set(community_nodes)
    H = G.subgraph(C).copy()                       # induced subgraph
    # within-community (weighted) degree
    k = {u: H.degree(u, weight=weight) for u in H}
    ks = np.array(list(k.values()), dtype=float)
    zscore_list = zscore(ks)
    Z = dict(zip(H,zscore_list))        # within-module degree z-score

    # rank by z
    ranked = sorted(H.nodes(), key=lambda u: (Z[u]), reverse=True)
    return [u for u in ranked if Z[u] >= score_cap]

def community_central_genes_by_pct(G, community_nodes, weight="weight",pct = 0.3):
    C = set(community_nodes)
    H = G.subgraph(C).copy()                       # induced subgraph
    # within-community (weighted) degree
    k = {u: H.degree(u, weight=weight) for u in H}
    ks = np.array(list(k.values()), dtype=float)
    zscore_list = zscore(ks)
    Z = dict(zip(H,zscore_list))        # within-module degree z-score

    # rank by z
    ranked = sorted(H.nodes(), key=lambda u: (Z[u]), reverse=True)
    top = int(len(ranked)*pct)
    return [u for u in ranked[:top]]

def betweenness_for_community(
    g: ig.Graph,
    community_nodes,
    weight: str | None = "weight"
):
    """
    Compute betweenness centrality for nodes inside a community.

    Parameters
    ----------
    g : ig.Graph
        Full igraph graph (must have vs['name']).
    community_nodes : list of str
        Node names belonging to this community.
    weight : str or None
        Edge weight attribute name. None = unweighted betweenness.
    normalized : bool
        Normalize betweenness by maximum possible value.

    Returns
    -------
    dict {node_name : betweenness_score}
    """

    # Build subgraph of the community
    sub = g.subgraph(community_nodes)

    # Compute betweenness (igraph does this fast, C-optimized)
    bt = sub.betweenness(weights=weight)
    bt_normalized = ig.rescale(bt, clamp=True)

    # Return mapping {node_name : score}
    return {sub.vs[i]["name"]: bt_normalized[i] for i in range(sub.vcount())}


def zscores_dict(G, community_nodes, weight="weight"):
    C = set(community_nodes)
    H = G.subgraph(C).copy()                       # induced subgraph
    # within-community (weighted) degree
    k = {u: H.degree(u, weight=weight) for u in H}
    ks = np.array(list(k.values()), dtype=float)
    zscore_list = zscore(ks)
    Z = dict(zip(H,zscore_list))        # within-module degree z-score

    return Z


## Betweenness Scores

### Load Betweenness Scores

In [None]:
with open("../output/disease_to_all_betweenness_scores.json", "r") as f:
    disease_to_all_betweenness_scores = json.load(f)

In [None]:
disease_to_all_betweenness_scores

### Compute Betweenness Scores (very time consuming)

In [None]:
# BETWEENNESS_SCORE_CAP = 0.1

# disease_to_all_betweenness_scores = {}
# for disease in LIST_OF_DISEASES:
#     all_betweenness_scores = []
#     ds_igraph = nx_to_igraph(all_info[disease]['graph'])
#     for community in all_info[disease]['communities']:
#         betweenness_scores = betweenness_for_community(ds_igraph,community)
#         print(len(community),len([u for u in community if betweenness_scores[u] >= BETWEENNESS_SCORE_CAP]))
#         all_betweenness_scores.append(betweenness_scores)
#     disease_to_all_betweenness_scores[disease] = all_betweenness_scores

### Convert Indices to Integers

In [None]:
# Covnert indices to integers
for disease, list_of_dicts in disease_to_all_betweenness_scores.items():
    for i, subdict in enumerate(list_of_dicts):
        disease_to_all_betweenness_scores[disease][i] = {int(k): v for k, v in subdict.items()}

#### Save the Resulting Dictionary

In [None]:
# # Save the resulting dictionary
# with open("../output/disease_to_all_betweenness_scores.json", "w") as f:
#     json.dump(disease_to_all_betweenness_scores, f, indent=2)

## Compute Degree Z-score 

In [None]:
disease_to_all_zscores = {}
for disease in LIST_OF_DISEASES:
    all_zscores = []
    for community in all_info[disease]['communities']:
        zscores = zscores_dict(all_info[disease]['graph'],community)
        all_zscores.append(zscores)
    disease_to_all_zscores[disease] = all_zscores

In [None]:
SCORE_CAP = 0

for disease in LIST_OF_DISEASES:
    all_zscores = disease_to_all_zscores[disease]
    communities_selected = []
    comms = all_info[disease]['communities']
    i = 0
    for i in range(len(comms)):
        selected_nodes = [u for u in comms[i] if all_zscores[i][u] >= SCORE_CAP]
        communities_selected.append(selected_nodes)
    all_info[disease]['communities_selected'] = communities_selected

# Disease-specific Genes Comparison

### Genes Similarity Graph Between Diseases

In [None]:
def num_common_genes(d1,d2):
    return len(all_info[d1]['DGIDB_genes'] & all_info[d2]['DGIDB_genes'])
def num_common_drugs(d1,d2):
    return len(all_info[d1]['DGIDB_drugs'] & all_info[d2]['DGIDB_drugs'])

In [None]:
G = nx.Graph()

# add nodes with size attribute
for disease in LIST_OF_DISEASES_CLEAN:
    G.add_node(disease, size=len(all_info[disease]['DGIDB_genes']))

# add weighted edges for all pairs
for i in range(len(LIST_OF_DISEASES_CLEAN)):
    for j in range(i + 1, len(LIST_OF_DISEASES_CLEAN)):
        w = num_common_genes(LIST_OF_DISEASES_CLEAN[i], LIST_OF_DISEASES_CLEAN[j])
        G.add_edge(LIST_OF_DISEASES_CLEAN[i], LIST_OF_DISEASES_CLEAN[j], weight=w)

# -----------------------------
# Draw the graph
# -----------------------------
pos = nx.spring_layout(G, seed=42)

node_sizes = [G.nodes[n]['size'] for n in G.nodes]
edge_weights = [G[u][v]['weight'] for u, v in G.edges]

plt.figure(figsize=(6, 6))

w = edge_weights
w_min, w_max = min(w), max(w)
edge_widths = [(x - w_min) / (w_max - w_min + 1e-9) * 3 + 0.5 for x in w]

nx.draw(
    G,
    pos,
    with_labels=True,
    node_size=node_sizes,
    width=edge_widths,          # edge width proportional to similarity
    edge_color='gray',
    font_size=10
)

label_pos = {n: (x, y + 0.07) for n, (x, y) in pos.items()}
nx.draw_networkx_labels(G, label_pos, labels={n: G.nodes[n]['size'] for n in G.nodes})
nx.draw_networkx_edge_labels(G, pos, edge_labels=nx.get_edge_attributes(G, 'weight'))

plt.show()

### Community Detection Similarity

In [None]:
def draw_community_bipartite(comms1, comms2, jaccard_threshold=0.0, cap=0, graph_label = ""):
    """
    Draw a bipartite graph comparing two community detection results.

    Parameters
    ----------
    comms1 : list of iterables
        Communities from run 1 (each element is an iterable of node IDs).
    comms2 : list of iterables
        Communities from run 2.
    jaccard_threshold : float, optional
        Only draw edges with Jaccard >= threshold.
    cap : int, optional
        Only include communities with size >= cap.
    """
    G = nx.Graph()

    # --- filter communities by cap ---
    filtered1 = [(i, c) for i, c in enumerate(comms1) if len(c) >= cap]
    filtered2 = [(j, c) for j, c in enumerate(comms2) if len(c) >= cap]

    # node names
    left_nodes  = [f"A_{i}" for i, _ in filtered1]
    right_nodes = [f"B_{j}" for j, _ in filtered2]

    # add nodes with size attribute
    for i, c in filtered1:
        G.add_node(f"A_{i}", bipartite=0, comm_idx=i, size=len(c))

    for j, c in filtered2:
        G.add_node(f"B_{j}", bipartite=1, comm_idx=j, size=len(c))

    # add edges with jaccard weight
    for i, c1 in filtered1:
        for j, c2 in filtered2:
            jacc = jaccard(c1, c2)
            if jacc >= jaccard_threshold:
                G.add_edge(f"A_{i}", f"B_{j}", weight=jacc)

    # if no edges and no nodes, nothing to draw
    if G.number_of_nodes() == 0:
        print("No communities passed the cap filter; nothing to draw.")
        return

    # --- positions: two vertical columns ---
    pos = {}
    for k, n in enumerate(left_nodes):
        if n in G:
            pos[n] = (0, k)
    for k, n in enumerate(right_nodes):
        if n in G:
            pos[n] = (1, k)

    # --- node sizes (scale by community size) ---
    sizes = [G.nodes[n]["size"] for n in G.nodes]
    s_min, s_max = min(sizes), max(sizes)
    node_sizes = [100 + 900 * (s - s_min) / (s_max - s_min + 1e-9) for s in sizes]

    # --- normalized edge widths from Jaccard weights ---
    weights = [d["weight"] for _, _, d in G.edges(data=True)]
    if weights:
        w_min, w_max = min(weights), max(weights)
        if w_max > w_min:
            edge_widths = [
                0.5 + 4 * (w - w_min) / (w_max - w_min)  # thickness in [0.5, 4.5]
                for w in weights
            ]
        else:
            edge_widths = [2.0] * len(weights)
    else:
        edge_widths = []

    # labels: A0, A1, ... for comms1; B0, B1, ... for comms2
    labels = {
        n: f"{'A' if G.nodes[n]['bipartite']==0 else 'B'}{G.nodes[n]['comm_idx']}"
        for n in G.nodes
    }

    plt.figure(figsize=(8, 6))
    nx.draw(
        G,
        pos,
        with_labels=False,
        node_size=node_sizes,
        width=edge_widths,
        edge_color="gray",
    )
    nx.draw_networkx_labels(G, pos, labels=labels, font_size=9)

    # edge labels = Jaccard scores
    edge_labels = {(u, v): f"{d['weight']:.2f}" for u, v, d in G.edges(data=True)}

    ax = plt.gca()
    offset = 0.04  # tweak this to move labels further/closer

    for (u, v), label in edge_labels.items():
        x1, y1 = pos[u]
        x2, y2 = pos[v]

        # midpoint of the edge
        mx, my = (x1 + x2) / 2.0, (y1 + y2) / 2.0

        # perpendicular unit vector to the edge
        dx, dy = x2 - x1, y2 - y1
        length = (dx**2 + dy**2) ** 0.5 or 1.0
        px, py = -dy / length, dx / length  # rotate (dx,dy) by -90°

        # offset midpoint along perpendicular
        lx, ly = mx + offset * px, my + offset * py

        ax.text(lx, ly, label,
                fontsize = 10,
                ha="center", va="center")

    plt.axis("off")
    plt.title(graph_label)
    plt.tight_layout()  
    plt.show()
    
    return G

In [None]:
def index_to_ncbi(comms,index_to_ncbi):
    comms_ncbi = [list(map(index_to_ncbi.get, c)) for c in comms]
    return comms_ncbi

In [None]:
for disease in LIST_OF_DISEASES:
    disease_info = all_info[disease]
    all_info[disease]['communities_ncbi'] = index_to_ncbi(disease_info['communities'],
                                                          disease_info['index_to_gene_distinct'])
    all_info[disease]['communities_selected_ncbi'] = index_to_ncbi(disease_info['communities_selected'],
                                                        disease_info['index_to_gene_distinct'])

In [None]:
def draw_helper(left,right,jaccard_threshold, cap, code = "ncbi"):
    com_code = f'communities_{code}'
    return draw_community_bipartite(all_info[left][com_code], all_info[right][com_code], jaccard_threshold=jaccard_threshold, cap = cap, graph_label = f"{left} v.s. {right}")

In [None]:
num_common_genes("LEUKEMIA","SCHIZOPHRENIA")

In [None]:
for disease in LIST_OF_DISEASES_CLEAN:
    print(f"{disease}: ")
    for c in all_info[disease]['communities_selected']:
        print(len(c))
    print()

In [None]:
disease_to_msigdb_comparison = {}
for disease in LIST_OF_DISEASES_CLEAN:
    disease_to_msigdb_comparison[disease] = draw_helper(disease,"NONE",jaccard_threshold=0.1, cap = 0, code = "selected")

In [None]:
draw_helper("LEUKEMIA","SCHIZOPHRENIA",jaccard_threshold=0.1, cap = 100, code = "selected")

In [None]:
draw_helper("LEUKEMIA","BIPOLAR",jaccard_threshold=0.2, cap = 100, code = "selected")

In [None]:
draw_helper("BIPOLAR","SCHIZOPHRENIA",jaccard_threshold=0.05, cap = 100, code = "selected")

### DGIDB Genes Analysis

In [None]:
def DGIDB_count(c,DGIDB_genes_ncbi):
    return set(c) & set(DGIDB_genes_ncbi)

In [None]:
def print_comms_DGIDB_num(disease):
    i = 0
    for c in all_info[disease]['communities_ncbi']:
        print(f"Community {i}: {len(c)}/{len(DGIDB_count(c,all_info[disease]['DGIDB_genes']))}")
        i += 1

In [None]:
print_comms_DGIDB_num("LEUKEMIA")

In [None]:
BETWEENNESS_SCORE_CAP = 0.1

In [None]:
for disease in LIST_OF_DISEASES:
    DGIDB_genes = all_info[disease]['DGIDB_genes']
    comms = all_info[disease]['communities']
    comms_ncbi = all_info[disease]['communities_ncbi']
    betweenness_scores = disease_to_all_betweenness_scores[disease]
    gtid = all_info[disease]['gene_to_index_distinct']
    
    print(f"\033[1m{disease}: \033[0m")
    for i,(c,s) in enumerate(zip(comms_ncbi,betweenness_scores)):
        print(f"Community {i}:")
        avg_score = sum([val for _,val in s.items()]) / len(s)
        print(f"Average score: {avg_score}")
        dgidb_genes = list(DGIDB_count(c,DGIDB_genes))
        dgidb_scores = {g:s[gtid[g]] for g in dgidb_genes}
        dgidb_scores_sorted = dict(sorted(dgidb_scores.items(), key=lambda x: x[1], reverse=True))
        print(dgidb_scores_sorted)
        print()

In [None]:
DGIDB_count(all_info['LEUKEMIA']['communities_ncbi'][1],all_info['LEUKEMIA']['DGIDB_genes'])

In [None]:
[u for u,v in disease_to_all_betweenness_scores['LEUKEMIA'][1].items() if v >= BETWEENNESS_SCORE_CAP]

### Community Detection Similarity (Term-wise)

In [None]:
def plot_comm_jaccard_bipartite_bestboth(
    df1,
    df2,
    community_col="community_id",
    term_col="term",
    min_jaccard=0.0,
):
    # --- 1. Build sets of terms per community ---
    comm1 = (
        df1[[community_col, term_col]]
        .drop_duplicates()
        .groupby(community_col)[term_col]
        .agg(set)
    )
    comm2 = (
        df2[[community_col, term_col]]
        .drop_duplicates()
        .groupby(community_col)[term_col]
        .agg(set)
    )
    # --- 2. Best match df1 -> df2 ---
    best_12 = {}
    for cid1, set1 in comm1.items():
        best_score = 0.0
        best_cid2 = None
        for cid2, set2 in comm2.items():
            score = jaccard(set1, set2)
            if score > best_score:
                best_score = score
                best_cid2 = cid2
        if best_cid2 is not None and best_score >= min_jaccard:
            best_12[(cid1, best_cid2)] = best_score

    # --- 3. Best match df2 -> df1 ---
    best_21 = {}
    for cid2, set2 in comm2.items():
        best_score = 0.0
        best_cid1 = None
        for cid1, set1 in comm1.items():
            score = jaccard(set1, set2)
            if score > best_score:
                best_score = score
                best_cid1 = cid1
        if best_cid1 is not None and best_score >= min_jaccard:
            best_21[(best_cid1, cid2)] = best_score

    # Union of edges from both directions
    all_edges = {}
    for (c1, c2), s in {**best_12, **best_21}.items():
        all_edges[(c1, c2)] = s  # Jaccard is symmetric anyway

    # --- 4. Build graph ---
    G = nx.Graph()

    # Extract community sizes from df1 and df2
    size1 = (
        df1[[community_col, "Community Size"]]
        .drop_duplicates()
        .set_index(community_col)["Community Size"]
        .to_dict()
    )

    size2 = (
        df2[[community_col, "Community Size"]]
        .drop_duplicates()
        .set_index(community_col)["Community Size"]
        .to_dict()
    )

    # Add nodes for df1 (left)
    for cid, terms in comm1.items():
        G.add_node(f"A_{cid}", side="A", comm_id=cid, size=size1[cid])

    # Add nodes for df2 (right)
    for cid, terms in comm2.items():
        G.add_node(f"B_{cid}", side="B", comm_id=cid, size=size2[cid])

    # Add edges from union of best matches (both ways)
    for (cid1, cid2), score in all_edges.items():
        G.add_edge(f"A_{cid1}", f"B_{cid2}", weight=score)

    # --- 5. Layout: two columns ---
    left_nodes  = sorted(
        [n for n, d in G.nodes(data=True) if d["side"] == "A"],
        key=lambda n: G.nodes[n]["comm_id"],
    )
    right_nodes = sorted(
        [n for n, d in G.nodes(data=True) if d["side"] == "B"],
        key=lambda n: G.nodes[n]["comm_id"],
    )

    pos = {}
    for i, n in enumerate(left_nodes):
        pos[n] = (0.0, i)
    for i, n in enumerate(right_nodes):
        pos[n] = (1.0, i)

    # Node size ∝ community size
    def size_to_marker(s):
        return 3 * s

    node_sizes = {n: size_to_marker(G.nodes[n]["size"]) for n in G.nodes()}

    # Edge width ∝ Jaccard
    edge_weights = {(u, v): d["weight"] for u, v, d in G.edges(data=True)}
    edge_widths  = [2 + 8 * w for w in edge_weights.values()]

    # Edge labels = Jaccard
    edge_labels = {(u, v): f"{w:.2f}" for (u, v), w in edge_weights.items()}

    # Node labels = community IDs
    labels = {n: G.nodes[n]["comm_id"] for n in G.nodes()}

    # --- 6. Plot ---
    plt.figure(figsize=(10, max(4, 0.5 * max(len(left_nodes), len(right_nodes)))))

    nx.draw_networkx_nodes(
        G, pos,
        nodelist=left_nodes,
        node_size=[node_sizes[n] for n in left_nodes],
        node_color="skyblue",
    )
    nx.draw_networkx_nodes(
        G, pos,
        nodelist=right_nodes,
        node_size=[node_sizes[n] for n in right_nodes],
        node_color="salmon",
    )

    nx.draw_networkx_edges(G, pos, width=edge_widths, alpha=0.5)
    # nx.draw_networkx_edge_labels(G, pos, edge_labels=edge_labels, font_size=7)
    
    for i, (u, v) in enumerate(G.edges()):
        x1, y1 = pos[u]
        x2, y2 = pos[v]

        # Midpoint of the edge
        xm = (x1 + x2) / 2.0
        ym = (y1 + y2) / 2.0

        # Perpendicular unit vector to the edge
        dx = x2 - x1
        dy = y2 - y1
        length = math.hypot(dx, dy) or 1.0
        px = -dy / length
        py = dx / length

        # Spread labels along the perpendicular direction
        # pattern: -2, -1, 0, 1, 2, -2, -1, ...
        k = (i % 5) - 2
        offset = 0.06 * k   # tune this if you want more/less separation

        xl = xm + offset * px
        yl = ym + offset * py

        w = edge_weights[(u, v)]
        label = f"{w:.2f}"

        plt.text(
            xl,
            yl,
            label,
            fontsize=7,
            ha="center",
            va="center",
            bbox=dict(boxstyle="round,pad=0.1", fc="white", ec="none", alpha=0.7),
        )    
    
    nx.draw_networkx_labels(G, pos, labels=labels, font_size=8)

    plt.axis("off")
    plt.tight_layout()
    plt.show()

    return G


In [None]:
def best_jaccard_match(
    df_source,
    df_target,
    source_community,
    *,
    community_col="community_id",
    term_col="term",
):
    """
    Compare one community in df_source to all communities in df_target.
    Returns (best_community_in_target, best_jaccard_score, all_scores_series).
    """

    # Terms for the source community
    source_terms = set(
        df_source.loc[df_source[community_col] == source_community, term_col]
    )

    if not source_terms:
        raise ValueError(f"No terms found for community {source_community!r} in df_source")

    # Build a mapping: community_id -> set of terms for df_target
    target_sets = (
        df_target[[community_col, term_col]]
        .drop_duplicates()
        .groupby(community_col)[term_col]
        .agg(set)
    )

    # Compute Jaccard score against every target community
    scores = target_sets.apply(lambda term_set: jaccard(source_terms, term_set))

    # Identify the best match
    best_community = scores.idxmax()
    best_score = scores.loc[best_community]

    return best_community, best_score, scores.sort_values(ascending=False)

def best_jaccard_analysis_left_to_right(left,right):
    for id in disease_to_df[left]["Community Index"].unique():
        best_comm, best_jaccard, all_scores = best_jaccard_match(
            df_source=disease_to_df[left],
            df_target=disease_to_df[right],
            source_community=id,
            community_col="Community Index",  # change if your column name is different
            term_col="Term",
        )
        print(f"Community {id}:")
        print("Best match in right:", best_comm)
        print("Best Jaccard score:", best_jaccard)
        print("All scores (sorted):")
        
        left_DGIDB = DGIDB_count(all_info[left]['communities_ncbi'][id],all_info[left]['DGIDB_genes'])
        right_DGIDB = DGIDB_count(all_info[right]['communities_ncbi'][best_comm],all_info[right]['DGIDB_genes'])
        print(f"Number of DGIDB terms left: {len(left_DGIDB)}")
        print(f"Number of DGIDB terms right: {len(right_DGIDB)}")
        print(f"Number of common DGIDB terms (with best match): {len(left_DGIDB & right_DGIDB)}")
        
        print(all_scores.head())
        print()

In [None]:
plot_comm_jaccard_bipartite_bestboth(
    disease_to_df["LEUKEMIA"],
    disease_to_df["NONE"],
    community_col="Community Index",
    term_col="Term",
    min_jaccard=0.0,
)

In [None]:
disease_to_df['NONE']

In [None]:
best_jaccard_analysis_left_to_right("BIPOLAR","NONE")

### L10, S11, B11, N12 Analysis

In [None]:
L10 = all_info['LEUKEMIA']['communities_selected_ncbi'][10]
S11 = all_info['SCHIZOPHRENIA']['communities_selected_ncbi'][11]
B11 = all_info['BIPOLAR']['communities_selected_ncbi'][11]
N12 = all_info['NONE']['communities_selected_ncbi'][12]

In [None]:
print(jaccard(L10,S11))
print(jaccard(L10,B11))
print(jaccard(S11,B11))
print(jaccard(S11,N12))

### TBD

In [None]:
disease_to_df['BIPOLAR']["Community Index"].value_counts()

In [None]:
disease_to_df['SCHIZOPHRENIA']["Community Index"].value_counts()

In [None]:
disease_to_df['LEUKEMIA']["Community Index"].value_counts()

In [None]:
disease_to_df['LEUKEMIA']

### Preliminary Gene Comparison (to introduce a new disease)

In [None]:
NEW_DISEASE = "PARKINSON"

In [None]:
NEW_DGIDB_DIRECTORY = DGIDB_DIRECTORY = f"../../Gen_Hypergraph/output/DGIDB_{NEW_DISEASE}/"

In [None]:
with open(NEW_DGIDB_DIRECTORY + f"gene_to_index_{NEW_DISEASE}.json", "r") as file:
    new_DGIDB_gene_to_index = json.load(file)
with open(NEW_DGIDB_DIRECTORY + f"drug_to_index_{NEW_DISEASE}.json", "r") as file:
    new_DGIDB_drug_to_index = json.load(file)

In [None]:
new_genes = set(new_DGIDB_gene_to_index.keys())
new_drugs = set(new_DGIDB_drug_to_index.keys())

In [None]:
print(new_genes)

In [None]:
len(new_genes & all_info["BIPOLAR"]['DGIDB_genes']) 

In [None]:
CLOZA_genes = [     1813,      6714,     64816,      9429,      1543,        12,
            3107,       262,      2952,      3356,     25970,       718,
            7365,      4363,      3127,      2904,      1562,     54658,
               1,         3,        11,      2740,      2675,      2668,
            4524,      1557,      4907,      3350,      4781,      3125,
              24,       215,       217,      3359,      2944,      1565,
             265, 100302144,      2166,      1576,      1812,      3363,
       100302251,        26,      1128,     28755,        22,       218,
            5294,      3358,     11201,      2194,      3558,      1993,
             887,      5565,      3115,      4160,        23,      6505,
            3952,      3699,     23216,      5617,      2641,     84062,
            3106,     50852,      9135,      7957,      2289,      6532,
              10,       216,      6531,      1544,      5020,        13,
           10951,      3727,      3760,      1559,     54657,      5563,
            4915,      3176,      4835,         2,       214,     59340,
            1815,      3269]

In [None]:
CLOZA_genes = set([str(i) for i in CLOZA_genes])

In [None]:
print(CLOZA_genes)

In [None]:
# Common genes in communities 8 (both diseases) that are also in CLOZA_genes
print(len(CLOZA_genes & DGIDB_count(all_info['BIPOLAR']['communities_ncbi'][8],all_info['BIPOLAR']['DGIDB_genes']) & DGIDB_count(all_info['SCHIZOPHRENIA']['communities_ncbi'][8],all_info['SCHIZOPHRENIA']['DGIDB_genes'])))

# Community Description

In [None]:
def create_communities_attributes_df(disease):
    comms = all_info[disease]['communities']
    msigdb_comp_graph = disease_to_msigdb_comparison[disease]
    edge_list = [(u, v, data["weight"]) for u, v, data in msigdb_comp_graph.edges(data=True)]
    comms_ncbi = all_info[disease]['communities_ncbi']
    DGIDB_genes = all_info[disease]['DGIDB_genes']
    betweenness_scores = disease_to_all_betweenness_scores[disease]
    gtid = all_info[disease]['gene_to_index_distinct']
    
    # compute msigdb connection weight for each community
    comm_to_total_weights = {}
    
    for i in range(len(comms)):
        comm_to_total_weights[i] = 0
    for edge in edge_list:
        weight = edge[2]
        comm_id = int(edge[0].split("_")[1])
        comm_to_total_weights[comm_id] += weight
    
    # compute number of DGIDB genes

    total_num_dgidb_genes = 0
    comm_to_num_dgidb_genes = {}
    
    for i,c in enumerate(comms_ncbi):
        num_dgidb_genes = len(DGIDB_count(c,DGIDB_genes))
        total_num_dgidb_genes += num_dgidb_genes
        comm_to_num_dgidb_genes[i] = num_dgidb_genes
    
    
    # betweenness score for DGIDB genes

    comm_to_betweenness_score = {}
    comm_to_average_score = {}
    
    for i,(c,s) in enumerate(zip(comms_ncbi,betweenness_scores)):
        avg_score = sum([val for _,val in s.items()]) / len(s)
        comm_to_average_score[i] = avg_score
        dgidb_genes = list(DGIDB_count(c,DGIDB_genes))
        dgidb_scores = {g:s[gtid[g]] for g in dgidb_genes}
        dgidb_scores_sorted = dict(sorted(dgidb_scores.items(), key=lambda x: x[1], reverse=True))
        comm_to_betweenness_score[i] = dgidb_scores_sorted
    
    # # print result
    # print(f"Disease: {disease}")
    # print(f"Number of communities: {len(comms)}\n")
    # for i,c in enumerate(comms):
    #     print(f"\033[1mCommunity {i}:\033[0m")
    #     print(f"Total connection weight to MsigDB: {comm_to_total_weights[i]}")
    #     print(f"Number of DGIDB genes: {comm_to_num_dgidb_genes[i]}")
    #     print(f"Average betweenness score: {comm_to_average_score[i]}")
    #     print(f"DGIDB genes betweenness score: {comm_to_betweenness_score[i]}")
    #     print(f"Total number of DGIDB genes (in communities greater than size_cap): {total_num_dgidb_genes}")
    #     print()
        
    # create dataframe
    community_attributes = pd.DataFrame({
        "msigdb_connection_weight": comm_to_total_weights,
        "num_dgidb_genes": comm_to_num_dgidb_genes,
        "avg_btn_score": comm_to_average_score,
        "dgidb_genes_btn_score": comm_to_betweenness_score,
    })
    
    return community_attributes
    

In [None]:
p = create_communities_attributes_df('SCHIZOPHRENIA')

In [None]:
p[p['msigdb_connection_weight'] > 0.60]

# Filtering

In [None]:
# components = list(nx.connected_components(G))
# for i, comp in enumerate(components):
#     print(f"Component {i} ({len(comp)} nodes):")
#     print(comp)