# Flow Cytometry Alignment with FlowTransformer

This notebook demonstrates how to align Sysmex flow cytometry data (RET, WDF, WNR, PLTF channels) between different analysers using the `FlowTransformer` API.

## ⚠️ Important: Memory Requirements

**This notebook requires significant computational resources:**
- Flow cytometry data involves millions of data points
- GMM fitting with many components on large datasets is memory-intensive
- **Recommended**: Run on system with sufficient memory
- **Laptop users**: May experience kernel crashes or slow performance

**For testing purposes**, consider:
- Reducing `n_components`
- Setting `max_samples=100000` to limit data size
- Using fewer input files for initial testing

## Overview

The `FlowTransformer` uses Gaussian Mixture Models (GMM) combined with Optimal Transport (OT) to align flow cytometry distributions across different analysers. This is useful when:

- Combining data from multiple analyser machines
- Correcting batch effects between analyser units  
- Training machine learning models on data from one analyser and applying to another

## Method

1. **Fit GMMs**: Gaussian mixture models are fit to both source and target distributions
2. **Compute transport**: Optimal transport plan computed between GMM components
3. **Transform**: Individual data points are transformed using the transport map
4. **Validate**: Compare transformed distribution to target using Wasserstein distance


The math behind the method is from:
Julie Delon and Agnès Desolneux. A Wasserstein-Type Distance in the Space of 13
Gaussian Mixture Models. SIAM Journal on Imaging Sciences, 13(2):936–970, 14
January 2020. doi: 10.1137/19M1301047. URL https://epubs.siam.org/doi/abs/10. 15
1137/19M1301047. Publisher: Society for Industrial and Applied Mathematics.


### ⚠️ Important:  This notebook uses data from the INTERVAL and STRIDES blood donor studies which we have available in Cambridge. Users will have to load their own Sysmex flow cytometry data.

In [None]:
import os
import sys
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from pathlib import Path

# Import from sysmexcbctools.transfer
from sysmexcbctools.transfer.sysmexalign import FlowTransformer
from sysmexcbctools.transfer.config import load_config

import matplotlib

matplotlib.rcParams["pdf.fonttype"] = 42
matplotlib.rcParams["ps.fonttype"] = 42


import scienceplots
import seaborn as sns

plt.style.use(["science", "nature"])

SEABORN_PALETTE = "colorblind"
seaborn_colors = sns.color_palette(SEABORN_PALETTE)

print("✓ Imports successful")

## 1. Prepare Data Paths

Flow cytometry data is stored in `.116.csv` files with the format:
```
CHANNEL_[analyserID][...][YYYYMMDD_HHMMSS][SampleNumber].116.csv
```

For this example, we'll use **real data from two different Sysmex analysers**:
- **Source**: STRIDES study data
- **Target**: INTERVAL study (analyser XN-10^11036)

We'll work with the **RET (reticulocyte)** channel as the primary example.

In [None]:
# Load configuration to get data paths
# The config file contains paths to RDS storage for different datasets

# Load data paths configuration
config = load_config(str("../../sysmexcbctools/transfer/config/data_paths.yaml"))

# Get dataset paths for two different analysers
# Source: STRIDES merged data
# Target: INTERVAL analyser XN-10^11036
SOURCE_DIR = config["datasets"]["raw"]["strides_merged"] + "/SCT"
TARGET_DIR = config["datasets"]["raw"]["interval_36"] + "/SCT"

# Also load sample numbers for filtering (optional but recommended)
# These .npy files contain centile sample numbers that represent the population well
source_samples_file = config["files"]["centile_samples"]["strides"]
target_samples_file = config["files"]["centile_samples"]["interval_baseline_36"]

# Create output directory
OUTPUT_DIR = "../outputs/flow_transformed/"
os.makedirs(OUTPUT_DIR, exist_ok=True)

print("Data Configuration Loaded:")
print(f"Source directory: {SOURCE_DIR}")
print(f"Target directory: {TARGET_DIR}")
print(f"Source samples file: {source_samples_file}")
print(f"Target samples file: {target_samples_file}")
print(f"Output directory: {OUTPUT_DIR}")

In [None]:
# Load sample numbers for filtering (recommended for large datasets)
# These files contain sample numbers that represent population centiles well
source_sample_nos = np.load(source_samples_file, allow_pickle=True)
target_sample_nos = np.load(target_samples_file, allow_pickle=True)

print(f"\nSample filtering:")
print(f"  Source samples: {len(source_sample_nos):,} centile samples")
print(f"  Target samples: {len(target_sample_nos):,} centile samples")

# Get all RET files from both directories
# Note: This may take a moment for large directories on RDS
print("\nScanning for RET channel files...")
source_files_all = sorted(Path(SOURCE_DIR).glob("RET*.csv"))
target_files_all = sorted(Path(TARGET_DIR).glob("RET*.csv"))

print(f"  Total source RET files: {len(source_files_all):,}")
print(f"  Total target RET files: {len(target_files_all):,}")

# For initial testing, we'll use just the centile samples
# This provides a representative subset without processing millions of cells
USE_CENTILE_SAMPLES = True  # Set to False to use all files

if USE_CENTILE_SAMPLES:
    print("\nUsing centile samples for efficient computation...")
    # Filter to only files matching centile sample numbers
    # File format: RET_[analyser][...][date_time][sample_number].116.csv
    # We'll match files by sample number (extract from filename)

    def extract_sample_number(filepath):
        """Extract sample number from .116.csv filename"""
        # Format: RET_[XN-20^14232][00-22_123][20211002_001515][      XXXXXX].116.csv
        # Sample number is in the last bracketed section before .116.csv (XXXXXX)
        import re

        match = re.search(r"\[([^\]]+)\]\.116", str(filepath.name))
        if match:
            return match.group(1).strip()
        return None

    # Filter files by sample numbers
    source_files = [
        f for f in source_files_all if extract_sample_number(f) in source_sample_nos
    ]
    target_files = [
        f for f in target_files_all if extract_sample_number(f) in target_sample_nos
    ]

    print(f"  Filtered source files: {len(source_files):,}")
    print(f"  Filtered target files: {len(target_files):,}")
else:
    source_files = source_files_all
    target_files = target_files_all
    print("\nUsing all available files (may be memory-intensive)...")

# Show example filenames
if len(source_files) > 0:
    print(f"\nExample source file:")
    print(f"  {source_files[0].name}")
if len(target_files) > 0:
    print(f"\nExample target file:")
    print(f"  {target_files[0].name}")

### Understanding the Data

