## **Skeletonisation and Exponential Decay Fitting**

In [42]:
# Single-TIFF pipeline on the CENTRAL SQUARE crop (width = height of image)
# Steps: crop -> polarity auto-select -> remove specks -> centroids -> adaptive/connectivity dilation -> skeletonize
# Preview PNG: 2-px thicker (visual only). Exponential plot: blue dots, red line, green dashed.

import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from scipy.ndimage import convolve, distance_transform_edt
from scipy.spatial import cKDTree
from scipy.optimize import curve_fit

from skimage.io import imread
from skimage.filters import threshold_otsu
from skimage.measure import label, regionprops
from skimage.morphology import (
    binary_dilation, binary_closing, remove_small_objects,
    disk, skeletonize
)

In [43]:
# branch lengths from skeleton (needs skan library)
"""
trying to import skan for measuring branch lengths
if it doesnt work just install it with pip install skan
"""
try:
    from skan import Skeleton, summarize
    SKAN_AVAILABLE = True
except:
    SKAN_AVAILABLE = False
    print("skan not installed - need it for proper branch lengths")
    print("run: pip install skan")

In [44]:
# config section change these paths for your computer
"""
setting up all the file paths and folders
input_tif is the image i want to analyze
the output folders are where results get saved
"""
import os
from pathlib import Path
notebook_dir = Path.cwd()
proj_root = notebook_dir.parent
org_dir = str(proj_root)
BASE_DIR = str(proj_root / "CRYO-SEM DATA" / "CRYO-SEM X30000")

input_tif = os.path.join(BASE_DIR , "SAMPLE TEST IMAGE X30000.tif")

out_skeleton_img_dir = os.path.join(org_dir , "SKELETON +EDF/SKELETON")
out_csv_dir          = os.path.join(org_dir , "SKELETON +EDF/ SKELETON +EDF/ CSVS")
out_decay_dir        = os.path.join(org_dir , "SKELETON +EDF/EXPONENTIAL DECAY FITTING")

# make sure all the output folders exist
os.makedirs(out_skeleton_img_dir, exist_ok=True)
os.makedirs(out_csv_dir, exist_ok=True)
os.makedirs(out_decay_dir, exist_ok=True)

# image scale info from the microscope
"""
metadata for converting pixels to micrometers
4.00 x 3.00 micrometers total size
1280 x 961 pixels but i crop width to 961 to make it square
"""
width_um, height_um = 4.00, 3.00
width_px_meta, height_px_meta = 1280, 961

# settings that might need tweaking
"""
these are parameters i can adjust if results look weird
DEFAULT_MIN_BRANCH_UM ignore really short branches
DEFAULT_HEAD_KEEP_PERCENT how much of branch tips to keep
DEFAULT_MIN_RADIUS_PX minimum dilation radius
"""
DEFAULT_MIN_BRANCH_UM = 0.0
DEFAULT_HEAD_KEEP_PERCENT = 100
DEFAULT_MIN_RADIUS_PX = 6
MAX_SAFE_BINARY_RADIUS_PX = 64
PREVIEW_THICKEN_DISK = 2
MIN_CC_AREA = 9  # drop tiny specks before centroiding

def exponential_decay_func(x, a, b, c):
    """exponential decay function for curve fitting"""
    return a * np.exp(-b * x) + c

def dilate_image_safely(pore_map_bool, radius_px, max_safe=MAX_SAFE_BINARY_RADIUS_PX):
    """
    dilate the binary image safely
    if radius is too big it uses distance transform instead to avoid memory errors
    """
    if radius_px <= max_safe:
        try:
            return binary_dilation(pore_map_bool, disk(int(radius_px)))
        except MemoryError:
            pass  # fallback to distance method
    
    # use distance transform method for big radius
    dist = distance_transform_edt(~pore_map_bool)
    return dist <= float(radius_px)

In [45]:
# load the tiff file and crop to central square
"""
checking if the input file exists first
then loading it and making it grayscale if needed
finally cropping to a square shape in the center
"""

