# Mapping RNAscope & Spatial-Transcriptomics Data to a Calibrated Vestibular-Ganglion Reference Frame  
*Author:* &lt;Ruiqi Liu&gt;  
*Date:* 2025-09-10

---

## 🎯 Purpose  
This notebook performs **end-to-end coordinate mapping** of RNA-scope or ST spot data onto a **standardised, calibrated Vestibular Ganglion (VG)** atlas.  

### Three core stages  
1. Build the **standard VG frame**  
2. **Map every slice** into that frame  
3. **Summarise, visualise & analyse** spatial gene-expression patterns  

---

## 1️⃣ Construction of the Standardised VG Reference Frame  
- Import **all VG slices**, together with manually drawn **landmark points** outlining the boundary anatomical features.  
- Perform **size normalisation** on each slice based on these landmarks, then compute the **pixel-wise average** across the entire set to establish the final VG common reference frame.  
- Use **Thin-Plate-Spline (TPS)** warping to align every normalised slice to this average frame, generating an **average template** and a **surface mesh** that remain unchanged for downstream mapping steps.
---

In [2]:
# import the modules
import os
from modules import preprocessing

In [None]:
# set the path and parameters
base_folder = 'path/to/your/data/'
landmarks_folder = os.path.join(base_folder, 'landmarks')
summary_folder = os.path.join(base_folder, 'summary')
data_folder = os.path.join(base_folder, 'data')

# list of markers and landmarkpoints
marker_list = ['a', 'b', 'c', 'd', 'e']
point_label_list = ['big sharp', 'big flat', 'mid sharp', 'mid flat', 'small sharp', 'small flat', 'small2mid', 'big2mid']
order_list = [0, 7, 2, 6, 4, 5, 3, 1, 0]

# retrieve data subfolders
sub_folders = [f for f in os.listdir(data_folder) 
               if os.path.isdir(os.path.join(data_folder, f)) and f not in ['.', '..']]

print(f'Root Directory: {base_folder}')
print(f'Total number of samples: {len(sub_folders)}')
print(f'Sample List: {sub_folders}')

In [None]:
# Step 1: Data Preprocessing
print('===== Step 1: Data Preprocessing =====')
try:
    preprocessing.preprocessing(base_folder, landmarks_folder, data_folder)
    print('✅ Step 1: Data Preprocessing')
except Exception as e:
    print(f'❌ Step 1: Data Preprocessing Failed: {e}')

# Step 2: Calculate the standardized coordinates for each sample
print('===== Step 2: Calculate the standardized coordinates for each sample =====')
try:
    preprocessing.norm_rot_point(data_folder, summary_folder, sub_folders)
    print('✅ Step 2: Calculate the standardized coordinates for each sample')
except Exception as e:
    print(f'❌ Step 2: Calculate the standardized coordinates for each sample Failed: {e}')

# Step 3: Draw the average VG
print('===== Step 3: Draw the average VG =====')
try:
    preprocessing.registration_to_frame(data_folder, summary_folder, sub_folders, point_label_list, order_list)
    print('✅ Step 3: Draw the average VG')
except Exception as e:
    print(f'❌ Step 3: Draw the average VG Failed: {e}')

# Step 4: Generate the landmarks table
print('===== Step 4: Generate the landmarks table =====')
try:
    preprocessing.mapping_landmarks(data_folder, summary_folder, sub_folders)
    print('✅ Step 4: Generate the landmarks table')
except Exception as e:
    print(f'❌ Step 4: Generate the landmarks table Failed: {e}')

# Step 5: Generate the original coordinates of the ROI
print('===== Step 5: Generate the original coordinates of the ROI =====')
try:
    preprocessing.true_mapping_roi(marker_list, data_folder, sub_folders)
    print('✅ Step 5: Generate the original coordinates of the ROI')
except Exception as e:
    print(f'❌ Step 5: Generate the original coordinates of the ROI Failed: {e}')

## 2️⃣ Slice-to-Atlas Mapping  
❗Before running this step, you need to run Script.groovy to generate a calibrated list of cell coordinates.
- Load spot tables (`gene, x, y`) and gene matrices  
- Calibrate transform with **fiduciary landmarks**  
  - ganglion contours  
- Apply **thin-plate-spline** transform  
- QC: overlay mapped slice on reference outline  

In [5]:
# import the modules
import os
from modules import roi_distribution

In [None]:
# set the path and parameters
base_folder = 'path/to/your/data/'
summary_folder = os.path.join(base_folder, 'summary')
data_folder = os.path.join(base_folder, 'data')
sub_folders = [f for f in os.listdir(data_folder) 
               if os.path.isdir(os.path.join(data_folder, f)) and f not in ['.', '..']]
marker_names = ['a', 'b', 'c', 'd', 'e']
order_list = [0, 7, 2, 6, 4, 5, 3, 1, 0]

print(f'Root Directory: {base_folder}')
print(f'Total number of samples: {len(sub_folders)}')
print(f'Sample List: {sub_folders}')
print(f'Marker List: {marker_names}')

In [None]:
# Step 1: Draw the ROI map for each sample
print('===== Step 1: Draw the ROI map for each sample =====')
roi_distribution.plot_individually(summary_folder, sub_folders, data_folder)
print('✅ Step 1: Draw the ROI map for each sample')

In [None]:
# Step 2: Summarize the ROI information for each subtype
print('===== Step 2: Summarize the ROI information for each subtype =====')
roi_distribution.summary_subtypes_csv(summary_folder, data_folder)
print('✅ Step 2: Summarize the ROI information for each subtype）')

In [None]:
# Step 3: Draw the ROI scatter and distribution plots for each subtype
print('===== Step 3: Draw the ROI scatter and distribution plots for each subtype =====')
try:
    num_xbins = 20
    num_ybins = 10
    roi_distribution.plot_summary_scatter(summary_folder, marker_names, order_list, num_xbins, num_ybins)
    print('✅ Step 3: Draw the ROI scatter and distribution plots for each subtype')
except Exception as e:
    print(f'❌ Step 3: Draw the ROI scatter and distribution plots for each subtype Failed: {e}')

In [None]:
import os
import pandas as pd
from functools import reduce

sample_info_path = os.path.join(summary_folder, "sample_info.csv")
# Read sample_info.csv to obtain samName
try:
    sample_info_df = pd.read_csv(sample_info_path)
    # Get unique list of samName
    sam_names = sample_info_df['sam_name'].unique()
    print(f'✅ Successfully read sample info file, found {len(sam_names)} unique sam_name(s)')
    print(f'sam_name list: {list(sam_names)}')
except FileNotFoundError:
    print(f"❌ Error: file not found {sample_info_path}")
    print("Please check if the file path is correct")
    sam_names = []
except Exception as e:
    print(f"❌ Error while reading file: {e}")
    sam_names = []

# Inspect files in the summary folder
if os.path.exists(summary_folder):
    summary_files = os.listdir(summary_folder)
    csv_files = [f for f in summary_files if f.endswith('.csv')]
    print(f'{len(csv_files)} CSV file(s) found in summary folder')
    print('First 10 CSV files:')
    for i, file in enumerate(csv_files[:10]):
        print(f'  {i+1}. {file}')
    if len(csv_files) > 10:
        print(f'  ... and {len(csv_files) - 10} more files')
else:
    print(f'❌ Summary folder does not exist: {summary_folder}')


