## imports and load data

In [1]:
import matplotlib.pyplot as plt
import os
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import ks_2samp, norm, combine_pvalues, wasserstein_distance_nd
import time
from exseq_kit import ProjectObject

# parameters
NUM_BINS = 20
NUM_PERMUTATIONS = 100
label_size = 15

ModuleNotFoundError: No module named 'exseq_kit'

In [None]:
# Step 1: Create the project
project_test = ProjectObject(main_folder='C://Users/Moshe/OneDrive/coding/SpatialGenomicsLab/SegmentedData/Organoids/organoids_not_filtered/', puncta_file_format='FOV1FOV_1.output.csv.withoutRF.csv.withCells.csv.enlarge.csv', auto_calc=False, xy_scaling=1/3.3, z_scaling=1/3.3)
for samp_name, samp_data in project_test.samples.items():
    print(samp_name)
    print("Num of cells", len(samp_data.cells))


In [None]:
# Step 1: Create the project
project = ProjectObject(main_folder='C://Users/Moshe/OneDrive/coding/SpatialGenomicsLab/SegmentedData/Organoids/filtered_organoids', xy_scaling=1/3.3, z_scaling=1/3.3)
for samp_name, samp_data in project.samples.items():
    # Change types by replacing 'type_0' with 'immature' and other types (1,2,3) with 'mature'. Also, delete cells of 'type_nan'
    samp_data.cells['cell_type'] = samp_data.cells['cell_type'].replace({'type_0': 'immature', 'type_1': 'mature', 'type_2': 'mature', 'type_3': 'mature'})
    samp_data.cells = samp_data.cells[samp_data.cells['cell_type'] != 'type_nan']


In [None]:
# general information
project.compute_sample_info()
project.sample_info

In [None]:
# ROBO Moran's i
MI_df = pd.read_csv(os.path.join(project.path, "comparing_conditions_Moran_I.csv"), index_col=0).T
MI_df['sample'] = MI_df.index
robo_mi_df = MI_df[['sample', 'ROBO1']]
robo_mi_df


In [None]:
for samp_name, samp_data in project.samples.items():
    samp_data.plot_correlation_vs_distance(method='spearman')

## graph of cell types proportions

In [None]:
# Step 2: Run count types analysis
df_counts, df_stats, (fig1, fig3) = project.analyze_cell_type_population(mode='relative', n_permutations=1000, x_label_size=label_size, y_label_size=label_size)

# Step 3: Show the plots
plt.show()

# Step 4: Optionally display the stats
print(df_stats)


## graph of cell types count

In [None]:
# Step 2: Run count types analysis
df_counts, df_stats, (fig1, fig3) = project.analyze_cell_type_population(mode='absolute', n_permutations=1000, x_label_size=label_size, y_label_size=label_size)

# Step 3: Show the plots
plt.show()

# Step 4: Optionally display the stats
print(df_stats)


## define functions for distribution along axis

In [None]:

def get_cumulative_distribution(data, cell_type):
    """
    Calculate the average cumulative distribution function (CDF) for a specific type across multiple samples.
    """
    cumulative_distributions = []

    for sample_name, df in data.items():
        signal = df[cell_type]
        cumulative = np.cumsum(signal) / np.sum(signal)
        cumulative_distributions.append(cumulative)

    mean_cumulative_distribution = np.mean(cumulative_distributions, axis=0)
    return mean_cumulative_distribution

def get_mean_hist(data, cell_type):

    signals = []
    for sample_name, df in data.items():
        sig = df[cell_type]
        signals.append(sig)
    # assume the signals are histograms values 
    mean_hist = np.mean(signals, axis=0)
    return mean_hist

def compute_ks_statistic(sick_data, control_data, cell_type, print_p=False):
    """
    Compute the Kolmogorov-Smirnov (KS) statistic for two groups (sick and control).
    """
    sick_cdf = get_cumulative_distribution(sick_data, cell_type)
    control_cdf = get_cumulative_distribution(control_data, cell_type)
    sick_hist = get_mean_hist(sick_data, cell_type)
    control_hist = get_mean_hist(control_data, cell_type)
    ks_statistic = np.max(np.abs(sick_cdf - control_cdf))
    if print_p:
        _, p_val = ks_2samp(sick_hist, control_hist)
        print("ks pval:", p_val)

    return ks_statistic, sick_cdf, control_cdf, sick_hist, control_hist
    

