# Preprocessing Notebook

This notebook contains **all preprocessing steps** required to build the heterogeneous circRNA-disease-gene graph used throughout the project. It does *not* train models or perform inference; instead, it prepares the full graph data that is later consumed by the R-GCN (Stage A) and the evidential head (Stage B) in the training notebook.

The preprocessing pipeline implements the graph construction procedure described in Section 3 (Methodology) of the paper. Specifically, this notebook:

- Loads circRNA identifiers and genomic coordinates from **circBase**  
- Loads experimentally validated circRNA-disease associations from **circRNADisease v2.0**  
- Normalizes disease names to **Human Disease Ontology (DOID)** identifiers  
- Constructs node sets for circRNAs, diseases, and genes  
- Builds four relation types:  
  • **CC** - circRNA-circRNA similarity edges via Needleman-Wunsch alignment  
  • **CD** - curated circRNA-disease associations  
  • **CG** - circRNA-gene interactions  
  • **DD** - parent-child Disease Ontology links  
- Generates the top-K circRNA similarity graph  
- Produces unified indexing dictionaries and adjacency tensors used by all downstream models  

The final output of this notebook is a serialized graph dictionary `G` that is loaded by the training/inference notebook.


In [1]:
!pip install cupy

Collecting cupy
  Using cached cupy-13.6.0-cp312-cp312-linux_x86_64.whl
Collecting fastrlock>=0.5 (from cupy)
  Using cached fastrlock-0.8.3-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_28_x86_64.whl.metadata (7.7 kB)
Using cached fastrlock-0.8.3-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_28_x86_64.whl (53 kB)
Installing collected packages: fastrlock, cupy
Successfully installed cupy-13.6.0 fastrlock-0.8.3


In [2]:
import re
import pandas as pd
import cupy as cp
import numpy as np
import pickle
import networkx as nx
import matplotlib.pyplot as plt

In [3]:
# Parse Disease Ontology (names, synonyms, parents)
def parse_do(obo_path):
    name2id = {}
    parent = {}
    cur_id = None

    with open(obo_path, 'r', encoding='utf-8') as f:
        for ln in f:
            ln = ln.strip()
            if ln.startswith('id: DOID:'):
                cur_id = ln.split('id: ')[1]
            elif ln.startswith('name: '):
                nm = ln.split('name: ')[1].lower()
                name2id[nm] = cur_id
            elif ln.startswith('synonym: '):
                m = re.search(r'"(.+?)"', ln)
                if m and cur_id:
                    nm = m.group(1).lower()
                    name2id[nm] = cur_id
            elif ln.startswith('is_a: ') and cur_id:
                p = ln.split('is_a: ')[1].split(' ! ')[0]
                parent.setdefault(cur_id, []).append(p)

    return name2id, parent

# Extract unique diseases from dataset
def load_dataset_diseases(path):
    df = pd.read_csv(path)
    ds = df['Disease Name'].dropna().unique().tolist()
    ds_l = [d.lower() for d in ds]
    return df, ds, ds_l

# Build final mapping (exact/synonym + manual)
def build_mapping(ds, ds_l, name2id, manual):
    out = {}
    for orig, low in zip(ds, ds_l):
        if low in name2id:
            out[orig] = name2id[low]
    for k, v in manual.items():
        out[k] = v
    unmapped = [d for d in ds if d not in out]
    return out, unmapped


