# XN_SAMPLE Alignment Between analysers

This notebook demonstrates how to align tabular XN_SAMPLE data between different Sysmex analysers using the `XNSampleTransformer` class.

**What is XN_SAMPLE alignment?**
- XN_SAMPLE.csv files contain complete blood count (CBC) parameters from Sysmex analysers
- Different analysers can produce systematically different measurements for the same sample
- MAD (Median Absolute Deviation) / median-based transformation aligns distributions

**When to use XN_SAMPLE alignment:**
- Combining data from multiple Sysmex analysers, and not wanting to use the GAM-based adjustment ("correction" module)
- Training ML models on multi-site data
- Longitudinal studies where analysers were upgraded/replaced

**Method:**
- Uses robust statistics (median and MAD) to align distributions, hence it is slightly more sophisticated than just adding a factor term in the GAM method ("correction" module)
- More robust to outliers than mean/standard deviation
- Transforms: `X_transformed = target_median + (X - source_median) * (target_MAD / source_MAD)`

In [None]:
import sys
import os
import numpy as np
import pandas as pd
import matplotlib
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
from scipy import stats

# PDF-compatible fonts
matplotlib.rcParams["pdf.fonttype"] = 42
matplotlib.rcParams["ps.fonttype"] = 42

# Scientific plot style
import scienceplots

plt.style.use(["science", "nature"])

# Colourblind-friendly palette
SEABORN_PALETTE = "colorblind"
seaborn_colors = sns.color_palette(SEABORN_PALETTE)

# Add parent directory to path to import sysmexcbctools
# This allows imports to work whether or not the package is installed
repo_root = (
    Path(__file__).parent.parent.parent
    if "__file__" in globals()
    else Path.cwd().parent.parent
)
sys.path.insert(0, str(repo_root))

from sysmexcbctools.transfer.sysmexalign import XNSampleTransformer
from sysmexcbctools.transfer.sysmexalign.alignment_1d import mad
from sysmexcbctools.transfer.config import load_config

print("Imports successful!")


# Function to escape special LaTeX characters in parameter names
def latex_escape(text):
    """Escape special characters for LaTeX rendering."""
    import re

    # FIRST: Escape LaTeX special characters (before adding math mode)
    # Save the original parentheses patterns we'll convert to math mode
    text = text.replace("&", "\\&")
    text = text.replace("%", "\\%")
    text = text.replace("#", "\\#")
    text = text.replace("_", "\\_")
    text = text.replace("~", "\\textasciitilde")

    # THEN: Handle superscripts in math mode
    # Match patterns like (10^3/uL) and convert to LaTeX math mode
    # Convert (10^3/uL) -> ($10^{3}$/uL)
    text = re.sub(r"\((\d+)\^(\d+)/(\w+)\)", r"($\1^{\2}$/\3)", text)

    return text

## 1. Load XN_SAMPLE Data from Two analysers

We'll use XN_SAMPLE data from two different blood donor studies:
- **Source**: STRIDES
- **Target**: INTERVAL (analyser 36)

These are real data from the INTERVAL and STRIDES study, collected on different Sysmex XN analysers.

In [None]:
# Load configuration to get data paths
config = load_config(str(repo_root / "sysmexcbctools/transfer/config/data_paths.yaml"))

source_path = config["datasets"]["raw"]["strides"] + "/"
target_path = config["datasets"]["raw"]["interval_36"] + "/XN_SAMPLE.csv"

# Load data
print("Loading XN_SAMPLE data...")
# source path is actually a folder with multiple subfolders, each containing a XN_SAMPLE.csv
import glob

source_df = pd.concat(
    [
        pd.read_csv(f, low_memory=False)
        for f in glob.glob(source_path + "**/XN_SAMPLE.csv", recursive=True)
    ]
)
target_df = pd.read_csv(target_path, low_memory=False)

print(f"Source data shape: {source_df.shape}")
print(f"Target data shape: {target_df.shape}")
print(f"\nSource columns: {len(source_df.columns)} columns")
print(f"Target columns: {len(target_df.columns)} columns")

## 2. Define Columns to Transform

We'll transform standard full blood count (FBC) parameters. These are the core clinical measurements that should be aligned between analysers.