def plot_cumulative_distributions(sick_cdf, control_cdf, ks_statistic, bin_edges = NUM_BINS, cell_type = None, var_name=None):
    """
    Plot cumulative distributions for sick and control groups, highlighting the KS statistic.
    """
    if bin_edges == None:
        bin_edges = 50
    if np.isscalar(bin_edges):
        bin_edges = np.linspace(0, 1, bin_edges + 1)
    bin_centers = (bin_edges[:-1] + bin_edges[1:]) / 2
    plt.plot(bin_centers, sick_cdf, color="red")
    plt.plot(bin_centers, control_cdf, color="green")

    # Highlight the KS statistic
    max_diff_index = np.argmax(np.abs(sick_cdf - control_cdf))
    max_diff = sick_cdf[max_diff_index] - control_cdf[max_diff_index]
    plt.vlines(bin_centers[max_diff_index], control_cdf[max_diff_index], sick_cdf[max_diff_index], color="black", linestyle="--")
    print(f"KS Statistic: {ks_statistic:.3f}")
    plt.title(cell_type)

    plt.xlabel("Normalized 3D distance")
    plt.ylabel("Cumulative Probability")
    plt.show()


def plot_histograms(sick_hist, control_hist, title=None):

    # Plot the 2 histograms on the same graph
    n_points = len(sick_hist)
    
    bin_edges = np.linspace(0, 1, n_points + 1) 
    plt.hist(bin_edges[:-1], bins=bin_edges, weights=control_hist, alpha=0.5, label='Healthy', color='green')
    plt.hist(bin_edges[:-1], bins=bin_edges, weights=sick_hist, alpha=0.5, label='STXBP1', color='red')
    plt.xlabel("Normalized 3D distance")
    plt.ylabel("Normalized distribution")
    plt.title(title)
    plt.legend(loc='upper left')
    plt.show()
    
    
# def get_pdf
#     """
#     Calculate the average probability density function (PDF) for a specific type across multiple samples.
#     """
#     pdfs = []

#     for sample_name, df in data.items():
#         signal = df[cell_type]
#         # scale power s.t. the integral over probabilities will be 1
#         signal = signal * len(signal)
#         pdfs.append(signal)

#     mean_pdf = np.mean(pdfs, axis=0)
#     # senity check
#     point_step = 1 / (len(signal))
#     print(f"integral over pdf is approximately {np.sum(signal) * point_step}")
#     return mean_pdf
    
# def plot_pdf(sick_pdf, control_pdf, bin_edges = None, cell_type = None, var_name='Z axis'):
#     """
#     Plot probability density functions for sick and control groups
#     """

#     if bin_edges == None:
#         bin_edges = 50
#     if np.isscalar(bin_edges):
#         bin_edges = np.linspace(0, 1, bin_edges + 1)    
#     bin_centers = (bin_edges[:-1] + bin_edges[1:]) / 2
#     plt.plot(bin_centers, sick_pdf, label="Sick PDF", color="blue")
#     plt.plot(bin_centers, control_pdf, label="Control PDF", color="green")
#     plt.xlabel(var_name)
#     plt.ylabel("Probability Density")
#     plt.title(f"PDFs of type {cell_type} along {var_name}")
#     plt.legend()
#     plt.show()

