# Advanced GAM Terms: Tensor Products, Factor Terms, and Cyclic Splines

This notebook demonstrates advanced GAM term types using the `term_spec` parameter.

## Term Types

1. **Smooth splines (`s`)**: For continuous covariates (default)
2. **Tensor products (`te`)**: For interactions between continuous covariates
3. **Factor terms (`f`)**: For categorical variables
4. **Cyclic splines (`s` with `basis='cp'`)**: For periodic covariates (e.g., time-of-day)

## When to Use Each Term Type

- **Tensor products**: When covariate effects depend on each other (e.g., sample age effect varies by time-of-day)
- **Factor terms**: For categorical variables with discrete levels (e.g., weekday, machine ID)
- **Cyclic splines**: For periodic variables where end connects to beginning (e.g., 23:59 ≈ 00:00)

In [None]:
import sys
import numpy as np
import pandas as pd
from pathlib import Path

# Add sysmexcorrect to path
sys.path.insert(0, "../../sysmexcbctools/correction")

from sysmexcbctools.correction.sysmexcorrect import GAMCorrector

# Standard plotting imports
import matplotlib
import matplotlib.pyplot as plt
import seaborn as sns

# PDF-compatible fonts
matplotlib.rcParams["pdf.fonttype"] = 42
matplotlib.rcParams["ps.fonttype"] = 42

# Scientific plot style
import scienceplots

plt.style.use(["science", "nature"])

# Colourblind-friendly palette
SEABORN_PALETTE = "colorblind"
seaborn_colors = sns.color_palette(SEABORN_PALETTE)

np.random.seed(42)

## 1. Load and Prepare Data

We'll add some additional covariates to demonstrate the advanced term types.

In [None]:
# Load data
data_path = Path("../data/data_B.csv")
df = pd.read_csv(data_path)

# Add categorical weekday (0-6)
df["Weekday"] = (df["TimeIntoStudy"] / 24).astype(int) % 7

# Add time-of-day (0-24 hours, periodic)
df["TimeOfDay"] = (df["VenepunctureDelay"] % 24).clip(0, 24)

# Create synthetic interaction effect:
# Feature '6' is affected by BOTH VenepunctureDelay AND TimeOfDay
# The effect of VenepunctureDelay depends on TimeOfDay
interaction_effect = (
    0.5 * np.sin(2 * np.pi * df["TimeOfDay"] / 24) * (df["VenepunctureDelay"] / 10)
)
df["6"] = df["6"] + interaction_effect

# Create synthetic weekday effect:
# Feature '7' has different baseline on weekends (days 5, 6)
weekend_mask = df["Weekday"].isin([5, 6])
df["7"] = df["7"] + 2.0 * weekend_mask

print(f"Data shape: {df.shape}")
print(f"\nCovariates:")
print(f"  - VenepunctureDelay: continuous (0-24+ hours)")
print(f"  - TimeOfDay: periodic (0-24 hours)")
print(f"  - TimeIntoStudy: continuous trend")
print(f"  - Weekday: categorical (0-6)")
print(f"\nFeatures:")
print(f"  - Feature '6': Has interaction between VenepunctureDelay × TimeOfDay")
print(f"  - Feature '7': Has weekday effect (higher on weekends)")

## 2. Visualize the Synthetic Effects

Let's visualize the interaction and categorical effects we added.

In [None]:
# Advanced GAM with term_spec
corrector_advanced = GAMCorrector(
    covariates=["VenepunctureDelay", "TimeOfDay", "TimeIntoStudy", "Weekday"],
    feature_columns=["6", "7"],
    term_spec={
        # Tensor product: VenepunctureDelay × TimeOfDay interaction
        ("VenepunctureDelay", "TimeOfDay"): {"type": "te", "n_splines": 20},
        # Smooth term for long-term drift
        "TimeIntoStudy": {"type": "s", "n_splines": 30},
        # Factor term for categorical weekday
        "Weekday": {"type": "f"},
    },
    transformation="none",
    verbose=True,
)

corrector_advanced.fit(df)
df_advanced = corrector_advanced.transform(df)

print("\n✓ Advanced GAM fitted")

## 3. Standard GAM (No Interactions)

First, let's fit a standard GAM with all smooth terms (no interactions). This will serve as our baseline.

In [None]:
# Standard GAM: all smooth terms, no interactions
corrector_standard = GAMCorrector(
    covariates=["VenepunctureDelay", "TimeOfDay", "TimeIntoStudy", "Weekday"],
    feature_columns=["6", "7"],
    n_splines=25,
    transformation="none",
    verbose=True,
)

corrector_standard.fit(df)
df_standard = corrector_standard.transform(df)

print("\n✓ Standard GAM fitted")

## 4. Advanced GAM with term_spec