In [4]:
manual_disease_mappings = {
    
    'Cardiovascular disease' : 'DOID:1287',
    'Triple negative breast cancer': 'DOID:0060081',
    'Glioblastoma': 'DOID:3068',
    'Cervical carcinoma': 'DOID:4362',
    'Cutaneous squamous cell carcinoma': 'DOID:3151',
    'Heart failure': 'DOID:1037',
    'Acute kidney injury': 'DOID:10284',
    'Diabetic nephropathy': 'DOID:4467',
    'Spinal cord injury': 'DOID:936',
    'T-cell lymphoblastic lymphoma': 'DOID:5075',
    'Laryngeal cancer': 'DOID:5520',
    'Oral cancer': 'DOID:5523',
    'Influenza A virus infection': 'DOID:8469',
    'Diabetic foot ulcer': 'DOID:0050731',

    'Glioma': 'DOID:3069',
    'Ischemic stroke': 'DOID:2316',
    'Neuropathic pain': 'DOID:0060164',
    'Small cell lung cancer': 'DOID:5409',
    'Liver fibrosis': 'DOID:5082',
    'Myocardial ischemia': 'DOID:5844',
    'Epithelial ovarian cancer': 'DOID:2152',
    'Age-related cataract': 'DOID:0110067',
    'Hypertrophic scar': 'DOID:0060861',
    'Cardiac fibrosis': 'DOID:0110057',
    'Acute ischemic stroke': 'DOID:2316',
    'Hypopharyngeal carcinoma': 'DOID:9261',
    'Salivary adenoid cystic carcinoma': 'DOID:5916',
    'Keloid': 'DOID:8704',
    'Primary hepatic carcinoma': 'DOID:684',
    'Peripheral nerve injury': 'DOID:1852',
    'Thoracic aortic dissection': 'DOID:0080169',
    'Vascular endothelial dysfunction': 'DOID:178',
    'Inflammatory response': 'DOID:1596',
    'Oxidative injury': 'DOID:1596',
    'Hypoxia': 'DOID:1596',
    'Senescence': 'DOID:1596',

    'Esophageal squamous cell carcinoma': 'DOID:3748',
    'Hepatic fibrosis': 'DOID:5082',
    'Alcoholic liver disease': 'DOID:409',
    'Senile dementia': 'DOID:1307',
    'Intracerebral hemorrhage': 'DOID:2246',
    'Pulmonary arterial hypertension': 'DOID:2802',
    'Idiopathic membranous nephropathy': 'DOID:0080204',
    'Primary Sjogren\'s Syndrome': 'DOID:12894',
    'Diabetic cardiomyopathy': 'DOID:0050700',
    'Cardiac hypertrophy': 'DOID:11984',
    'Temporomandibular joint osteoarthritis': 'DOID:8398',
    'Facet joint osteoarthritis': 'DOID:8398',
    'Steroid-induced osteonecrosis of the femoral head': 'DOID:10159',
    'Oxygen-induced retinopathy': 'DOID:5679',
    'Kashin-Beck disease': 'DOID:0080367',
    'Scar': 'DOID:8704',
    'Early allograft dysfunction': 'DOID:0060043',
    'Leukoaraiosis': 'DOID:2316',
    'Craniocerebral trauma': 'DOID:0081292',
    'Laryngocarcinoma': 'DOID:2876',

    'Carcinogenesis': 'DOID:162',
    'Inflammation': 'DOID:1596',
    'Cellular senescence': 'DOID:1596',
    'Immunosenescence': 'DOID:1596',
    'Osteogenesis': 'DOID:1596',
    'Ischemia-reperfusion injury': 'DOID:1596',
    'Neuronal cell injury': 'DOID:1596',

    'Ischemic heart disease': 'DOID:3393',
    'Sepsis': 'DOID:9741',
    'Acute lung injury': 'DOID:0080788',
    'Postmenopausal osteoporosis': 'DOID:11476',
    'Non-obstructive azoospermia': 'DOID:11831',
    'Nasopharyngeal cancer': 'DOID:5522',
    'Papillary thyroid microcarcinoma': 'DOID:0060708',
    'Intervertebral disc disease': 'DOID:9205',
    'T cell acute lymphoblastic leukemia': 'DOID:0050746',
    'Anorectal malformation': 'DOID:0111285',
    'Axial spondyloarthritis': 'DOID:13195',
    'Radiation-induced liver disease': 'DOID:0080604',
    'Hypoxic-ischaemic encephalopathy': 'DOID:0080599',

    'Retinal neurodegeneration': 'DOID:5679',
    'Retinal vascular dysfunction': 'DOID:5679',
    'Oral mucosal melanoma': 'DOID:8923',
    'Diabetic kidney disease': 'DOID:4467',
    'Liver injury': 'DOID:409',
    'Myocardial injury': 'DOID:5844',
    'Kidney injury': 'DOID:1074',
    'Neuronal injury': 'DOID:863',
    'Fetal growth restriction': 'DOID:0111222',
    'Aortic valve calcification': 'DOID:0060319',
    'Laryngeal squamous cell cancer': 'DOID:2876',
    'Haemangiomas': 'DOID:687',
    'Angiogenesis': 'DOID:1596',
    'Neointimal hyperplasia': 'DOID:1287',
    'Intervertebral disk degeneration': 'DOID:9205',

    'T2DM with depression': 'DOID:9352',
    'Hypopharyngeal squamous cell carcinoma': 'DOID:9261',
    'Myocardial fibrosis': 'DOID:114',
    'Lead-induced neuronal cell apoptosis': 'DOID:1596',
    'Nonfunctioning pituitary adenomas': 'DOID:3829',
    'Acute Stanford type A aortic dissection': 'DOID:0080169',
    'Doxorubicin-Induced cardiotoxicity': 'DOID:114',
    'Trauma-induced osteonecrosis of femoral head': 'DOID:10159',
    'Sepsis-induced liver damage': 'DOID:409',
    'Cardiotoxicity': 'DOID:114',
    'Neurotoxicity and neuroinflammation': 'DOID:863',
    'Photoaging': 'DOID:1596',
    'Chondrocyte apoptosis': 'DOID:8398',
}


