In [ ]:
import os
from shapely.geometry import Point, Polygon
from shapely.vectorized import contains
from tifffile import imread
import random
import numpy as np
import pandas as pd
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
import seaborn as sns
from concurrent.futures import ProcessPoolExecutor, as_completed
from functools import partial
import multiprocessing as mp
from scipy.ndimage import gaussian_filter
from skimage.measure import find_contours

sns.set(color_codes=True, style="white")


# Parameters


In [ ]:
# =============================================================================
# ANALYSIS MODE
# =============================================================================
# ROI_MODE can be:
# 'image_automated' - Uses image thresholding/contours to find ROIs (Cellpose/Blobs style)
# 'manual_roi'      - Loads ROI coordinate files from disk
ROI_MODE = 'image_automated'

# =============================================================================
# SPATIAL ANALYSIS PARAMETERS
# =============================================================================
nm_per_pxl = 23.4  # Nanometers per pixel
r_max_nm = 1120    # Maximum distance to analyze (nm)
ringwidth_nm = 100 # Width of each distance bin (nm)
dr_slidingrings_nm = 20 # Step size between adjacent overlapping rings (nm)

# =============================================================================
# FILE PATHS
# =============================================================================
# folder = "/Volumes/guttman/Guoming_Gao-Resnick/Data_BIF/DEFAULT_USER/20250424_ONIdemo_Guoming/Guoming_data/dSTORM/dSTORM_tif/Malat1"
# os.chdir(folder)

# SMLM Data files
ch1_file = "dSTORM1_TX_nodiff_Malat_AF647_AF488-1-cropped-left-driftcorrected10k.csv"
ch2_file = "dSTORM1_TX_nodiff_Malat_AF647_AF488-1-cropped-right-driftcorrected10k.csv"

# Configuration for 'image_automated' mode
fname_segmentation_reference = "dSTORM1_TX_nodiff_Malat_AF647_AF488-1_befbleach-composite-5x.tif"
sigma_smooth = 30
threshold_normalized = 0.2
area_threshold_pxl = 1e4

# Configuration for 'manual_roi' mode
roi_file_pattern = "cell" 


In [ ]:
# === DISTANCE BIN SETUP ===
bin_starts = np.arange(0, r_max_nm - ringwidth_nm, dr_slidingrings_nm)
bin_ends = bin_starts + ringwidth_nm
bins = bin_starts # For plotting

# === PRECOMPUTE RING AREAS ===
ring_areas_nm2 = np.pi * (bin_ends**2 - bin_starts**2)
ring_areas_pxl2 = ring_areas_nm2 / (nm_per_pxl**2)


# Functions


In [ ]:
def create_cell_polygon(roi_file, nm_per_pxl):
    """Create a Shapely Polygon from manual ROI coordinate file."""
    try:
        coords = pd.read_csv(roi_file, sep="\t", header=None)
        coords_pxl = [tuple(row * 1000 / nm_per_pxl) for _, row in coords.iterrows()]
        return Polygon(coords_pxl)
    except Exception as e:
        print(f"Error loading {roi_file}: {e}")
        return None

def corr_within_cell_polygon_truly_vectorized(df, cell_polygon):
    """Fast vectorized point-in-polygon filtering."""
    minx, miny, maxx, maxy = cell_polygon.bounds
    mask_bounds = (df["x"] >= minx) & (df["x"] <= maxx) & (df["y"] >= miny) & (df["y"] <= maxy)
    df_candidates = df[mask_bounds]
    if len(df_candidates) == 0:
        return np.array([]), np.array([])
    mask_within = contains(cell_polygon, df_candidates["x"], df_candidates["y"])
    final_points = df_candidates[mask_within]
    return final_points["x"].values, final_points["y"].values

def process_reference_point_vectorized(i, x_ref, y_ref, x_interest, y_interest, cell_polygon, bin_starts, bin_ends, nm_per_pxl, ring_areas_pxl2):
    """Core PCF logic for a single reference point (vectorized)."""
    # Edge correction
    rings = [Point(x_ref[i], y_ref[i]).buffer(end).difference(Point(x_ref[i], y_ref[i]).buffer(start))
             for start, end in zip(bin_starts / nm_per_pxl, bin_ends / nm_per_pxl)]

    intersect_areas = np.array([cell_polygon.intersection(Polygon(ring), grid_size=0.1).area for ring in rings])
    edge_correction_factors = 1 / (intersect_areas / ring_areas_pxl2)

    # Distances
    distances = np.sqrt((x_ref[i] - x_interest) ** 2 + (y_ref[i] - y_interest) ** 2) * nm_per_pxl

    # Histogram
    hist_matrix = (bin_starts[:, np.newaxis] <= distances) & (bin_ends[:, np.newaxis] >= distances)
    hist_per_point = np.sum(hist_matrix, axis=1)

    return hist_per_point * edge_correction_factors