Now let's use `term_spec` to properly model:
- **Tensor product** for VenepunctureDelay × TimeOfDay interaction
- **Factor term** for categorical Weekday
- **Cyclic spline** for periodic TimeOfDay
- **Smooth spline** for TimeIntoStudy trend

In [None]:
# Advanced GAM with term_spec
corrector_advanced = GAMCorrector(
    covariates=["VenepunctureDelay", "TimeOfDay", "TimeIntoStudy", "Weekday"],
    feature_columns=["6", "7"],
    term_spec={
        # Tensor product: VenepunctureDelay × TimeOfDay interaction
        ("VenepunctureDelay", "TimeOfDay"): {"type": "te", "n_splines": 20},
        # Smooth term for long-term drift
        "TimeIntoStudy": {"type": "s", "n_splines": 30},
        # Factor term for categorical weekday
        "Weekday": {"type": "f"},
    },
    transformation="none",
    verbose=True,
)

corrector_advanced.fit(df)
df_advanced = corrector_advanced.transform(df)

print("\n✓ Advanced GAM fitted")

## 5. Compare Correction Quality

Let's compare how well each model removed the effects.

In [None]:
print("=== Residual Correlations: Feature 6 (interaction effect) ===")
print("\nOriginal:")
for cov in ["VenepunctureDelay", "TimeOfDay"]:
    corr = df[["6", cov]].corr().iloc[0, 1]
    print(f"  {cov:20s}: {corr:+.4f}")

print("\nStandard GAM (no interaction):")
for cov in ["VenepunctureDelay", "TimeOfDay"]:
    corr = df_standard[["6", cov]].corr().iloc[0, 1]
    print(f"  {cov:20s}: {corr:+.4f}")

print("\nAdvanced GAM (with tensor product):")
for cov in ["VenepunctureDelay", "TimeOfDay"]:
    corr = df_advanced[["6", cov]].corr().iloc[0, 1]
    print(f"  {cov:20s}: {corr:+.4f}")

print("\n" + "=" * 60)
print("=== Weekday Effect: Feature 7 (categorical effect) ===")
print("\nOriginal:")
weekday_means_orig = df.groupby("Weekday")["7"].mean()
print(f"  Weekend mean: {weekday_means_orig[[5,6]].mean():.3f}")
print(f"  Weekday mean: {weekday_means_orig[[0,1,2,3,4]].mean():.3f}")
print(
    f"  Difference:   {weekday_means_orig[[5,6]].mean() - weekday_means_orig[[0,1,2,3,4]].mean():.3f}"
)

print("\nStandard GAM (smooth term for Weekday):")
weekday_means_std = df_standard.groupby("Weekday")["7"].mean()
print(f"  Weekend mean: {weekday_means_std[[5,6]].mean():.3f}")
print(f"  Weekday mean: {weekday_means_std[[0,1,2,3,4]].mean():.3f}")
print(
    f"  Difference:   {weekday_means_std[[5,6]].mean() - weekday_means_std[[0,1,2,3,4]].mean():.3f}"
)

print("\nAdvanced GAM (factor term for Weekday):")
weekday_means_adv = df_advanced.groupby("Weekday")["7"].mean()
print(f"  Weekend mean: {weekday_means_adv[[5,6]].mean():.3f}")
print(f"  Weekday mean: {weekday_means_adv[[0,1,2,3,4]].mean():.3f}")
print(
    f"  Difference:   {weekday_means_adv[[5,6]].mean() - weekday_means_adv[[0,1,2,3,4]].mean():.3f}"
)

## 6. Visualize Correction Results

In [None]:
fig, axes = plt.subplots(2, 3, figsize=(6.6, 4.5))

# Row 1: Interaction effect (Feature 6)
for ax, (df_plot, title) in zip(
    axes[0],
    [(df, "Original"), (df_standard, "Standard GAM"), (df_advanced, "Advanced GAM")],
):
    tod_bins = pd.cut(
        df_plot["TimeOfDay"], bins=[0, 8, 16, 24], labels=["Night", "Day", "Evening"]
    )
    for tod_label, color in zip(
        ["Night", "Day", "Evening"], ["blue", "orange", "green"]
    ):
        mask = tod_bins == tod_label
        subset = df_plot[mask].sort_values("VenepunctureDelay")

        bins = pd.cut(subset["VenepunctureDelay"], bins=10)
        binned = subset.groupby(bins)[["VenepunctureDelay", "6"]].mean()

        ax.plot(
            binned["VenepunctureDelay"],
            binned["6"],
            "o-",
            label=tod_label,
            color=color,
            linewidth=2,
            markersize=4,
        )

    ax.set_xlabel("VenepunctureDelay (hours)")
    ax.set_ylabel("Feature 6")
    ax.set_title(f"{title}\nInteraction Effect")
    ax.legend()
    ax.grid(alpha=0.3)