In [5]:
name2id, parent = parse_do("datasets/doid.obo.txt")
df, ds, ds_l = load_dataset_diseases("datasets/The circRNA-disease entries.csv")
mapping, unmapped = build_mapping(ds, ds_l, name2id, manual_disease_mappings)


In [6]:
df['Disease_ID'] = df['Disease Name'].map(mapping)
df = df[df['Disease_ID'].notna()].copy()
df['Gene Symbol'] = df['Gene Symbol'].astype(str).str.strip()


In [7]:
# Parse FASTA into {id: seq}
def parse_fasta(path):
    out = {}
    cur = None
    seq = []

    with open(path, 'r') as f:
        for ln in f:
            ln = ln.strip()
            if ln.startswith('>'):
                if cur:
                    out[cur] = ''.join(seq)
                cur = ln[1:].split('|')[0]
                seq = []
            else:
                seq.append(ln)

    if cur:
        out[cur] = ''.join(seq)
    return out

# Simple similarity heuristic for name matching
def name_sim(a, b):
    if a == b:
        return 1.0
    na = re.findall(r'\d+', a)
    nb = re.findall(r'\d+', b)
    m = 0.4 if any(x in nb for x in na) else 0.0
    s = len(set(a)&set(b)) / len(set(a)|set(b))
    return m + 0.3*s

# Resolve circRNA names to FASTA IDs
def resolve_names(cnames, fasta):
    out = {}
    for nm in cnames:
        if nm in fasta and fasta[nm]:
            out[nm] = (nm, 1.0)
            continue

        if 'circRNA' in nm:
            alt = nm.replace('circRNA', 'circ')
            if alt in fasta and fasta[alt]:
                out[nm] = (alt, 0.8)
                continue

        m = re.search(r'(\d+)', nm)
        if m:
            k = m.group(1)
            cands = [
                f'hsa_circ_{k.zfill(7)}',
                f'hsa_circ_{k}'
            ]
            for c in cands:
                if c in fasta and fasta[c]:
                    sim = name_sim(nm, c)
                    conf = 0.6 if sim > 0.6 else 0.3
                    out[nm] = (c, conf)
                    break
    return out

# Generate high-confidence subset
def high_conf_subset(df, seq_res, min_conf=0.8):
    keep = []
    for nm in df['circRNA Name']:
        if nm in seq_res:
            alias, c = seq_res[nm]
            if c >= min_conf:
                keep.append(nm)
    return df[df['circRNA Name'].isin(keep)].copy()

# Build ID maps for circRNAs, diseases, genes
def build_id_maps(df):
    c = sorted(df['circRNA Name'].unique())
    d = sorted(df['Disease_ID'].dropna().unique())
    g = sorted(df['Gene Symbol'].dropna().unique())

    c2id = {x:i for i,x in enumerate(c)}
    d2id = {x:i for i,x in enumerate(d)}
    g2id = {x:i for i,x in enumerate(g)}

    return c2id, d2id, g2id


In [8]:
fasta = parse_fasta("datasets/human_hg19_circRNAs_putative_spliced_sequence.fa")
cnames = df['circRNA Name'].dropna().unique().tolist()
res = resolve_names(cnames, fasta)

df_hc = high_conf_subset(df, res, min_conf=0.8)
c2id, d2id, g2id = build_id_maps(df_hc)


In [9]:
# Debug: verify preprocessing state before continuing

print("DataFrame Columns")
print(df.columns.tolist())

