# Bone Age Estimation based on the Spatial Information

##  0. Libraries and Parameters

In [5]:
# Standard library
import itertools
import os
import pickle
from datetime import datetime

# Scientific computing
import numpy as np
import pandas as pd
from scipy.interpolate import interp1d, Rbf
from scipy.ndimage import gaussian_filter1d
from scipy.optimize import curve_fit, minimize, root_scalar
from scipy.special import rel_entr
from scipy.spatial.distance import jensenshannon
from scipy.stats import gaussian_kde, wasserstein_distance

# Plotting
import matplotlib.pyplot as plt
from matplotlib.colors import Normalize
from matplotlib.lines import Line2D  
import seaborn as sns

# Image processing
from PIL import Image
from skimage import measure, morphology
from skimage.metrics import structural_similarity as compare_ssim

# Spatial
import geopandas as gpd
from matplotlib.path import Path
from shapely.geometry import Polygon, box

# File reader
from imaris_ims_file_reader.ims import ims

# Machine learning
import joblib
import shap
from sklearn.gaussian_process import GaussianProcessRegressor
from sklearn.gaussian_process.kernels import RBF, ConstantKernel as C
from sklearn.inspection import permutation_importance
from sklearn.metrics import (
    mean_absolute_error,
    mean_squared_error,
    r2_score,
    explained_variance_score,
)
from sklearn.model_selection import KFold, cross_val_predict
from sklearn.preprocessing import StandardScaler

In [None]:
# Create results directory with timestamp
results_dir = "results_bone_age"
os.makedirs(results_dir, exist_ok=True)

# Define directory to save models
model_save_dir = os.path.join(results_dir, "models")
os.makedirs(model_save_dir, exist_ok=True)


# Increase the maximum image size limit
Image.MAX_IMAGE_PIXELS = None

# Set grid off by default
plt.rcParams["axes.grid"] = False

# Set default font sizes globally
plt.rcParams["font.size"] = 12          # General font size
plt.rcParams["legend.fontsize"] = 12    # Legend font size
plt.rcParams["axes.titlesize"] = 14      # Title font size
plt.rcParams["axes.labelsize"] = 12     # X and Y axis labels
plt.rcParams["xtick.labelsize"] = 12    # X-axis tick labels
plt.rcParams["ytick.labelsize"] = 12     # Y-axis tick labels



In [3]:
# Colors
hsc_rd_colors = {
    "HSCs": "#FF0000", # Red
    "RDs": "#808080" # Grey
}

hsc_color = {"HSCs": "#FF0000"}

ckits_color_map = {
    ("cKits", 0): "#B15928",  # Brown
    ("cKits", 1): "#6A3D9A",  # Purple
    ("cKits", 2): "#FF7F00",  # Orange
    ("cKits", 3): "#FDBF6F",  # Light Orange
    ("cKits", 4): "#A6CEE3",  # Light Blue
    ("cKits", 5): "#FB9A99",  # Pink
    ("cKits", 6): "#CAB2D6",  # Lavender
    ("cKits", 7): "#1F78B4",  # Blue
    ("cKits", 8): "#B2DF8A",  # Light Green
    ("cKits", 9): "#33A02C"   # Green
}

cluster_color_map = {
    0: "#B15928",  # Brown
    1: "#6A3D9A",  # Purple
    2: "#FF7F00",  # Orange
    3: "#FDBF6F",  # Light Orange
    4: "#A6CEE3",  # Light Blue
    5: "#FB9A99",  # Pink
    6: "#CAB2D6",  # Lavender
    7: "#1F78B4",  # Blue
    8: "#B2DF8A",  # Light Green
    9: "#33A02C"   # Green
}



# Others
features_name_dict = {
    "cKit Density Divergence (vs. 3mo)": "HeM distribution(vs. 3mo)",
    "cKit Density Divergence (vs. 12mo)": "HeM distribution(vs. 12mo)",
    "cKit Density Divergence (vs. 20mo)": "HeM distribution(vs. 20mo)",
    "cKit Neighborhood Affinity (vs. 3mo)": "HeM neighborhood(vs. 3mo)",
    "cKit Neighborhood Affinity (vs. 12mo)": "HeM neighborhood(vs. 12mo)",
    "cKit Neighborhood Affinity (vs. 20mo)": "HeM neighborhood(vs. 20mo)",
    "HSC Density Divergence (vs. 3mo)": "HSC distribution(vs. 3mo)",
    "HSC Density Divergence (vs. 12mo)": "HSC distribution(vs. 12mo)",
    "HSC Density Divergence (vs. 20mo)": "HSC distribution(vs. 20mo)",
    "HSC Count": "HSC numbers",
    "HSC Spatial Similarity (vs. 3mo)": "HSC-HeM association(vs. 3mo)",
    "HSC Spatial Similarity (vs. 12mo)": "HSC-HeM association(vs. 12mo)",
    "HSC Spatial Similarity (vs. 20mo)": "HSC-HeM association(vs. 20mo)",
    "HSC Composition (vs. 3mo)": "HSC composition(vs. 3mo)",
    "HSC Composition (vs. 12mo)": "HSC composition(vs. 12mo)",
    "HSC Composition (vs. 20mo)": "HSC composition(vs. 20mo)",

}

## 1. Data Loading
- csv files for the pdf estimation and pseudo age calculation
- ims files for metabone creation and transformation

### 1.1 Load position (csv) files

In [None]:
data_dir = "data_bone_age"
# Load the positions files (cKits, HSCs, and RDs)
position_files = [f for f in os.listdir(data_dir)]
position_dirs = [os.path.join(data_dir, f) for f in position_files if f.endswith("csv")]


In [None]:
columns_to_use = ["Position.X", "Position.Y", "Position.Z", "age", "clusters","bone"] # based on the columns of data and also the interest

# Initialize a dictionary to hold the data
positions_dict = {
    "HSCs": None,
    "RDs": None,
    "cKits": None
}

for position_dir in position_dirs:

    # Read the CSV and filter columns
    positions = pd.read_csv(position_dir)
    positions = positions[columns_to_use]

    # Assign the processed data to the appropriate key in the dictionary
    if "hsc" in position_dir:
        positions["source"] = "HSCs"
        positions_dict["HSCs"] = positions
    elif "rd" in position_dir:
        positions["source"] = "RDs"
        positions_dict["RDs"] = positions
    elif "ckit" in position_dir:
        positions["source"] = "cKits"
        positions_dict["cKits"] = positions

# Merge the df in the dict
positions = pd.concat(positions_dict.values())

In [None]:
# (Optional) Drop the data with age 5fu45d
positions = positions[positions.age != "5fu45d"]

# Scale the positions to pixels by dividing by the scaling factor
scaling_factor = {"Position.X":0.7575, "Position.Y":0.7575, "Position.Z":2.5}
positions[["Position.X", "Position.Y", "Position.Z"]] = positions[["Position.X", "Position.Y", "Position.Z"]].div(scaling_factor)

### 1.2 Load ims files

In [None]:
# Function to load the ims files
def get_ims_files_by_conditions(data_dir, keys=["3mo", "12mo", "20mo", "5fu30d", "5fu60d"], source = "Bone"):
    # Initialize a dictionary to store .ims files for each key
    ims_files = {key: [] for key in keys}
    
    # Find all directories in data_dir, excluding .ims files
    all_dirs = [os.path.join(data_dir, d) for d in os.listdir(data_dir) if os.path.isdir(os.path.join(data_dir, d))]
    
    # Classify the directories into the specified keys
    for data_subdir in all_dirs:
        sub_files = [os.path.join(data_subdir, f) for f in os.listdir(data_subdir) if f.endswith(".ims") and source in f]
        # Check which key the directory matches
        matched = False
        for key in keys:
            if key in os.path.basename(data_subdir):
                ims_files[key].extend(sub_files)
                matched = True
                break
        # If no key matches, classify as "other"
        if not matched:
            if "other" not in ims_files:
                ims_files["other"] = []
            ims_files["other"].extend(sub_files)

    # Return the dictionary with the results
    return ims_files

In [None]:
# Load the .ims files
metabone_ims_files = get_ims_files_by_conditions(data_dir = data_dir, keys=["3mo", "12mo", "20mo", "5fu30d", "5fu60d"], source = "Bone")
ims_files_3mo = metabone_ims_files["3mo"]
ims_files_12mo = metabone_ims_files["12mo"]
ims_files_20mo = metabone_ims_files["20mo"]
ims_files_5fu30d = metabone_ims_files["5fu30d"]
ims_files_5fu60d = metabone_ims_files["5fu60d"]

In [None]:
# Convert the bone images with non-zero values to dataframes with Position.X and Position.Y columns
def image_to_df(image):
    # Find the non-zero intensities and their coordinates
    non_zero_coords = np.nonzero(image)  # Returns the y, x coordinates where intensity > 0
    # intensity_values = image[non_zero_coords]  # Extract the intensity values

    # Create a DataFrame from the non-zero coordinates and intensity values
    df = pd.DataFrame({
        "source": "DAPI", # depending on the source we can rename it 
        "Position.Z": 0, # non_zero_coords[0],
        "Position.Y": non_zero_coords[0],
        "Position.X": non_zero_coords[1]
    })
    
    return df

# Function to process and store images for a given condition
def process_bone_images(bone_ims_files):
    bone_dict = {}
    
    for i in bone_ims_files:
        # Extract bone name from the file path
        bone_name = "_".join(i.split("/")[-1].split(" ")[:2])
        
        # Load and process the image (assuming ims function is defined elsewhere)
        bone_img = ims(i)  # You need to ensure that ims(i) correctly loads the image
        # bone_img = bone_img[0, bone_img.shape[1] - 1, :, :, :].max(axis=0).copy()
        bone_img = bone_img[0, 1, :, :, :].max(axis=0).copy() # DAPI channel is the second channel
        # Morphological operations, labeling and remove small objects
        # Threshold the image to create a binary mask
        binary_img = bone_img > 0
        # Label the binary image
        labeled_img = measure.label(binary_img)
        # Remove small objects (20000)
        cleaned_img = morphology.remove_small_objects(labeled_img, min_size=20000)
        # Convert the cleaned image back to binary
        bone_img_df = image_to_df(cleaned_img)
        # Store the processed image in the dictionary with the extracted name
        bone_dict[bone_name] = bone_img_df

    return bone_dict

In [None]:
bone_3mo = process_bone_images(ims_files_3mo)
bone_12mo = process_bone_images(ims_files_12mo)
bone_20mo = process_bone_images(ims_files_20mo)
bone_5fu30d = process_bone_images(ims_files_5fu30d)
bone_5fu60d = process_bone_images(ims_files_5fu60d)

### 1.3 Load cluster affinity matrices

In [None]:
# Read the affinity matrices for each condition
affinity_matrices_dir = os.path.join(data_dir, "affinity_matrices")

# Load the affinity matrices
affinity_matrices = {}
for f in os.listdir(affinity_matrices_dir):
    # f should start with cluster, and end with sum.csv
    if not f.startswith("affinity") or not f.endswith("sum.csv"):
        continue
    # Load the affinity matrix (csv)
    # add column names 0-9
    affinity_matrix = pd.read_csv(os.path.join(affinity_matrices_dir, f), header=None)
    affinity_matrix.columns = range(10)
    # Extract the condition name
    condition_name = f.split("_")[2]
    # Store the affinity matrix in the dictionary
    affinity_matrices[condition_name] = affinity_matrix

## 2. Data Inspection and Preprocessing
- visualization
- preparation for transformation

In [None]:
# Check the images 
for ims_files in [ims_files_3mo, ims_files_12mo, ims_files_20mo, ims_files_5fu30d, ims_files_5fu60d]:
# for ims_files in [ims_files_5fu30d]:
    for i in ims_files:
        print(i)
        img = ims(i)
        print(img.shape)
        bone_img = img[0,0,:,:,:].max(axis=0)
        bone_img = bone_img > 0

        dapi_img = img[0,1,:,:,:].max(axis=0)
        dapi_img = dapi_img > 0
        
        # Label the binary image
        labeled_img = measure.label(dapi_img)
        # Remove small objects
        cleaned_img = morphology.remove_small_objects(labeled_img, min_size=20000)
        cleaned_img = cleaned_img > 0
        # Show the two images together
        fig, ax = plt.subplots(3, 1, figsize=(15, 15))
        ax[0].imshow(bone_img, cmap="gray")
        ax[0].set_title("Bone Image")
        
        ax[1].imshow(dapi_img, cmap="gray")
        ax[1].set_title("DAPI Image")

        ax[2].imshow(cleaned_img, cmap="gray")
        ax[2].set_title("Cleaned DAPI Image")
        
        plt.show()

In [None]:
bone_dicts = {
    "3mo": bone_3mo,
    "12mo": bone_12mo,
    "20mo": bone_20mo,
    "5fu30d": bone_5fu30d,
    "5fu60d": bone_5fu60d
}

# Initialize a list to hold all processed DataFrames
bone_dfs = []

# Process each dictionary
for age, bone_dict in bone_dicts.items():
    for bone, df in bone_dict.items():
        # Add the required columns
        df["age"] = age
        df["clusters"] = None  # Set clusters as NaN
        df["bone"] = bone
        # Reorder columns to match the positions DataFrame
        df = df[["Position.X", "Position.Y", "Position.Z", "age", "clusters", "bone", "source"]]
        bone_dfs.append(df)

# Concatenate all processed DataFrames along with the positions DataFrame
positions = pd.concat([positions] + bone_dfs, ignore_index=True)

# Delete bone_3mo, bone_12mo, bone_20mo, bone_5fu30d, bone_5fu60d, bone_dicts, bone_dfs
del bone_3mo, bone_12mo, bone_20mo, bone_5fu30d, bone_5fu60d, bone_dicts, bone_dfs


In [None]:
# Flip the bone images and positions based on the visual inspection
bone_names = positions["bone"].unique()
bone_y_flips = {bone: False for bone in ["210722KK_st4", "210722KK_st6", "210722KK_st8", "211110KK_st2", "211110KK_st5", "211128KK_st1", "220202KK_st8", "211110KK_st6", "211110KK_st8", "211110KK_st9"]}
bone_y_flips.update({bone: True for bone in bone_names if bone not in bone_y_flips})

# Precompute y_min and y_max for all bones
y_bounds = positions.groupby("bone")["Position.Y"].agg(["min", "max"]).to_dict("index")

# Apply transformations in a vectorized way
positions["y_flip"] = positions["bone"].map(bone_y_flips)
positions["y_min"] = positions["bone"].map({bone: bounds["min"] for bone, bounds in y_bounds.items()})
positions["y_max"] = positions["bone"].map({bone: bounds["max"] for bone, bounds in y_bounds.items()})

# Apply the flipping logic
positions["Position.Y"] = np.where(
    positions["y_flip"],
    positions["y_max"] - positions["Position.Y"],
    positions["Position.Y"] - positions["y_min"]
)

# Drop helper columns
positions.drop(["y_flip", "y_min", "y_max"], axis=1, inplace=True)


## 3. Bone Alignment and Transformation

### 3.1 Bone alignment with the reference bone

In [None]:
# Find the outline of the bone in the x-y plane
def find_outline(points, window_size=10):
    """
    Find the outline of a set of points by finding the min and max y-values for each x-value within a window.
    The outline is only in the x-y plane.
    
    Parameters:
    points : np.array of shape (n,2) points.
    window_size : size of the window to smooth the outline.
    Returns:
    outline_points : np.array of outline points.
    """
    df = pd.DataFrame(points, columns=["x", "y"])
    
    min_y_points = []
    max_y_points = []

    # Sort points by x value
    df_sorted = df.sort_values(by="x")
    
    # Slide over the x values with a window
    for i in range(0, len(df_sorted), window_size):
        window = df_sorted.iloc[i:i + window_size]
        min_y = window.loc[window["y"].idxmin()]
        max_y = window.loc[window["y"].idxmax()]
        min_y_points.append(min_y)
        max_y_points.append(max_y)
    
    # Ensure the outline is in order
    min_y_points = pd.DataFrame(min_y_points).drop_duplicates().sort_values(by="x").values
    max_y_points = pd.DataFrame(max_y_points).drop_duplicates().sort_values(by="x", ascending=False).values

    # Combine min_y and max_y points and close the loop
    outline_points = np.vstack([min_y_points, max_y_points, min_y_points[0]])

    return outline_points

# Find the center of each bone (outline) and put the center of the bones at the same position
def calculate_centroid(outline_points):
    """
    Calculate the centroid of the bone outline.
    
    Parameters:
    outline_points : np.array of shape (n, 2)
    
    Returns:
    centroid : tuple containing (centroid_x, centroid_y)
    """
    # Use the weight centroid
    centroid_x = np.mean(outline_points[:, 0])
    centroid_y = np.mean(outline_points[:, 1])
    
    # Using the middle of the x and y values as the centroid
    # x_min, x_max = outline_points[:, 0].min(), outline_points[:, 0].max()
    # y_min, y_max = outline_points[:, 1].min(), outline_points[:, 1].max()
    # centroid_x = (x_min + x_max) / 2
    # centroid_y = (y_min + y_max) / 2
    
    return centroid_x, centroid_y

def translate_to_origin(outline_points, centroid):
    """
    Translate the outline points so that the centroid is at the origin.
    
    Parameters:
    outline_points : np.array of shape (n, 2)
    centroid : tuple containing (centroid_x, centroid_y)
    
    Returns:
    translated_points : np.array of shape (n, 2)
    """
    translated_points = outline_points.astype(np.float64).copy()
    
    translated_points[:, 0] -= centroid[0]
    translated_points[:, 1] -= centroid[1]
    return translated_points

# Rescale the bones to the same size (bounding box)(optional)
def get_max_dimensions(bone_dicts):
    """
    Find the maximum width and height across all bone outlines in the given dictionaries.
    
    Parameters:
    bone_dicts : list of dictionaries of bones (where each value is a DataFrame with "Position.X" and "Position.Y")
    
    Returns:
    max_width : float, maximum width found across all bones
    max_height : float, maximum height found across all bones
    """
    max_width = 0
    max_height = 0
    
    for bone_dict in bone_dicts:
        for df in bone_dict.values():
            # Extract bone points where "source" == "Bone"
            bone_points = df[df["source"] == "Bone"][["Position.X", "Position.Y"]].values
            
            # Find min and max values of x and y
            min_x, max_x = bone_points[:, 0].min(), bone_points[:, 0].max()
            min_y, max_y = bone_points[:, 1].min(), bone_points[:, 1].max()
            
            # Calculate width and height
            width = max_x - min_x
            height = max_y - min_y
            
            # Update maximum width and height if necessary
            if width > max_width:
                max_width = width
            if height > max_height:
                max_height = height
                
    return max_width, max_height

def rescale_outline(outline_points, max_width, max_height):
    """
    Rescale the bone outline to fit within the maximum width and height across all bones.
    
    Parameters:
    outline_points : np.array of shape (n, 2)
    max_width : float, the maximum width across all bones
    max_height : float, the maximum height across all bones
    
    Returns:
    scaled_points : np.array of shape (n, 2)
    """
    min_x, max_x = outline_points[:, 0].min(), outline_points[:, 0].max()
    min_y, max_y = outline_points[:, 1].min(), outline_points[:, 1].max()

    current_width = max_x - min_x
    current_height = max_y - min_y

    scale_x = max_width / current_width
    scale_y = max_height / current_height

    scaled_points = outline_points.copy()
    scaled_points[:, 0] *= scale_x
    scaled_points[:, 1] *= scale_y

    return scaled_points

# Calculate the overlap area on the grid inside the outline
def calculate_overlap_area(reference_outline, target_outline, resolution=200):
    """
    Calculate the overlap area (in terms of pixels or points) between two outlines.
    
    Parameters:
    reference_outline : np.array of shape (n, 2), outline of the reference bone
    target_outline : np.array of shape (n, 2), outline of the target bone
    resolution : int, the number of points or pixels to use for the area calculation.
    
    Returns:
    overlap_area : float, the number of pixels or points where the areas overlap.
    """
    # Ensure the outlines are 2D arrays of shape (N, 2)
    reference_outline = np.asarray(reference_outline).reshape(-1, 2)
    target_outline = np.asarray(target_outline).reshape(-1, 2)
    
    # Calculate centroids for both bones
    reference_centroid = calculate_centroid(reference_outline)
    target_centroid = calculate_centroid(target_outline)
    
    # Translate both bones to center them
    reference_outline_centered = translate_to_origin(reference_outline, reference_centroid)
    target_outline_centered = translate_to_origin(target_outline, target_centroid)
    
    # Get bounding box of the reference outline
    min_x, max_x = reference_outline_centered[:, 0].min(), reference_outline_centered[:, 0].max()
    min_y, max_y = reference_outline_centered[:, 1].min(), reference_outline_centered[:, 1].max()
    
    # Generate grid of points (pixels) covering the bounding box
    x_grid = np.linspace(min_x, max_x, resolution)
    y_grid = np.linspace(min_y, max_y, resolution)
    xv, yv = np.meshgrid(x_grid, y_grid)
    grid_points = np.vstack([xv.ravel(), yv.ravel()]).T

    # Create Path objects for the reference and target outlines
    reference_path = Path(reference_outline_centered)
    target_path = Path(target_outline_centered)
    
    # Check which points of the grid are inside both outlines
    points_in_reference = reference_path.contains_points(grid_points)
    points_in_target = target_path.contains_points(grid_points)
    
    # Calculate the overlap area as the number of points (pixels) inside both outlines
    overlap_area = np.sum(points_in_reference & points_in_target)
    
    return overlap_area

def rotate_points(points, angle):
    """
    Rotate a set of points by a given angle.
    
    Parameters:
    points : np.array of shape (n, 2)
    angle : float, angle to rotate by in radians
    
    Returns:
    rotated_points : np.array of shape (n, 2)
    """
    # Ensure points are a 2D array of shape (N, 2)
    points = np.asarray(points).reshape(-1, 2)
    
    # Rotation matrix
    rotation_matrix = np.array([[np.cos(angle), -np.sin(angle)], 
                                [np.sin(angle), np.cos(angle)]])
    
    # Rotate points
    rotated_points = points.dot(rotation_matrix)
    
    return rotated_points


def grid_search_rotation(reference_outline, target_outline, angle_step=np.pi/36):
    """
    Perform a grid search over possible rotation angles to maximize overlap.
    
    Parameters:
    reference_outline : np.array of shape (n, 2), outline of the reference bone
    target_outline : np.array of shape (n, 2), outline of the target bone
    angle_step : float, step size for angle search (in radians)
    
    Returns:
    best_rotated_points : np.array of shape (n, 2), the rotated target bone outline with maximum overlap
    best_angle : float, the optimal rotation angle in radians
    """
    best_angle = None
    max_overlap = -np.inf
    best_rotated_points = None
    
    # Iterate over angles between -90 and 90 degrees (in radians)
    for angle in np.arange(-np.pi/4, np.pi/4, angle_step):
        rotated_outline = rotate_points(target_outline, angle)
        overlap = calculate_overlap_area(reference_outline, rotated_outline)
        
        if overlap > max_overlap:
            max_overlap = overlap
            best_angle = angle
            best_rotated_points = rotated_outline
    
    # Recenter the final rotated outline
    final_centroid = calculate_centroid(best_rotated_points)
    best_rotated_points = translate_to_origin(best_rotated_points, final_centroid)
    
    return best_rotated_points, best_angle, final_centroid

# Modify the process_and_align_bones function to accept the max_width and max_height
def process_and_align_bones_with_overlap(bone_dict, reference_bone, window_size=500, source_col = "DAPI"):
    """
    Process and align all bones from the given dictionary to maximize overlap with a reference bone.
    
    Parameters:
    bone_dict : dict of bones, where each value is a DataFrame with "Position.X" and "Position.Y".
    reference_bone_name : string, the name of the bone to use as the reference for alignment.
    
    Returns:
    aligned_bones : dict of aligned bone outlines
    """
    aligned_bones = {}
    aligned_angles = {}
    aligned_centroids = {}
    # Check the type of the reference_bone_name
    # If it is a dataframe, we can use the reference_bone_name directly
    if isinstance(reference_bone, pd.DataFrame):
        reference_df = reference_bone
        reference_points = reference_df[reference_df["source"] == source_col][["Position.X", "Position.Y"]].values
        reference_outline = find_outline(reference_points, window_size=window_size)
        reference_centroid = calculate_centroid(reference_outline)
        reference_outline = translate_to_origin(reference_outline, reference_centroid)
        
        for bone_name, df in bone_dict.items():

            # Filter points where "source" == "Bone"
            bone_points = df[df["source"] == source_col][["Position.X", "Position.Y"]].values
        
            # Find the outline of the target bone
            target_outline = find_outline(bone_points, window_size=window_size)
            # Optimize rotation to maximize overlap with the reference bone
            aligned_outline, best_angle, final_centroid = grid_search_rotation(reference_outline, target_outline)
            aligned_bones[bone_name] = aligned_outline
            aligned_angles[bone_name] = best_angle
            aligned_centroids[bone_name] = final_centroid
    else:
        raise ValueError("Invalid reference_bone_name. Must be a DataFrame.")


    return aligned_bones, aligned_angles, aligned_centroids