def process_sam_name(sam_name, summary_folder):
    """Merge data for a single sam_name"""
    print(f"\nProcessing sam_name: {sam_name}")

    # Find all summary files for current samName, excluding files ending with _all.csv
    sam_files = [f for f in os.listdir(summary_folder)
                 if f.startswith(f"{sam_name}_summary_") and not f.endswith("_all.csv")]

    if not sam_files:
        print(f"⚠️ Warning: no files found for sam_name {sam_name}")
        return False

    print(f"{len(sam_files)} related file(s) found:")
    for file in sam_files:
        print(f"  - {file}")

    # List to store dataframes
    dataframes = []

    # Read each file
    for file in sam_files:
        file_path = os.path.join(summary_folder, file)
        try:
            # Extract marker name from filename
            parts = file.split('_')
            if len(parts) >= 3:
                marker = parts[2].replace('.csv', '')
            else:
                print(f"⚠️ Warning: filename format unexpected {file}, skip marker extraction.")
                marker = ""

            # Read CSV
            df = pd.read_csv(file_path)

            # Ensure Label column exists
            if 'Label' not in df.columns:
                print(f"⚠️ Warning: no Label column in {file_path}, skipping.")
                continue

            # Add marker prefix to all columns except Label (if marker exists)
            if marker:
                rename_dict = {col: f"{marker}_{col}" for col in df.columns if col != 'Label'}
                df = df.rename(columns=rename_dict)
                print(f"    Added marker prefix '{marker}' for {file}")

            dataframes.append(df)

        except Exception as e:
            print(f"❌ Error reading file {file_path}: {e}")
            continue

    if not dataframes:
        print(f"❌ No valid data found for sam_name: {sam_name}")
        return False

    print(f"Successfully read {len(dataframes)} data file(s)")

    # Horizontally merge all dataframes on Label using reduce
    if len(dataframes) > 1:
        final_df = reduce(
            lambda left, right: pd.merge(left, right, on='Label', how='outer'),
            dataframes
        )
        print(f"Merge completed, final dataframe shape: {final_df.shape}")
    else:
        final_df = dataframes[0]
        print(f"Only one dataframe, no merge needed, shape: {final_df.shape}")

    # Handle missing values differently for isPositive columns vs others
    # Get all isPositive columns
    is_positive_cols = [col for col in final_df.columns if col.endswith('_isPositive')]
    # Other columns (excluding Label and isPositive)
    other_cols = [col for col in final_df.columns if col not in is_positive_cols and col != 'Label']

    # Fill missing values in isPositive columns with 0
    if is_positive_cols:
        final_df[is_positive_cols] = final_df[is_positive_cols].fillna(0)
        print(f"Filled missing values with 0 for {len(is_positive_cols)} isPositive column(s)")

    # Fill missing values in other columns with -1
    if other_cols:
        final_df[other_cols] = final_df[other_cols].fillna(-1)
        print(f"Filled missing values with -1 for {len(other_cols)} other column(s)")

    # Add hasPositive column
    if is_positive_cols:
        print(f"Found {len(is_positive_cols)} isPositive column(s): {is_positive_cols}")
        # Ensure numeric type for calculation
        for col in is_positive_cols:
            if col in final_df.columns:
                final_df[col] = pd.to_numeric(final_df[col], errors='coerce').fillna(0)

        # hasPositive = 1 if any isPositive column equals 1, else 0
        final_df['hasPositive'] = final_df[is_positive_cols].apply(
            lambda row: 1 if (row == 1).any() else 0, axis=1
        )
        print("hasPositive column calculated")
    else:
        final_df['hasPositive'] = 0
        print("No isPositive columns found, hasPositive set to all zeros")

    # Remove fully duplicate rows before saving
    original_rows = len(final_df)
    final_df = final_df.drop_duplicates()
    removed_rows = original_rows - len(final_df)
    if removed_rows > 0:
        print(f"Removed {removed_rows} duplicate row(s)")

    # Save merged dataframe
    output_file_name = f"{sam_name}_summary_all.csv"
    output_file_path = os.path.join(summary_folder, output_file_name)
    final_df.to_csv(output_file_path, index=False)
    print(f"✅ Summary table saved to {output_file_path}")

    return True


print("Processing function defined!")
# Run data processing
if len(sam_names) > 0:
    print("===== Start processing data for all sam_name =====")

    success_count = 0
    for sam_name in sam_names:
        if process_sam_name(sam_name, summary_folder):
            success_count += 1

    print(f"\n🎉 All sam_name data processing complete!")
    print(f"Successfully processed: {success_count}/{len(sam_names)} sam_name(s)")
else:
    print("❌ No sam_name found, cannot proceed")

# Inspect generated files
print("\n===== Inspect generated files =====")
if os.path.exists(summary_folder):
    all_files = [f for f in os.listdir(summary_folder) if f.endswith('_all.csv')]
    if all_files:
        print(f"Found {len(all_files)} summary file(s):")
        for file in all_files:
            file_path = os.path.join(summary_folder, file)
            file_size = os.path.getsize(file_path)
            print(f"  - {file} ({file_size} bytes)")

            # Preview file content
            try:
                df = pd.read_csv(file_path)
                print(f"    Shape: {df.shape}, Columns: {len(df.columns)}")
                print(f"    Column names: {list(df.columns[:5])}{'...' if len(df.columns) > 5 else ''}")
            except Exception as e:
                print(f"    Unable to read file content: {e}")
    else:
        print("No summary files found")
else:
    print(f"Summary folder does not exist: {summary_folder}")

## 3️⃣ Summarisation, Visualisation & Significance Testing  
- Aggregate counts into **grid maps**  
- Create **2-D surface** and **1-D flattened** expression maps  
- Run  
  - differential expression  
- Export  
  - publication-ready PNG/SVG  
  - CSV of p-values & effect sizes  

---

### one sample mode, each marker

In [None]:
# import packages
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from scipy.interpolate import CubicSpline
from modules.judge_sig import (
    create_position_bins,
    two_step_stratified_sample,
    monte_carlo_simulation,
    x_monte,
    correct_grid_pvalues,
    plot_corrected_p_heatmap,
)


In [None]:
# parameters
n_simulations = 10000
num_xbins = 20
num_ybins = 10
correction_method = 'fdr'
bx_min = -0.6
bx_max = 0.6
by_min = -0.6
by_max = 0.6

cell_num = 0
random_num =42

sample_type = 'wt'
# ----------------------
# paths
# ----------------------
base_folder = 'path/to/your/data/' 
summary_folder = os.path.join(base_folder, "summary")
sample_info_path = os.path.join(summary_folder, "sample_info.csv")
wt_summary_path = os.path.join(summary_folder, sample_type + "_summary_all.csv")
results_file_path = os.path.join(summary_folder, 'vg_averageing_labels.csv')
vg_size_path = os.path.join(summary_folder, 'vg_size.csv')

print(f'Root directory: {base_folder}')
print(f'Summary folder: {summary_folder}')
print(f'Number of simulations: {n_simulations}')
print(f'Number of x-bins: {num_xbins}, Number of y-bins: {num_ybins}')
print(f'Correction method: {correction_method}')

# ----------------------
# load the data and preprocessing
# ----------------------
def load_data():
    try:
        wt_data = pd.read_csv(wt_summary_path)
        print(f"succesfullly load {sample_type} data, length : {len(wt_data)} rows")
        
        results_data = pd.read_csv(results_file_path, header=None)
        BX = results_data.iloc[:, 0].values
        BY = results_data.iloc[:, 1].values
        print(f"succesfully load BX/BY, length are : {len(BX)}, {len(BY)}")
        
        return wt_data, BX, BY
    except Exception as e:
        print(f"data loading error: {e}")
        raise


summary_data, BX, BY = load_data()
original_count = len(summary_data)
summary_data = summary_data.drop_duplicates(keep='first')
deduplicated_count = len(summary_data)
removed_count = original_count - deduplicated_count
print(f"data num: {deduplicated_count} (removed {removed_count} duplicates")

# Read VG size data and calculate normalized coordinates
vg_size = pd.read_csv(vg_size_path, header=None)
avg_vg_size = vg_size.mean(axis=0).values
print(f"VG average size: {avg_vg_size}")


is_positive_cols = [col for col in summary_data.columns if col.endswith('_isPositive')]
if not is_positive_cols:
    raise ValueError("No columns ending with '_isPositive' were found, please check the column names.")
print(f"find {len(is_positive_cols)} positive-labeling columns: {is_positive_cols}")

# Find coordinate column
x_col = next((col for col in summary_data.columns if col.endswith('warpedROIvar1')), None)
y_col = next((col for col in summary_data.columns if col.endswith('warpedROIvar2')), None)

if not x_col or not y_col:
    missing = []
    if not x_col: missing.append("The column ending with warpedROIvar1")
    if not y_col: missing.append("The column ending with warpedROIvar2")
    raise ValueError(f"necessary cols lost: {', '.join(missing)}")

# Calculate standardized coordinates
summary_data['standardized_x'] = summary_data[x_col] / avg_vg_size[0]
summary_data['standardized_y'] = summary_data[y_col] / avg_vg_size[1]

data = summary_data[['Label'] + is_positive_cols + ['standardized_x', 'standardized_y', 'hasPositive']].copy()

if cell_num > 0 :
    data = two_step_stratified_sample(data, cell_num, random_num)

order_list=[0, 7, 2, 6, 4, 5, 3, 1, 0]
BX_ordered = BX[order_list]
BY_ordered = BY[order_list]
cs = CubicSpline(np.arange(len(BX_ordered)), np.c_[BX_ordered, BY_ordered], 
                axis=0, bc_type='periodic')
t_fine = np.linspace(0, len(BX_ordered)-1, 300)
x_fine, y_fine = cs(t_fine).T

# Calculate the boundaries of BX and BY
actual_bx_min, actual_bx_max = x_fine.min(), x_fine.max()
actual_by_min, actual_by_max = y_fine.min(), y_fine.max()

# Create Intervals
bx_bins, by_bins = create_position_bins(
    actual_bx_min, actual_bx_max,
    actual_by_min, actual_by_max,
    num_xbins, num_ybins
)

print(f"BX: from {actual_bx_min:.2f} to {actual_bx_max:.2f}")
print(f"BY: from {actual_by_min:.2f} to {actual_by_max:.2f}")