**Why use centile samples?**
- Flow cytometry files can be very large (millions of data points per sample)
- Centile samples are pre-selected to represent the population distribution well (if you have QC samples available, that's even better)
- This makes computation tractable while maintaining statistical validity
- You can still use all samples by setting `USE_CENTILE_SAMPLES = False` above

**What we're aligning:**
- **Source**: STRIDES study (different analyser/lab/year)
- **Target**: INTERVAL study baseline (analyser XN-10^11036)
- **Goal**: Transform STRIDES data to match INTERVAL distribution

This enables cross-study comparisons and pooling data from multiple analysers.

## 2. Initialize the Transformer

Create a `FlowTransformer` for the RET channel with sensible defaults.

**Note on `save_fitted_data` parameter:**
- By default, the transformer does NOT save the downsampled data used for fitting (for privacy and memory reasons)
- Set `save_fitted_data=True` to enable saving the data in `source_data_` and `target_data_` attributes
- This is useful for visualization and validation (as shown below)
- The saved data is the downsampled version (up to `max_samples`) after filtering out saturated values

In [None]:
transformer = FlowTransformer(
    channel="RET",  # Channel to transform
    n_components=64,  # Number of Gaussian components
    covariance_type="full",  # Full covariance matrices
    transport_method="rand",  # T_rand transport method
    max_samples=1_000_000,  # Downsample if more samples
    preserve_rare=False,  # Disable rare handling
    omega_threshold=0.0,  # Don't filter target components
    use_cascade_init=True,  # Use cascade initialisation (target GMM initialised from fitted source GMM)
    n_jobs=-1,  # Use all CPU cores
    random_state=42,  # For reproducibility
    save_fitted_data=True,  # Save data used for fitting (for later visualization) -- this might not always be a good idea since can be very large or sensitive patient data
)

print("Transformer configuration:")
print(f"  Channel: {transformer.channel}")
print(f"  GMM components: {transformer.n_components}")
print(f"  Use cascade init: {transformer.use_cascade_init}")
print(f"  Rare threshold: {transformer.rare_threshold}")
print(f"  Omega threshold: {transformer.omega_threshold}")
print(f"  Save fitted data: {transformer.save_fitted_data}")

## 3. Fit the Transformation

Fit GMM models to both source and target distributions, then compute the optimal transport map.

**Note**: This step can take several minutes depending on data size and number of components. The transformer will:
1. Load and concatenate all source and target files
2. Downsample if needed (to `max_samples`)
3. Fit GMM models to both distributions
4. Compute the optimal transport plan between GMM components

In [None]:
# Check if we have data to work with
if len(source_files) == 0 or len(target_files) == 0:
    print("⚠ ERROR: No files found!")
    print(f"  Source files: {len(source_files)}")
    print(f"  Target files: {len(target_files)}")
    print("\nPossible causes:")
    print("  - Data not available on this system (requires HPC/RDS access)")
    print("  - Incorrect paths in config/data_paths.yaml")
    print("  - Sample number filtering removed all files")
    print("\nTo proceed:")
    print("  1. Verify you have access to the RDS data")
    print("  2. Check paths in config/data_paths.yaml")
    print("  3. Try setting USE_CENTILE_SAMPLES = False in cell above")
else:
    print(
        f"✓ Found {len(source_files)} source files and {len(target_files)} target files"
    )
    print("\nFitting GMM-OT transformation...")
    print("This may take several minutes depending on data size.")
    print(f"  Max samples per distribution: {transformer.max_samples:,}")
    print(f"  GMM components: {transformer.n_components}")
    print(f"  Parallel jobs: {transformer.n_jobs}")

    # Fit using file paths
    # The FlowTransformer will load files, concatenate, and fit GMMs
    transformer.fit(
        source_files=[str(f) for f in source_files],
        target_files=[str(f) for f in target_files],
    )

    print("\n✓ Transformer fitted successfully!")
    print(f"  Source GMM: {transformer.n_components} components")
    print(f"  Target GMM: {transformer.n_components} components")
    print(f"  Transport computed: {transformer.is_fitted_}")

In [None]:
# Save the fitted transformer for later use
# This allows you to skip the time-consuming fitting step in future sessions

save_path = "../outputs/flow_transformer_ret.pkl"
transformer.save(save_path)
print(f"✓ Transformer saved to: {save_path}")
print(f"  You can now load this transformer in future sessions to skip re-fitting")

In [None]:
# Load a previously saved transformer
# Uncomment and run this cell if you want to load a saved transformer instead of fitting

# from sysmexcbctools.transfer.sysmexalign import FlowTransformer
# save_path = "../outputs/flow_transformer_ret.pkl"
# transformer = FlowTransformer.load(save_path)
# print(f"✓ Transformer loaded from: {save_path}")
# print(f"  Channel: {transformer.channel}")
# print(f"  GMM components: {transformer.n_components}")
# print(f"  Fitted: {transformer.is_fitted_}")

## 4. Transform Source Data

Apply the learned transformation to source data files. This will:
1. Load each source file
2. Apply the GMM-OT transformation to align with target distribution  
3. Save transformed files to the output directory

Transformed files maintain the same format as input files (`.116.csv`).

In [None]:
if not transformer.is_fitted_:
    print("⚠ Transformer not fitted. Skipping transformation.")
    print("Run the 'Fit the Transformation' cell above first.")
else:
    print("Transforming source files to match target distribution...")
    print(f"  Transforming {len(source_files)} files")
    print(f"  Output directory: {OUTPUT_DIR}")

    # Transform files and save to output directory
    output_files = transformer.transform(
        source_files=[str(f) for f in source_files],
        output_dir=OUTPUT_DIR,
    )

    print(f"\n✓ Successfully transformed {len(output_files)} files")
    print(f"  Output directory: {OUTPUT_DIR}")
    if len(output_files) > 0:
        print(f"\nExample output files:")
        for i, f in enumerate(output_files[:3]):
            print(f"  {i+1}. {Path(f).name}")
        if len(output_files) > 3:
            print(f"  ... and {len(output_files) - 3} more")

## 5. Validate Transformation Quality

Compare the transformed distribution to the target using Wasserstein distance and likelihood scores.

In [None]:
if not transformer.is_fitted_:
    print("⚠ Transformer not fitted. Skipping visualization.")
elif transformer.source_data_ is None or transformer.target_data_ is None:
    print("⚠ Fitted data not available. Skipping visualization.")
    print(
        "  To enable visualization, create the transformer with save_fitted_data=True"
    )
else:
    print("Creating distribution comparison plots...")

    # Get the saved downsampled data used for fitting
    source_sample_data = transformer.source_data_
    target_sample_data = transformer.target_data_

    print(f"  Source data used for fitting: {source_sample_data.shape[0]:,} samples")
    print(f"  Target data used for fitting: {target_sample_data.shape[0]:,} samples")

    # Transform the source data to compare with target
    transformed_sample = transformer.transform_array(source_sample_data)

    # Downsample for plotting (use first 20,000 points for clarity)
    n_plot = min(20000, len(source_sample_data), len(target_sample_data))
    source_plot = source_sample_data[:n_plot]
    target_plot = target_sample_data[:n_plot]
    transformed_plot = transformed_sample[:n_plot]

    print(f"  Plotting {n_plot:,} cells per distribution")

    # Create figure with multiple subplots
    fig, axes = plt.subplots(2, 2, figsize=(14, 12))

    # 1D histograms for each dimension
    # Dimension 1 (SFL - Side Fluorescence)
    axes[0, 0].hist(
        source_plot[:, 0],
        bins=50,
        alpha=0.5,
        label="Source (STRIDES)",
        density=True,
        color="blue",
    )
    axes[0, 0].hist(
        target_plot[:, 0],
        bins=50,
        alpha=0.5,
        label="Target (INTERVAL)",
        density=True,
        color="orange",
    )
    axes[0, 0].hist(
        transformed_plot[:, 0],
        bins=50,
        alpha=0.5,
        label="Transformed",
        density=True,
        color="green",
    )
    axes[0, 0].set_xlabel("SFL (Side Fluorescence)")
    axes[0, 0].set_ylabel("Density")
    axes[0, 0].set_title("RET Channel - Dimension 1 (SFL)")
    axes[0, 0].legend()
    axes[0, 0].grid(alpha=0.3)

    # Dimension 2 (FSC - Forward Scatter)
    axes[0, 1].hist(
        source_plot[:, 1],
        bins=50,
        alpha=0.5,
        label="Source (STRIDES)",
        density=True,
        color="blue",
    )
    axes[0, 1].hist(
        target_plot[:, 1],
        bins=50,
        alpha=0.5,
        label="Target (INTERVAL)",
        density=True,
        color="orange",
    )
    axes[0, 1].hist(
        transformed_plot[:, 1],
        bins=50,
        alpha=0.5,
        label="Transformed",
        density=True,
        color="green",
    )
    axes[0, 1].set_xlabel("FSC (Forward Scatter)")
    axes[0, 1].set_ylabel("Density")
    axes[0, 1].set_title("RET Channel - Dimension 2 (FSC)")
    axes[0, 1].legend()
    axes[0, 1].grid(alpha=0.3)

    # 2D scatter plots
    # Source vs Target
    axes[1, 0].scatter(
        source_plot[:, 0],
        source_plot[:, 1],
        s=1,
        alpha=0.3,
        label="Source (STRIDES)",
        color="blue",
    )
    axes[1, 0].scatter(
        target_plot[:, 0],
        target_plot[:, 1],
        s=1,
        alpha=0.3,
        label="Target (INTERVAL)",
        color="orange",
    )
    axes[1, 0].set_xlabel("SFL (Side Fluorescence)")
    axes[1, 0].set_ylabel("FSC (Forward Scatter)")
    axes[1, 0].set_title("Before Transformation: Source vs Target")
    axes[1, 0].legend(markerscale=10)
    axes[1, 0].grid(alpha=0.3)

    # Transformed vs Target
    axes[1, 1].scatter(
        transformed_plot[:, 0],
        transformed_plot[:, 1],
        s=1,
        alpha=0.3,
        label="Transformed",
        color="green",
    )
    axes[1, 1].scatter(
        target_plot[:, 0],
        target_plot[:, 1],
        s=1,
        alpha=0.3,
        label="Target (INTERVAL)",
        color="orange",
    )
    axes[1, 1].set_xlabel("SFL (Side Fluorescence)")
    axes[1, 1].set_ylabel("FSC (Forward Scatter)")
    axes[1, 1].set_title("After Transformation: Transformed vs Target")
    axes[1, 1].legend(markerscale=10)
    axes[1, 1].grid(alpha=0.3)

    plt.tight_layout()
    plt.savefig("../outputs/distribution_comparison.png", dpi=150, bbox_inches="tight")
    plt.show()

    print("\n✓ Distribution plots saved to ../outputs/distribution_comparison.png")
    print("\nInterpretation:")
    print("  - Top row: 1D marginal distributions for each dimension")
    print("  - Bottom left: Original source (blue) vs target (orange)")
    print("  - Bottom right: Transformed (green) should match target (orange)")
    print("  - Success = green and orange distributions overlap well")

In [None]:
if not transformer.is_fitted_:
    print("⚠ Transformer not fitted. Skipping enhanced visualization.")
elif transformer.source_data_ is None or transformer.target_data_ is None:
    print("⚠ Fitted data not available. Skipping enhanced visualization.")
else:
    import json
    from matplotlib.path import Path as MplPath
    from matplotlib.patches import Polygon

    print("Creating enhanced transformation visualizations...")

    # Get data
    source_data = transformer.source_data_
    target_data = transformer.target_data_
    transformed_data = transformer.transform_array(source_data)

    # Load gates for per-population analysis
    gate_file = "../../sysmexcbctools/transfer/flow_gates/json_gates/RET_gates.json"
    with open(gate_file, "r") as f:
        gates = json.load(f)

    # Downsample for visualization
    n_plot = min(30000, len(source_data), len(target_data))
    np.random.seed(42)
    plot_idx = np.random.choice(len(source_data), n_plot, replace=False)

    source_plot = source_data[plot_idx]
    target_plot = target_data[np.random.choice(len(target_data), n_plot, replace=False)]
    transformed_plot = transformed_data[plot_idx]

    print(f"  Visualizing {n_plot:,} cells per distribution\n")

    # ========================================================================
    # Figure 1: 3-Panel Comparison with 2D Density
    # ========================================================================
    fig, axes = plt.subplots(2, 3, figsize=(18, 12))

    # Common settings
    xlim = (0, 260)
    ylim = (0, 260)
    bins = 100

    # Helper function to add gates
    def add_gates_to_plot(ax, gates_dict, alpha=0.6):
        colors = {"RBC": "red", "RET": "blue", "PLT": "green"}
        for pop_name, coords in gates_dict.items():
            if len(coords) > 0:
                poly = Polygon(
                    coords,
                    fill=False,
                    edgecolor=colors.get(pop_name, "white"),
                    linewidth=2,
                    alpha=alpha,
                    linestyle="--",
                )
                ax.add_patch(poly)

    # Row 1: Scatter plots
    # Panel 1: Source
    ax = axes[0, 0]
    ax.scatter(
        source_plot[:, 0],
        source_plot[:, 1],
        s=0.5,
        alpha=0.2,
        c="blue",
        rasterized=True,
    )
    add_gates_to_plot(ax, gates)
    ax.set_xlabel("SFL (Side Fluorescence)", fontsize=11)
    ax.set_ylabel("FSC (Forward Scatter)", fontsize=11)
    ax.set_title("Source (STRIDES)", fontsize=13, fontweight="bold")
    ax.set_xlim(xlim)
    ax.set_ylim(ylim)
    ax.grid(alpha=0.3)

    # Panel 2: Transformed
    ax = axes[0, 1]
    ax.scatter(
        transformed_plot[:, 0],
        transformed_plot[:, 1],
        s=0.5,
        alpha=0.2,
        c="green",
        rasterized=True,
    )
    add_gates_to_plot(ax, gates)
    ax.set_xlabel("SFL (Side Fluorescence)", fontsize=11)
    ax.set_ylabel("FSC (Forward Scatter)", fontsize=11)
    ax.set_title("Transformed", fontsize=13, fontweight="bold", color="green")
    ax.set_xlim(xlim)
    ax.set_ylim(ylim)
    ax.grid(alpha=0.3)

    # Panel 3: Target
    ax = axes[0, 2]
    ax.scatter(
        target_plot[:, 0],
        target_plot[:, 1],
        s=0.5,
        alpha=0.2,
        c="orange",
        rasterized=True,
    )
    add_gates_to_plot(ax, gates)
    ax.set_xlabel("SFL (Side Fluorescence)", fontsize=11)
    ax.set_ylabel("FSC (Forward Scatter)", fontsize=11)
    ax.set_title("Target (INTERVAL)", fontsize=13, fontweight="bold")
    ax.set_xlim(xlim)
    ax.set_ylim(ylim)
    ax.grid(alpha=0.3)

    # Row 2: 2D Density Heatmaps
    # Panel 1: Source density
    ax = axes[1, 0]
    h_source = ax.hist2d(
        source_plot[:, 0],
        source_plot[:, 1],
        bins=bins,
        cmap="Blues",
        range=[xlim, ylim],
        density=True,
        cmin=1e-6,
    )
    add_gates_to_plot(ax, gates)
    ax.set_xlabel("SFL (Side Fluorescence)", fontsize=11)
    ax.set_ylabel("FSC (Forward Scatter)", fontsize=11)
    ax.set_title("Source Density", fontsize=12)
    plt.colorbar(h_source[3], ax=ax, label="Density")

    # Panel 2: Transformed density
    ax = axes[1, 1]
    h_trans = ax.hist2d(
        transformed_plot[:, 0],
        transformed_plot[:, 1],
        bins=bins,
        cmap="Greens",
        range=[xlim, ylim],
        density=True,
        cmin=1e-6,
    )
    add_gates_to_plot(ax, gates)
    ax.set_xlabel("SFL (Side Fluorescence)", fontsize=11)
    ax.set_ylabel("FSC (Forward Scatter)", fontsize=11)
    ax.set_title("Transformed Density", fontsize=12)
    plt.colorbar(h_trans[3], ax=ax, label="Density")

    # Panel 3: Target density
    ax = axes[1, 2]
    h_target = ax.hist2d(
        target_plot[:, 0],
        target_plot[:, 1],
        bins=bins,
        cmap="Oranges",
        range=[xlim, ylim],
        density=True,
        cmin=1e-6,
    )
    add_gates_to_plot(ax, gates)
    ax.set_xlabel("SFL (Side Fluorescence)", fontsize=11)
    ax.set_ylabel("FSC (Forward Scatter)", fontsize=11)
    ax.set_title("Target Density", fontsize=12)
    plt.colorbar(h_target[3], ax=ax, label="Density")

    plt.tight_layout()
    plt.savefig(
        "../outputs/transformation_3panel_comparison.png", dpi=150, bbox_inches="tight"
    )
    plt.show()

    print(
        "✓ 3-panel comparison saved to ../outputs/transformation_3panel_comparison.png\n"
    )

    # ========================================================================
    # Figure 2: Per-Population Comparison
    # ========================================================================

    def classify_by_gate(data, gates_dict):
        """Classify points by gate membership"""
        labels = np.full(len(data), -1, dtype=int)  # -1 = ungated
        label_names = {}

        for label_idx, (pop_name, coords) in enumerate(gates_dict.items()):
            if len(coords) == 0:
                continue
            path = MplPath(coords)
            mask = path.contains_points(data)
            labels[mask] = label_idx
            label_names[label_idx] = pop_name

        return labels, label_names

    # Classify points
    source_labels, label_names = classify_by_gate(source_plot, gates)
    target_labels, _ = classify_by_gate(target_plot, gates)
    transformed_labels, _ = classify_by_gate(transformed_plot, gates)

    # Create per-population figure
    n_pops = len(label_names)
    fig, axes = plt.subplots(n_pops, 3, figsize=(16, 4 * n_pops))

    if n_pops == 1:
        axes = axes.reshape(1, -1)

    pop_colors = {"RBC": "red", "RET": "blue", "PLT": "green"}

    for row_idx, (label_idx, pop_name) in enumerate(label_names.items()):
        # Get points in this population
        source_pop = source_plot[source_labels == label_idx]
        target_pop = target_plot[target_labels == label_idx]
        transformed_pop = transformed_plot[transformed_labels == label_idx]

        color = pop_colors.get(pop_name, "gray")

        # Column 1: Source vs Target overlay for this population
        ax = axes[row_idx, 0]
        if len(source_pop) > 0:
            ax.scatter(
                source_pop[:, 0],
                source_pop[:, 1],
                s=1,
                alpha=0.3,
                c="blue",
                label=f"Source {pop_name} (n={len(source_pop):,})",
                rasterized=True,
            )
        if len(target_pop) > 0:
            ax.scatter(
                target_pop[:, 0],
                target_pop[:, 1],
                s=1,
                alpha=0.3,
                c="orange",
                label=f"Target {pop_name} (n={len(target_pop):,})",
                rasterized=True,
            )
        ax.set_xlabel("SFL", fontsize=10)
        ax.set_ylabel("FSC", fontsize=10)
        ax.set_title(f"{pop_name}: Source vs Target", fontsize=11, fontweight="bold")
        ax.legend(markerscale=5, fontsize=8)
        ax.grid(alpha=0.3)
        ax.set_xlim(xlim)
        ax.set_ylim(ylim)

        # Column 2: Transformed vs Target overlay for this population
        ax = axes[row_idx, 1]
        if len(transformed_pop) > 0:
            ax.scatter(
                transformed_pop[:, 0],
                transformed_pop[:, 1],
                s=1,
                alpha=0.3,
                c="green",
                label=f"Transformed {pop_name} (n={len(transformed_pop):,})",
                rasterized=True,
            )
        if len(target_pop) > 0:
            ax.scatter(
                target_pop[:, 0],
                target_pop[:, 1],
                s=1,
                alpha=0.3,
                c="orange",
                label=f"Target {pop_name} (n={len(target_pop):,})",
                rasterized=True,
            )
        ax.set_xlabel("SFL", fontsize=10)
        ax.set_ylabel("FSC", fontsize=10)
        ax.set_title(
            f"{pop_name}: Transformed vs Target",
            fontsize=11,
            fontweight="bold",
            color="green",
        )
        ax.legend(markerscale=5, fontsize=8)
        ax.grid(alpha=0.3)
        ax.set_xlim(xlim)
        ax.set_ylim(ylim)

        # Column 3: 1D marginal comparisons
        ax = axes[row_idx, 2]

        # SFL marginal (horizontal)
        ax_sfl = ax
        if len(source_pop) > 0:
            ax_sfl.hist(
                source_pop[:, 0],
                bins=50,
                alpha=0.4,
                color="blue",
                density=True,
                label="Source",
            )
        if len(transformed_pop) > 0:
            ax_sfl.hist(
                transformed_pop[:, 0],
                bins=50,
                alpha=0.4,
                color="green",
                density=True,
                label="Transformed",
            )
        if len(target_pop) > 0:
            ax_sfl.hist(
                target_pop[:, 0],
                bins=50,
                alpha=0.4,
                color="orange",
                density=True,
                label="Target",
            )
        ax_sfl.set_xlabel("SFL (Side Fluorescence)", fontsize=10)
        ax_sfl.set_ylabel("Density", fontsize=10)
        ax_sfl.set_title(f"{pop_name}: SFL Marginal Distribution", fontsize=11)
        ax_sfl.legend(fontsize=8)
        ax_sfl.grid(alpha=0.3)
        ax_sfl.set_xlim(xlim)

    plt.tight_layout()
    plt.savefig(
        "../outputs/transformation_per_population.png", dpi=150, bbox_inches="tight"
    )
    plt.show()

    print(
        "✓ Per-population comparison saved to ../outputs/transformation_per_population.png\n"
    )

    # ========================================================================
    # Figure 3: Overlay Comparison (Transformed vs Target)
    # ========================================================================
    fig, axes = plt.subplots(1, 2, figsize=(14, 6))

    # Panel 1: Direct overlay
    ax = axes[0]
    ax.scatter(
        target_plot[:, 0],
        target_plot[:, 1],
        s=1,
        alpha=0.15,
        c="orange",
        label=f"Target (n={len(target_plot):,})",
        rasterized=True,
    )
    ax.scatter(
        transformed_plot[:, 0],
        transformed_plot[:, 1],
        s=1,
        alpha=0.15,
        c="green",
        label=f"Transformed (n={len(transformed_plot):,})",
        rasterized=True,
    )
    add_gates_to_plot(ax, gates)
    ax.set_xlabel("SFL (Side Fluorescence)", fontsize=11)
    ax.set_ylabel("FSC (Forward Scatter)", fontsize=11)
    ax.set_title(
        "Overlay: Transformed (green) vs Target (orange)",
        fontsize=12,
        fontweight="bold",
    )
    ax.legend(markerscale=10, fontsize=10)
    ax.grid(alpha=0.3)
    ax.set_xlim(xlim)
    ax.set_ylim(ylim)

    # Panel 2: Density difference (Transformed - Target)
    ax = axes[1]

    # Compute 2D histograms
    H_trans, xedges, yedges = np.histogram2d(
        transformed_plot[:, 0],
        transformed_plot[:, 1],
        bins=bins,
        range=[xlim, ylim],
        density=True,
    )
    H_target, _, _ = np.histogram2d(
        target_plot[:, 0],
        target_plot[:, 1],
        bins=bins,
        range=[xlim, ylim],
        density=True,
    )

    # Compute difference
    H_diff = H_trans - H_target

    # Plot difference
    im = ax.imshow(
        H_diff.T,
        origin="lower",
        extent=[xlim[0], xlim[1], ylim[0], ylim[1]],
        cmap="RdBu_r",
        vmin=-np.percentile(np.abs(H_diff), 95),
        vmax=np.percentile(np.abs(H_diff), 95),
        aspect="auto",
    )
    add_gates_to_plot(ax, gates, alpha=0.8)
    ax.set_xlabel("SFL (Side Fluorescence)", fontsize=11)
    ax.set_ylabel("FSC (Forward Scatter)", fontsize=11)
    ax.set_title(
        "Density Difference: Transformed - Target", fontsize=12, fontweight="bold"
    )
    cbar = plt.colorbar(im, ax=ax)
    cbar.set_label("Density Difference", fontsize=10)
    ax.grid(alpha=0.3, color="white", linewidth=0.5)

    plt.tight_layout()
    plt.savefig(
        "../outputs/transformation_overlay_comparison.png", dpi=150, bbox_inches="tight"
    )
    plt.show()

    print(
        "✓ Overlay comparison saved to ../outputs/transformation_overlay_comparison.png\n"
    )

    print("=" * 60)
    print("VISUALIZATION SUMMARY")
    print("=" * 60)
    print("\nGenerated three comprehensive visualization figures:")
    print("  1. 3-panel comparison: Source | Transformed | Target")
    print("     - Top row: Scatter plots with gate overlays")
    print("     - Bottom row: 2D density heatmaps")
    print()
    print("  2. Per-population comparison: Separate analysis for each gated population")
    print("     - Shows how well transformation works for RBC, RET, PLT separately")
    print("     - Critical for validating rare population handling")
    print()
    print("  3. Overlay comparison: Direct visual comparison")
    print("     - Left: Transformed (green) vs Target (orange) overlay")
    print("     - Right: Density difference map (red=excess, blue=deficit)")
    print()
    print("Interpretation:")
    print("  - Transformed should visually match Target distribution")
    print("  - Density difference map should be mostly white (neutral)")
    print("  - Per-population plots show if rare populations are well-aligned")
    print("  - Red/blue regions in difference map indicate misalignment")

    print("\n✓ Enhanced visualization complete")

## 5.2. Enhanced Transformation Visualization

Comprehensive visual comparison of the transformation quality:
- **3-panel comparison**: Source | Transformed | Target side-by-side
- **2D density heatmaps**: Show distribution densities for detailed comparison
- **Per-population overlays**: Evaluate transformation quality for each gated cell population

In [None]:
if not transformer.is_fitted_:
    print("⚠ Transformer not fitted. Skipping metrics computation.")
elif transformer.source_data_ is None or transformer.target_data_ is None:
    print("⚠ Fitted data not available. Skipping metrics computation.")
else:
    from scipy.stats import wasserstein_distance
    from scipy.spatial.distance import cdist
    import ot

    print("Computing distribution distance metrics...")

    # Get data
    source_data = transformer.source_data_
    target_data = transformer.target_data_
    transformed_data = transformer.transform_array(source_data)

    # Downsample for faster computation if needed
    n_metric = min(50000, len(source_data), len(target_data))
    np.random.seed(42)
    source_metric = source_data[
        np.random.choice(len(source_data), n_metric, replace=False)
    ]
    target_metric = target_data[
        np.random.choice(len(target_data), n_metric, replace=False)
    ]
    transformed_metric = transformer.transform_array(source_metric)

    source_metric = source_metric.astype(np.float64)
    target_metric = target_metric.astype(np.float64)
    transformed_metric = transformed_metric.astype(np.float64)

    print(f"  Using {n_metric:,} samples for metric computation\n")

    # ===== 1D Wasserstein Distances (per dimension) =====
    print("=" * 60)
    print("1D Wasserstein Distance (per dimension)")
    print("=" * 60)

    dim_names = ["SFL (Side Fluorescence)", "FSC (Forward Scatter)"]
    for dim in range(2):
        w_source_target = ot.emd2_1d(source_metric[:, dim], target_metric[:, dim])
        w_trans_target = ot.emd2_1d(transformed_metric[:, dim], target_metric[:, dim])
        improvement = (
            (1 - w_trans_target / w_source_target) * 100 if w_source_target > 0 else 0
        )

        print(f"\n{dim_names[dim]}:")
        print(f"  Source → Target:      {w_source_target:.4f}")
        print(f"  Transformed → Target: {w_trans_target:.4f}")
        print(f"  Improvement:          {improvement:+.1f}%")

    # ===== 2D Wasserstein Distance (approximation using sliced Wasserstein) =====
    print("\n" + "=" * 60)
    print("2D Wasserstein Distance (approximate)")
    print("=" * 60)
    # print("(Average of 1D Wasserstein over 50 random projections)")

    # def sliced_wasserstein(X, Y, n_projections=50, seed=42):
    #     """Compute sliced Wasserstein distance (approximation of 2D Wasserstein)"""
    #     np.random.seed(seed)
    #     distances = []
    #     for _ in range(n_projections):
    #         # Random direction
    #         theta = np.random.uniform(0, 2 * np.pi)
    #         direction = np.array([np.cos(theta), np.sin(theta)])
    #         # Project data
    #         X_proj = X @ direction
    #         Y_proj = Y @ direction
    #         # 1D Wasserstein
    #         distances.append(wasserstein_distance(X_proj, Y_proj))
    #     return np.mean(distances)

    def sliced_wasserstein(X, Y, n_projections=500, seed=42):
        """Compute sliced Wasserstein distance using POT library"""
        return ot.sliced_wasserstein_distance(
            X, Y, n_projections=n_projections, seed=seed
        )

    # sw_source_target = sliced_wasserstein(source_metric, target_metric)
    # sw_trans_target = sliced_wasserstein(transformed_metric, target_metric)
    # sw_improvement = (1 - sw_trans_target / sw_source_target) * 100 if sw_source_target > 0 else 0

    # again using a subset but computing actual 2D Wasserstein distance via POT
    n_samples_wasserstein = 5000

    sw_source_target = ot.solve_sample(
        source_metric[
            np.random.choice(len(source_metric), n_samples_wasserstein, replace=False)
        ],
        target_metric[
            np.random.choice(len(target_metric), n_samples_wasserstein, replace=False)
        ],
        metric="euclidean",
        verbose=True,
    ).value
    sw_trans_target = ot.solve_sample(
        transformed_metric[
            np.random.choice(
                len(transformed_metric), n_samples_wasserstein, replace=False
            )
        ],
        target_metric[
            np.random.choice(len(target_metric), n_samples_wasserstein, replace=False)
        ],
        metric="euclidean",
        verbose=True,
    ).value
    sw_improvement = (
        (1 - sw_trans_target / sw_source_target) * 100 if sw_source_target > 0 else 0
    )

    print(f"\n  Source → Target:      {sw_source_target:.4f}")
    print(f"  Transformed → Target: {sw_trans_target:.4f}")
    print(f"  Improvement:          {sw_improvement:+.1f}%")

    # ===== Maximum Mean Discrepancy (MMD) =====
    print("\n" + "=" * 60)
    print("Maximum Mean Discrepancy (MMD)")
    print("=" * 60)
    print("(Kernel-based distance, using Gaussian kernel with bandwidth=10)")

    def compute_mmd(X, Y, bandwidth=10.0):
        """Compute Maximum Mean Discrepancy with Gaussian kernel"""
        # Use subset for computational efficiency
        n = min(5000, len(X), len(Y))
        X_sub = X[np.random.choice(len(X), n, replace=False)]
        Y_sub = Y[np.random.choice(len(Y), n, replace=False)]

        # Gaussian kernel
        def kernel(x, y, bandwidth):
            return np.exp(-np.sum((x - y) ** 2) / (2 * bandwidth**2))

        # Compute kernel matrices
        XX = cdist(X_sub, X_sub, metric="sqeuclidean")
        YY = cdist(Y_sub, Y_sub, metric="sqeuclidean")
        XY = cdist(X_sub, Y_sub, metric="sqeuclidean")

        # Apply Gaussian kernel
        K_XX = np.exp(-XX / (2 * bandwidth**2))
        K_YY = np.exp(-YY / (2 * bandwidth**2))
        K_XY = np.exp(-XY / (2 * bandwidth**2))

        # MMD^2 estimate
        mmd_sq = K_XX.mean() + K_YY.mean() - 2 * K_XY.mean()
        return np.sqrt(max(0, mmd_sq))  # Ensure non-negative

    mmd_source_target = compute_mmd(source_metric, target_metric)
    mmd_trans_target = compute_mmd(transformed_metric, target_metric)
    mmd_improvement = (
        (1 - mmd_trans_target / mmd_source_target) * 100 if mmd_source_target > 0 else 0
    )

    print(f"\n  Source → Target:      {mmd_source_target:.6f}")
    print(f"  Transformed → Target: {mmd_trans_target:.6f}")
    print(f"  Improvement:          {mmd_improvement:+.1f}%")

    # ===== Per-Population Analysis Using Gates =====
    print("\n" + "=" * 60)
    print("Per-Population Distance Metrics (Gate-Based)")
    print("=" * 60)

    # Load gates
    import json
    from matplotlib.path import Path as MplPath

    gate_file = "../../sysmexcbctools/transfer/flow_gates/json_gates/RET_gates.json"
    with open(gate_file, "r") as f:
        gates = json.load(f)

    def classify_by_gate(data, gates_dict):
        """Classify points by gate membership"""
        labels = np.full(len(data), -1, dtype=int)  # -1 = ungated
        label_names = {}

        for label_idx, (pop_name, coords) in enumerate(gates_dict.items()):
            if len(coords) == 0:
                continue
            path = MplPath(coords)
            mask = path.contains_points(data)
            labels[mask] = label_idx
            label_names[label_idx] = pop_name

        return labels, label_names

    # Classify all points
    source_labels, label_names = classify_by_gate(source_metric, gates)
    target_labels, _ = classify_by_gate(target_metric, gates)
    transformed_labels, _ = classify_by_gate(transformed_metric, gates)

    print("\nPer-population 2D Sliced Wasserstein Distance:")
    print()

    for label_idx, pop_name in label_names.items():
        # Get points in this population
        source_pop = source_metric[source_labels == label_idx]
        target_pop = target_metric[target_labels == label_idx]
        transformed_pop = transformed_metric[transformed_labels == label_idx]

        if len(source_pop) < 100 or len(target_pop) < 100:
            print(
                f"  {pop_name}: Insufficient data (source={len(source_pop)}, target={len(target_pop)})"
            )
            continue

        # Compute sliced Wasserstein for this population
        sw_pop_source = sliced_wasserstein(source_pop, target_pop, n_projections=30)
        sw_pop_trans = sliced_wasserstein(transformed_pop, target_pop, n_projections=30)
        sw_pop_improvement = (
            (1 - sw_pop_trans / sw_pop_source) * 100 if sw_pop_source > 0 else 0
        )

        print(f"  {pop_name}:")
        print(
            f"    Source → Target:      {sw_pop_source:.4f} (n_source={len(source_pop):,}, n_target={len(target_pop):,})"
        )
        print(
            f"    Transformed → Target: {sw_pop_trans:.4f} (n_transformed={len(transformed_pop):,})"
        )
        print(f"    Improvement:          {sw_pop_improvement:+.1f}%")
        print()

    # ===== Summary =====
    print("=" * 60)
    print("SUMMARY")
    print("=" * 60)
    print("\nOverall transformation effectiveness:")
    print(f"  2D Sliced Wasserstein: {sw_improvement:+.1f}% improvement")
    print(f"  MMD:                   {mmd_improvement:+.1f}% improvement")
    print("\nInterpretation:")
    print("  - Positive % = transformation reduced distance (good)")
    print("  - Negative % = transformation increased distance (bad)")
    print("  - Goal: All metrics show positive improvement")
    print("  - Per-population metrics show if rare populations are handled well")

    # Save metrics to dict for later use
    metrics_summary = {
        "sliced_wasserstein": {
            "source_target": sw_source_target,
            "transformed_target": sw_trans_target,
            "improvement_pct": sw_improvement,
        },
        "mmd": {
            "source_target": mmd_source_target,
            "transformed_target": mmd_trans_target,
            "improvement_pct": mmd_improvement,
        },
    }

    print("\n✓ Metrics computation complete")

## 5.1. Quantitative Distribution Distance Metrics

To rigorously assess transformation quality, we'll compute several distribution distance metrics comparing:
- **Source vs Target**: Original distance (what we're trying to fix)
- **Transformed vs Target**: After transformation (should be much smaller)

Metrics include:
- **Wasserstein distance**: Optimal transport cost (lower is better)
- **Maximum Mean Discrepancy (MMD)**: Kernel-based distance
- **Per-population analysis**: Evaluate each cell population separately using gates

## Diagnostic: GMM Fit Quality Analysis

Let's diagnose potential GMM collapse by visualizing how well the GMM components fit the actual data distribution. We'll overlay the manually-defined flow cytometry gates to see if important populations are being captured.

⚠️ **Note on Gates**: The gate definitions used here are manually derived through visual inspection of flow cytometry data and should be considered approximate. Official Sysmex-provided gates or expert-validated gates would be preferable for production use.

In [None]:
import json
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from matplotlib.patches import Polygon
import numpy as np

# Load the manually-defined gates for RET channel
gate_file = "../../sysmexcbctools/transfer/flow_gates/json_gates/RET_gates.json"
with open(gate_file, "r") as f:
    gates = json.load(f)

print("Available gate populations:")
for pop_name, coords in gates.items():
    print(f"  - {pop_name}: {len(coords)} vertices")

# Get data and GMM fits
source_data = transformer.source_data_
target_data = transformer.target_data_
source_gmm = transformer.source_gmm_
target_gmm = transformer.target_gmm_

# Downsample for visualization
n_plot = min(50000, len(source_data), len(target_data))
np.random.seed(42)
source_plot_idx = np.random.choice(len(source_data), n_plot, replace=False)
target_plot_idx = np.random.choice(len(target_data), n_plot, replace=False)

source_plot = source_data[source_plot_idx]
target_plot = target_data[target_plot_idx]

# Create comprehensive diagnostic figure
fig, axes = plt.subplots(2, 2, figsize=(16, 14))


# Helper function to add gates to plot
def add_gates_to_plot(ax, gates_dict, alpha=0.5):
    colors = {"RBC": "red", "RET": "blue", "PLT": "green"}
    for pop_name, coords in gates_dict.items():
        if len(coords) > 0:
            poly = Polygon(
                coords,
                fill=False,
                edgecolor=colors.get(pop_name, "gray"),
                linewidth=2,
                alpha=alpha,
                label=f"{pop_name} gate",
            )
            ax.add_patch(poly)


# Helper function to add GMM components
def add_gmm_components(
    ax, gmm, color="red", marker="x", size=200, alpha=0.8, label="GMM"
):
    means = gmm.means_
    # Get component weights for sizing
    weights = gmm.weights_
    sizes = weights * size * 10  # Scale by weights
    ax.scatter(
        means[:, 0],
        means[:, 1],
        c=color,
        marker=marker,
        s=sizes,
        alpha=alpha,
        edgecolors="black",
        linewidths=1.5,
        label=label,
        zorder=10,
    )

    # Add component numbers
    for i, (mean, weight) in enumerate(zip(means, weights)):
        ax.annotate(
            f"{i}\n({weight:.2%})",
            mean,
            fontsize=6,
            ha="center",
            va="center",
            bbox=dict(
                boxstyle="round,pad=0.3", facecolor="white", alpha=0.7, edgecolor=color
            ),
        )


# ===== Plot 1: Source data with Source GMM =====
ax = axes[0, 0]
ax.scatter(
    source_plot[:, 0], source_plot[:, 1], s=1, alpha=0.1, c="blue", rasterized=True
)
add_gates_to_plot(ax, gates)
add_gmm_components(ax, source_gmm, color="red", label="Source GMM components")
ax.set_xlabel("SFL (Side Fluorescence)")
ax.set_ylabel("FSC (Forward Scatter)")
ax.set_title("Source Data + Source GMM Components + Manual Gates")
ax.set_xlim(0, 260)
ax.set_ylim(0, 260)
ax.grid(alpha=0.3)
ax.legend(loc="upper right", fontsize=8)

# ===== Plot 2: Target data with Target GMM =====
ax = axes[0, 1]
ax.scatter(
    target_plot[:, 0], target_plot[:, 1], s=1, alpha=0.1, c="orange", rasterized=True
)
add_gates_to_plot(ax, gates)
add_gmm_components(ax, target_gmm, color="darkred", label="Target GMM components")
ax.set_xlabel("SFL (Side Fluorescence)")
ax.set_ylabel("FSC (Forward Scatter)")
ax.set_title("Target Data + Target GMM Components + Manual Gates")
ax.set_xlim(0, 260)
ax.set_ylim(0, 260)
ax.grid(alpha=0.3)
ax.legend(loc="upper right", fontsize=8)

# ===== Plot 3: GMM component weight distribution =====
ax = axes[1, 0]
source_weights = source_gmm.weights_
target_weights = target_gmm.weights_
x = np.arange(len(source_weights))
width = 0.35
ax.bar(x - width / 2, source_weights, width, label="Source", alpha=0.7)
ax.bar(x + width / 2, target_weights, width, label="Target", alpha=0.7)
ax.set_xlabel("Component Index")
ax.set_ylabel("Weight")
ax.set_title("GMM Component Weights (showing collapse to few components)")
ax.legend()
ax.grid(alpha=0.3, axis="y")
ax.axhline(y=0.01, color="r", linestyle="--", alpha=0.5, label="1% threshold")

# ===== Plot 4: Population coverage analysis =====
ax = axes[1, 1]


# For each gate, count how many GMM components have their mean inside
def point_in_polygon(point, polygon):
    from matplotlib.path import Path

    path = Path(polygon)
    return path.contains_point(point)


def analyze_gate_coverage(gmm, gates_dict):
    coverage = {}
    means = gmm.means_
    weights = gmm.weights_

    for pop_name, coords in gates_dict.items():
        if len(coords) == 0:
            continue
        components_in_gate = []
        total_weight_in_gate = 0.0

        for i, (mean, weight) in enumerate(zip(means, weights)):
            if point_in_polygon(mean, coords):
                components_in_gate.append(i)
                total_weight_in_gate += weight

        coverage[pop_name] = {
            "n_components": len(components_in_gate),
            "total_weight": total_weight_in_gate,
            "components": components_in_gate,
        }

    return coverage


source_coverage = analyze_gate_coverage(source_gmm, gates)
target_coverage = analyze_gate_coverage(target_gmm, gates)

# Display coverage as text
text_str = "GMM Coverage of Manual Gate Regions:\n\n"
text_str += "SOURCE GMM:\n"
for pop, data in source_coverage.items():
    text_str += f"  {pop}: {data['n_components']} components, {data['total_weight']:.1%} weight\n"

text_str += "\nTARGET GMM:\n"
for pop, data in target_coverage.items():
    text_str += f"  {pop}: {data['n_components']} components, {data['total_weight']:.1%} weight\n"

text_str += (
    "\nWARNING: If RET/PLT gates have low weight,\nrare populations are being ignored!"
)

ax.text(
    0.1,
    0.5,
    text_str,
    transform=ax.transAxes,
    fontsize=11,
    verticalalignment="center",
    family="monospace",
    bbox=dict(boxstyle="round", facecolor="wheat", alpha=0.8),
)
ax.axis("off")

plt.tight_layout()
plt.savefig("../outputs/gmm_diagnostic.png", dpi=150, bbox_inches="tight")
plt.show()

print("\n✓ Diagnostic plot saved to ../outputs/gmm_diagnostic.png")
print("\nInterpretation:")
print("  - Top left: Source data with GMM means (sized by component weight)")
print("  - Top right: Target data with GMM means (sized by component weight)")
print("  - Bottom left: Component weight distribution (should be spread out)")
print("  - Bottom right: How many components cover each gated population")
print(
    "\nWARNING: If most components cluster in RBC region, rare populations (RET/PLT) are lost!"
)

In [None]:
# AGAIN with publication formatting

# Load the manually-defined gates for RET channel
gate_file = "../../sysmexcbctools/transfer/flow_gates/json_gates/RET_gates.json"
with open(gate_file, "r") as f:
    gates = json.load(f)

print("Available gate populations:")
for pop_name, coords in gates.items():
    print(f"  - {pop_name}: {len(coords)} vertices")

# Get data and GMM fits
source_data = transformer.source_data_
target_data = transformer.target_data_
source_gmm = transformer.source_gmm_
target_gmm = transformer.target_gmm_

# Downsample for visualization
n_plot = min(50000, len(source_data), len(target_data))
np.random.seed(42)
source_plot_idx = np.random.choice(len(source_data), n_plot, replace=False)
target_plot_idx = np.random.choice(len(target_data), n_plot, replace=False)

source_plot = source_data[source_plot_idx]
target_plot = target_data[target_plot_idx]

# Create comprehensive diagnostic figure
# fig, axes = plt.subplots(2, 2, figsize=(16, 14))
fig, axes = plt.subplots(1, 2, figsize=(5, 2.5))


# Helper function to add gates to plot
def add_gates_to_plot(ax, gates_dict, alpha=0.5):
    colors = {"RBC": "red", "RET": "blue", "PLT": "green"}
    for pop_name, coords in gates_dict.items():
        if len(coords) > 0:
            poly = Polygon(
                coords,
                fill=False,
                edgecolor=colors.get(pop_name, "gray"),
                linewidth=2,
                alpha=alpha,
                label=f"{pop_name} gate",
            )
            ax.add_patch(poly)


# Helper function to add GMM components
def add_gmm_components(
    ax, gmm, color="red", marker="x", size=30, alpha=0.8, label="GMM"
):
    means = gmm.means_
    # Get component weights for sizing
    weights = gmm.weights_
    sizes = weights * size * 10  # Scale by weights
    ax.scatter(
        means[:, 0],
        means[:, 1],
        c=color,
        marker=marker,
        s=sizes,
        alpha=alpha,
        edgecolors="black",
        linewidths=1,
        label=label,
        zorder=10,
    )

    # Add component numbers
    # for i, (mean, weight) in enumerate(zip(means, weights)):
    #     ax.annotate(f'{i}\n({weight:.2%})', mean, fontsize=4, ha='center', va='center',
    #                bbox=dict(boxstyle='round,pad=0.1', facecolor='white', alpha=0.7, edgecolor=color))


# ===== Plot 1: Source data with Source GMM =====
# ax = axes[0, 0]
ax = axes[0]
ax.scatter(
    source_plot[:, 0], source_plot[:, 1], s=1, alpha=0.1, c="blue", rasterized=True
)
# add_gates_to_plot(ax, gates)
add_gmm_components(
    ax, source_gmm, color="red", label="Source GMM\ncomponents\n(size $\propto$ weight)"
)
ax.set_xlabel("SFL")
ax.set_ylabel("FSC")
# ax.set_title('Source Data + Source GMM Components + Manual Gates')
ax.set_xlim(0, 260)
ax.set_ylim(0, 260)
ax.grid(alpha=0.3)
lg = ax.legend(loc="upper right", frameon=True, fancybox=False)
frame = lg.get_frame()
frame.set_edgecolor("black")
frame.set_linewidth(1.0)
frame.set_alpha(1.0)

# ===== Plot 2: Target data with Target GMM =====
# ax = axes[0, 1]
ax = axes[1]
ax.scatter(
    target_plot[:, 0], target_plot[:, 1], s=1, alpha=0.1, c="orange", rasterized=True
)
# add_gates_to_plot(ax, gates)
add_gmm_components(
    ax,
    target_gmm,
    color="darkred",
    label="Target GMM\ncomponents\n(size $\propto$ weight)",
)
ax.set_xlabel("SFL")
ax.set_ylabel("FSC")
# ax.set_title('Target Data + Target GMM Components + Manual Gates')
ax.set_xlim(0, 260)
ax.set_ylim(0, 260)
ax.grid(alpha=0.3)
lg = ax.legend(loc="upper right", frameon=True, fancybox=False)
frame = lg.get_frame()
frame.set_edgecolor("black")
frame.set_linewidth(1.0)
frame.set_alpha(1.0)

# # ===== Plot 3: GMM component weight distribution =====
# ax = axes[1, 0]
# source_weights = source_gmm.weights_
# target_weights = target_gmm.weights_
# x = np.arange(len(source_weights))
# width = 0.35
# ax.bar(x - width/2, source_weights, width, label='Source', alpha=0.7)
# ax.bar(x + width/2, target_weights, width, label='Target', alpha=0.7)
# ax.set_xlabel('Component Index')
# ax.set_ylabel('Weight')
# ax.set_title('GMM Component Weights (showing collapse to few components)')
# ax.legend()
# ax.grid(alpha=0.3, axis='y')
# ax.axhline(y=0.01, color='r', linestyle='--', alpha=0.5, label='1% threshold')

# # ===== Plot 4: Population coverage analysis =====
# ax = axes[1, 1]

# # For each gate, count how many GMM components have their mean inside
# def point_in_polygon(point, polygon):
#     from matplotlib.path import Path
#     path = Path(polygon)
#     return path.contains_point(point)

# def analyze_gate_coverage(gmm, gates_dict):
#     coverage = {}
#     means = gmm.means_
#     weights = gmm.weights_

#     for pop_name, coords in gates_dict.items():
#         if len(coords) == 0:
#             continue
#         components_in_gate = []
#         total_weight_in_gate = 0.0

#         for i, (mean, weight) in enumerate(zip(means, weights)):
#             if point_in_polygon(mean, coords):
#                 components_in_gate.append(i)
#                 total_weight_in_gate += weight

#         coverage[pop_name] = {
#             'n_components': len(components_in_gate),
#             'total_weight': total_weight_in_gate,
#             'components': components_in_gate
#         }

#     return coverage

# source_coverage = analyze_gate_coverage(source_gmm, gates)
# target_coverage = analyze_gate_coverage(target_gmm, gates)

# # Display coverage as text
# text_str = "GMM Coverage of Manual Gate Regions:\n\n"
# text_str += "SOURCE GMM:\n"
# for pop, data in source_coverage.items():
#     text_str += f"  {pop}: {data['n_components']} components, {data['total_weight']:.1%} weight\n"

# text_str += "\nTARGET GMM:\n"
# for pop, data in target_coverage.items():
#     text_str += f"  {pop}: {data['n_components']} components, {data['total_weight']:.1%} weight\n"

# text_str += "\n⚠️ Problem: If RET/PLT gates have low weight,\nrare populations are being ignored!"

# ax.text(0.1, 0.5, text_str, transform=ax.transAxes, fontsize=11,
#         verticalalignment='center', family='monospace',
#         bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.8))
# ax.axis('off')

plt.tight_layout()
plt.savefig("../outputs/gmm_diagnostic.png", dpi=150, bbox_inches="tight")
plt.savefig("../outputs/gmm_diagnostic.pdf", dpi=400, bbox_inches="tight")
plt.show()

print("\n✓ Diagnostic plot saved to ../outputs/gmm_diagnostic.png")
print("\nInterpretation:")
print("  - Top left: Source data with GMM means (sized by component weight)")
print("  - Top right: Target data with GMM means (sized by component weight)")
print("  - Bottom left: Component weight distribution (should be spread out)")
print("  - Bottom right: How many components cover each gated population")
print(
    "\nWARNING If most components cluster in RBC region, rare populations (RET/PLT) are lost!"
)

## 6. Working with Different Channels

The same workflow applies to other flow cytometry channels.