# Impedance Data Alignment with ImpedanceTransformer

This notebook demonstrates how to align Sysmex impedance data (RBC and PLT histograms) between different analysers using the `ImpedanceTransformer` API.

## Overview

The `ImpedanceTransformer` uses Gaussian Mixture Models (GMM) combined with Optimal Transport (OT) in 1D to align impedance histogram distributions across different analysers. This is useful when:

- Combining data from multiple analyser machines
- Correcting batch effects between analyser units
- Training machine learning models on data from one analyser and applying to another

## Method

1. **Load OutputData.csv**: Impedance data in histogram format (RBC and PLT channels)
2. **Fit GMMs**: Gaussian mixture models are fit to both source and target distributions
3. **Compute transport**: Optimal transport plan computed between GMM components
4. **Transform**: Individual histogram bins are transformed using the transport map
5. **Validate**: Compare transformed distribution to target

## Data Format

Impedance data (`OutputData.csv`) contains:
- **Sample No.**: Sample identifier
- **RBC_RAW_000 to RBC_RAW_127**: Red blood cell impedance histogram (128 bins, 0-250 fL range, ~1.95 fL bin width)
- **PLT_RAW_000 to PLT_RAW_127**: Platelet impedance histogram (128 bins, 0-40 fL range, ~0.31 fL bin width)
- Other columns: Metadata and derived parameters

**Important**: Column names are always indexed 000-127, but they represent different femtoliter (fL) ranges:
- **RBC**: 0-250 fL (full particle size range for red blood cells)
- **PLT**: 0-40 fL (zoomed view for smaller platelets - higher resolution in small particle range)


Once again, we are using our own Sysmex raw data from the INTERVAL and STRIDES studies. Users will have to replace the data loading sections of this notebook with their own paths.

In [None]:
import os
import sys
import numpy as np
import pandas as pd
import matplotlib
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path

# PDF-compatible fonts
matplotlib.rcParams["pdf.fonttype"] = 42
matplotlib.rcParams["ps.fonttype"] = 42

# Scientific plot style
import scienceplots

plt.style.use(["science", "nature"])

# Colourblind-friendly palette
SEABORN_PALETTE = "colorblind"
seaborn_colors = sns.color_palette(SEABORN_PALETTE)

# Add parent directory to path to import sysmexcbctools
# This allows imports to work whether or not the package is installed
repo_root = (
    Path(__file__).parent.parent.parent
    if "__file__" in globals()
    else Path.cwd().parent.parent
)
sys.path.insert(0, str(repo_root))

# Import from sysmexcbctools.transfer
from sysmexcbctools.transfer.sysmexalign import ImpedanceTransformer
from sysmexcbctools.transfer.config import load_config

print("✓ Imports successful")
print(f"  Repository root: {repo_root}")

## 1. Prepare Data Paths

Impedance data is stored in `OutputData.csv` files containing RBC and PLT histogram bins.

For this example, we'll use **real data from two different Sysmex analysers**:
- **Source**: STRIDES study data
- **Target**: INTERVAL study (analyser XN-10^11036)

We'll align both RBC and PLT impedance distributions.

In [None]:
# Load configuration to get data paths
config = load_config(str(repo_root / "sysmexcbctools/transfer/config/data_paths.yaml"))

# Get dataset paths for two different analysers
# Source: STRIDES impedance data (directory containing multiple OutputData.csv files)
# Target: INTERVAL analyser XN-10^11036 impedance data (single OutputData.csv file)
SOURCE_DIR = config["datasets"]["raw"]["strides"]
TARGET_FILE = config["datasets"]["raw"]["interval_36"] + "/OutputData.csv"

# Also load sample numbers for filtering (optional but recommended)
source_samples_file = config["files"]["centile_samples"]["strides"]
target_samples_file = config["files"]["centile_samples"]["interval_baseline_36"]

# Create output directory
OUTPUT_DIR = "../outputs/impedance_transformed/"
os.makedirs(OUTPUT_DIR, exist_ok=True)

