<a href="https://colab.research.google.com/github/RolakeOkans/JO-neuron-connectivity-clustering/blob/main/jo_connectivity_clustering.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# ============================================
# #DOWNSTREAM CLUSTERING: JO-A & JO-B  NEURONS
# ============================================

# ---- Install dependencies (Colab) ----
!pip -q install scikit-learn yellowbrick fafbseg

import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

from sklearn.preprocessing import Normalizer, MinMaxScaler
from sklearn.metrics.pairwise import cosine_similarity

from scipy.cluster.hierarchy import linkage, dendrogram, fcluster
from scipy.spatial.distance import squareform

# Optional: only import flywire if you truly use it (comment out if unused)
# from fafbseg import flywire

# --------------------------------------------
# Config (set these to your private file paths)
# --------------------------------------------
INPUT_SYNAPSE_CSV = "PATH/TO/synapse_table.csv"
POST_CELL_MAP_CSV = "PATH/TO/post_cell_type_mapping.csv"
NBLAST_SCORES_CSV = "PATH/TO/nblast_scores.csv"

OUTDIR = "outputs"          # keep outputs local; do NOT commit
os.makedirs(OUTDIR, exist_ok=True)

# Turn plotting OFF for public version by default (prevents generating result images)
ENABLE_PLOTS = False

# --------------------------------------------
# Helper utilities
# --------------------------------------------
def safe_save_csv(df: pd.DataFrame, filename: str) -> None:
    """Save intermediate tables locally. Consider excluding OUTDIR via .gitignore."""
    path = os.path.join(OUTDIR, filename)
    df.to_csv(path, index=False)

def safe_save_fig(filename: str) -> None:
    """Save figures locally only when ENABLE_PLOTS=True."""
    if not ENABLE_PLOTS:
        return
    path = os.path.join(OUTDIR, filename)
    plt.savefig(path, dpi=300, bbox_inches="tight")

# --------------------------------------------
# 1) Load + filter synapse table
# --------------------------------------------
def load_and_filter_synapses(input_csv: str) -> pd.DataFrame:
    """
    Expected columns include:
      - cell_type (str)
      - syn_count (numeric)
      - post_root_id (int/str)
    """
    df = pd.read_csv(input_csv)

    # Filter out right hemisphere labels like *_R<number> and syn_count < 5
    df_left = df[
        ~df["cell_type"].astype(str).str.contains(r"_R\d+$", regex=True) &
        (pd.to_numeric(df["syn_count"], errors="coerce") >= 5)
    ].copy()

    return df_left

# --------------------------------------------
# 2) Add post_cell_type labels via mapping file
# --------------------------------------------
def add_post_cell_type(df_left: pd.DataFrame, mapping_csv: str) -> pd.DataFrame:
    """
    Mapping file expected columns include:
      - post_root_id (or post_pt_root_id)
      - post_cell_type
    """
    post_map = pd.read_csv(mapping_csv).copy()

    # Harmonize column name if needed
    if "post_pt_root_id" in post_map.columns and "post_root_id" not in post_map.columns:
        post_map = post_map.rename(columns={"post_pt_root_id": "post_root_id"})

    # Deduplicate mapping
    post_map = post_map.drop_duplicates(subset="post_root_id")

    id_to_celltype = dict(zip(post_map["post_root_id"], post_map["post_cell_type"]))
    df_left["post_cell_type"] = df_left["post_root_id"].map(id_to_celltype)

    return df_left

# --------------------------------------------
# 3) Connectivity pivot table (for clustering)
# --------------------------------------------
def make_connectivity_pivot(df: pd.DataFrame) -> pd.DataFrame:
    grouped = df.groupby(["cell_type", "post_root_id"], as_index=False)["syn_count"].sum()
    pivot = grouped.pivot(index="cell_type", columns="post_root_id", values="syn_count").fillna(0)

    # Replace values of 4 or less with 0 (thresholding)
    pivot = pivot.where(pivot > 4, 0)

    return pivot

