In [8]:
# Import statements

import os
import numpy as np
from scipy.spatial import distance_matrix
import glob
import seaborn as sns
import matplotlib.pyplot as plt
import re
from scipy.ndimage import distance_transform_edt
import matplotlib.pyplot as plt
from scipy.stats import ttest_ind
from scipy.spatial.distance import cdist
import pandas as pd

In [None]:
# Functions (robust SWC loader + verbose pairwise computation)

# Robust SWC loader: uses pandas to handle comments and variable spacing; returns empty array on failure
def load_swc_coordinates(file_path):
    import pandas as _pd
    try:
        # SWC format: n, type, x, y, z, radius, parent -- we want cols 2:5 (x,y,z)
        df = _pd.read_csv(file_path, sep=r'\s+', comment='#', header=None, engine='python')
        if df.shape[1] < 5:
            # fallback to numpy if unexpected format
            data = np.loadtxt(file_path)
            coords = data[:, 2:5]
        else:
            coords = df.iloc[:, 2:5].to_numpy(dtype=float)
        return coords
    except Exception as e:
        print(f'WARNING: failed to load SWC {file_path}: {e}')
        return np.empty((0, 3))

def average_bidirectional_distance(A, B):
    # Ensure that A and B are 2D arrays
    A = np.array(A)
    B = np.array(B)
    # If either has zero points, return NaN to indicate invalid comparison
    if A.size == 0 or B.size == 0:
        return float('nan')
    if A.ndim != 2 or B.ndim != 2:
        raise ValueError("Both A and B should be 2D arrays representing the coordinates (n_points, 3).")
    # Compute pairwise distance matrix between A and B
    dists = cdist(A, B)
    # Average distances A->B and B->A
    avg_a_to_b = np.mean(np.min(dists, axis=1))  # Minimum distance for each point in A to closest point in B
    avg_b_to_a = np.mean(np.min(dists, axis=0))  # Minimum distance for each point in B to closest point in A
    return (avg_a_to_b + avg_b_to_a) / 2

def compute_pairwise_distances(swc_folder):
    swc_files = sorted(glob.glob(os.path.join(swc_folder, '*.swc')))
    print(f'INFO: compute_pairwise_distances - folder={swc_folder} files_found={len(swc_files)}')
    pairwise_means = []
    # compute pairwise distances between unique pairs
    for i, file_i in enumerate(swc_files):
        coords_i = load_swc_coordinates(file_i)
        if coords_i.size == 0:
            print(f'  SKIP: could not load or empty coords for {file_i}')
            continue
        for j, file_j in enumerate(swc_files[i + 1:], start=i + 1):
            coords_j = load_swc_coordinates(file_j)
            if coords_j.size == 0:
                print(f'  SKIP: could not load or empty coords for {file_j}')
                continue
            val = average_bidirectional_distance(coords_i, coords_j)
            if not np.isnan(val):
                pairwise_means.append(val)
    print(f'INFO: computed {len(pairwise_means)} pairwise means for folder {swc_folder}')
    return pairwise_means


In [None]:
# Debug: list SWC counts and sample files for the GABA folders before computing distances
import os, glob
gaba_folders = [
    os.path.join('data','skeletons','skeletons_affine+rigid','GABA','left'),
    os.path.join('data','skeletons','skeletons_affine+rigid','GABA','right'),
    os.path.join('data','skeletons','skeletons_diffeomorphic','GABA','left'),
    os.path.join('data','skeletons','skeletons_diffeomorphic','GABA','right'),
]
for f in gaba_folders:
    files = sorted(glob.glob(os.path.join(f, '*.swc')))
    print(f'DEBUG: folder={f} count={len(files)}')
    if len(files) > 0:
        print('  sample->', files[:5])


In [None]:
## Data wrangling for GABA skeletons (safe, fixed)
# Build combined DataFrames and compute pairwise statistics

left_GABA_affine_folder = 'data/skeletons/skeletons_affine+rigid/GABA/left/'
right_GABA_affine_folder = 'data/skeletons/skeletons_affine+rigid/GABA/right/'
left_GABA_diffeo_folder = 'data/skeletons/skeletons_diffeomorphic/GABA/left/'
right_GABA_diffeo_folder = 'data/skeletons/skeletons_diffeomorphic/GABA/right/'

