# Multi-Covariate Correction with GAMCorrector

This notebook demonstrates how to correct for **multiple spurious covariates simultaneously** in CBC data.

## Overview

In real-world scenarios, blood test results are often affected by multiple technical factors, such as:
- **Sample age** (time between blood draw and analysis)
- **Time into study** (temporal drift in equipment or reagents)
- **Machine/batch effects** (different analyzers or reagent batches)
- **Time of day**, **day of week**, etc.

GAM can model complex interactions between multiple covariates and remove their combined effects.

In [None]:
import sys
import numpy as np
import pandas as pd
import matplotlib
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
from mpl_toolkits.mplot3d import Axes3D

# PDF-compatible fonts
matplotlib.rcParams["pdf.fonttype"] = 42
matplotlib.rcParams["ps.fonttype"] = 42

# Scientific plot style
import scienceplots

plt.style.use(["science", "nature"])

# Colourblind-friendly palette
SEABORN_PALETTE = "colorblind"
seaborn_colors = sns.color_palette(SEABORN_PALETTE)

# Add sysmexcorrect to path
sys.path.insert(0, "../../sysmexcbctools/correction")

from sysmexcbctools.correction.sysmexcorrect import GAMCorrector

np.random.seed(42)

## 1. Load Data

We'll use `data_B.csv` which has:
- **2 continuous covariates**: `TimeIntoStudy` and `VenepunctureDelay`
- **1 categorical variable**: `Machine` (can be used for group-specific correction)

In [None]:
# Load data
data_path = Path("../data/data_B.csv")
df = pd.read_csv(data_path)

print(f"Data shape: {df.shape}")

# Identify feature columns
feature_cols = [f"{col}" for col in range(32) if f"{col}" in df.columns]
print(f"\nFeature columns: {len(feature_cols)}")

# Check covariates
print(f"\nCovariate statistics:")
print(df[["TimeIntoStudy", "VenepunctureDelay"]].describe())

# Check machine distribution
print(f"\nMachine distribution:")
print(df["Machine"].value_counts())

## 2. Visualise Multi-Covariate Effects

Let's visualise how features are affected by both covariates simultaneously.

In [None]:
# Select a few features to visualise
sample_features = feature_cols[6:10]

fig, axes = plt.subplots(2, 2, figsize=(4.5, 4.5), subplot_kw={"projection": "3d"})
axes = axes.flatten()

for i, feat in enumerate(sample_features):
    df_plot = df[[feat, "TimeIntoStudy", "VenepunctureDelay"]].dropna()

    # Scatter plot
    scatter = axes[i].scatter(
        df_plot["TimeIntoStudy"],
        df_plot["VenepunctureDelay"],
        df_plot[feat],
        c=df_plot[feat],
        cmap="viridis",
        alpha=0.3,
        s=1,
    )

    axes[i].set_xlabel("Time Into Study")
    axes[i].set_ylabel("Venepuncture Delay")
    axes[i].set_zlabel(feat)
    axes[i].set_title(f"{feat} vs Both Covariates")
    plt.colorbar(scatter, ax=axes[i], shrink=0.5)

plt.tight_layout()
plt.show()

print("These 3D plots show how features depend on BOTH covariates.")
print("Patterns in these plots indicate covariate effects that need correction.")

## 3. Single vs. Multi-Covariate Correction

First, let's compare correcting for one covariate vs. both covariates.

In [None]:
# Corrector 1: Single covariate (VenepunctureDelay only)
print("Fitting single-covariate GAM (VenepunctureDelay only)...")
corrector_single = GAMCorrector(
    covariates=["VenepunctureDelay"],
    feature_columns=feature_cols,
    transformation="none",
    n_splines=25,
    verbose=True,
)
corrector_single.fit(df)
df_corrected_single = corrector_single.transform(df)

print("\n" + "=" * 60 + "\n")

# Corrector 2: Multi-covariate (both covariates)
print("Fitting multi-covariate GAM (TimeIntoStudy + VenepunctureDelay)...")
corrector_multi = GAMCorrector(
    covariates=["TimeIntoStudy", "VenepunctureDelay"],
    feature_columns=feature_cols,
    transformation="none",
    n_splines={
        "TimeIntoStudy": 30,
        "VenepunctureDelay": 25,
    },  # Custom splines per covariate
    verbose=True,
)
corrector_multi.fit(df)
df_corrected_multi = corrector_multi.transform(df)

