In [None]:
import ipywidgets as widgets
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from aicsimageio import AICSImage
import pyclesperanto as cle
from IPython.display import display
from skimage.measure import regionprops
from skimage.measure import regionprops_table
from scipy.spatial import cKDTree
import time, os, sys
from urllib.parse import urlparse
import skimage.io
import matplotlib.pyplot as plt
from cellpose import models, core, utils, io
from skimage import measure, morphology, feature
from scipy.spatial import distance
from aicsimageio import AICSImage
from cellpose import plot
import skimage.morphology
import pandas as pd
import skimage.measure
from skimage.color import label2rgb
import tifffile
from scipy.ndimage import zoom
import napari
from tifffile import imread

io.logger_setup() 

if core.use_gpu()==False:
  raise ImportError("No GPU access, change your runtime")

model = models.CellposeModel(gpu=True)

In [None]:

def compute_crystal_to_nucleus_distances(crystal_vol, nuclei_vol):


    crystal_props = regionprops(crystal_vol)
    crystals = []
    for prop in crystal_props:
        crystals.append({
            'crystal_id': prop.label,
            'crystal_centroid_z': prop.centroid[0],
            'crystal_centroid_y': prop.centroid[1],
            'crystal_centroid_x': prop.centroid[2]
        })
    df_crystals = pd.DataFrame(crystals)
    

    nucleus_props = regionprops(nuclei_vol)
    nuclei = []
    for prop in nucleus_props:
        nuclei.append({
            'nucleus_id': prop.label,
            'nucleus_centroid_z': prop.centroid[0],
            'nucleus_centroid_y': prop.centroid[1],
            'nucleus_centroid_x': prop.centroid[2]
        })
    df_nuclei = pd.DataFrame(nuclei)
    
    
    crystal_centroids = df_crystals[['crystal_centroid_z', 'crystal_centroid_y', 'crystal_centroid_x']].to_numpy()
    nucleus_centroids = df_nuclei[['nucleus_centroid_z', 'nucleus_centroid_y', 'nucleus_centroid_x']].to_numpy()
    

    nearest_ids = []
    nearest_coords = []
    distances = []
    
   
    for cz, cy, cx in crystal_centroids:
        diffs = nucleus_centroids - np.array([cz, cy, cx])
        dists = np.linalg.norm(diffs, axis=1)
        idx_min = np.argmin(dists)
        nearest_ids.append(df_nuclei.loc[idx_min, 'nucleus_id'])
        nearest_coords.append(nucleus_centroids[idx_min])
        distances.append(dists[idx_min])
 
    df_crystals['nearest_nucleus_id'] = nearest_ids
    df_crystals[['nucleus_centroid_z', 'nucleus_centroid_y', 'nucleus_centroid_x']] = np.vstack(nearest_coords)
    df_crystals['distance'] = distances
    
    return df_crystals

def remove_far_crystals(crystal_vol, df_distances, max_dist):

    keep_ids = set(df_distances.loc[df_distances['distance'] <= max_dist, 'crystal_id'])
    
    cleaned = crystal_vol.copy()
    mask_far = ~np.isin(cleaned, list(keep_ids))  
    cleaned[mask_far & (cleaned > 0)] = 0
    
    return cleaned

def remove_small_labels(label_vol, min_voxels):

    labels, counts = np.unique(label_vol, return_counts=True)
    keep_labels = [lab for lab, cnt in zip(labels[1:], counts[1:]) if cnt >= min_voxels]
    mapping = {0: 0}
    new_id = 1
    for lab in sorted(keep_labels):
        mapping[lab] = new_id
        new_id += 1
    cleaned = np.zeros_like(label_vol, dtype=np.int32)
    for old_lab, new_lab in mapping.items():
        if new_lab == 0:
            continue
        cleaned[label_vol == old_lab] = new_lab

    return cleaned

def remove_large_labels(label_vol, max_voxels):
  
    labels, counts = np.unique(label_vol, return_counts=True)
   
    keep_labels = [lab for lab, cnt in zip(labels[1:], counts[1:]) if cnt <= max_voxels]
    
    
    mapping = {0: 0}
    new_id = 1
    for lab in sorted(keep_labels):
        mapping[lab] = new_id
        new_id += 1
        
    
    cleaned = np.zeros_like(label_vol, dtype=np.int32)
    for old_lab, new_lab in mapping.items():
        if new_lab == 0:
            continue
        cleaned[label_vol == old_lab] = new_lab 
    
    return cleaned
    