# --------------------------------------------
# 4) Hierarchical clustering (connectivity)
# --------------------------------------------
def hierarchical_clustering(matrix: np.ndarray, labels, method: str = "ward"):
    scaler = Normalizer(norm="l2")
    normed = scaler.fit_transform(matrix)
    link = linkage(normed, method=method)

    if ENABLE_PLOTS:
        plt.figure(figsize=(24, 12))
        dendrogram(link, labels=labels, leaf_rotation=90, leaf_font_size=8)
        plt.title("Hierarchical Clustering Dendrogram")
        plt.xlabel("Pre Cell Type")
        plt.ylabel("Distance")
        plt.tight_layout()
        safe_save_fig("connectivity_dendrogram.png")
        plt.show()

    return link

# --------------------------------------------
# 5) Clustermap prep (optional + no display)
# --------------------------------------------
def make_log_heatmap(df_with_types: pd.DataFrame, auditory_post_cell_types: list[str]) -> pd.DataFrame:
    heatmap_data = (
        df_with_types.groupby(["cell_type", "post_cell_type"])["syn_count"]
        .sum()
        .reset_index()
    )
    heatmap_matrix = heatmap_data.pivot(index="cell_type", columns="post_cell_type", values="syn_count")

    auditory_sorted = sorted(auditory_post_cell_types)
    other_sorted = sorted([c for c in heatmap_matrix.columns if c not in auditory_sorted])
    column_order = auditory_sorted + other_sorted
    heatmap_matrix = heatmap_matrix.reindex(columns=column_order)

    log_heatmap = np.log1p(heatmap_matrix)
    return log_heatmap

# --------------------------------------------
# 6) NBLAST cleaning + scaling
# --------------------------------------------
def load_clean_nblast(nblast_csv: str) -> pd.DataFrame:
    nblast = pd.read_csv(nblast_csv, index_col=0)

    pattern = r"_R\d+$"
    nblast = nblast[~nblast.index.astype(str).str.contains(pattern, regex=True)]
    nblast = nblast.loc[:, ~nblast.columns.astype(str).str.contains(pattern, regex=True)]

    # Min-max scale similarity matrix to [0, 1]
    scaler = MinMaxScaler()
    scaled_values = scaler.fit_transform(nblast.values)
    df_scaled = pd.DataFrame(scaled_values, index=nblast.index, columns=nblast.columns)

    return df_scaled

# --------------------------------------------
# 7) Combine connectivity similarity + NBLAST
# --------------------------------------------
def combine_similarity(conn_pivot: pd.DataFrame, nblast_scaled: pd.DataFrame) -> pd.DataFrame:
    # Cosine similarity over connectivity features
    conn_sim = cosine_similarity(conn_pivot.values)
    conn_sim_df = pd.DataFrame(conn_sim, index=conn_pivot.index, columns=conn_pivot.index)

    common = conn_sim_df.index.intersection(nblast_scaled.index)

    conn_df = conn_sim_df.loc[common, common]
    nblast_df = nblast_scaled.loc[common, common]

    combined = conn_df + nblast_df

    # Min-max normalize combined similarity
    min_val = combined.min().min()
    max_val = combined.max().max()
    norm_sim = (combined - min_val) / (max_val - min_val)

    # Convert similarity -> distance
    dist = 1 - norm_sim
    np.fill_diagonal(dist.values, 0)

    # Force symmetry
    sym = (dist.values + dist.values.T) / 2
    dist_sym = pd.DataFrame(sym, index=dist.index, columns=dist.columns)
    np.fill_diagonal(dist_sym.values, 0)

    return dist_sym

# --------------------------------------------
# 8) Cluster combined distance (optional plots)
# --------------------------------------------
def cluster_distance(distance_df_sym: pd.DataFrame, threshold: float = 1.8):
    condensed = squareform(distance_df_sym.values)
    link = linkage(condensed, method="ward")

    if ENABLE_PLOTS:
        plt.figure(figsize=(24, 12))
        dendrogram(link, labels=distance_df_sym.index.tolist(), leaf_rotation=90, leaf_font_size=8,
                   color_threshold=threshold)
        plt.title(f"Combined Clustering (Threshold = {threshold})")
        plt.xlabel("Neuron")
        plt.ylabel("Distance")
        plt.tight_layout()
        safe_save_fig("combined_dendrogram.png")
        plt.show()

    clusters = fcluster(link, t=threshold, criterion="distance")
    cluster_series = pd.Series(clusters, index=distance_df_sym.index, name="cluster")

    return link, cluster_series