def worker_pcf(cell_data, bin_starts, bin_ends, nm_per_pxl, ring_areas_pxl2, ring_areas_nm2):
    """Worker function for multiprocessing."""
    x_ref, y_ref, x_interest, y_interest, cell_polygon, rho_interest_per_nm2, cell_id = cell_data

    if len(x_ref) == 0 or len(x_interest) == 0:
        return cell_id, None

    hist_results = []
    # Loop over reference points within this cell
    for i in range(len(x_ref)):
        res = process_reference_point_vectorized(i, x_ref, y_ref, x_interest, y_interest, cell_polygon, bin_starts, bin_ends, nm_per_pxl, ring_areas_pxl2)
        hist_results.append(res)

    norm_factors = len(x_ref) * ring_areas_nm2 * rho_interest_per_nm2
    pcf_result = np.sum(hist_results, axis=0) / norm_factors
    return cell_id, pcf_result


# Processing


In [ ]:
print("Loading localization data...")
# Load once for all ROIs
df_interest_raw = pd.read_csv(ch1_file, skiprows=lambda i: i > 0 and random.random() > 0.1)
df_interest_raw["x"] = df_interest_raw["x [nm]"] / nm_per_pxl
df_interest_raw["y"] = df_interest_raw["y [nm]"] / nm_per_pxl

df_ref_raw = pd.read_csv(ch2_file)
df_ref_raw["x"] = df_ref_raw["x [nm]"] / nm_per_pxl
df_ref_raw["y"] = df_ref_raw["y [nm]"] / nm_per_pxl
print(f"Loaded {len(df_interest_raw)} interest points and {len(df_ref_raw)} reference points.")


In [ ]:
polygons = []
cell_ids = []

if ROI_MODE == 'image_automated':
    print(f"Segmenting {fname_segmentation_reference}...")
    img = imread(fname_segmentation_reference)
    img_ave = np.mean(img, axis=0) if img.ndim == 3 else img
    img_smoothed = gaussian_filter(img_ave, sigma=sigma_smooth)
    img_norm = (img_smoothed - img_smoothed.min()) / (img_smoothed.max() - img_smoothed.min())
    img_thresh = img_norm > threshold_normalized

    contours = find_contours(img_thresh, 0.5)
    for idx, contour in enumerate(contours):
        if len(contour) >= 3:
            poly = Polygon(contour[:, [1, 0]]) # y,x -> x,y
            if poly.is_valid and poly.area > area_threshold_pxl:
                polygons.append(poly)
                cell_ids.append(f"auto_{idx}")
    print(f"Found {len(polygons)} polygons after filtering.")

elif ROI_MODE == 'manual_roi':
    print(f"Loading manual ROIs matching {roi_file_pattern}...")
    roi_files = sorted([f for f in os.listdir(".") if roi_file_pattern in f and f.endswith(".txt")])
    for f in roi_files:
        poly = create_cell_polygon(f, nm_per_pxl)
        if poly:
            polygons.append(poly)
            cell_ids.append(os.path.basename(f))
    print(f"Loaded {len(polygons)} polygons.")


In [ ]:
# Prepare task list
tasks = []
for poly, cid in zip(polygons, cell_ids):
    x_r, y_r = corr_within_cell_polygon_truly_vectorized(df_ref_raw, poly)
    x_i, y_i = corr_within_cell_polygon_truly_vectorized(df_interest_raw, poly)

    if len(x_r) > 0 and len(x_i) > 0:
        area_nm2 = poly.area * (nm_per_pxl**2)
        rho_i = len(x_i) / area_nm2
        tasks.append((x_r, y_r, x_i, y_i, poly, rho_i, cid))

print(f"Processing {len(tasks)} valid ROIs in parallel...")

results = {}
n_proc = min(len(tasks), mp.cpu_count() - 1) if len(tasks) > 0 else 1

if tasks:
    with ProcessPoolExecutor(max_workers=n_proc) as executor:
        futures = {executor.submit(worker_pcf, t, bin_starts, bin_ends, nm_per_pxl, ring_areas_pxl2, ring_areas_nm2): t for t in tasks}
        for future in tqdm(as_completed(futures), total=len(futures)):
            cid, pcf = future.result()
            if pcf is not None:
                results[cid] = pcf

print(f"Completed processing. Results for {len(results)} ROIs.")


# Visualization


In [ ]:
if results:
    plt.figure(figsize=(8, 6))
    all_pcfs = []
    for cid, pcf in results.items():
        plt.plot(bins, pcf, color='gray', alpha=0.3)
        all_pcfs.append(pcf)

    mean_pcf = np.mean(all_pcfs, axis=0)
    plt.plot(bins, mean_pcf, color='red', lw=2, label='Mean PCF')
    plt.axhline(1, ls='--', color='black', alpha=0.5)
    plt.xlabel('Distance (nm)')
    plt.ylabel('G(r)')
    plt.title(f'PCF Analysis (n={len(results)})')
    plt.legend()
    plt.show()