# Perform the alignment for each age group
def align_bones_with_centroids_angles(positions_df, aligned_centroids, aligned_angles, exclude_col="DAPI"):
    aligned_bones = {}
    
    for bone_name, df in positions_df.items():
        # Filter out rows based on exclude_col first
        df_filtered = df[df["source"] != exclude_col].copy()  # Copy only the filtered rows

        # Get the centroid and angle for alignment
        centroid = aligned_centroids[bone_name]
        angle = aligned_angles[bone_name]

        # Rotate points
        df_filtered[["Position.X", "Position.Y"]] = rotate_points(df_filtered[["Position.X", "Position.Y"]].values, angle)

        # Translate points
        df_filtered["Position.X"] -= centroid[0]
        df_filtered["Position.Y"] -= centroid[1]

        # Save the aligned DataFrame
        aligned_bones[bone_name] = df_filtered

    return aligned_bones


In [None]:
# Transfer the DataFrame to the dictionary to fit the function
positions_dfs_dict = {}

# Iterate over the unique ages in the DataFrame
for age in positions["age"].unique():
    # Filter the DataFrame by the current age
    age_group = positions[positions["age"] == age]
    
    # Initialize a dictionary for this age group
    positions_dfs_dict[age] = {}
    
    # Iterate over the unique bones in this age group
    for bone in age_group["bone"].unique():
        # Filter the DataFrame by the current bone
        bone_group = age_group[age_group["bone"] == bone]
        
        # Assign the filtered DataFrame to the dictionary
        # Drop the age and bone
        bone_group = bone_group.drop(columns=["age", "bone"])
        positions_dfs_dict[age][bone] = bone_group

In [None]:
# Define the ref bone based on the visual inspection
ref_bone = positions_dfs_dict["3mo"]["210722KK_st1"].copy()

# Recenter and find the outline
ref_bone_points = ref_bone[ref_bone["source"] == "DAPI"][["Position.X", "Position.Y"]].values

ref_bone_outline = find_outline(ref_bone_points, window_size=500)
ref_bone_centroid = calculate_centroid(ref_bone_outline)

# Transformed outline
ref_bone_outline = translate_to_origin(ref_bone_outline, ref_bone_centroid)


In [None]:
# Calculate the parameters for the alignment
positions_3mo_outline, positions_3mo_angles, positions_3mo_centroids = process_and_align_bones_with_overlap(positions_dfs_dict["3mo"], reference_bone = ref_bone, window_size=500, source_col="DAPI")
positions_12mo_outline, positions_12mo_angles, positions_12mo_centroids = process_and_align_bones_with_overlap(positions_dfs_dict["12mo"], reference_bone = ref_bone, window_size=500, source_col="DAPI")
positions_20mo_outline, positions_20mo_angles, positions_20mo_centroids = process_and_align_bones_with_overlap(positions_dfs_dict["20mo"], reference_bone = ref_bone, window_size=500, source_col="DAPI")
positions_5fu30d_outline, positions_5fu30d_angles, positions_5fu30d_centroids = process_and_align_bones_with_overlap(positions_dfs_dict["5fu30d"], reference_bone = ref_bone, window_size=500, source_col="DAPI")
positions_5fu60d_outline, positions_5fu60d_angles, positions_5fu60d_centroids = process_and_align_bones_with_overlap(positions_dfs_dict["5fu60d"], reference_bone = ref_bone, window_size=500, source_col="DAPI")


In [None]:
# Perform the alignment for each age group
aligned_3mo_bones = align_bones_with_centroids_angles(positions_dfs_dict["3mo"], positions_3mo_centroids, positions_3mo_angles)
aligned_12mo_bones = align_bones_with_centroids_angles(positions_dfs_dict["12mo"], positions_12mo_centroids, positions_12mo_angles)
aligned_20mo_bones = align_bones_with_centroids_angles(positions_dfs_dict["20mo"], positions_20mo_centroids, positions_20mo_angles)
aligned_5fu30d_bones = align_bones_with_centroids_angles(positions_dfs_dict["5fu30d"], positions_5fu30d_centroids, positions_5fu30d_angles)
aligned_5fu60d_bones = align_bones_with_centroids_angles(positions_dfs_dict["5fu60d"], positions_5fu60d_centroids, positions_5fu60d_angles)


In [None]:
del positions_dfs_dict

### 3.2 Outline smoothing

In [None]:
def smooth_outline(outline, sigma=2):
    if not np.array_equal(outline[0], outline[-1]):
        outline = np.vstack([outline, outline[0]])
    smoothed_x = gaussian_filter1d(outline[:, 0], sigma=sigma)
    smoothed_y = gaussian_filter1d(outline[:, 1], sigma=sigma)
    smoothed_outline = np.vstack((smoothed_x, smoothed_y)).T
    if not np.array_equal(smoothed_outline[0], smoothed_outline[-1]):
        smoothed_outline = np.vstack([smoothed_outline, smoothed_outline[0]])
    return smoothed_outline


In [None]:
# Smoothed the alinged_bone_outlines
positions_3mo_bone_outlines_smoothed = {}
positions_12mo_bone_outlines_smoothed = {}
positions_20mo_bone_outlines_smoothed = {}
positions_5fu30d_bone_outlines_smoothed = {}
positions_5fu60d_bone_outlines_smoothed = {}

sigma = 50
for bone_name, outline in positions_3mo_outline.items():
    smoothed_outline = smooth_outline(outline, sigma=sigma)
    positions_3mo_bone_outlines_smoothed[bone_name] = smoothed_outline
for bone_name, outline in positions_12mo_outline.items():
    smoothed_outline = smooth_outline(outline, sigma=sigma)
    positions_12mo_bone_outlines_smoothed[bone_name] = smoothed_outline
for bone_name, outline in positions_20mo_outline.items():
    smoothed_outline = smooth_outline(outline, sigma=sigma)
    positions_20mo_bone_outlines_smoothed[bone_name] = smoothed_outline
for bone_name, outline in positions_5fu30d_outline.items():
    smoothed_outline = smooth_outline(outline, sigma=sigma)
    positions_5fu30d_bone_outlines_smoothed[bone_name] = smoothed_outline
for bone_name, outline in positions_5fu60d_outline.items():
    smoothed_outline = smooth_outline(outline, sigma=sigma)
    positions_5fu60d_bone_outlines_smoothed[bone_name] = smoothed_outline

ref_outline = smooth_outline(ref_bone_outline, sigma=sigma)

### 3.3 Bone transformation

In [None]:
# Get the mask of the bone outline
def points_in_polygon(x_points, y_points, outline):
    path = Path(outline)
    points = np.vstack((x_points, y_points)).T
    return path.contains_points(points)


# Define the function to exclude the data points outside the bone outline
def exclude_outside_bone_outline(df, bone_outline, exclude_source ="DAPI"):
    """
    Exclude the data points outside the bone outline using GeoPandas with bounding box filtering for more efficiency.

    Parameters:
    df : DataFrame containing "Position.X", "Position.Y", "weights", and "source".
    bone_outline : The outline of the bone to limit the KDE calculation within the bone.
    exclude_source : The name of the source that used for the bone outline creation.
    
    Returns:
    df_inside : DataFrame containing the data points inside the bone outline.
    """
    # print(df.shape)
    df = pd.DataFrame(df[df["source"] != exclude_source])
    # print(df.shape)
    # Create a GeoDataFrame from the original DataFrame
    gdf = gpd.GeoDataFrame(df, geometry=gpd.points_from_xy(df["Position.X"], df["Position.Y"]))

    # Convert the bone outline to a Shapely polygon
    bone_polygon = Polygon(bone_outline)

    # Create a bounding box polygon from the bounds of the bone_polygon
    minx, miny, maxx, maxy = bone_polygon.bounds
    bounding_box = box(minx, miny, maxx, maxy)

    # First, filter by the bounding box of the polygon (faster operation)
    gdf_in_bbox = gdf[gdf.geometry.within(bounding_box)]
    
    # Then, perform the more precise filtering with the actual polygon
    gdf_inside = gdf_in_bbox[gdf_in_bbox.within(bone_polygon)]
    
    # Drop the "geometry" column if you don"t need it in the result
    df_inside = gdf_inside.drop(columns="geometry")

    return df_inside

def get_y_range_at_x(shape_points, x):
    """
    Find the range of y-values where the vertical line at x intersects the shape.
    """
    # Find all edges of the shape where x is between the x-coordinates of the endpoints
    y_vals = []
    for i in range(len(shape_points)):
        p1 = shape_points[i]
        p2 = shape_points[(i + 1) % len(shape_points)]  # wrap around the shape points
        
        # Check if the x value is between p1 and p2"s x-coordinates
        if (p1[0] <= x <= p2[0]) or (p2[0] <= x <= p1[0]):
            # Linearly interpolate to find the corresponding y value at x
            if p1[0] != p2[0]:  # Avoid division by zero
                y = p1[1] + (p2[1] - p1[1]) * (x - p1[0]) / (p2[0] - p1[0])
                y_vals.append(y)
    
    if y_vals:
        return min(y_vals), max(y_vals)
    else:
        return None, None  # No intersection with the shape at this x

def create_structured_grid(shape_points, x_num, y_num):
    """
    Create a structured grid by dividing the bounding box of the shape into x_num vertical sections.
    Then place y_num points along each vertical grid line where it intersects the shape.
    """
    shape_points = np.array(shape_points)
    
    # Step 1: Compute the bounding box
    min_x, max_x = np.min(shape_points[:, 0]), np.max(shape_points[:, 0])
    
    # Step 2: Divide the x-range into equal sections
    x_vals = np.linspace(min_x, max_x, x_num + 1)  # x_num divisions create x_num + 1 grid lines
    # Shift the x_vals to the left by half the grid spacing to center the grid
    x_vals = x_vals[1:]  # Remove the first point (left edge of bounding box)
    x_vals = x_vals - (x_vals[1] - x_vals[0]) / 2  # Shift left by half the grid spacing
    
    grid_points = []

    # Step 3: For each x grid line, find the y range and then place points
    for x in x_vals:  # Skip the first and last lines (already have the bounding box)
        y_min, y_max = get_y_range_at_x(shape_points, x)
        
        if y_min is not None and y_max is not None:
            # Get y points by placing y_num points between y_min and y_max
            y_vals = np.linspace(y_min, y_max, y_num+1)
            
            # Shift the y_vals down by half the grid spacing to center the grid
            y_vals = y_vals[1:]  # Remove the first point (bottom edge of bounding box)
            y_vals = y_vals - (y_vals[1] - y_vals[0]) / 2  # Shift down by half the grid spacing
            # Add grid points (x, y) for this vertical line
            for y in y_vals: # Skip the first and last points (already have the y_min and y_max)
                grid_points.append([x, y])

    return np.array(grid_points)

# (Not used)
def is_point_inside_shape(point, shape_points):
    """
    Determines if a point is inside an irregular shape using ray-casting.
    """
    x, y = point
    n = len(shape_points)
    inside = False
    p1x, p1y = shape_points[0]
    for i in range(n + 1):
        p2x, p2y = shape_points[i % n]
        if y > min(p1y, p2y):
            if y <= max(p1y, p2y):
                if x <= max(p1x, p2x):
                    if p1y != p2y:
                        xinters = (y - p1y) * (p2x - p1x) / (p2y - p1y) + p1x
                    if p1x == p2x or x <= xinters:
                        inside = not inside
        p1x, p1y = p2x, p2y
    return inside


def thin_plate_spline_transform(src_points, dst_points):
    """
    Perform Thin Plate Spline (TPS) transformation from src_points to dst_points.
    """
    # Create Radial Basis Function (RBF) interpolators for x and y coordinates
    rbf_x = Rbf(src_points[:, 0], src_points[:, 1], dst_points[:, 0], function="thin_plate")
    rbf_y = Rbf(src_points[:, 0], src_points[:, 1], dst_points[:, 1], function="thin_plate")
    
    def transform(points):
        new_x = rbf_x(points[:, 0], points[:, 1])
        new_y = rbf_y(points[:, 0], points[:, 1])
        return np.vstack([new_x, new_y]).T
    
    return transform

def transform_data(data_points, grid_shape_1, grid_shape_2):
    """
    Apply the TPS transformation to the data points based on the grid transformation.
    """
    # Perform Thin Plate Spline (TPS) transformation
    tps_transform = thin_plate_spline_transform(grid_shape_2, grid_shape_1)
    
    # Apply the transformation to the data points
    transformed_data_points = tps_transform(data_points)
    
    return transformed_data_points


def filter_bone_positions(df, source_value=None, columns_to_keep=None):
    """
    Filters the DataFrame for a specific source if provided and returns the Position.X, Position.Y columns as a NumPy array.
    If source_value is None, return all positions.
    """
    if source_value is not None:
        filtered_df = df[df["source"] != source_value]
    else:
        # Exclude the data with source value "DAPI"
        filtered_df = df[df["source"] != "DAPI"]
    positions = filtered_df[["Position.X", "Position.Y"]].to_numpy()
    if columns_to_keep is not None:
        return positions, filtered_df["source"].to_numpy(), filtered_df[columns_to_keep].to_numpy()
    else:
        return positions, filtered_df["source"].to_numpy()  # Return positions and the source column



def transform_bone_positions(outline_dict, position_dict, common_outline, x_num=40, y_num=20, source_value=None, columns_to_keep=None):
    """
    Transforms the bone positions from multiple datasets using Thin Plate Spline (TPS) based on the provided outlines and positions.
    Parameters:
        outline_dict: Dictionary containing outlines.
        position_dict: Dictionary containing bone positions (DataFrames).
        common_outline: The common outline (to which the other outlines will be aligned).
        x_num: Number of vertical sections for structured grid.
        y_num: Number of horizontal points along each vertical section.
        source_value: If provided, exclude the data with this source value.
        columns_to_keep: If provided, keep the specified columns in the transformed DataFrame.
    Returns:
        Dictionary containing the transformed bone positions with the "source" column retained.
    """
    transformed_dict = {}

    # Create the structured grid for the common outline
    grid_common_outline = create_structured_grid(common_outline, x_num=x_num, y_num=y_num)

    # Loop through each dataset in the position_dict
    for dataset_name, position_df in position_dict.items():
        # Get the corresponding outline
        outline_2 = outline_dict[dataset_name]
        
        # Create the structured grid for the specific dataset"s outline
        grid_outline_2 = create_structured_grid(outline_2, x_num=x_num, y_num=y_num)

        # Filter the positions based on the source (if provided)
        # By default, it will exclude the data with source value "GFP"
        if columns_to_keep is None:
            bone_positions_2, source_column = filter_bone_positions(position_df, source_value=source_value) 
        else:
            bone_positions_2, source_column, kept_columns = filter_bone_positions(position_df, source_value=source_value, columns_to_keep=columns_to_keep)

        # Transform the filtered bone positions from the dataset outline to the common outline
        transformed_bone_positions = transform_data(bone_positions_2, grid_common_outline, grid_outline_2)

        # Convert the transformed positions to a DataFrame and include the source column
        transformed_df = pd.DataFrame(transformed_bone_positions, columns=["Position.X", "Position.Y"])
        transformed_df["source"] = source_column  # Add the source column back
        transformed_df["dataset"] = dataset_name  # Add the dataset name for reference
        if columns_to_keep is not None:
            transformed_df[columns_to_keep] = kept_columns 
        # Store the transformed DataFrame in the result dictionary
        transformed_dict[dataset_name] = transformed_df

    return transformed_dict



In [None]:
# The num is based on the size of the bone outline
x_num = 200
y_num = 40

# We are using the aligned_bone_date dict instead of using the dataframes, because for bones of the same day, they have different bone outline
transformed_3mo_bones = transform_bone_positions(positions_3mo_bone_outlines_smoothed, aligned_3mo_bones, ref_outline, x_num=x_num, y_num=y_num, columns_to_keep="clusters")
transformed_12mo_bones = transform_bone_positions(positions_12mo_bone_outlines_smoothed, aligned_12mo_bones, ref_outline, x_num=x_num, y_num=y_num, columns_to_keep="clusters")
transformed_20mo_bones = transform_bone_positions(positions_20mo_bone_outlines_smoothed, aligned_20mo_bones, ref_outline, x_num=x_num, y_num=y_num, columns_to_keep="clusters")
transformed_5fu30d_bones = transform_bone_positions(positions_5fu30d_bone_outlines_smoothed, aligned_5fu30d_bones, ref_outline, x_num=x_num, y_num=y_num, columns_to_keep="clusters")
transformed_5fu60d_bones = transform_bone_positions(positions_5fu60d_bone_outlines_smoothed, aligned_5fu60d_bones, ref_outline, x_num=x_num, y_num=y_num, columns_to_keep="clusters")

transformed_3mo_bones_df = pd.concat(transformed_3mo_bones)
transformed_12mo_bones_df = pd.concat(transformed_12mo_bones)
transformed_20mo_bones_df = pd.concat(transformed_20mo_bones)
transformed_5fu30d_bones_df = pd.concat(transformed_5fu30d_bones)
transformed_5fu60d_bones_df = pd.concat(transformed_5fu60d_bones)

# Exclude the data points outside the bone outline
transformed_3mo_bones_df = exclude_outside_bone_outline(transformed_3mo_bones_df, ref_outline)
transformed_12mo_bones_df = exclude_outside_bone_outline(transformed_12mo_bones_df, ref_outline)
transformed_20mo_bones_df = exclude_outside_bone_outline(transformed_20mo_bones_df, ref_outline)
transformed_5fu30d_bones_df = exclude_outside_bone_outline(transformed_5fu30d_bones_df, ref_outline)
transformed_5fu60d_bones_df = exclude_outside_bone_outline(transformed_5fu60d_bones_df, ref_outline)

del transformed_3mo_bones, transformed_12mo_bones, transformed_20mo_bones, transformed_5fu30d_bones, transformed_5fu60d_bones

## 4. Components Calculation and Estimation
- calculate the PDF of cKits and HSC
- calculate the 2D histogram of HSC
- calculate the cluster composition of HSC (based on the raw data or transformed data)

### 4.1 PDF estimation (function defintion)

In [None]:
# Use the bone outline to generate the KDE for each source
# Define the KDE function for each source, with weights based on z-aggregated points
def kde_for_clusters(df, bw_method="scott", bone_outline = None, binsize = 10):

    kde_results = {}
    sources = df["source"].unique()
    sources = sources[sources != "GFP"] # Exclude the bone for the KDE calculation

    for source in sources:
        # Filter data for the current source
        source_data = df[df["source"] == source]
        if source == "cKits":
            for cluster in source_data["clusters"].unique():
                cluster_data = source_data[source_data["clusters"] == cluster]

                # Group by Position.X and Position.Y, summing weights (or using counts as weights if no weights are given)
                if "weights" in cluster_data.columns:
                    cluster_data_agg = cluster_data.groupby(["Position.X", "Position.Y"])["weights"].sum().reset_index()
                else:
                    # If no weights are provided, use the count of occurrences as weights
                    cluster_data_agg = cluster_data.groupby(["Position.X", "Position.Y"]).size().reset_index(name="weights")

                # Get the x and y values and aggregated weights
                x_vals = cluster_data_agg["Position.X"]
                y_vals = cluster_data_agg["Position.Y"]
                weights = cluster_data_agg["weights"]
                if bone_outline is None:
                    x_min, x_max = x_vals.min(), x_vals.max()
                    y_min, y_max = y_vals.min(), y_vals.max()
                else:
                    x_min, y_min = bone_outline.min(axis=0)
                    x_max, y_max = bone_outline.max(axis=0)

                xi, yi = np.linspace(x_min, x_max, int((x_max - x_min)/binsize)+1), np.linspace(y_min, y_max, int((y_max - y_min)/binsize)+1)
                xi, yi = np.meshgrid(xi, yi)
                grid_points = np.vstack([xi.flatten(), yi.flatten()])
                common_grid = (xi, yi, grid_points)

                # Stack the x and y data for KDE input
                xy = np.vstack([x_vals, y_vals])

                # Perform the KDE with aggregated weights
                kde = gaussian_kde(xy, weights=weights, bw_method=bw_method)
                kde_values = kde(grid_points).reshape(xi.shape)

                # Store the results for each source
                kde_results[(source, cluster)] = kde_values
        else:
            # Group by Position.X and Position.Y, summing weights (or using counts as weights if no weights are given)
            if "weights" in source_data.columns:
                source_data_agg = source_data.groupby(["Position.X", "Position.Y"])["weights"].sum().reset_index()
            else:
                # If no weights are provided, use the count of occurrences as weights
                source_data_agg = source_data.groupby(["Position.X", "Position.Y"]).size().reset_index(name="weights")

            # Get the x and y values and aggregated weights
            x_vals = source_data_agg["Position.X"]
            y_vals = source_data_agg["Position.Y"]
            weights = source_data_agg["weights"]
            if bone_outline is None:
                x_min, x_max = x_vals.min(), x_vals.max()
                y_min, y_max = y_vals.min(), y_vals.max()
            else:
                x_min, y_min = bone_outline.min(axis=0)
                x_max, y_max = bone_outline.max(axis=0)

            xi, yi = np.linspace(x_min, x_max, int((x_max - x_min)/binsize)+1), np.linspace(y_min, y_max, int((y_max - y_min)/binsize)+1)
            xi, yi = np.meshgrid(xi, yi)
            grid_points = np.vstack([xi.flatten(), yi.flatten()])
            common_grid = (xi, yi, grid_points)

            # Stack the x and y data for KDE input
            xy = np.vstack([x_vals, y_vals])

            # Perform the KDE with aggregated weights
            kde = gaussian_kde(xy, weights=weights, bw_method=bw_method)
            kde_values = kde(grid_points).reshape(xi.shape)

            # Store the results for each source
            kde_results[source] = kde_values

    return kde_results, common_grid




### 4.2 Cluster composition calculation

In [None]:
# Define relevant columns and ages
columns_to_use = ["Position.X", "Position.Y", "Position.Z", "age", "clusters", "bone"]
ages = ["3mo", "12mo", "20mo", "5fu30d", "5fu60d"]

# Collect only HSC data
hsc_positions = []

for position_dir in position_dirs:
    if "hsc" in position_dir.lower():
        df = pd.read_csv(position_dir, usecols=columns_to_use)
        df["source"] = "HSCs"
        hsc_positions.append(df)

# Combine all HSC data
hsc_combined = pd.concat(hsc_positions, ignore_index=True)
# Rename the column "bone" as "dataset"
hsc_combined.rename(columns={"bone": "dataset"}, inplace=True)

# Split into individual variables
hsc_3mo_raw = hsc_combined[hsc_combined["age"] == "3mo"]
hsc_12mo_raw = hsc_combined[hsc_combined["age"] == "12mo"]
hsc_20mo_raw = hsc_combined[hsc_combined["age"] == "20mo"]
hsc_5fu30d_raw = hsc_combined[hsc_combined["age"] == "5fu30d"]
hsc_5fu60d_raw = hsc_combined[hsc_combined["age"] == "5fu60d"]



In [None]:
# Calulate the the cluster proportions of each dataset
def compute_proportions_hsc(df, clusters = np.arange(10)):
    df = df[df["source"] == "HSCs"].copy()
    counts = df.groupby(["clusters"]).size().reset_index(name="count")
    # Fill the missing clusters with 0
    counts = counts.set_index("clusters").reindex(clusters, fill_value=0).reset_index()
    # Add the dataset column
    # counts["dataset"] = df["dataset"].unique()[0]
    # Add the total count
    counts["total"] = df.shape[0]
    # Add the proportion
    counts["proportion"] = counts["count"] / counts["total"]
    return counts[["clusters", "count", "total", "proportion"]]

# Compute proportions for each dataset
proportions_3mo = compute_proportions_hsc(hsc_3mo_raw)
proportions_12mo = compute_proportions_hsc(hsc_12mo_raw)
proportions_20mo = compute_proportions_hsc(hsc_20mo_raw)
proportions_5fu30d = compute_proportions_hsc(hsc_5fu30d_raw)
proportions_5fu60d = compute_proportions_hsc(hsc_5fu60d_raw)


# Apply linear interpolation for the 3mo, 12mo, and 20mo datasets
real_times = np.array([3, 12, 20])
fine_times = np.linspace(3, 20, 100)  # Fine-grained time points
proportions_3mo["proportion"] = proportions_3mo["proportion"].astype(float)
proportions_12mo["proportion"] = proportions_12mo["proportion"].astype(float)
proportions_20mo["proportion"] = proportions_20mo["proportion"].astype(float)
proportions_5fu30d["proportion"] = proportions_5fu30d["proportion"].astype(float)
proportions_5fu60d["proportion"] = proportions_5fu60d["proportion"].astype(float)


## 5 Bone Age Estimation

