# 04 - Spline Effects Model

This notebook explores nonlinear effects of age and mileage using B-splines,
while retaining partial pooling (random intercepts) on categorical predictors.

**Model comparison:**
| Model | Continuous | Categoricals |
|-------|------------|---------------|
| Hierarchical | Linear age + random slope by gen | Random intercepts |
| Spline | `bs(age)` + `bs(mileage_scaled)` | Random intercepts |

In [None]:
import logging
from pathlib import Path

import arviz as az
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

from price_analysis.data.cleaning import prepare_model_data
from price_analysis.models import (
    build_model,
    build_spline_model,
    compare_models_loo,
    compare_residual_stats,
    fit_model,
    fit_spline_model,
    get_residuals,
    plot_residual_diagnostics,
    plot_spline_effects_grid,
)
from price_analysis.models.hierarchical import check_diagnostics

logging.basicConfig(level=logging.INFO)
az.style.use("arviz-darkgrid")
plt.rcParams["figure.figsize"] = (12, 6)

In [None]:
DATA_DIR = Path("../data")
PROCESSED_PATH = DATA_DIR / "processed" / "cleaned_listings.parquet"

## Load and Prepare Data

In [None]:
df_cleaned = pd.read_parquet(PROCESSED_PATH)
df = prepare_model_data(df_cleaned, group_trims=True, group_trans=True)

MILEAGE_MEAN, MILEAGE_STD = df_cleaned["mileage"].mean(), df_cleaned["mileage"].std()

print(f"Model data: {len(df)} listings")
print(f"Age range: {df['age'].min()} - {df['age'].max()} years")
print(
    f"Mileage range (scaled): {df['mileage_scaled'].min():.2f} - {df['mileage_scaled'].max():.2f}"
)

## Build Spline Model

In [None]:
AGE_DF = 4  # Degrees of freedom for age spline
MILEAGE_DF = 4  # Degrees of freedom for mileage spline
INCLUDE_SALE_YEAR = False  # All 2025 listings
TARGET_ACCEPT = 0.975

In [None]:
spline_model = build_spline_model(
    df,
    age_df=AGE_DF,
    mileage_df=MILEAGE_DF,
    include_sale_year=INCLUDE_SALE_YEAR,
)
print(spline_model)

## Fit Spline Model

In [None]:
%%time
idata_spline = fit_spline_model(
    spline_model,
    draws=2000,
    tune=1000,
    chains=8,
    target_accept=TARGET_ACCEPT,
)

## Diagnostics

In [None]:
diagnostics = check_diagnostics(idata_spline)
print(f"Converged: {diagnostics['converged']}")
print(f"Divergences: {diagnostics['n_divergences']}")
print(f"Max R-hat: {diagnostics['rhat_max']:.3f}")
print(f"Min ESS (bulk): {diagnostics['ess_bulk_min']:.0f}")
if diagnostics["issues"]:
    print(f"Issues: {diagnostics['issues']}")

In [None]:
var_names = ["Intercept", "is_low_mileage"]
az.plot_trace(idata_spline, var_names=var_names)
plt.tight_layout()

## Visualize Spline Effects

These plots show the estimated nonlinear relationship between age/mileage and log(price),
holding other variables at their median/mode values.

In [None]:
fig = plot_spline_effects_grid(spline_model, idata_spline, df)
fig.suptitle("Spline Effects on log(price)", y=1.02)

## Fit Hierarchical Model for Comparison

In [None]:
hierarchical_model = build_model(
    df,
    include_sale_year=INCLUDE_SALE_YEAR,
    include_generation_slope=True,
    use_trim_tier=True,
    use_trans_type=True,
)
print(hierarchical_model)

In [None]:
%%time
idata_hierarchical = fit_model(
    hierarchical_model,
    draws=2000,
    tune=1000,
    chains=8,
    target_accept=TARGET_ACCEPT,
)

## LOO-CV Model Comparison

Compare models using Leave-One-Out Cross-Validation (PSIS-LOO).
Higher ELPD = better out-of-sample predictive performance.

In [None]:
comparison = compare_models_loo(
    {
        "hierarchical (linear + random slopes)": idata_hierarchical,
        "spline (bs age/mileage + random intercepts)": idata_spline,
    }
)
display(comparison)

In [None]:
az.plot_compare(comparison)
plt.title("Model Comparison (LOO-CV)")

## Residual Comparison

Compare residual patterns to see if the spline model better captures
nonlinear relationships that the linear model misses.

In [None]:
residuals_hierarchical = get_residuals(hierarchical_model, idata_hierarchical, df)
residuals_spline = get_residuals(spline_model, idata_spline, df)

residuals_dict = {
    "Hierarchical": residuals_hierarchical,
    "Spline": residuals_spline,
}

In [None]:
stats_df = compare_residual_stats(residuals_dict)
display(stats_df.round(4))

In [None]:
fig = plot_residual_diagnostics(residuals_dict)

## Degrees of Freedom Sensitivity

Quick check: does the number of spline basis functions matter much?

In [None]:
# Compare df=3 vs df=5 (optional - can be slow)
# Uncomment to run sensitivity analysis

# df_sensitivity = {}
# for df_val in [3, 5]:
#     model = build_spline_model(df, age_df=df_val, mileage_df=df_val, include_sale_year=False)
#     idata = fit_spline_model(model, draws=1000, tune=500, chains=4)
#     df_sensitivity[f"df={df_val}"] = idata

# df_sensitivity["df=4"] = idata_spline
# sensitivity_comparison = compare_models_loo(df_sensitivity)
# display(sensitivity_comparison)

## Summary

**Key findings:**

1. **Spline effects**: [interpret the shape of age/mileage curves]

2. **Model comparison**: [ELPD difference and interpretation]

3. **Residuals**: [any patterns resolved by splines?]

**Recommendations:**

- [which model to prefer for inference vs prediction?]
- [any follow-up analyses needed?]