# Single Covariate Correction with GAMCorrector

This notebook demonstrates how to use `GAMCorrector` to correct for a single spurious covariate in complete blood count (CBC) data.

## Overview

Blood test results can be affected by various technical factors that are not related to the patient's true biological state. Examples:
- **Sample age**: Time between blood draw and analysis
- **Time of day**: Circadian effects or instrument drift
- **Batch effects**: Different analysers or reagent batches

We can use GAMs (Generalized Additive Models) to model these non-linear relationships and attempt to remove their effects.

In [None]:
import sys
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path

# Add sysmexcorrect to path
sys.path.insert(0, "../../sysmexcbctools/correction")

from sysmexcbctools.correction.sysmexcorrect import GAMCorrector

# Set plotting style
sns.set_theme(style="whitegrid", context="notebook")
plt.rcParams["figure.dpi"] = 100

np.random.seed(42)

## 1. Load Data

We'll use `data_B.csv`, which contains:
- 32 features (tabular like CBC data)
- 3 domain factors (Machine, vendelay_binned, studytime_binned)
- 2 continuous covariates (TimeIntoStudy, VenepunctureDelay)

The **VenepunctureDelay** covariate represents the time between blood draw and analysis - a common technical factor that can affect certain blood parameters.

In [None]:
# Load data
data_path = Path("../data/data_B.csv")
df = pd.read_csv(data_path)

print(f"Data shape: {df.shape}")
print(f"\nColumns: {df.columns.tolist()}")

# Display first few rows
df.head()

In [None]:
# Identify feature columns (F1-F32)
feature_cols = [f"{col}" for col in range(32) if f"{col}" in df.columns]
print(f"Found {len(feature_cols)} feature columns: {feature_cols[:5]}...")

# Check covariate
print(f"\nVenepunctureDelay statistics:")
print(df["VenepunctureDelay"].describe())

## 2. Visualise Covariate Effects

Before correction, let's visualise how VenepunctureDelay affects some features. In real CBC data, sample age typically affects white blood cell counts (WBC) and other parameters.

In [None]:
# Sample a few features to visualise
sample_features = feature_cols[6:12]

fig, axes = plt.subplots(2, 3, figsize=(15, 10))
axes = axes.flatten()

for i, feat in enumerate(sample_features):
    # Create bins for visualization
    df_plot = df[[feat, "VenepunctureDelay"]].dropna()

    # Scatter plot with trend
    axes[i].scatter(
        df_plot["VenepunctureDelay"], df_plot[feat], alpha=0.1, s=1, c="gray"
    )

    # Add binned means to show trend
    bins = pd.cut(df_plot["VenepunctureDelay"], bins=20)
    binned = df_plot.groupby(bins)[feat].mean()
    bin_centers = [interval.mid for interval in binned.index]
    axes[i].plot(bin_centers, binned.values, "r-", linewidth=2, label="Binned mean")

    axes[i].set_xlabel("Venepuncture Delay")
    axes[i].set_ylabel(feat)
    axes[i].set_title(f"{feat} vs Venepuncture Delay")
    axes[i].legend()

plt.tight_layout()
plt.show()

print("Note: The red line shows the average feature value at different delay times.")
print("Any trend in this line indicates a covariate effect that should be corrected.")

## 3. Fit GAMCorrector

Now we'll use `GAMCorrector` to model and remove the effect of VenepunctureDelay on all features.

In [None]:
# Create GAMCorrector
corrector = GAMCorrector(
    covariates=["VenepunctureDelay"],  # Single covariate to correct for
    feature_columns=feature_cols,  # All CBC features
    transformation="none",  # No transformation (data appears normalized)
    n_splines=5,  # Number of spline basis functions
    verbose=True,
)

print("Fitting GAM models...")
corrector.fit(df)

print(f"\nFitted {len(corrector.gam_models_)} GAM models successfully!")

## 4. Apply Correction

Now we'll apply the fitted GAM models to correct the data.

In [None]:
# Transform data
df_corrected = corrector.transform(df)

print("Correction applied successfully!")
print(f"\nOriginal data shape: {df.shape}")
print(f"Corrected data shape: {df_corrected.shape}")

# Check that correction preserves data structure
assert df.shape == df_corrected.shape, "Shape mismatch!"
print("\n✓ Data structure preserved")

## 5. visualise Correction Effect

Let's visualise the same features after correction to see if the covariate effect has been removed.

In [None]:
fig, axes = plt.subplots(2, 3, figsize=(15, 10))
axes = axes.flatten()

for i, feat in enumerate(sample_features):
    df_plot = df_corrected[[feat, "VenepunctureDelay"]].dropna()

    # Scatter plot
    axes[i].scatter(
        df_plot["VenepunctureDelay"], df_plot[feat], alpha=0.1, s=1, c="gray"
    )

    # Add binned means
    bins = pd.cut(df_plot["VenepunctureDelay"], bins=20)
    binned = df_plot.groupby(bins)[feat].mean()
    bin_centers = [interval.mid for interval in binned.index]
    axes[i].plot(
        bin_centers, binned.values, "b-", linewidth=2, label="Binned mean (corrected)"
    )

    axes[i].set_xlabel("Venepuncture Delay")
    axes[i].set_ylabel(feat)
    axes[i].set_title(f"{feat} vs Venepuncture Delay (After Correction)")
    axes[i].legend()

plt.tight_layout()
plt.show()

print("After correction, the binned mean should be approximately flat,")
print("indicating that the covariate effect has been removed.")

## 6. Compare Before and After

Let's create side-by-side comparisons for a single feature.

In [None]:
# Select one feature for detailed comparison
example_feature = "6"

fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Before correction
df_plot_before = df[[example_feature, "VenepunctureDelay"]].dropna()
axes[0].scatter(
    df_plot_before["VenepunctureDelay"],
    df_plot_before[example_feature],
    alpha=0.1,
    s=1,
    c="gray",
)
bins = pd.cut(df_plot_before["VenepunctureDelay"], bins=20)
binned_before = df_plot_before.groupby(bins)[example_feature].mean()
bin_centers = [interval.mid for interval in binned_before.index]
axes[0].plot(
    bin_centers, binned_before.values, "r-", linewidth=3, label="Before correction"
)
axes[0].set_xlabel("Venepuncture Delay")
axes[0].set_ylabel(example_feature)
axes[0].set_title("Before Correction")
axes[0].legend()

# After correction
df_plot_after = df_corrected[[example_feature, "VenepunctureDelay"]].dropna()
axes[1].scatter(
    df_plot_after["VenepunctureDelay"],
    df_plot_after[example_feature],
    alpha=0.1,
    s=1,
    c="gray",
)
bins = pd.cut(df_plot_after["VenepunctureDelay"], bins=20)
binned_after = df_plot_after.groupby(bins)[example_feature].mean()
bin_centers = [interval.mid for interval in binned_after.index]
axes[1].plot(
    bin_centers, binned_after.values, "b-", linewidth=3, label="After correction"
)
axes[1].set_xlabel("Venepuncture Delay")
axes[1].set_ylabel(example_feature)
axes[1].set_title("After Correction")
axes[1].legend()

plt.tight_layout()
plt.show()

## 7. Quantify Correction Quality

We can quantify how well the correction worked by measuring the correlation between features and the covariate before and after correction.

In [None]:
# Calculate correlations before and after
correlations_before = []
correlations_after = []

for feat in feature_cols:
    corr_before = df[[feat, "VenepunctureDelay"]].corr().iloc[0, 1]
    corr_after = df_corrected[[feat, "VenepunctureDelay"]].corr().iloc[0, 1]
    correlations_before.append(abs(corr_before))
    correlations_after.append(abs(corr_after))

# Create comparison dataframe
corr_df = pd.DataFrame(
    {
        "Feature": feature_cols,
        "Before": correlations_before,
        "After": correlations_after,
        "Reduction": np.array(correlations_before) - np.array(correlations_after),
    }
)

print("Top 10 features with largest correlation reduction:")
print(corr_df.nlargest(10, "Reduction")[["Feature", "Before", "After", "Reduction"]])

# Plot correlation comparison
fig, ax = plt.subplots(figsize=(10, 6))
x = np.arange(len(feature_cols))
width = 0.35

ax.bar(
    x - width / 2, correlations_before, width, label="Before", alpha=0.8, color="red"
)
ax.bar(x + width / 2, correlations_after, width, label="After", alpha=0.8, color="blue")

ax.set_xlabel("Features")
ax.set_ylabel("Absolute Correlation with VenepunctureDelay")
ax.set_title("Feature-Covariate Correlation: Before vs After Correction")
ax.set_xticks(x)
ax.set_xticklabels(feature_cols, rotation=90)
ax.legend()
ax.grid(axis="y", alpha=0.3)

plt.tight_layout()
plt.show()

print(f"\nMean absolute correlation before: {np.mean(correlations_before):.4f}")
print(f"Mean absolute correlation after: {np.mean(correlations_after):.4f}")
print(
    f"Average reduction: {np.mean(correlations_before) - np.mean(correlations_after):.4f}"
)

## 8. Inspect Partial Dependence

We can visualise the learned GAM relationship using partial dependence plots.

In [None]:
# Get partial dependence for a few features
fig, axes = plt.subplots(2, 3, figsize=(15, 10))
axes = axes.flatten()

for i, feat in enumerate(sample_features):
    try:
        pdep_df = corrector.get_partial_dependence(
            feature=feat, covariate="VenepunctureDelay"
        )

        axes[i].plot(
            pdep_df["covariate_value"],
            pdep_df["partial_dependence"],
            "b-",
            linewidth=2,
            label="Partial dependence",
        )
        axes[i].fill_between(
            pdep_df["covariate_value"],
            pdep_df["lower_ci"],
            pdep_df["upper_ci"],
            alpha=0.3,
            label="95% CI",
        )
        axes[i].set_xlabel("VenepunctureDelay")
        axes[i].set_ylabel("Effect on " + feat)
        axes[i].set_title(f"Learned GAM for {feat}")
        axes[i].legend()
        axes[i].grid(alpha=0.3)
    except Exception as e:
        axes[i].text(0.5, 0.5, f"Error: {str(e)}", ha="center", va="center")

plt.tight_layout()
plt.show()

print("These plots show the non-linear relationship that the GAM learned.")
print("The shaded area represents the 95% confidence interval.")

## 9. Save and Load Model

Once fitted, you can save the GAMCorrector and reuse it on new data.

In [None]:
# Save model
model_path = "../notebooks/gam_corrector_single.pkl"
corrector.save(model_path)
print(f"Model saved to {model_path}")

# Load model
corrector_loaded = GAMCorrector.load(model_path)
print("Model loaded successfully")

# Verify loaded model works
df_test = df.sample(n=100, random_state=42)
df_test_corrected = corrector_loaded.transform(df_test)
print(f"Applied loaded model to {len(df_test)} test samples")

# Verify results match
df_original_corrected = corrector.transform(df_test)
difference = (
    (df_test_corrected[feature_cols] - df_original_corrected[feature_cols])
    .abs()
    .max()
    .max()
)
print(f"Maximum difference between original and loaded model: {difference:.2e}")
assert difference < 1e-10, "Models don't match!"
print("✓ Loaded model produces identical results")