### 5.1 Required functions

In [None]:
# Define a function to calculate KL Divergence
def calculate_kl_divergence(p, q):
    # Ensure no zero values to avoid division by zero or log of zero
    p = np.clip(p, 1e-10, None)
    q = np.clip(q, 1e-10, None)
    return np.sum(rel_entr(p, q))

# Function to calculate Jensen-Shannon Divergence (JSD)
def calculate_jsd(p, q):
    # Compute the average distribution M
    m = 0.5 * (p + q)
    
    # Compute JSD as the average KL divergence between P-M and Q-M
    jsd = 0.5 * calculate_kl_divergence(p, m) + 0.5 * calculate_kl_divergence(q, m)
    return jsd


# Function to compute Wasserstein distance between two matrices (not use for now)
def wasserstein_matrix_distance(A, B):
    """
    Computes Wasserstein distance between two affinity matrices.
    Applies row-wise Wasserstein distance and averages over all rows.
    """
    num_rows = A.shape[0]
    distances = [wasserstein_distance(A.iloc[i, :], B.iloc[i, :]) for i in range(num_rows)]
    return np.mean(distances)


def wasserstein_distance_per_row(A: np.ndarray, B: np.ndarray) -> float:
    """
    Computes Wasserstein distance between two 1D arrays (rows of an affinity matrix).
    
    Inputs:
        A, B: NumPy arrays of shape [10,] representing a single row of affinity values.
    
    Output:
        A single float value representing the Wasserstein distance between the two rows.
    """
    # Ensure both arrays have the same shape
    assert A.shape == B.shape, "Error: The arrays A and B must have the same shape!"

    # Compute Wasserstein Distance
    # print(A.shape[0])
    bins = np.arange(A.shape[0])
    return wasserstein_distance(bins,bins,A, B)

# Define a function to calculate the heatmap SSIM
def calculate_heatmap_ssim(p, q):
    """
    Calculate Structural Similarity Index (SSIM) between two heatmaps.

    Args:
        p: 2D array (heatmap) representing the first image.
        q: 2D array (heatmap) representing the second image.

    Returns:
        float: SSIM value between the two heatmaps.
    """
    # Ensure both inputs are numpy arrays
    p = np.array(p)
    q = np.array(q)

    # Check if both heatmaps are entirely zero
    if np.all(p == 0) and np.all(q == 0):
        return 1.0  # If both are empty, they are identical

    # Check if one heatmap is zero and the other is not
    if np.all(p == 0) or np.all(q == 0):
        return 0.0  # Completely different if one is empty

    # Compute SSIM normally
    with np.errstate(divide="ignore", invalid="ignore"):
        return compare_ssim(p, q, data_range=p.max() - p.min(), win_size=3)

# Define a function to calculate the Mean Squared Error (MSE)
def calculate_mse(p, q):
    """
    Calculate Mean Squared Error (MSE) between two arrays.

    Args:
        p: Array of true values.
        q: Array of predicted values.

    Returns:
        float: MSE value.
    """
    return np.mean((p - q) ** 2)


def calculate_cross_entropy(p, q):
    """
    Compute the cross entropy H(p, q) = -sum_i p(i)*log(q(i)).
    We add a small epsilon to q to avoid log(0).
    """
    epsilon = 1e-10
    return -np.sum(p * np.log(q + epsilon))

    


In [None]:
# Before the calculation of the KL Divergence, we need to normalize the kde values to have the sum of 1 (integral of 1)
# Normalize the kde values for the HSCs, RDs and cKits
def normalize_kde_values(kde_results):
    """
    Normalize the KDE values to have the sum of 1 for each cluster.

    Args:
        kde_results (dict): Dictionary of KDE values by cluster.

    Returns:
        dict: Dictionary of normalized KDE values by cluster.
    """
    # Create a new dictionary to store the normalized KDE values
    normalized_kde_results = {}
    
    for cluster, kde_values in kde_results.items():
        # Normalize the KDE values to have the sum of 1
        normalized_kde_values = kde_values / np.sum(kde_values)
        normalized_kde_results[cluster] = normalized_kde_values
    
    return normalized_kde_results

In [None]:
# Define exponential function
def exponential_func(x, a, b):
    return a * np.exp(b * x)

# Define linear function
def linear_func(x, m, c):
    return m * x + c

# Constrained linear regression to ensure y > 0
def constrained_linear_fit(x, y):
    def objective(params):
        m, c = params
        return np.sum((linear_func(x, m, c) - y) ** 2)  # Minimize squared error
    
    # Initial guesses for m and c
    initial_guess = [1, 1]
    
    # Constraints: c > 0
    constraints = {"type": "ineq", "fun": lambda params: params[1]}  # c > 0
    
    # Perform optimization
    result = minimize(objective, initial_guess, constraints=constraints, method="SLSQP")
    
    if result.success:
        return result.x  # Return fitted m and c
    else:
        raise RuntimeError("Optimization failed for constrained linear regression")
    
def find_intersection(func, y_value, fine_ages, method="closest"):
    """
    Find the x-values where a function intersects a given y-value.
    If multiple intersection points exist, return their average.

    Parameters:
        func (callable): The function to intersect (e.g., regression or interpolation).
        y_value (float): The y-value of the horizontal line.
        fine_ages (np.ndarray): The range of x-values to search.
        method (str): Method to find the value ("root" or "closest").

    Returns:
        float: The averaged x-value of the intersections, or None if no intersection exists.
    """
    if method == "root":
        # Define the function to find roots (difference between curve and y_value)
        def func_to_solve(x):
            return func(x) - y_value

        # Find all points where the function crosses the horizontal line
        intersections = []
        for i in range(len(fine_ages) - 1):
            x1, x2 = fine_ages[i], fine_ages[i + 1]
            try:
                result = root_scalar(func_to_solve, bracket=(x1, x2), method="brentq")
                if result.converged:
                    intersections.append(result.root)
            except ValueError:
                continue

        # Return the average of all intersection points, or None if no intersection found
        return np.mean(intersections) if intersections else None

    elif method == "closest":
        # Evaluate the function at all fine_ages
        y_values = func(fine_ages)
        # Find the index of the closest value
        closest_index = np.argmin(np.abs(y_values - y_value))
        return fine_ages[closest_index]
    
    else:
        raise ValueError("Invalid method. Choose 'root' or 'closest'.")

In [None]:
def calculate_cluster_sizes(transformed_df):
    """
    Calculate the cluster sizes for each source (HSCs, cKits, RDs) and cluster.
    
    Parameters:
        transformed_df (pd.DataFrame): DataFrame containing `source`, `clusters`, and other data.
    
    Returns:
        dict: Nested dictionary with source as the first-level key and cluster sizes as second-level keys.
    """
    cluster_sizes = {}
    sources = transformed_df["source"].unique()
    for source in sources:
        source_data = transformed_df[transformed_df["source"] == source]
        cluster_counts = source_data.groupby("clusters").size().to_dict()
        cluster_sizes[source] = cluster_counts
    return cluster_sizes

def compute_weighted_cluster_sizes(ref_cluster_sizes, cluster_sizes, condition, sources, clusters):
    """
    Compute weighted cluster sizes, normalizing each condition and filling missing clusters with 0.
    
    Parameters:
        ref_cluster_sizes (dict): Nested dictionary with cluster sizes from the reference S2 data,
                                with keys like "3mo_ref", "12mo_ref", "20mo_ref".
        cluster_sizes (dict): Nested dictionary with cluster sizes per condition (e.g., for selected data).
        condition (str): The condition string for the selected data (e.g., "3mo_selected").
        sources (list): List of sources (e.g., ["cKits", "HSCs"]).
        clusters (list): List of all possible clusters (e.g., [0,1,...,9]).
        
    Returns:
        dict: A dictionary mapping each source to a dictionary of weighted cluster sizes.
    """
    cluster_size_input = {}
    for source in sources:
        # Initialize dictionaries for each age condition
        cluster_size_3mo = {cluster: 0 for cluster in clusters}
        cluster_size_12mo = {cluster: 0 for cluster in clusters}
        cluster_size_20mo = {cluster: 0 for cluster in clusters}
        cluster_size_input[source] = {cluster: 0 for cluster in clusters}
        
        # Helper function: normalize a dictionary of counts to proportions
        def normalize_cluster_sizes(cluster_dict):
            total = sum(cluster_dict.values())
            return {cluster: (size / total if total > 0 else 0) for cluster, size in cluster_dict.items()}
        
        # Normalize the reference sizes for each age from S2 data (assume keys "3mo_ref", "12mo_ref", "20mo_ref")
        cluster_size_3mo.update(normalize_cluster_sizes(ref_cluster_sizes.get("3mo_ref", {}).get(source, {})))
        cluster_size_12mo.update(normalize_cluster_sizes(ref_cluster_sizes.get("12mo_ref", {}).get(source, {})))
        cluster_size_20mo.update(normalize_cluster_sizes(ref_cluster_sizes.get("20mo_ref", {}).get(source, {})))
        # Normalize the selected (test) cluster sizes for the current condition (e.g., "3mo_selected")
        cluster_size_input_vals = normalize_cluster_sizes(cluster_sizes.get(condition, {}).get(source, {}))
        
        # Compute the average (reference) cluster size across ages 3mo, 12mo, and 20mo
        cluster_size_ref = [
            (cluster_size_3mo[cluster] + cluster_size_12mo[cluster] + cluster_size_20mo[cluster]) / 3
            for cluster in clusters
        ]
        # Combine the test and reference values by averaging them
        for cluster in clusters:
            cluster_size_input[source][cluster] = (cluster_size_input_vals.get(cluster, 0) + 
                                                    cluster_size_ref[clusters.index(cluster)]) / 2
    return cluster_size_input

In [None]:
# Stack proportions into a matrix for interpolation
def stack_proportions(*props_list):
    """Returns a (n_clusters, n_timepoints) array"""
    return np.stack([df.sort_values("clusters")["proportion"].values for df in props_list], axis=1)


### 5.2 Training data generation and weight optimization

In [None]:
# ========================
# Parameters
# ========================
lambda_reg = 0.05  # L2 regularization parameter (adjust as needed)

# ========================
# Load the transformed bone data (for 3mo, 12mo, 20mo, 5FU30d, 5FU60d)
# ========================
""" if you have saved the transformed bone dataframes, you can load them here
transformed_3mo_bones_df = pd.read_csv(f"{results_dir}/transformed_3mo_bones_inside_df.csv")
transformed_12mo_bones_df = pd.read_csv(f"{results_dir}/transformed_12mo_bones_inside_df.csv")
transformed_20mo_bones_df = pd.read_csv(f"{results_dir}/transformed_20mo_bones_inside_df.csv")
transformed_5fu30d_bones_df = pd.read_csv(f"{results_dir}/transformed_5fu30d_bones_inside_df.csv")
transformed_5fu60d_bones_df = pd.read_csv(f"{results_dir}/transformed_5fu60d_bones_inside_df.csv")
"""

# Get unique dataset IDs for each age group
datasets_3mo = transformed_3mo_bones_df["dataset"].unique()
datasets_12mo = transformed_12mo_bones_df["dataset"].unique()
datasets_20mo = transformed_20mo_bones_df["dataset"].unique()


# ========================
# Get all combinations for leave-one-out: one test bone from each age group.
# (If there are 4 bones per group, there will be 4*4*4=64 iterations.)
# ========================
# Random seed for reproducibility
np.random.seed(42)
loo_combinations = list(itertools.product(datasets_3mo, datasets_12mo, datasets_20mo))
print(f"Total LOO iterations: {len(loo_combinations)}")

# ========================
# Define Loss Functions and Optimization Functions
# ========================

# Ground truth ages for the normal conditions
ground_truth = {"3mo": 3, "12mo": 12, "20mo": 20}


def total_loss(weights, optimal_input, ground_truth, reg_lambda):
    """
    Loss function with L2 regularization.
    weights: a numpy array (length 5)
    optimal_input: dict mapping condition ("3mo", "12mo", "20mo") to a list of 5 numbers.
    ground_truth: dict mapping condition to true age.
    reg_lambda: regularization parameter.
    """
    loss = 0
    for condition, target in ground_truth.items():
        age_values = np.array(optimal_input[condition])
        weighted_avg = np.dot(weights, age_values)
        loss += (weighted_avg - target) ** 2
    loss += reg_lambda * np.sum(np.square(weights))
    return loss

def weight_constraint(weights):
    return np.sum(weights) - 1



### 5.3 Training data generation for the Gaussian model

In [None]:
# Use LOO-CV to generate training data for prediction 
# instead of using interpolation assumptions
# ========================
# Main LOO-CV Loop
# ========================
epoch_index = 1
affinity_flag = "CE" # "CE" or "MSE"

# Define the dataframe to save the training data collected from LOO-CV
# 1. Define the columns
# epoch, cKit_pdf_3mo, cKit_pdf_12mo, cKit_pdf_20mo, HSC_pdf_3mo, HSC_pdf_12mo, HSC_pdf_20mo, HSC_num, HSC_hist_3mo, HSC_hist_12mo, HSC_hist_20mo, cKit_affinity_3mo, cKit_affinity_12mo, cKit_affinity_20mo, ground_truth
# training_data_columns = ["epoch", "cKit_pdf_3mo", "cKit_pdf_12mo", "cKit_pdf_20mo", "HSC_pdf_3mo", "HSC_pdf_12mo", "HSC_pdf_20mo", "HSC_num", "HSC_hist_3mo", "HSC_hist_12mo", "HSC_hist_20mo", "cKit_affinity_3mo", "cKit_affinity_12mo", "cKit_affinity_20mo", "ground_truth"]
training_data_list = []

# Read the affinity matrices for each condition
affinity_matrices_dir = os.path.join(data_dir, "affinity_matrices")

# Load the affinity matrices
affinity_matrices = {}
for f in os.listdir(affinity_matrices_dir):
    # f should start with cluster, and end with sum.csv
    if not f.startswith("affinity") or not f.endswith("sum.csv"):
        continue
    # Load the affinity matrix (csv)
    # add column names 0-9
    affinity_matrix = pd.read_csv(os.path.join(affinity_matrices_dir, f), header=None)
    affinity_matrix.columns = range(10)
    # Extract the condition name
    condition_name = f.split("_")[2]
    # Store the affinity matrix in the dictionary
    affinity_matrices[condition_name] = affinity_matrix
    