## 4. Compare Correction Quality

Calculate how well each approach removes covariate effects.

In [None]:
# Calculate correlations for each covariate
results = []

for feat in feature_cols:
    # Original correlations
    corr_time_orig = abs(df[[feat, "TimeIntoStudy"]].corr().iloc[0, 1])
    corr_delay_orig = abs(df[[feat, "VenepunctureDelay"]].corr().iloc[0, 1])

    # Single-covariate correction
    corr_time_single = abs(
        df_corrected_single[[feat, "TimeIntoStudy"]].corr().iloc[0, 1]
    )
    corr_delay_single = abs(
        df_corrected_single[[feat, "VenepunctureDelay"]].corr().iloc[0, 1]
    )

    # Multi-covariate correction
    corr_time_multi = abs(df_corrected_multi[[feat, "TimeIntoStudy"]].corr().iloc[0, 1])
    corr_delay_multi = abs(
        df_corrected_multi[[feat, "VenepunctureDelay"]].corr().iloc[0, 1]
    )

    results.append(
        {
            "Feature": feat,
            "TimeIntoStudy_Original": corr_time_orig,
            "TimeIntoStudy_Single": corr_time_single,
            "TimeIntoStudy_Multi": corr_time_multi,
            "VenepunctureDelay_Original": corr_delay_orig,
            "VenepunctureDelay_Single": corr_delay_single,
            "VenepunctureDelay_Multi": corr_delay_multi,
        }
    )

corr_df = pd.DataFrame(results)

# Summary statistics
print("Average absolute correlations:")
print(f"\nTimeIntoStudy:")
print(f"  Original:            {corr_df['TimeIntoStudy_Original'].mean():.4f}")
print(
    f"  Single-covariate:    {corr_df['TimeIntoStudy_Single'].mean():.4f} (not corrected for)"
)
print(f"  Multi-covariate:     {corr_df['TimeIntoStudy_Multi'].mean():.4f}")

print(f"\nVenepunctureDelay:")
print(f"  Original:            {corr_df['VenepunctureDelay_Original'].mean():.4f}")
print(f"  Single-covariate:    {corr_df['VenepunctureDelay_Single'].mean():.4f}")
print(f"  Multi-covariate:     {corr_df['VenepunctureDelay_Multi'].mean():.4f}")

In [None]:
# Visualise correlation comparison
fig, axes = plt.subplots(1, 2, figsize=(4.5, 2.2))

# TimeIntoStudy correlations
x = np.arange(len(feature_cols))
width = 0.25

axes[0].bar(
    x - width,
    corr_df["TimeIntoStudy_Original"],
    width,
    label="Original",
    alpha=0.8,
    color="red",
)
axes[0].bar(
    x,
    corr_df["TimeIntoStudy_Single"],
    width,
    label="Single-cov (not corrected)",
    alpha=0.8,
    color="orange",
)
axes[0].bar(
    x + width,
    corr_df["TimeIntoStudy_Multi"],
    width,
    label="Multi-cov",
    alpha=0.8,
    color="blue",
)
axes[0].set_xlabel("Features")
axes[0].set_ylabel("Absolute Correlation")
axes[0].set_title("TimeIntoStudy Correlation")
axes[0].set_xticks(x)
axes[0].set_xticklabels(feature_cols, rotation=90)
axes[0].legend()
axes[0].grid(axis="y", alpha=0.3)

# VenepunctureDelay correlations
axes[1].bar(
    x - width,
    corr_df["VenepunctureDelay_Original"],
    width,
    label="Original",
    alpha=0.8,
    color="red",
)
axes[1].bar(
    x,
    corr_df["VenepunctureDelay_Single"],
    width,
    label="Single-cov",
    alpha=0.8,
    color="orange",
)
axes[1].bar(
    x + width,
    corr_df["VenepunctureDelay_Multi"],
    width,
    label="Multi-cov",
    alpha=0.8,
    color="blue",
)
axes[1].set_xlabel("Features")
axes[1].set_ylabel("Absolute Correlation")
axes[1].set_title("VenepunctureDelay Correlation")
axes[1].set_xticks(x)
axes[1].set_xticklabels(feature_cols, rotation=90)
axes[1].legend()
axes[1].grid(axis="y", alpha=0.3)