print("Data Configuration Loaded:")
print(f"Source directory: {SOURCE_DIR}")
print(f"Target file: {TARGET_FILE}")
print(f"Source samples file: {source_samples_file}")
print(f"Target samples file: {target_samples_file}")
print(f"Output directory: {OUTPUT_DIR}")

In [None]:
# Load impedance data
print("Loading impedance data...")

# Source: Find all OutputData.csv files in subdirectories and concatenate them
print("\n1. Loading source data (STRIDES)...")
source_files = sorted(Path(SOURCE_DIR).rglob("OutputData.csv"))
print(f"   Found {len(source_files)} OutputData.csv files in source directory")

if len(source_files) == 0:
    raise FileNotFoundError(f"No OutputData.csv files found in {SOURCE_DIR}")

# Load and concatenate all source files
source_dfs = []
for i, file in enumerate(source_files, 1):
    print(f"   Loading file {i}/{len(source_files)}: {file.parent.name}/OutputData.csv")
    df = pd.read_csv(file)
    source_dfs.append(df)

source_df = pd.concat(source_dfs, ignore_index=True)
print(f"   ✓ Concatenated {len(source_files)} files into single DataFrame")
print(f"   Total source samples: {len(source_df):,}")

# Target: Load single file
print("\n2. Loading target data (INTERVAL)...")
target_df = pd.read_csv(TARGET_FILE)
print(f"   Total target samples: {len(target_df):,}")

# Load sample numbers for filtering (recommended for large datasets)
source_sample_nos = np.load(source_samples_file, allow_pickle=True)
target_sample_nos = np.load(target_samples_file, allow_pickle=True)

print(f"\n3. Sample filtering:")
print(f"  Source samples: {len(source_sample_nos):,} centile samples")
print(f"  Target samples: {len(target_sample_nos):,} centile samples")

# For initial testing, we'll use centile samples
USE_CENTILE_SAMPLES = True  # Set to False to use all samples

if USE_CENTILE_SAMPLES:
    print("\nUsing centile samples for efficient computation...")
    source_df_filtered = source_df[
        source_df["Sample No."].isin(source_sample_nos)
    ].copy()
    target_df_filtered = target_df[
        target_df["Sample No."].isin(target_sample_nos)
    ].copy()

    print(f"  Filtered source: {len(source_df_filtered):,} samples")
    print(f"  Filtered target: {len(target_df_filtered):,} samples")
else:
    source_df_filtered = source_df.copy()
    target_df_filtered = target_df.copy()
    print("\nUsing all available samples...")

# Display sample of data structure
print("\nSample of source data (first 5 columns):")
print(source_df_filtered.iloc[:3, :5])

## 2. Initialize the Transformer

Create an `ImpedanceTransformer` with appropriate parameters.

**Key Parameter: `gmm_sample_size`**

This controls the number of point samples used for GMM fitting. When the total histogram counts exceed this value, the data is downsampled using **probabilistic rounding** to preserve distribution shape.

The transformer will:
1. Fit separate GMMs for RBC and PLT histogram distributions
2. Compute optimal transport between source and target for each channel
3. Apply transformations to align histograms

In [None]:
transformer = ImpedanceTransformer(
    gmm_sample_size=50000,  # Increased from 1000 to preserve distribution shape
    n_jobs=-1,
)

print("✓ ImpedanceTransformer initialized")
print(f"  GMM sample size: {transformer.gmm_sample_size:,}")
print(f"  Parallel jobs: {transformer.n_jobs}")

## 3. Fit the Transformation

Fit GMM models to both source and target distributions for RBC and PLT channels, then compute optimal transport maps.

**This step performs the expensive GMM fitting and transport map computation:**
1. Fits 4 GMMs (RBC source, RBC target, PLT source, PLT target)
2. Computes 2 optimal transport maps (RBC, PLT)
3. Stores these for reuse during transformation

**Note**: This can take several minutes. Once fitted, you can save the transformer and reuse it for multiple datasets without refitting.