data_all = pd.DataFrame()
for test_3mo, test_12mo, test_20mo in loo_combinations:
    # Create a subfolder for this epoch
    epoch_folder = os.path.join(results_dir, f"epoch_{epoch_index}")
    if not os.path.exists(epoch_folder):
        os.makedirs(epoch_folder)
    
    print(f"Epoch {epoch_index}: Test datasets: 3mo={test_3mo}, 12mo={test_12mo}, 20mo={test_20mo}")

    # -------------------------------
    # 1. Separate test and reference data
    # -------------------------------
    # For each age group, test data = rows with dataset == test_x, reference = rows with dataset != test_x.
    selected_3mo_df = transformed_3mo_bones_df[transformed_3mo_bones_df["dataset"] == test_3mo]
    selected_12mo_df = transformed_12mo_bones_df[transformed_12mo_bones_df["dataset"] == test_12mo]
    selected_20mo_df = transformed_20mo_bones_df[transformed_20mo_bones_df["dataset"] == test_20mo]

    # Save the selected test datasets for record
    selected_3mo_df.to_csv(os.path.join(epoch_folder, "selected_3mo_bones_inside_df.csv"), index=False)
    selected_12mo_df.to_csv(os.path.join(epoch_folder, "selected_12mo_bones_inside_df.csv"), index=False)
    selected_20mo_df.to_csv(os.path.join(epoch_folder, "selected_20mo_bones_inside_df.csv"), index=False)
    
    # Define reference datasets (exclude the test dataset)
    ref_3mo_df = transformed_3mo_bones_df[transformed_3mo_bones_df["dataset"] != test_3mo]
    ref_12mo_df = transformed_12mo_bones_df[transformed_12mo_bones_df["dataset"] != test_12mo]
    ref_20mo_df = transformed_20mo_bones_df[transformed_20mo_bones_df["dataset"] != test_20mo]
    
    # -------------------------------
    # 2. Process the reference data to compute features (KDE, histograms, affinity matrices)
    # -------------------------------
    
    kde_results_3mo_ref, _ = kde_for_clusters(ref_3mo_df, bone_outline=ref_outline, binsize=10)
    kde_results_12mo_ref, _ = kde_for_clusters(ref_12mo_df, bone_outline=ref_outline, binsize=10)
    kde_results_20mo_ref, _ = kde_for_clusters(ref_20mo_df, bone_outline=ref_outline, binsize=10)
    
    kde_results_3mo_selected, _ = kde_for_clusters(selected_3mo_df, bone_outline=ref_outline, binsize=10)
    kde_results_12mo_selected, _ = kde_for_clusters(selected_12mo_df, bone_outline=ref_outline, binsize=10)
    kde_results_20mo_selected, _ = kde_for_clusters(selected_20mo_df, bone_outline=ref_outline, binsize=10)
    
    # Optional: Save the KDE results
    """
    with open(os.path.join(epoch_folder, "kde_results_3mo_bones_ref.pkl"), "wb") as f:
        pickle.dump(kde_results_3mo_ref, f)
    with open(os.path.join(epoch_folder, "kde_results_12mo_bones_ref.pkl"), "wb") as f:
        pickle.dump(kde_results_12mo_ref, f)
    with open(os.path.join(epoch_folder, "kde_results_20mo_bones_ref.pkl"), "wb") as f:
        pickle.dump(kde_results_20mo_ref, f)
        
    with open(os.path.join(epoch_folder, "kde_results_3mo_bones_selected.pkl"), "wb") as f:
        pickle.dump(kde_results_3mo_selected, f)
    with open(os.path.join(epoch_folder, "kde_results_12mo_bones_selected.pkl"), "wb") as f:
        pickle.dump(kde_results_12mo_selected, f)
    with open(os.path.join(epoch_folder, "kde_results_20mo_bones_selected.pkl"), "wb") as f:
        pickle.dump(kde_results_20mo_selected, f)
    """
    
    """
    # Read the KDE results from the saved files
    with open(os.path.join(epoch_folder, "kde_results_3mo_bones_ref.pkl"), "rb") as f:
        kde_results_3mo_ref = pickle.load(f)
    with open(os.path.join(epoch_folder, "kde_results_12mo_bones_ref.pkl"), "rb") as f:
        kde_results_12mo_ref = pickle.load(f)
    with open(os.path.join(epoch_folder, "kde_results_20mo_bones_ref.pkl"), "rb") as f:
        kde_results_20mo_ref = pickle.load(f)
        
    with open(os.path.join(epoch_folder, "kde_results_3mo_bones_selected.pkl"), "rb") as f:
        kde_results_3mo_selected = pickle.load(f)
    with open(os.path.join(epoch_folder, "kde_results_12mo_bones_selected.pkl"), "rb") as f:
        kde_results_12mo_selected = pickle.load(f)
    with open(os.path.join(epoch_folder, "kde_results_20mo_bones_selected.pkl"), "rb") as f:
        kde_results_20mo_selected = pickle.load(f)
    """  
    # Normalize the KDE results for cKits and HSCs
    kde_results_3mo_ref_normalized = normalize_kde_values(kde_results_3mo_ref)
    kde_results_12mo_ref_normalized = normalize_kde_values(kde_results_12mo_ref)
    kde_results_20mo_ref_normalized = normalize_kde_values(kde_results_20mo_ref)
    
    # Normalize the KDE results for selected bones
    kde_results_3mo_selected_normalized = normalize_kde_values(kde_results_3mo_selected)
    kde_results_12mo_selected_normalized = normalize_kde_values(kde_results_12mo_selected)
    kde_results_20mo_selected_normalized = normalize_kde_values(kde_results_20mo_selected)

    
    # Instead of using hsc histograms, we will use hsc cluster composition instead
    hsc_cluster_comp_3mo_ref = compute_proportions_hsc(ref_3mo_df)
    hsc_cluster_comp_12mo_ref = compute_proportions_hsc(ref_12mo_df)
    hsc_cluster_comp_20mo_ref = compute_proportions_hsc(ref_20mo_df)
    hsc_cluster_comp_3mo_selected = compute_proportions_hsc(selected_3mo_df)
    hsc_cluster_comp_12mo_selected = compute_proportions_hsc(selected_12mo_df)
    hsc_cluster_comp_20mo_selected = compute_proportions_hsc(selected_20mo_df)
    
    # Save the HSC cluster composition results as csv
    """
    hsc_cluster_comp_3mo_ref.to_csv(os.path.join(epoch_folder, "hsc_cluster_comp_3mo_ref.csv"), index=False)
    hsc_cluster_comp_12mo_ref.to_csv(os.path.join(epoch_folder, "hsc_cluster_comp_12mo_ref.csv"), index=False)
    hsc_cluster_comp_20mo_ref.to_csv(os.path.join(epoch_folder, "hsc_cluster_comp_20mo_ref.csv"), index=False)
    hsc_cluster_comp_3mo_selected.to_csv(os.path.join(epoch_folder, "hsc_cluster_comp_3mo_selected.csv"), index=False)
    hsc_cluster_comp_12mo_selected.to_csv(os.path.join(epoch_folder, "hsc_cluster_comp_12mo_selected.csv"), index=False)
    hsc_cluster_comp_20mo_selected.to_csv(os.path.join(epoch_folder, "hsc_cluster_comp_20mo_selected.csv"), index=False)
    """

    # -------------------------------
    # Process affinity matrices for reference.
    # -------------------------------
    affinity_matrices_selected = {}
    affinity_mat_selected_dict = {"3mo": test_3mo, "12mo": test_12mo, "20mo": test_20mo}

    # Read the affinity matrices for the selected datasets
    for age, bone_name in affinity_mat_selected_dict.items():
        # Read the affinity matrix without header or index
        affinity_matrix = pd.read_csv(f"{affinity_matrices_dir}/affinity_age_{age}_bone_{bone_name}.csv", header=None, delimiter=r"\s+")
        affinity_matrix.columns = range(10)
        affinity_matrices_selected[age] = affinity_matrix
    

    affinity_matrices_3mo_ref = (affinity_matrices.get("3mo") - affinity_matrices_selected["3mo"]) / 3
    affinity_matrices_12mo_ref = (affinity_matrices.get("12mo") - affinity_matrices_selected["12mo"]) / 3
    affinity_matrices_20mo_ref = (affinity_matrices.get("20mo") - affinity_matrices_selected["20mo"]) / 3
    
    # Normalize the affinity matrices by row to fit the assumption of Wasserstein distance
    affinity_matrices_3mo_ref = affinity_matrices_3mo_ref.div(affinity_matrices_3mo_ref.sum(axis=1), axis=0)
    affinity_matrices_12mo_ref = affinity_matrices_12mo_ref.div(affinity_matrices_12mo_ref.sum(axis=1), axis=0)
    affinity_matrices_20mo_ref = affinity_matrices_20mo_ref.div(affinity_matrices_20mo_ref.sum(axis=1), axis=0)

    # Save the reference affinity matrices to CSV files (without header or index)
    """
    affinity_matrices_3mo_ref.to_csv(os.path.join(epoch_folder, "affinity_matrices_3mo_ref_normalized.csv"), index=False, header=False)
    affinity_matrices_12mo_ref.to_csv(os.path.join(epoch_folder, "affinity_matrices_12mo_ref_normalized.csv"), index=False, header=False)
    affinity_matrices_20mo_ref.to_csv(os.path.join(epoch_folder, "affinity_matrices_20mo_ref_normalized.csv"), index=False, header=False)
    """
    # Normalize the selected affinity matrices by row (for CE and MSE)
    affinity_matrices_selected["3mo"] = affinity_matrices_selected["3mo"].div(affinity_matrices_selected["3mo"].sum(axis=1), axis=0)
    affinity_matrices_selected["12mo"] = affinity_matrices_selected["12mo"].div(affinity_matrices_selected["12mo"].sum(axis=1), axis=0)
    affinity_matrices_selected["20mo"] = affinity_matrices_selected["20mo"].div(affinity_matrices_selected["20mo"].sum(axis=1), axis=0)
    
    # Save the affinity matrices for the selected datasets
    """
    affinity_matrices_selected["3mo"].to_csv(os.path.join(epoch_folder, "affinity_matrices_3mo_selected_normalized.csv"), index=False, header=False)
    affinity_matrices_selected["12mo"].to_csv(os.path.join(epoch_folder, "affinity_matrices_12mo_selected_normalized.csv"), index=False, header=False)
    affinity_matrices_selected["20mo"].to_csv(os.path.join(epoch_folder, "affinity_matrices_20mo_selected_normalized.csv"), index=False, header=False)
    """


    # -------------------------------
    # 3. Compute features for the selected datasets (test data)
    # -------------------------------
    clusters = list(range(10))
    # 3.1 Compute feature values (without using the weights) for all selected datasets

    # 3.1.1 Calculate the cKit KDE-based feature values (KL divergence)
    cKit_3mo_3mo_kl_divs = {}
    cKit_3mo_12mo_kl_divs = {}
    cKit_3mo_20mo_kl_divs = {}

    cKit_12mo_3mo_kl_divs = {}
    cKit_12mo_12mo_kl_divs = {}
    cKit_12mo_20mo_kl_divs = {}

    cKit_20mo_3mo_kl_divs = {}
    cKit_20mo_12mo_kl_divs = {}
    cKit_20mo_20mo_kl_divs = {}

    for cluster in clusters:
        cKit_3mo_3mo_kl_divs[cluster] = calculate_kl_divergence(kde_results_3mo_selected_normalized[("cKits", cluster)], kde_results_3mo_ref_normalized[("cKits", cluster)])
        cKit_3mo_12mo_kl_divs[cluster] = calculate_kl_divergence(kde_results_3mo_selected_normalized[("cKits", cluster)], kde_results_12mo_ref_normalized[("cKits", cluster)])
        cKit_3mo_20mo_kl_divs[cluster] = calculate_kl_divergence(kde_results_3mo_selected_normalized[("cKits", cluster)], kde_results_20mo_ref_normalized[("cKits", cluster)])

        cKit_12mo_3mo_kl_divs[cluster] = calculate_kl_divergence(kde_results_12mo_selected_normalized[("cKits", cluster)], kde_results_3mo_ref_normalized[("cKits", cluster)])
        cKit_12mo_12mo_kl_divs[cluster] = calculate_kl_divergence(kde_results_12mo_selected_normalized[("cKits", cluster)], kde_results_12mo_ref_normalized[("cKits", cluster)])
        cKit_12mo_20mo_kl_divs[cluster] = calculate_kl_divergence(kde_results_12mo_selected_normalized[("cKits", cluster)], kde_results_20mo_ref_normalized[("cKits", cluster)])

        cKit_20mo_3mo_kl_divs[cluster] = calculate_kl_divergence(kde_results_20mo_selected_normalized[("cKits", cluster)], kde_results_3mo_ref_normalized[("cKits", cluster)])
        cKit_20mo_12mo_kl_divs[cluster] = calculate_kl_divergence(kde_results_20mo_selected_normalized[("cKits", cluster)], kde_results_12mo_ref_normalized[("cKits", cluster)])
        cKit_20mo_20mo_kl_divs[cluster] = calculate_kl_divergence(kde_results_20mo_selected_normalized[("cKits", cluster)], kde_results_20mo_ref_normalized[("cKits", cluster)])



    
    # 3.1.2 Calculate the HSC cluster composition based features (jsd)
    p_3mo_ref = hsc_cluster_comp_3mo_ref.sort_values("clusters")["proportion"].values
    p_12mo_ref = hsc_cluster_comp_12mo_ref.sort_values("clusters")["proportion"].values
    p_20mo_ref = hsc_cluster_comp_20mo_ref.sort_values("clusters")["proportion"].values
    
    p_3mo_selected = hsc_cluster_comp_3mo_selected.sort_values("clusters")["proportion"].values
    p_12mo_selected = hsc_cluster_comp_12mo_selected.sort_values("clusters")["proportion"].values
    p_20mo_selected = hsc_cluster_comp_20mo_selected.sort_values("clusters")["proportion"].values
    
    HSC_3mo_3mo_cluster_comp = jensenshannon(p_3mo_selected, p_3mo_ref) **2
    HSC_3mo_12mo_cluster_comp = jensenshannon(p_3mo_selected, p_12mo_ref) **2
    HSC_3mo_20mo_cluster_comp = jensenshannon(p_3mo_selected, p_20mo_ref) **2
    HSC_12mo_3mo_cluster_comp = jensenshannon(p_12mo_selected, p_3mo_ref) **2
    HSC_12mo_12mo_cluster_comp = jensenshannon(p_12mo_selected, p_12mo_ref) **2
    HSC_12mo_20mo_cluster_comp = jensenshannon(p_12mo_selected, p_20mo_ref) **2
    HSC_20mo_3mo_cluster_comp = jensenshannon(p_20mo_selected, p_3mo_ref) **2
    HSC_20mo_12mo_cluster_comp = jensenshannon(p_20mo_selected, p_12mo_ref) **2
    HSC_20mo_20mo_cluster_comp = jensenshannon(p_20mo_selected, p_20mo_ref) **2
    
    
    # 3.1.3 Calculate the HSC numbers
    HSC_3mo_selected_num = selected_3mo_df[selected_3mo_df["source"] == "HSCs"].shape[0]
    HSC_12mo_selected_num = selected_12mo_df[selected_12mo_df["source"] == "HSCs"].shape[0]
    HSC_20mo_selected_num = selected_20mo_df[selected_20mo_df["source"] == "HSCs"].shape[0]


    # 3.1.4 Calculate the HSC KDE-based feature values (KL divergence)
    HSC_3mo_3mo_kl_div = calculate_kl_divergence(kde_results_3mo_selected_normalized["HSCs"], kde_results_3mo_ref_normalized["HSCs"])
    HSC_3mo_12mo_kl_div = calculate_kl_divergence(kde_results_3mo_selected_normalized["HSCs"], kde_results_12mo_ref_normalized["HSCs"])
    HSC_3mo_20mo_kl_div = calculate_kl_divergence(kde_results_3mo_selected_normalized["HSCs"], kde_results_20mo_ref_normalized["HSCs"])

    HSC_12mo_3mo_kl_div = calculate_kl_divergence(kde_results_12mo_selected_normalized["HSCs"], kde_results_3mo_ref_normalized["HSCs"])
    HSC_12mo_12mo_kl_div = calculate_kl_divergence(kde_results_12mo_selected_normalized["HSCs"], kde_results_12mo_ref_normalized["HSCs"])
    HSC_12mo_20mo_kl_div = calculate_kl_divergence(kde_results_12mo_selected_normalized["HSCs"], kde_results_20mo_ref_normalized["HSCs"])

    HSC_20mo_3mo_kl_div = calculate_kl_divergence(kde_results_20mo_selected_normalized["HSCs"], kde_results_3mo_ref_normalized["HSCs"])
    HSC_20mo_12mo_kl_div = calculate_kl_divergence(kde_results_20mo_selected_normalized["HSCs"], kde_results_12mo_ref_normalized["HSCs"])
    HSC_20mo_20mo_kl_div = calculate_kl_divergence(kde_results_20mo_selected_normalized["HSCs"], kde_results_20mo_ref_normalized["HSCs"])


    # 3.1.5 Calculate the cKit affinity-based feature values (Cross Entropy)
    cKit_3mo_3mo_ces = {}
    cKit_3mo_12mo_ces = {}
    cKit_3mo_20mo_ces = {}

    cKit_12mo_3mo_ces = {}
    cKit_12mo_12mo_ces = {}
    cKit_12mo_20mo_ces = {}

    cKit_20mo_3mo_ces = {}
    cKit_20mo_12mo_ces = {}
    cKit_20mo_20mo_ces = {}

    for cluster in clusters:
        # Matched comparisons
        cKit_3mo_3mo_ces[cluster] = calculate_cross_entropy(affinity_matrices_selected["3mo"].iloc[cluster, :], affinity_matrices_3mo_ref.iloc[cluster, :])
        cKit_12mo_12mo_ces[cluster] = calculate_cross_entropy(affinity_matrices_selected["12mo"].iloc[cluster, :], affinity_matrices_12mo_ref.iloc[cluster, :])
        cKit_20mo_20mo_ces[cluster] = calculate_cross_entropy(affinity_matrices_selected["20mo"].iloc[cluster, :], affinity_matrices_20mo_ref.iloc[cluster, :])

        # Unmatched comparisons
        cKit_3mo_12mo_ces[cluster] = calculate_cross_entropy(affinity_matrices_selected["3mo"].iloc[cluster, :], affinity_matrices_12mo_ref.iloc[cluster, :])
        cKit_3mo_20mo_ces[cluster] = calculate_cross_entropy(affinity_matrices_selected["3mo"].iloc[cluster, :], affinity_matrices_20mo_ref.iloc[cluster, :])

        cKit_12mo_3mo_ces[cluster] = calculate_cross_entropy(affinity_matrices_selected["12mo"].iloc[cluster, :], affinity_matrices_3mo_ref.iloc[cluster, :])
        cKit_12mo_20mo_ces[cluster] = calculate_cross_entropy(affinity_matrices_selected["12mo"].iloc[cluster, :], affinity_matrices_20mo_ref.iloc[cluster, :])

        cKit_20mo_3mo_ces[cluster] = calculate_cross_entropy(affinity_matrices_selected["20mo"].iloc[cluster, :], affinity_matrices_3mo_ref.iloc[cluster, :])
        cKit_20mo_12mo_ces[cluster] = calculate_cross_entropy(affinity_matrices_selected["20mo"].iloc[cluster, :], affinity_matrices_12mo_ref.iloc[cluster, :])
        
        
    
    # 3.2 Compute the cluster size and cluster weights for the selected datasets
    # Compute cluster sizes for each selected dataset using the function calculate_cluster_sizes
    cluster_sizes_selected = {}
    for transformed_df, condition in zip([selected_3mo_df, selected_12mo_df, selected_20mo_df],
                                        ["3mo_selected", "12mo_selected", "20mo_selected"]):
        cluster_sizes_selected[condition] = calculate_cluster_sizes(transformed_df)

    # Compute cluster sizes for the reference data (for normalization)
    cluster_sizes_ref = {}
    for transformed_df, condition in zip([ref_3mo_df, ref_12mo_df, ref_20mo_df],
                                        ["3mo_ref", "12mo_ref", "20mo_ref"]):
        cluster_sizes_ref[condition] = calculate_cluster_sizes(transformed_df)
    
    # Define the sources and clusters to use for weighted averaging.
    sources = ["cKits", "HSCs"]  # adjust if you use additional sources
    clusters_list = list(range(10))
    
    # Compute weighted cluster sizes for each age condition using compute_weighted_cluster_sizes.
    cluster_size_3mo_selected = compute_weighted_cluster_sizes(ref_cluster_sizes=cluster_sizes_ref,
                                                            cluster_sizes=cluster_sizes_selected,
                                                            condition="3mo_selected",
                                                            sources=sources,
                                                            clusters=clusters_list)
    cluster_size_12mo_selected = compute_weighted_cluster_sizes(ref_cluster_sizes=cluster_sizes_ref,
                                                            cluster_sizes=cluster_sizes_selected,
                                                            condition="12mo_selected",
                                                            sources=sources,
                                                            clusters=clusters_list)
    cluster_size_20mo_selected = compute_weighted_cluster_sizes(ref_cluster_sizes=cluster_sizes_ref,
                                                            cluster_sizes=cluster_sizes_selected,
                                                            condition="20mo_selected",
                                                            sources=sources,
                                                            clusters=clusters_list)
    
    # Optionally, save the computed cluster sizes for the selected datasets for later inspection.
    """
    with open(os.path.join(epoch_folder, "cluster_sizes_selected.pkl"), "wb") as f:
        pickle.dump(cluster_sizes_selected, f)
    with open(os.path.join(epoch_folder, "weighted_cluster_size_3mo_selected.pkl"), "wb") as f:
        pickle.dump(cluster_size_3mo_selected, f)
    with open(os.path.join(epoch_folder, "weighted_cluster_size_12mo_selected.pkl"), "wb") as f:
        pickle.dump(cluster_size_12mo_selected, f)
    with open(os.path.join(epoch_folder, "weighted_cluster_size_20mo_selected.pkl"), "wb") as f:
        pickle.dump(cluster_size_20mo_selected, f)
    """
    
    
    # with keys "cKits" and "HSCs" that map cluster indices (0..9) to normalized weights.
    # -------------------------------
    # Weighted average for cKit KDE-based features (KL divergence)
    average_cKit_3mo_3mo_kl_div = np.average(list(cKit_3mo_3mo_kl_divs.values()), 
                                            weights=list(cluster_size_3mo_selected["cKits"].values()))
    average_cKit_3mo_12mo_kl_div = np.average(list(cKit_3mo_12mo_kl_divs.values()), 
                                            weights=list(cluster_size_3mo_selected["cKits"].values()))
    average_cKit_3mo_20mo_kl_div = np.average(list(cKit_3mo_20mo_kl_divs.values()), 
                                            weights=list(cluster_size_3mo_selected["cKits"].values()))

    average_cKit_12mo_3mo_kl_div = np.average(list(cKit_12mo_3mo_kl_divs.values()), 
                                            weights=list(cluster_size_12mo_selected["cKits"].values()))
    average_cKit_12mo_12mo_kl_div = np.average(list(cKit_12mo_12mo_kl_divs.values()), 
                                            weights=list(cluster_size_12mo_selected["cKits"].values()))
    average_cKit_12mo_20mo_kl_div = np.average(list(cKit_12mo_20mo_kl_divs.values()), 
                                            weights=list(cluster_size_12mo_selected["cKits"].values()))

    average_cKit_20mo_3mo_kl_div = np.average(list(cKit_20mo_3mo_kl_divs.values()), 
                                            weights=list(cluster_size_20mo_selected["cKits"].values()))
    average_cKit_20mo_12mo_kl_div = np.average(list(cKit_20mo_12mo_kl_divs.values()), 
                                            weights=list(cluster_size_20mo_selected["cKits"].values()))
    average_cKit_20mo_20mo_kl_div = np.average(list(cKit_20mo_20mo_kl_divs.values()), 
                                            weights=list(cluster_size_20mo_selected["cKits"].values()))


    



    # Weighted average for cKit affinity-based features (Cross Entropy)
    average_cKit_3mo_3mo_ce = np.average(list(cKit_3mo_3mo_ces.values()), 
                                        weights=list(cluster_size_3mo_selected["cKits"].values()))
    average_cKit_3mo_12mo_ce = np.average(list(cKit_3mo_12mo_ces.values()), 
                                        weights=list(cluster_size_3mo_selected["cKits"].values()))
    average_cKit_3mo_20mo_ce = np.average(list(cKit_3mo_20mo_ces.values()), 
                                        weights=list(cluster_size_3mo_selected["cKits"].values()))

    average_cKit_12mo_3mo_ce = np.average(list(cKit_12mo_3mo_ces.values()), 
                                        weights=list(cluster_size_12mo_selected["cKits"].values()))
    average_cKit_12mo_12mo_ce = np.average(list(cKit_12mo_12mo_ces.values()), 
                                        weights=list(cluster_size_12mo_selected["cKits"].values()))
    average_cKit_12mo_20mo_ce = np.average(list(cKit_12mo_20mo_ces.values()), 
                                        weights=list(cluster_size_12mo_selected["cKits"].values()))

    average_cKit_20mo_3mo_ce = np.average(list(cKit_20mo_3mo_ces.values()), 
                                        weights=list(cluster_size_20mo_selected["cKits"].values()))
    average_cKit_20mo_12mo_ce = np.average(list(cKit_20mo_12mo_ces.values()), 
                                        weights=list(cluster_size_20mo_selected["cKits"].values()))
    average_cKit_20mo_20mo_ce = np.average(list(cKit_20mo_20mo_ces.values()), 
                                        weights=list(cluster_size_20mo_selected["cKits"].values()))
    
    # 4. Add features and ground truth values to the training data dataframe
    # -------------------------------
    # Construct the training data row for this epoch


    training_data_row_3mo = {
        "epoch": epoch_index,
        "cKit Density Divergence (vs. 3mo)": average_cKit_3mo_3mo_kl_div,
        "cKit Density Divergence (vs. 12mo)": average_cKit_3mo_12mo_kl_div,
        "cKit Density Divergence (vs. 20mo)": average_cKit_3mo_20mo_kl_div,
        "HSC Density Divergence (vs. 3mo)": HSC_3mo_3mo_kl_div,
        "HSC Density Divergence (vs. 12mo)": HSC_3mo_12mo_kl_div,
        "HSC Density Divergence (vs. 20mo)": HSC_3mo_20mo_kl_div,
        "HSC Count": HSC_3mo_selected_num,
        "HSC Composition (vs. 3mo)": HSC_3mo_3mo_cluster_comp,
        "HSC Composition (vs. 12mo)": HSC_3mo_12mo_cluster_comp,
        "HSC Composition (vs. 20mo)": HSC_3mo_20mo_cluster_comp,
        "cKit Neighborhood Affinity (vs. 3mo)": average_cKit_3mo_3mo_ce,
        "cKit Neighborhood Affinity (vs. 12mo)": average_cKit_3mo_12mo_ce,
        "cKit Neighborhood Affinity (vs. 20mo)": average_cKit_3mo_20mo_ce,
        "ground_truth": 3
    }
    training_data_row_12mo = {
        "epoch": epoch_index,
        "cKit Density Divergence (vs. 3mo)": average_cKit_12mo_3mo_kl_div,
        "cKit Density Divergence (vs. 12mo)": average_cKit_12mo_12mo_kl_div,
        "cKit Density Divergence (vs. 20mo)": average_cKit_12mo_20mo_kl_div,
        "HSC Density Divergence (vs. 3mo)": HSC_12mo_3mo_kl_div,
        "HSC Density Divergence (vs. 12mo)": HSC_12mo_12mo_kl_div,
        "HSC Density Divergence (vs. 20mo)": HSC_12mo_20mo_kl_div,
        "HSC Count": HSC_12mo_selected_num,
        "HSC Composition (vs. 3mo)": HSC_12mo_3mo_cluster_comp,
        "HSC Composition (vs. 12mo)": HSC_12mo_12mo_cluster_comp,
        "HSC Composition (vs. 20mo)": HSC_12mo_20mo_cluster_comp,
        "cKit Neighborhood Affinity (vs. 3mo)": average_cKit_12mo_3mo_ce,
        "cKit Neighborhood Affinity (vs. 12mo)": average_cKit_12mo_12mo_ce,
        "cKit Neighborhood Affinity (vs. 20mo)": average_cKit_12mo_20mo_ce,
        "ground_truth": 12
    }
    training_data_row_20mo = {
        "epoch": epoch_index,
        "cKit Density Divergence (vs. 3mo)": average_cKit_20mo_3mo_kl_div,
        "cKit Density Divergence (vs. 12mo)": average_cKit_20mo_12mo_kl_div,
        "cKit Density Divergence (vs. 20mo)": average_cKit_20mo_20mo_kl_div,
        "HSC Density Divergence (vs. 3mo)": HSC_20mo_3mo_kl_div,
        "HSC Density Divergence (vs. 12mo)": HSC_20mo_12mo_kl_div,
        "HSC Density Divergence (vs. 20mo)": HSC_20mo_20mo_kl_div,
        "HSC Count": HSC_20mo_selected_num,
        "HSC Composition (vs. 3mo)": HSC_20mo_3mo_cluster_comp,
        "HSC Composition (vs. 12mo)": HSC_20mo_12mo_cluster_comp,
        "HSC Composition (vs. 20mo)": HSC_20mo_20mo_cluster_comp,
        "cKit Neighborhood Affinity (vs. 3mo)": average_cKit_20mo_3mo_ce,
        "cKit Neighborhood Affinity (vs. 12mo)": average_cKit_20mo_12mo_ce,
        "cKit Neighborhood Affinity (vs. 20mo)": average_cKit_20mo_20mo_ce,
        "ground_truth": 20
    }


    # Append rows to the list
    training_data_list.append(training_data_row_3mo)
    training_data_list.append(training_data_row_12mo)
    training_data_list.append(training_data_row_20mo)
    print(f"Epoch {epoch_index} completed.")
    
    epoch_index += 1
# End of LOO-CV loop

training_data_df = pd.DataFrame(training_data_list)
# training_data_df.to_csv(os.path.join(results_dir, "training_data_cluster_comp.csv"), index=False)

### 5.4 Weight optimization for the linear model