In [None]:
print("\n===== Starting Monte Carlo simulation =====")
results = []

for col in is_positive_cols:
    gene = col.replace('_isPositive', '')
    x_col = 'standardized_x'
    y_col = 'standardized_y'
    
    if x_col not in data.columns or y_col not in data.columns:
        print(f"⚠️ Warning: standardized coordinate columns not found, skipping {col}")
        continue
    
    print(f"\n===== Analyzing {col} =====")
    print(f"Coordinate columns used: X={x_col}, Y={y_col}")
    
    result = monte_carlo_simulation(data, col, x_col, y_col, bx_bins, by_bins, n_simulations, save_base_dir=os.path.join(summary_folder, "test_monte"))
    
    if result is not None:
        total_real = np.sum(result['real_counts'])
        print(f"Total real counts for {col}: {total_real}")
        
        if total_real > 0:
            result['corrected_p'] = correct_grid_pvalues(result['p_values'], method=correction_method)
            results.append(result)
            print(f"✅ Analysis of {col} completed")
        else:
            print(f"❌ Total real counts for {col} is 0, not adding to results")
    else:
        print(f"❌ Analysis of {col} failed, no result generated")

In [None]:
min_count = 0
max_count = 0.02
if results:
    print("\n===== Saving analysis results =====")
    if cell_num > 0:
        results_folder = os.path.join(summary_folder, f"grid_pvalues_{sample_type}_{cell_num}")
        visualization_folder = os.path.join(summary_folder, f"grid_visualizations_{sample_type}_{cell_num}")
    else:
        results_folder = os.path.join(summary_folder, f"grid_pvalues_{sample_type}")
        visualization_folder = os.path.join(summary_folder, f"grid_visualizations_{sample_type}")
    os.makedirs(results_folder, exist_ok=True)

    for result in results:
        gene = result['gene']
        np.savetxt(os.path.join(results_folder, f'{gene}_real_counts.csv'), result['real_counts'], delimiter=',', fmt='%d')
        np.savetxt(os.path.join(results_folder, f'{gene}_corrected_p.csv'), result['corrected_p'], delimiter=',', fmt='%.6f')
        plot_corrected_p_heatmap(result, data, gene, bx_min, bx_max, by_min, by_max, BX, BY, visualization_folder, min_count, max_count)
else:
    print("\n⚠️ No valid results were generated")

print("\n===== Analysis complete =====")

In [None]:
print("\n===== Starting x-axis distribution Monte Carlo simulation =====")
results = []

for col in is_positive_cols:
    gene = col.replace('_isPositive', '')
    x_col = 'standardized_x'
    y_col = 'standardized_y'
    
    if x_col not in data.columns or y_col not in data.columns:
        print(f"⚠️ Warning: standardized coordinate columns not found, skipping {col}")
        continue
    
    print(f"\n===== Analyzing {col} =====")
    print(f"Coordinate columns used: X={x_col}, Y={y_col}")
    
    result = x_monte(
        data, col, x_col, 
        bx_bins, n_simulations
    )
    
    if result is not None:
        total_real = np.sum(result['real_counts'])
        print(f"Total real counts for {col}: {total_real}")
        
        if total_real > 0:
            result['corrected_p'] = correct_grid_pvalues(result['p_values'], method=correction_method)
            results.append(result)
            print(f"✅ Analysis of {col} completed")
        else:
            print(f"❌ Total real counts for {col} is 0, not adding to results")
    else:
        print(f"❌ Analysis of {col} failed, no result generated")

In [None]:
min_count = 0
max_count = 0.1
p_value_threshold = 0.05  # significance threshold
# -------------------------- Key: define target Marker order --------------------------
target_marker_order = ['a', 'b', 'c', 'd', 'e']  # top-to-bottom order

if results:
    print("\n===== 1. Merge and save X-axis data for all Markers =====")
    # 1.1 Create folder with _xbins suffix (following previous path rule)
    if cell_num > 0:
        results_folder = os.path.join(summary_folder, f"grid_pvalues_{sample_type}_{cell_num}_xbins")
        visualization_folder = os.path.join(summary_folder, f"grid_visualizations_{sample_type}_{cell_num}_xbins")
    else:
        results_folder = os.path.join(summary_folder, f"grid_pvalues_{sample_type}_xbins")
        visualization_folder = os.path.join(summary_folder, f"grid_visualizations_{sample_type}_xbins")
    os.makedirs(results_folder, exist_ok=True)
    os.makedirs(visualization_folder, exist_ok=True)

    # 1.2 Extract X-axis bin edges (all markers share the same x bins; use first valid result's logic)
    x_bin_edges = np.linspace(BX.min(), BX.max(), num_xbins + 1)  # original X-axis bin boundaries
    x_bin_labels = [f"Bin{i+1}\n[{x_bin_edges[i]:.2f}-{x_bin_edges[i+1]:.2f}]" for i in range(len(x_bin_edges) - 1)]
    x_bin_num = len(x_bin_labels)  # total number of X bins

    # 1.3 Merge real_counts and corrected_p for all markers
    # Initialise merged DataFrame (rows = markers, columns = x bins)
    marker_list = []
    real_counts_all = []
    corrected_p_all = []

    # ---------- 1. Compute 20 bin indices ----------
    min_bin_idx = np.digitize(x_fine.min(), bx_bins) - 1
    max_bin_idx = np.digitize(x_fine.max(), bx_bins) - 1
    zero_bin_idx = np.digitize(0, bx_bins) - 1
    x_bins_indices = np.arange(min_bin_idx, min_bin_idx + 20)  # fixed 20 bins
    x_offsets = x_bins_indices - zero_bin_idx  # offset relative to 0

    # ---------- 2. Pick ticks every 5 steps, centred at 0, only within existing 20 offsets ----------
    step = 5
    low = int(np.floor(x_offsets.min() / step) * step)
    high = int(np.ceil(x_offsets.max() / step) * step)
    tick_labels = np.arange(low, high + 1, step)  # multiples of 5
    tick_labels = tick_labels[np.isin(tick_labels, x_offsets)]  # ensure existence

    # ---------- 3. Corresponding plot positions ----------
    tick_positions = [np.where(x_offsets == lab)[0][0] for lab in tick_labels]

    for result in results:
        marker = result['gene']
        rc = result['real_counts']  # 1-D array (number of x bins)
        cp = result['corrected_p']  # 1-D array (number of x bins)

        rc = rc[min_bin_idx:max_bin_idx]
        cp = cp[min_bin_idx:max_bin_idx]
        marker_list.append(marker)
        real_counts_all.append(rc)
        corrected_p_all.append(cp)

    # 1.4 Save merged data (CSV format for downstream use)
    # Save real counts (marker × x bin)
    rc_df = pd.DataFrame(real_counts_all, index=marker_list, columns=[f"X_Bin{i+1}" for i in range(num_xbins)])
    rc_df.index.name = "Marker"
    rc_df.to_csv(os.path.join(results_folder, "all_markers_real_counts.csv"), index=True)
    print(f"✅ Saved real counts for all markers: all_markers_real_counts.csv")

    # Save corrected p values (marker × x bin)
    cp_df = pd.DataFrame(corrected_p_all, index=marker_list, columns=[f"X_Bin{i+1}" for i in range(num_xbins)])
    cp_df.index.name = "Marker"
    cp_df.to_csv(os.path.join(results_folder, "all_markers_corrected_p.csv"), index=True)
    print(f"✅ Saved corrected p values for all markers: all_markers_corrected_p.csv")

    # Save X-axis bin info (for plotting / interpretation)
    x_bin_info = pd.DataFrame({
        "Bin_Index": range(1, num_xbins + 1),
        "Bin_Left": BX.min(),
        "Bin_Right": BX.max()
    })
    x_bin_info.to_csv(os.path.join(results_folder, "x_axis_bin_info.csv"), index=False)
    print(f"✅ Saved X-axis bin info: x_axis_bin_info.csv")

In [None]:
min_count = 0
max_count = 0.1
p_value_threshold = 0.05  # significance threshold
# -------------------------- Key: define target Marker order --------------------------
target_marker_order = ['a', 'b', 'c', 'd', 'e']  # top-to-bottom order

if cell_num > 0:
    results_folder = os.path.join(summary_folder, f"grid_pvalues_{sample_type}_{cell_num}_xbins")
    visualization_folder = os.path.join(summary_folder, f"grid_visualizations_{sample_type}_{cell_num}_xbins")
else:
    results_folder = os.path.join(summary_folder, f"grid_pvalues_{sample_type}_xbins")
    visualization_folder = os.path.join(summary_folder, f"grid_visualizations_{sample_type}_xbins")