# --------------------------------------------
# MAIN PIPELINE (code-only; no outputs shown)
# --------------------------------------------
def main():
    # 1) Load + filter synapses
    df_left = load_and_filter_synapses(INPUT_SYNAPSE_CSV)

    # 2) Add post_cell_type
    df_left = add_post_cell_type(df_left, POST_CELL_MAP_CSV)

    # 3) Connectivity pivot
    conn_pivot = make_connectivity_pivot(df_left)

    # 4) Hierarchical clustering on connectivity
    _ = hierarchical_clustering(conn_pivot.to_numpy(), labels=conn_pivot.index.tolist())

    # 5) Optional heatmap prep (no plotting by default)
    auditory_post_cell_types = [
        # Keep list if it’s not sensitive; otherwise move to a private config
        "A1", "A2", "AVLP_pr01-1", "AVLP_pr23", "aPN2", "B1-1", "B1-2", "B1-3", "B1-4", "B1-5",
        "B1-6", "B1-7", "B1-u", "B2", "GF", "IPS_pr01-2", "IPS_pr02", "JO-A", "JO-B",
        "SAD_pr01", "SAD_pr02", "vpoEN", "WED-VLP-1", "WED-VLP-2", "WED_pr02",
        "WV-WV-1", "WV-WV-2", "WV-WV-3", "WV-WV-4"
    ]
    log_heatmap = make_log_heatmap(df_left, auditory_post_cell_types)

    # 6) Load + clean NBLAST, scale
    nblast_scaled = load_clean_nblast(NBLAST_SCORES_CSV)

    # 7) Combine similarity matrices -> distance
    dist_sym = combine_similarity(conn_pivot, nblast_scaled)

    # 8) Cluster combined distance
    _, cluster_series = cluster_distance(dist_sym, threshold=1.8)

    # Save cluster assignments locally (optional; do NOT commit)
    if ENABLE_PLOTS:
        safe_save_csv(cluster_series.reset_index().rename(columns={"index": "neuron"}), "cluster_assignments.csv")

    # IMPORTANT: No prints, no head(), no shapes in public notebook
    return

main()


In [None]:
# ============================================
#UPSTREAM CLUSTERING: JO-A & JO-B  NEURONS
# ============================================

!pip -q install scikit-learn yellowbrick fafbseg

import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

from sklearn.preprocessing import Normalizer
from scipy.cluster.hierarchy import linkage, dendrogram, fcluster

# --------------------------------------------
# Config (replace with your PRIVATE file paths)
# --------------------------------------------
CODEX_CONNECTIONS_CSV = "PATH/TO/connections_princeton.csv"
JO_CLUSTER_LIST_CSV   = "PATH/TO/JO_cluster_list_with_pre_root_id_upstream.csv"  # has post_root_id + JO neuron name
PRE_CELL_MAP_CSV      = "PATH/TO/JO_Upstream_cell_types_mapping.csv"             # pre_root_id -> pre_cell_type
CLUSTER_ASSIGN_CSV    = "PATH/TO/JO_left_cluster_list_ordered_by_dendrogram.csv" # optional, for cluster color bar

OUTDIR = "outputs"
os.makedirs(OUTDIR, exist_ok=True)

ENABLE_PLOTS = False  # OFF by default for public version

# --------------------------------------------
# Helper utilities
# --------------------------------------------
def save_csv_local(df: pd.DataFrame, filename: str) -> None:
    df.to_csv(os.path.join(OUTDIR, filename), index=False)

def save_fig_local(filename: str) -> None:
    if not ENABLE_PLOTS:
        return
    plt.savefig(os.path.join(OUTDIR, filename), dpi=300, bbox_inches="tight")

