# Spatial Partial Dependence Plots (GeoPDP) Demo

This notebook demonstrates how to use `geopdp` to visualize geographic model predictions. We will:

1.  **Generate Synthetic Data**: Create a dataset with a realistic, non-linear geographic disease pattern.
2.  **Train a Model**: Use a Gradient Boosting model to learn this complex spatial relationship.
3.  **Analyze with PDPs**: Use standard univariate PDPs to see marginal effects.
4.  **Analyze with GeoPDP**: Use `geopdp` to reveal the true geographic risk pattern on a map.

## Step 1: Setup and Imports

In [None]:
import geopandas as gpd
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from sklearn.ensemble import HistGradientBoostingClassifier
from sklearn.inspection import PartialDependenceDisplay
from sklearn.model_selection import train_test_split
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler

from geopdp import compute_geopdp, plot_geopdp
from geopdp.data import TANZANIA_GEOJSON

## Step 2: Load Geographic Data

We'll use Tanzania's regions as our geographic context.

In [None]:
gdf = gpd.read_file(TANZANIA_GEOJSON)
print(f"Loaded {len(gdf)} regions")

## Step 3: Generate Synthetic Data

We simulate a disease outbreak with a **sharp geographic boundary**. 

-   **Geographic Risk**: High risk on one side of a diagonal line, low risk on the other.
-   **Demographic Risk**: Older age and male gender slightly increase risk.
-   **Quasi-Deterministic Labels**: To clearly demonstrate the model's ability to learn spatial patterns, we use an almost deterministic threshold with very little noise.

In [None]:
np.random.seed(42)
n_samples = 10_000

regions = np.random.choice(gdf["NAME_1"].values, size=n_samples)
coords = []
for region in regions:
    geom = gdf[gdf["NAME_1"] == region].geometry.iloc[0]
    point = geom.representative_point()
    # Add noise around the centroid
    lon = point.x + np.random.normal(0, 0.5)
    lat = point.y + np.random.normal(0, 0.5)
    coords.append((lon, lat))

coords = np.array(coords)

df = pd.DataFrame(
    {
        "longitude": coords[:, 0],
        "latitude": coords[:, 1],
        "region": regions,
        "age": np.random.randint(18, 80, size=n_samples),
        "is_male": np.random.binomial(1, 0.5, size=n_samples),
    }
)


def generate_disease_risk(row):
    """Generate disease with SHARP geographic boundary."""

    lat = row["latitude"]
    lon = row["longitude"]

    # Define the diagonal line endpoints
    x1, y1 = 29, -12
    x2, y2 = 41, -0.5

    # Compute distance from point to the line
    num = (y2 - y1) * lon - (x2 - x1) * lat + x2 * y1 - y2 * x1
    den = np.sqrt((y2 - y1) ** 2 + (x2 - x1) ** 2)
    distance = np.abs(num / den)

    if distance < 0.7:
        geo_risk = 0.99
    elif distance < 1:
        geo_risk = (1 - distance) / 0.3
    else:
        geo_risk = 0.01

    # Add tiny age/gender effects
    age_effect = (row["age"] - 50) * 0.001
    gender_effect = (row["is_male"] == 1) * 0.01

    risk = geo_risk + age_effect + gender_effect

    return np.clip(risk, 0, 1)


# Generate labels
df["disease_prob"] = df.apply(generate_disease_risk, axis=1)

# Deterministic labels for high accuracy demonstration
df["has_disease"] = np.random.binomial(1, df["disease_prob"])

print(f"Disease prevalence: {df['has_disease'].mean():.1%}")
df[["longitude", "latitude", "age", "disease_prob", "has_disease"]].head()

### Visualize the Ground Truth

Let's visualize the true disease distribution. Notice the sharp diagonal boundary dividing high-risk and low-risk areas.

In [None]:
# Plot the geographic pattern
plt.figure(figsize=(12, 5))

plt.subplot(1, 2, 1)
scatter = plt.scatter(
    df["longitude"],
    df["latitude"],
    c=df["disease_prob"],
    cmap="RdYlGn_r",
    alpha=0.5,
    s=1,
)
plt.colorbar(scatter, label="Disease Probability")
plt.xlabel("Longitude")
plt.ylabel("Latitude")
plt.title("True Disease Risk (Continuous)")

# Draw the diagonal line
x_line = [29, 41]
y_line = [-12, -0.5]
plt.plot(x_line, y_line, "b--", linewidth=2, label="Boundary")
plt.legend()

plt.subplot(1, 2, 2)
scatter = plt.scatter(
    df["longitude"],
    df["latitude"],
    c=df["has_disease"],
    cmap="RdYlGn_r",
    alpha=0.5,
    s=1,
)
plt.colorbar(scatter, label="Has Disease")
plt.xlabel("Longitude")
plt.ylabel("Latitude")
plt.title("True Disease Status (Binary)")
plt.plot(x_line, y_line, "b--", linewidth=2, label="Boundary")
plt.legend()

plt.tight_layout()
plt.show()

## Step 4: Train a Model

We use a **HistGradientBoostingClassifier**, which is excellent at learning smooth, non-linear patterns like our geographic boundary. Unlike Random Forests which struggle with diagonal lines (needing many rectangular splits), Gradient Boosting can approximate this boundary very effectively.

In [None]:
# Features and target
feature_cols = ["longitude", "latitude", "age", "is_male"]
X = df[feature_cols]
y = df["has_disease"]

# Split
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=1)

