# Community Detection from Activation Maps

This notebook performs **community detection** on the neuron graph derived from activation maps produced in data prep.
It groups neurons into communities based on shared activation patterns for downstream theme analysis.

**Pipeline**
1. Load activation-derived data.
2. Construct a graph (NetworkX).
3. Run community detection (Leiden).
4. Post-process communities (filtering, sizes/coverage).
5. Save community assignments and summary statistics.


In [None]:
import torch
from typing import Dict, Hashable, List, Optional, Tuple, Union
import numpy as np
import scipy.sparse as sp
from __future__ import annotations
import faiss
import igraph as ig
import leidenalg as la
import json

### Data Preparation

In [None]:
# # load data
activations = {}
for layer in range(17, 24):
    layer_dict = torch.load(f"../layer_activations_no_test/layer_{layer}.json")
    activations[f"layer_{layer}"] = {qid : d["post"] for qid, d in layer_dict.items()}

In [None]:
def align_layers_to_common_prompts(
    layer_to_map: Dict[Hashable, Dict[Hashable, np.ndarray]],
    *,
    dtype: np.dtype = np.float32,
) -> Tuple[List[Hashable], List[Union[np.ndarray, sp.csr_matrix]], List[Hashable], dict]:
    """
    Align per-layer {prompt_id -> vector} dicts to a common prompt set and row order.

    Args
    ----
    layer_to_map: dict[layer] -> {prompt_id: 1D activation vector for that layer}
    dtype: output dtype

    Returns
    -------
    ordered_prompt_ids: list of prompt_ids (row 0..P-1)
    X_list: list of matrices, one per layer (aligned rows)
    layer_numbers: list of layers in the order processed (sorted by key)
    stats: dict with counters (P_all, P_common, removed_per_layer)
    """
    layer_numbers = list(layer_to_map.keys())
    layer_numbers = sorted(layer_numbers)

    # compute common prompt set (prompts present in ALL layers)
    sets = [set(layer_to_map[L].keys()) for L in layer_numbers]
    common = set.intersection(*sets)
    P_common = len(common)

    # choose a global row order for the prompts
    ordered_query_ids = sorted(common)

    # verify vector lengths and build aligned matrices
    removed_per_layer = {}
    X_list: List[Union[np.ndarray, sp.csr_matrix]] = []

    for layer in layer_numbers:
        m = layer_to_map[layer]
        # get any vector to infer dimensionality
        first_qid = next(iter(m.keys()))
        D = int(np.asarray(m[first_qid]).shape[0])        
        removed_per_layer[layer] = len(m) - P_common

        # Build CSR row-by-row without densifying the whole layer
        data: List[float] = []
        indices: List[int] = []
        indptr: List[int] = [0]
        for qid in ordered_query_ids:
            v = np.asarray(m[qid], dtype=dtype, order="C")
            nz = np.flatnonzero(v) 
            if nz.size:
                data.extend(v[nz].tolist())
                indices.extend(nz.astype(np.int32).tolist())
            indptr.append(len(data))
        X_csr = sp.csr_matrix((np.asarray(data, dtype=dtype),
                                np.asarray(indices, dtype=np.int32),
                                np.asarray(indptr, dtype=np.int64)),
                                shape=(P_common, D), dtype=dtype)
        X_list.append(X_csr)
        

    stats = dict(
        P_all={L: len(layer_to_map[L]) for L in layer_numbers},
        P_common=P_common,
        removed_per_layer=removed_per_layer
    )
    return ordered_query_ids, X_list, layer_numbers, stats


In [None]:
ordered_ids, X_list, layers, stats = align_layers_to_common_prompts(
    activations,
    order="sorted",      # or "ref", ref_layer=0
    as_sparse=True,      # recommended
    zero_tol=0.0,        # set e.g. 1e-6 if you want to prune tiny values
    dtype=np.float32,
)

print("common prompts:", stats["P_common"])
print("dims per layer:", stats["dims_per_layer"])
print("removed per layer:", stats["removed_per_layer"])
# X_list is now aligned: for every layer L, row i corresponds to prompt ordered_ids[i]


#### Helper Functions

In [None]:
def _ensure_csr_float32(X) -> sp.csr_matrix:
    """Ensure SciPy CSR float32 without densifying."""
    if sp.issparse(X):
        X = X.tocsr()
        if X.dtype != np.float32:
            X = X.astype(np.float32, copy=False)
        return X
    # dense path: convert to CSR without copying huge intermediates
    X = np.asarray(X)
    if X.dtype != np.float32:
        X = X.astype(np.float32, copy=False)
    return sp.csr_matrix(X)  # respects zeros; safe if you have many zeros