In [None]:
# Use LOO-CV to estimate the optimal weights for each epoch (with the interpolation assumtion)
# ========================
# Main LOO-CV Loop
# ========================
epoch_index = 1
affinity_flag = "CE" # cross entropy
for test_3mo, test_12mo, test_20mo in loo_combinations:
    # Create a subfolder for this epoch
    epoch_folder = os.path.join(results_dir, f"epoch_{epoch_index}")
    if not os.path.exists(epoch_folder):
        os.makedirs(epoch_folder)
    
    print(f"Epoch {epoch_index}: Test datasets: 3mo={test_3mo}, 12mo={test_12mo}, 20mo={test_20mo}")

    # -------------------------------
    # 1. Separate test and reference data
    # -------------------------------
    # For each age group, test data = rows with dataset == test_x, reference = rows with dataset != test_x.
    selected_3mo_df = transformed_3mo_bones_df[transformed_3mo_bones_df["dataset"] == test_3mo]
    selected_12mo_df = transformed_12mo_bones_df[transformed_12mo_bones_df["dataset"] == test_12mo]
    selected_20mo_df = transformed_20mo_bones_df[transformed_20mo_bones_df["dataset"] == test_20mo]

    # Optional: Save the selected test datasets for record
    # selected_3mo_df.to_csv(os.path.join(epoch_folder, "selected_3mo_bones_inside_df.csv"), index=False)
    # selected_12mo_df.to_csv(os.path.join(epoch_folder, "selected_12mo_bones_inside_df.csv"), index=False)
    # selected_20mo_df.to_csv(os.path.join(epoch_folder, "selected_20mo_bones_inside_df.csv"), index=False)
    
    # Define reference datasets (exclude the test dataset)
    ref_3mo_df = transformed_3mo_bones_df[transformed_3mo_bones_df["dataset"] != test_3mo]
    ref_12mo_df = transformed_12mo_bones_df[transformed_12mo_bones_df["dataset"] != test_12mo]
    ref_20mo_df = transformed_20mo_bones_df[transformed_20mo_bones_df["dataset"] != test_20mo]
    
    # -------------------------------
    # 2. Process the reference data to compute features
    # -------------------------------
    
    kde_results_3mo_ref, _ = kde_for_clusters(ref_3mo_df, bone_outline=ref_outline, binsize=10)
    kde_results_12mo_ref, _ = kde_for_clusters(ref_12mo_df, bone_outline=ref_outline, binsize=10)
    kde_results_20mo_ref, _ = kde_for_clusters(ref_20mo_df, bone_outline=ref_outline, binsize=10)
    
    kde_results_3mo_selected, _ = kde_for_clusters(selected_3mo_df, bone_outline=ref_outline, binsize=10)
    kde_results_12mo_selected, _ = kde_for_clusters(selected_12mo_df, bone_outline=ref_outline, binsize=10)
    kde_results_20mo_selected, _ = kde_for_clusters(selected_20mo_df, bone_outline=ref_outline, binsize=10)
    
    # Optional: save the KDE results
    """
    with open(os.path.join(epoch_folder, "kde_results_3mo_bones_ref.pkl"), "wb") as f:
        pickle.dump(kde_results_3mo_ref, f)
    with open(os.path.join(epoch_folder, "kde_results_12mo_bones_ref.pkl"), "wb") as f:
        pickle.dump(kde_results_12mo_ref, f)
    with open(os.path.join(epoch_folder, "kde_results_20mo_bones_ref.pkl"), "wb") as f:
        pickle.dump(kde_results_20mo_ref, f)
        
    with open(os.path.join(epoch_folder, "kde_results_3mo_bones_selected.pkl"), "wb") as f:
        pickle.dump(kde_results_3mo_selected, f)
    with open(os.path.join(epoch_folder, "kde_results_12mo_bones_selected.pkl"), "wb") as f:
        pickle.dump(kde_results_12mo_selected, f)
    with open(os.path.join(epoch_folder, "kde_results_20mo_bones_selected.pkl"), "wb") as f:
        pickle.dump(kde_results_20mo_selected, f)
    """
    
    """ If you have saved the kde results, you can load them here (approx. 17 hours)
    # Read the KDE results from the saved files
    with open(os.path.join(epoch_folder, "kde_results_3mo_bones_ref.pkl"), "rb") as f:
        kde_results_3mo_ref = pickle.load(f)
    with open(os.path.join(epoch_folder, "kde_results_12mo_bones_ref.pkl"), "rb") as f:
        kde_results_12mo_ref = pickle.load(f)
    with open(os.path.join(epoch_folder, "kde_results_20mo_bones_ref.pkl"), "rb") as f:
        kde_results_20mo_ref = pickle.load(f)
        
    with open(os.path.join(epoch_folder, "kde_results_3mo_bones_selected.pkl"), "rb") as f:
        kde_results_3mo_selected = pickle.load(f)
    with open(os.path.join(epoch_folder, "kde_results_12mo_bones_selected.pkl"), "rb") as f:
        kde_results_12mo_selected = pickle.load(f)
    with open(os.path.join(epoch_folder, "kde_results_20mo_bones_selected.pkl"), "rb") as f:
        kde_results_20mo_selected = pickle.load(f)
    """
    
    
    # Instead of using hsc histograms, we will use hsc cluster composition instead
    hsc_cluster_comp_3mo_ref = compute_proportions_hsc(ref_3mo_df)
    hsc_cluster_comp_12mo_ref = compute_proportions_hsc(ref_12mo_df)
    hsc_cluster_comp_20mo_ref = compute_proportions_hsc(ref_20mo_df)
    hsc_cluster_comp_3mo_selected = compute_proportions_hsc(selected_3mo_df)
    hsc_cluster_comp_12mo_selected = compute_proportions_hsc(selected_12mo_df)
    hsc_cluster_comp_20mo_selected = compute_proportions_hsc(selected_20mo_df)
    
    # Save the HSC cluster composition results as csv
    """
    hsc_cluster_comp_3mo_ref.to_csv(os.path.join(epoch_folder, "hsc_cluster_comp_3mo_ref.csv"), index=False)
    hsc_cluster_comp_12mo_ref.to_csv(os.path.join(epoch_folder, "hsc_cluster_comp_12mo_ref.csv"), index=False)
    hsc_cluster_comp_20mo_ref.to_csv(os.path.join(epoch_folder, "hsc_cluster_comp_20mo_ref.csv"), index=False)
    hsc_cluster_comp_3mo_selected.to_csv(os.path.join(epoch_folder, "hsc_cluster_comp_3mo_selected.csv"), index=False)
    hsc_cluster_comp_12mo_selected.to_csv(os.path.join(epoch_folder, "hsc_cluster_comp_12mo_selected.csv"), index=False)
    hsc_cluster_comp_20mo_selected.to_csv(os.path.join(epoch_folder, "hsc_cluster_comp_20mo_selected.csv"), index=False)
    """
    
    # -------------------------------
    # Process affinity matrices for reference.
    # The overall affinity matrices are assumed to be stored in a dictionary "affinity_matrices"
    # and the test (selected) affinity matrices have been loaded into affinity_matrices_selected.
    # Since you have 4 bones per age, the reference matrix is computed as:
    #   (overall affinity matrix - test affinity matrix) / 3
    # -------------------------------
    affinity_matrices_selected = {}
    affinity_mat_selected_dict = {"3mo": test_3mo, "12mo": test_12mo, "20mo": test_20mo}

    # Read the affinity matrices for the selected datasets
    for age, bone_name in affinity_mat_selected_dict.items():
        # Read the affinity matrix without header or index
        affinity_matrix = pd.read_csv(f"{affinity_matrices_dir}/affinity_age_{age}_bone_{bone_name}.csv", header=None, delimiter=r"\s+")
        affinity_matrix.columns = range(10)
        affinity_matrices_selected[age] = affinity_matrix
    

    affinity_matrices_3mo_ref = (affinity_matrices.get("3mo") - affinity_matrices_selected["3mo"]) / 3
    affinity_matrices_12mo_ref = (affinity_matrices.get("12mo") - affinity_matrices_selected["12mo"]) / 3
    affinity_matrices_20mo_ref = (affinity_matrices.get("20mo") - affinity_matrices_selected["20mo"]) / 3
    
    # Normalize the affinity matrices by row to fit the assumption of Wasserstein distance
    affinity_matrices_3mo_ref = affinity_matrices_3mo_ref.div(affinity_matrices_3mo_ref.sum(axis=1), axis=0)
    affinity_matrices_12mo_ref = affinity_matrices_12mo_ref.div(affinity_matrices_12mo_ref.sum(axis=1), axis=0)
    affinity_matrices_20mo_ref = affinity_matrices_20mo_ref.div(affinity_matrices_20mo_ref.sum(axis=1), axis=0)

    # Save the reference affinity matrices to CSV files (without header or index)
    affinity_matrices_3mo_ref.to_csv(os.path.join(epoch_folder, "affinity_matrices_3mo_ref_normalized.csv"), index=False, header=False)
    affinity_matrices_12mo_ref.to_csv(os.path.join(epoch_folder, "affinity_matrices_12mo_ref_normalized.csv"), index=False, header=False)
    affinity_matrices_20mo_ref.to_csv(os.path.join(epoch_folder, "affinity_matrices_20mo_ref_normalized.csv"), index=False, header=False)
    
    # Normalize the selected affinity matrices by row (for CE and MSE)
    affinity_matrices_selected["3mo"] = affinity_matrices_selected["3mo"].div(affinity_matrices_selected["3mo"].sum(axis=1), axis=0)
    affinity_matrices_selected["12mo"] = affinity_matrices_selected["12mo"].div(affinity_matrices_selected["12mo"].sum(axis=1), axis=0)
    affinity_matrices_selected["20mo"] = affinity_matrices_selected["20mo"].div(affinity_matrices_selected["20mo"].sum(axis=1), axis=0)
    
    # Save the affinity matrices for the selected datasets
    affinity_matrices_selected["3mo"].to_csv(os.path.join(epoch_folder, "affinity_matrices_3mo_selected_normalized.csv"), index=False, header=False)
    affinity_matrices_selected["12mo"].to_csv(os.path.join(epoch_folder, "affinity_matrices_12mo_selected_normalized.csv"), index=False, header=False)
    affinity_matrices_selected["20mo"].to_csv(os.path.join(epoch_folder, "affinity_matrices_20mo_selected_normalized.csv"), index=False, header=False)
    
    # -------------------------------
    # 3. Perform interpolation of the features
    # This block creates the dictionaries:
    #   interpolated_kde_ckits, interpolated_histograms_hsc, interpolated_affinity_matrices, interpolated_kde_hsc_merged
    # -------------------------------
    
    # Define the real time points corresponding to your reference ages
    real_times = np.array([3, 12, 20])
    # Create a fine-grained time grid (e.g., 100 points between 3 and 20)
    fine_times = np.linspace(3, 20, 100)
    
    # Normalize the KDE results for cKits and HSCs
    kde_results_3mo_ref_normalized = normalize_kde_values(kde_results_3mo_ref)
    kde_results_12mo_ref_normalized = normalize_kde_values(kde_results_12mo_ref)
    kde_results_20mo_ref_normalized = normalize_kde_values(kde_results_20mo_ref)
    
    # Normalize the KDE results for selected bones
    kde_results_3mo_selected_normalized = normalize_kde_values(kde_results_3mo_selected)
    kde_results_12mo_selected_normalized = normalize_kde_values(kde_results_12mo_selected)
    kde_results_20mo_selected_normalized = normalize_kde_values(kde_results_20mo_selected)
    
    # Initialize dictionaries to hold the interpolated results
    interpolated_kde_ckits = {}
    interpolated_histograms_hsc = {}
    interpolated_affinity_matrices = {}

    # Define the list of clusters (0 to 9)
    clusters = list(range(10))
    
    # -------------------------------
    # 3a. Interpolate the KDEs for cKits per cluster
    # -------------------------------
    for cluster in clusters:
        # For each age group, extract the KDE for the given cluster from the normalized S2 results
        age_kdes = np.array([
            kde_results_3mo_ref_normalized[("cKits", cluster)],
            kde_results_12mo_ref_normalized[("cKits", cluster)],
            kde_results_20mo_ref_normalized[("cKits", cluster)]
        ])
        # Create a linear interpolator for this cluster (along the age dimension)
        f_kde = interp1d(real_times, age_kdes, axis=0, kind="linear")
        # Evaluate the interpolator at the fine-grained time points and store in the dictionary
        # interpolated_kde_ckits[cluster] = f_kde(fine_times)
        
        # Evaluate the interpolator at the fine-grained time points
        interp_values_kde = f_kde(fine_times)
        # Normalize each interpolated pdf (each row) so that its sum is 1
        interp_values_normalized_kde = interp_values_kde / np.sum(interp_values_kde, axis=(1,2), keepdims=True)
        # Store the normalized result in the dictionary
        interpolated_kde_ckits[cluster] = interp_values_normalized_kde
    

    # -------------------------------
    # 3b. Interpolate the HSC cluster composition (as the alternative of the histograms)
    # -------------------------------
    age_hsc_cluster_comp = stack_proportions(hsc_cluster_comp_3mo_ref, hsc_cluster_comp_12mo_ref, hsc_cluster_comp_20mo_ref)
    # Create a linear interpolator for the HSC cluster composition
    f_hsc_cluster_comp = interp1d(real_times, age_hsc_cluster_comp, axis=1, kind="linear", fill_value="extrapolate") # It is correct here to use axis=1
    
    # Step 3 (revised): Interpolate and normalize each timepoint
    interp_values_hsc_cluster_comp = f_hsc_cluster_comp(fine_times)  # shape: (100 timepoints, 10 clusters)

    # Step 4 Normalize to make sure each vector sums to 1
    interp_values_hsc_cluster_comp = interp_values_hsc_cluster_comp.T
    interp_values_hsc_cluster_comp = interp_values_hsc_cluster_comp / interp_values_hsc_cluster_comp.sum(axis=1, keepdims=True)

    # -------------------------------
    # 3c. Interpolate the affinity matrices row-wise per cluster
    # -------------------------------
    for cluster in clusters:
        # For each age group, extract the row corresponding to the cluster from the reference affinity matrices.
        # We convert each row to a NumPy array.
        age_affinity_matrices = np.array([
            affinity_matrices_3mo_ref.iloc[cluster, :].values,
            affinity_matrices_12mo_ref.iloc[cluster, :].values,
            affinity_matrices_20mo_ref.iloc[cluster, :].values
        ])
        # Create a linear interpolator for this row
        f_affinity = interp1d(real_times, age_affinity_matrices, axis=0, kind="linear")
        # Evaluate the interpolator at the fine time grid and store the result
        # interpolated_affinity_matrices[cluster] = f_affinity(fine_times)
        
        # Normalize the interpolated values as probabilities
        # Evaluate the interpolator at the fine time grid
        interp_values_affinity = f_affinity(fine_times)
        # Re-normalize each interpolated row so that it sums to 1:
        interp_values_normalized_affinity = interp_values_affinity / np.sum(interp_values_affinity, axis=1, keepdims=True)
        # Store the normalized interpolated row
        interpolated_affinity_matrices[cluster] = interp_values_normalized_affinity
        
    # -------------------------------
    # 3d. Interpolate merged HSC KDEs
    # For the entire "HSCs" key, assume the normalized KDE results are arrays.
    # -------------------------------
    age_kdes_hsc_merged = np.array([
        kde_results_3mo_ref_normalized["HSCs"],
        kde_results_12mo_ref_normalized["HSCs"],
        kde_results_20mo_ref_normalized["HSCs"]
    ])
    f_kde_hsc_merged = interp1d(real_times, age_kdes_hsc_merged, axis=0, kind="linear")
    # interpolated_kde_hsc_merged = f_kde_hsc_merged(fine_times)
    
    # Normalize the interpolated values
    interp_values_kde_hsc_merged = f_kde_hsc_merged(fine_times)
    interp_values_normalized_hsc_merged = interp_values_kde_hsc_merged / np.sum(interp_values_kde_hsc_merged, axis=(1,2), keepdims=True)
    interpolated_kde_hsc_merged = interp_values_normalized_hsc_merged
    
    
    
    # -------------------------------
    # 4. Compute the "optimal input" for each age condition
    # -------------------------------

    # Initialize containers for the per-modality, per-age estimates.
    # For cKits (KDE-based estimates) we store a list (one value per cluster)
    age_ckit_3mo_selected = []
    age_ckit_12mo_selected = []
    age_ckit_20mo_selected = []

    # For cKit affinity, we store a list per condition
    age_affinity_3mo_selected = []
    age_affinity_12mo_selected = []
    age_affinity_20mo_selected = []
    
    # For HSC numbers (exponential regression) we have a single value per condition (without cluster)
    age_num_3mo_selected = 0
    age_num_12mo_selected = 0
    age_num_20mo_selected = 0
    
    # HSC numbers-based age estimates using exponential regression (no clusters)
    age_groups_numeric = np.array([3, 12, 20])  # Numeric representation of age groups
    conditions_ref = ["3mo_ref", "12mo_ref", "20mo_ref"]
    conditions_selected = ["3mo_selected", "12mo_selected", "20mo_selected"]
    
    hsc_nums_ref = {}
    hsc_nums_selected = {}


    # Calculate the number of cells per cluster for each condition (ref)
    for transformed_df, condition in zip(
        [ref_3mo_df, ref_12mo_df, ref_20mo_df],
        conditions_ref,
    ):
        hsc_condition = len(transformed_df[transformed_df["source"] == "HSCs"])
        # Normalize by the number of datasets
        num_datasets = len(transformed_df["dataset"].unique())
        hsc_condition = hsc_condition / num_datasets
        hsc_nums_ref[condition] = hsc_condition
        
    # Calculate the number of cells per cluster for each condition (selected)
    for transformed_df, condition in zip(
        [selected_3mo_df, selected_12mo_df, selected_20mo_df],
        conditions_selected,
    ):
        hsc_condition = len(transformed_df[transformed_df["source"] == "HSCs"])
        # Normalize by the number of datasets
        num_datasets = len(transformed_df["dataset"].unique())
        hsc_condition = hsc_condition / num_datasets
        hsc_nums_selected[condition] = hsc_condition
        
    hsc_nums_matrix_ref = np.array([hsc_nums_ref[condition] for condition in conditions_ref])  # Convert dictionary to matrix
    hsc_nums_matrix_selected = np.array([hsc_nums_selected[condition] for condition in conditions_selected])  # Convert dictionary to matrix

    # For 3mo:
    try:
        popt_exp, _ = curve_fit(exponential_func, age_groups_numeric, hsc_nums_matrix_ref[:3], maxfev=10000)
        exp_func = lambda x: exponential_func(x, *popt_exp)
        intersection_3mo = find_intersection(exp_func, hsc_nums_matrix_selected[0], fine_times, method="closest")
    except Exception as e:
        intersection_3mo = np.nan
    age_num_3mo_selected = intersection_3mo

    # For 12mo:
    try:
        popt_exp, _ = curve_fit(exponential_func, age_groups_numeric, hsc_nums_matrix_ref[:3], maxfev=10000)
        exp_func = lambda x: exponential_func(x, *popt_exp)
        intersection_12mo = find_intersection(exp_func, hsc_nums_matrix_selected[1], fine_times, method="closest")
    except Exception as e:
        intersection_12mo = np.nan
    age_num_12mo_selected = intersection_12mo

    # For 20mo:
    try:
        popt_exp, _ = curve_fit(exponential_func, age_groups_numeric, hsc_nums_matrix_ref[:3], maxfev=10000)
        exp_func = lambda x: exponential_func(x, *popt_exp)
        intersection_20mo = find_intersection(exp_func, hsc_nums_matrix_selected[2], fine_times, method="closest")
    except Exception as e:
        intersection_20mo = np.nan
    age_num_20mo_selected = intersection_20mo

    # For merged HSC KDEs, we compute the KL divergence over the entire "HSCs" array.
    kl_div_3mo_hsc_merged_selected = np.array([
        calculate_kl_divergence(kde_results_3mo_selected_normalized["HSCs"], interpolated_kde_hsc_merged[i])
        for i in range(len(fine_times))
    ])
    kl_div_12mo_hsc_merged_selected = np.array([
        calculate_kl_divergence(kde_results_12mo_selected_normalized["HSCs"], interpolated_kde_hsc_merged[i])
        for i in range(len(fine_times))
    ])
    kl_div_20mo_hsc_merged_selected = np.array([
        calculate_kl_divergence(kde_results_20mo_selected_normalized["HSCs"], interpolated_kde_hsc_merged[i])
        for i in range(len(fine_times))
    ])
    
    # Choose the fine time (age) that minimizes KL divergence for merged HSC KDEs:
    age_3mo_hsc_merged_selected = fine_times[np.argmin(kl_div_3mo_hsc_merged_selected)]
    age_12mo_hsc_merged_selected = fine_times[np.argmin(kl_div_12mo_hsc_merged_selected)]
    age_20mo_hsc_merged_selected = fine_times[np.argmin(kl_div_20mo_hsc_merged_selected)]
    
    
    # For hsc cluster composition, we compute the jensenshannon divergence over the entire "HSCs" array.
    p_3mo_selected = hsc_cluster_comp_3mo_selected.sort_values("clusters")["proportion"].values
    p_12mo_selected = hsc_cluster_comp_12mo_selected.sort_values("clusters")["proportion"].values
    p_20mo_selected = hsc_cluster_comp_20mo_selected.sort_values("clusters")["proportion"].values
    
    jsd_3mo_hsc_cluster_comp_selected = np.array([
        jensenshannon(p_3mo_selected, q) ** 2 for q in interp_values_hsc_cluster_comp
    ])
    jsd_12mo_hsc_cluster_comp_selected = np.array([
        jensenshannon(p_12mo_selected, q) ** 2 for q in interp_values_hsc_cluster_comp
    ])
    jsd_20mo_hsc_cluster_comp_selected = np.array([
        jensenshannon(p_20mo_selected, q) ** 2 for q in interp_values_hsc_cluster_comp
    ])
    
    # Choose the fine time (age) that minimizes JS divergence for hsc cluster composition:
    age_3mo_hsc_cluster_comp_selected = fine_times[np.argmin(jsd_3mo_hsc_cluster_comp_selected)]
    age_12mo_hsc_cluster_comp_selected = fine_times[np.argmin(jsd_12mo_hsc_cluster_comp_selected)]
    age_20mo_hsc_cluster_comp_selected = fine_times[np.argmin(jsd_20mo_hsc_cluster_comp_selected)]
    
    
    # Loop over clusters (assumed 0 to 9)
    for cluster in clusters:
        
        # --------
        # a) cKit KDE-based age estimates using KL divergence
        # For 3mo:
        kl_div_3mo = np.array([
            calculate_kl_divergence(kde_results_3mo_selected_normalized[("cKits", cluster)],
                                    interpolated_kde_ckits[cluster][i])
            for i in range(len(fine_times))
        ])
        best_time_3mo_ckit = fine_times[np.argmin(kl_div_3mo)]
        age_ckit_3mo_selected.append(best_time_3mo_ckit)
        # For 12mo:
        kl_div_12mo = np.array([
            calculate_kl_divergence(kde_results_12mo_selected_normalized[("cKits", cluster)],
                                    interpolated_kde_ckits[cluster][i])
            for i in range(len(fine_times))
        ])
        best_time_12mo_ckit = fine_times[np.argmin(kl_div_12mo)]
        age_ckit_12mo_selected.append(best_time_12mo_ckit)
        # For 20mo:
        kl_div_20mo = np.array([
            calculate_kl_divergence(kde_results_20mo_selected_normalized[("cKits", cluster)],
                                    interpolated_kde_ckits[cluster][i])
            for i in range(len(fine_times))
        ])
        best_time_20mo_ckit = fine_times[np.argmin(kl_div_20mo)]
        age_ckit_20mo_selected.append(best_time_20mo_ckit)

        if affinity_flag == "MSE":
            # --------
            # c) cKit affinity-based age estimates using mean squared error (choose minimum distance)
            # For 3mo:
            mse_3mo = np.array([
                calculate_mse(
                    affinity_matrices_selected["3mo"].iloc[cluster, :].values,
                    interpolated_affinity_matrices[cluster][i]
                )
                for i in range(len(fine_times))
            ])
            best_time_3mo_affinity = fine_times[np.argmin(mse_3mo)]
            age_affinity_3mo_selected.append(best_time_3mo_affinity)
            # For 12mo:
            mse_12mo = np.array([
                calculate_mse(
                    affinity_matrices_selected["12mo"].iloc[cluster, :].values,
                    interpolated_affinity_matrices[cluster][i]
                )
                for i in range(len(fine_times))
            ])
            best_time_12mo_affinity = fine_times[np.argmin(mse_12mo)]
            age_affinity_12mo_selected.append(best_time_12mo_affinity)
            # For 20mo:
            mse_20mo = np.array([
                calculate_mse(
                    affinity_matrices_selected["20mo"].iloc[cluster, :].values,
                    interpolated_affinity_matrices[cluster][i]
                )
                for i in range(len(fine_times))
            ])
            best_time_20mo_affinity = fine_times[np.argmin(mse_20mo)]
            age_affinity_20mo_selected.append(best_time_20mo_affinity)
            
        elif affinity_flag == "CE":
            # --------
            # c) cKit affinity-based age estimates using cross-entropy (choose minimum distance)
            # For 3mo:
            cross_entropy_3mo = np.array([
                calculate_cross_entropy(
                    affinity_matrices_selected["3mo"].iloc[cluster, :].values,
                    interpolated_affinity_matrices[cluster][i]
                )
                for i in range(len(fine_times))
            ])
            best_time_3mo_affinity = fine_times[np.argmin(cross_entropy_3mo)]
            age_affinity_3mo_selected.append(best_time_3mo_affinity)
            # For 12mo:
            cross_entropy_12mo = np.array([
                calculate_cross_entropy(
                    affinity_matrices_selected["12mo"].iloc[cluster, :].values,
                    interpolated_affinity_matrices[cluster][i]
                )
                for i in range(len(fine_times))
            ])
            best_time_12mo_affinity = fine_times[np.argmin(cross_entropy_12mo)]
            age_affinity_12mo_selected.append(best_time_12mo_affinity)
            # For 20mo:
            cross_entropy_20mo = np.array([
                calculate_cross_entropy(
                    affinity_matrices_selected["20mo"].iloc[cluster, :].values,
                    interpolated_affinity_matrices[cluster][i]
                )
                for i in range(len(fine_times))
            ])
            best_time_20mo_affinity = fine_times[np.argmin(cross_entropy_20mo)]
            age_affinity_20mo_selected.append(best_time_20mo_affinity)
        else:
            raise ValueError("Invalid affinity flag. Choose 'MSE' or 'CE'.") 
    

    # -------------------------------
    # 4e. Compute weighted averages for each modality using cluster sizes.

    # Compute cluster sizes for each selected dataset using the function calculate_cluster_sizes
    cluster_sizes_selected = {}
    for transformed_df, condition in zip([selected_3mo_df, selected_12mo_df, selected_20mo_df],
                                        ["3mo_selected", "12mo_selected", "20mo_selected"]):
        cluster_sizes_selected[condition] = calculate_cluster_sizes(transformed_df)

    # Compute cluster sizes for the reference data (for normalization)
    cluster_sizes_ref = {}
    for transformed_df, condition in zip([ref_3mo_df, ref_12mo_df, ref_20mo_df],
                                        ["3mo_ref", "12mo_ref", "20mo_ref"]):
        cluster_sizes_ref[condition] = calculate_cluster_sizes(transformed_df)
    
    # Define the sources and clusters to use for weighted averaging.
    sources = ["cKits", "HSCs"]  # adjust if you use additional sources
    clusters_list = list(range(10))
    
    # Compute weighted cluster sizes for each age condition using compute_weighted_cluster_sizes.
    cluster_size_3mo_selected = compute_weighted_cluster_sizes(ref_cluster_sizes=cluster_sizes_ref,
                                                            cluster_sizes=cluster_sizes_selected,
                                                            condition="3mo_selected",
                                                            sources=sources,
                                                            clusters=clusters_list)
    cluster_size_12mo_selected = compute_weighted_cluster_sizes(ref_cluster_sizes=cluster_sizes_ref,
                                                            cluster_sizes=cluster_sizes_selected,
                                                            condition="12mo_selected",
                                                            sources=sources,
                                                            clusters=clusters_list)
    cluster_size_20mo_selected = compute_weighted_cluster_sizes(ref_cluster_sizes=cluster_sizes_ref,
                                                            cluster_sizes=cluster_sizes_selected,
                                                            condition="20mo_selected",
                                                            sources=sources,
                                                            clusters=clusters_list)
    
    # Optionally, save the computed cluster sizes for the selected datasets for later inspection.
    """
    with open(os.path.join(epoch_folder, "cluster_sizes_selected.pkl"), "wb") as f:
        pickle.dump(cluster_sizes_selected, f)
    with open(os.path.join(epoch_folder, "weighted_cluster_size_3mo_selected.pkl"), "wb") as f:
        pickle.dump(cluster_size_3mo_selected, f)
    with open(os.path.join(epoch_folder, "weighted_cluster_size_12mo_selected.pkl"), "wb") as f:
        pickle.dump(cluster_size_12mo_selected, f)
    with open(os.path.join(epoch_folder, "weighted_cluster_size_20mo_selected.pkl"), "wb") as f:
        pickle.dump(cluster_size_20mo_selected, f)
    
    print("Cluster sizes computed for selected datasets in epoch", epoch_index)
    """
    
    # with keys "cKits" and "HSCs" that map cluster indices (0..9) to normalized weights.
    # -------------------------------
    # Weighted average for cKit KDE-based age estimates:
    average_age_3mo_cKits_selected = np.average(list(age_ckit_3mo_selected), 
                                                weights=list(cluster_size_3mo_selected["cKits"].values()))
    average_age_12mo_cKits_selected = np.average(list(age_ckit_12mo_selected), 
                                                weights=list(cluster_size_12mo_selected["cKits"].values()))
    average_age_20mo_cKits_selected = np.average(list(age_ckit_20mo_selected), 
                                                weights=list(cluster_size_20mo_selected["cKits"].values()))
    

    # Weighted average for cKit affinity-based estimates:
    average_age_3mo_affinity_selected = np.average(list(age_affinity_3mo_selected), 
                                                weights=list(cluster_size_3mo_selected["cKits"].values()))
    average_age_12mo_affinity_selected = np.average(list(age_affinity_12mo_selected), 
                                                weights=list(cluster_size_12mo_selected["cKits"].values()))
    average_age_20mo_affinity_selected = np.average(list(age_affinity_20mo_selected), 
                                                weights=list(cluster_size_20mo_selected["cKits"].values()))
    

    # -------------------------------
    # Construct the optimal_input dictionary using all five components.
    # -------------------------------
    optimal_input = {
        "3mo": [average_age_3mo_cKits_selected, age_3mo_hsc_merged_selected, average_age_3mo_affinity_selected,
                age_num_3mo_selected, age_3mo_hsc_cluster_comp_selected],
        "12mo": [average_age_12mo_cKits_selected, age_12mo_hsc_merged_selected, average_age_12mo_affinity_selected,
                age_num_12mo_selected, age_12mo_hsc_cluster_comp_selected],
        "20mo": [average_age_20mo_cKits_selected, age_20mo_hsc_merged_selected, average_age_20mo_affinity_selected,
                age_num_20mo_selected, age_20mo_hsc_cluster_comp_selected]
    }
    
    # Save the optimal_input dictionary for inspection
    with open(os.path.join(epoch_folder, f"optimal_input_NoNumCluster_{affinity_flag}.pkl"), "wb") as f:
        pickle.dump(optimal_input, f)
    
    # print("Optimal input computed for epoch", epoch_index)

    # -------------------------------
    # 5. Optimize the weights (5-feature model with L2 regularization)
    # -------------------------------
    initial_weights = np.array([0.2, 0.2, 0.2, 0.2, 0.2])
    bounds = [(0, 1) for _ in range(5)]
    constraints = {"type": "eq", "fun": weight_constraint}
    res_full = minimize(
        total_loss,
        initial_weights,
        args=(optimal_input, ground_truth, lambda_reg),
        method="SLSQP",  # Using SLSQP instead of Nelder-Mead
        bounds=bounds,
        constraints=constraints,
        options={"disp": True, "maxiter": 100}
    )

    if res_full.success:
        full_weights = res_full.x / np.sum(res_full.x)  # normalize to sum to 1
        full_loss = res_full.fun
    else:
        full_weights = None
        full_loss = None


    # -------------------------------
    # 6. Save the results for this epoch in the epoch folder.
    # Save the optimized weights (for both 5-feature and 4-feature models) and the loss.
    # -------------------------------
    weights_dict_full = {
        "cKit_pdf_weight": full_weights[0] if full_weights is not None else None,
        "HSC_pdf_weight": full_weights[1] if full_weights is not None else None,
        "cKit_affinity_weight": full_weights[2] if full_weights is not None else None,
        "HSC_num_weight": full_weights[3] if full_weights is not None else None,
        "HSC_cluster_comp_weight": full_weights[4] if full_weights is not None else None,
        "Total_loss": full_loss
    }
    weights_df_full = pd.DataFrame(list(weights_dict_full.items()), columns=["Variable", "Weight"])
    weights_df_full.to_csv(os.path.join(epoch_folder, f"optimal_input_weights_full_NoNumCluster_{affinity_flag}.csv"), index=False)

    print(f"Epoch {epoch_index} completed. Full-model weights: {full_weights}")
    
    epoch_index += 1