In [None]:
# Standard FBC columns to transform
fbc_columns = [
    # Red blood cell parameters
    "RBC(10^6/uL)",  # Red blood cell count
    "HGB(g/dL)",  # Hemoglobin
    "HCT(%)",  # Hematocrit
    "MCV(fL)",  # Mean corpuscular volume
    "MCH(pg)",  # Mean corpuscular hemoglobin
    "MCHC(g/dL)",  # Mean corpuscular hemoglobin concentration
    "RDW-SD(fL)",  # Red cell distribution width (SD)
    "RDW-CV(%)",  # Red cell distribution width (CV)
    # White blood cell parameters
    "WBC(10^3/uL)",  # White blood cell count
    "NEUT#(10^3/uL)",  # Neutrophil count
    "LYMPH#(10^3/uL)",  # Lymphocyte count
    "MONO#(10^3/uL)",  # Monocyte count
    "EO#(10^3/uL)",  # Eosinophil count
    "BASO#(10^3/uL)",  # Basophil count
    "NEUT%(%)",  # Neutrophil percentage
    "LYMPH%(%)",  # Lymphocyte percentage
    "MONO%(%)",  # Monocyte percentage
    "EO%(%)",  # Eosinophil percentage
    "BASO%(%)",  # Basophil percentage
    # Platelet parameters
    "PLT(10^3/uL)",  # Platelet count
    "MPV(fL)",  # Mean platelet volume
    "PCT(%)",  # Plateletcrit
    "PDW(fL)",  # Platelet distribution width
]

# Check which columns are available in both datasets
source_cols_available = [col for col in fbc_columns if col in source_df.columns]
target_cols_available = [col for col in fbc_columns if col in target_df.columns]
columns_to_transform = list(set(source_cols_available) & set(target_cols_available))

print(f"Requested columns: {len(fbc_columns)}")
print(f"Available in source: {len(source_cols_available)}")
print(f"Available in target: {len(target_cols_available)}")
print(f"Available in both: {len(columns_to_transform)}")
print(f"\nColumns to transform: {sorted(columns_to_transform)}")

## 3. Load Centile Sample Numbers (Optional)

We can load sample numbers to use in lieu of an official standard (such as a parallel callibration measurement)

In [None]:
# Load baseline sample numbers
# Also load sample numbers for filtering (optional but recommended)
source_samples_path = config["files"]["centile_samples"]["strides"]
target_samples_path = config["files"]["centile_samples"]["interval_baseline_36"]

source_sample_nos = np.load(source_samples_path, allow_pickle=True)
target_sample_nos = np.load(target_samples_path, allow_pickle=True)

# Convert to list of strings for matching
source_sample_nos = [str(s).strip() for s in source_sample_nos]
target_sample_nos = [str(s).strip() for s in target_sample_nos]

print(f"Source centile samples: {len(source_sample_nos)}")
print(f"Target centile samples: {len(target_sample_nos)}")

# Check how many samples match
source_df["Sample No."] = source_df["Sample No."].astype(str).str.strip()
target_df["Sample No."] = target_df["Sample No."].astype(str).str.strip()

source_matched = source_df["Sample No."].isin(source_sample_nos).sum()
target_matched = target_df["Sample No."].isin(target_sample_nos).sum()

print(
    f"\nMatched samples in source: {source_matched:,} / {len(source_df):,} ({100*source_matched/len(source_df):.1f}%)"
)
print(
    f"Matched samples in target: {target_matched:,} / {len(target_df):,} ({100*target_matched/len(target_df):.1f}%)"
)

## 4. Create and Fit XNSampleTransformer

Now we'll create the transformer and fit it using the centile samples. The transformer will compute median and MAD for each column in both source and target distributions.

In [None]:
# Create transformer
transformer = XNSampleTransformer(columns=columns_to_transform)

# Fit transformer
print("Fitting XNSampleTransformer...\n")
transformer.fit(
    source_df=source_df,
    target_df=target_df,
    source_sample_nos=source_sample_nos,
    target_sample_nos=target_sample_nos,
)

print("\nTransformer fitted successfully!")

## 5. Transform Source Data

Apply the fitted transformation to align source distribution to target distribution.