# check if file actually exists
if not os.path.isfile(input_tif):
    print("Error: cant find the input file at " + input_tif)
    raise FileNotFoundError("Input file not found")

# load the image
img = imread(input_tif)
if img.ndim > 2:
    img = img[..., 0]  # take first channel if its color

# crop to center square
"""
making the image square by cropping the width
keeping the full height and centering the width crop
this makes analysis consistent across different images
"""
H, W = img.shape
side = H  # width equals height now
x0 = max(0, (W - side) // 2)  # start position for crop
x1 = x0 + side  # end position
img = img[:, x0:x1]  # do the actual cropping

In [46]:
# figure out if structures are bright or dark pixels
"""
trying to automatically detect whether the structures show up as 
bright pixels or dark pixels in the image
this saves me from having to check manually every time
"""

# get threshold value for splitting bright vs dark
try:
    t = threshold_otsu(img)  # otsu method usually works well
except Exception:
    t = float(np.mean(img))  # fallback to average if otsu fails

# make binary masks for both possibilities
bright_fg = img > t  # foreground is bright pixels
dark_fg   = img < t  # foreground is dark pixels

def score_mask(mask):
    """
    count how many objects look like real structure pieces
    not too tiny and not too huge
    """
    lbl = label(mask, connectivity=2)  # find connected components
    if lbl.max() == 0:
        return 0  # no objects found
    
    # get sizes of all objects
    areas = np.bincount(lbl.ravel())[1:]  # skip background label 0
    total = mask.size
    
    # set reasonable size limits
    area_min = max(MIN_CC_AREA, 3)  # not too small
    area_max = max(int(0.01 * total), area_min + 1)  # not too big
    
    # count objects in good size range
    return int(((areas >= area_min) & (areas <= area_max)).sum())

# pick whichever polarity gives more reasonable objects
chosen_mask = dark_fg if score_mask(dark_fg) >= score_mask(bright_fg) else bright_fg

In [47]:
# clean up the mask by removing tiny specks
"""
getting rid of really small dots that are probably just noise
doing this before finding centroids so we dont get random scattered points
"""
clean_mask = remove_small_objects(chosen_mask, min_size=MIN_CC_AREA, connectivity=2)

# find all the separate objects and get their centers
lbl = label(clean_mask, connectivity=2)  # label each connected piece
props = regionprops(lbl)  # get properties of each piece

# check if we actually found anything useful
if len(props) == 0:
    print("Error: no objects found after cleaning")
    print("maybe try changing MIN_CC_AREA or check if polarity is wrong")
    raise RuntimeError("No usable components found")

# get image dimensions
height_px, width_px = clean_mask.shape  # both should be the same since we made it square

# extract the center points of each object
"""
centroid gives (row, col) but i want (x, y) coordinates
so i flip them with [::-1]
then make sure they stay inside the image boundaries
"""
centroids = np.array([p.centroid[::-1] for p in props], dtype=float)  # flip to get (x, y)

# convert to integer pixel coordinates and keep them in bounds
xs = np.clip(np.round(centroids[:, 0]).astype(int), 0, width_px - 1)
ys = np.clip(np.round(centroids[:, 1]).astype(int), 0, height_px - 1)

In [48]:
# calculate how many micrometers each pixel represents
"""
need to convert from pixels to real world measurements
using the original image dimensions and field of view size
cropping doesnt change the scale just removes some pixels
"""
um_per_px_x = width_um  / width_px_meta  # micrometers per pixel in x direction
um_per_px_y = height_um / height_px_meta  # micrometers per pixel in y direction
um_per_px = float((um_per_px_x + um_per_px_y) / 2.0)  # average the two directions

# make a boolean map showing where the pore centers are
"""
creating a binary image where True means theres a pore center at that pixel
starting with all False then setting True at each centroid location
"""
pore_map = np.zeros((height_px, width_px), dtype=bool)  # start with all False
pore_map[ys, xs] = True  # mark the pore center locations

In [49]:
# figure out how big to make the dilation radius
"""
want to connect nearby pores but not merge everything into one blob
using the distance between neighboring pores to guess a good starting radius
"""
pts = np.column_stack([xs, ys])  # combine x and y coordinates into pairs

if len(pts) >= 2:
    # find distance to nearest neighbor for each point
    dists, _ = cKDTree(pts).query(pts, k=2)  # k=2 gives point itself and nearest neighbor
    nn_med = float(np.median(dists[:, 1]))  # take the neighbor distances (skip self at index 0)
    radius_start = max(DEFAULT_MIN_RADIUS_PX, int(np.ceil(nn_med / 2.0)))  # start with half the median distance
else:
    radius_start = DEFAULT_MIN_RADIUS_PX  # fallback if only one point

def connect_nearby_pores(pore_map, start_r, max_r=MAX_SAFE_BINARY_RADIUS_PX):
    """
    grow the radius until we get a reasonable number of connected components
    dont want too many tiny pieces or one giant blob
    """
    # set a target for max number of separate pieces
    target_max_components = max(50, int(0.02 * pore_map.sum()))  # generous cap
    best = None
    r = int(start_r)
    
    # try bigger and bigger radius until we connect enough stuff
    while r <= max_r:
        blobs = dilate_image_safely(pore_map, r)  # grow each pore by radius r
        blobs = binary_closing(blobs, footprint=disk(1))  # fill small gaps
        n_comp = label(blobs, connectivity=2).max()  # count separate pieces
        best = (blobs, r, n_comp)  # save this result
        
        if n_comp <= target_max_components:
            break  # good enough, stop growing
        r += 1  # try bigger radius
    
    return best  # give back the final blobs, radius used, and component count

# do the actual connection
pore_blobs, radius, n_comp = connect_nearby_pores(pore_map, radius_start)

In [50]:
# Skeletonize analysis mask (1 px)
skeleton = skeletonize(pore_blobs).astype(bool)


In [51]:
# count junction points in the skeleton
"""
a junction is where 3 or more branches meet
using a convolution to count neighbors around each pixel
"""

# make a kernel that counts neighbors plus gives center pixel extra weight
branch_kernel = np.array([[1, 1, 1],
                          [1,10, 1],
                          [1, 1, 1]])

# apply the kernel to count neighbors at each pixel
convolved = convolve(skeleton.astype(int), branch_kernel, mode='constant')

# find junction points
"""
center pixel gets value 10 if its part of skeleton
each neighbor adds 1 if its also skeleton
so total is 10 + number_of_neighbors
if >= 13 that means 10 + 3+ neighbors = junction point
"""
junctions = int(np.sum((convolved >= 13) & (skeleton == 1)))

In [52]:
# measure the length of each branch in the skeleton
"""
using skan library to trace all the branches and get their lengths
converting from pixels to micrometers for real measurements
"""
path_lengths_um = []  # start with empty list

# check if skan library is installed and working
if SKAN_AVAILABLE:
    # analyze the skeleton to find all branches
    sk = Skeleton(skeleton)
    branch_df = summarize(sk, separator="_")  # get dataframe with branch info
    
    # figure out which column has the distance data
    """
    different versions of skan use different column names
    so checking both possibilities
    """
    col = "branch_distance" if "branch_distance" in branch_df.columns else "branch-distance"
    
    # get branch lengths in pixels
    lengths_px = branch_df[col].to_numpy(dtype=float)
    
    # convert to micrometers
    path_lengths_um = (lengths_px * um_per_px)
    
    # filter out bad values and tiny branches
    """
    removing infinite values and branches that are too short to be real
    keeping only reasonable sized branches
    """
    path_lengths_um = path_lengths_um[np.isfinite(path_lengths_um) & (path_lengths_um > DEFAULT_MIN_BRANCH_UM)]
    path_lengths_um = path_lengths_um.tolist()  # convert back to list
else:
    print("Skipping branch length export: skan library not available")

# calculate summary statistics
total_paths = len(path_lengths_um)  # how many branches we found
mean_um = float(np.mean(path_lengths_um)) if path_lengths_um else 0.0  # average length
max_um  = float(np.max(path_lengths_um))  if path_lengths_um else 0.0   # longest branch
min_um  = float(np.min(path_lengths_um))  if path_lengths_um else 0.0   # shortest branch

In [53]:
# get the filename without extension for saving outputs
"""
taking just the base filename and removing the .tif part
using this as prefix for all the output files
"""
prefix = os.path.splitext(os.path.basename(input_tif))[0]

# make skeleton thicker for better visibility in preview image
"""
the skeleton is only 1 pixel wide which is hard to see
making it a bit thicker just for the preview image
not changing the actual analysis just making it easier to look at
"""
skel_preview = binary_dilation(skeleton, footprint=disk(PREVIEW_THICKEN_DISK))

# create and save the skeleton preview image
plt.imshow(skel_preview, cmap='gray')  # show in grayscale
title_text = "Skeleton Network\nJunctions: " + str(junctions) + " | Paths: " + str(total_paths) + " | Mean Length: " + str(round(mean_um, 2)) + " µm"
plt.title(title_text)
plt.axis('off')  # hide the axis numbers
plt.tight_layout()  # make it look neat

# save the image
skeleton_img_path = os.path.join(out_skeleton_img_dir, prefix + "_skeleton_network_detailed.png")
plt.savefig(skeleton_img_path, dpi=300)  # high quality
plt.close()  # close to free up memory

In [54]:
# save the path lengths to a csv file
"""
if we found any branches then save all their lengths
each row will be one branch with its length in micrometers
"""
if path_lengths_um:
    path_df = pd.DataFrame({"Path_Length_um": path_lengths_um})
    path_csv_name = prefix + "_skeleton_path_lengths.csv"
    path_df.to_csv(os.path.join(out_csv_dir, path_csv_name), index=False)

# create summary table with all the important results
"""
putting all the key numbers in one table
includes counts, measurements, and settings used
"""
summary_df = pd.DataFrame({
    "Metric": [
        "Total Junctions",
        "Total Path Segments", 
        "Mean Path Length (µm)",
        "Max Path Length (µm)",
        "Min Path Length (µm)",
        "µm per pixel (used)",
        "Dilation radius (px)",
        "Head kept for fit (%)",
        "Min branch used (µm)",
        "Connected components after dilation",
        "Crop used",
    ],
    "Value": [
        junctions,
        total_paths,
        round(mean_um, 3),
        round(max_um, 3), 
        round(min_um, 3),
        round(float(um_per_px), 6),
        int(radius),
        int(DEFAULT_HEAD_KEEP_PERCENT),
        float(DEFAULT_MIN_BRANCH_UM),
        int(n_comp),
        "central square: H=" + str(height_px) + ", W=" + str(width_px),
    ],
})

# save the summary table
summary_csv_name = prefix + "_skeleton_summary_metrics.csv"
summary_df.to_csv(os.path.join(out_csv_dir, summary_csv_name), index=False)

In [55]:
# try to fit a curve to see how branch lengths change
"""
taking all the branch lengths and sorting them
then seeing if they follow an exponential pattern
like how tree branches get shorter as you go out
"""
def exp_decay(x, a, b, c):
    """Exponential decay function: y = a * exp(-b * x) + c"""
    return a * np.exp(-b * x) + c

lengths = np.array(path_lengths_um, dtype=float)
lengths = lengths[np.isfinite(lengths) & (lengths > 0)]  # remove bad values

# need enough data points to make a curve
if lengths.size >= 4:
    # sort branches from biggest to smallest
    lengths_sorted = np.sort(lengths)[::-1]  # flip to get biggest first
    ranks = np.arange(len(lengths_sorted))  # just 0, 1, 2, 3...
    
    # maybe use only the longer branches
    if DEFAULT_HEAD_KEEP_PERCENT < 100:
        cutoff = np.percentile(lengths_sorted, 100 - DEFAULT_HEAD_KEEP_PERCENT)
        mask = lengths_sorted > cutoff
        x_data = ranks[mask]
        y_data = lengths_sorted[mask]
    else:
        x_data = ranks
        y_data = lengths_sorted
    
    # guess some starting values for the curve fitting
    a_guess = max(y_data) - min(y_data) if len(y_data) else lengths_sorted[0]
    c_guess = min(y_data) if len(y_data) else lengths_sorted[-1]
    half_way = max(1, int(0.10 * len(x_data)))
    b_guess = np.log(2) / half_way
    starting_guess = (a_guess, b_guess, c_guess)
    
    # try to fit the curve
    try:
        fitted_params, _ = curve_fit(
            exp_decay, x_data, y_data,
            p0=starting_guess,
            bounds=([0, 0, 0], [np.inf, np.inf, np.inf]),
            maxfev=20000
        )
        
        # make a reference line for comparison
        b_ref = np.log(2) / half_way
        a_ref = lengths_sorted[0]
        c_ref = lengths_sorted[-1]
        reference_line = exp_decay(ranks, a_ref, b_ref, c_ref)
        
        # draw the plot
        fig, ax = plt.subplots(figsize=(9, 6))
        ax.scatter(ranks, lengths_sorted, label='Data Points', color='blue')
        ax.plot(ranks, exp_decay(ranks, *fitted_params), 'r-', label='Fitted Curve')
        ax.plot(ranks, reference_line, 'g--', label='Reference')
        
        # add labels
        title_text = prefix + " - How Branch Lengths Change"
        ax.set_title(title_text, pad=16)
        ax.set_xlabel("Branch Number (sorted)", labelpad=14)
        ax.set_ylabel("Branch Length (micrometers)", labelpad=14)
        ax.grid(True, alpha=0.4)
        ax.legend(loc='center left', bbox_to_anchor=(1.02, 0.5))
        plt.tight_layout(rect=[0, 0, 0.80, 1])
        
        # save the picture
        plot_filename = prefix + "_decay_fit.png"
        decay_png = os.path.join(out_decay_dir, plot_filename)
        plt.savefig(decay_png, dpi=300, bbox_inches='tight')
        plt.close()
        
        print("Saved plot: " + decay_png)
        
    except Exception as e:
        print("Could not fit curve for " + prefix + ": " + str(e))
        
else:
    print(prefix + ": not enough branches to fit curve (only " + str(lengths.size) + ")")

Saved plot: c:\Users\walsh\Documents\GitHub\AGAROSE-HYDROGEL-TRENDS-USING-AI-ML\SKELETON +EDF/EXPONENTIAL DECAY FITTING\SAMPLE TEST IMAGE X30000_decay_fit.png


In [56]:
# print summary of what we found
"""
showing all the main results so we can see if things worked right
"""
print("")
print("Done processing: " + input_tif)
print("Used radius: " + str(radius) + "px, scale: " + str(round(um_per_px, 6)) + " µm/px")
print("")
print("Results:")
print(" Junctions found: " + str(junctions))
print(" Path segments: " + str(total_paths))
print("")
print("Path lengths:")
print(" Average: " + str(round(mean_um, 2)) + " µm")
print(" Longest: " + str(round(max_um, 2)) + " µm") 
print(" Shortest: " + str(round(min_um, 2)) + " µm")
print("")
print("Output files:")
print(" Skeleton image: " + skeleton_img_path)
print(" CSV files saved in: " + out_csv_dir)


Done processing: c:\Users\walsh\Documents\GitHub\AGAROSE-HYDROGEL-TRENDS-USING-AI-ML\CRYO-SEM DATA\CRYO-SEM X30000\SAMPLE TEST IMAGE X30000.tif
Used radius: 22px, scale: 0.003123 µm/px

Results:
 Junctions found: 682
 Path segments: 415

Path lengths:
 Average: 0.13 µm
 Longest: 0.79 µm
 Shortest: 0.0 µm

Output files:
 Skeleton image: c:\Users\walsh\Documents\GitHub\AGAROSE-HYDROGEL-TRENDS-USING-AI-ML\SKELETON +EDF/SKELETON\SAMPLE TEST IMAGE X30000_skeleton_network_detailed.png
 CSV files saved in: c:\Users\walsh\Documents\GitHub\AGAROSE-HYDROGEL-TRENDS-USING-AI-ML\SKELETON +EDF/ SKELETON +EDF/ CSVS