print("\nDisease_ID field")
print("Non-null Disease_ID:", df['Disease_ID'].notna().sum())
print("Unique Disease_IDs:", df['Disease_ID'].nunique())
print("Sample:", df[['Disease Name', 'Disease_ID']].head(10))

print("\nFASTA size")
print("FASTA entries:", len(fasta))

print("\ncircRNA Name Resolution")
resolved = sum(1 for k in res if k in res)
high_conf = sum(1 for k,v in res.items() if v[1] >= 0.8)
print("Total circRNAs in dataset:", len(cnames))
print("Resolved circRNAs:", len(res))
print("High-confidence mapped:", high_conf)

print("\nHigh-confidence dataframe")
print("Rows in df_hc:", len(df_hc))
print("Unique circRNAs:", df_hc['circRNA Name'].nunique())
print("Unique Disease_IDs:", df_hc['Disease_ID'].nunique())
print("Unique genes:", df_hc['Gene Symbol'].dropna().nunique())

print("\nID map sizes")
print("circRNAs:", len(c2id))
print("diseases:", len(d2id))
print("genes:", len(g2id))

print("\nSample ID mappings")
print("circRNA sample:", list(c2id.items())[:5])
print("disease sample:", list(d2id.items())[:5])
print("gene sample:", list(g2id.items())[:5])


DataFrame Columns
['CRD ID', 'circRNA Name', 'Synonyms', 'Gene Symbol', 'Disease Name', 'Expression pattern', 'PubMed ID', 'Region', 'Strand', 'Species', 'Experimental techniques', 'Brief description', 'Title', 'Disease_ID']

Disease_ID field
Non-null Disease_ID: 4146
Unique Disease_IDs: 232
Sample:           Disease Name  Disease_ID
0         Tuberculosis    DOID:399
1         Tuberculosis    DOID:399
2         Tuberculosis    DOID:399
3         Tuberculosis    DOID:399
4         Tuberculosis    DOID:399
5         Tuberculosis    DOID:399
6  Alzheimer's disease  DOID:10652
7  Alzheimer's disease  DOID:10652
8  Alzheimer's disease  DOID:10652
9  Alzheimer's disease  DOID:10652

FASTA size
FASTA entries: 140790

circRNA Name Resolution
Total circRNAs in dataset: 3083
Resolved circRNAs: 2953
High-confidence mapped: 1487

High-confidence dataframe
Rows in df_hc: 2224
Unique circRNAs: 1487
Unique Disease_IDs: 160
Unique genes: 1179

ID map sizes
circRNAs: 1487
diseases: 160
genes: 1179

Sa

In [10]:
# Extract sequences for retained circRNAs
def build_seq_map(df_hc, res, fasta):
    out = {}
    for nm in df_hc['circRNA Name'].unique():
        fa_id, conf = res[nm]
        out[nm] = fasta[fa_id]
    return out

# Filter sequences by length and remap circRNA IDs
def filter_by_len(seq_map, c2id, max_len=5500):
    kept = {}
    for nm, seq in seq_map.items():
        if len(seq) <= max_len:
            kept[nm] = seq

    c = sorted(kept.keys())
    new_c2id = {nm:i for i,nm in enumerate(c)}
    seq_arr = [kept[nm] for nm in c]
    return seq_arr, new_c2id

# Encode sequences into integer matrix for GPU kernel
def encode_sequences(seqs):
    L = max(len(s) for s in seqs)
    A = np.zeros((len(seqs), L), dtype=np.int32)
    Ls = np.zeros(len(seqs), dtype=np.int32)
    for i, s in enumerate(seqs):
        Ls[i] = len(s)
        for j, ch in enumerate(s):
            A[i,j] = ord(ch.upper())
    return A, Ls, L