if results:
    rc_df = pd.read_csv(os.path.join(results_folder, "all_markers_real_counts.csv"), header=0, index_col=0)
    cp_df = pd.read_csv(os.path.join(results_folder, "all_markers_corrected_p.csv"), header=0, index_col=0)
    x_bin_info = pd.read_csv(os.path.join(results_folder, "x_axis_bin_info.csv"), header=0, index_col=0)
    # -------------------------- 2. Plot integrated heatmap (all markers combined) --------------------------
    print("\n===== 2. Plot integrated heatmap for all Markers =====")
    # -------------------------- Key modification 1: filter and reorder data by target order --------------------------
    # 1. Filter: keep only target Markers
    rc_df_filtered = rc_df.loc[rc_df.index.isin(target_marker_order)]  # count data filter
    cp_df_filtered = cp_df.loc[cp_df.index.isin(target_marker_order)]  # p-value data filter

    # 2. Reorder: adjust row index according to target_marker_order (core step)
    rc_df_ordered = rc_df_filtered.reindex(target_marker_order)  # count data in target order
    cp_df_ordered = cp_df_filtered.reindex(target_marker_order)  # p-value data in target order

    # 3. Check for missing data (optional: avoid empty plot if target Marker absent)
    missing_markers = [m for m in target_marker_order if m not in rc_df_ordered.index]
    if missing_markers:
        print(f"⚠️ No data for following Marker(s), skipped: {missing_markers}")
    if rc_df_ordered.empty:
        print("⚠️ No valid Marker data, skipping plot")
    else:
        # 2.1 Data preprocessing (based on reordered data)
        rc_normalized = rc_df_ordered.div(rc_df_ordered.sum(axis=1), axis=0)  # normalize counts
        cp_clipped = cp_df_ordered.clip(upper=p_value_threshold)  # clip p-values to avoid color overflow

        # 2.2 Create figure (2 subplots: top = count heatmap, bottom = p-value heatmap)
        # Height adapts to number of valid Markers (rc_df_ordered.shape[0])
        fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12, 1 * rc_df_ordered.shape[0]))
        fig.suptitle(f"All Markers - X-axis Bins Analysis (Sample: {sample_type})", fontsize=14, y=0.98)

        # 2.3 Plot 「count proportion heatmap」 (top subplot, based on reordered data)
        im1 = ax1.imshow(rc_normalized.values, cmap='Blues', aspect='auto', vmin=min_count, vmax=max_count)
        # Set axis labels (yticks use reordered Marker list)
        ax1.set_xticks(tick_positions)
        ax1.set_xticklabels(tick_labels, fontsize=15)
        ax1.set_xlabel("offset from border", fontsize=15)

        # -------------------------- Key modification 2: y-axis labels in reordered order --------------------------
        ax1.set_yticks(range(len(rc_df_ordered)))  # y-tick positions (match reordered rows)
        ax1.set_yticklabels([word.capitalize() for word in rc_df_ordered.index], fontsize=15)  # y labels = target-order Markers
        ax1.set_title("Positive Cell Proportion", fontsize=15, pad=8)

        # Add colorbar for count heatmap
        cbar1 = plt.colorbar(im1, ax=ax1, shrink=0.8, aspect=20)
        cbar1.outline.set_visible(False)
        cbar1.set_label("Cell Proportion", fontsize=15, labelpad=8)
        cbar1.set_ticks([min_count, (min_count + max_count) / 2, max_count])
        cbar1.set_ticklabels([f"{x:.2f}" for x in [min_count, (min_count + max_count) / 2, max_count]], fontsize=12)

        # 2.4 Plot 「corrected p-value heatmap」 (bottom subplot, based on reordered data)
        im2 = ax2.imshow(cp_clipped.values, cmap='Reds_r', aspect='auto', vmin=0, vmax=p_value_threshold)
        # Set axis labels (y-axis consistent with top subplot)
        ax2.set_xticks(tick_positions)
        ax2.set_xticklabels(tick_labels, fontsize=15)
        ax2.set_xlabel("offset from border", fontsize=15)

        # -------------------------- Key modification 3: p-value heatmap y-axis order consistent with count heatmap --------------------------
        ax2.set_yticks(range(len(cp_df_ordered)))  # y-tick positions (match reordered rows)
        ax2.set_yticklabels([word.capitalize() for word in cp_df_ordered.index], fontsize=15)  # y labels = target-order Markers
        ax2.set_title(f"q-value (FDR, cutoff at {p_value_threshold})", fontsize=15, pad=8)

        # Add colorbar for p-value heatmap
        cbar2 = plt.colorbar(im2, ax=ax2, shrink=0.8, aspect=20)
        cbar2.outline.set_visible(False)
        cbar2.set_label("FDR", fontsize=10, labelpad=8)
        cbar2.set_ticks([0, p_value_threshold / 2, p_value_threshold])
        cbar2.set_ticklabels([f"{x:.2f}" for x in [0, p_value_threshold / 2, p_value_threshold]], fontsize=10)

        # 2.5 Adjust layout (avoid label overlap)
        plt.tight_layout(rect=[0, 0.03, 1, 0.95])  # reserve top space for title

        # 2.6 Save integrated heatmap
        heatmap_path = os.path.join(visualization_folder, "all_markers_xbin_combined_heatmap")
        plt.savefig(heatmap_path + ".png", dpi=1200, bbox_inches='tight')
        plt.savefig(heatmap_path + ".pdf", dpi=1200, bbox_inches='tight')
        plt.savefig(heatmap_path + ".eps", format='eps', dpi=1200, bbox_inches='tight')
        plt.close(fig)
        print(f"✅ Saved integrated heatmap for all markers: {heatmap_path}")

else:
    print("\n⚠️ No valid results generated, skipping data saving and plotting")

print("\n===== Analysis complete =====")

### two sample mode, Tmie-/- - WT

In [None]:
import os
import pandas as pd
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
from scipy.interpolate import CubicSpline
from modules.judge_sig import (
    create_position_bins,
    two_step_stratified_sample,
    count_positive_in_bins,
    correct_grid_pvalues,
    count_positive_in_x_bins
)

In [None]:
# parameters
n_simulations = 10000
num_xbins = 20
num_ybins = 10
correction_method = 'fdr'
bx_min = -0.6
bx_max = 0.6
by_min = -0.6
by_max = 0.6

cell_num = 0
random_num = 42

# sample types
sample_types = ['tmie', 'wt']

# ----------------------
# paths
# ----------------------
base_folder = 'path/to/your/data/'
summary_folder = os.path.join(base_folder, "summary")
sample_info_path = os.path.join(summary_folder, "sample_info.csv")
results_file_path = os.path.join(summary_folder, 'vg_averageing_labels.csv')
vg_size_path = os.path.join(summary_folder, 'vg_size.csv')

print(f'Root directory: {base_folder}')
print(f'Summary folder: {summary_folder}')
print(f'Number of simulations: {n_simulations}')
print(f'Number of x-bins: {num_xbins}, Number of y-bins: {num_ybins}')
print(f'Correction method: {correction_method}')
print(f'Sample types to compare: {sample_types}')

In [None]:
# ----------------------
# Data loading & pre-processing
# ----------------------
def load_data(sample_type):
    """Load all necessary data for a given sample type."""
    try:
        # load sample data
        sample_data_path = os.path.join(summary_folder, sample_type + "_summary_all.csv")
        sample_data = pd.read_csv(sample_data_path)
        print(f"Successfully loaded {sample_type} data: {len(sample_data)} rows")
        
        # load BX/BY data
        results_data = pd.read_csv(results_file_path, header=None)
        BX = results_data.iloc[:, 0].values
        BY = results_data.iloc[:, 1].values
        print(f"Successfully loaded BX/BY data: lengths {len(BX)}, {len(BY)}")
        
        return sample_data, BX, BY
    except Exception as e:
        print(f"Error loading {sample_type} data: {e}")
        raise

def preprocess_data(data, vg_size):
    """Pre-process data: de-duplicate, compute standardised coordinates, etc."""
    # de-duplicate
    original_count = len(data)
    data = data.drop_duplicates(keep='first')
    deduplicated_count = len(data)
    removed_count = original_count - deduplicated_count
    print(f"After de-duplication: {deduplicated_count} records ({removed_count} duplicates removed)")
    
    # compute standardised coordinates
    avg_vg_size = vg_size.mean(axis=0).values
    print(f"Average VG size: {avg_vg_size}")
    
    # locate coordinate columns
    x_col = next((col for col in data.columns if col.endswith('warpedROIvar1')), None)
    y_col = next((col for col in data.columns if col.endswith('warpedROIvar2')), None)
    
    if not x_col or not y_col:
        missing = []
        if not x_col: missing.append("column ending with warpedROIvar1")
        if not y_col: missing.append("column ending with warpedROIvar2")
        raise ValueError(f"Required coordinate columns missing: {', '.join(missing)}")
    
    print(f"Coordinate columns found: X={x_col}, Y={y_col}")
    
    # calculate standardised coordinates
    data['standardized_x'] = data[x_col] / avg_vg_size[0]
    data['standardized_y'] = data[y_col] / avg_vg_size[1]
    
    # extract positive-marker columns
    is_positive_cols = [col for col in data.columns if col.endswith('_isPositive')]
    if not is_positive_cols:
        raise ValueError("No columns ending with '_isPositive' found; please check column names")
    print(f"{len(is_positive_cols)} positive-marker columns found")
    
    # sub-sample if required
    if cell_num > 0:
        data = two_step_stratified_sample(data, cell_num, random_num)
    
    return data, is_positive_cols, x_col, y_col

