
# Generalized Frangi with Multi-modal Fusion on FIND (Illustrative Notebook)

This Colab notebook demonstrates the complete pipeline (download, unzip, Hessians, fusion, Frangi graph, HDBSCAN, MST + k-centers, animation, and metrics).


In [None]:

# Install deps
!pip -q install numpy scipy scikit-image matplotlib joblib tqdm tqdm-joblib hdbscan networkx gdown pot imageio pandas Pillow


In [None]:

import os, sys, glob, re, random, numpy as np, imageio, matplotlib.pyplot as plt
from tqdm import tqdm
from tqdm_joblib import tqdm_joblib
from joblib import Parallel, delayed

# Allow importing from this repo if running in Colab after uploading or mounting
repo_path = os.path.abspath("..")
if repo_path not in sys.path:
    sys.path.append(repo_path)

from frangi_fusion import (set_seed, auto_discover_find_structure, load_modalities_and_gt_by_index,
                           to_gray, compute_hessians_per_scale, fuse_hessians_per_scale,
                           build_frangi_similarity_graph, distances_from_similarity, triangle_connectivity_graph,
                           largest_connected_component, hdbscan_from_sparse,
                           mst_on_cluster, kcenters_on_tree, fault_graph_from_mst_and_kcenters,
                           skeletonize_lee, jaccard_index, tversky_index, wasserstein_distance_skeletons, thicken,
                           overlay_hessian_orientation, show_clusters_on_image, animate_fault_growth)


## 1) Download FIND `data.zip` and unzip

In [None]:

import gdown, zipfile
url = "https://drive.google.com/uc?id=1qnLMCeon7LJjT9H0ENiNF5sFs-F7-NvK"
zip_path = "data.zip"
if not os.path.exists(zip_path):
    gdown.download(url, zip_path, quiet=False)

extract_dir = "data_find"
os.makedirs(extract_dir, exist_ok=True)
with zipfile.ZipFile(zip_path, 'r') as zf:
    zf.extractall(extract_dir)
print("Unzipped to:", extract_dir)


## 2) Pick one image and display modalities

In [None]:

set_seed(1234)
struct = auto_discover_find_structure(extract_dir)
n_total = len(struct["label"]) if struct["label"] else len(struct["intensity"])
index = random.randint(0, max(0,n_total-1))
dat = load_modalities_and_gt_by_index(struct, index)
print("Selected index:", index)
for k,v in dat["paths"].items():
    print(k, "->", v)

cols = len(dat["arrays"])
plt.figure(figsize=(4*cols,4))
for i,(k,arr) in enumerate(dat["arrays"].items()):
    plt.subplot(1, cols, i+1); plt.title(k); plt.imshow(arr, cmap='gray'); plt.axis('off')
plt.show()


## 3) Parameters

In [None]:

sigmas = [1,3,5,7,9]
beta = 0.5
c = 0.25
ctheta = 0.125
R = 5
K = 1
expZ = 2.0


## 4–5) Hessians and fusion

In [None]:

mods = {}
if "intensity" in dat["arrays"]:
    mods["intensity"] = compute_hessians_per_scale(to_gray(dat["arrays"]["intensity"]), sigmas)
if "range" in dat["arrays"]:
    mods["range"] = compute_hessians_per_scale(to_gray(dat["arrays"]["range"]), sigmas)
if "fused" in dat["arrays"]:
    mods["fused"] = compute_hessians_per_scale(to_gray(dat["arrays"]["fused"]), sigmas)

weights = {k:1.0 for k in mods.keys()}
fused_H = fuse_hessians_per_scale(mods, weights)
print("Fused using modalities:", list(mods.keys()))


## 6) Visualize Hessian overlays

In [None]:

base = dat["arrays"].get("intensity", list(dat["arrays"].values())[0])
plt.figure(figsize=(15,4))
for i, Hd in enumerate(fused_H[:3]):
    overlay = overlay_hessian_orientation(base, Hd, alpha=0.5)
    plt.subplot(1,3,i+1); plt.title(f"Sigma={Hd['sigma']}"); plt.imshow(overlay); plt.axis('off')
plt.show()


## 7–8) Frangi graph and optional triangle-connectivity

In [None]:

coords, neighbors, S = build_frangi_similarity_graph(fused_H, beta, c, ctheta, R)
D = distances_from_similarity(S)
if K==2:
    D = triangle_connectivity_graph(coords, D)
print("Graph nodes:", D.shape[0], "non-zeros:", D.nnz)


## 9) Largest connected component

In [None]:

