# Advanced GAM Usage with GAMCorrector

This notebook demonstrates advanced features and best practices for using `GAMCorrector`.

## Overview

Beyond basic covariate correction, `GAMCorrector` provides several advanced features for fine-tuning the correction process:

- **Custom spline specifications**: Control model flexibility per covariate
- **Reference conditions**: Correct to specific reference conditions
- **Outlier filtering**: Remove extreme values before fitting
- **Transformation options**: Handle different data distributions (log, logit)
- **Auto-detection**: Automatic handling of percentage/proportion data
- **Model diagnostics**: Understand what the GAM learned

In [None]:
import sys
import numpy as np
import pandas as pd
from pathlib import Path

# Add sysmexcorrect to path
sys.path.insert(0, "../../sysmexcbctools/correction")

from sysmexcbctools.correction.sysmexcorrect import GAMCorrector

# Standard plotting imports
import matplotlib
import matplotlib.pyplot as plt
import seaborn as sns

# PDF-compatible fonts
matplotlib.rcParams["pdf.fonttype"] = 42
matplotlib.rcParams["ps.fonttype"] = 42

# Scientific plot style
import scienceplots

plt.style.use(["science", "nature"])

# Colourblind-friendly palette
SEABORN_PALETTE = "colorblind"
seaborn_colors = sns.color_palette(SEABORN_PALETTE)

np.random.seed(42)

## 1. Load Data

We'll use `data_B.csv`. Features are named '0' through '31'.

**Note**: Feature '6' has a strong relationship with VenepunctureDelay, and feature '7' has a strong relationship with TimeIntoStudy, making them ideal for demonstrations.

In [None]:
# Load data
data_path = Path("../data/data_B.csv")
df = pd.read_csv(data_path)

# Identify feature columns (named '0' through '31')
feature_cols = [str(col) for col in range(32)]

# Create some percentage columns to demonstrate transformation
# Simulate percentage data (e.g., reticulocyte percentage)
df["RET_PCT"] = (df["6"] - df["6"].min()) / (df["6"].max() - df["6"].min())
df["PLT_PCT"] = (df["7"] - df["7"].min()) / (df["7"].max() - df["7"].min())

print(f"Data shape: {df.shape}")
print(f"\nFeature columns: {feature_cols[:5]}... (32 total)")
print(f"\nKey features for demonstration:")
print(f"  - Feature '6': Strong relationship with VenepunctureDelay")
print(f"  - Feature '7': Strong relationship with TimeIntoStudy")
print(f"\nSample of percentage columns:")
print(df[["RET_PCT", "PLT_PCT"]].describe())

## 2. Custom Spline Specifications

The number of splines controls GAM flexibility:
- **More splines**: More flexible, can capture complex patterns (risk: overfitting)
- **Fewer splines**: Less flexible, smoother curves (risk: underfitting)

Let's compare different spline numbers using feature '6' (strong VenepunctureDelay relationship).

In [None]:
# Fit GAMs with different spline numbers
spline_configs = {
    "Few splines (5)": 5,
    "Medium splines (25)": 25,
    "Many splines (50)": 50,
}

correctors = {}
corrected_dfs = {}

for name, n_splines in spline_configs.items():
    print(f"\nFitting with {name}...")
    corrector = GAMCorrector(
        covariates=["VenepunctureDelay"],
        feature_columns=["6"],  # Feature with strong VenepunctureDelay relationship
        n_splines=n_splines,
        transformation="none",
        verbose=False,
    )
    corrector.fit(df)
    correctors[name] = corrector
    corrected_dfs[name] = corrector.transform(df)

print("\n✓ All models fitted")

In [None]:
# visualise partial dependence for different spline numbers
fig, axes = plt.subplots(1, 3, figsize=(6.6, 2.2))

for ax, (name, corrector) in zip(axes, correctors.items()):
    pdep = corrector.get_partial_dependence(feature="6", covariate="VenepunctureDelay")

    ax.plot(
        pdep["covariate_value"],
        pdep["partial_dependence"],
        "b-",
        linewidth=2,
        label="Partial dependence",
    )
    ax.fill_between(
        pdep["covariate_value"],
        pdep["lower_ci"],
        pdep["upper_ci"],
        alpha=0.3,
        label="95% CI",
    )
    ax.set_xlabel("VenepunctureDelay")
    ax.set_ylabel("Effect on Feature 6")
    ax.set_title(name)
    ax.legend()
    ax.grid(alpha=0.3)

plt.tight_layout()
plt.show()

print("Observations:")
print("- Fewer splines → smoother curves (may miss fine details)")
print("- More splines → wigglier curves (may overfit noise)")
print("- Confidence intervals are wider with more splines (less certain)")

### Custom Splines Per Covariate