def _idf_vector_from_csr(X_csr: sp.csr_matrix, smooth: float = 1.0) -> np.ndarray:
    """Compute column-wise IDF: log((P+smooth)/(df+smooth))+1, float32."""
    P, N = X_csr.shape
    # df: nonzero count per column
    df = np.diff(X_csr.tocsc().indptr).astype(np.float32, copy=False)  # length N
    idf = np.log((np.float32(P) + smooth) / (df + smooth)) + np.float32(1.0)
    return idf.astype(np.float32, copy=False)


def _apply_tfidf_rowl2(X_csr: sp.csr_matrix, idf: np.ndarray) -> sp.csr_matrix:
    """Right-multiply by diag(idf) then row L2 normalize (all sparse)."""
    X = X_csr @ sp.diags(idf, dtype=np.float32, format="csr")
    # row L2 norms
    sq = X.multiply(X).sum(axis=1).A1.astype(np.float32, copy=False)
    inv = np.float32(1.0) / np.sqrt(np.maximum(sq, np.float32(1e-12)))
    return sp.diags(inv, dtype=np.float32, format="csr") @ X


def _countsketch_batch_to_dense(
    sub_csr: sp.csr_matrix,
    buckets: np.ndarray,   # (N_l,) int32 in [0, out_dim)
    signs:   np.ndarray,   # (N_l,) int8  in {-1, +1}
    out_dim: int,
) -> np.ndarray:
    """
    CountSketch current batch (rows of CSR) into a small dense (rows x out_dim).
    We use a sparse COO to accumulate into buckets, then .toarray().
    """
    coo = sub_csr.tocoo(copy=False)
    if coo.nnz == 0:
        return np.zeros((sub_csr.shape[0], out_dim), dtype=np.float32)
    hb = buckets[coo.col].astype(np.int32, copy=False)
    sg = signs[coo.col].astype(np.float32, copy=False)
    vals = (coo.data.astype(np.float32, copy=False) * sg)
    sketch = sp.coo_matrix((vals, (coo.row.astype(np.int32, copy=False), hb)),
                           shape=(sub_csr.shape[0], out_dim), dtype=np.float32)
    sketch.sum_duplicates()
    return sketch.toarray()  # tiny dense (e.g., 2048 x 64)