# Train pipeline
pipe = Pipeline(
    [
        ("scaler", StandardScaler()),
        (
            "clf",
            HistGradientBoostingClassifier(
                max_iter=200, learning_rate=0.1, max_depth=8, random_state=42
            ),
        ),
    ]
)

pipe.fit(X_train, y_train)

print(f"Test accuracy: {pipe.score(X_test, y_test):.3f}")

## Step 5: Partial Dependence Analysis

Now we analyze what the model learned. We'll start with standard Univariate PDPs and then use GeoPDP.

In [None]:
# 1D PDPs
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Longitude PDP
PartialDependenceDisplay.from_estimator(
    pipe, X_test, features=["longitude"], ax=axes[0], kind="average", grid_resolution=50
)
axes[0].set_title("Partial Dependence: Longitude", fontsize=14)
axes[0].set_ylabel("Predicted Disease Probability", fontsize=12)
axes[0].set_xlabel("Longitude", fontsize=12)

# Latitude PDP
PartialDependenceDisplay.from_estimator(
    pipe, X_test, features=["latitude"], ax=axes[1], kind="average", grid_resolution=50
)
axes[1].set_title("Partial Dependence: Latitude", fontsize=14)
axes[1].set_ylabel("Predicted Disease Probability", fontsize=12)
axes[1].set_xlabel("Latitude", fontsize=12)

plt.tight_layout()
plt.show()

### Interpretation of Univariate PDPs

The univariate plots show the **marginal effect** of longitude and latitude:
-   **Longitude**: Shows a trend of risk raising from East to West untill a longitude of 36°E is reached, then it is decreasing.
-   **Latitude**: Shows that the model’s predictions increase steadily toward latitudes around –4 to –3, where they peak.

However, these plots **fail to capture the diagonal boundary**. They treat longitude and latitude independently, missing the crucial interaction that defines the risk zone.

### 2D Partial Dependence Plot

Standard sklearn can show **2D interactions** between longitude and latitude. Let's visualize this:

In [None]:
# 2D PDP showing longitude-latitude interaction
fig, ax = plt.subplots(figsize=(10, 8))

display = PartialDependenceDisplay.from_estimator(
    pipe,
    X_test,
    features=[(0, 1)],  # longitude and latitude interaction
    ax=ax,
    kind="average",
    grid_resolution=50,
)

ax.set_title("2D Partial Dependence: Longitude x Latitude", fontsize=14, pad=20)
ax.set_xlabel("Longitude", fontsize=12)
ax.set_ylabel("Latitude", fontsize=12)

plt.tight_layout()
plt.show()

### Why This Still Isn't Enough

The 2D PDP heatmap **does** show the diagonal boundary pattern! You can see the high-risk zone separated from the low-risk zone. However, it has critical limitations:

-   **No Geographic Context**: The plot is just a grid of longitude/latitude coordinates, it doesn't show actual country boundaries, regions, cities, or landmarks.
-   **Hard to Interpret**: Without a map overlay, it's difficult to know which **real-world locations** correspond to high/low risk areas.
-   **No Actionable Insights**: Policy makers and stakeholders need to see risks on an actual map to make decisions about resource allocation.

This is where **GeoPDP** shines: it combines the statistical power of PDPs with the intuitive clarity of a geographic map.

## Step 6: Geographic Partial Dependence (GeoPDP)

Now we use `geopdp` to visualize the **joint geographic effect**. By computing predictions for each region's centroid (while holding other features constant), we can map the predicted risk directly onto the geography.

In [None]:
# Compute GeoPDP
pdp_results = compute_geopdp(
    X_test,
    pipe,
    col_index_to_predict=1,  # Probability of disease (class 1)
    geojson=TANZANIA_GEOJSON,
    region_col="region",
    geojson_region_property="NAME_1",
    lon_col="longitude",
    lat_col="latitude",
)

pdp_results.head()

In [None]:
# Visualize GeoPDP
fig = plot_geopdp(
    pdp_results,
    geojson=TANZANIA_GEOJSON,
    region_col="region",
    geojson_region_property="NAME_1",
    color_scale="Viridis",
    title="Predicted Disease Risk by Region (Spatial PDP)",
)

fig.show()

## Interpretation

The GeoPDP reveals the **true geographic pattern**:
-   We clearly see the **diagonal boundary** dividing the country.
-   Regions on one side are red (high risk), and on the other are blue/green (low risk).
-   This insight was **completely invisible** in the univariate PDPs!

This demonstrates why `geopdp` is essential for models with geographic features: it translates complex spatial interactions into intuitive maps.

## Bonus: Geometry Simplification

For faster rendering, we can simplify the GeoJSON. Zooming in will show the effect of the simplification.

In [None]:
from geopdp.geometry import simplify_geojson
from geopdp.visualization import compare_geojson_geometry

# Simplify
simplified_gdf = simplify_geojson(TANZANIA_GEOJSON, tolerance=0.05, precision=0.001)

# Compare
fig = compare_geojson_geometry(
    TANZANIA_GEOJSON, simplified_gdf, title="Geometry Simplification Comparison"
)
fig.show()

## Summary

This demo showed:
1.  **Synthetic Data Generation**: Creating a dataset with a sharp, realistic geographic risk boundary.
2.  **Model Training**: Using Gradient Boosting to accurately learn non-linear spatial patterns.
3.  **Univariate PDPs**: Showing marginal effects of longitude and latitude (which miss the diagonal interaction).
4.  **GeoPDP**: Revealing the true diagonal risk boundary on a map, providing clear and actionable geographic insights.

**Key Insight**: Spatial PDPs are essential for interpreting models with geographic features!