left_GABA_affine_pairwise_means = compute_pairwise_distances(left_GABA_affine_folder)
right_GABA_affine_pairwise_means = compute_pairwise_distances(right_GABA_affine_folder)
left_GABA_diffeo_pairwise_means = compute_pairwise_distances(left_GABA_diffeo_folder)
right_GABA_diffeo_pairwise_means = compute_pairwise_distances(right_GABA_diffeo_folder)

# Safely compute means/stds — handle empty lists
def safe_stats(lst):
    import numpy as _np
    if not lst:
        return _np.nan, _np.nan
    return _np.mean(lst), _np.std(lst)

mean_GABA_affine_left, std_GABA_affine_left = safe_stats(left_GABA_affine_pairwise_means)
mean_GABA_affine_right, std_GABA_affine_right = safe_stats(right_GABA_affine_pairwise_means)
mean_GABA_diffeo_left, std_GABA_diffeo_left = safe_stats(left_GABA_diffeo_pairwise_means)
mean_GABA_diffeo_right, std_GABA_diffeo_right = safe_stats(right_GABA_diffeo_pairwise_means)

# Create DataFrames
import pandas as _pd

df_GABA_affine_left = _pd.DataFrame({'pairwise_mean': left_GABA_affine_pairwise_means, 'set': 'Left'})
df_GABA_affine_right = _pd.DataFrame({'pairwise_mean': right_GABA_affine_pairwise_means, 'set': 'Right'})
df_GABA_affine = _pd.concat([df_GABA_affine_left, df_GABA_affine_right], ignore_index=True)

df_GABA_diffeo_left = _pd.DataFrame({'pairwise_mean': left_GABA_diffeo_pairwise_means, 'set': 'Left'})
df_GABA_diffeo_right = _pd.DataFrame({'pairwise_mean': right_GABA_diffeo_pairwise_means, 'set': 'Right'})
df_GABA_diffeo = _pd.concat([df_GABA_diffeo_left, df_GABA_diffeo_right], ignore_index=True)

# Combine only if both exist
try:
    df_GABA_affine['Registration'] = 'affine+rigid'
    df_GABA_diffeo['Registration'] = 'diffeomorphic'
    df_GABA = pd.concat([df_GABA_affine, df_GABA_diffeo], ignore_index=True)
except Exception:
    # if concatenation fails, make a minimal df_GABA
    if 'df_GABA' not in globals():
        df_GABA = pd.DataFrame({'pairwise_mean': []})

print('DEBUG: sizes', len(left_GABA_affine_pairwise_means), len(right_GABA_affine_pairwise_means), len(left_GABA_diffeo_pairwise_means), len(right_GABA_diffeo_pairwise_means))


In [None]:
# Data wrangling for inotocin skeletons 

# Load and compute pairwise distances for inotocin left and right datasets (affine+rigid and diffeomorphic)
left_inotocin_affine_folder = 'data/skeletons/skeletons_affine+rigid/inotocin/left/'
right_inotocin_affine_folder = 'data/skeletons/skeletons_affine+rigid/inotocin/right/'

left_inotocin_affine_pairwise_means = compute_pairwise_distances(left_inotocin_affine_folder)
right_inotocin_affine_pairwise_means = compute_pairwise_distances(right_inotocin_affine_folder)

left_inotocin_diffeo_folder = 'data/skeletons/skeletons_diffeomorphic/inotocin/left/'
right_inotocin_diffeo_folder = 'data/skeletons/skeletons_diffeomorphic/inotocin/right/'

left_inotocin_diffeo_pairwise_means = compute_pairwise_distances(left_inotocin_diffeo_folder)
right_inotocin_diffeo_pairwise_means = compute_pairwise_distances(right_inotocin_diffeo_folder)

# Calculate means and standard deviations
mean_inotocin_affine_left = np.mean(left_inotocin_affine_pairwise_means)
std_inotocin_affine_left = np.std(left_inotocin_affine_pairwise_means)

mean_inotocin_affine_right = np.mean(right_inotocin_affine_pairwise_means)
std_inotocin_affine_right = np.std(right_inotocin_affine_pairwise_means)

mean_inotocin_diffeo_left = np.mean(left_inotocin_diffeo_pairwise_means)
std_inotocin_diffeo_left = np.std(left_inotocin_diffeo_pairwise_means)