# GPU Needleman-Wunsch kernel (2-row DP)
nw_kernel = cp.RawKernel(r'''
extern "C" __global__
void nw(
    const int* A,
    const int* Ls,
    float* M,
    int N,
    int Lmax,
    float ms,
    float mis,
    float gs
){
    int i = blockIdx.x;
    int j = blockIdx.y;
    if(i>=N || j>=N) return;

    int L1 = Ls[i];
    int L2 = Ls[j];

    if(L1==0 || L2==0){
        M[i*N+j] = 0.0f;
        return;
    }
    if(i==j){
        M[i*N+j] = 1.0f;
        return;
    }

    extern __shared__ float buf[];
    float* prev = buf;
    float* curr = buf + (L2+1);

    for(int c=0;c<=L2;c++) prev[c] = c*gs;

    for(int r=1;r<=L1;r++){
        curr[0] = r*gs;
        int idx1 = i*Lmax + (r-1);
        for(int c=1;c<=L2;c++){
            int idx2 = j*Lmax + (c-1);
            float match = prev[c-1] + ((A[idx1]==A[idx2])?ms:mis);
            float delv = prev[c] + gs;
            float insv = curr[c-1] + gs;
            float b = match;
            if(delv>b) b=delv;
            if(insv>b) b=insv;
            curr[c] = b;
        }
        for(int c=0;c<=L2;c++) prev[c] = curr[c];
    }
    float sc = prev[L2];
    float Lm = (L1>L2?L1:L2);
    M[i*N+j] = sc/Lm;
}
''', 'nw')

# Compute similarity matrix on GPU
def compute_gpu_similarity(seq_arr):
    A, Ls, Lmax = encode_sequences(seq_arr)
    N = len(seq_arr)

    A_gpu = cp.array(A, dtype=cp.int32)
    Ls_gpu = cp.array(Ls, dtype=cp.int32)
    M_gpu = cp.zeros((N,N), dtype=cp.float32)

    shmem = (Lmax+1)*2*4  # 2 rows, float32
    grid = (N, N)
    block = (1,1)

    nw_kernel(
        grid,
        block,
        (A_gpu, Ls_gpu, M_gpu,
         np.int32(N), np.int32(Lmax),
         np.float32(2.0), np.float32(-0.5), np.float32(-1.0)),
        shared_mem=shmem
    )
    cp.cuda.Stream.null.synchronize()

    M = cp.asnumpy(M_gpu)
    M = (M + M.T)/2.0
    for i in range(N):
        M[i,i] = 1.0
    return M


In [None]:
seq_map = build_seq_map(df_hc, res, fasta)
seq_arr, new_c2id = filter_by_len(seq_map, c2id, max_len=5500)
M = compute_gpu_similarity(seq_arr)


In [None]:
print("Sequence Array")
print("Number of sequences:", len(seq_arr))
print("Example lengths:", [len(seq_arr[i]) for i in range(min(5, len(seq_arr)))])

Lmax = max(len(s) for s in seq_arr)
Lmin = min(len(s) for s in seq_arr)
Lavg = np.mean([len(s) for s in seq_arr])
print("Length stats -> min:", Lmin, "max:", Lmax, "avg:", round(Lavg, 2))

print("\nnew_c2id mapping")
print("Mapping size:", len(new_c2id))
print("Sample:", list(new_c2id.items())[:10])

print("\nSimilarity Matrix")
print("Shape:", M.shape)
print("dtype:", M.dtype)
print("Diagonal min/max:", M.diagonal().min(), M.diagonal().max())
print("Matrix min/mean/max:", M.min(), M.mean(), M.max())

# Symmetry check
sym_err = np.abs(M - M.T).max()
print("\nSymmetry error (should be <1e-5):", sym_err)

# Random spot checks
print("\nRandom entries")
idx = np.random.choice(len(seq_arr), size=5, replace=False)
for i in idx:
    for j in idx:
        print(f"M[{i},{j}] = {M[i,j]:.4f}")
    print("----")

In [None]:
# Restrict df_hc to circRNAs surviving length filter
df_hc2 = df_hc[df_hc["circRNA Name"].isin(new_c2id.keys())].copy()

# Build updated disease and gene sets from filtered df_hc2
dset = sorted(df_hc2["Disease_ID"].dropna().unique())
gset = sorted(df_hc2["Gene Symbol"].dropna().unique())

# Build final circRNA ID map (new_c2id already created)
circ_ids = new_c2id

dis_ids  = {d: i + len(circ_ids) for i, d in enumerate(dset)}
gene_ids = {g: i + len(circ_ids) + len(dis_ids) for i, g in enumerate(gset)}

n_circ = len(circ_ids)
n_dis  = len(dis_ids)
n_gene = len(gene_ids)


In [None]:
CC = []
for i in range(n_circ):
    for j in range(n_circ):
        w = M[i, j]
        if w > 0 and i != j:
            CC.append((i, j, float(w)))