def _mutualize_topk_numpy(
    idx_k: np.ndarray, sim_k: np.ndarray, P: int
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
    """
    Turn directed top-k neighbors into mutual, undirected edges.
    Returns unique (src, dst, weight) with src<dst and weight = mean of both dirs.
    idx_k: (P, k) int32 neighbors; sim_k: (P, k) float32 similarities (cosines).
    """
    k = idx_k.shape[1]
    src = np.repeat(np.arange(P, dtype=np.int32), k)
    dst = idx_k.reshape(-1).astype(np.int32, copy=False)
    w   = sim_k.reshape(-1).astype(np.float32, copy=False)

    # drop self and invalid (-1)
    mask = (src != dst) & (dst >= 0)
    a = np.minimum(src[mask], dst[mask])
    b = np.maximum(src[mask], dst[mask])
    w = w[mask]

    order = np.lexsort((b, a))
    a, b, w = a[order], b[order], w[order]
    eq = (a[1:] == a[:-1]) & (b[1:] == b[:-1])
    hit = np.where(eq)[0]  # indices of second in each pair

    src_u = a[hit]
    dst_u = b[hit]
    w_u   = 0.5 * (w[hit] + w[hit + 1])
    return src_u, dst_u, w_u


def _leiden_cpu_chunked(
    src: np.ndarray, dst: np.ndarray, weight: np.ndarray, n: int,
    resolution: float = 1.0, seed: int = 42, use_weights: bool = True,
    edge_chunk: int = 200_000
) -> Tuple[np.ndarray, float]:
    """Build graph and run Leiden, adding edges in chunks to limit RAM spikes."""
    g = ig.Graph()
    g.add_vertices(int(n))
    m = int(src.shape[0])
    off = 0
    for s in range(0, m, edge_chunk):
        e = min(s + edge_chunk, m)
        g.add_edges(zip(src[s:e], dst[s:e]))
        if use_weights:
            g.es[off:off+(e-s)]["weight"] = weight[s:e].tolist()
        off += (e - s)

    if use_weights:
        part = la.find_partition(g, la.RBConfigurationVertexPartition,
                                 weights="weight", resolution_parameter=resolution, seed=seed)
    else:
        part = la.find_partition(g, la.RBConfigurationVertexPartition,
                                 resolution_parameter=resolution, seed=seed)
    return np.asarray(part.membership, dtype=np.int32), float(part.quality())


### Main Pipeline Function

In [None]:
"""
Stable, memory-safe prompt community detection across layers.

- Input: list of L layer matrices X_list, each shape (P, N_l).
         Dense NumPy (float32) or SciPy CSR is accepted.
- Output: community labels for P prompts, mutual-kNN graph, and diagnostics.

Design:
- Per layer: TF-IDF (column IDF + row L2) in sparse, then CountSketch to rp_dim.
- Concatenate per-layer sketches -> (P, D_total), final row L2 normalization.
- FAISS (CPU) HNSW kNN on the compact features (batched).
- Mutual-kNN graph + Leiden (CPU), with chunked edge loads to limit RAM spikes.
"""
def cluster_prompts_multilayer_cpu_safe(
    X_list: List[Union[np.ndarray, sp.spmatrix]],  # L layers, each (P, N_l)
    rp_dim: Union[int, List[int]] = 64,            # per-layer sketch dim (int or list[L])
    k: int = 40,                                   
    hnsw_M: int = 32,                               # HNSW graph degree
    efC: int = 200,                                 # HNSW construction ef
    efS: int = 64,                                  # HNSW search ef
    batch_rows: int = 2048,                         # rows per batch for sketching
    resolution: float = 1.0,                        # Leiden resolution
    seed: int = 42,
    layer_weights: Optional[List[float]] = None,    # reliability weights per layer (√w scaling)
    edge_chunk: int = 200_000,                      # edges per chunk to add to igraph
) -> dict:

    assert len(X_list) >= 1, "Provide at least one layer."
    P = X_list[0].shape[0]
    for Xl in X_list:
        assert Xl.shape[0] == P, "All layers must have the same number of prompts."

    if isinstance(rp_dim, int):
        rp_dims = [int(rp_dim)] * len(X_list)
    else:
        assert len(rp_dim) == len(X_list)
        rp_dims = [int(d) for d in rp_dim]

    rng = np.random.RandomState(seed)
    F_parts = []

    # --- Per layer: sparse TF-IDF + CountSketch (batched) ---
    for li, (Xl, dli) in enumerate(zip(X_list, rp_dims)):
        X = _ensure_csr_float32(Xl)                      # CSR float32
        idf = _idf_vector_from_csr(X)                    # (N_l, )

        # hash functions
        buckets = rng.randint(0, dli, size=X.shape[1], dtype=np.int32)
        signs   = rng.choice(np.array([-1, 1], dtype=np.int8), size=X.shape[1])

        F_l = np.empty((P, dli), dtype=np.float32)
        for s in range(0, P, batch_rows):
            e = min(s + batch_rows, P)
            sub = X[s:e]                                 # CSR slice
            sub = _apply_tfidf_rowl2(sub, idf)           # TF-IDF + row L2 (sparse)
            F_batch = _countsketch_batch_to_dense(sub, buckets, signs, dli)
            # optional layer weight (√w scaling)
            if layer_weights is not None:
                F_batch *= np.sqrt(np.float32(layer_weights[li]))
            F_l[s:e] = F_batch

        # light per-layer row L2 (helps if layer coverage differs)
        norms = np.linalg.norm(F_l, axis=1, keepdims=True)
        norms[norms < 1e-8] = 1e-8
        F_l /= norms
        F_parts.append(F_l)

    # --- Fuse across layers: concat + FINAL row L2 normalization ---
    F_host = np.concatenate(F_parts, axis=1).astype(np.float32)  # (P, D_total)
    del F_parts
    norms = np.linalg.norm(F_host, axis=1, keepdims=True)
    norms[norms < 1e-8] = 1e-8
    F_host /= norms
    d_total = int(F_host.shape[1])

    # --- kNN via FAISS (CPU) HNSW (cosine via L2 on unit vectors) ---
    # For unit vectors, cosine(x,y) = 1 - 0.5 * ||x - y||^2
    index = faiss.IndexHNSWFlat(d_total, hnsw_M)  # L2 by default
    index.hnsw.efConstruction = int(efC)
    index.hnsw.efSearch = int(efS)

    index.add(F_host)  # build graph; multi-threaded in FAISS
    kk = min(max(2, k + 8), P)  # ask a few extras; we'll drop self

    # batched search to limit RAM
    I = np.empty((P, kk), dtype=np.int64)
    D = np.empty((P, kk), dtype=np.float32)  # squared L2 distances
    bs = 8192
    for s in range(0, P, bs):
        e = min(s + bs, P)
        Di, Ii = index.search(F_host[s:e], kk)
        D[s:e] = Di
        I[s:e] = Ii

    # Convert distances to cosine sims, drop self/invalid, keep top-k
    idx_k = np.empty((P, k), dtype=np.int32)
    sim_k = np.empty((P, k), dtype=np.float32)
    for i in range(P):
        neigh = I[i]
        dist2 = D[i]
        # FAISS returns -1 when no neighbor; remove invalids and self
        good = (neigh >= 0) & (neigh != i)
        neigh = neigh[good]
        dist2 = dist2[good]
        sim = 1.0 - 0.5 * dist2  # cosine for unit vectors
        # take top-k by similarity
        take = min(k, neigh.shape[0])
        if take == 0:
            idx_k[i, :] = i
            sim_k[i, :] = 0.0
        else:
            order = np.argsort(-sim)[:take]
            idx_k[i, :take] = neigh[order].astype(np.int32, copy=False)
            sim_k[i, :take] = sim[order].astype(np.float32, copy=False)
            if take < k:
                idx_k[i, take:] = idx_k[i, :1]
                sim_k[i, take:] = sim_k[i, :1]

    # --- Mutual kNN graph (unique undirected edges) ---
    src, dst, weight = _mutualize_topk_numpy(idx_k, sim_k, P)

    # --- Leiden (chunked edges to avoid RAM spikes) ---
    labels, quality = _leiden_cpu_chunked(
        src, dst, weight, n=P, resolution=resolution, seed=seed,
        use_weights=True, edge_chunk=edge_chunk
    )

    return {
        "labels": labels,                    # (P,) community id per prompt
        "edges": (src, dst, weight),         # mutual-kNN graph
        "quality": quality,                  # modularity-like score
        "params": {
            "P": P, "layer_dims": [X.shape[1] for X in X_list], "rp_dims": rp_dims,
            "k": int(k), "hnsw_M": int(hnsw_M), "efC": int(efC), "efS": int(efS),
            "batch_rows": int(batch_rows), "resolution": float(resolution),
        },
    }

### Running the Pipeline

In [None]:
out = cluster_prompts_multilayer_cpu_safe(
    X_list,
    rp_dim=64,         # try 64 first; if stability is low, try 96
    k=40,
    hnsw_M=32,         # larger M => better recall/slower build; 32–48 good
    efC=200,           # construction quality
    efS=64,            # search quality (raise to 96 for higher recall)
    batch_rows=2048,   # reduce to 1024 if RAM is tight during sketching
    resolution=1.0,
    seed=42,
)
labels = out["labels"]
print("Leiden quality:", out["quality"], "num clusters:", len(set(labels)))


### Inspecting the Communities

##### Functions

In [None]:

# =============== Basics & Graph Construction =================

def partition_summary(labels: np.ndarray) -> Dict:
    labels = np.asarray(labels, dtype=np.int32)
    comms, counts = np.unique(labels, return_counts=True)
    order = np.argsort(-counts)
    return {
        "n_nodes": labels.size,
        "n_comms": comms.size,
        "sizes_sorted": counts[order],
        "comms_sorted": comms[order],
    }

# =============== Prototypes (per community & layer) =================

def community_prototypes_by_layer(
    X_list: List[Union[np.ndarray, sp.spmatrix]],
    labels: np.ndarray,
    normalize: bool = False  # if True: z-score neurons across prompts before averaging
) -> List[np.ndarray]:
    """
    For each layer ℓ: return a (C, N_ℓ) array, mean activation per community.
    Optional: z-score across prompts first to remove neuron bias.
    """
    labels = labels.astype(np.int32)
    comms, inv = np.unique(labels, return_inverse=True)
    C = comms.size
    protos = []
    for X in X_list:
        if sp.issparse(X):
            X = X.tocsr()
            if normalize:
                # z-score columns using sparse stats
                mu = X.mean(axis=0).A1
                Xc = X - sp.csr_matrix(np.broadcast_to(mu, X.shape))
                # var: E[(x-mu)^2]; approximate via sparse (ignore implicit zeros correction)
                var = (Xc.multiply(Xc)).mean(axis=0).A1
                sd = np.sqrt(np.maximum(var, 1e-8))
                Xn = (Xc @ sp.diags(1.0/sd, dtype=np.float32))
            else:
                Xn = X
            # aggregate by community: sum rows per community then divide by counts
            counts = np.bincount(inv, minlength=C).astype(np.float32)
            # build indicator sparse matrix for communities
            S = sp.csr_matrix((np.ones(Xn.shape[0], dtype=np.float32), (np.arange(Xn.shape[0]), inv)), shape=(Xn.shape[0], C))
            M = (S.T @ Xn).toarray()  # (C, N)
            M = (M.T / counts).T
        else:
            X = np.asarray(X, dtype=np.float32)
            if normalize:
                mu = X.mean(axis=0, keepdims=True); sd = X.std(axis=0, keepdims=True) + 1e-8
                Xn = (X - mu) / sd
            else:
                Xn = X
            C = comms.size
            M = np.zeros((C, Xn.shape[1]), dtype=np.float32)
            for c_idx, c in enumerate(comms):
                M[c_idx] = Xn[labels == c].mean(axis=0)
        protos.append(M.astype(np.float32))
    return protos  # list of length L, each (C, N_l)

def top_neurons_for_community(
    layer_protos: np.ndarray,  # (C, N_l) for one layer
    community_index: int, topk: int = 20,
    compare_to: str = "global"  # "global" or "others"
) -> Tuple[np.ndarray, np.ndarray]:
    """
    Rank neurons for one community: returns (indices, scores).
    'global': score = mean_c - global_mean
    'others': score = mean_c - mean_over_other_communities
    """
    C, N = layer_protos.shape
    mean_c = layer_protos[community_index]
    if compare_to == "global":
        base = layer_protos.mean(axis=0)
    else:
        base = (layer_protos.sum(axis=0) - mean_c) / max(C-1, 1)
    scores = mean_c - base
    idx = np.argsort(-scores)[:topk]
    return idx, scores[idx]


##### Pipeline

In [None]:

edges  = out["edges"]             
P = len(labels)
src, dst, w = edges

# Basic summary
summ = partition_summary(labels)
print("nodes:", summ["n_nodes"], "communities:", summ["n_comms"])
print("largest sizes:", summ["sizes_sorted"][:10])

# Number of isolated components
deg = np.zeros(P, dtype=np.int32)
np.add.at(deg, src, 1); np.add.at(deg, dst, 1)
print("deg=0:", (deg==0).sum(), "deg=1:", (deg==1).sum())
print("deg quantiles:", np.quantile(deg, [0, .1, .25, .5, .75, .9, .99, 1.0]))

g = ig.Graph(n=P, edges=list(zip(src.tolist(), dst.tolist())), directed=False)
print("components:", g.components().summary())

### Back Into Query IDs

In [None]:
def ids_by_community(labels: np.ndarray, prompt_ids: np.ndarray, min_size: int = 2, sort_by_size: bool = True):
    """
    prompt_ids: array-like of length P (strings/ints). Returns {community_id: np.ndarray(ids)}.
    """
    labels = np.asarray(labels, dtype=np.int32)
    prompt_ids = np.asarray(prompt_ids)
    comms, counts = np.unique(labels, return_counts=True)
    keep = counts >= min_size
    comms_k = comms[keep]
    counts_k = counts[keep]
    order = np.argsort(-counts_k) if sort_by_size else np.arange(len(comms_k))
    comms_k = comms_k[order]

    groups = {int(c): prompt_ids[labels == c] for c in comms_k}
    sizes  = {int(c): int((labels == c).sum()) for c in comms_k}
    return groups, sizes

In [None]:
community_to_ids, sizes = ids_by_community(labels, ordered_ids)

##### Saving community -> queries

In [None]:
qid_to_queries = torch.load("../q_ids_to_queries.json")

In [None]:
community_to_queries = {com : [qid_to_queries[id] for id in community_to_ids[com]] for com in community_to_ids}

with open("communities_to_queries.json","w") as f:
    json.dump(community_to_queries, f)