In [None]:
import os
import pandas as pd
import numpy as np
import networkx as nx
from sklearn.metrics.pairwise import cosine_distances
import ndex2
import json
import matplotlib.pyplot as plt
import seaborn as sns

In [None]:
base_dir = "../data/embedding"
output_path = "../data/embedding_graph"
output_file = os.path.join(output_path, "if.cx")
os.makedirs(output_path, exist_ok=True)
G = nx.Graph()

In [None]:
def get_edge_dataframe(G):
    edge_data = []
    for u, v, attrs in G.edges(data=True):
        row = {"gene1": u, "gene2": v}
        row.update(attrs)
        edge_data.append(row)
    return pd.DataFrame(edge_data)

In [None]:
treatment_dfs = {}

for treatment in os.listdir(base_dir):
    treatment_path = os.path.join(base_dir, treatment)
    emb_file = os.path.join(treatment_path, "image_emd.tsv")

    if not os.path.isfile(emb_file):
        print(f"❌ Skipping {treatment}: image_emd.tsv not found.")
        continue

    print(f"🔍 Processing treatment: {treatment}")
    df = pd.read_csv(emb_file, sep="\t")
    df = df[df["id"].notna() & (df["id"] != "NEGATIVE")]
    embedding_cols = df.columns[1:]
    df[embedding_cols] = df[embedding_cols].apply(pd.to_numeric)

    # Average embeddings
    df_avg = df.groupby("id")[embedding_cols].mean().reset_index()
    treatment_dfs[treatment] = df_avg

In [None]:
# Get only genes common to all treatments
common_genes = set.intersection(*[set(df["id"]) for df in treatment_dfs.values()])
print(f"✅ Found {len(common_genes)} common genes across treatments")

In [None]:
for treatment, df_avg in treatment_dfs.items():
    df_avg = df_avg[df_avg["id"].isin(common_genes)].reset_index(drop=True)
    embedding_cols = df_avg.columns[1:]

    gene_names = df_avg["id"].tolist()
    emb_matrix = df_avg[embedding_cols].values
    distance_matrix = cosine_distances(emb_matrix)

    print(f"📐 Building cosine distance graph for {treatment}")
    for i in range(len(gene_names)):
        for j in range(i + 1, len(gene_names)):
            gene_i = gene_names[i]
            gene_j = gene_names[j]
            dist = distance_matrix[i, j]

            if not G.has_edge(gene_i, gene_j):
                G.add_edge(gene_i, gene_j)
            G[gene_i][gene_j][f"cosine_distance_{treatment}"] = dist

In [None]:
# Postprocessing: compute variability across treatments
print("📊 Computing variability in cosine distances across treatments...")

treatment_attrs = [attr for attr in next(iter(G.edges(data=True)))[-1] if attr.startswith("cosine_distance_")]

for u, v, data in G.edges(data=True):
    dists = [data.get(attr) for attr in treatment_attrs if data.get(attr) is not None]

    if len(dists) >= 2:  # Only calculate if we have at least 2 values
        data["cosine_distance_min"] = min(dists)
        data["cosine_distance_max"] = max(dists)
        data["cosine_distance_range"] = max(dists) - min(dists)
        data["cosine_distance_mean"] = sum(dists) / len(dists)
        data["cosine_distance_std"] = np.std(dists)
    else:
        # Set to None or 0 if not enough values
        data["cosine_distance_min"] = None
        data["cosine_distance_max"] = None
        data["cosine_distance_range"] = None
        data["cosine_distance_mean"] = None
        data["cosine_distance_std"] = None

In [None]:
ndex_graph = ndex2.create_nice_cx_from_networkx(G)

# Get the CX (list of CX aspects, not BytesIO)
cx_aspect_list = ndex_graph.to_cx()

# Save the CX JSON to a file
with open(output_file, "w") as f:
    json.dump(cx_aspect_list, f, indent=2)

print(f"✅ Graph saved to {output_file}")

In [None]:
df_edges = get_edge_dataframe(G)

In [None]:
summary_metrics = [
    "cosine_distance_mean",
    "cosine_distance_std",
    "cosine_distance_range",
    "cosine_distance_min",
    "cosine_distance_max"
]

plt.figure(figsize=(15, 10))
for i, metric in enumerate(summary_metrics, 1):
    plt.subplot(2, 3, i)
    sns.histplot(df_edges[metric].dropna(), bins=50, kde=True)
    plt.title(metric)
    plt.xlabel("Distance")
    plt.ylabel("Count")

plt.tight_layout()
plt.show()

In [None]:
print(df_edges[[
    "cosine_distance_min",
    "cosine_distance_max",
    "cosine_distance_mean",
    "cosine_distance_range",
    "cosine_distance_std"
]].describe())

In [None]:
top_var = df_edges.sort_values("cosine_distance_std", ascending=False).head(10)
print("🔎 Top 10 most variable gene pairs across treatments:")
print(top_var[["gene1", "gene2", "cosine_distance_std", "cosine_distance_range"]])