# load VG size data
vg_size = pd.read_csv(vg_size_path, header=None)

# load & pre-process both sample types
data1, BX, BY = load_data(sample_types[0])
data2, BX, BY = load_data(sample_types[1])
data1, positive_cols1, x_col, y_col = preprocess_data(data1, vg_size)
data2, positive_cols2, x_col, y_col = preprocess_data(data2, vg_size)

# ensure both samples share the same positive-marker columns
common_markers = set(positive_cols1)
common_markers.intersection_update(positive_cols2)

common_markers = sorted(list(common_markers))
print(f"\nMarkers common to both samples: {common_markers}")

# compute BX/BY boundaries (using combined range of both samples)
all_BX = BX
all_BY = BY

order_list = [0, 7, 2, 6, 4, 5, 3, 1, 0]
BX_ordered = BX[order_list]
BY_ordered = BY[order_list]
cs = CubicSpline(np.arange(len(BX_ordered)), np.c_[BX_ordered, BY_ordered], 
                axis=0, bc_type='periodic')
t_fine = np.linspace(0, len(BX_ordered) - 1, 300)
x_fine, y_fine = cs(t_fine).T

actual_bx_min, actual_bx_max = x_fine.min(), x_fine.max()
actual_by_min, actual_by_max = y_fine.min(), y_fine.max()

# create bins
bx_bins, by_bins = create_position_bins(
    actual_bx_min, actual_bx_max,
    actual_by_min, actual_by_max,
    num_xbins, num_ybins
)

print(f"BX range: {actual_bx_min:.2f} to {actual_bx_max:.2f}")
print(f"BY range: {actual_by_min:.2f} to {actual_by_max:.2f}")

In [None]:
# ----------------------
# Modified plotting function: only the difference in count proportion
# ----------------------
def plot_count_diff_heatmap(
    result,
    gene,
    bx_min, bx_max, by_min, by_max,
    BX, BY,
    output_folder,
    zero_color=[0.5, 0.5, 0.5],
    order_list=[0, 7, 2, 6, 4, 5, 3, 1, 0]
):
    """Plot inter-sample count difference: higher in tmie → red, lower → blue."""
    os.makedirs(output_folder, exist_ok=True)

    # Decide which sample is tmie
    if 'tmie' in sample_types[0]:
        tmie_idx = 0
        other_idx = 1
    elif 'tmie' in sample_types[1]:
        tmie_idx = 1
        other_idx = 0
    else:
        print("⚠️ tmie sample not identified; using default order")
        tmie_idx = 1
        other_idx = 0

    # Difference: tmie − other (positive = red, negative = blue)
    tmie_counts  = result['real_counts1'] if tmie_idx == 0 else result['real_counts2']
    other_counts = result['real_counts1'] if other_idx == 0 else result['real_counts2']
    count_diff   = tmie_counts - other_counts

    rows, cols = count_diff.shape if count_diff is not None else (0, 0)

    # Data check
    if rows == 0 or cols == 0:
        print(f"⚠️ No count-difference data for {gene}, skipping plot")
        return

    # ---------------------- Canvas setup ----------------------
    fig = plt.figure(figsize=(8, 6))
    ax  = plt.gca()

    # Hide top & right spines
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)

    # ---------------------- Landmarks & smooth contour ----------------------
    safe_order   = [i for i in order_list if i < len(BX)] if len(BX) > 0 else []
    BX_ordered   = BX[safe_order] if len(safe_order) > 0 else BX
    BY_ordered   = BY[safe_order] if len(safe_order) > 0 else BY

    if len(BX_ordered) >= 3:
        from scipy.interpolate import CubicSpline
        cs = CubicSpline(np.arange(len(BX_ordered)),
                         np.c_[BX_ordered, BY_ordered],
                         axis=0, bc_type='periodic')
        t_fine = np.linspace(0, len(BX_ordered) - 1, 300)
        x_fine, y_fine = cs(t_fine).T
        ax.plot(x_fine, y_fine, color=1 - np.array(zero_color), lw=2)
    ax.plot(BX, BY, '+', ms=10, lw=2, color=np.array(zero_color) * 0.6)

    # ---------------------- Main heatmap: count difference ----------------------
    max_abs = np.max(np.abs(count_diff))
    vmin, vmax = -max_abs, max_abs

    im = ax.pcolormesh(
        np.linspace(bx_min, bx_max, cols),
        np.linspace(by_min, by_max, rows),
        count_diff,
        cmap='coolwarm',  # blue → white → red
        vmin=vmin,
        vmax=vmax
    )
    ax.set_xlim(bx_min, bx_max)
    ax.set_ylim(by_min, by_max)
    ax.set_title(f'{gene} - count difference (tmie − {sample_types[other_idx]})', fontsize=12, pad=10)

    # ---------------------- Colorbar ----------------------
    cbar_ax = fig.add_axes([0.85, 0.15, 0.02, 0.15])
    cbar = plt.colorbar(im, cax=cbar_ax)
    cbar.set_label('count difference', rotation=270, labelpad=15)
    cbar.outline.set_visible(False)

    cbar_ax.text(0.5, 1.10, 'higher in tmie →',
                 ha='center', va='bottom', fontsize=10, transform=cbar_ax.transAxes)
    cbar_ax.text(0.5, -0.20, '← lower in tmie',
                 ha='center', va='top', fontsize=10, transform=cbar_ax.transAxes)

    # ---------------------- Save difference heatmap ----------------------
    plt.subplots_adjust(left=0.1, right=0.9, bottom=0.15, top=0.85)
    output_path = os.path.join(output_folder, f'{gene}_count_diff.png')
    fig.savefig(output_path, dpi=1200, bbox_inches='tight')
    plt.close(fig)
    print(f"✅ Count-difference heatmap saved: {output_path}")

In [None]:
# ----------------------
# 2. Single-panel function 2: heatmap of inter-sample p-value differences
# ----------------------
def plot_sample_diff_p_heatmap(
    result,
    gene,
    bx_min, bx_max, by_min, by_max,
    BX, BY,
    output_folder,
    zero_color=[0.5, 0.5, 0.5],
    p_threshold=0.05,
    order_list=[0, 7, 2, 6, 4, 5, 3, 1, 0]
):
    """Plot the corrected p-value heatmap of inter-sample differences (no extra subplots)."""
    os.makedirs(output_folder, exist_ok=True)
    corrected_p = result['corrected_p']
    rows, cols = corrected_p.shape if corrected_p is not None else (0, 0)

    # Skip if p-value matrix is empty
    if rows == 0 or cols == 0:
        print(f"⚠️ No p-value data for {gene}, skipping p-value heatmap")
        return

    # -------------------------- Single-panel layout (p-value heatmap only) --------------------------
    fig = plt.figure(figsize=(8, 6))
    ax = plt.gca()  # one single main axis

    # Hide top & right spines
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)

    # -------------------------- Landmarks & smooth contour (safe handling) --------------------------
    safe_order = [i for i in order_list if i < len(BX)] if len(BX) > 0 else []
    BX_ordered = BX[safe_order] if len(safe_order) > 0 else BX
    BY_ordered = BY[safe_order] if len(safe_order) > 0 else BY

    if len(BX_ordered) >= 3:
        cs = CubicSpline(np.arange(len(BX_ordered)), np.c_[BX_ordered, BY_ordered],
                        axis=0, bc_type='periodic')
        t_fine = np.linspace(0, len(BX_ordered) - 1, 300)
        x_fine, y_fine = cs(t_fine).T
        ax.plot(x_fine, y_fine, color=1 - np.array(zero_color), linewidth=2)
    ax.plot(BX, BY, '+', markersize=10, linewidth=2, color=np.array(zero_color) * 0.6)

    # -------------------------- Main panel: corrected p-value heatmap --------------------------
    # Clip p-values above threshold to avoid colour overflow
    p_clipped = np.clip(corrected_p, 0, p_threshold)
    im = ax.pcolormesh(
        np.linspace(bx_min, bx_max, cols),
        np.linspace(by_min, by_max, rows),
        p_clipped,
        cmap='Reds_r',  # darker red = smaller p-value
        vmin=0,
        vmax=p_threshold
    )
    ax.set_xlim(bx_min, bx_max)
    ax.set_ylim(by_min, by_max)
    ax.set_title(f'{gene} - {sample_types[0]} vs {sample_types[1]} FDR Heatmap', fontsize=12, pad=10)

    # -------------------------- Colour bar (separate axes) --------------------------
    cbar_ax = fig.add_axes([0.85, 0.15, 0.02, 0.15])
    cbar = plt.colorbar(im, cax=cbar_ax)
    cbar.set_ticks([0, p_threshold / 2, p_threshold])
    cbar.set_ticklabels([f'{x:.3f}' for x in [0, p_threshold / 2, p_threshold]])
    cbar.outline.set_visible(False)
    cbar_ax.text(
        x=0.5, y=1.10, s='FDR',
        ha='center', va='bottom', fontsize=10, transform=cbar_ax.transAxes
    )

    # -------------------------- Save single p-value figure --------------------------
    plt.subplots_adjust(left=0.1, right=0.9, bottom=0.15, top=0.85)
    output_path = os.path.join(output_folder, f'{gene}_diff_p.png')
    fig.savefig(output_path, dpi=1200, bbox_inches='tight')
    plt.close(fig)
    print(f"✅ Saved {gene} difference p-value heatmap: {output_path}")