In [None]:
CD = []
for _, row in df_hc2.iterrows():
    cn = row["circRNA Name"]
    d  = row["Disease_ID"]
    u = circ_ids[cn]
    v = dis_ids[d]
    CD.append((u, v, 1.0))
    CD.append((v, u, 1.0))

In [None]:
CG = []
for _, row in df_hc2.iterrows():
    cn = row["circRNA Name"]
    g  = row["Gene Symbol"]
    u = circ_ids[cn]
    v = gene_ids[g]
    CG.append((u, v, 1.0))
    CG.append((v, u, 1.0))

In [None]:
DD = []
for child, parents in parent.items():
    if child in dis_ids:
        for p in parents:
            if p in dis_ids:
                DD.append((dis_ids[child], dis_ids[p]))

In [None]:
print("NODE COUNTS")
print("circRNAs:", n_circ)
print("diseases:", n_dis)
print("genes:", n_gene)
print("total nodes:", n_circ + n_dis + n_gene)

print("\nEDGE COUNTS")
print("CC edges:", len(CC))
print("CD edges:", len(CD))
print("CG edges:", len(CG))
print("DD edges:", len(DD))

# Basic connectivity checks
print("\nCONNECTIVITY CHECKS")
print("Any CC edges?", len(CC) > 0)
print("Any CD edges?", len(CD) > 0)
print("Any CG edges?", len(CG) > 0)
print("Any DD edges?", len(DD) > 0)

# Sample inspection
print("\nSAMPLE EDGES")
print("CC:", CC[:3])
print("CD:", CD[:3])
print("CG:", CG[:3])
print("DD:", DD[:3])

# Index validity checks
all_nodes = set(range(n_circ + n_dis + n_gene))

def check_edges(label, E):
    bad = [(u,v) for (u,v,*_) in E if u not in all_nodes or v not in all_nodes]
    print(label, "invalid edges:", len(bad))

check_edges("CC", CC)
check_edges("CD", CD)
check_edges("CG", CG)
check_edges("DD", DD)


In [None]:
# Build final graph structure
G = {
    "circ_ids": circ_ids,           # circRNA name → node index
    "dis_ids": dis_ids,             # DOID → node index
    "gene_ids": gene_ids,           # gene symbol → node index

    "CC": CC,                       # (u, v, w)
    "CD": CD,                       # (u, v, w=1)
    "CG": CG,                       # (u, v, w=1)
    "DD": DD,                       # (u, v)

    "seq_arr": seq_arr,             # circRNA sequences (ordered)
    "similarity_matrix": M,         # final NW similarity matrix

    "n_circ": n_circ,
    "n_dis": n_dis,
    "n_gene": n_gene,
    "n_total": n_circ + n_dis + n_gene
}

# Save
path = "datasets/final_graph.pkl"
with open(path, "wb") as f:
    pickle.dump(G, f)

print("Saved graph to:", path)
print("Total nodes:", G["n_total"])
print("CC edges:", len(G["CC"]))
print("CD edges:", len(G["CD"]))
print("CG edges:", len(G["CG"]))
print("DD edges:", len(G["DD"]))

In [None]:
def edge_range_forward(E):
    # use only circRNA → other side (forward edges)
    F = [(u,v) for (u,v,*_) in E if u < v]  # forward edges always have circRNA < disease/gene
    us = [u for u,v in F]
    vs = [v for u,v in F]
    return min(us), max(us), min(vs), max(vs)

print("CD forward ranges:", edge_range_forward(CD))
print("CG forward ranges:", edge_range_forward(CG))


In [None]:
# build DD graph
G_dd = nx.Graph()
for u, v in DD:
    G_dd.add_edge(u, v)

dis_gene_deg = {}
for u, g, *_ in CG:
    dis_gene_deg[g] = dis_gene_deg.get(g, 0)
for u, d, *_ in CD:
    if d in dis_ids.values():
        dis_gene_deg[d] = dis_gene_deg.get(d, 0) + sum(1 for x, g, *_ in CG if x == u)

cand = [d for d in dis_gene_deg if dis_gene_deg[d] > 0]
if len(cand) == 0:
    cand = list(dis_ids.values())
d0 = np.random.choice(cand)

comp = list(nx.node_connected_component(G_dd, d0))
if len(comp) > 12:
    comp = comp[:12]

C_block = []
for u, v, *_ in CD:
    if v in comp:
        C_block.append(u)