You can specify different spline numbers for different covariates.

In [None]:
# Custom splines per covariate
corrector_custom = GAMCorrector(
    covariates=["TimeIntoStudy", "VenepunctureDelay"],
    feature_columns=["6", "7"],
    n_splines={
        "TimeIntoStudy": 50,  # More splines for longer-term drift
        "VenepunctureDelay": 20,  # Fewer splines for sample age
    },
    transformation="none",
    verbose=True,
)
corrector_custom.fit(df)

print("\n✓ Custom spline configuration applied")

## 3. Reference Conditions

By default, GAMCorrector corrects to the overall mean. But you can specify a **reference condition** to correct to a specific sub-population.

Example: Correct all samples to look like samples with VenepunctureDelay ≤ 10 hours.

In [None]:
# Corrector WITHOUT reference condition (default)
corrector_default = GAMCorrector(
    covariates=["VenepunctureDelay"],
    feature_columns=["6", "7"],
    n_splines=25,
    transformation="none",
    verbose=False,
)
corrector_default.fit(df)
df_corrected_default = corrector_default.transform(df)

print("Fitted default corrector (corrects to overall mean)")

# Corrector WITH reference condition
corrector_ref = GAMCorrector(
    covariates=["VenepunctureDelay"],
    feature_columns=["6", "7"],
    reference_condition={
        "VenepunctureDelay": lambda x: x
        <= 10.0  # Only use samples ≤ 10 hours for reference
    },
    n_splines=25,
    transformation="none",
    verbose=False,
)
corrector_ref.fit(df)
df_corrected_ref = corrector_ref.transform(df)

print("Fitted reference condition corrector (corrects to delay ≤ 10 hours)")

# Compare reference means
print(f"\nReference means for Feature 6:")
print(
    f"  Default (overall mean):        {corrector_default.feature_means_[('6', None)]:.4f}"
)
print(
    f"  With reference condition:      {corrector_ref.feature_means_[('6', None)]:.4f}"
)
print(
    f"  Actual mean (delay ≤ 10h):     {df[df['VenepunctureDelay'] <= 10.0]['6'].mean():.4f}"
)

In [None]:
# visualise the difference
fig, axes = plt.subplots(1, 2, figsize=(4.5, 2.2))

# Default correction
axes[0].scatter(
    df["VenepunctureDelay"], df["6"], alpha=0.1, s=1, c="gray", label="Original"
)
axes[0].scatter(
    df["VenepunctureDelay"],
    df_corrected_default["6"],
    alpha=0.1,
    s=1,
    c="blue",
    label="Corrected",
)
axes[0].axhline(
    corrector_default.feature_means_[("6", None)],
    color="red",
    linestyle="--",
    linewidth=2,
    label="Reference mean",
)
axes[0].set_xlabel("VenepunctureDelay")
axes[0].set_ylabel("Feature 6")
axes[0].set_title("Default Correction\n(corrects to overall mean)")
axes[0].legend()
axes[0].grid(alpha=0.3)

# Reference condition correction
axes[1].scatter(
    df["VenepunctureDelay"], df["6"], alpha=0.1, s=1, c="gray", label="Original"
)
axes[1].scatter(
    df["VenepunctureDelay"],
    df_corrected_ref["6"],
    alpha=0.1,
    s=1,
    c="green",
    label="Corrected",
)
axes[1].axhline(
    corrector_ref.feature_means_[("6", None)],
    color="red",
    linestyle="--",
    linewidth=2,
    label="Reference mean ($\leq$ 10h)",
)
axes[1].axvline(
    10.0, color="orange", linestyle=":", linewidth=2, label="Reference cutoff"
)
axes[1].set_xlabel("VenepunctureDelay")
axes[1].set_ylabel("Feature 6")
axes[1].set_title("Reference Condition Correction\n(corrects to delay $\leq$ 10 hours)")
axes[1].legend()
axes[1].grid(alpha=0.3)

plt.tight_layout()
plt.show()

print("The reference condition shifts all data to match the target sub-population.")

## 4. Outlier Filtering

Outliers can distort GAM fitting. The `centralise_threshold` parameter removes extreme values before fitting.

Values are kept if they're within `threshold × MAD` of the median, where MAD is the median absolute deviation.

In [None]:
# Create a feature with outliers
df_with_outliers = df.copy()
# Add some extreme outliers to feature '6'
outlier_indices = np.random.choice(len(df), size=100, replace=False)
df_with_outliers.loc[outlier_indices, "6"] += np.random.normal(10, 5, 100)

print(f"Added {len(outlier_indices)} outliers to feature '6'")
print(f"Original range: [{df['6'].min():.2f}, {df['6'].max():.2f}]")
print(
    f"With outliers range: [{df_with_outliers['6'].min():.2f}, {df_with_outliers['6'].max():.2f}]"
)