In [None]:
# ----------------------
# Revised Monte-Carlo simulation function
# ----------------------
def monte_carlo_simulation(data1, data2, is_positive_col=None,
                          x_col='standardized_x', y_col='standardized_y',
                          bx_bins=None, by_bins=None, n_simulations=10000):
    """
    Monte-Carlo simulation supporting two modes:
    1. When data2 is provided: compare the spatial distribution of is_positive_col
       between data1 and data2.
    2. When data2 is None: compare the observed distribution in data1 against
       a null distribution obtained by random shuffling.

    Parameters
    ----------
    data1 : pd.DataFrame
        Primary dataset.
    data2 : pd.DataFrame | None
        Second dataset for comparison (mode 1); if None, mode 2 is used.
    is_positive_col : str
        Name of the positive-marker column.
    x_col, y_col : str
        Coordinate columns.
    bx_bins, by_bins : array-like
        Bin edges for x and y axes.
    n_simulations : int
        Number of Monte-Carlo iterations.

    Returns
    -------
    dict
        Dictionary containing real counts, simulated counts, and p-values.
    """

    # ---------- Mode 1: two-sample comparison ----------
    if data2 is not None and is_positive_col is not None:
        # Filter valid entries
        filtered1 = data1[data1[is_positive_col] != -1].copy()
        filtered1[is_positive_col] = filtered1[is_positive_col].astype(bool)
        sum1 = np.sum(filtered1[is_positive_col])

        filtered2 = data2[data2[is_positive_col] != -1].copy()
        filtered2[is_positive_col] = filtered2[is_positive_col].astype(bool)
        sum2 = np.sum(filtered2[is_positive_col])

        # Real observed counts
        counts1 = count_positive_in_bins(filtered1, is_positive_col, x_col, y_col, bx_bins, by_bins)
        counts2 = count_positive_in_bins(filtered2, is_positive_col, x_col, y_col, bx_bins, by_bins)

        # Real difference (proportions)
        real_diff = counts1 / sum1 - counts2 / sum2
        total_real = np.sum(np.abs(real_diff))
        if total_real == 0:
            print(f"⚠️ Real difference for {is_positive_col} is zero between samples, skipping simulation")
            return None

        # Storage for simulated differences
        sim_diffs = np.zeros((n_simulations, *real_diff.shape))

        # Monte-Carlo loop
        for i in tqdm(range(n_simulations), desc=f"MC sim {is_positive_col} sample comparison"):
            # Shuffle positive labels within each sample
            shuffled1 = filtered1[is_positive_col].values.copy()
            shuffled2 = filtered2[is_positive_col].values.copy()
            np.random.shuffle(shuffled1)
            np.random.shuffle(shuffled2)

            temp1 = filtered1.copy()
            temp1[is_positive_col] = shuffled1
            temp2 = filtered2.copy()
            temp2[is_positive_col] = shuffled2

            sim_counts1 = count_positive_in_bins(temp1, is_positive_col, x_col, y_col, bx_bins, by_bins)
            sim_counts2 = count_positive_in_bins(temp2, is_positive_col, x_col, y_col, bx_bins, by_bins)

            sim_diffs[i] = sim_counts1 / sum1 - sim_counts2 / sum2

        # P-value: proportion of |simulated| ≥ |real|
        p_values = (np.sum(np.abs(sim_diffs) >= np.abs(real_diff), axis=0) + 1) / (n_simulations + 1)

        return {
            'is_positive_col': is_positive_col,
            'gene': is_positive_col.replace('_isPositive', ''),
            'real_counts1': counts1 / sum1,
            'real_counts2': counts2 / sum2,
            'real_diff': real_diff,
            'sim_diffs': sim_diffs,
            'p_values': p_values,
            'corrected_p': None
        }


In [None]:
# ----------------------
# Execute analysis
# ----------------------
# ----------------------
# Save results & plotting
# ----------------------
min_count = 0
max_count = 0.02

# Create output folder
if cell_num > 0:
    comparison_results_folder = os.path.join(
        summary_folder, 
        f"grid_comparison_{sample_types[0]}_vs_{sample_types[1]}_{cell_num}"
    )
else:
    comparison_results_folder = os.path.join(
        summary_folder, 
        f"grid_comparison_{sample_types[0]}_vs_{sample_types[1]}"
    )
os.makedirs(comparison_results_folder, exist_ok=True)

print("\n===== Starting Monte-Carlo comparison =====")
comparison_results = []

# Compare every common marker between the two samples
for marker in common_markers:
    gene = marker.replace('_isPositive', '')
    x_col = 'standardized_x'
    y_col = 'standardized_y'

    # Ensure both samples have the marker and coordinates
    valid = True
    for st in sample_types:
        if x_col not in data1.columns or y_col not in data2.columns:
            print(f"⚠️ Warning: {st} missing standardised coordinates, skipping {marker}")
            valid = False
            break
        if marker not in common_markers:
            print(f"⚠️ Warning: {st} missing marker {marker}, skipping")
            valid = False
            break

    if not valid:
        continue

    print(f"\n===== Comparing {marker} between {sample_types[0]} and {sample_types[1]} =====")

    # Run two-sample Monte-Carlo simulation
    result = monte_carlo_simulation(
        data1, data2, marker,
        x_col, y_col, bx_bins, by_bins, n_simulations
    )

    if result is not None:
        total_diff = np.sum(np.abs(result['real_diff']))
        print(f"Total real difference for {marker}: {total_diff}")

        if total_diff > 0:
            # Correct p-values
            result['corrected_p'] = correct_grid_pvalues(result['p_values'], method=correction_method)
            comparison_results.append(result)
            print(f"✅ Comparison for {marker} completed")
        else:
            print(f"❌ Total real difference for {marker} is zero, not added to results")
    else:
        print(f"❌ Comparison for {marker} failed, no result generated")

# ----------------------
# Revised plotting logic: only difference and p-value heatmaps
# ----------------------
if comparison_results:
    print("\n===== Saving inter-sample comparison results and plotting difference heatmaps =====")
    for result in comparison_results:
        gene = result['gene']

        # 1. Save raw data
        np.savetxt(
            os.path.join(comparison_results_folder, f'{gene}_{sample_types[0]}_counts.csv'),
            result['real_counts1'], delimiter=',', fmt='%d'
        )
        np.savetxt(
            os.path.join(comparison_results_folder, f'{gene}_{sample_types[1]}_counts.csv'),
            result['real_counts2'], delimiter=',', fmt='%d'
        )
        np.savetxt(
            os.path.join(comparison_results_folder, f'{gene}_diff_counts.csv'),
            result['real_diff'], delimiter=',', fmt='%d'
        )
        np.savetxt(
            os.path.join(comparison_results_folder, f'{gene}_corrected_p.csv'),
            result['corrected_p'], delimiter=',', fmt='%.6f'
        )

        # 2. Data validation
        required_cols = [f'{gene}_isPositive', 'standardized_x', 'standardized_y']
        s1_valid = all(col in data1.columns for col in required_cols)
        s2_valid = all(col in data2.columns for col in required_cols)

        if not s1_valid or not s2_valid:
            missing = []
            if not s1_valid:
                missing.append(
                    f"{sample_types[0]}: "
                    f"{', '.join([c for c in required_cols if c not in data1.columns])}"
                )
            if not s2_valid:
                missing.append(
                    f"{sample_types[1]}: "
                    f"{', '.join([c for c in required_cols if c not in data2.columns])}"
                )
            print(f"⚠️ {gene} missing columns: {'; '.join(missing)}, skipping plots")
            continue

        # 3. Plot only difference and p-value heatmaps
        # Difference heatmap (red = higher in tmie, blue = lower)
        plot_count_diff_heatmap(
            result=result,
            gene=gene,
            bx_min=bx_min,
            bx_max=bx_max,
            by_min=by_min,
            by_max=by_max,
            BX=BX,
            BY=BY,
            output_folder=comparison_results_folder
        )

        # p-value heatmap of inter-sample difference
        plot_sample_diff_p_heatmap(
            result=result,
            gene=gene,
            bx_min=bx_min,
            bx_max=bx_max,
            by_min=by_min,
            by_max=by_max,
            BX=BX,
            BY=BY,
            output_folder=comparison_results_folder
        )