C_block = list(set(C_block))[:20]

G_block = []
for u, g, *_ in CG:
    if u in C_block:
        G_block.append(g)
G_block = list(set(G_block))[:15]

V = set(comp + C_block + G_block)

K = 4
CC_edges = []
CC_map = {}
for u, v, w in CC:
    if u in C_block:
        CC_map.setdefault(u, []).append((v, w))
for u, lst in CC_map.items():
    lst = sorted(lst, key=lambda x: -x[1])[:K]
    for v, w in lst:
        if v in C_block:
            CC_edges.append((u, v))

G_sub = nx.Graph()
for u in V:
    if u < n_circ:
        t = 'circ'
    elif u < n_circ + n_dis:
        t = 'dis'
    else:
        t = 'gene'
    G_sub.add_node(u, t=t)

# classify edges
E_DD = []
E_CD = []
E_CG = []
E_CC = []

for u, v in DD:
    if u in V and v in V:
        E_DD.append((u, v))
for u, v, *_ in CD:
    if u in V and v in V:
        E_CD.append((u, v))
for u, v, *_ in CG:
    if u in V and v in V:
        E_CG.append((u, v))
for u, v in CC_edges:
    if u in V and v in V:
        E_CC.append((u, v))

# add edges to graph
for u, v in E_DD + E_CD + E_CG + E_CC:
    G_sub.add_edge(u, v)

# node colors
cmap = {'circ':'#1f77b4', 'dis':'#d62728', 'gene':'#2ca02c'}
node_color = [cmap[G_sub.nodes[u]['t']] for u in G_sub.nodes()]

pos = nx.spring_layout(G_sub, seed=3, k=0.7, iterations=90)

plt.figure(figsize=(9,9))
nx.draw_networkx_nodes(G_sub, pos, node_size=90, node_color=node_color)

# draw edges by class
nx.draw_networkx_edges(G_sub, pos, edgelist=E_DD, width=0.9, alpha=0.7, edge_color="gray")
nx.draw_networkx_edges(G_sub, pos, edgelist=E_CD, width=0.8, alpha=0.8, edge_color="black")
nx.draw_networkx_edges(G_sub, pos, edgelist=E_CG, width=0.8, alpha=0.8, edge_color="#2ecc71")
nx.draw_networkx_edges(G_sub, pos, edgelist=E_CC, width=0.8, alpha=0.8, edge_color="#2980b9")

labels = {}
for u in G_sub.nodes():
    if G_sub.nodes[u]['t'] == 'circ':
        labels[u] = list(circ_ids.keys())[list(circ_ids.values()).index(u)]
    elif G_sub.nodes[u]['t'] == 'dis':
        labels[u] = list(dis_ids.keys())[list(dis_ids.values()).index(u)]
    else:
        labels[u] = list(gene_ids.keys())[list(gene_ids.values()).index(u)]

for u, (x, y) in pos.items():
    if G_sub.nodes[u]['t'] == 'circ':
        fs = 5
    else:
        fs = 8
    plt.text(x, y - 0.035, labels[u], fontsize=fs, ha='center', va='top')

circ_patch = plt.Line2D([0],[0],marker='o',color='none',
                        markerfacecolor=cmap['circ'],markersize=8)
dis_patch = plt.Line2D([0],[0],marker='o',color='none',
                       markerfacecolor=cmap['dis'],markersize=8)
gene_patch = plt.Line2D([0],[0],marker='o',color='none',
                        markerfacecolor=cmap['gene'],markersize=8)

DD_patch = plt.Line2D([0],[0],color='gray',lw=2)
CD_patch = plt.Line2D([0],[0],color='black',lw=2)
CG_patch = plt.Line2D([0],[0],color='#2ecc71',lw=2)
CC_patch = plt.Line2D([0],[0],color='#2980b9',lw=2)

plt.legend(
    [circ_patch, dis_patch, gene_patch, DD_patch, CD_patch, CG_patch, CC_patch],
    ['circRNA', 'disease', 'gene', 'DD edge', 'CD edge', 'CG edge', 'CC edge'],
    loc='upper right',
    frameon=False,
    fontsize=9
)

plt.title("Disease-cluster with ontology links and circRNA/gene associations")
plt.axis('off')
plt.tight_layout()
# plt.savefig("subgraph_visualization.png", dpi=300, bbox_inches="tight")
plt.show()