In [None]:
# Combining the results from all epochs
# Load the results from each epoch and combine them into a single DataFrame
weights_full_list = []

# Get the all the epoch folders under results_dir
epoch_folders = [os.path.join(results_dir, folder) for folder in os.listdir(results_dir) if folder.startswith("epoch")]

for epoch_folder in epoch_folders:
    weights_full_df = pd.read_csv(os.path.join(epoch_folder, f"optimal_input_weights_full_NoNumCluster_{affinity_flag}.csv"))
    weights_full_list.append(weights_full_df)

# Combine the results from all epochs into a single DataFrame
weights_full_df = pd.concat(weights_full_list)

weights_full_mean = weights_full_df.groupby("Variable")["Weight"].mean()
weights_full_std = weights_full_df.groupby("Variable")["Weight"].std()

weights_full_mean = weights_full_mean.drop("Total_loss")
weights_full_std = weights_full_std.drop("Total_loss")

# Normalize the mean/median weights to sum to 1
weights_full_mean_normalized = weights_full_mean / weights_full_mean.sum()


### 5.5 Bone age estimation with linear model

In [None]:
# Ref data preparation
# Load the final optimal weights from the results directory

cKit_pdf_weight = weights_full_mean_normalized.loc["cKit_pdf_weight"].values[0]
HSC_pdf_weight = weights_full_mean_normalized.loc["HSC_pdf_weight"].values[0]
HSC_num_weight = weights_full_mean_normalized.loc["HSC_num_weight"].values[0]
HSC_cluster_comp_weight = weights_full_mean_normalized.loc["HSC_cluster_comp_weight"].values[0]
cKit_affinity_weight = weights_full_mean_normalized.loc["cKit_affinity_weight"].values[0]

# final_weights = np.array([cKit_pdf_weight, HSC_pdf_weight, cKit_affinity_weight, HSC_num_weight, HSC_histogram_weight])
final_weights = np.array([cKit_pdf_weight, HSC_pdf_weight, cKit_affinity_weight, HSC_num_weight, HSC_cluster_comp_weight])

# ----------------------------
# 1. Read the transformed data for reference and test datasets if necessary
# ----------------------------


# ----------------------------
# 2. Process the Reference Data: Compute Features and Build Interpolators
# ----------------------------

# 2a. Compute KDEs for reference

kde_results_3mo_ref, ref_grid = kde_for_clusters(transformed_3mo_bones_df, bone_outline=ref_outline, binsize=10)
kde_results_12mo_ref, _ = kde_for_clusters(transformed_12mo_bones_df, bone_outline=ref_outline, binsize=10)
kde_results_20mo_ref, _ = kde_for_clusters(transformed_20mo_bones_df, bone_outline=ref_outline, binsize=10)

# Save the KDE results for reference, required for the visualization later
with open(os.path.join(results_dir, "kde_results_3mo_ref.pkl"), "wb") as f:
    pickle.dump(kde_results_3mo_ref, f)
with open(os.path.join(results_dir, "kde_results_12mo_ref.pkl"), "wb") as f:
    pickle.dump(kde_results_12mo_ref, f)
with open(os.path.join(results_dir, "kde_results_20mo_ref.pkl"), "wb") as f:
    pickle.dump(kde_results_20mo_ref, f)
    
with open(f"{results_dir}/ref_grid.pkl", "wb") as f:
    pickle.dump(ref_grid, f)

# Normalize the KDE dictionaries for reference
kde_results_3mo_ref_normalized = normalize_kde_values(kde_results_3mo_ref)
kde_results_12mo_ref_normalized = normalize_kde_values(kde_results_12mo_ref)
kde_results_20mo_ref_normalized = normalize_kde_values(kde_results_20mo_ref)


hsc_cluster_comp_3mo_ref = compute_proportions_hsc(transformed_3mo_bones_df)
hsc_cluster_comp_12mo_ref = compute_proportions_hsc(transformed_12mo_bones_df)
hsc_cluster_comp_20mo_ref = compute_proportions_hsc(transformed_20mo_bones_df)

# Save the HSC cluster compositions for reference as csv files
hsc_cluster_comp_3mo_ref.to_csv(os.path.join(results_dir, "hsc_cluster_comp_3mo_ref.csv"), index=False)
hsc_cluster_comp_12mo_ref.to_csv(os.path.join(results_dir, "hsc_cluster_comp_12mo_ref.csv"), index=False)
hsc_cluster_comp_20mo_ref.to_csv(os.path.join(results_dir, "hsc_cluster_comp_20mo_ref.csv"), index=False)

# 2c. Process Affinity Matrices for reference.
# Load reference affinity matrices
affinity_matrices_ref = {
    "3mo": affinity_matrices.get("3mo").div(affinity_matrices.get("3mo").sum(axis=1), axis=0),
    "12mo": affinity_matrices.get("12mo").div(affinity_matrices.get("12mo").sum(axis=1), axis=0),
    "20mo": affinity_matrices.get("20mo").div(affinity_matrices.get("20mo").sum(axis=1), axis=0)
}


# 2d. Build Interpolation Dictionaries.
# We use the same real times and a fine grid for interpolation.
real_times = np.array([3, 12, 20])
fine_times = np.linspace(3, 20, 100)

# Create dictionaries for each modality:
interpolated_kde_ckits = {}        # For cKit KDEs, per cluster
interpolated_histograms_hsc = {}     # For HSC histograms, per cluster
interpolated_affinity_matrices = {}  # For affinity matrices (row-wise per cluster)
interpolated_kde_hsc_merged = None   # For merged HSC KDEs

clusters = list(range(10))

# 2d.i. Interpolate cKit KDEs (from the normalized reference KDEs)
for cluster in clusters:
    age_kdes = np.array([
        kde_results_3mo_ref_normalized[("cKits", cluster)],
        kde_results_12mo_ref_normalized[("cKits", cluster)],
        kde_results_20mo_ref_normalized[("cKits", cluster)]
    ])
    f_kde = interp1d(real_times, age_kdes, axis=0, kind="linear")
    # interpolated_kde_ckits[cluster] = f_kde(fine_times)

    # Evaluate the interpolator at the fine-grained time points
    interp_values_kde = f_kde(fine_times)
    # Normalize each interpolated pdf (each row) so that its sum is 1
    interp_values_normalized_kde = interp_values_kde / np.sum(interp_values_kde, axis=(1,2), keepdims=True)
    # Store the normalized result in the dictionary
    interpolated_kde_ckits[cluster] = interp_values_normalized_kde
    
# 2d.ii. Interpolate HSC cluster compositions
age_hsc_cluster_comp = stack_proportions(hsc_cluster_comp_3mo_ref, hsc_cluster_comp_12mo_ref, hsc_cluster_comp_20mo_ref)
# Create a linear interpolator for the HSC cluster composition
f_hsc_cluster_comp = interp1d(real_times, age_hsc_cluster_comp, axis=1, kind="linear", fill_value="extrapolate") # It is correct here to use axis=1

# Step 3 (revised): Interpolate and normalize each timepoint
interp_hsc_cluster_comp = f_hsc_cluster_comp(fine_times)  # shape: (100 timepoints, 10 clusters)

# Step 4 Normalize to make sure each vector sums to 1
interp_hsc_cluster_comp = interp_hsc_cluster_comp.T
interp_hsc_cluster_comp = interp_hsc_cluster_comp / interp_hsc_cluster_comp.sum(axis=1, keepdims=True)


# 2d.iii. Interpolate affinity matrices (row-wise, per cluster)
for cluster in clusters:
    age_affinity = np.array([
        affinity_matrices_ref["3mo"].iloc[cluster, :].values,
        affinity_matrices_ref["12mo"].iloc[cluster, :].values,
        affinity_matrices_ref["20mo"].iloc[cluster, :].values
    ])
    f_affinity = interp1d(real_times, age_affinity, axis=0, kind="linear")
    # interpolated_affinity_matrices[cluster] = f_affinity(fine_times)
    
    # Evaluate the interpolator at the fine-grained time points
    interp_values_affinity = f_affinity(fine_times)
    # Normalize each interpolated row so that its sum is 1
    interp_values_normalized_affinity = interp_values_affinity / np.sum(interp_values_affinity, axis=1, keepdims=True)
    # Store the normalized result in the dictionary
    interpolated_affinity_matrices[cluster] = interp_values_normalized_affinity

# 2d.iv. Interpolate merged HSC KDEs (for the entire "HSCs" key)
age_kdes_hsc_merged = np.array([
    kde_results_3mo_ref_normalized["HSCs"],
    kde_results_12mo_ref_normalized["HSCs"],
    kde_results_20mo_ref_normalized["HSCs"]
])
f_kde_hsc_merged = interp1d(real_times, age_kdes_hsc_merged, axis=0, kind="linear")
# interpolated_kde_hsc_merged = f_kde_hsc_merged(fine_times)

# Evaluate the interpolator at the fine-grained time points
interp_values_kde_hsc_merged = f_kde_hsc_merged(fine_times)
# Normalize each interpolated pdf (each row) so that its sum is 1
interp_values_normalized_kde_hsc_merged = interp_values_kde_hsc_merged / np.sum(interp_values_kde_hsc_merged, axis=(1,2), keepdims=True)
# Store the normalized result in the dictionary
interpolated_kde_hsc_merged = interp_values_normalized_kde_hsc_merged

# 2d.v. Prepare the HSCs nums matrix for exponential regression
age_groups_numeric = np.array([3, 12, 20])  # Numeric representation of age groups
conditions_ref = ["3mo_ref", "12mo_ref", "20mo_ref"]

hsc_nums_ref = {}
# Calculate the number of cells for each condition (ref, without clusters)
for transformed_df, condition in zip(
    [transformed_3mo_bones_df, transformed_12mo_bones_df, transformed_20mo_bones_df],
    conditions_ref,
):
    hsc_condition = len(transformed_df[transformed_df["source"] == "HSCs"])
    # Normalize by the number of datasets
    num_datasets = len(transformed_df["dataset"].unique())
    hsc_condition = hsc_condition / num_datasets
    hsc_nums_ref[condition] = hsc_condition
hsc_nums_matrix_ref = np.array([hsc_nums_ref[condition] for condition in conditions_ref])  # Convert dictionary to matrix

# Save the interpolated data

with open(os.path.join(results_dir, "interpolated_kde_ckits.pkl"), "wb") as f:
    pickle.dump(interpolated_kde_ckits, f)
with open(os.path.join(results_dir, "interpolated_affinity_matrices.pkl"), "wb") as f:
    pickle.dump(interpolated_affinity_matrices, f)
with open(os.path.join(results_dir, "interpolated_kde_hsc_merged.pkl"), "wb") as f:
    pickle.dump(interpolated_kde_hsc_merged, f)
with open(os.path.join(results_dir, "hsc_nums_ref.pkl"), "wb") as f:
    pickle.dump(hsc_nums_ref, f)
with open(os.path.join(results_dir, "interpolated_hsc_cluster_comp.pkl"), "wb") as f:
    pickle.dump(interp_hsc_cluster_comp, f)


In [None]:
# Calculate the cluster weights
# Compute cluster sizes for each ref dataset using the function calculate_cluster_sizes
cluster_sizes_ref = {}
for transformed_df, condition in zip([transformed_3mo_bones_df, transformed_12mo_bones_df, transformed_20mo_bones_df],
                                    conditions_ref):
    cluster_sizes_ref[condition] = calculate_cluster_sizes(transformed_df)

# Compute cluster sizes for the 5fu30d and 5fu60d datasets
cluster_sizes_5fu = {}
for transformed_df, condition in zip([transformed_5fu30d_bones_df, transformed_5fu60d_bones_df],
                                    ["5fu30d", "5fu60d"]):
    cluster_sizes_5fu[condition] = calculate_cluster_sizes(transformed_df)

# Save the cluster sizes for reference and test datasets for later inspection.
"""
with open(os.path.join(results_dir, "cluster_sizes_ref.pkl"), "wb") as f:
    pickle.dump(cluster_sizes_ref, f)
with open(os.path.join(results_dir, "cluster_sizes_5fu.pkl"), "wb") as f:
    pickle.dump(cluster_sizes_5fu, f)
"""


In [None]:
# Process the Test Data
def process_test_data(transformed_test_df, cluster_size_test, condition, ref_outline, save_dir, affinity_flag):
    """
    Process a test dataset (e.g., 5fu30d or 5fu60d) and compute the five component age estimates.
    cluster_size_test: dictionary with weighted cluster sizes from reference data,
                            used as weights in averaging per-cluster estimates.
    Returns:
        A list of 5 estimated ages in the order:
        [cKit_KDE_est, merged_HSC_KDE_est, cKit_affinity_est, HSC_numbers_est, HSC_histogram_est]
    """
    
    # Compute KDEs for test data and normalize
    kde_results_test, _ = kde_for_clusters(transformed_test_df, bone_outline=ref_outline, binsize=10)
    kde_results_test_normalized = normalize_kde_values(kde_results_test)
    
    # Save the kde_results_test when savedir is provided
    # if save_dir:
    #     with open(os.path.join(save_dir, f"kde_results_test_{condition}.pkl"), "wb") as f:
    #         pickle.dump(kde_results_test_normalized, f)
    
    # Compute HSC cluster compositions for test data
    hsc_cluster_comp_test = compute_proportions_hsc(transformed_test_df)
    hsc_cluster_comp_test = hsc_cluster_comp_test.sort_values("clusters")["proportion"].values
    
    # For affinity, assume that you have a method to compute the affinity matrix for test data.
    test_affinity_matrix = affinity_matrices.get(condition)  
    
    # Normalize the test affinity matrix by row (for MSE and CE)
    test_affinity_matrix = test_affinity_matrix.div(test_affinity_matrix.sum(axis=1), axis=0)
    

    # For HSC numbers, compute counts per cluster for source "HSCs"
    hsc_conditions = len(transformed_test_df[transformed_test_df["source"] == "HSCs"])
    hsc_counts = hsc_conditions / len(transformed_test_df["dataset"].unique())
    
    # Initialize containers for per-cluster estimates for each modality.
    age_ckit_estimates = []

    age_affinity_estimates = []
    age_hsc_num_estimates = {}  # dictionary: key = cluster, value = estimated age
    
    kl_divs = []
    mse_valss = []
    ce_valss = []

    
    # Read the interpolated data if needed
    with open(os.path.join(save_dir, "interpolated_kde_ckits.pkl"), "rb") as f:
        interpolated_kde_ckits = pickle.load(f)
    with open(os.path.join(save_dir, "interpolated_hsc_cluster_comp.pkl"), "rb") as f:
        interpolated_hsc_cluster_comp = pickle.load(f)
    with open(os.path.join(save_dir, "interpolated_affinity_matrices.pkl"), "rb") as f:
        interpolated_affinity_matrices = pickle.load(f)
    with open(os.path.join(save_dir, "interpolated_kde_hsc_merged.pkl"), "rb") as f:
        interpolated_kde_hsc_merged = pickle.load(f)

    
    # Component 1: cKit KDE-based estimates (using KL divergence)
    for cluster in clusters:
        kl_div = np.array([
            calculate_kl_divergence(kde_results_test_normalized[("cKits", cluster)],
                                    interpolated_kde_ckits[cluster][i])
            for i in range(len(fine_times))
        ])
        kl_divs.append(kl_div)
        best_time = fine_times[np.argmin(kl_div)]
        age_ckit_estimates.append(best_time)
    avg_ckit_est = np.average(age_ckit_estimates, weights=list(cluster_size_test["cKits"].values()))
    
    # Component 2: Merged HSC KDE-based estimate (using KL divergence)
    kl_div_hsc = np.array([
        calculate_kl_divergence(kde_results_test_normalized["HSCs"], interpolated_kde_hsc_merged[i])
        for i in range(len(fine_times))
    ])
    avg_hsc_merged_est = fine_times[np.argmin(kl_div_hsc)]

    if affinity_flag == "MSE":
        # Component 3: cKit affinity-based estimates (using mean square error)
        for cluster in clusters:
            mse_vals = np.array([
                calculate_mse(test_affinity_matrix.iloc[cluster, :].values,
                            interpolated_affinity_matrices[cluster][i])
                for i in range(len(fine_times))
            ])
            mse_valss.append(mse_vals)
            best_time_aff = fine_times[np.argmin(mse_vals)]
            age_affinity_estimates.append(best_time_aff)
        avg_affinity_est = np.average(age_affinity_estimates, weights=list(cluster_size_test["cKits"].values()))
    elif affinity_flag == "CE":
        # Component 3: cKit affinity-based estimates (using cross entropy)
        for cluster in clusters:
            ce_vals = np.array([
                calculate_cross_entropy(test_affinity_matrix.iloc[cluster, :].values,
                                        interpolated_affinity_matrices[cluster][i])
                for i in range(len(fine_times))
            ])
            ce_valss.append(ce_vals)
            best_time_aff = fine_times[np.argmin(ce_vals)]
            age_affinity_estimates.append(best_time_aff)
        avg_affinity_est = np.average(age_affinity_estimates, weights=list(cluster_size_test["cKits"].values()))
    else:
        raise ValueError("Invalid affinity_flag. Choose 'MSE' or 'CE'.")

    # Component 4: HSC numbers-based estimates (using exponential regression, without clusters)
    try:
        popt_exp, _ = curve_fit(exponential_func, age_groups_numeric, hsc_nums_matrix_ref[:3], maxfev=10000)
        exp_func = lambda x: exponential_func(x, *popt_exp)
        intersection_age = find_intersection(exp_func, hsc_counts, fine_times, method="closest")
    except Exception as e:
        intersection_age = np.nan
    age_hsc_num_estimates = intersection_age
    avg_hsc_num_est = age_hsc_num_estimates

    # Component 5: HSC cluster compositon based estimates (using Jensen-Shannon divergence)
    jsd_hsc = np.array([
        jensenshannon(hsc_cluster_comp_test, q) ** 2 for q in interpolated_hsc_cluster_comp
    ])
    avg_hsc_cluster_comp = fine_times[np.argmin(jsd_hsc)]
    
    # Save the results
    if affinity_flag == "MSE":
        with open(os.path.join(save_dir, f"{datetime.today().strftime("%y%m%d")}_elements_dist_{condition}_NoNumCluster_{affinity_flag}.pkl"), "wb") as f:
            pickle.dump({
                "cKit_KL": kl_divs,
                "merged_HSC_KL": kl_div_hsc,
                "cKit_MSE": mse_valss,
                "HSC_cluster_comp_JSD": jsd_hsc,
            }, f)
    elif affinity_flag == "CE":
        with open(os.path.join(save_dir, f"{datetime.today().strftime("%y%m%d")}_elements_dist_{condition}_NoNumCluster_{affinity_flag}.pkl"), "wb") as f:
            pickle.dump({
                "cKit_KL": kl_divs,
                "merged_HSC_KL": kl_div_hsc,
                "cKit_CE": ce_valss,
                "HSC_cluster_comp_JSD": jsd_hsc,
            }, f)
    
    # Save the estimated age list before averaging
    with open(os.path.join(save_dir, f"{datetime.today().strftime("%y%m%d")}_estimated_ages_{condition}_NoNumCluster_{affinity_flag}.pkl"), "wb") as f:
        pickle.dump({
            "cKit_KDE": age_ckit_estimates,
            "merged_HSC_KDE": avg_hsc_merged_est,
            "cKit_affinity": age_affinity_estimates,
            "HSC_numbers": age_hsc_num_estimates,
            "HSC_cluster_comp": avg_hsc_cluster_comp,
        }, f)
    
    # Return the list of 5 component age estimates in order:
    return [avg_ckit_est, avg_hsc_merged_est, avg_affinity_est, avg_hsc_num_est, avg_hsc_cluster_comp]


In [None]:
# Generate test data

# Define the sources and clusters to use for weighted averaging.
sources = ["cKits", "HSCs"]  # adjust if you use additional sources
clusters_list = list(range(10))


# Compute weighted cluster sizes for each age condition using compute_weighted_cluster_sizes.

cluster_size_5fu30d = compute_weighted_cluster_sizes(ref_cluster_sizes=cluster_sizes_ref,
                                                        cluster_sizes=cluster_sizes_5fu,
                                                        condition="5fu30d",
                                                        sources=sources,
                                                        clusters=clusters_list)
cluster_size_5fu60d = compute_weighted_cluster_sizes(ref_cluster_sizes=cluster_sizes_ref,
                                                        cluster_sizes=cluster_sizes_5fu,
                                                        condition="5fu60d",
                                                        sources=sources,
                                                        clusters=clusters_list)



optimal_input_5fu30d = process_test_data(transformed_5fu30d_bones_df, cluster_size_5fu30d, "5fu30d", ref_outline, results_dir, affinity_flag)
optimal_input_5fu60d = process_test_data(transformed_5fu60d_bones_df, cluster_size_5fu60d, "5fu60d", ref_outline, results_dir, affinity_flag)

# ----------------------------
# Use the final optimal weights to estimate the final age for each test dataset.
# The final estimated age is the dot product of the 5-element optimal input and final_weights.
# ----------------------------
final_age_5fu30d = np.dot(final_weights, optimal_input_5fu30d)
final_age_5fu60d = np.dot(final_weights, optimal_input_5fu60d)

# ----------------------------
# Save the final estimated ages to a CSV file
# ----------------------------
final_estimates = pd.DataFrame({
    "Condition": ["5fu30d", "5fu60d"],
    "Estimated_Age": [final_age_5fu30d, final_age_5fu60d]
})
# final_estimates.to_csv(os.path.join(results_dir, f"final_age_estimates_5fu_avg_{weight_flag}_NoNumCluster_{affinity_flag}.csv"), index=False)