else:
    print("\n⚠️ No valid comparison results generated")

print("\n===== All plotting finished =====")

In [None]:
def xcol_two_sample_monte(data1, data2, is_positive_col, x_col, bx_bins, n_simulations):
    """
    Assess significance of between-sample differences along the X-axis.
    
    Logic: obtain per-sample column counts first, then compare differences
    via Monte-Carlo simulation.
    """
    # Remove invalid data and convert to Boolean
    filtered1 = data1[data1[is_positive_col] != -1].copy()
    filtered1[is_positive_col] = filtered1[is_positive_col].astype(bool)
    filtered2 = data2[data2[is_positive_col] != -1].copy()
    filtered2[is_positive_col] = filtered2[is_positive_col].astype(bool)

    # STEP 1: compute real X-bin counts for each sample
    result1 = count_positive_in_x_bins(filtered1, is_positive_col, x_col, bx_bins)
    result2 = count_positive_in_x_bins(filtered2, is_positive_col, x_col, bx_bins)
    sum1 = np.sum(result1)
    sum2 = np.sum(result2)

    if result1 is None or result2 is None:
        print(f"⚠️ {is_positive_col}: single-sample x-bin counting failed, two-sample comparison skipped")
        return None

    # Normalised column counts
    real_counts1 = result1 / sum1   # sample-1 proportions  (shape: num_xbins,)
    real_counts2 = result2 / sum2   # sample-2 proportions  (shape: num_xbins,)

    # Observed difference  (sample1 - sample2, can be changed to tmie - wt if desired)
    real_diff = real_counts1 - real_counts2
    total_real_diff = np.sum(np.abs(real_diff))
    if total_real_diff == 0:
        print(f"⚠️ {is_positive_col}: total real difference between samples is zero, simulation skipped")
        return None

    # STEP 2: Monte-Carlo simulation  (shuffle labels and re-compute differences)
    orig1 = filtered1[is_positive_col].values.copy()   # keep original labels
    sim_diffs = np.zeros((n_simulations, len(real_diff)))

    for i in tqdm(range(n_simulations), desc=f"Simulating column diff {is_positive_col}"):
        # shuffle sample-1 labels  (sample-2 can be treated likewise if needed)
        shuffled1 = orig1.copy()
        np.random.shuffle(shuffled1)

        temp1 = filtered1.copy()
        temp1[is_positive_col] = shuffled1
        temp2 = filtered2.copy()   # shuffle similarly if required

        # simulated counts
        sim_result1 = count_positive_in_x_bins(temp1, is_positive_col, x_col, bx_bins)
        sim_result2 = count_positive_in_x_bins(temp2, is_positive_col, x_col, bx_bins)
        if sim_result1 is not None and sim_result2 is not None:
            sim_diff = sim_result1 / sum1 - sim_result2 / sum2
            sim_diffs[i] = sim_diff

    # STEP 3: compute p-values  (proportion of |simulated| ≥ |real|)
    p_matrix = np.abs(sim_diffs) >= np.abs(real_diff)
    p_values = (np.sum(p_matrix, axis=0) + 1) / (n_simulations + 1)

    # Return column-difference results
    return {
        'is_positive_col': is_positive_col,
        'gene': is_positive_col.replace('_isPositive', ''),
        'real_counts1': real_counts1,   # sample-1 X-bin proportions
        'real_counts2': real_counts2,   # sample-2 X-bin proportions
        'real_diff': real_diff,         # column difference  (sample1 - sample2)
        'sim_diffs': sim_diffs,         # array of simulated differences
        'p_values': p_values,           # uncorrected p-values
        'corrected_p': None             # slot for corrected p-values
    }

In [None]:
# ----------------------
# 5. Batch-run column-wise difference analysis for all markers (code unchanged)
# ----------------------
print("\n===== Starting two-sample column-difference significance analysis along X-axis =====")
xcol_diff_results = []  # store results for all markers

for col in common_markers:
    gene = col.replace('_isPositive', '')
    x_col = 'standardized_x'  # same as in the main pipeline

    # check coordinate column existence
    if x_col not in data1.columns or x_col not in data2.columns:
        print(f"⚠️ Warning: {sample_types[0]} (wt) or {sample_types[1]} (tmie) missing X-axis coordinate column {x_col}, skipping {col}")
        continue

    print(f"\n===== Analysing marker: {col} (gene: {gene}) =====")
    print(f"X-axis coordinate column: {x_col}")
    print(f"Number of X-bins: {num_xbins}")

    # perform two-sample column-difference analysis (function assumed pre-defined)
    result = xcol_two_sample_monte(
        data1=data1,
        data2=data2,
        is_positive_col=col,
        x_col=x_col,
        bx_bins=bx_bins,
        n_simulations=n_simulations
    )

    if result is not None:
        total_real = np.sum(np.abs(result['real_diff']))
        print(f"Column-difference sum for {col}: {total_real}")

        if total_real > 0:
            # correct p-values (correct_grid_pvalues assumed pre-defined)
            result['corrected_p'] = correct_grid_pvalues(result['p_values'], method=correction_method)

            # store proportions for later plotting
            sum1 = np.sum(result['real_counts1'])
            sum2 = np.sum(result['real_counts2'])
            result['tmie_proportion'] = result['real_counts1']
            result['wt_proportion'] = result['real_counts2']
            result['proportion_diff'] = result['tmie_proportion'] - result['wt_proportion']

            xcol_diff_results.append(result)
            print(f"✅ Column-difference analysis for {col} completed (proportions & differences calculated)")
        else:
            print(f"❌ Column-difference sum for {col} is zero, not added to results")
    else:
        print(f"❌ Column-difference analysis for {col} failed, no result generated")

print(f"\n===== X-axis column-difference analysis finished: {len(xcol_diff_results)} marker(s) succeeded =====")

In [None]:
if xcol_diff_results:
    print("\n===== Saving X-axis column-difference results (including proportions & differences) =====")
    # Define paths (separate result files and visualizations)
    results_folder = comparison_results_folder
    visualization_folder = os.path.join(comparison_results_folder, "grid_visualizations")
    os.makedirs(visualization_folder, exist_ok=True)

    for result in xcol_diff_results:
        gene = result['gene']

        # ---------------------- Original file saves (kept) ----------------------
        # 1. Real counts (wt=data1, tmie=data2)
        np.savetxt(
            os.path.join(results_folder, f'{gene}_wt_xcol_counts.csv'),
            result['real_counts1'],
            delimiter=',',
            fmt='%.6f'
        )
        np.savetxt(
            os.path.join(results_folder, f'{gene}_tmie_xcol_counts.csv'),
            result['real_counts2'],
            delimiter=',',
            fmt='%.6f'
        )

        # 2. Count proportions (wt and tmie)
        np.savetxt(
            os.path.join(results_folder, f'{gene}_wt_xcol_proportions.csv'),
            result['wt_proportion'],
            delimiter=',',
            fmt='%.6f'
        )
        np.savetxt(
            os.path.join(results_folder, f'{gene}_tmie_xcol_proportions.csv'),
            result['tmie_proportion'],
            delimiter=',',
            fmt='%.6f'
        )

        # ---------------------- NEW: save proportion-difference file (tmie − wt) ----------------------
        np.savetxt(
            os.path.join(results_folder, f'{gene}_tmie_minus_wt_proportion_diff.csv'),
            result['proportion_diff'],
            delimiter=',',
            fmt='%.6f'
        )

        # 3. Original difference & p-value files
        np.savetxt(
            os.path.join(results_folder, f'{gene}_xcol_real_diff.csv'),
            result['real_diff'],
            delimiter=',',
            fmt='%.6f'
        )
        np.savetxt(
            os.path.join(results_folder, f'{gene}_xcol_corrected_p.csv'),
            result['corrected_p'],
            delimiter=',',
            fmt='%.6f'
        )

        print(f"✅ Results for {gene} saved:")
        print(f"   - Proportion-difference file: {gene}_tmie_minus_wt_proportion_diff.csv")
        print(f"   - Other files: counts / proportions / p-values (original format)")

    # 1. Extract X-axis bin information
    x_bin_edges = np.linspace(x_fine.min(), x_fine.max(), num_xbins + 1)  # use original x-bin boundaries
    x_bin_labels = [f"Bin{i+1}\n[{x_bin_edges[i]:.2f}-{x_bin_edges[i+1]:.2f}]"
                    for i in range(len(x_bin_edges) - 1)]
    x_bin_num = len(x_bin_labels)  # total number of x-bins

    # 2. Merge difference and p-value data for all genes
    marker_list = []               # valid gene list
    proportion_diff_all = []       # proportion difference (tmie − wt): rows=genes, cols=bins
    corrected_p_all = []           # corrected p-values

    # ---------- 1. Compute 20 bin indices ----------
    min_bin_idx = np.digitize(x_fine.min(), bx_bins) - 1
    max_bin_idx = np.digitize(x_fine.max(), bx_bins) - 1
    zero_bin_idx = np.digitize(0, bx_bins) - 1
    x_bins_indices = np.arange(min_bin_idx, min_bin_idx + 20)  # fixed 20 bins
    x_offsets = x_bins_indices - zero_bin_idx  # offset relative to 0

    # ---------- 2. Pick ticks every 5 steps, centred at 0, only within existing 20 offsets ----------
    step = 5
    low = int(np.floor(x_offsets.min() / step) * step)
    high = int(np.ceil(x_offsets.max() / step) * step)
    tick_labels = np.arange(low, high + 1, step)  # multiples of 5
    tick_labels = tick_labels[np.isin(tick_labels, x_offsets)]  # ensure existence

    # ---------- 3. Corresponding plot positions ----------
    tick_positions = [np.where(x_offsets == lab)[0][0] for lab in tick_labels]

    for result in xcol_diff_results:
        gene = result['gene']
        diff = result['proportion_diff']
        corr_p = result['corrected_p']

        diff = diff[min_bin_idx:max_bin_idx]
        corr_p = corr_p[min_bin_idx:max_bin_idx]

        marker_list.append(gene)
        proportion_diff_all.append(diff)
        corrected_p_all.append(corr_p)

    diff_df = pd.DataFrame(proportion_diff_all,
                          index=marker_list,
                          columns=[f"Bin{i+1}" for i in range(x_bin_num)])
    corr_p_df = pd.DataFrame(corrected_p_all,
                            index=marker_list,
                            columns=[f"Bin{i+1}" for i in range(x_bin_num)])

    diff_df.to_csv(os.path.join(results_folder, 'all_tmie_minus_wt_proportion_diff.csv'), index=True)
    corr_p_df.to_csv(os.path.join(results_folder, 'all_tmie_minus_wt_p_adj.csv'), index=True)