In [None]:
# Check if we have data to work with
if len(source_df_filtered) == 0 or len(target_df_filtered) == 0:
    print("⚠ ERROR: No samples found after filtering!")
    print(f"  Source samples: {len(source_df_filtered)}")
    print(f"  Target samples: {len(target_df_filtered)}")
    print("\nPossible causes:")
    print("  - Data not available on this system (requires HPC/RDS access)")
    print("  - Incorrect paths in config/data_paths.yaml")
    print("  - Sample number filtering removed all files")
    print("\nTo proceed:")
    print("  1. Verify you have access to the RDS data")
    print("  2. Check paths in config/data_paths.yaml")
    print("  3. Try setting USE_CENTILE_SAMPLES = False in cell above")
else:
    print(
        f"✓ Found {len(source_df_filtered):,} source samples and {len(target_df_filtered):,} target samples"
    )
    print("\nFitting transformation parameters...")

    # Fit using filtered dataframes
    # Pass sample numbers if using centile samples
    if USE_CENTILE_SAMPLES:
        transformer.fit(
            source_df=source_df,
            target_df=target_df,
            source_sample_nos=source_sample_nos,
            target_sample_nos=target_sample_nos,
        )
    else:
        transformer.fit(
            source_df=source_df,
            target_df=target_df,
        )

    print("\n✓ Transformer fitted successfully!")
    # print(f"  Source standards: {len(transformer.source_standards_):,} samples")
    # print(f"  Target standards: {len(transformer.target_standards_):,} samples")

## 3a. Visualise Fitted GMMs

Inspect the quality of the fitted GMMs by comparing them to the data used for fitting.

This helps debug potential issues:
- Are the GMM components well-positioned?
- Does the GMM capture the distribution shape?
- Are there outliers or unusual patterns?

We'll show 4 plots:
- **RBC Source GMM** vs source histogram data
- **RBC Target GMM** vs target histogram data  
- **PLT Source GMM** vs source histogram data (log-transformed)
- **PLT Target GMM** vs target histogram data (log-transformed)

### Known Limitations

**PLT Distribution Shape:**
The PLT impedance distribution is Gaussian-like on the left side but has a sharp cutoff on the right side (physical constraint from the 0-40 fL measurement range). This asymmetric shape cannot be perfectly captured by GMMs, which assume Gaussian components. 

**Observed effects:**
- GMM may underfit or overfit certain regions
- Transformation quality for PLT may be lower than for RBC
- Some artificial spikes may still appear in transformed PLT histograms

**Future work:**
- Consider truncated Gaussian mixture models for PLT
- Investigate alternative distributions (e.g., log-normal, gamma) for PLT components
- Explore histogram matching methods that don't assume parametric forms

For now, this is an acceptable limitation given the overall transformation quality improvements achieved with increased `gmm_sample_size` and probabilistic rounding.

In [None]:
if not transformer.is_fitted_:
    print("⚠ Transformer not fitted yet. Skipping GMM visualization.")