print("Final estimated ages for 5fu treated bones computed:")
print(final_estimates)




### 5.6 Bone age estimation with Gaussian model

In [None]:
# Generate features of test data for model inference
def generate_features(
    transformed_test_df, cluster_size_test, condition, affinity_matrix, ref_outline, save_dir,
    kde_results_ref, hsc_cluster_comp_ref, affinity_matrices_ref, clusters
):
    """
    Process test data and compute features for model inference.

    Args:
        transformed_test_df (DataFrame): Preprocessed test data.
        cluster_size_test (dict): Cluster sizes for test condition.
        condition (str): Condition name (e.g., "test_condition").
        affinity_matrix (DataFrame): Affinity matrix for test data.
        ref_outline (array): Reference bone outline for KDE calculation.
        save_dir (str): Directory to save intermediate results.
        kde_results_ref (dict): Precomputed reference KDEs (normalized).
        histogram_results_ref (dict): Precomputed reference histograms.
        affinity_matrices_ref (dict): Precomputed reference affinity matrices.
        clusters (list): List of clusters.

    Returns:
        dict: Dictionary containing computed feature values for the test data.
    """

    # Step 1: Compute KDEs for test data and normalize
    
    kde_results_test, _ = kde_for_clusters(transformed_test_df, bone_outline=ref_outline, binsize=10)
    kde_results_test_normalized = normalize_kde_values(kde_results_test)

    # Save KDE results if save_dir is provided
    """
    if save_dir:
        with open(os.path.join(save_dir, f"kde_results_test_{condition}.pkl"), "wb") as f:
            pickle.dump(kde_results_test_normalized, f)
    """      

    # Step 2: Compute HSC Cluster Composition for test data
    hsc_cluster_comp_test = compute_proportions_hsc(transformed_test_df)
    hsc_cluster_comp_test = hsc_cluster_comp_test.sort_values("clusters")["proportion"].values
    
    
    # Step 3: Compute KDE-based Features (KL Divergence) - For cKit and HSCs
    cKit_test_3mo_kl_divs = {}
    cKit_test_12mo_kl_divs = {}
    cKit_test_20mo_kl_divs = {}


    for cluster in clusters:
        cKit_test_3mo_kl_divs[cluster] = calculate_kl_divergence(kde_results_test_normalized[("cKits", cluster)], kde_results_ref["3mo"][("cKits", cluster)])
        cKit_test_12mo_kl_divs[cluster] = calculate_kl_divergence(kde_results_test_normalized[("cKits", cluster)], kde_results_ref["12mo"][("cKits", cluster)])
        cKit_test_20mo_kl_divs[cluster] = calculate_kl_divergence(kde_results_test_normalized[("cKits", cluster)], kde_results_ref["20mo"][("cKits", cluster)])

    HSC_test_3mo_kl_div = calculate_kl_divergence(kde_results_test_normalized["HSCs"], kde_results_ref["3mo"]["HSCs"])
    HSC_test_12mo_kl_div = calculate_kl_divergence(kde_results_test_normalized["HSCs"], kde_results_ref["12mo"]["HSCs"])
    HSC_test_20mo_kl_div = calculate_kl_divergence(kde_results_test_normalized["HSCs"], kde_results_ref["20mo"]["HSCs"])

    # Step 4: Compute HSC Cluster Composition
    HSC_test_3mo_cluster_comp = jensenshannon(hsc_cluster_comp_test, hsc_cluster_comp_ref["3mo"]) **2
    HSC_test_12mo_cluster_comp = jensenshannon(hsc_cluster_comp_test, hsc_cluster_comp_ref["12mo"]) **2
    HSC_test_20mo_cluster_comp = jensenshannon(hsc_cluster_comp_test, hsc_cluster_comp_ref["20mo"]) **2

    
    # Step 5: Compute HSC Numbers
    HSC_test_num = transformed_test_df[transformed_test_df["source"] == "HSCs"].shape[0]
    HSC_test_num = HSC_test_num / len(transformed_test_df["dataset"].unique())

    # Step 6: Compute Affinity-based Features (Cross Entropy)
    cKit_test_3mo_ces = {}
    cKit_test_12mo_ces = {}
    cKit_test_20mo_ces = {}

    test_affinity_matrix = affinity_matrix
    test_affinity_matrix = test_affinity_matrix.div(test_affinity_matrix.sum(axis=1), axis=0)

    for cluster in clusters:
        cKit_test_3mo_ces[cluster] = calculate_cross_entropy(test_affinity_matrix.iloc[cluster, :], affinity_matrices_ref["3mo"].iloc[cluster, :])
        cKit_test_12mo_ces[cluster] = calculate_cross_entropy(test_affinity_matrix.iloc[cluster, :], affinity_matrices_ref["12mo"].iloc[cluster, :])
        cKit_test_20mo_ces[cluster] = calculate_cross_entropy(test_affinity_matrix.iloc[cluster, :], affinity_matrices_ref["20mo"].iloc[cluster, :])

    # Step 7: Weighted Averaging of Features
    cKit_3mo_kl_div_avg = np.average(list(cKit_test_3mo_kl_divs.values()), weights=list(cluster_size_test["cKits"].values()))
    cKit_12mo_kl_div_avg = np.average(list(cKit_test_12mo_kl_divs.values()), weights=list(cluster_size_test["cKits"].values()))
    cKit_20mo_kl_div_avg = np.average(list(cKit_test_20mo_kl_divs.values()), weights=list(cluster_size_test["cKits"].values()))
    
    
    
    cKit_3mo_ce_avg = np.average(list(cKit_test_3mo_ces.values()), weights=list(cluster_size_test["cKits"].values()))
    cKit_12mo_ce_avg = np.average(list(cKit_test_12mo_ces.values()), weights=list(cluster_size_test["cKits"].values()))
    cKit_20mo_ce_avg = np.average(list(cKit_test_20mo_ces.values()), weights=list(cluster_size_test["cKits"].values()))
    
    # Step 8: Return the computed features as a dictionary
    
    test_features = {
        "cKit Density Divergence (vs. 3mo)": cKit_3mo_kl_div_avg,
        "cKit Density Divergence (vs. 12mo)": cKit_12mo_kl_div_avg,
        "cKit Density Divergence (vs. 20mo)": cKit_20mo_kl_div_avg,
        "HSC Density Divergence (vs. 3mo)": HSC_test_3mo_kl_div,
        "HSC Density Divergence (vs. 12mo)": HSC_test_12mo_kl_div,
        "HSC Density Divergence (vs. 20mo)": HSC_test_20mo_kl_div,
        "HSC Count": HSC_test_num,
        "HSC Composition (vs. 3mo)": HSC_test_3mo_cluster_comp,
        "HSC Composition (vs. 12mo)": HSC_test_12mo_cluster_comp,
        "HSC Composition (vs. 20mo)": HSC_test_20mo_cluster_comp,
        "cKit Neighborhood Affinity (vs. 3mo)": cKit_3mo_ce_avg,
        "cKit Neighborhood Affinity (vs. 12mo)": cKit_12mo_ce_avg,
        "cKit Neighborhood Affinity (vs. 20mo)": cKit_20mo_ce_avg
    }

    return test_features

In [None]:
# Define the sources and clusters to use for weighted averaging.
sources = ["cKits", "HSCs"]  # adjust if you use additional sources
clusters_list = list(range(10))

# Compute weighted cluster sizes for each age condition using compute_weighted_cluster_sizes.

cluster_size_5fu30d = compute_weighted_cluster_sizes(ref_cluster_sizes=cluster_sizes_ref,
                                                        cluster_sizes=cluster_sizes_5fu,
                                                        condition="5fu30d",
                                                        sources=sources,
                                                        clusters=clusters_list)
cluster_size_5fu60d = compute_weighted_cluster_sizes(ref_cluster_sizes=cluster_sizes_ref,
                                                        cluster_sizes=cluster_sizes_5fu,
                                                        condition="5fu60d",
                                                        sources=sources,
                                                        clusters=clusters_list)


affinity_matrix_5fu30d = affinity_matrices.get("5fu30d")
affinity_matrix_5fu60d = affinity_matrices.get("5fu60d")

kde_results_ref = {
    "3mo": normalize_kde_values(kde_results_3mo_ref),
    "12mo": normalize_kde_values(kde_results_12mo_ref),
    "20mo": normalize_kde_values(kde_results_20mo_ref),
}

hsc_cluster_comp_ref = {
    "3mo": hsc_cluster_comp_3mo_ref["proportion"].values,
    "12mo": hsc_cluster_comp_12mo_ref["proportion"].values,
    "20mo": hsc_cluster_comp_20mo_ref["proportion"].values
}


affinity_matrices_ref = {
    "3mo": affinity_matrices.get("3mo").div(affinity_matrices.get("3mo").sum(axis=1), axis=0),
    "12mo": affinity_matrices.get("12mo").div(affinity_matrices.get("12mo").sum(axis=1), axis=0),
    "20mo": affinity_matrices.get("20mo").div(affinity_matrices.get("20mo").sum(axis=1), axis=0)
}


features_5fu30d = generate_features(transformed_5fu30d_bones_df, cluster_size_5fu30d, "5fu30d", affinity_matrix_5fu30d, ref_outline, results_dir, kde_results_ref, hsc_cluster_comp_ref, affinity_matrices_ref, clusters_list)
features_5fu60d = generate_features(transformed_5fu60d_bones_df, cluster_size_5fu60d, "5fu60d", affinity_matrix_5fu60d, ref_outline, results_dir, kde_results_ref, hsc_cluster_comp_ref, affinity_matrices_ref, clusters_list)

In [None]:
# Drop non-feature columns (we keep only feature columns)
X = training_data_df.drop(columns=["epoch", "ground_truth"])
y = training_data_df["ground_truth"]

# Standardize the features
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X)
X_scaled_3mo = X_scaled[training_data_df["ground_truth"] == 3]
X_scaled_12mo = X_scaled[training_data_df["ground_truth"] == 12]
X_scaled_20mo = X_scaled[training_data_df["ground_truth"] == 20]

# Define cross-validation strategy
kf = KFold(n_splits=5, shuffle=True, random_state=42)

# Read models to compare

models = {
    "Gaussian Process": GaussianProcessRegressor(kernel=C(3.16**2) * RBF(length_scale=3.85), alpha=0.2, random_state=42)
}

# Store results for metrics
metrics = {model: {} for model in models}

# Perform cross-validation and compute metrics
for name, model in models.items():
    y_pred = cross_val_predict(model, X_scaled, y, cv=kf)
    
    # Compute metrics
    mse = mean_squared_error(y, y_pred)
    mae = mean_absolute_error(y, y_pred)
    r2 = r2_score(y, y_pred)
    evs = explained_variance_score(y, y_pred)

    # Store results
    metrics[name] = {
        "MSE": mse,
        "MAE": mae,
        "R² Score": r2,
        "Explained Variance": evs
    }

# Convert metrics to DataFrame
metrics_df = pd.DataFrame(metrics).T

# Print results
print(metrics_df)


In [None]:
# Train each model on the full dataset and save it
trained_models = {}
for name, model in models.items():
    model.fit(X_scaled, y)  # Train on the full dataset
    trained_models[name] = model
    
    # Save the trained model
    model_path = os.path.join(model_save_dir, f"model.pkl")
    joblib.dump(model, model_path) # don"t save the model again since we already saved it
    print(f"Saved {name} model to {model_path}")


In [None]:
# Compute feature importance using permutation importance
gpr_model = trained_models["Gaussian Process"]
perm_importance = permutation_importance(gpr_model, X_scaled, y, n_repeats=1000, random_state=42)

# Create a DataFrame for better visualization
perm_importance_df = pd.DataFrame({
    "Feature": X.columns,
    "Importance": perm_importance.importances_mean
}).sort_values(by="Importance", ascending=False)  # Sort by importance

# Print the permutation importance results
print("\n Feature Importance Based on Permutation Importance:")
print(perm_importance_df)

# Highlight the most and least important features
print("\n Most Important Features (Highest Impact on Model):")
print(perm_importance_df.head())

print("\n Least Important Features (Lowest Impact on Model):")
print(perm_importance_df.tail())

In [None]:
# Define feature categories by removing age comparisons
def categorize_feature(feature):
    if "HSC Count" in feature:
        return "HSC numbers"
    elif "cKit Density Divergence" in feature:
        return "HeM distribution"
    elif "HSC Density Divergence" in feature:
        return "HSC distribution"
    elif "HSC Spatial Similarity" in feature:
        return "HSC-HeM association"
    elif "cKit Neighborhood Affinity" in feature:
        return "HeM neighborhood"
    elif "HSC Composition" in feature:
        return "HSC composition"
    else:
        return feature  # Keep the original name if no match

# Apply categorization to the features
perm_importance_df["Feature Category"] = perm_importance_df["Feature"].apply(categorize_feature)

# Group by Feature Category and sum the importance
perm_importance_grouped = (
    perm_importance_df.groupby("Feature Category")["Importance"]
    .sum()
    .reset_index()
    .sort_values(by="Importance", ascending=True)
)

# Normalize it before visualization
perm_importance_grouped["Importance"] /= perm_importance_grouped["Importance"].sum()


## 6. Visualization

### 6.1 Element-wise bone age visualization

In [None]:
# Define the plot function (KL div, SSIM, CE, MSE) for elements
def plot_lineplot(data ,method, num_clusters, labels, y_label, results_dir, colors):
    # Create a figure and set of subplots
    fig, axes = plt.subplots(nrows=num_clusters, ncols=1, figsize=(8, 1 * num_clusters), sharex=True)
    # If there"s only one cluster, axes is not an array, so we handle that case
    if num_clusters == 1:
        axes = [axes]

    # Plot each cluster as a ridgeline plot
    for i, ax in enumerate(axes):
        for j in range(len(labels)):
            sns.lineplot(x=range(0, 100), y=data[j][i], ax=ax, label=labels[j] if i == 0 else "", c=colors[j])
            if method == "min":
                # selected_kl_div = np.min(data[j][i])
                # selected_kl_div_index = np.argmin(data[j][i])
                # Ignore the NaN values
                selected_kl_div = np.nanmin(data[j][i])
                selected_kl_div_index = np.nanargmin(data[j][i])
            elif method == "max":
                # selected_kl_div = np.max(data[j][i])
                # selected_kl_div_index = np.argmax(data[j][i])
                # Ignore the NaN values
                selected_kl_div = np.nanmax(data[j][i])
                selected_kl_div_index = np.nanargmax(data[j][i])
            else:
                raise ValueError("Invalid method. Please choose either 'min' or 'max'")
            sns.lineplot(x=[selected_kl_div_index], y=[selected_kl_div], marker="o", markersize=5, ax=ax, c=colors[j])

        ax.set_ylabel(f"Cluster {i + 1}", fontsize=12)
        ax.set_xlim([-2, 101])
        ax.set_xticks([0, 52.4, 99])
        ax.set_xticklabels([3, 12, 20], fontsize=12)
        ax.spines["top"].set_visible(False)
        ax.spines["right"].set_visible(False)
        # ax.spines["left"].set_visible(False)
        ax.spines["bottom"].set_visible(False)
        ax.tick_params(bottom=False) #left=False, 
        ax.grid(axis="y", linestyle="--", alpha=0.5)
        ax.tick_params(axis="both", which="major", labelsize=12)  # Change tick label size



    axes[-1].spines["bottom"].set_visible(True)
    axes[-1].tick_params(bottom=True)

    # Set common labels
    fig.text(0.5, 0.01, "Estimated Age (months)", ha="center", va="center")
    fig.text(0, 0.5, y_label, ha="center", va="center", rotation="vertical")

    # Add title
    fig.suptitle(f"{y_label}", fontsize=16)

    # Add legend
    axes[0].legend(bbox_to_anchor=(1.5, 1), loc="upper right", fontsize=12)
    # Adjust layout to fit titles and labels and bring subplots closer
    plt.subplots_adjust(hspace=0.5, left=0.12, right=0.95, top=0.95, bottom=0.05)
    
    plt.show()


In [None]:
# (Precalculated when generating the reference data)
with open(f"{results_dir}/{datetime.today().strftime("%y%m%d")}_elements_dist_5fu30d_NoNumCluster_{affinity_flag}.pkl", "rb") as f:
    elements_dist_5fu30d = pickle.load(f)
with open(f"{results_dir}/{datetime.today().strftime("%y%m%d")}_elements_dist_5fu60d_NoNumCluster_{affinity_flag}.pkl", "rb") as f:
    elements_dist_5fu60d = pickle.load(f)


In [None]:
# For merged HSC KL divergence    
merged_HSC_kl_5fu30d = elements_dist_5fu30d["merged_HSC_KL"]
merged_HSC_kl_5fu60d = elements_dist_5fu60d["merged_HSC_KL"]

# Plot the KL divergence as a ridgeline plot for HSCs
labels = ["HSC KL div 5fu30d", "HSC KL div 5fu60d"]
colors = ["red", "blue"]
y_label = "KL Divergence"
num_clusters = 1
method = "min"

# Create a figure and set of subplots
fig, axes = plt.subplots(nrows=num_clusters, ncols=1, figsize=(8, 1 * num_clusters), sharex=True)
# If there"s only one cluster, axes is not an array, so we handle that case
if num_clusters == 1:
    axes = [axes]

# Plot each cluster as a ridgeline plot
for i, ax in enumerate(axes):

    # For 5fu30d
    sns.lineplot(x=range(0, 100), y=merged_HSC_kl_5fu30d, ax=ax, label=labels[0], c=colors[0])
    selected_kl_div = np.min(merged_HSC_kl_5fu30d)
    selected_kl_div_index = np.argmin(merged_HSC_kl_5fu30d)
    sns.lineplot(x=[selected_kl_div_index], y=[selected_kl_div], marker="o", markersize=5, ax=ax, c=colors[0])
    
    # For 5fu60d
    sns.lineplot(x=range(0, 100), y=merged_HSC_kl_5fu60d, ax=ax, label=labels[1], c=colors[1])
    selected_kl_div = np.min(merged_HSC_kl_5fu60d)
    selected_kl_div_index = np.argmin(merged_HSC_kl_5fu60d)
    sns.lineplot(x=[selected_kl_div_index], y=[selected_kl_div], marker="o", markersize=5, ax=ax, c=colors[1])
    

    #ax.set_ylabel(f"Cluster {i + 1}", fontsize=12)
    ax.set_xlim([0, 99])
    ax.set_xticks([0, 52.4, 99])
    ax.set_xticklabels([3, 12, 20], fontsize=12)
    ax.spines["top"].set_visible(False)
    ax.spines["right"].set_visible(False)
    # ax.spines["left"].set_visible(False)
    ax.spines["bottom"].set_visible(False)
    ax.tick_params(bottom=False) #left=False, 
    ax.grid(axis="y", linestyle="--", alpha=0.5)
    ax.tick_params(axis="both", which="major", labelsize=12)  # Change tick label size



axes[-1].spines["bottom"].set_visible(True)
axes[-1].tick_params(bottom=True)

# Set common labels
fig.text(0.5, -0.5, "Estimated Age (months)", ha="center", va="center")
fig.text(0, 0.5, y_label, ha="center", va="center", rotation="vertical")

# Add title
fig.suptitle(f"{y_label} ", fontsize=16, y=1.2)

# Add legend
axes[0].legend(bbox_to_anchor=(1.5, 1), loc="upper right", fontsize=12)
# Adjust layout to fit titles and labels and bring subplots closer
plt.subplots_adjust(hspace=0.5, left=0.12, right=0.95, top=0.95, bottom=0.05)

plt.show()

In [None]:
# Fine-grained x-values for interpolation
fine_ages = np.linspace(3, 20, 100)
# For HSC numbers
HSC_number_3mo = transformed_3mo_bones_df[transformed_3mo_bones_df["source"] == "HSCs"].shape[0]
HSC_number_12mo = transformed_12mo_bones_df[transformed_12mo_bones_df["source"] == "HSCs"].shape[0]
HSC_number_20mo = transformed_20mo_bones_df[transformed_20mo_bones_df["source"] == "HSCs"].shape[0]

HSC_number_5fu30d = transformed_5fu30d_bones_df[transformed_5fu30d_bones_df["source"] == "HSCs"].shape[0]
HSC_number_5fu60d = transformed_5fu60d_bones_df[transformed_5fu60d_bones_df["source"] == "HSCs"].shape[0]

# Plot the exponential regression based on the HSC numbers
# Apply the exponential regression to the HSCs data without the cluster
hsc_nums_matrix_all = np.array([HSC_number_3mo, HSC_number_12mo, HSC_number_20mo, HSC_number_5fu30d, HSC_number_5fu60d])

# Exponential regression
try:
    popt_exp_all, _ = curve_fit(exponential_func, age_groups_numeric, hsc_nums_matrix_all[:3], maxfev=10000)
    exponential_regression_results_all = exponential_func(fine_ages, *popt_exp_all)
except RuntimeError:
    # If fitting fails, use a fallback
    exponential_regression_results_all = np.interp(fine_ages, age_groups_numeric, hsc_nums_matrix_all[:3])

# Plotting results
fig, ax = plt.subplots(figsize=(12, 4))

# Scatter plot of the HSC numbers
ax.scatter(
    age_groups_numeric,
    hsc_nums_matrix_all[:3],
    label=f"HSCs number",
    color="black"
)

# Exponential regression
ax.plot(
    fine_ages, 
    exponential_regression_results_all, 
    label=f"Exponential Regression", 
    linestyle="-", color="black", alpha=0.7
)

# Horizontal lines for 30d and 60d
ax.axhline(y=hsc_nums_matrix_all[3], color="red", linestyle="-", label="HSCs number of 5fu30d")
ax.axhline(y=hsc_nums_matrix_all[4], color="blue", linestyle="-", label="HSCs number of 5fu60d")

ax.legend()
ax.set_title("HSCs number")

# Set the x ticks and labels
ax.set_xticks([3, 12, 20])
ax.set_xticklabels(["3", "12", "20"])

# Set shared x-axis label
plt.xlabel("Estimated Age (Months)")
plt.ylabel("Number of HSCs")

plt.grid(axis="y", linestyle="--", alpha=0.5)
plt.legend(bbox_to_anchor=(1.5, 1), loc="upper right", fontsize=12)
# Adjust layout
plt.tight_layout()

plt.show()


In [None]:
# Plot the HSCs number per condition with barplot and error bars
# For each condition, we show the average HSC number and the standard deviation by source
# Prepare the data for plotting
# HSC numbers for each condition
hsc_numbers_per_condition_3mo = transformed_3mo_bones_df[transformed_3mo_bones_df["source"] == "HSCs"].groupby("dataset").size()
hsc_numbers_per_condition_12mo = transformed_12mo_bones_df[transformed_12mo_bones_df["source"] == "HSCs"].groupby("dataset").size()
hsc_numbers_per_condition_20mo = transformed_20mo_bones_df[transformed_20mo_bones_df["source"] == "HSCs"].groupby("dataset").size()
hsc_numbers_per_condition_5fu30d = transformed_5fu30d_bones_df[transformed_5fu30d_bones_df["source"] == "HSCs"].groupby("dataset").size()
hsc_numbers_per_condition_5fu60d = transformed_5fu60d_bones_df[transformed_5fu60d_bones_df["source"] == "HSCs"].groupby("dataset").size()

hsc_numbers_per_condition = {
    "3mo": hsc_numbers_per_condition_3mo,
    "12mo": hsc_numbers_per_condition_12mo,
    "20mo": hsc_numbers_per_condition_20mo,
    "5fu30d": hsc_numbers_per_condition_5fu30d,
    "5fu60d": hsc_numbers_per_condition_5fu60d
}

# Calculate the mean and standard deviation for each condition
hsc_numbers_mean = {key: value.mean() for key, value in hsc_numbers_per_condition.items()}
hsc_numbers_std = {key: value.std() for key, value in hsc_numbers_per_condition.items()}

# Plot the HSC numbers per condition with barplot and error bars
fig, ax = plt.subplots(figsize=(15, 5))
ax.bar(hsc_numbers_mean.keys(), hsc_numbers_mean.values(), yerr=hsc_numbers_std.values(), capsize=5, color="skyblue")
# show the dots for each condition
for i, key in enumerate(hsc_numbers_per_condition.keys()):
    y = hsc_numbers_per_condition[key]
    x = np.random.normal(i, 0.05, size=len(y))
    ax.plot(x, y, "o", color="black", alpha=1, markersize=5)
ax.set_title("HSC Numbers per Condition")
ax.set_ylabel("HSC Number")
ax.set_xlabel("Condition")
ax.grid(axis="y", linestyle="--", alpha=0.5)

plt.show()


In [None]:
# For cKit KL divergence
cKit_KL_5fu30d = elements_dist_5fu30d["cKit_KL"]
cKit_KL_5fu60d = elements_dist_5fu60d["cKit_KL"]