# --------------------------------------------
# 1) Build upstream table: find all pre neurons connecting to JO (as posts)
# --------------------------------------------
def build_upstream_table(connections_csv: str, jo_csv: str) -> pd.DataFrame:
    big_df = pd.read_csv(connections_csv)
    jo_df = pd.read_csv(jo_csv)

    # Normalize naming: ensure JO neuron name column becomes 'cell_type'
    if "JO_neuron" in jo_df.columns and "cell_type" not in jo_df.columns:
        jo_df = jo_df.rename(columns={"JO_neuron": "cell_type"})

    # Filter connections to only JO posts (upstream means pre_root_id -> post_root_id (JO))
    filtered = big_df[big_df["post_root_id"].isin(jo_df["post_root_id"])].copy()

    # Merge JO neuron names
    merged = filtered.merge(jo_df[["post_root_id", "cell_type"]], on="post_root_id", how="left")

    # Reorder columns (optional)
    cols = ["cell_type"] + [c for c in merged.columns if c != "cell_type"]
    merged = merged[cols]

    return merged

# --------------------------------------------
# 2) Pivot upstream connectivity for clustering
# --------------------------------------------
def make_upstream_pivot(up_df: pd.DataFrame) -> pd.DataFrame:
    grouped = (
        up_df.groupby(["cell_type", "pre_root_id"], as_index=False)["syn_count"]
        .sum()
    )

    pivot = grouped.pivot(index="cell_type", columns="pre_root_id", values="syn_count").fillna(0)

    # Threshold small counts
    pivot = pivot.where(pivot > 4, 0)
    return pivot

# --------------------------------------------
# 3) Hierarchical clustering (upstream)
# --------------------------------------------
def cluster_upstream(pivot: pd.DataFrame, method: str = "ward"):
    data_matrix = pivot.to_numpy()

    scaler = Normalizer(norm="l2")
    normed = scaler.fit_transform(data_matrix)

    link = linkage(normed, method=method)

    if ENABLE_PLOTS:
        plt.figure(figsize=(24, 12))
        dendrogram(link, labels=pivot.index.tolist(), leaf_rotation=90, leaf_font_size=8)
        plt.title("Hierarchical Clustering Dendrogram — JO Upstream")
        plt.xlabel("Post Cell Type (JO)")
        plt.ylabel("Distance")
        plt.tight_layout()
        save_fig_local("JO_upstream_dendrogram.png")
        plt.show()

    return link

# --------------------------------------------
# 4) Add pre_cell_type labels + create heatmap matrix
# --------------------------------------------
def add_pre_cell_type(up_df: pd.DataFrame, pre_map_csv: str) -> pd.DataFrame:
    pre_map = pd.read_csv(pre_map_csv)

    # Normalize naming
    if "pre_pt_root_id" in pre_map.columns and "pre_root_id" not in pre_map.columns:
        pre_map = pre_map.rename(columns={"pre_pt_root_id": "pre_root_id"})

    # Filter syn_count >= 5 (as in your original)
    up_df2 = up_df.copy()
    up_df2["syn_count"] = pd.to_numeric(up_df2["syn_count"], errors="coerce")
    up_df2 = up_df2[up_df2["syn_count"] >= 5]

    merged = up_df2.merge(pre_map, on="pre_root_id", how="left")  # adds pre_cell_type
    return merged

def make_pretype_heatmap_matrix(df: pd.DataFrame) -> pd.DataFrame:
    heatmap_df = df.pivot_table(
        index="pre_cell_type",
        columns="cell_type",
        values="syn_count",
        aggfunc="sum"
    )
    return heatmap_df