In [None]:
# Transform the source data
print("Transforming source data...\n")
transformed_df = transformer.transform(source_df.copy())

print(f"\nTransformed data shape: {transformed_df.shape}")
print("Transformation complete!")

## 6. Evaluate Transformation Quality

Let's evaluate how well the transformation aligns the distributions using multiple metrics.

In [None]:
# Compute evaluation metrics for baseline samples
def evaluate_transformation(
    source_df, transformed_df, target_df, sample_nos_source, sample_nos_target, columns
):
    """
    Evaluate transformation quality using multiple metrics.
    """
    results = []

    for col in columns:
        # Filter to baseline samples
        source_vals = pd.to_numeric(
            source_df.loc[source_df["Sample No."].isin(sample_nos_source), col],
            errors="coerce",
        ).dropna()

        transformed_vals = pd.to_numeric(
            transformed_df.loc[
                transformed_df["Sample No."].isin(sample_nos_source), col
            ],
            errors="coerce",
        ).dropna()

        target_vals = pd.to_numeric(
            target_df.loc[target_df["Sample No."].isin(sample_nos_target), col],
            errors="coerce",
        ).dropna()

        if len(source_vals) < 10 or len(target_vals) < 10:
            continue

        # Compute statistics
        source_median = source_vals.median()
        source_mad = mad(source_vals)

        transformed_median = transformed_vals.median()
        transformed_mad = mad(transformed_vals)

        target_median = target_vals.median()
        target_mad = mad(target_vals)

        # Median difference (before and after)
        median_diff_before = abs(source_median - target_median)
        median_diff_after = abs(transformed_median - target_median)
        median_improvement = (
            100 * (median_diff_before - median_diff_after) / median_diff_before
            if median_diff_before > 0
            else 0
        )

        # MAD difference (before and after)
        mad_diff_before = abs(source_mad - target_mad)
        mad_diff_after = abs(transformed_mad - target_mad)
        mad_improvement = (
            100 * (mad_diff_before - mad_diff_after) / mad_diff_before
            if mad_diff_before > 0
            else 0
        )

        # Wasserstein distance (1D)
        wasserstein_before = stats.wasserstein_distance(source_vals, target_vals)
        wasserstein_after = stats.wasserstein_distance(transformed_vals, target_vals)
        wasserstein_improvement = (
            100 * (wasserstein_before - wasserstein_after) / wasserstein_before
            if wasserstein_before > 0
            else 0
        )

        # KS statistic (goodness of fit test)
        ks_before = stats.ks_2samp(source_vals, target_vals).statistic
        ks_after = stats.ks_2samp(transformed_vals, target_vals).statistic
        ks_improvement = (
            100 * (ks_before - ks_after) / ks_before if ks_before > 0 else 0
        )

        results.append(
            {
                "Column": col,
                "Median (Source)": source_median,
                "Median (Transformed)": transformed_median,
                "Median (Target)": target_median,
                "Median Δ Before": median_diff_before,
                "Median Δ After": median_diff_after,
                "Median Improvement %": median_improvement,
                "MAD (Source)": source_mad,
                "MAD (Transformed)": transformed_mad,
                "MAD (Target)": target_mad,
                "MAD Δ Before": mad_diff_before,
                "MAD Δ After": mad_diff_after,
                "MAD Improvement %": mad_improvement,
                "Wasserstein Before": wasserstein_before,
                "Wasserstein After": wasserstein_after,
                "Wasserstein Improvement %": wasserstein_improvement,
                "KS Before": ks_before,
                "KS After": ks_after,
                "KS Improvement %": ks_improvement,
            }
        )

    return pd.DataFrame(results)


# Evaluate transformation
print("Evaluating transformation quality...\n")
eval_df = evaluate_transformation(
    source_df,
    transformed_df,
    target_df,
    source_sample_nos,
    target_sample_nos,
    columns_to_transform,
)

# Display results
print("Transformation Quality Metrics:")
print("=" * 80)
display(
    eval_df[
        [
            "Column",
            "Median Improvement %",
            "MAD Improvement %",
            "Wasserstein Improvement %",
            "KS Improvement %",
        ]
    ].round(2)
)