# Define the labels and colors for the lineplot
labels = ["cKit KL div 5fu30d", "cKit KL div 5fu60d"]
colors = ["red", "blue"]
y_label = "KL Divergence"
num_clusters = 10
method = "min"

plot_lineplot([cKit_KL_5fu30d, cKit_KL_5fu60d], method, num_clusters, labels, y_label, results_dir, colors)

# For the ckit neighborhood
if affinity_flag == "MSE":
    # For cKit neighborhood MSE
    cKit_MSE_5fu30d = elements_dist_5fu30d["cKit_MSE"]
    cKit_MSE_5fu60d = elements_dist_5fu60d["cKit_MSE"]

    # Define the labels and colors for the lineplot
    labels = ["cKit MSE 5fu30d", "cKit MSE 5fu60d"]
    colors = ["red", "blue"]
    y_label = "MSE"
    num_clusters = 10
    method = "min"

    plot_lineplot([cKit_MSE_5fu30d, cKit_MSE_5fu60d], method, num_clusters, labels, y_label, results_dir, colors)

elif affinity_flag == "CE":
    # For cKit neighborhood cross entropy (CE)
    cKit_cross_entropy_5fu30d = elements_dist_5fu30d["cKit_CE"]
    cKit_cross_entropy_5fu60d = elements_dist_5fu60d["cKit_CE"]

    # Define the labels and colors for the lineplot
    labels = ["cKit CE 5fu30d", "cKit CE 5fu60d"]
    colors = ["red", "blue"]
    y_label = "Cross Entropy"
    num_clusters = 10
    method = "min"

    plot_lineplot([cKit_cross_entropy_5fu30d, cKit_cross_entropy_5fu60d], method, num_clusters, labels, y_label, results_dir, colors)
else:
    raise ValueError("Invalid affinity flag. Please choose either 'MSE' or 'CE'")



### 6.2 Spatial density map of HSCs, RDs, and cKits

In [None]:
# Visualize the KDE for the metabone of the HSCs, RDs
def plot_SPDM(df, kde_results, common_grid, save_dir, color_map, bone_outline, filename=None, mesh = True, scatter = False, anno_text = None):


    # Ensure that the sources are aligned with the keys of mk_color_map
    # First, get the color map keys, excluding "Bone"
    color_map_keys = [key for key in color_map.keys() if key != "GFP"]
    
    # Reorder sources to match the order of color_map_keys
    sources = color_map_keys
    
    fig, axs = plt.subplots(len(sources), 1, figsize=(15 , 5 *len(sources)), sharex=True, sharey=True, constrained_layout=True)
    if len(sources) == 1:
        axs = [axs]  # since we only have one row, make it iterable

    for i, source in enumerate(sources):
        ax = axs[i]
        

        xi, yi, _ = common_grid # TODO: use the dict to save the common grid so we can use string to index
        zi = kde_results[source]
        # Masked the values outside the bone outline
        mask = points_in_polygon(xi.flatten(), yi.flatten(), bone_outline).reshape(xi.shape)
        zi[~mask] = 0

        # Filter only the non-zero values of zi for percentile calculation (i.e., exclude outside the mask)
        zi_inside_mask = zi[mask]

        # Normalize the values of zi for the full plot
        norm = Normalize(vmin=zi_inside_mask.min(), vmax=zi_inside_mask.max())
        normed_z = norm(zi)
        normed_z[~mask] = 0
        # Set the alpha transparency to 1 for test only
        # normed_z[mask] = 1    
        colors = np.array(plt.cm.colors.hex2color(color_map[source]))
        rgba_colors = np.zeros((*zi.shape, 4))
        rgba_colors[..., :3] = colors[:3]  # RGB values
        rgba_colors[..., -1] = normed_z  # Alpha transparency 
        
        # rgba_colors[zi < np.percentile(zi, 80)] = 0
        if anno_text is None:
            ax.set_title(source)
        else:
            ax.set_title(f"{source}, {anno_text[i]}")
        xlim_min, ylim_min = bone_outline.min(axis=0)
        xlim_max, ylim_max = bone_outline.max(axis=0)
        # Convert them into integers
        xlim_min, xlim_max = int(xlim_min)-500, int(xlim_max)+500
        ylim_min, ylim_max = int(ylim_min)-300, int(ylim_max)+300
        
        ax.set_xlim(xlim_min, xlim_max)
        ax.set_ylim(ylim_min, ylim_max)
        ax.set_xlabel("Position.X")
        ax.set_ylabel("Position.Y")
        if mesh:
            ax.pcolormesh(xi, yi, rgba_colors, shading="auto", rasterized=True)
        
        percentiles = np.arange(0, 81, 20)
        contour_levels = np.unique(np.percentile(norm(zi_inside_mask), percentiles))
        colors = color_map[source]
        if len(contour_levels) > 1:

            contour = ax.contour(xi, yi, normed_z, levels=contour_levels, linewidths=1, colors="black", alpha=0.5)
        
        # Plot the x, y of each bone group as the background
        if scatter:
            source_positions = df[df["source"] == source]
            ax.scatter(source_positions["Position.X"], source_positions["Position.Y"], s=1, c=color_map[source], alpha=1)
        

        ax.plot(bone_outline[:,0], bone_outline[:, 1], color="black")
        ax.grid(False)

    # Save the figure for the current bone group in the corresponding directory
    fig.suptitle(f"Spatial Distribution Estimation of {filename}", fontsize=16)

    plt.show()

In [None]:
kde_results_5fu30d_bones, _ = kde_for_clusters(transformed_5fu30d_bones_df, bone_outline=ref_outline, binsize=10)
kde_results_5fu60d_bones, _ = kde_for_clusters(transformed_5fu60d_bones_df, bone_outline=ref_outline, binsize=10)

with open(f"{results_dir}/kde_results_5fu30d_bones.pkl", "wb") as f:
    pickle.dump(kde_results_5fu30d_bones, f)
with open(f"{results_dir}/kde_results_5fu60d_bones.pkl", "wb") as f:
    pickle.dump(kde_results_5fu60d_bones, f)
    

In [None]:
# Read the kde results from the pickle files
with open(f"{results_dir}/kde_results_3mo_ref.pkl", "rb") as f:
    kde_results_3mo_bones = pickle.load(f)
with open(f"{results_dir}/kde_results_12mo_ref.pkl", "rb") as f:
    kde_results_12mo_bones = pickle.load(f)
with open(f"{results_dir}/kde_results_20mo_ref.pkl", "rb") as f:
    kde_results_20mo_bones = pickle.load(f)
with open(f"{results_dir}/kde_results_5fu30d_bones.pkl", "rb") as f:
    kde_results_5fu30d_bones = pickle.load(f)
with open(f"{results_dir}/kde_results_5fu60d_bones.pkl", "rb") as f:
    kde_results_5fu60d_bones = pickle.load(f)
    
with open(f"{results_dir}/ref_grid.pkl", "rb") as f:
    ref_grid = pickle.load(f)


In [None]:
# HSCs, RDs
plot_SPDM(transformed_3mo_bones_df, kde_results_3mo_bones, ref_grid, results_dir, hsc_rd_colors, ref_outline, "3mo_HSCs_RDs", mesh = True, scatter = True)
plot_SPDM(transformed_12mo_bones_df, kde_results_12mo_bones, ref_grid, results_dir, hsc_rd_colors, ref_outline, "12mo_HSCs_RDs", mesh = True, scatter = True)
plot_SPDM(transformed_20mo_bones_df, kde_results_20mo_bones, ref_grid, results_dir, hsc_rd_colors, ref_outline, "20mo_HSCs_RDs", mesh = True, scatter = True)
plot_SPDM(transformed_5fu30d_bones_df, kde_results_5fu30d_bones, ref_grid, results_dir, hsc_rd_colors, ref_outline, "5fu30d_HSCs_RDs", mesh = True, scatter = True)
plot_SPDM(transformed_5fu60d_bones_df, kde_results_5fu60d_bones, ref_grid, results_dir, hsc_rd_colors, ref_outline, "5fu60d_HSCs_RDs", mesh = True, scatter = True)


In [None]:
# cKits
plot_SPDM(transformed_3mo_bones_df, kde_results_3mo_bones, ref_grid, results_dir, ckits_color_map, ref_outline, "3mo_cKits", mesh = True, scatter = False)
plot_SPDM(transformed_12mo_bones_df, kde_results_12mo_bones, ref_grid, results_dir, ckits_color_map, ref_outline, "12mo_cKits", mesh = True, scatter = False)
plot_SPDM(transformed_20mo_bones_df, kde_results_20mo_bones, ref_grid, results_dir, ckits_color_map, ref_outline, "20mo_cKits", mesh = True, scatter = False)
plot_SPDM(transformed_5fu30d_bones_df, kde_results_5fu30d_bones, ref_grid, results_dir, ckits_color_map, ref_outline, "5fu30d_cKits", mesh = True, scatter = False)
plot_SPDM(transformed_5fu60d_bones_df, kde_results_5fu60d_bones, ref_grid, results_dir, ckits_color_map, ref_outline, "5fu60d_cKits", mesh = True, scatter = False)


### 6.3 Heatmap of the cKits affinity matrix

In [None]:
# Plot the interpolated affinity matrices 
def plot_affinity_heatmap(df, title, save_path=None, diagonal_text=None):
    """
    Plots a heatmap for the given interpolated affinity matrix.
    
    Parameters:
    - df (pd.DataFrame): Interpolated affinity matrix dataframe.
    - title (str): Title for the heatmap.
    - save_path (str, optional): Path to save the figure. If None, the figure is not saved.
    - diagonal_text (dict, optional): Dictionary mapping row indices to text to display on the diagonal.
                                    If None, the diagonal will remain black without text.
    """
    # Convert matrix to DataFrame if needed
    if not isinstance(df, pd.DataFrame):
        df = pd.DataFrame(df)

    # Set diagonal to NaN so it can be masked
    np.fill_diagonal(df.values, np.nan)

    # Create the figure and axis
    fig, ax = plt.subplots(figsize=(10, 8))

    # Create a mask for NaN values
    mask = np.isnan(df)

    # Plot heatmap
    sns.heatmap(df, ax=ax, cmap="coolwarm", annot=False, square=True, mask=mask, 
                cbar_kws={"shrink": 0.7, "aspect": 30, "ticks":[0.0, 0.2, 0.4, 0.55]}, rasterized=True, vmin=0, vmax=0.55)


    # Set title
    ax.set_title(title)
    
    # Set x and y labels
    ax.set_xlabel("Cluster")
    ax.set_ylabel("Neighborhood")

    # Get matrix size
    matrix_size = df.shape[0]

    # Add black horizontal lines between rows
    for i in range(1, matrix_size):
        ax.hlines(i, *ax.get_xlim(), colors="black", linewidth=1)

    # Overlay black patches for diagonal and optional text
    for i in range(matrix_size):
        ax.add_patch(plt.Rectangle((i, i), 1, 1, color="black", ec=None))  # Black diagonal
        
        if diagonal_text:  # If diagonal_text is provided, add text
            ax.text(i + 0.5, i + 0.5, diagonal_text.get(i, ""), ha="center", va="center", 
                    color="white", fontsize=10, fontweight="bold")

    # Set x-ticks and y-ticks from 1 to matrix_size
    ax.set_xticks(np.arange(matrix_size) + 0.5)
    ax.set_yticks(np.arange(matrix_size) + 0.5)
    ax.set_xticklabels(np.arange(1, matrix_size + 1))
    ax.set_yticklabels(np.arange(1, matrix_size + 1))

    # Save the figure if save_path is provided
    if save_path:
        fig.savefig(save_path, dpi=300)
    
    plt.show()

In [None]:
# Plot the normalized affinity matrix for all the data
for key in ["3mo", "12mo", "20mo", "5fu30d", "5fu60d"]:
    # Normalize the affinity matrix by row
    affinity_matrix_plot = affinity_matrices[key].copy()
    affinity_matrix_plot = affinity_matrix_plot / affinity_matrix_plot.sum(axis=1).to_numpy()[:, np.newaxis]
    # Pick specific rows
    row_picked = [4, 8]
    
    file_path = os.path.join(results_dir, f"{datetime.today().strftime("%y%m%d")}_HeM-HeM Association_{key}.pdf")
    plot_affinity_heatmap(affinity_matrix_plot, f"Normalized HeM-HeM Association for {key}", save_path=file_path)

### 6.4 Bone age visualization with the linear model

In [None]:
# Based on the weights we rescale the size of the plot
ckit_pdf_size = final_weights[0] * 100
hsc_pdf_size = final_weights[1] * 100
ckit_affinity_size = final_weights[2] * 100
hsc_num_size = final_weights[3] * 100
hsc_cluster_comp_size = final_weights[4] * 100

# Set minimum size as 1 based on the smallest weight
# Scale the other size with log5
min_size = 1
ckit_pdf_size = min_size + np.log(ckit_pdf_size) / np.log(2)
hsc_pdf_size = min_size + np.log(hsc_pdf_size) / np.log(2)
ckit_affinity_size = min_size + np.log(ckit_affinity_size) / np.log(2)
hsc_num_size = min_size + np.log(hsc_num_size) / np.log(2)
hsc_cluster_comp_size = min_size + np.log(hsc_cluster_comp_size) / np.log(2)


In [None]:
# Plotting
size_factor = 1000
final_weights = np.array([cKit_pdf_weight, HSC_pdf_weight, cKit_affinity_weight, HSC_num_weight, HSC_cluster_comp_weight])

y_scale = 1.3
alpha = 1

fig, ax = plt.subplots(figsize=(12, 8*y_scale)) 

# Read pkl files for the estimated ages
# Read the estimated ages for 5fu30d and 5fu60d with HSC cluster composition
with open(os.path.join(results_dir, f"{datetime.today().strftime("%y%m%d")}_estimated_ages_5fu30d_NoNumCluster_{affinity_flag}.pkl"), "rb") as f:
    estimated_ages_5fu30d = pickle.load(f)
with open(os.path.join(results_dir, f"{datetime.today().strftime("%y%m%d")}_estimated_ages_5fu60d_NoNumCluster_{affinity_flag}.pkl"), "rb") as f:
    estimated_ages_5fu60d = pickle.load(f)
    
# Plot the estimated ages for 5fu30d and 5fu60d as scatter plots
scatter_5fu30d = ax.scatter(estimated_ages_5fu30d["cKit_KDE"], [c+3 for c in clusters], color="red", marker="o", s=[int(s*size_factor*ckit_pdf_size) for s in cluster_size_5fu30d["cKits"].values()], label="5fu30d HeM distribution", alpha = alpha, edgecolors="black")
ax.scatter(estimated_ages_5fu30d["merged_HSC_KDE"], 1, color="red", marker="X", s=50*hsc_pdf_size, label="5fu30d HSC distribution", alpha = alpha, edgecolors="black")
ax.scatter(estimated_ages_5fu30d["cKit_affinity"], [c+3 for c in clusters], color="red", marker="^",s=[int(s*size_factor*ckit_affinity_size) for s in cluster_size_5fu30d["cKits"].values()], label="5fu30d HeM neighborhood", alpha = alpha, edgecolors="black")
ax.scatter(estimated_ages_5fu30d["HSC_numbers"], 0, color="red", marker="D", s=50*hsc_num_size, label="5fu30d HSC numbers", alpha = alpha, edgecolors="black")
ax.scatter(estimated_ages_5fu30d["HSC_cluster_comp"], 2, color="red", marker="s", s=50*hsc_cluster_comp_size, label="5fu30d HSC Cluster Composition", alpha = alpha, edgecolors="black") 

ax.scatter(estimated_ages_5fu60d["cKit_KDE"], [c+3 for c in clusters], color="blue", marker="o", s=[int(s*size_factor*ckit_pdf_size) for s in cluster_size_5fu60d["cKits"].values()], label="5fu60d HeM distribution", alpha = alpha, edgecolors="black")
ax.scatter(estimated_ages_5fu60d["merged_HSC_KDE"], 1, color="blue", marker="X", s=50*hsc_pdf_size, label="5fu60d HSC distribution", alpha = alpha, edgecolors="black")
ax.scatter(estimated_ages_5fu60d["cKit_affinity"], [c+3 for c in clusters], color="blue", marker="^", s=[int(s*size_factor*ckit_affinity_size) for s in cluster_size_5fu60d["cKits"].values()], label="5fu60d HeM neighborhood", alpha = alpha, edgecolors="black")
ax.scatter(estimated_ages_5fu60d["HSC_numbers"], 0, color="blue", marker="D", s=50*hsc_num_size, label="5fu60d HSC numbers", alpha = alpha, edgecolors="black")
ax.scatter(estimated_ages_5fu60d["HSC_cluster_comp"], 2, color="blue", marker="s", s=50*hsc_cluster_comp_size, label="5fu60d HSC Cluster Composition", alpha = alpha, edgecolors="black")

# Add the final estimated ages for 5fu30d and 5fu60d
ax.axvline(final_age_5fu30d, color="red", linestyle="--", label=f"Weighted Average 5fu30d:{final_age_5fu30d:.2f}", alpha = alpha)
ax.axvline(final_age_5fu60d, color="blue", linestyle="--", label=f"Weighted Average 5fu60d:{final_age_5fu60d:.2f}", alpha = alpha)

# Reference vertical lines
ax.axvline(3, color="black", linestyle=":", label="3mo", alpha = 0.3)
ax.axvline(12, color="black", linestyle=":", label="12mo", alpha = 0.3)
ax.axvline(20, color="black", linestyle=":", label="20mo", alpha = 0.3)

# Axes labels, title, and limits
ax.set_xlim(2, 21)
ax.set_xticks([3, 12, 20])
ax.set_xlabel("Age (months)")
ax.set_ylabel("Cluster")
ax.set_title("Estimated Age of 5fu30d and 5fu60d") # and 5fu60d
ax.set_yticks(clusters + [len(clusters), len(clusters)+1, len(clusters)+2])
ax.set_yticklabels(["HSC numbers"]+["HSC distribution"]+["HSC Cluster Composition"]+[f"HeM# {clusters+1}" for clusters in clusters])
# Invert the yticklabels, starting with HSCs num, HSCs pdf, and then the clusters from 10 to 1
ax.invert_yaxis()



# Combine legends
handles1, labels1 = ax.get_legend_handles_labels()
labels2 = [f"HeM distribution weight: {final_weights[0]:.4f}", f"HSC distribution weight: {final_weights[1]:.4f}", 
        f"HeM neighborhood weight: {final_weights[2]:.4f}", f"HSC numbers weight: {final_weights[3]:.4f}", f"HSC cluster composition weight: {final_weights[4]:.4f}"]


# Create custom handles for labels2 and labels3
handles2 = [Line2D([0], [0], marker="o", color="w", label=label,
                        markerfacecolor="white", markersize=10) for label in labels2]


# Create combined legend
handles = handles1 + handles2
labels = labels1 + labels2 

ax.legend(handles, labels, title="Estimated Age and Cluster Size", bbox_to_anchor=(1.05, 1), loc="upper left", fontsize=14, title_fontsize=14)

# Grid
ax.yaxis.grid(True, linestyle="--", alpha=0.3)

# Save and show plot
plt.show()



In [None]:
# Prepare data for heatmap (adjusting the order)
heatmap_data = np.zeros((2, 23))

# HeM Distribution (1-10)
for cluster in range(10):
    heatmap_data[0, cluster] = estimated_ages_5fu30d["cKit_KDE"][cluster] 
    heatmap_data[1, cluster] = estimated_ages_5fu60d["cKit_KDE"][cluster] 

# HeM-HeM Neighborhood (1-10)
for cluster in range(10):
    heatmap_data[0, 10 + cluster] = estimated_ages_5fu30d["cKit_affinity"][cluster] 
    heatmap_data[1, 10 + cluster] = estimated_ages_5fu60d["cKit_affinity"][cluster] 

# HSC Distribution
heatmap_data[0, 20] = estimated_ages_5fu30d["merged_HSC_KDE"] 
heatmap_data[1, 20] = estimated_ages_5fu60d["merged_HSC_KDE"] 

# HSC Numbers
heatmap_data[0, 21] = estimated_ages_5fu30d["HSC_numbers"] 
heatmap_data[1, 21] = estimated_ages_5fu60d["HSC_numbers"] 

# HSC Cluster Composition
heatmap_data[0, 22] = estimated_ages_5fu30d["HSC_cluster_comp"] 
heatmap_data[1, 22] = estimated_ages_5fu60d["HSC_cluster_comp"] 

# Plotting the heatmap
fig, ax = plt.subplots(figsize=(20, 6))

# Custom x-tick labels: Cluster numbers 1-10
# x_labels = [str(i + 1) for i in range(10)] + [str(i + 1) for i in range(10)] + ["", ""] + [str(i + 1) for i in range(10)]
x_labels = [str(i + 1) for i in range(10)] + [str(i + 1) for i in range(10)] + ["", "", ""]

# Generate heatmap with square cells and specified color bar range
sns.heatmap(heatmap_data, cmap="Greens", xticklabels=x_labels, yticklabels=["5FU30d", "5FU60d"],
            cbar_kws={"label": "Estimated Age (months)", "ticks": [3, 6, 9, 12, 15, 18, 20]}, 
            vmin=3, vmax=20, square=True, ax=ax)

# Adjust x-tick rotation to 0 degrees for cluster numbers
ax.set_xticklabels(x_labels, rotation=0)

# Adding sub-labels in the middle of each feature block
ax.text(5, 3.5, "HeM distribution", ha="center", va="center", fontsize=12, fontweight="bold")
ax.text(15, 3.5, "HeM-HeM neighborhood", ha="center", va="center", fontsize=12, fontweight="bold")
ax.text(20.5, 3.5, "HSC distribution", ha="center", va="center", fontsize=12, fontweight="bold")
ax.text(21.5, 3, "HSC numbers", ha="center", va="center", fontsize=12, fontweight="bold")
ax.text(22.5, 2.5, "HSC cluster composition", ha="center", va="center", fontsize=12, fontweight="bold")

# Adding arrowhead annotations for values smaller than 4.5 using Matplotlib triangle marker
for i in range(2):  # Rows: 5FU30d and 5FU60d
    for j in range(23):  # Columns: All features
        # Left arrowhead for values < 1.5
        if heatmap_data[i, j] < 4.5:
            ax.scatter(j + 0.5, i + 0.5, marker="<", color="black", s=100)
        
        # First mask: Difference > 15.5 (use a red star marker)
        if heatmap_data[i, j] > 18.5:
            ax.scatter(j + 0.5, i + 0.5, marker=">", color="black", s=100)
        
        # Second mask: Difference between 7.5 and 10.5 (use an orange plus marker)
        if 10.5 < heatmap_data[i, j] < 13.5:
            ax.scatter(j + 0.5, i + 0.5, marker="o", color="black", s=100)
        

legend_elements = [
    Line2D([0], [0], marker="<", color="w", label="3mo conserve", markerfacecolor="black", markersize=10),
    Line2D([0], [0], marker="o", color="w", label="12mo conserve", markerfacecolor="black", markersize=10),
    Line2D([0], [0], marker=">", color="w", label="20mo conserve", markerfacecolor="black", markersize=10),
    
]

# Add the legend to the plot (horizontal at the bottom)
ax.legend(handles=legend_elements, loc="lower center", bbox_to_anchor=(0.5, -2), 
        title="Marker Legend", fontsize=12, title_fontsize=14, ncol=3, frameon=False)

# Title and labels
plt.title("Estimated Biological Age of 5FU-Treated Bone Marrow Samples")
plt.ylabel("Treatment")

# Show plot
plt.show()

### 6.5 SHAP (Gaussian model) visualization

In [None]:
# Ensure SHAP only uses mean predictions from GPR
explainer = shap.Explainer(lambda x: gpr_model.predict(x, return_std=False), X_scaled, feature_names=list(X.columns.map(features_name_dict)))
# explainer = shap.LinearExplainer(ridge_model, X_scaled, feature_names=X.columns)

for i, features_test in enumerate([features_5fu30d, features_5fu60d]):
    dataset_name = "5fu30d" if i == 0 else "5fu60d"
    print(f"Test Condition: {dataset_name}")

    # Convert test sample to DataFrame and standardize it
    features_test = pd.DataFrame([features_test])
    test_features_scaled = scaler.transform(features_test)

    # Compute SHAP values
    shap_values = explainer(test_features_scaled)
    
    # Create SHAP waterfall plot
    shap.plots.waterfall(shap_values[0], max_display=len(X.columns), show=False)

    # Get mean prediction and standard deviation (uncertainty)
    y_pred, y_std = gpr_model.predict(test_features_scaled, return_std=True)
    
    # Update title
    # plt.title(f"SHAP Waterfall Plot for {dataset_name} with Ridge Regression", fontsize=14)
    plt.title(f"SHAP Waterfall Plot for {dataset_name}\nPrediction: {y_pred[0]:.2f} ± {y_std[0]:.2f}", fontsize=14)

    plt.show()