# Fit without outlier filtering
corrector_no_filter = GAMCorrector(
    covariates=["VenepunctureDelay"],
    feature_columns=["6"],
    centralise_threshold=None,  # No filtering
    n_splines=25,
    verbose=False,
)
corrector_no_filter.fit(df_with_outliers)

# Fit with outlier filtering
corrector_filtered = GAMCorrector(
    covariates=["VenepunctureDelay"],
    feature_columns=["6"],
    centralise_threshold=3.5,  # Keep values within 3.5 MADs
    n_splines=25,
    verbose=False,
)
corrector_filtered.fit(df_with_outliers)

print("\n✓ Fitted both models")

In [None]:
# Compare partial dependence
fig, axes = plt.subplots(1, 2, figsize=(4.5, 2.2))

# Without filtering
pdep_no_filter = corrector_no_filter.get_partial_dependence(
    feature="6", covariate="VenepunctureDelay"
)
axes[0].plot(
    pdep_no_filter["covariate_value"],
    pdep_no_filter["partial_dependence"],
    "r-",
    linewidth=2,
    label="No outlier filtering",
)
axes[0].fill_between(
    pdep_no_filter["covariate_value"],
    pdep_no_filter["lower_ci"],
    pdep_no_filter["upper_ci"],
    alpha=0.3,
    color="red",
)
axes[0].set_xlabel("VenepunctureDelay")
axes[0].set_ylabel("Effect on Feature 6")
axes[0].set_title("Without Outlier Filtering\n(outliers distort the fit)")
axes[0].legend()
axes[0].grid(alpha=0.3)

# With filtering
pdep_filtered = corrector_filtered.get_partial_dependence(
    feature="6", covariate="VenepunctureDelay"
)
axes[1].plot(
    pdep_filtered["covariate_value"],
    pdep_filtered["partial_dependence"],
    "g-",
    linewidth=2,
    label="With outlier filtering",
)
axes[1].fill_between(
    pdep_filtered["covariate_value"],
    pdep_filtered["lower_ci"],
    pdep_filtered["upper_ci"],
    alpha=0.3,
    color="green",
)
axes[1].set_xlabel("VenepunctureDelay")
axes[1].set_ylabel("Effect on Feature 6")
axes[1].set_title("With Outlier Filtering (threshold=3.5 MAD)\n(more robust fit)")
axes[1].legend()
axes[1].grid(alpha=0.3)

plt.tight_layout()
plt.show()

print("Outlier filtering produces more robust GAM fits.")
print("Recommended threshold: 3.5 MAD (removes ~0.5% of extreme values)")

## 5. Transformation Options

`GAMCorrector` supports three transformation types:

1. **'none'**: No transformation (for normally distributed data)
2. **'log'**: Natural log (for right-skewed positive data)
3. **'logit'**: Logit transform (for percentages/proportions in [0,1])

Transformations stabilize variance and make relationships more linear.

In [None]:
# Demonstrate auto-detection of percentage columns
corrector_auto = GAMCorrector(
    covariates=["VenepunctureDelay"],
    feature_columns=["6", "RET_PCT", "PLT_PCT"],  # Mix of regular and percentage
    transformation="none",  # Default for most columns
    auto_detect_percentages=True,  # Auto-apply logit to PCT columns
    n_splines=25,
    verbose=True,
)
corrector_auto.fit(df)

print(f"\nAuto-detected percentage columns: {corrector_auto.transformed_columns_}")
print("These columns were automatically logit-transformed.")

In [None]:
# Compare transformations for percentage data
transformations = {
    "No transform": "none",
    "Log transform": "log",
    "Logit transform": "logit",
}

fig, axes = plt.subplots(1, 3, figsize=(6.6, 2.2))

for ax, (name, transform) in zip(axes, transformations.items()):
    corrector = GAMCorrector(
        covariates=["VenepunctureDelay"],
        feature_columns=["RET_PCT"],
        transformation=transform,
        auto_detect_percentages=False,  # Disable auto-detection
        n_splines=25,
        verbose=False,
    )
    corrector.fit(df)
    df_corrected = corrector.transform(df)

    # Plot distribution
    ax.hist(
        df_corrected["RET_PCT"], bins=50, alpha=0.6, color="blue", edgecolor="black"
    )
    ax.set_xlabel("RET_PCT (corrected)")
    ax.set_ylabel("Frequency")
    ax.set_title(name)
    ax.grid(alpha=0.3, axis="y")

plt.tight_layout()
plt.show()

print("For percentage data, logit transform is recommended.")
print("It ensures corrected values stay in [0, 1] range.")

## 6. Model Diagnostics and Interpretation