def assign_nuclei_to_neurons(neuron_vol, nuclei_vol):
    
    records = []
    unique_nuclei = np.unique(nuclei_vol)
    unique_nuclei = unique_nuclei[unique_nuclei > 0]  
    
    
    assignment_vol = np.zeros_like(nuclei_vol, dtype=np.int32)
    
    for nuc_id in unique_nuclei:
        
        nuc_mask = (nuclei_vol == nuc_id)
        
        nuc_voxels = np.count_nonzero(nuc_mask)
      
        overlapping_neurons = neuron_vol[nuc_mask]
        
        if overlapping_neurons.size == 0:
            
            continue
        
        
        neur_labels, counts = np.unique(overlapping_neurons, return_counts=True)
        
       
        valid_indices = neur_labels > 0
        neur_labels = neur_labels[valid_indices]
        counts = counts[valid_indices]
        
        if counts.size == 0:
            
            assigned_neuron = 0
            overlap_count = 0
        else:
            
            max_idx = np.argmax(counts)
            assigned_neuron = int(neur_labels[max_idx])
            overlap_count = int(counts[max_idx])
        
        overlap_fraction = overlap_count / nuc_voxels
        
        
        records.append({
            'nucleus_id': nuc_id,
            'nucleus_size_voxels': nuc_voxels,
            'assigned_neuron_id': assigned_neuron,
            'overlap_voxels': overlap_count,
            'overlap_fraction': overlap_fraction
        })
        
        
        if assigned_neuron > 0:
            assignment_vol[nuc_mask] = assigned_neuron
    
    df_assign = pd.DataFrame.from_records(records)
    return df_assign, assignment_vol

def filter_neurons_by_nuclei_assignments(neuron_vol, df_assign):
    
    valid_neuron_ids = set(df_assign['assigned_neuron_id'].values)
    valid_neuron_ids.discard(0) 

    to_remove = ~np.isin(neuron_vol, list(valid_neuron_ids))

    filtered_neurons = neuron_vol.copy()
    filtered_neurons[to_remove & (filtered_neurons > 0)] = 0

    return filtered_neurons

def reindex_neurons_and_nuclei(neuron_vol, nuclei_vol, df_assign):
    
    assigned_neurons = df_assign['assigned_neuron_id'].unique()
    assigned_neurons = assigned_neurons[assigned_neurons > 0]
    assigned_neurons = np.unique(assigned_neurons)
    

    sorted_ids = np.sort(assigned_neurons)
    mapping = {old: new for new, old in enumerate(sorted_ids, start=1)}
    

    new_neuron_vol = np.zeros_like(neuron_vol, dtype=np.int32)
    for old_id, new_id in mapping.items():
        new_neuron_vol[neuron_vol == old_id] = new_id
    
    
    new_nuclei_vol = np.zeros_like(nuclei_vol, dtype=np.int32)
    for _, row in df_assign.iterrows():
        old_nuc = int(row['nucleus_id'])
        old_neu = int(row['assigned_neuron_id'])
        if old_neu > 0 and old_neu in mapping:
            new_id = mapping[old_neu]
            new_nuclei_vol[nuclei_vol == old_nuc] = new_id
        

    mapping_df = pd.DataFrame({
        'old_neuron_id': sorted_ids,
        'new_id': np.arange(1, len(sorted_ids) + 1)
    })
    
    return new_neuron_vol, new_nuclei_vol, mapping_df

def generate_corrected_neuron_mask(neuron_vol, nuclei_vol):
    
    corrected = neuron_vol.copy()
    
    mask_nuclei = (nuclei_vol > 0)
    corrected[mask_nuclei] = nuclei_vol[mask_nuclei]
    return corrected

def filter_labels_by_voxel_counts(nuclei_vol, neuron_vol, df_counts, 
                                  nuc_min=0, nuc_max=np.inf, 
                                  neu_min=0, neu_max=np.inf):

    cond = (
        (df_counts['nucleus_voxels'] >= nuc_min) &
        (df_counts['nucleus_voxels'] <= nuc_max) &
        (df_counts['neuron_voxels']  >= neu_min) &
        (df_counts['neuron_voxels']  <= neu_max)
    )
    kept_df = df_counts[cond]
    kept_labels = kept_df['label'].to_numpy()

    
    filtered_nuclei = np.where(np.isin(nuclei_vol, kept_labels), nuclei_vol, 0)
    filtered_neurons = np.where(np.isin(neuron_vol, kept_labels), neuron_vol, 0)

    
    unique_kept = np.unique(kept_labels)
    unique_kept = unique_kept[unique_kept > 0]
    new_ids = np.arange(1, len(unique_kept) + 1)
    mapping_dict = {old: new for old, new in zip(unique_kept, new_ids)}

    
    reindexed_nuclei = np.zeros_like(filtered_nuclei, dtype=np.int32)
    reindexed_neurons = np.zeros_like(filtered_neurons, dtype=np.int32)
    for old, new in mapping_dict.items():
        reindexed_nuclei[filtered_nuclei == old] = new
        reindexed_neurons[filtered_neurons == old] = new

   
    mapping_df = pd.DataFrame({
        'old_label': unique_kept,
        'new_label': new_ids
    })

    return filtered_nuclei, filtered_neurons, reindexed_nuclei, reindexed_neurons, mapping_df