else:
    from sysmexcbctools.transfer.sysmexalign.alignment_1d import sample_impedance_array
    import matplotlib.pyplot as plt
    from scipy.stats import norm

    print("Visualizing fitted GMMs...")

    # Recreate the sampled data used for fitting
    class Args:
        gmm_sample_size = transformer.gmm_sample_size

    args = Args()

    # Get standard samples (same as in fit())
    source_df_copy = source_df.copy()
    target_df_copy = target_df.copy()
    source_df_copy["IsStandard"] = (
        source_df_copy["Sample No."].isin(source_sample_nos).astype(int)
    )
    target_df_copy["IsStandard"] = (
        target_df_copy["Sample No."].isin(target_sample_nos).astype(int)
    )

    source_standards = source_df_copy[source_df_copy["IsStandard"] == 1]
    target_standards = target_df_copy[target_df_copy["IsStandard"] == 1]

    # RBC data
    source_rbc_data = source_standards.filter(like="RBC_RAW_").sum(axis=0)
    target_rbc_data = target_standards.filter(like="RBC_RAW_").sum(axis=0)
    X_source_rbc = sample_impedance_array(args, source_rbc_data)
    X_target_rbc = sample_impedance_array(args, target_rbc_data)

    # PLT data (log-transformed)
    source_plt_data = source_standards.filter(like="PLT_RAW_").sum(axis=0)
    target_plt_data = target_standards.filter(like="PLT_RAW_").sum(axis=0)
    X_source_plt_log = np.log(sample_impedance_array(args, source_plt_data) + 1)
    X_target_plt_log = np.log(sample_impedance_array(args, target_plt_data) + 1)

    # Create visualization
    fig, axes = plt.subplots(2, 2, figsize=(4.5, 4.5))

    # Helper function to compute GMM PDF
    def plot_gmm_with_data(ax, data, gmm, title, xlabel):
        # Plot histogram of data
        ax.hist(data, bins=100, density=True, alpha=0.5, color="blue", label="Data")

        # Compute GMM PDF
        x_range = np.linspace(data.min(), data.max(), 1000).reshape(-1, 1)
        log_prob = gmm.score_samples(x_range)
        pdf = np.exp(log_prob)

        # Plot GMM PDF
        ax.plot(
            x_range,
            pdf,
            "r-",
            linewidth=2,
            label=f"GMM ({gmm.n_components} components)",
        )

        # Plot individual Gaussian components
        for i in range(gmm.n_components):
            mean = gmm.means_[i, 0]
            var = gmm.covariances_[i, 0, 0]
            weight = gmm.weights_[i]

            component_pdf = weight * norm.pdf(x_range, mean, np.sqrt(var))
            ax.plot(
                x_range,
                component_pdf,
                "--",
                alpha=0.5,
                label=f"Component {i+1} ($\mu$={mean:.1f}, $\sigma$={np.sqrt(var):.1f})",
            )

        ax.set_xlabel(xlabel)
        ax.set_ylabel("Density")
        ax.set_title(title)
        ax.legend(fontsize=5)
        ax.grid(alpha=0.3)

    # Plot 1: RBC Source
    plot_gmm_with_data(
        axes[0, 0],
        X_source_rbc,
        transformer.rbc_source_gmm_,
        "RBC Source GMM Fit",
        "Bin Index (0-127)",
    )

    # Plot 2: RBC Target
    plot_gmm_with_data(
        axes[0, 1],
        X_target_rbc,
        transformer.rbc_target_gmm_,
        "RBC Target GMM Fit",
        "Bin Index (0-127)",
    )

    # Plot 3: PLT Source (log-transformed)
    plot_gmm_with_data(
        axes[1, 0],
        X_source_plt_log,
        transformer.plt_source_gmm_,
        "PLT Source GMM Fit (Log-Transformed)",
        "log(Bin Index + 1)",
    )

    # Plot 4: PLT Target (log-transformed)
    plot_gmm_with_data(
        axes[1, 1],
        X_target_plt_log,
        transformer.plt_target_gmm_,
        "PLT Target GMM Fit (Log-Transformed)",
        "log(Bin Index + 1)",
    )

    plt.tight_layout()
    plt.savefig("../outputs/impedance_gmm_fits.png", dpi=150, bbox_inches="tight")
    plt.show()

    print("\n✓ GMM fit visualizations saved to ../outputs/impedance_gmm_fits.png")
    print("\nInterpretation:")
    print("  - Blue histogram: Actual sampled data used for GMM fitting")
    print("  - Red solid line: Overall GMM probability density function")
    print("  - Dashed lines: Individual Gaussian components")
    print("  - Good fit = red line closely follows blue histogram")
    print(f"  - RBC GMMs: {transformer.rbc_source_gmm_.n_components} components")
    print(f"  - PLT GMMs: {transformer.plt_source_gmm_.n_components} components")

    # Print component details
    print("\n" + "=" * 60)
    print("RBC SOURCE GMM COMPONENTS")
    print("=" * 60)
    for i in range(transformer.rbc_source_gmm_.n_components):
        mean = transformer.rbc_source_gmm_.means_[i, 0]
        var = transformer.rbc_source_gmm_.covariances_[i, 0, 0]
        weight = transformer.rbc_source_gmm_.weights_[i]
        print(
            f"  Component {i+1}: μ={mean:6.2f}, σ={np.sqrt(var):6.2f}, weight={weight:.3f}"
        )

    print("\n" + "=" * 60)
    print("RBC TARGET GMM COMPONENTS")
    print("=" * 60)
    for i in range(transformer.rbc_target_gmm_.n_components):
        mean = transformer.rbc_target_gmm_.means_[i, 0]
        var = transformer.rbc_target_gmm_.covariances_[i, 0, 0]
        weight = transformer.rbc_target_gmm_.weights_[i]
        print(
            f"  Component {i+1}: μ={mean:6.2f}, σ={np.sqrt(var):6.2f}, weight={weight:.3f}"
        )