plt.tight_layout()
plt.show()

print(
    "Key observation: Multi-covariate correction removes effects from BOTH covariates,"
)
print("while single-covariate correction only handles one.")

## 5. Group-Specific Correction

Now let's demonstrate fitting **separate GAM models per machine**. This is useful when different instruments have different covariate relationships.

We'll compare two approaches:
1. **Grouped correction only**: Models covariates separately per machine
2. **Grouped + normalization**: Models covariates separately per machine AND removes systematic machine offsets

In [None]:
# Corrector 1: Machine-specific models (without normalization)
print("Fitting machine-specific GAMs (without group normalization)...")
corrector_grouped = GAMCorrector(
    covariates=["TimeIntoStudy", "VenepunctureDelay"],
    feature_columns=feature_cols,
    group_column="Machine",  # Fit separate models per machine
    normalize_groups=False,  # Do NOT normalize groups
    transformation="none",
    n_splines=25,
    verbose=True,
)
corrector_grouped.fit(df)
df_corrected_grouped = corrector_grouped.transform(df)

print(f"\nFitted {len(corrector_grouped.gam_models_)} GAM models")
print(
    f"(Expected: {len(feature_cols)} features × {df['Machine'].nunique()} machines = {len(feature_cols) * df['Machine'].nunique()})"
)

print("\n" + "=" * 60 + "\n")

# Corrector 2: Machine-specific models WITH normalization
print("Fitting machine-specific GAMs (WITH group normalization)...")
corrector_normalized = GAMCorrector(
    covariates=["TimeIntoStudy", "VenepunctureDelay"],
    feature_columns=feature_cols,
    group_column="Machine",
    normalize_groups=True,  # Enable group normalization
    reference_group=None,  # normalise to overall mean (can also specify a reference machine)
    transformation="none",
    n_splines=25,
    verbose=True,
)
corrector_normalized.fit(df)
df_corrected_normalized = corrector_normalized.transform(df)

print(f"\nGroup offsets calculated: {len(corrector_normalized.group_offsets_)}")

### 5.1 Visualise Machine Offsets Before and After Normalization

Let's examine the systematic differences between machines before and after normalization.

In [None]:
# Calculate mean feature values per machine for different correction strategies
example_feature = feature_cols[0]

# Original data
means_original = df.groupby("Machine")[example_feature].mean().sort_index()

# After grouped correction (without normalization)
means_grouped = (
    df_corrected_grouped.groupby("Machine")[example_feature].mean().sort_index()
)

# After grouped correction with normalization
means_normalized = (
    df_corrected_normalized.groupby("Machine")[example_feature].mean().sort_index()
)

# Plot comparison
fig, axes = plt.subplots(1, 3, figsize=(6.6, 2.2))

machines = means_original.index
x = np.arange(len(machines))

# Original
axes[0].bar(x, means_original.values, alpha=0.8, color="red")
axes[0].axhline(
    means_original.mean(),
    color="black",
    linestyle="--",
    linewidth=2,
    label="Overall mean",
)
axes[0].set_xlabel("Machine")
axes[0].set_ylabel(f"{example_feature} Mean")
axes[0].set_title("Original Data")
axes[0].set_xticks(x)
axes[0].set_xticklabels([f"M{int(m)}" for m in machines])
axes[0].legend()
axes[0].grid(axis="y", alpha=0.3)

# Grouped (no normalization)
axes[1].bar(x, means_grouped.values, alpha=0.8, color="orange")
axes[1].axhline(
    means_grouped.mean(),
    color="black",
    linestyle="--",
    linewidth=2,
    label="Overall mean",
)
axes[1].set_xlabel("Machine")
axes[1].set_ylabel(f"{example_feature} Mean")
axes[1].set_title(
    "After Grouped GAM\n(covariates corrected, but machine offsets remain)"
)
axes[1].set_xticks(x)
axes[1].set_xticklabels([f"M{int(m)}" for m in machines])
axes[1].legend()
axes[1].grid(axis="y", alpha=0.3)