def compute_crystal_neuron_metrics(crystal_vol, neuron_vol):
    
    crystal_props = regionprops(crystal_vol)
    crystal_records = []
    for prop in crystal_props:
        crystal_records.append({
            'crystal_id': prop.label,
            'crystal_centroid_z': prop.centroid[0],
            'crystal_centroid_y': prop.centroid[1],
            'crystal_centroid_x': prop.centroid[2],
            'crystal_size': prop.area
        })
    df_crystals = pd.DataFrame(crystal_records)
    
   
    neuron_props = regionprops(neuron_vol)
    neuron_records = []
    for prop in neuron_props:
        neuron_records.append({
            'neuron_id': prop.label,
            'neuron_centroid_z': prop.centroid[0],
            'neuron_centroid_y': prop.centroid[1],
            'neuron_centroid_x': prop.centroid[2]
        })
    df_neurons = pd.DataFrame(neuron_records)
    

    if df_neurons.empty or df_crystals.empty:
        return pd.DataFrame(columns=[
            'crystal_id',
            'crystal_centroid_z', 'crystal_centroid_y', 'crystal_centroid_x',
            'nearest_neuron_id',
            'neuron_centroid_z', 'neuron_centroid_y', 'neuron_centroid_x',
            'distance',
            'overlap_voxels',
            'overlap_fraction'
        ])
    
   
    crystal_coords = df_crystals[['crystal_centroid_z', 'crystal_centroid_y', 'crystal_centroid_x']].to_numpy()
    neuron_coords = df_neurons[['neuron_centroid_z', 'neuron_centroid_y', 'neuron_centroid_x']].to_numpy()
    
    
    dists = cdist(crystal_coords, neuron_coords)  
    nearest_idx = np.argmin(dists, axis=1)
    df_crystals['nearest_neuron_id'] = df_neurons.loc[nearest_idx, 'neuron_id'].values
    df_crystals['neuron_centroid_z'] = df_neurons.loc[nearest_idx, 'neuron_centroid_z'].values
    df_crystals['neuron_centroid_y'] = df_neurons.loc[nearest_idx, 'neuron_centroid_y'].values
    df_crystals['neuron_centroid_x'] = df_neurons.loc[nearest_idx, 'neuron_centroid_x'].values
    df_crystals['distance'] = dists[np.arange(dists.shape[0]), nearest_idx]
    
    # For each crystal, compute overlap with its nearest neuron
    overlap_voxels = []
    overlap_fraction = []
    for _, row in df_crystals.iterrows():
        cid = int(row['crystal_id'])
        nid = int(row['nearest_neuron_id'])
        
        crystal_mask = (crystal_vol == cid)
        total_vox = int(row['crystal_size'])
        
        if nid > 0:
            overlap_count = int(np.count_nonzero(neuron_vol[crystal_mask] == nid))
        else:
            overlap_count = 0
        
        overlap_voxels.append(overlap_count)
        frac = overlap_count / total_vox if total_vox > 0 else 0.0
        overlap_fraction.append(frac)
    
    df_crystals['overlap_voxels'] = overlap_voxels
    df_crystals['overlap_fraction'] = overlap_fraction
    
    return df_crystals

def remove_crystals_by_overlap(crystal_vol, df_metrics, threshold=0.4):

 
    remove_ids = df_metrics.loc[df_metrics['overlap_fraction'] < threshold, 'crystal_id'].values


    filtered_crystals = crystal_vol.copy()
    mask_remove = np.isin(filtered_crystals, remove_ids)
    filtered_crystals[mask_remove] = 0

    return filtered_crystals