## 4. Save and Load Transformer

Once fitted, you can save the transformer and reuse it later without re-fitting the GMMs.

In [None]:
# Save the fitted transformer for later use
# This allows you to skip the time-consuming fitting step in future sessions

save_path = "../outputs/impedance_transformer.pkl"
transformer.save(save_path)
print(f"✓ You can now load this transformer in future sessions to skip re-fitting")

In [None]:
# Load a previously saved transformer
# Uncomment and run this cell if you want to load a saved transformer instead of fitting

# from sysmexcbctools.transfer.sysmexalign import ImpedanceTransformer
# save_path = "../outputs/impedance_transformer.pkl"
# transformer = ImpedanceTransformer.load(save_path)
# print(f"✓ Transformer loaded from: {save_path}")
# print(f"  GMM sample size: {transformer.gmm_sample_size:,}")
# print(f"  Fitted: {transformer.is_fitted_}")

## 5. Transform Source Data

Apply the learned transformation to source data. This will:
1. Use the pre-computed GMMs and transport maps from fit()
2. Transform RBC histograms by converting bins→points→transform→re-bin
3. Transform PLT histograms (with log/exp transforms for better Gaussianity)
4. Return a transformed DataFrame with aligned distributions

**Note**: Unlike fit(), transform() is fast because it reuses the pre-computed GMMs and transport maps.

In [None]:
# Transform the source data using the fitted GMMs and transport maps
# This reuses the pre-computed GMMs - no refitting happens here!

if not transformer.is_fitted_:
    print("⚠ ERROR: Transformer not fitted yet. Please run the fit() cell above first.")
else:
    print("Transforming source data...")
    print(f"  Input: {len(source_df_filtered):,} samples")

    # Transform creates aligned histograms by:
    # 1. Converting histogram bins to points (based on counts)
    # 2. Transforming points using pre-computed GMMs and transport maps
    # 3. Re-binning transformed points into histograms
    transformed_df = transformer.transform(source_df_filtered.copy())

    print(f"✓ Transformation complete!")
    print(f"  Output: {len(transformed_df):,} samples")
    print(f"  Columns transformed: RBC_RAW_* and PLT_RAW_* histograms")

    # Display sample of transformed data
    print("\nSample of transformed data (first 5 columns):")
    print(transformed_df.iloc[:3, :5])

## 6. Validate Transformation Quality

Visualize the alignment quality by comparing source, target, and transformed distributions.

We'll create 4-panel plots showing:
- **Top row**: RBC impedance histograms (before and after transformation)
- **Bottom row**: PLT impedance histograms (before and after transformation)
- **Left column**: Original source (blue) vs target (orange) - shows initial misalignment
- **Right column**: Transformed (green) vs target (orange) - should overlap if transformation worked well

In [None]:
if not transformer.is_fitted_ or "transformed_df" not in locals():
    print(
        "⚠ Transformer not fitted or transformation not performed. Skipping visualization."
    )