D_cc, idx_nodes = largest_connected_component(D)
print("Largest CC nodes:", D_cc.shape[0])
sub_coords = coords[idx_nodes]


## 10–11) HDBSCAN and cluster display

In [None]:

labels = hdbscan_from_sparse(D_cc, min_cluster_size=50, min_samples=5, allow_single_cluster=True, expZ=2.0)
print("Clusters:", np.unique(labels))
show_clusters_on_image(base, sub_coords, labels, figsize=(5,5))


## 12) MST + k-centers -> fault graph

In [None]:

fault_edges_list = []
for lab in np.unique(labels):
    if lab < 0: 
        continue
    cluster_idx = np.where(labels==lab)[0]
    if cluster_idx.size < 3:
        continue
    mst = mst_on_cluster(D_cc, cluster_idx)
    k = max(3, int(cluster_idx.size/100))
    centers = kcenters_on_tree(mst, k, objective="max")
    Gf = fault_graph_from_mst_and_kcenters(mst, centers, weight_agg="mean")
    rows, cols = Gf.nonzero()
    for i,j in zip(rows, cols):
        if i<j:
            w = float(Gf[i,j])
            r0,c0 = sub_coords[i]
            r1,c1 = sub_coords[j]
            fault_edges_list.append([int(r0),int(c0),int(r1),int(c1),w])
fault_edges = np.array(fault_edges_list, dtype=np.float32) if len(fault_edges_list)>0 else np.zeros((0,5),dtype=np.float32)
print("Fault edges:", fault_edges.shape)


## 13) Animation

In [None]:

anim_path = "fault_growth.gif"
if fault_edges.shape[0] > 0:
    animate_fault_growth(dat['arrays']['intensity'], fault_edges, anim_path, steps=25)
    from IPython.display import Image, display
    display(Image(filename=anim_path))
else:
    print("No fault edges to animate.")


## 14) Threshold at τ = 0.3

In [None]:

tau = 0.3
img = dat['arrays']['intensity']
overlay = np.dstack([img,img,img]).astype(np.float32)
H, W = img.shape[:2]
thr_edges = fault_edges[fault_edges[:,-1] <= tau]
for e in thr_edges:
    r0,c0,r1,c1,w = e
    rr = np.linspace(r0, r1, num=int(max(abs(r1-r0),abs(c1-c0))+1)).astype(int)
    cc = np.linspace(c0, c1, num=rr.shape[0]).astype(int)
    rr = np.clip(rr, 0, H-1); cc = np.clip(cc, 0, W-1)
    overlay[rr,cc,0] = 255; overlay[rr,cc,1] = 0; overlay[rr,cc,2] = 0
plt.figure(figsize=(5,5)); plt.imshow(overlay.astype(np.uint8)); plt.axis('off'); plt.show()


## 15) Metrics vs GT (Lee skeleton)

In [None]:

mask = np.zeros_like(img, dtype=np.uint8)
for e in thr_edges:
    r0,c0,r1,c1,w = e
    rr = np.linspace(r0, r1, num=int(max(abs(r1-r0),abs(c1-c0))+1)).astype(int)
    cc = np.linspace(c0, c1, num=rr.shape[0]).astype(int)
    rr = np.clip(rr, 0, img.shape[0]-1); cc = np.clip(cc, 0, img.shape[1]-1)
    mask[rr,cc] = 1

sk_pred = skeletonize_lee(mask>0)
sk_pred = thicken(sk_pred, pixels=6)
gt = (dat["arrays"].get("label", np.zeros_like(img)) > 0).astype(np.uint8)
sk_gt = skeletonize_lee(gt); sk_gt = thicken(sk_gt, pixels=6)
jac = jaccard_index(sk_pred, sk_gt)
tvs = tversky_index(sk_pred, sk_gt, alpha=1.0, beta=0.5)
wass = wasserstein_distance_skeletons(sk_pred, sk_gt)
print("Jaccard:", jac, "Tversky:", tvs, "Wasserstein:", wass)

plt.figure(figsize=(10,4))
plt.subplot(1,3,1); plt.title("GT (thick skel)"); plt.imshow(sk_gt, cmap='gray'); plt.axis('off')
plt.subplot(1,3,2); plt.title("Pred (thick skel)"); plt.imshow(sk_pred, cmap='gray'); plt.axis('off')
plt.subplot(1,3,3); plt.title("Overlay"); plt.imshow(sk_gt*255, cmap='Reds', alpha=0.7); plt.imshow(sk_pred*255, cmap='Blues', alpha=0.5); plt.axis('off')
plt.show()