def compute_crystal_to_nucleus_distances(crystal_vol, nuclei_vol):
   
    crystal_props = regionprops(crystal_vol)
    crystals = []
    for prop in crystal_props:
        crystals.append({
            'crystal_id': prop.label,
            'crystal_centroid_z': prop.centroid[0],
            'crystal_centroid_y': prop.centroid[1],
            'crystal_centroid_x': prop.centroid[2]
        })
    df_crystals = pd.DataFrame(crystals)
    

    nucleus_props = regionprops(nuclei_vol)
    nuclei = []
    for prop in nucleus_props:
        nuclei.append({
            'nucleus_id': prop.label,
            'nucleus_centroid_z': prop.centroid[0],
            'nucleus_centroid_y': prop.centroid[1],
            'nucleus_centroid_x': prop.centroid[2]
        })
    df_nuclei = pd.DataFrame(nuclei)
 
    crystal_centroids = df_crystals[['crystal_centroid_z', 'crystal_centroid_y', 'crystal_centroid_x']].to_numpy()
    nucleus_centroids = df_nuclei[['nucleus_centroid_z', 'nucleus_centroid_y', 'nucleus_centroid_x']].to_numpy()
    
 
    nearest_ids = []
    nearest_coords = []
    distances = []
   
    for cz, cy, cx in crystal_centroids:
        diffs = nucleus_centroids - np.array([cz, cy, cx])
        dists = np.linalg.norm(diffs, axis=1)
        idx_min = np.argmin(dists)
        nearest_ids.append(df_nuclei.loc[idx_min, 'nucleus_id'])
        nearest_coords.append(nucleus_centroids[idx_min])
        distances.append(dists[idx_min])
    

    df_crystals['nearest_nucleus_id'] = nearest_ids
    df_crystals[['nucleus_centroid_z', 'nucleus_centroid_y', 'nucleus_centroid_x']] = np.vstack(nearest_coords)
    df_crystals['distance'] = distances
    
    return df_crystals

def map_crystals_to_neurons(crystal_vol, df_crystal_metrics):

    mapped_crystals = np.zeros_like(crystal_vol, dtype=np.int32)


    mapping_dict = dict(zip(df_crystal_metrics['crystal_id'], df_crystal_metrics['nearest_neuron_id']))

  
    for cid in np.unique(crystal_vol):
        if cid == 0:
            continue
        nid = mapping_dict.get(cid, 0)
        if nid > 0:
            mapped_crystals[crystal_vol == cid] = nid
     

    return mapped_crystals



In [None]:

image_path = "czi file"
img = AICSImage(image_path)
image = img.get_image_data("CZYX")  

NeuN = image[0]
DAPI = image[1]
Crystal = image[2]

scale_factors = (1.0, 1.0, 1, 1)

downsampled = zoom(image, zoom=scale_factors, order=1)
downsampled.shape

In [None]:
masks_Neuron, flows_Neuron, _ = model.eval(image[0], z_axis=0,
                                                  batch_size=8,
                                                  do_3D=False, stitch_threshold=0.5,
                                                  anisotropy= 31/7/4)
masks_Nuclei, flows_Nuclei, _ = model.eval(image[1], z_axis=0,
                                                  batch_size=8,
                                                  do_3D=False, stitch_threshold=0.5,
                                                  anisotropy= 31/7/4)

tifffile.imwrite('mask_neurons.tif', masks_Neuron.astype(np.uint16))
tifffile.imwrite('mask_nuclei.tif', masks_Nuclei.astype(np.uint16))

In [None]:

mask_neurons = imread('mask_neuron.tif')
mask_DAPI = imread('mask_nuclei.tif')
crystalraw = image[2]
max_pixel_value = np.max(crystalraw)
threshold_value = 0.10 * max_pixel_value
filtered_crystalimg = np.where(crystalraw >= threshold_value, crystalraw, 0)
gpu_crystal = cle.push(filtered_crystalimg)
labeled_crystal = cle.voronoi_otsu_labeling(gpu_crystal, spot_sigma=3.0, outline_sigma=2.0)

labeled_crystal = cle.pull(labeled_crystal)

In [None]:

props = regionprops(labeled_crystal)


label_ids = [prop.label for prop in props]
areas = [prop.area for prop in props]


df_props = pd.DataFrame({
    'label': label_ids,
    'size (voxels)': areas
}).sort_values('size (voxels)').reset_index(drop=True)

# Plot histogram of all label sizes
plt.figure(figsize=(6, 4))
plt.hist(df_props['size (voxels)'], bins=50)
plt.yscale('log')
plt.xlabel('Label Size (voxels)')
plt.ylabel('Number of labels')
plt.title('Distribution of Sizes (i feel the voxel here mismatches the voxel used in filtering function) just check with napari assistant')
plt.show()

In [None]:
crystals_removesmall = remove_small_labels(labeled_crystal, 3000)
crystals_removelarge = remove_large_labels(labeled_crystal, 7000)

mask = crystals_removelarge > 0                  
filtered_image_cpu =filtered_crystalimg.copy()
filtered_image_cpu[~mask] = 0                     

In [None]:


props = regionprops(mask_DAPI)


label_ids = [prop.label for prop in props]
areas = [prop.area for prop in props]