mean_inotocin_diffeo_right = np.mean(right_inotocin_diffeo_pairwise_means)
std_inotocin_diffeo_right = np.std(right_inotocin_diffeo_pairwise_means)

# Create a DataFrame to combine both sets of data
df_inotocin_affine_left = pd.DataFrame({
    "pairwise_mean": left_inotocin_affine_pairwise_means,
    "set": "Left"
})

df_inotocin_affine_right = pd.DataFrame({
    "pairwise_mean": right_inotocin_affine_pairwise_means,
    "set": "Right"
})

df_inotocin_affine = pd.concat([df_inotocin_affine_left, df_inotocin_affine_right], ignore_index=True)

df_inotocin_diffeo_left = pd.DataFrame({
    "pairwise_mean": left_inotocin_diffeo_pairwise_means,
    "set": "Left"
})

df_inotocin_diffeo_right = pd.DataFrame({
    "pairwise_mean": right_inotocin_diffeo_pairwise_means,
    "set": "Right"
})

df_inotocin_diffeo = pd.concat([df_inotocin_diffeo_left, df_inotocin_diffeo_right], ignore_index=True)

In [None]:
## Generating Figure 3G -- GABA skeleton distance

# make single df
try:
    df_GABA_affine['Registration'] = 'affine+rigid'
    df_GABA_diffeo['Registration'] = 'diffeomorphic'
    df_GABA = pd.concat([df_GABA_affine, df_GABA_diffeo], ignore_index=True)
except Exception:
    # df_GABA may already be present from earlier cells
    pass

# add jitter
rng = np.random.default_rng(321)
if 'x_jitter' not in df_GABA.columns:
    df_GABA['x_jitter'] = df_GABA['Registration'].map({
        'affine+rigid': lambda n=1: rng.uniform(-0.15, 0.15, size=n),
        'diffeomorphic': lambda n=1: rng.uniform(-0.15, 0.15, size=n)
    }).apply(lambda f: f() if callable(f) else 0)

# Compute group means and stds safely
grouped = df_GABA.groupby('Registration')['pairwise_mean']
means = grouped.mean()
medians = grouped.median()
stds = grouped.std()

print('means:\n', means)
print('stds:\n', stds)

from scipy.stats import ttest_ind

# Extract values safely
affine_vals = df_GABA.loc[df_GABA['Registration'] == 'affine+rigid', 'pairwise_mean']
diffeo_vals = df_GABA.loc[df_GABA['Registration'] == 'diffeomorphic', 'pairwise_mean']

# Two-sample t-test (Welch’s by default: unequal variances)
try:
    if affine_vals.dropna().size > 0 and diffeo_vals.dropna().size > 0:
        t_stat, p_val = ttest_ind(affine_vals, diffeo_vals, equal_var=False)
    else:
        p_val = float('nan')
except Exception as e:
    print('WARNING: t-test failed', e)
    p_val = float('nan')

# Plot
plt.figure(figsize=(6, 6))

# Custom colors per registration type
colors = {'affine+rigid': 'lime', 'diffeomorphic': 'lime'}

# Plot each group separately with safe medians/stds lookup
order = ['affine+rigid', 'diffeomorphic']
for i, reg in enumerate(order):
    subset = df_GABA[df_GABA['Registration'] == reg]
    if subset.empty:
        print(f'No data for {reg} — skipping')
        continue
    jitter = rng.uniform(-0.15, 0.15, size=len(subset)) + i
    plt.scatter(
        jitter, subset['pairwise_mean'],
        facecolors=colors.get(reg, 'gray'), edgecolors='black', linewidth=1.5, alpha=0.8, s=80
    )
    med_val = medians.get(reg, np.nan) if hasattr(medians, 'get') else (medians[reg] if reg in medians.index else np.nan)
    std_val = stds.get(reg, np.nan) if hasattr(stds, 'get') else (stds[reg] if reg in stds.index else np.nan)
    if not np.isnan(med_val):
        plt.hlines(med_val, xmin=i - 0.2, xmax=i + 0.2, colors='black', linewidth=5)
    if not np.isnan(std_val):
        plt.errorbar(x=i, y=med_val, yerr=std_val, fmt='none', ecolor='black', linewidth=3, capsize=0)