In [None]:
diff_df = pd.read_csv(os.path.join(results_folder, 'all_tmie_minus_wt_proportion_diff.csv'),
                      index_col=0, header=0)
corr_p_df = pd.read_csv(os.path.join(results_folder, 'all_tmie_minus_wt_p_adj.csv'),
                        index_col=0, header=0)

# ----------------------
# 6. Save analysis results (including proportion-difference files)
# ----------------------
target_marker_order = ['a', 'b', 'c', 'd', 'e']  # top-to-bottom order

if xcol_diff_results:
    # ---------------------- Core update: plot "proportion difference + corrected p" twin heatmaps ----------------------
    print("\n===== Plotting integrated heatmap: proportion difference & corrected p-values =====")
    # Plotting parameters (adjustable)
    diff_abs_max = 0.1   # maximum absolute proportion difference for symmetric colour scale (e.g. ±0.1)
    p_threshold = 0.05   # significance threshold

    # 1. Extract X-axis bin information
    x_bin_edges = np.linspace(x_fine.min(), x_fine.max(), num_xbins + 1)  # original bin boundaries
    x_bin_labels = [f"Bin{i+1}\n[{x_bin_edges[i]:.2f}-{x_bin_edges[i+1]:.2f}]"
                    for i in range(len(x_bin_edges) - 1)]
    x_bin_num = len(x_bin_labels)  # total number of x-bins

    # 3. Skip if no valid data
    if len(marker_list) == 0:
        print("⚠️ No valid gene data, heatmap plotting skipped")
    else:
        # 4. Convert to DataFrame
        # 1. Filter: keep only target Markers
        diff_df_filtered = diff_df.loc[diff_df.index.isin(target_marker_order)]   # count data filter
        corr_p_df_filtered = corr_p_df.loc[corr_p_df.index.isin(target_marker_order)]  # p-value filter

        # 2. Reorder rows according to target_marker_order (core step)
        diff_df = diff_df_filtered.reindex(target_marker_order)   # proportion difference in target order
        corr_p_df = corr_p_df_filtered.reindex(target_marker_order)  # p-values in target order

        # 5. Data preprocessing
        diff_clipped = diff_df.clip(lower=-diff_abs_max, upper=diff_abs_max)  # clip to avoid colour overflow
        corr_p_clipped = corr_p_df.clip(upper=p_threshold)                    # cap p-values at threshold

        # 6. Create twin subplots (top: proportion difference; bottom: corrected p-values)
        fig, (ax1, ax2) = plt.subplots(
            2, 1,
            figsize=(12, 1 * len(target_marker_order)),  # height adapts to gene number
            gridspec_kw={'height_ratios': [1, 0.8]}      # difference map slightly taller
        )
        fig.suptitle("X-axis Bins Analysis: wt vs tmie", fontsize=14, y=0.98)

        # ---------------------- Subplot 1: proportion-difference heatmap (core update) ----------------------
        # Colour map: RdBu_r (blue → white → red, reversed: red = tmie higher)
        im1 = ax1.imshow(
            diff_clipped.values,
            cmap='RdBu_r',        # red = tmie higher, blue = tmie lower, white = no difference
            aspect='auto',
            vmin=-diff_abs_max,   # symmetric colour scale
            vmax=diff_abs_max
        )
        # Axis settings
        ax1.set_xticks(tick_positions)
        ax1.set_xticklabels(tick_labels, fontsize=15)
        ax1.set_xlabel("offset from border", fontsize=10)

        ax1.set_yticks(range(len(target_marker_order)))
        ax1.set_yticklabels([word.capitalize() for word in diff_clipped.index], fontsize=15)
        ax1.set_ylabel("Genes", fontsize=12, labelpad=10)
        ax1.set_title(
            f"Proportion Difference (tmie - wt)\n(Blue: tmie < wt | Red: tmie > wt)",
            fontsize=15, pad=8
        )

        # Colour bar for difference map (with ± labels)
        cbar1 = plt.colorbar(im1, ax=ax1, shrink=0.8, aspect=20)
        cbar1.outline.set_visible(False)
        cbar1.set_label("tmie - wt Proportion", fontsize=15, labelpad=8)
        cbar1.set_ticks([-diff_abs_max, 0, diff_abs_max])
        cbar1.set_ticklabels([f"-{diff_abs_max:.3f}", "0", f"+{diff_abs_max:.3f}"])

        # ---------------------- Subplot 2: corrected p-value heatmap (kept for significance) ----------------------
        im2 = ax2.imshow(
            corr_p_clipped.values,
            cmap='Reds_r',       # red = small p (significant), white = large p (non-significant)
            aspect='auto',
            vmin=0,
            vmax=p_threshold
        )
        # Axis settings (hide y-labels, same gene order as top panel)
        ax2.set_xticks(tick_positions)
        ax2.set_xticklabels(tick_labels, fontsize=15)
        ax2.set_xlabel("offset from border", fontsize=10)

        ax2.set_yticks(range(len(target_marker_order)))
        ax2.set_yticklabels([word.capitalize() for word in diff_clipped.index], fontsize=15)
        ax2.set_title(f"q-value (FDR ≤ {p_threshold})", fontsize=15, pad=8)

        # Colour bar for p-values
        cbar2 = plt.colorbar(im2, ax=ax2, shrink=0.8, aspect=20)
        cbar2.outline.set_visible(False)
        cbar2.set_label("FDR", fontsize=10, labelpad=8)
        cbar2.set_ticks([0, p_threshold / 2, p_threshold])
        cbar2.set_ticklabels([f"{x:.3f}" for x in [0, p_threshold / 2, p_threshold]])

        # 7. Adjust layout to avoid label clipping
        plt.tight_layout(rect=[0, 0.03, 1, 0.95])

        # 8. Save high-resolution heatmap
        heatmap_path = os.path.join(visualization_folder, "wt_vs_tmie_proportion_diff_heatmap")
        plt.savefig(heatmap_path + ".png", dpi=1200, bbox_inches='tight')
        plt.savefig(heatmap_path + ".pdf", dpi=1200, bbox_inches='tight')
        plt.savefig(heatmap_path + ".eps", format='eps', dpi=1200, bbox_inches='tight')
        plt.close(fig)

        print(f"✅ Integrated heatmap saved: {heatmap_path}")
        print(f"   - Top panel: proportion difference (red = tmie higher, blue = tmie lower)")
        print(f"   - Bottom panel: significance p-values (red = significant, white = non-significant)")

else:
    print("\n⚠️ No valid X-axis column-difference results, skipping saving and plotting")

print(f"\n===== All X-axis column-difference analyses finished! Results path: {comparison_results_folder} =====")