df_props = pd.DataFrame({
    'label': label_ids,
    'size (voxels)': areas
}).sort_values('size (voxels)').reset_index(drop=True)

# Plot histogram of all label sizes
plt.figure(figsize=(6, 4))
plt.hist(df_props['size (voxels)'], bins=50)
plt.yscale('log')
plt.xlabel('Label Size (voxels)')
plt.ylabel('Number of labels')
plt.title('Distribution of Sizes (i feel the voxel here mismatches the voxel used in filtering function) just check with napari assistant')
plt.show()

In [None]:
mask_neurons = remove_small_labels(mask_neurons, 8000)

df_distances = compute_crystal_to_nucleus_distances(mask_DAPI, mask_neurons)

plt.figure(figsize=(6, 4))
plt.hist(df_distances['distance'], bins=50)
plt.yscale('log')
plt.xlabel('Distance (voxels)')
plt.ylabel('Number of labels')
plt.title('Distribution of Distance')
plt.show()

In [None]:

max_dist = 50.0
filtered_crystal = remove_small_labels(labeled_crystal, 1000)
filtered_neuron_nuclei = remove_far_crystals(mask_DAPI, df_distances, max_dist)
filtered_neuron_nuclei = remove_small_labels(filtered_neuron_nuclei, 10000)
df_assign, assigned_nuclei = assign_nuclei_to_neurons(mask_neurons, filtered_neuron_nuclei)

neurons_with_nuclei = filter_neurons_by_nuclei_assignments(mask_neurons, df_assign)

new_neurons, new_nuclei, mapping_df = reindex_neurons_and_nuclei(mask_neurons, filtered_neuron_nuclei, df_assign)
neurons = new_neurons
nuclei = new_nuclei

labels = np.unique(neurons)
labels = labels[labels > 0] 

records = []
for lab in labels:
    neuron_size = int(np.count_nonzero(neurons == lab))
    nucleus_size = int(np.count_nonzero(nuclei == lab))
    overlap_size = int(np.count_nonzero((neurons == lab) & (nuclei == lab)))
    overlap_fraction = overlap_size / nucleus_size if nucleus_size > 0 else 0.0
    records.append({
        'label': lab,
        'neuron_size': neuron_size,
        'nucleus_size': nucleus_size,
        'overlap_size': overlap_size,
        'overlap_fraction': overlap_fraction
    })

df_overlap = pd.DataFrame.from_records(records)

plt.figure(figsize=(6, 4))
plt.scatter(df_overlap['nucleus_size'], df_overlap['overlap_fraction'], alpha=0.7)
plt.xlabel('Nucleus Size (voxels)')
plt.ylabel('Overlap Fraction')
plt.title('Overlap Fraction vs. Nucleus Size')
plt.ylim(0, 1.05)
plt.grid(True)
plt.show()

In [None]:

valid_labels = df_overlap.loc[df_overlap['overlap_fraction'] >= 0.40, 'label'].values


filtered_neurons = new_neurons.copy()
filtered_nuclei = new_nuclei.copy()

mask_remove_neurons = ~np.isin(filtered_neurons, valid_labels)
filtered_neurons[mask_remove_neurons & (filtered_neurons > 0)] = 0

mask_remove_nuclei = ~np.isin(filtered_nuclei, valid_labels)
filtered_nuclei[mask_remove_nuclei & (filtered_nuclei > 0)] = 0


remaining_labels = np.unique(filtered_neurons)
remaining_labels = remaining_labels[remaining_labels > 0]


mapping = {old: new for new, old in enumerate(np.sort(remaining_labels), start=1)}


reindexed_neurons = np.zeros_like(filtered_neurons, dtype=np.int32)
for old, new in mapping.items():
    reindexed_neurons[filtered_neurons == old] = new


reindexed_nuclei = np.zeros_like(filtered_nuclei, dtype=np.int32)
for old, new in mapping.items():
    reindexed_nuclei[filtered_nuclei == old] = new


mapping_df = pd.DataFrame({
    'old_label': np.sort(remaining_labels),
    'new_label': np.arange(1, len(remaining_labels) + 1)
})


corrected_neuron_vol = generate_corrected_neuron_mask(reindexed_neurons, reindexed_nuclei)
filtered_nuclei = reindexed_nuclei


labels = np.unique(filtered_nuclei)
labels = labels[labels > 0]

records = []
for lab in labels:
    nucleus_count = int(np.count_nonzero(filtered_nuclei == lab))
    neuron_count = int(np.count_nonzero(corrected_neuron_vol == lab))
    records.append({
        'label': lab,
        'nucleus_voxels': nucleus_count,
        'neuron_voxels': neuron_count
    })


df_counts = pd.DataFrame.from_records(records)