else:
    print("Creating distribution comparison plots...")

    # Extract RBC and PLT histogram columns
    rbc_cols = [col for col in source_df_filtered.columns if col.startswith("RBC_RAW")]
    plt_cols = [col for col in source_df_filtered.columns if col.startswith("PLT_RAW")]

    print(f"  Found {len(rbc_cols)} RBC histogram bins")
    print(f"  Found {len(plt_cols)} PLT histogram bins")

    # Create bin centers (fL) - columns are indices 0-127, but represent different fL ranges
    # RBC: 128 bins spanning 0-250 fL (bin width ~1.95 fL)
    # PLT: 128 bins spanning 0-40 fL (bin width ~0.31 fL) - zoomed view for smaller particles
    rbc_bins = np.linspace(0, 250, len(rbc_cols))
    plt_bins = np.linspace(0, 40, len(plt_cols))

    # Compute average histograms
    source_rbc_avg = source_df_filtered[rbc_cols].mean(axis=0).values
    target_rbc_avg = target_df_filtered[rbc_cols].mean(axis=0).values
    transformed_rbc_avg = transformed_df[rbc_cols].mean(axis=0).values

    source_plt_avg = source_df_filtered[plt_cols].mean(axis=0).values
    target_plt_avg = target_df_filtered[plt_cols].mean(axis=0).values
    transformed_plt_avg = transformed_df[plt_cols].mean(axis=0).values

    # Create figure
    fig, axes = plt.subplots(2, 2, figsize=(4.5, 4.5))

    # ========== RBC Histograms ==========
    # Panel 1: RBC - Before transformation
    ax = axes[0, 0]
    ax.plot(
        rbc_bins,
        source_rbc_avg,
        label="Source (STRIDES)",
        color="blue",
        linewidth=2,
        alpha=0.7,
    )
    ax.plot(
        rbc_bins,
        target_rbc_avg,
        label="Target (INTERVAL)",
        color="orange",
        linewidth=2,
        alpha=0.7,
    )
    ax.set_xlabel("Volume (fL)")
    ax.set_ylabel("Average Count")
    ax.set_title("RBC Impedance - Before Transformation")
    ax.legend()
    ax.grid(alpha=0.3)

    # Panel 2: RBC - After transformation
    ax = axes[0, 1]
    ax.plot(
        rbc_bins,
        transformed_rbc_avg,
        label="Transformed",
        color="green",
        linewidth=2,
        alpha=0.7,
    )
    ax.plot(
        rbc_bins,
        target_rbc_avg,
        label="Target (INTERVAL)",
        color="orange",
        linewidth=2,
        alpha=0.7,
    )
    ax.set_xlabel("Volume (fL)")
    ax.set_ylabel("Average Count")
    ax.set_title("RBC Impedance - After Transformation")
    ax.legend()
    ax.grid(alpha=0.3)

    # ========== PLT Histograms ==========
    # Panel 3: PLT - Before transformation
    ax = axes[1, 0]
    ax.plot(
        plt_bins,
        source_plt_avg,
        label="Source (STRIDES)",
        color="blue",
        linewidth=2,
        alpha=0.7,
    )
    ax.plot(
        plt_bins,
        target_plt_avg,
        label="Target (INTERVAL)",
        color="orange",
        linewidth=2,
        alpha=0.7,
    )
    ax.set_xlabel("Volume (fL)")
    ax.set_ylabel("Average Count")
    ax.set_title("PLT Impedance - Before Transformation")
    ax.legend()
    ax.grid(alpha=0.3)

    # Panel 4: PLT - After transformation
    ax = axes[1, 1]
    ax.plot(
        plt_bins,
        transformed_plt_avg,
        label="Transformed",
        color="green",
        linewidth=2,
        alpha=0.7,
    )
    ax.plot(
        plt_bins,
        target_plt_avg,
        label="Target (INTERVAL)",
        color="orange",
        linewidth=2,
        alpha=0.7,
    )
    ax.set_xlabel("Volume (fL)")
    ax.set_ylabel("Average Count")
    ax.set_title("PLT Impedance - After Transformation")
    ax.legend()
    ax.grid(alpha=0.3)

    plt.tight_layout()
    plt.savefig("../outputs/impedance_comparison.png", dpi=150, bbox_inches="tight")
    plt.show()

    print("\n✓ Distribution plots saved to ../outputs/impedance_comparison.png")
    print("\nInterpretation:")
    print("  - Top row: RBC impedance histograms before and after transformation")
    print("  - Bottom row: PLT impedance histograms before and after transformation")
    print("  - Left column: Original source (blue) vs target (orange)")
    print("  - Right column: Transformed (green) should match target (orange)")
    print("  - Success = green and orange curves overlap well")