# Summary statistics
print("\nSummary Statistics:")
print(f"Average Median Improvement: {eval_df['Median Improvement %'].mean():.2f}%")
print(f"Average MAD Improvement: {eval_df['MAD Improvement %'].mean():.2f}%")
print(
    f"Average Wasserstein Improvement: {eval_df['Wasserstein Improvement %'].mean():.2f}%"
)
print(f"Average KS Improvement: {eval_df['KS Improvement %'].mean():.2f}%")

## 7. Visualize Transformation Results

Let's visualize the transformation for key FBC parameters.

In [None]:
# Select key columns to visualize
viz_columns = [
    "HGB(g/dL)",
    "RBC(10^6/uL)",
    "WBC(10^3/uL)",
    "PLT(10^3/uL)",
    "MCV(fL)",
    "NEUT#(10^3/uL)",
]
viz_columns = [col for col in viz_columns if col in columns_to_transform]

# Create figure
fig, axes = plt.subplots(2, 3, figsize=(6.6, 4.5))
axes = axes.flatten()

for idx, col in enumerate(viz_columns[:6]):
    ax = axes[idx]

    # Extract data for baseline samples
    source_vals = pd.to_numeric(
        source_df.loc[source_df["Sample No."].isin(source_sample_nos), col],
        errors="coerce",
    ).dropna()

    transformed_vals = pd.to_numeric(
        transformed_df.loc[transformed_df["Sample No."].isin(source_sample_nos), col],
        errors="coerce",
    ).dropna()

    target_vals = pd.to_numeric(
        target_df.loc[target_df["Sample No."].isin(target_sample_nos), col],
        errors="coerce",
    ).dropna()

    # Plot histograms
    ax.hist(
        source_vals, bins=50, alpha=0.5, label="Source (36)", density=True, color="blue"
    )
    ax.hist(
        transformed_vals,
        bins=50,
        alpha=0.5,
        label="Transformed",
        density=True,
        color="green",
    )
    ax.hist(
        target_vals,
        bins=50,
        alpha=0.5,
        label="Target (41)",
        density=True,
        color="orange",
    )

    ax.set_xlabel(latex_escape(col))
    ax.set_ylabel("Density")
    ax.legend()
    ax.set_title(latex_escape(col))

plt.tight_layout()
plt.savefig(
    "../outputs/xnsample_transformation_distributions.png", dpi=300, bbox_inches="tight"
)
plt.show()

print(ax.set_xlabel(latex_escape(col)))

## 8. Quantile-Quantile Plots

Q-Q plots show how well the transformed distribution matches the target distribution.

In [None]:
# Create Q-Q plots
fig, axes = plt.subplots(2, 3, figsize=(6.6, 4.5))
axes = axes.flatten()

for idx, col in enumerate(viz_columns[:6]):
    ax = axes[idx]

    # Extract data
    transformed_vals = (
        pd.to_numeric(
            transformed_df.loc[
                transformed_df["Sample No."].isin(source_sample_nos), col
            ],
            errors="coerce",
        )
        .dropna()
        .values
    )

    target_vals = (
        pd.to_numeric(
            target_df.loc[target_df["Sample No."].isin(target_sample_nos), col],
            errors="coerce",
        )
        .dropna()
        .values
    )

    # Compute quantiles
    quantiles = np.linspace(0, 100, 100)
    transformed_quantiles = np.percentile(transformed_vals, quantiles)
    target_quantiles = np.percentile(target_vals, quantiles)

    # Plot Q-Q
    ax.scatter(target_quantiles, transformed_quantiles, alpha=0.6, s=20)

    # Add diagonal line (perfect alignment)
    min_val = min(target_quantiles.min(), transformed_quantiles.min())
    max_val = max(target_quantiles.max(), transformed_quantiles.max())
    ax.plot(
        [min_val, max_val],
        [min_val, max_val],
        "r--",
        linewidth=2,
        label="Perfect alignment",
    )

    ax.set_xlabel(f"Target (41) Quantiles")
    ax.set_ylabel(f"Transformed Quantiles")
    ax.set_title(latex_escape(col))
    ax.legend()
    ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig(
    "../outputs/xnsample_transformation_qqplots.png", dpi=300, bbox_inches="tight"
)
plt.show()