plt.figure(figsize=(6, 6))
plt.scatter(df_counts['nucleus_voxels'], df_counts['neuron_voxels'], alpha=0.7)
plt.xlabel('Nucleus Voxels (filtered_nuclei)')
plt.ylabel('Neuron Voxels (corrected_neuron)')
plt.title('Nucleus vs Neuron Voxel Counts per Label')
plt.grid(True)
plt.show()

In [None]:

nucleus_min, nucleus_max = 10000, 50000   
neuron_min, neuron_max = 10000, 80000   

filtered_nuclei2, filtered_neurons2, reindexed_nuclei2, reindexed_neurons2, mapping_df2 = \
    filter_labels_by_voxel_counts(
        filtered_nuclei, 
        corrected_neuron_vol, 
        df_counts, 
        nuc_min=nucleus_min, 
        nuc_max=nucleus_max, 
        neu_min=neuron_min, 
        neu_max=neuron_max
    )


records2 = []
for old, new in zip(mapping_df2['old_label'], mapping_df2['new_label']):
    nvox = np.count_nonzero(reindexed_nuclei2 == new)
    mvox = np.count_nonzero(reindexed_neurons2 == new)
    records2.append({'new_label': new, 'nucleus_voxels': nvox, 'neuron_voxels': mvox})

df_new_counts = pd.DataFrame(records2)
plt.figure(figsize=(6,6))
plt.scatter(df_new_counts['nucleus_voxels'], df_new_counts['neuron_voxels'], alpha=0.7)
plt.xlabel('Nucleus Voxels (reindexed)')
plt.ylabel('Neuron Voxels (reindexed)')
plt.title('Filtered & Reindexed: Nucleus vs. Neuron Voxels')
plt.grid(True)
plt.show()


df_distances = compute_crystal_to_nucleus_distances(filtered_crystal, filtered_nuclei2)

plt.figure(figsize=(6, 4))
plt.hist(df_distances['distance'], bins=50)
plt.yscale('log')
plt.xlabel('Distance (voxels)')
plt.ylabel('Number of labels')
plt.title('Distribution of Distance')
plt.show()

In [None]:

max_dist = 50.0
filtered_crystals_close = remove_far_crystals(filtered_crystal, df_distances, max_dist)

df_crystal_metrics = compute_crystal_neuron_metrics(filtered_crystals_close, reindexed_neurons2)



plt.figure(figsize=(6, 4))
plt.scatter(df_crystal_metrics['distance'], df_crystal_metrics['overlap_fraction'], alpha=0.7)
plt.xlabel('Distance to Nearest Neuron (voxels)')

plt.show()

In [None]:

threshold_value = 0.05
filtered_crystals_strict = remove_crystals_by_overlap(filtered_crystals_close, df_crystal_metrics, threshold_value)
mapped_crystals = map_crystals_to_neurons(filtered_crystals_strict, df_crystal_metrics)

corrected_neuron_vol = generate_corrected_neuron_mask(reindexed_neurons2, reindexed_nuclei2)
corrected_neuron_vol = generate_corrected_neuron_mask(corrected_neuron_vol, mapped_crystals)

In [None]:
viewer = napari.Viewer()

viewer.add_image(image[0], name='neurons', colormap='green')
viewer.add_image(image[1], name='DAPI', colormap='blue')
viewer.add_image(filtered_image_cpu, name='raw', colormap='red')
viewer.add_labels(reindexed_nuclei2, name='assigned DAPI')
viewer.add_labels(corrected_neuron_vol, name='assigned neurons')
viewer.add_labels(mapped_crystals, name='assigned crystals')
viewer.add_labels(mask_DAPI, name='ALL DAPI')
viewer.add_labels(mask_neurons, name='ALL neurons')
viewer.add_labels(crystals_removelarge, name='ALL crystals')


napari.run()

In [None]:

label_volumes = [
    reindexed_nuclei2,       
    corrected_neuron_vol,    
    mapped_crystals          
]

labels_czyx = np.stack(label_volumes, axis=0)

tifffile.imwrite(
    'assigned_nuclei-neuron-crystal.tif',
    labels_czyx,
    photometric='minisblack'
)


In [None]:
import numpy as np
import pandas as pd
import tifffile
from skimage.measure import marching_cubes, mesh_surface_area