# normalised
axes[2].bar(x, means_normalized.values, alpha=0.8, color="green")
axes[2].axhline(
    means_normalized.mean(),
    color="black",
    linestyle="--",
    linewidth=2,
    label="Overall mean",
)
axes[2].set_xlabel("Machine")
axes[2].set_ylabel(f"{example_feature} Mean")
axes[2].set_title(
    "After Grouped GAM + Normalization\n(covariates AND machine offsets removed)"
)
axes[2].set_xticks(x)
axes[2].set_xticklabels([f"M{int(m)}" for m in machines])
axes[2].legend()
axes[2].grid(axis="y", alpha=0.3)

plt.tight_layout()
plt.show()

print(f"Machine mean variance for {example_feature}:")
print(f"  Original:           {means_original.var():.6f}")
print(f"  Grouped (no norm):  {means_grouped.var():.6f}")
print(f"  normalised:         {means_normalized.var():.6f}")
print(
    f"\nNormalization reduces machine mean variance by {(1 - means_normalized.var()/means_grouped.var())*100:.1f}%"
)

### 5.2 Compare All Features

Let's quantify the machine offset reduction across all features.

In [None]:
# Calculate machine mean variance for all features
variance_results = []

for feat in feature_cols:
    # Variance of machine means (measure of between-machine variability)
    var_original = df.groupby("Machine")[feat].mean().var()
    var_grouped = df_corrected_grouped.groupby("Machine")[feat].mean().var()
    var_normalized = df_corrected_normalized.groupby("Machine")[feat].mean().var()

    variance_results.append(
        {
            "Feature": feat,
            "Original": var_original,
            "Grouped": var_grouped,
            "normalised": var_normalized,
        }
    )

var_df = pd.DataFrame(variance_results)

# Plot
fig, ax = plt.subplots(figsize=(4.5, 2.2))
x = np.arange(len(feature_cols))
width = 0.25

ax.bar(x - width, var_df["Original"], width, label="Original", alpha=0.8, color="red")
ax.bar(
    x, var_df["Grouped"], width, label="Grouped (no norm)", alpha=0.8, color="orange"
)
ax.bar(
    x + width, var_df["normalised"], width, label="normalised", alpha=0.8, color="green"
)

ax.set_xlabel("Features")
ax.set_ylabel("Variance of Machine Means")
ax.set_title("Between-Machine Variability: Impact of Normalization")
ax.set_xticks(x)
ax.set_xticklabels(feature_cols, rotation=90)
ax.legend()
ax.grid(axis="y", alpha=0.3)
ax.set_yscale("log")  # Log scale to see differences more clearly

plt.tight_layout()
plt.show()

print("Average variance of machine means (log scale):")
print(f"  Original:          {var_df['Original'].mean():.6f}")
print(f"  Grouped (no norm): {var_df['Grouped'].mean():.6f}")
print(f"  normalised:        {var_df['normalised'].mean():.6f}")
print(
    f"\nAverage reduction: {(1 - var_df['normalised'].mean()/var_df['Grouped'].mean())*100:.1f}%"
)

In [None]:
# Compare covariate correction quality
print("Comparing covariate correction quality:\n")

for feat in sample_features:
    corr_multi = abs(df_corrected_multi[[feat, "VenepunctureDelay"]].corr().iloc[0, 1])
    corr_grouped = abs(
        df_corrected_grouped[[feat, "VenepunctureDelay"]].corr().iloc[0, 1]
    )
    corr_normalized = abs(
        df_corrected_normalized[[feat, "VenepunctureDelay"]].corr().iloc[0, 1]
    )

    print(f"{feat}:")
    print(f"  Multi-covariate (single GAM):     {corr_multi:.4f}")
    print(f"  Grouped (no norm):                {corr_grouped:.4f}")
    print(f"  Grouped + normalized:             {corr_normalized:.4f}")
    print()

print("Note: Normalization doesn't affect covariate correlations,")
print("it only removes systematic offsets between machines.")