# Row 2: Weekday effect (Feature 7)
weekday_labels = ["Mon", "Tue", "Wed", "Thu", "Fri", "Sat", "Sun"]
colors = ["steelblue"] * 5 + ["coral", "coral"]

for ax, (df_plot, title) in zip(
    axes[1],
    [(df, "Original"), (df_standard, "Standard GAM"), (df_advanced, "Advanced GAM")],
):
    weekday_means = df_plot.groupby("Weekday")["7"].mean()
    ax.bar(range(7), weekday_means.values, color=colors, edgecolor="black")
    ax.set_xticks(range(7))
    ax.set_xticklabels(weekday_labels)
    ax.set_xlabel("Day of Week")
    ax.set_ylabel("Feature 7 (mean)")
    ax.set_title(f"{title}\nWeekday Effect")
    ax.grid(alpha=0.3, axis="y")

plt.tight_layout()
plt.show()

print("Observations:")
print("- Top row: Tensor product (te) properly models the interaction")
print("  → Different slopes at different times of day are corrected")
print("- Bottom row: Factor term (f) removes categorical weekday effect completely")
print("  → Weekend elevation is removed, all days have same mean")

## 7. Cyclic Splines for Periodic Covariates

Cyclic splines ensure that 23:59 connects smoothly to 00:00 for periodic variables.

In [None]:
# Create synthetic periodic effect (24-hour cycle)
df_cyclic = df.copy()
periodic_effect = 3.0 * np.sin(2 * np.pi * df_cyclic["TimeOfDay"] / 24)
df_cyclic["8"] = df_cyclic["8"] + periodic_effect

print("Added 24-hour periodic effect to feature '8'")
print("\nFitting with standard (non-cyclic) spline...")
corrector_noncyclic = GAMCorrector(
    covariates=["TimeOfDay"],
    feature_columns=["8"],
    n_splines=25,
    transformation="none",
    verbose=False,
)
corrector_noncyclic.fit(df_cyclic)

print("Fitting with cyclic spline...")
corrector_cyclic = GAMCorrector(
    covariates=["TimeOfDay"],
    feature_columns=["8"],
    term_spec={
        "TimeOfDay": {
            "type": "s",
            "n_splines": 25,
            "basis": "cp",  # Cyclic penalized basis
        }
    },
    transformation="none",
    verbose=False,
)
corrector_cyclic.fit(df_cyclic)

print("\n✓ Both models fitted")

In [None]:
# Get partial dependence
pdep_noncyclic = corrector_noncyclic.get_partial_dependence(
    feature="8", covariate="TimeOfDay"
)
pdep_cyclic = corrector_cyclic.get_partial_dependence(
    feature="8", covariate="TimeOfDay"
)

fig, axes = plt.subplots(1, 2, figsize=(4.5, 2.2))

# Non-cyclic
axes[0].plot(
    pdep_noncyclic["covariate_value"],
    pdep_noncyclic["partial_dependence"],
    "b-",
    linewidth=2,
    label="Fitted curve",
)
axes[0].fill_between(
    pdep_noncyclic["covariate_value"],
    pdep_noncyclic["lower_ci"],
    pdep_noncyclic["upper_ci"],
    alpha=0.3,
)
axes[0].axvline(0, color="red", linestyle="--", linewidth=2, label="Start/End")
axes[0].axvline(24, color="red", linestyle="--", linewidth=2)
axes[0].set_xlabel("TimeOfDay (hours)")
axes[0].set_ylabel("Effect on Feature 8")
axes[0].set_title("Standard Spline (basis=cr)\nDiscontinuity at 0/24 hours")
axes[0].legend()
axes[0].grid(alpha=0.3)

# Cyclic
axes[1].plot(
    pdep_cyclic["covariate_value"],
    pdep_cyclic["partial_dependence"],
    "g-",
    linewidth=2,
    label="Fitted curve",
)
axes[1].fill_between(
    pdep_cyclic["covariate_value"],
    pdep_cyclic["lower_ci"],
    pdep_cyclic["upper_ci"],
    alpha=0.3,
    color="green",
)
axes[1].axvline(0, color="red", linestyle="--", linewidth=2, label="Start/End")
axes[1].axvline(24, color="red", linestyle="--", linewidth=2)
axes[1].set_xlabel("TimeOfDay (hours)")
axes[1].set_ylabel("Effect on Feature 8")
axes[1].set_title("Cyclic Spline (basis=cp)\nSmooth connection at 0/24 hours")
axes[1].legend()
axes[1].grid(alpha=0.3)

plt.tight_layout()
plt.show()

print("Observations:")
print("- Standard spline: Values at 0h and 24h are disconnected")
print("- Cyclic spline: Values at 0h and 24h connect smoothly (as they should)")
print(
    "\nUse cyclic basis='cp' for periodic variables like time-of-day, day-of-year, etc."
)