def extract_label_table(tif_path):

    labels = tifffile.imread(tif_path)
    if labels.ndim != 4 or labels.shape[0] < 3:
        raise ValueError("Expected shape (C>=3, Z, Y, X), got %s" % (labels.shape,))
    nuc, cell, crystal = labels[0], labels[1], labels[2]


    ids = np.unique(cell)
    ids = ids[ids != 0]


    rows = []
    for lab in ids:
        n0 = int((nuc      == lab).sum())
        n1 = int((cell     == lab).sum())
        n2 = int((crystal  == lab).sum())
        rows.append({
            'label_id':         int(lab),
            'nuc_voxels':       n0,
            'cell_voxels':      n1,
            'crystal_voxels':   n2,
            'crystal_present':  1 if n2>0 else 0,
        })

    return pd.DataFrame(rows)


def compute_sphericity(binary_mask, spacing=(1.0,1.0,1.0)):

    if not np.any(binary_mask):
        return np.nan
    
    verts, faces, _, _ = marching_cubes(binary_mask, level=0.5, spacing=spacing)
    area   = mesh_surface_area(verts, faces)
    volume = binary_mask.sum() * np.prod(spacing)
    return (np.pi**(1/3) * (6*volume)**(2/3)) / area

def extract_label_table_with_sphericity(tif_path, spacing=(1.0,1.0,1.0)):

    labels = tifffile.imread(tif_path)
    if labels.ndim != 4 or labels.shape[0] < 3:
        raise ValueError(f"Expected shape (C>=3,Z,Y,X), got {labels.shape}")
    nuc, cell, crystal = labels[0], labels[1], labels[2]

    ids = np.unique(cell)
    ids = ids[ids != 0]

    rows = []
    for lab in ids:
        nuc_mask    = (nuc     == lab)
        cell_mask   = (cell    == lab)
        cryst_mask  = (crystal == lab)

        n0 = int(nuc_mask.sum())
        n1 = int(cell_mask.sum())
        n2 = int(cryst_mask.sum())

        rows.append({
            'label_id':        int(lab),
            'nuc_voxels':      n0,
            'cell_voxels':     n1,
            'crystal_voxels':  n2,
            'crystal_present': int(n2>0),
            'nuc_sphericity':  compute_sphericity(nuc_mask,   spacing),
            'cell_sphericity': compute_sphericity(cell_mask,  spacing),
        })

    return pd.DataFrame(rows)

df = extract_label_table_with_sphericity('assigned_nuclei-neuron-crystal.tif',
                                          spacing=(0.30, 0.26, 0.26))

In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.stats import ttest_ind
from matplotlib.ticker import MultipleLocator

def p_to_stars(p):
    if p < 0.001:
        return '***'
    elif p < 0.01:
        return '**'
    elif p < 0.05:
        return '*'
    else:
        return 'ns'

def plot_grouped_box_strip(
    df,
    group_col,
    y_col,
    group_names=None,
    filter_y_max=None,
    y_lim=None,
    n_major_ticks=5,
    figsize=(7, 6)
):

    df2 = df.copy()
    if filter_y_max is not None:
        df2 = df2[df2[y_col] <= filter_y_max]
    df2 = df2.dropna(subset=[group_col, y_col])


    groups = sorted(df2[group_col].unique())
    if len(groups) != 2:
        raise ValueError(f"Expected exactly two groups in '{group_col}', got {groups}")
    labels = [
        (group_names[g] if group_names and g in group_names else str(g))
        for g in groups
    ]
    df2['_grp_str'] = df2[group_col].map(dict(zip(groups, labels)))


    d0 = df2[df2[group_col] == groups[0]][y_col]
    d1 = df2[df2[group_col] == groups[1]][y_col]
    _, p = ttest_ind(d0, d1, equal_var=True)
    star = p_to_stars(p)


    sns.set_style("white")
    sns.set_context("talk")
    fig, ax = plt.subplots(figsize=figsize)

   
    sns.boxplot(
        x='_grp_str', y=y_col, data=df2, order=labels,
        showcaps=True, whis=[0,100], showfliers=False,
        boxprops={'facecolor':'none','edgecolor':'black','linewidth':1.5},
        medianprops={'color':'black','linewidth':2},
        whiskerprops={'color':'black','linewidth':1.5},
        capprops={'color':'black','linewidth':1.5},
        flierprops={'marker':'o','markerfacecolor':'none',
                    'markeredgecolor':'black','markersize':4},
        ax=ax
    )


    sns.stripplot(
        x='_grp_str', y=y_col, data=df2, order=labels,
        jitter=0.2, size=5, alpha=0.7, color='black',
        edgecolor='none', ax=ax
    )


    ymin_data, ymax_data = df2[y_col].min(), df2[y_col].max()
    span = ymax_data - ymin_data

    def draw_sig_bracket(x1, x2, y, h, txt):
        
        ax.plot([x1, x1], [y, y+h], 'k-', lw=1.2)
        
        ax.plot([x1, x2], [y+h, y+h], 'k-', lw=1.2)
        
        ax.plot([x2, x2], [y+h, y], 'k-', lw=1.2)
        
        ax.text((x1 + x2) / 2, y + h + 0.01 * span,
                txt, ha='center', va='bottom', fontsize=14, color='black')

    draw_sig_bracket(0, 1, ymax_data, 0.05 * span, star)

    
    if y_lim is not None:
        y0, y1 = y_lim
    elif filter_y_max is not None:
        y0, y1 = 0, filter_y_max
    else:
        y0, y1 = ax.get_ylim()

    ax.set_ylim(y0, y1)
    major_step = (y1 - y0) / n_major_ticks
    minor_step = major_step / 2
    ax.yaxis.set_major_locator(MultipleLocator(major_step))
    ax.yaxis.set_minor_locator(MultipleLocator(minor_step))
    ax.minorticks_on()

    
    for spine in ax.spines.values():
        spine.set_visible(True)
        spine.set_linewidth(1.2)
        spine.set_color('black')

    
    ax.set_xlabel(group_col, fontsize=16, color='black')
    ax.set_ylabel(y_col, fontsize=16, color='black')
    ax.set_title(f"{y_col} by {group_col}", fontsize=18, color='black')
       
    ax.tick_params(axis='both', which='both',
                   direction='in', length=6, width=1.2,
                   bottom=True, left=True,
                   colors='black')



    plt.tight_layout()
    print(p)
    return fig