## 6. Visualise Machine-Specific Effects

Let's see if the covariate relationships differ by machine.

In [None]:
# Plot feature vs. covariate for each machine
example_feature = "1"

fig, axes = plt.subplots(1, 2, figsize=(4.5, 2.2))

# Before correction
for machine in df["Machine"].unique():
    df_machine = df[df["Machine"] == machine]
    axes[0].scatter(
        df_machine["VenepunctureDelay"],
        df_machine[example_feature],
        alpha=0.3,
        s=10,
        label=f"Machine {machine}",
    )
axes[0].set_xlabel("Venepuncture Delay")
axes[0].set_ylabel(example_feature)
axes[0].set_title("Before Correction (by Machine)")
axes[0].legend()
axes[0].grid(alpha=0.3)

# After grouped correction
for machine in df["Machine"].unique():
    df_machine = df_corrected_grouped[df_corrected_grouped["Machine"] == machine]
    axes[1].scatter(
        df_machine["VenepunctureDelay"],
        df_machine[example_feature],
        alpha=0.3,
        s=10,
        label=f"Machine {machine}",
    )
axes[1].set_xlabel("Venepuncture Delay")
axes[1].set_ylabel(example_feature)
axes[1].set_title("After Machine-Specific Correction")
axes[1].legend()
axes[1].grid(alpha=0.3)

plt.tight_layout()
plt.show()

print("Machine-specific correction accounts for different covariate relationships")
print("across different instruments/batches.")

## 7. Inspect Learned Multi-Covariate Relationships

Visualise the partial dependence for each covariate in the multi-covariate model.

In [None]:
# Get partial dependence for both covariates
example_feature = "6"

fig, axes = plt.subplots(1, 2, figsize=(4.5, 2.2))

# TimeIntoStudy partial dependence
pdep_time = corrector_multi.get_partial_dependence(
    feature=example_feature, covariate="TimeIntoStudy"
)
axes[0].plot(
    pdep_time["covariate_value"],
    pdep_time["partial_dependence"],
    "b-",
    linewidth=2,
    label="Partial dependence",
)
axes[0].fill_between(
    pdep_time["covariate_value"],
    pdep_time["lower_ci"],
    pdep_time["upper_ci"],
    alpha=0.3,
    label="95% CI",
)
axes[0].set_xlabel("Time Into Study")
axes[0].set_ylabel(f"Effect on {example_feature}")
axes[0].set_title("Partial Dependence: TimeIntoStudy")
axes[0].legend()
axes[0].grid(alpha=0.3)

# VenepunctureDelay partial dependence
pdep_delay = corrector_multi.get_partial_dependence(
    feature=example_feature, covariate="VenepunctureDelay"
)
axes[1].plot(
    pdep_delay["covariate_value"],
    pdep_delay["partial_dependence"],
    "b-",
    linewidth=2,
    label="Partial dependence",
)
axes[1].fill_between(
    pdep_delay["covariate_value"],
    pdep_delay["lower_ci"],
    pdep_delay["upper_ci"],
    alpha=0.3,
    label="95% CI",
)
axes[1].set_xlabel("Venepuncture Delay")
axes[1].set_ylabel(f"Effect on {example_feature}")
axes[1].set_title("Partial Dependence: VenepunctureDelay")
axes[1].legend()
axes[1].grid(alpha=0.3)

plt.tight_layout()
plt.show()

print("These partial dependence plots show the learned effect of each covariate")
print("while holding the other covariate constant (marginal effects).")

## 8. Compare Correction Strategies

Let's create a comprehensive comparison of all strategies including normalization.

In [None]:
# Calculate total residual correlation (sum across both covariates)
strategies = []