In [None]:
def permutation_test(sick_data, control_data, cell_type, bin_edges=NUM_BINS, n_permutations=NUM_PERMUTATIONS, plot_data_cdfs=False, plot_permutation_dist=False, var_name = 'Z axis', seed=1):
    """
    Conduct a permutation test for the KS statistic.
    """
    # Set seed
    np.random.seed(seed)
    
    # Compute observed KS statistic
    observed_statistic, sick_cdf, control_cdf, sick_hist, control_hist = compute_ks_statistic(sick_data, control_data, cell_type, print_p=True)

    # Combine all data
    all_data = {**sick_data, **control_data}
    all_sample_names = list(all_data.keys())
    n_sick = len(sick_data)

    permuted_stats = []
    for _ in range(n_permutations):
        np.random.shuffle(all_sample_names)
        permuted_sick = {name: all_data[name] for name in all_sample_names[:n_sick]}
        permuted_control = {name: all_data[name] for name in all_sample_names[n_sick:]}

        permuted_statistic, _, _, _, _ = compute_ks_statistic(permuted_sick, permuted_control, cell_type)
        permuted_stats.append(permuted_statistic)

    # Calculate p-value assuming normal distribution of permuted statistics
    mean_stat = np.mean(permuted_stats)
    std_stat = np.std(permuted_stats)
    z_p_value = 1 - norm.cdf(observed_statistic, loc=mean_stat, scale=std_stat)
    direct_p_value = sum(permuted_stats>=observed_statistic)/(1+n_permutations)
 
    if plot_data_cdfs:
        plot_cumulative_distributions(sick_cdf, control_cdf, observed_statistic, bin_edges, cell_type, var_name)
        plot_histograms(sick_hist, control_hist, title=cell_type)
        
    if plot_permutation_dist:
        # Plot histogram of permuted statistics 
        plt.hist(permuted_stats, bins=30, alpha=0.6, color="gray")
        
        # Print observed statistic
        plt.axvline(observed_statistic, color="red", linestyle="--")
        print(f"For type {cell_type} with {n_permutations} permutations, we got direct_p-value of: {direct_p_value:.3e}")
        
        plt.xlabel("KS Statistic")
        plt.ylabel("Counts")
        plt.title(f"Permutation Test of KS-statistic for type {cell_type}")
        plt.show()

    return observed_statistic, z_p_value, direct_p_value, 

## Types distribution along r axis - spherical

### graph for each sample

In [None]:
control_data, sick_data = project.analyze_cell_distribution_along_axis(num_bins = NUM_BINS, normalize_counts=True, normalize_axis=True, bandwidth=2, show_histogram=False, radial='spherical')

### Condition analysis - graphs

In [None]:
from scipy.special import comb
N = len(control_data) + len(sick_data)
K = len(sick_data)
print("the num of combinations: ", comb(N, K))


In [None]:

results = []
for ct in ['immature', 'mature']:
    observed_stat, z_p_val, str_p_val = permutation_test(sick_data, control_data, cell_type=ct, plot_data_cdfs=True, plot_permutation_dist=True, var_name='radial distance (3D)', seed=1)
    results.append({
        "type": ct,
        "KS-stat": observed_stat,
        "P-value (straight calc)": str_p_val,
        "P-value (normal assumption)": z_p_val  
    })

res_df = pd.DataFrame(results)
res_df

## Types distribution along z axis 

### graph for each sample

In [None]:
control_data, sick_data = project.analyze_cell_distribution_along_axis(num_bins = NUM_BINS, normalize_counts=True, normalize_axis=True, bandwidth=2, show_histogram=False)

### Condition analysis - graphs

In [None]:

results = []
for ct in ['immature', 'mature']:
    observed_stat, z_p_val, str_p_val = permutation_test(sick_data, control_data, cell_type=ct, plot_data_cdfs=True, plot_permutation_dist=True, var_name='radial distance (3D)', seed=1)
    results.append({
        "type": ct,
        "KS-stat": observed_stat,
        "P-value (straight calc)": str_p_val,
        "P-value (normal assumption)": z_p_val  
    })

res_df = pd.DataFrame(results)
res_df

## Types distribution along r axis - spherical, bin cells by volume

### graph for each sample

In [None]:
control_data, sick_data = project.analyze_cell_distribution_along_axis(num_bins = NUM_BINS, normalize_counts=True, normalize_axis=True, bandwidth=2, show_histogram=False, radial='spherical_volume')

### Condition analysis - graphs

In [None]:

results = []
for ct in ['immature', 'mature']:
    observed_stat, z_p_val, str_p_val = permutation_test(sick_data, control_data, cell_type=ct, n_permutations=100, plot_data_cdfs=True, plot_permutation_dist=True, var_name='radial distance (3D)', seed=1)
    results.append({
        "type": ct,
        "KS-stat": observed_stat,
        "P-value (straight calc)": str_p_val,
        "P-value (normal assumption)": z_p_val  
    })

res_df = pd.DataFrame(results)
res_df