In [None]:
fig = plot_grouped_box_strip(
    df,
    group_col='crystal_present',
    y_col='nuc_sphericity',
    group_names={0:'No Crystal',1:'With Crystal'},
    filter_y_max=1.0,
    y_lim=(0.0, 1.2),
    n_major_ticks=6
)

plt.show()

In [None]:
fig = plot_grouped_box_strip(
    df,
    group_col='crystal_present',
    y_col='cell_sphericity',
    group_names={0:'No Crystal',1:'With Crystal'},
    filter_y_max=1.0,
    y_lim=(0.0, 1.2),
    n_major_ticks=6
)

plt.show()


In [None]:
fig = plot_grouped_box_strip(
    df,
    group_col='crystal_present',
    y_col='cell_voxels',
    group_names={0: 'No Crystal', 1: 'With Crystal'},
    filter_y_max=100000,
    y_lim=(10000, 100000),
    n_major_ticks=9
)

plt.show()

In [None]:
fig = plot_grouped_box_strip(
    df,
    group_col='crystal_present',
    y_col='nuc_voxels',
    group_names={0: 'No Crystal', 1: 'With Crystal'},
    filter_y_max=50000,
    y_lim=(0, 60000),
    n_major_ticks=6
)

plt.show()

In [None]:
crystals = df[df['crystal_present'] == 1]


crystals['bin'] = pd.cut(crystals['crystal_voxels'], bins=3)


grouped = (
    crystals
    .groupby('bin')['cell_voxels']
    .agg(['mean', 'sem'])
    .reset_index()
)


plt.figure(figsize=(6,4))


plt.errorbar(
    x=[interval.mid for interval in grouped['bin']],
    y=grouped['mean'],
    yerr=grouped['sem'],
    fmt='o-',
    color='firebrick',
    capsize=4
)


plt.scatter(
    crystals['crystal_voxels'],
    crystals['cell_voxels'],
    alpha=0.3
)
plt.xlim(0, 20000)
plt.tick_params(direction='in', length=6, width=1.2, bottom=True, left=True, colors='black')

plt.xlabel("Crystal Size)")
plt.ylabel("Cell Size")
plt.title("")
plt.tight_layout()
plt.ylim(0, 100000)

plt.show()


In [None]:

crystals['bin'] = pd.cut(crystals['crystal_voxels'], bins=3)


grouped = (
    crystals
    .groupby('bin')['nuc_sphericity']
    .agg(['mean', 'sem'])
    .reset_index()
)


plt.figure(figsize=(6,4))


plt.errorbar(
    x=[interval.mid for interval in grouped['bin']],
    y=grouped['mean'],
    yerr=grouped['sem'],
    fmt='o-',
    color='firebrick',
    capsize=4
)


plt.scatter(
    crystals['crystal_voxels'],
    crystals['nuc_sphericity'],
    alpha=0.3
)

plt.xlabel("Crystal Size (voxels)")
plt.ylabel("Nucleus Sphericity")
plt.title("")
plt.ylim(0.4, 1.1)
plt.xlim(0, 20000)
plt.tick_params(direction='in', length=6, width=1.2, bottom=True, left=True, colors='black')
plt.tight_layout()
plt.tick_params(direction='in', length=6, width=1.2, colors='black')

plt.show()