# Annotate p-value safely
if np.isfinite(p_val):
    max_val = df_GABA['pairwise_mean'].max()
    y_annotation = max_val + (0.05 * max_val)
    plt.plot([0, 1], [y_annotation]*2, color='black', linewidth=1)
    plt.text(0.5, y_annotation + 0.02 * max_val, f"p = {p_val:.3g}" if p_val >= 0.001 else "p < 0.001", ha='center', va='bottom', fontsize=12)

# X-axis labels
plt.xticks([0, 1], ['affine+rigid', 'diffeomorphic'])
plt.xlim(-0.5, 1.5)
plt.ylim(0,35)
plt.ylabel("Mean Pairwise Distance (units)")

# Keep only x and y axis
sns.despine(top=True, right=True)

plt.tight_layout()

# Save into figures directory
os.makedirs('figures', exist_ok=True)
plt.savefig('figures/Fig3E_GABAskeletondistance.png', format='png', dpi=600)
plt.savefig('figures/Fig3E_GABAskeletondistance.eps', format='eps', dpi=300)
plt.close()
print('WROTE: figures/Fig3E_GABAskeletondistance.png and .eps')


In [None]:
# Plot inotocin (Fig3F) with safe lookups
try:
    df_inotocin_affine['Registration'] = 'affine+rigid'
    df_inotocin_diffeo['Registration'] = 'diffeomorphic'
    df_inotocin = pd.concat([df_inotocin_affine, df_inotocin_diffeo], ignore_index=True)
except Exception:
    pass

# Compute group summaries safely
grouped_in = df_inotocin.groupby('Registration')['pairwise_mean']
means_in = grouped_in.mean()
medians_in = grouped_in.median()
stds_in = grouped_in.std()

# Plot
plt.figure(figsize=(6,6))
colors = {'affine+rigid':'cyan','diffeomorphic':'cyan'}
order = ['affine+rigid', 'diffeomorphic']
for i, reg in enumerate(order):
    subset = df_inotocin[df_inotocin['Registration']==reg]
    if subset.empty:
        print(f'No data for {reg} in inotocin — skipping')
        continue
    jitter = np.random.default_rng(123).uniform(-0.15, 0.15, size=len(subset)) + i
    plt.scatter(jitter, subset['pairwise_mean'], facecolors=colors.get(reg,'gray'), edgecolors='black', s=60, alpha=0.9)
    med_val = medians_in.get(reg, np.nan) if hasattr(medians_in, 'get') else (medians_in[reg] if reg in medians_in.index else np.nan)
    std_val = stds_in.get(reg, np.nan) if hasattr(stds_in, 'get') else (stds_in[reg] if reg in stds_in.index else np.nan)
    if not np.isnan(med_val):
        plt.hlines(med_val, xmin=i-0.2, xmax=i+0.2, colors='black', linewidth=4)
    if not np.isnan(std_val):
        plt.errorbar(x=i, y=med_val, yerr=std_val, fmt='none', ecolor='black', capsize=0)

# t-test if both groups present
try:
    a = df_inotocin.loc[df_inotocin['Registration']=='affine+rigid','pairwise_mean']
    b = df_inotocin.loc[df_inotocin['Registration']=='diffeomorphic','pairwise_mean']
    if a.dropna().size>0 and b.dropna().size>0:
        t_stat, p_val_in = ttest_ind(a,b,equal_var=False)
    else:
        p_val_in = float('nan')
except Exception as e:
    print('WARNING t-test inotocin failed', e)
    p_val_in = float('nan')

if np.isfinite(p_val_in):
    max_val = df_inotocin['pairwise_mean'].max()
    y_annotation = max_val + (0.05 * max_val)
    plt.plot([0,1],[y_annotation]*2, color='black', linewidth=1)
    plt.text(0.5, y_annotation + 0.02*max_val, f"p = {p_val_in:.3g}" if p_val_in>=0.001 else "p < 0.001", ha='center', va='bottom', fontsize=12)

plt.xticks([0,1], ['affine+rigid','diffeomorphic'])
plt.xlim(-0.5,1.5)
plt.ylim(0,120)
plt.ylabel('Mean Pairwise Distance (units)')
sns.despine(top=True, right=True)
plt.tight_layout()

os.makedirs('figures', exist_ok=True)
plt.savefig('figures/Fig3F_inotocinskeletondistance.png', dpi=600)
plt.savefig('figures/Fig3F_inotocinskeletondistance.eps', dpi=300)
plt.close()
print('WROTE: figures/Fig3F_inotocinskeletondistance.png and .eps')