# --------------------------------------------
# 5) Apply manual ordering + log transform (no printing)
# --------------------------------------------
def reorder_and_log_transform(heatmap_df: pd.DataFrame,
                              priority_types: list[str],
                              jo_neuron_order: list[str]) -> pd.DataFrame:
    # Row order: priority first, then remaining alphabetical
    priority_sorted = sorted(priority_types)
    remaining = sorted([x for x in heatmap_df.index if x not in priority_sorted])
    final_row_order = priority_sorted + remaining
    heatmap_df = heatmap_df.reindex(index=final_row_order)

    # Ensure all JO neurons exist as columns (fill missing with NaN)
    for neuron in jo_neuron_order:
        if neuron not in heatmap_df.columns:
            heatmap_df[neuron] = np.nan

    final_col_order = [c for c in jo_neuron_order if c in heatmap_df.columns]
    heatmap_df = heatmap_df.reindex(columns=final_col_order)

    # Log-transform: use log10, handle 0 -> -inf -> NaN
    log_df = np.log10(heatmap_df)
    log_df.replace(-np.inf, np.nan, inplace=True)

    return log_df, final_row_order, final_col_order

# --------------------------------------------
# 6) Optional plotting (disabled by default)
# --------------------------------------------
def plot_upstream_heatmap(log_df: pd.DataFrame, outname: str):
    if not ENABLE_PLOTS:
        return

    # compute vmin/vmax ignoring NaNs
    non_nan = log_df.values[np.isfinite(log_df.values)]
    if non_nan.size == 0:
        return
    vmin, vmax = float(np.min(non_nan)), float(np.max(non_nan))

    cmap = plt.get_cmap("viridis_r").copy()
    cmap.set_bad(color="#d0d0d0")

    plt.figure(figsize=(24, 18))
    ax = sns.heatmap(
        log_df,
        mask=log_df.isna(),
        cmap=cmap,
        vmin=vmin,
        vmax=vmax,
        linewidths=0.25,
        linecolor="black",
        xticklabels=True,
        yticklabels=True,
        cbar_kws={"label": "Synapse count (log10)"}
    )
    ax.yaxis.tick_right()
    ax.yaxis.set_label_position("right")
    ax.set_xticklabels(ax.get_xticklabels(), rotation=90, fontsize=6)
    ax.set_yticklabels(ax.get_yticklabels(), rotation=0, fontsize=6)
    plt.tight_layout()
    save_fig_local(outname)
    plt.show()

# --------------------------------------------
# MAIN (code-only)
# --------------------------------------------
def main():
    # Build upstream table
    upstream = build_upstream_table(CODEX_CONNECTIONS_CSV, JO_CLUSTER_LIST_CSV)
    save_csv_local(upstream, "filtered_JO_left_pre_ids_CODEONLY.csv")  # local only

    # Pivot + cluster
    pivot = make_upstream_pivot(upstream)
    _ = cluster_upstream(pivot)

    # Add pre_cell_type mapping
    upstream_pretyped = add_pre_cell_type(upstream, PRE_CELL_MAP_CSV)
    save_csv_local(upstream_pretyped, "upstream_with_pre_cell_type_CODEONLY.csv")

    # Heatmap matrix
    heatmap_df = make_pretype_heatmap_matrix(upstream_pretyped)

    # Manual ordering lists (keep here, or move to a private config)
    priority_types = [
        "AVLP_pr01-1", "B2", "IPS_pr02", "JO-A", "JO-B", "JO-EDM", "JO-DP", "JO-mz",
        "SAD_pr01", "SAD_pr02", "WED_pr02", "WV-WV-3",
    ]

    jo_neuron_order = [
        # Keep your full list here (truncated for brevity in this template)
        # Paste your whole jo_neuron_order list from your notebook
        "JO-A_L16", "JO-A_L21", "JO-A_L31", "JO-A_L1",
        "JO-B_L124", "JO-B_L40", "JO-B_L132", "JO-B_L145",
        # ...
    ]

    log_df, final_row_order, final_col_order = reorder_and_log_transform(
        heatmap_df, priority_types, jo_neuron_order
    )

    # Optional plot (OFF by default)
    plot_upstream_heatmap(log_df, "JO_upstream_heatmap_CODEONLY.png")

    # Optional: export the matrix used for plotting (local only; do NOT commit)
    # To keep extra safe, save a *structure-only* version (no values) by default:
    structure_only = pd.DataFrame(index=log_df.index, columns=log_df.columns)
    structure_only.to_csv(os.path.join(OUTDIR, "JO_upstream_heatmap_structure_ONLY.csv"))

    return

main()