for feat in feature_cols:
    # Original
    corr_time_orig = abs(df[[feat, "TimeIntoStudy"]].corr().iloc[0, 1])
    corr_delay_orig = abs(df[[feat, "VenepunctureDelay"]].corr().iloc[0, 1])
    total_orig = corr_time_orig + corr_delay_orig

    # Single-covariate
    corr_time_single = abs(
        df_corrected_single[[feat, "TimeIntoStudy"]].corr().iloc[0, 1]
    )
    corr_delay_single = abs(
        df_corrected_single[[feat, "VenepunctureDelay"]].corr().iloc[0, 1]
    )
    total_single = corr_time_single + corr_delay_single

    # Multi-covariate
    corr_time_multi = abs(df_corrected_multi[[feat, "TimeIntoStudy"]].corr().iloc[0, 1])
    corr_delay_multi = abs(
        df_corrected_multi[[feat, "VenepunctureDelay"]].corr().iloc[0, 1]
    )
    total_multi = corr_time_multi + corr_delay_multi

    # Grouped
    corr_time_grouped = abs(
        df_corrected_grouped[[feat, "TimeIntoStudy"]].corr().iloc[0, 1]
    )
    corr_delay_grouped = abs(
        df_corrected_grouped[[feat, "VenepunctureDelay"]].corr().iloc[0, 1]
    )
    total_grouped = corr_time_grouped + corr_delay_grouped

    # normalised
    corr_time_normalized = abs(
        df_corrected_normalized[[feat, "TimeIntoStudy"]].corr().iloc[0, 1]
    )
    corr_delay_normalized = abs(
        df_corrected_normalized[[feat, "VenepunctureDelay"]].corr().iloc[0, 1]
    )
    total_normalized = corr_time_normalized + corr_delay_normalized

    strategies.append(
        {
            "Feature": feat,
            "Original": total_orig,
            "Single-Covariate": total_single,
            "Multi-Covariate": total_multi,
            "Grouped": total_grouped,
            "normalised": total_normalized,
        }
    )

strategy_df = pd.DataFrame(strategies)

# Plot comparison
fig, ax = plt.subplots(figsize=(4.5, 2.2))
x = np.arange(len(feature_cols))
width = 0.16

ax.bar(
    x - 2 * width,
    strategy_df["Original"],
    width,
    label="Original",
    alpha=0.8,
    color="red",
)
ax.bar(
    x - width,
    strategy_df["Single-Covariate"],
    width,
    label="Single-Covariate",
    alpha=0.8,
    color="orange",
)
ax.bar(
    x,
    strategy_df["Multi-Covariate"],
    width,
    label="Multi-Covariate",
    alpha=0.8,
    color="blue",
)
ax.bar(
    x + width,
    strategy_df["Grouped"],
    width,
    label="Grouped (no norm)",
    alpha=0.8,
    color="purple",
)
ax.bar(
    x + 2 * width,
    strategy_df["normalised"],
    width,
    label="Grouped + normalised",
    alpha=0.8,
    color="green",
)

ax.set_xlabel("Features")
ax.set_ylabel("Total Absolute Correlation\n(sum across both covariates)")
ax.set_title("Correction Strategy Comparison")
ax.set_xticks(x)
ax.set_xticklabels(feature_cols, rotation=90)
ax.legend()
ax.grid(axis="y", alpha=0.3)

plt.tight_layout()
plt.show()

# Print summary
print("Average total correlation (lower is better):")
print(f"  Original:               {strategy_df['Original'].mean():.4f}")
print(f"  Single-Covariate:       {strategy_df['Single-Covariate'].mean():.4f}")
print(f"  Multi-Covariate:        {strategy_df['Multi-Covariate'].mean():.4f}")
print(f"  Grouped (no norm):      {strategy_df['Grouped'].mean():.4f}")
print(f"  Grouped + normalised:   {strategy_df['normalised'].mean():.4f}")

## 9. Save Models

Save the multi-covariate, grouped, and normalized models for reuse.

In [None]:
# Save models
corrector_multi.save("../notebooks/gam_corrector_multi.pkl")
print("Saved multi-covariate model")

corrector_grouped.save("../notebooks/gam_corrector_grouped.pkl")
print("Saved grouped model")

corrector_normalized.save("../notebooks/gam_corrector_normalized.pkl")
print("Saved normalized model")

# Load and verify
corrector_loaded = GAMCorrector.load("../notebooks/gam_corrector_multi.pkl")
df_test = df.sample(n=100, random_state=42)
df_test_corrected = corrector_loaded.transform(df_test)
print("\n✓ Models saved and loaded successfully")