print("Q-Q plots saved to ../outputs/xnsample_transformation_qqplots.png")

## 9. Save and Load Transformer

You can save the fitted transformer and reload it later for transforming new data.

In [None]:
# Create output directory
output_dir = "../outputs"
os.makedirs(output_dir, exist_ok=True)

# Save transformer
transformer_path = f"{output_dir}/xnsample_transformer_36_to_41.pkl"
transformer.save(transformer_path)

# Load transformer
print("\nTesting load functionality...")
loaded_transformer = XNSampleTransformer.load(transformer_path)

# Verify loaded transformer works
test_df = source_df.head(100).copy()
test_transformed = loaded_transformer.transform(test_df)
print(f"\nTest transformation successful! Transformed {len(test_transformed)} samples.")

## 10. Compare Correlations

Let's check if the transformation preserves biological correlations between parameters.

In [None]:
# Select a subset of columns for correlation analysis
corr_columns = [
    "HGB(g/dL)",
    "RBC(10^6/uL)",
    "HCT(%)",
    "MCV(fL)",
    "WBC(10^3/uL)",
    "PLT(10^3/uL)",
]
corr_columns = [col for col in corr_columns if col in columns_to_transform]

# Compute correlations for baseline samples
source_corr_data = (
    source_df.loc[source_df["Sample No."].isin(source_sample_nos), corr_columns]
    .apply(pd.to_numeric, errors="coerce")
    .dropna()
)

transformed_corr_data = (
    transformed_df.loc[
        transformed_df["Sample No."].isin(source_sample_nos), corr_columns
    ]
    .apply(pd.to_numeric, errors="coerce")
    .dropna()
)

target_corr_data = (
    target_df.loc[target_df["Sample No."].isin(target_sample_nos), corr_columns]
    .apply(pd.to_numeric, errors="coerce")
    .dropna()
)

# Compute correlation matrices
source_corr = source_corr_data.corr()
transformed_corr = transformed_corr_data.corr()
target_corr = target_corr_data.corr()

# Apply LaTeX escaping to column and index names before plotting
source_corr_display = source_corr.copy()
source_corr_display.columns = [latex_escape(col) for col in source_corr_display.columns]
source_corr_display.index = [latex_escape(idx) for idx in source_corr_display.index]

transformed_corr_display = transformed_corr.copy()
transformed_corr_display.columns = [
    latex_escape(col) for col in transformed_corr_display.columns
]
transformed_corr_display.index = [
    latex_escape(idx) for idx in transformed_corr_display.index
]

target_corr_display = target_corr.copy()
target_corr_display.columns = [latex_escape(col) for col in target_corr_display.columns]
target_corr_display.index = [latex_escape(idx) for idx in target_corr_display.index]

# Plot correlation matrices
fig, axes = plt.subplots(1, 3, figsize=(9.9, 3.3))

sns.heatmap(
    source_corr_display,
    annot=True,
    fmt=".2f",
    cmap="coolwarm",
    vmin=-1,
    vmax=1,
    ax=axes[0],
    cbar_kws={"label": "Correlation"},
)
axes[0].set_title("Source (36) Correlations")

sns.heatmap(
    transformed_corr_display,
    annot=True,
    fmt=".2f",
    cmap="coolwarm",
    vmin=-1,
    vmax=1,
    ax=axes[1],
    cbar_kws={"label": "Correlation"},
)
axes[1].set_title("Transformed Correlations")

sns.heatmap(
    target_corr_display,
    annot=True,
    fmt=".2f",
    cmap="coolwarm",
    vmin=-1,
    vmax=1,
    ax=axes[2],
    cbar_kws={"label": "Correlation"},
)
axes[2].set_title("Target (41) Correlations")

plt.tight_layout()
plt.savefig(
    "../outputs/xnsample_transformation_correlations.png", dpi=300, bbox_inches="tight"
)
plt.show()

# Compute correlation difference
print("\nCorrelation Structure Preservation:")
print(
    f"Mean absolute difference (Source vs Target): {np.abs(source_corr.values - target_corr.values).mean():.4f}"
)
print(
    f"Mean absolute difference (Transformed vs Target): {np.abs(transformed_corr.values - target_corr.values).mean():.4f}"
)