Understanding what the GAM learned is crucial for validating the correction.

In [None]:
# Fit a model for diagnostic analysis
corrector_diag = GAMCorrector(
    covariates=["TimeIntoStudy", "VenepunctureDelay"],
    feature_columns=["6", "7", "8"],
    n_splines=25,
    transformation="none",
    verbose=False,
)
corrector_diag.fit(df)
df_corrected = corrector_diag.transform(df)

print("Fitted diagnostic model")

In [None]:
# Diagnostic 1: Residual correlations
print("=== Diagnostic 1: Residual Correlations ===")
print("\nCorrelations should be near zero after correction:\n")

for feat in ["6", "7", "8"]:
    for cov in ["TimeIntoStudy", "VenepunctureDelay"]:
        corr_before = df[[feat, cov]].corr().iloc[0, 1]
        corr_after = df_corrected[[feat, cov]].corr().iloc[0, 1]

        print(f"Feature {feat} vs {cov}:")
        print(f"  Before: {corr_before:+.4f}")
        print(f"  After:  {corr_after:+.4f} ({'✓' if abs(corr_after) < 0.01 else '⚠'})")
        print()

In [None]:
# Diagnostic 3: Visual inspection of correction
print("=== Diagnostic 3: Visual Inspection ===")

fig, axes = plt.subplots(2, 3, figsize=(6.6, 4.5))

for i, feat in enumerate(["6", "7", "8"]):
    # Before correction
    axes[0, i].scatter(df["VenepunctureDelay"], df[feat], alpha=0.1, s=1, c="gray")

    # Add binned means
    bins = pd.cut(df["VenepunctureDelay"], bins=20)
    binned = df.groupby(bins)[feat].mean()
    bin_centers = [interval.mid for interval in binned.index]
    axes[0, i].plot(bin_centers, binned.values, "r-", linewidth=3, label="Trend")

    axes[0, i].set_xlabel("VenepunctureDelay")
    axes[0, i].set_ylabel(f"Feature {feat}")
    axes[0, i].set_title(f"Feature {feat} Before Correction")
    axes[0, i].legend()
    axes[0, i].grid(alpha=0.3)

    # After correction
    axes[1, i].scatter(
        df["VenepunctureDelay"], df_corrected[feat], alpha=0.1, s=1, c="gray"
    )

    # Add binned means
    bins = pd.cut(df_corrected["VenepunctureDelay"], bins=20)
    binned = df_corrected.groupby(bins)[feat].mean()
    bin_centers = [interval.mid for interval in binned.index]
    axes[1, i].plot(
        bin_centers, binned.values, "b-", linewidth=3, label="Trend (should be flat)"
    )

    axes[1, i].set_xlabel("VenepunctureDelay")
    axes[1, i].set_ylabel(f"Feature {feat}")
    axes[1, i].set_title(f"Feature {feat} After Correction")
    axes[1, i].legend()
    axes[1, i].grid(alpha=0.3)

plt.tight_layout()
plt.show()

print("After correction, the binned trend should be approximately flat.")

## 7. Best Practices Summary

Here's a comprehensive example using best practices:

In [None]:
# Best practices example
corrector_best = GAMCorrector(
    # Covariates
    covariates=["TimeIntoStudy", "VenepunctureDelay"],
    # Features (let it auto-detect if None)
    feature_columns=feature_cols,
    # Group correction if applicable
    group_column="Machine",
    normalize_groups=True,  # Remove systematic group offsets
    reference_group=None,  # Normalize to overall mean
    # Transformation
    transformation="none",  # Default for most features
    auto_detect_percentages=True,  # Auto-apply logit to PCT columns
    # Splines (custom if needed)
    n_splines={
        "TimeIntoStudy": 30,  # More splines for long-term drift
        "VenepunctureDelay": 20,  # Fewer splines for sample age
    },
    # Outlier filtering (recommended)
    centralise_threshold=3.5,  # Remove extreme outliers
    # Reference condition (optional)
    # NOTE: Lambda functions cannot be pickled. If you need to save the model,
    # use a named function or don't use reference_condition when saving.
    reference_condition={
        "VenepunctureDelay": lambda x: x <= 10.0  # Correct to fresh samples
    },
    # Performance
    parallel=False,  # Set to True for large datasets
    n_jobs=-1,  # Use all CPUs
    verbose=True,
)

print("\nFitting best practices model...")
corrector_best.fit(df)

# Apply correction
df_corrected_best = corrector_best.transform(df)

# Save model
# Note: The lambda function in reference_condition will be excluded from the saved model
# with a warning. The model can still be used for prediction.
corrector_best.save("../notebooks/gam_corrector_best_practices.pkl")

print("\n✓ Best practices model fitted and saved")