# Exploratory Data Analysis — Water Quality Prediction

This notebook explores the training data, feature distributions, spatial patterns, 
and relationships between predictors and targets.

In [None]:
import warnings
warnings.filterwarnings("ignore")

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path

plt.style.use("seaborn-v0_8-whitegrid")
plt.rcParams["figure.figsize"] = (12, 6)
plt.rcParams["font.size"] = 11

DATA_DIR = Path("..") / "datasets"
print("Datasets available:", [f.name for f in DATA_DIR.iterdir() if f.is_file()])

## 1. Load Training Data

In [None]:
train = pd.read_csv(DATA_DIR / "water_quality_training_dataset.csv")
train["Sample Date"] = pd.to_datetime(train["Sample Date"], dayfirst=True)

val = pd.read_csv(DATA_DIR / "submission_template.csv")
val["Sample Date"] = pd.to_datetime(val["Sample Date"], dayfirst=True)

print(f"Training:   {train.shape[0]} rows x {train.shape[1]} cols")
print(f"Validation: {val.shape[0]} rows x {val.shape[1]} cols")
print(f"\nTraining date range: {train['Sample Date'].min()} to {train['Sample Date'].max()}")
print(f"Validation date range: {val['Sample Date'].min()} to {val['Sample Date'].max()}")
print(f"\nTraining columns: {list(train.columns)}")
train.head()

## 2. Target Variable Distributions

In [None]:
targets = ["Total Alkalinity", "Electrical Conductance", "Dissolved Reactive Phosphorus"]

print("Target statistics:")
print(train[targets].describe().round(2))

print(f"\nMissing values:")
print(train[targets].isna().sum())

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(16, 5))

for i, col in enumerate(targets):
    axes[i].hist(train[col].dropna(), bins=50, edgecolor="black", alpha=0.7)
    axes[i].set_title(col)
    axes[i].set_xlabel("Value")
    axes[i].set_ylabel("Count")
    axes[i].axvline(train[col].median(), color="red", linestyle="--", label=f"Median: {train[col].median():.1f}")
    axes[i].legend()

plt.suptitle("Target Variable Distributions", fontsize=14, y=1.02)
plt.tight_layout()
plt.savefig("../outputs/figures/target_distributions.png", dpi=150, bbox_inches="tight")
plt.show()

In [None]:
# Log-transformed DRP — check if log helps the bimodal distribution
fig, axes = plt.subplots(1, 2, figsize=(12, 4))

drp = train["Dissolved Reactive Phosphorus"].dropna()
axes[0].hist(drp, bins=50, edgecolor="black", alpha=0.7)
axes[0].set_title("DRP — Original")

axes[1].hist(np.log1p(drp), bins=50, edgecolor="black", alpha=0.7, color="orange")
axes[1].set_title("DRP — Log1p Transformed")

plt.tight_layout()
plt.show()

## 3. Target Correlations

In [None]:
corr = train[targets].corr()
print("Target correlations:")
print(corr.round(3))

fig, ax = plt.subplots(figsize=(6, 5))
sns.heatmap(corr, annot=True, cmap="RdBu_r", center=0, vmin=-1, vmax=1, ax=ax)
ax.set_title("Target Variable Correlations")
plt.tight_layout()
plt.show()

## 4. Spatial Analysis

In [None]:
# Unique stations
train_stations = train.groupby(["Latitude", "Longitude"]).agg(
    n_samples=("Total Alkalinity", "count"),
    mean_ta=("Total Alkalinity", "mean"),
    mean_ec=("Electrical Conductance", "mean"),
    mean_drp=("Dissolved Reactive Phosphorus", "mean"),
).reset_index()

val_stations = val[["Latitude", "Longitude"]].drop_duplicates()

print(f"Training stations: {len(train_stations)}")
print(f"Validation stations: {len(val_stations)}")
print(f"\nTraining lat range: {train_stations.Latitude.min():.2f} to {train_stations.Latitude.max():.2f}")
print(f"Training lon range: {train_stations.Longitude.min():.2f} to {train_stations.Longitude.max():.2f}")
print(f"Validation lat range: {val_stations.Latitude.min():.2f} to {val_stations.Latitude.max():.2f}")
print(f"Validation lon range: {val_stations.Longitude.min():.2f} to {val_stations.Longitude.max():.2f}")

In [None]:
# Map of training vs validation stations
fig, ax = plt.subplots(figsize=(10, 10))

scatter = ax.scatter(
    train_stations["Longitude"], train_stations["Latitude"],
    c=train_stations["mean_ta"], cmap="viridis",
    s=train_stations["n_samples"] * 2, alpha=0.7,
    edgecolors="black", linewidth=0.5, label="Training"
)
ax.scatter(
    val_stations["Longitude"], val_stations["Latitude"],
    c="red", s=80, marker="^", edgecolors="black",
    linewidth=0.5, label="Validation", zorder=5
)

plt.colorbar(scatter, label="Mean Total Alkalinity", shrink=0.7)
ax.set_xlabel("Longitude")
ax.set_ylabel("Latitude")
ax.set_title("Station Locations — Training (circles) vs Validation (triangles)")
ax.legend()
plt.tight_layout()
plt.savefig("../outputs/figures/station_map.png", dpi=150, bbox_inches="tight")
plt.show()

In [None]:
# Samples per station distribution
fig, ax = plt.subplots(figsize=(10, 4))
train_stations["n_samples"].hist(bins=30, edgecolor="black", ax=ax)
ax.set_xlabel("Number of Samples per Station")
ax.set_ylabel("Number of Stations")
ax.set_title("Samples per Station")
ax.axvline(train_stations["n_samples"].median(), color="red", linestyle="--",
           label=f"Median: {train_stations['n_samples'].median():.0f}")
ax.legend()
plt.tight_layout()
plt.show()

## 5. Temporal Patterns

In [None]:
train["month"] = train["Sample Date"].dt.month
train["year"] = train["Sample Date"].dt.year

# Monthly mean for each target
fig, axes = plt.subplots(1, 3, figsize=(16, 5))
month_names = ["Jan", "Feb", "Mar", "Apr", "May", "Jun",
               "Jul", "Aug", "Sep", "Oct", "Nov", "Dec"]

for i, col in enumerate(targets):
    monthly = train.groupby("month")[col].agg(["mean", "std"]).reset_index()
    axes[i].bar(monthly["month"], monthly["mean"], yerr=monthly["std"],
               capsize=3, alpha=0.7, edgecolor="black")
    axes[i].set_xticks(range(1, 13))
    axes[i].set_xticklabels(month_names, rotation=45)
    axes[i].set_title(col)
    axes[i].set_ylabel("Mean Value")

plt.suptitle("Monthly Average by Target (South Africa: DJF=Summer/Wet, JJA=Winter/Dry)", fontsize=13, y=1.02)
plt.tight_layout()
plt.savefig("../outputs/figures/monthly_patterns.png", dpi=150, bbox_inches="tight")
plt.show()

In [None]:
# Yearly trends
fig, axes = plt.subplots(1, 3, figsize=(16, 5))

for i, col in enumerate(targets):
    yearly = train.groupby("year")[col].agg(["mean", "std"]).reset_index()
    axes[i].bar(yearly["year"], yearly["mean"], yerr=yearly["std"],
               capsize=3, alpha=0.7, edgecolor="black")
    axes[i].set_title(col)
    axes[i].set_ylabel("Mean Value")
    axes[i].set_xlabel("Year")

plt.suptitle("Yearly Average by Target", fontsize=13, y=1.02)
plt.tight_layout()
plt.show()

## 6. Feature Analysis — Landsat Bands

In [None]:
landsat = pd.read_csv(DATA_DIR / "train_landsat_features.csv")
landsat["Sample Date"] = pd.to_datetime(landsat["Sample Date"], dayfirst=True)

print(f"Landsat shape: {landsat.shape}")
print(f"Columns: {list(landsat.columns)}")
print(f"\nMissing values:")
print(landsat.isna().sum())
print(f"\nMissing %:")
print((landsat.isna().sum() / len(landsat) * 100).round(1))

In [None]:
# Merge landsat with targets for correlation analysis
merged = train.merge(landsat, on=["Latitude", "Longitude", "Sample Date"], how="left")

band_cols = ["nir", "green", "swir16", "swir22", "NDMI", "MNDWI"]
corr_with_targets = merged[band_cols + targets].corr().loc[band_cols, targets]

print("Feature-Target correlations:")
print(corr_with_targets.round(3))

fig, ax = plt.subplots(figsize=(8, 5))
sns.heatmap(corr_with_targets, annot=True, cmap="RdBu_r", center=0, vmin=-1, vmax=1, ax=ax)
ax.set_title("Landsat Band Correlations with Targets")
plt.tight_layout()
plt.show()

## 7. Feature Analysis — TerraClimate

In [None]:
tc = pd.read_csv(DATA_DIR / "train_terraclimate_features.csv")
tc["Sample Date"] = pd.to_datetime(tc["Sample Date"], dayfirst=True)

print(f"TerraClimate shape: {tc.shape}")
print(f"Columns: {list(tc.columns)}")
print(f"\nPET stats:")
print(tc["pet"].describe().round(2))

# Merge and check correlation
merged_tc = train.merge(tc, on=["Latitude", "Longitude", "Sample Date"], how="left")
for col in targets:
    r = merged_tc[["pet", col]].dropna().corr().iloc[0, 1]
    print(f"\nPET vs {col}: r = {r:.3f}")

In [None]:
# Check if extended TerraClimate exists
tc_ext_path = DATA_DIR / "processed" / "train_terraclimate_extended.csv"
if tc_ext_path.exists():
    tc_ext = pd.read_csv(tc_ext_path)
    tc_ext["Sample Date"] = pd.to_datetime(tc_ext["Sample Date"])
    print(f"Extended TerraClimate: {tc_ext.shape}")
    print(f"Variables: {[c for c in tc_ext.columns if c not in ['Latitude', 'Longitude', 'Sample Date']]}")

    # Correlations with all climate variables
    climate_cols = [c for c in tc_ext.columns if c not in ["Latitude", "Longitude", "Sample Date"]]
    merged_ext = train.merge(tc_ext, on=["Latitude", "Longitude", "Sample Date"], how="left")
    corr_climate = merged_ext[climate_cols + targets].corr().loc[climate_cols, targets]

    fig, ax = plt.subplots(figsize=(10, 8))
    sns.heatmap(corr_climate, annot=True, cmap="RdBu_r", center=0, fmt=".2f", ax=ax)
    ax.set_title("TerraClimate Variable Correlations with Targets")
    plt.tight_layout()
    plt.savefig("../outputs/figures/climate_target_correlations.png", dpi=150, bbox_inches="tight")
    plt.show()
else:
    print("Extended TerraClimate not yet extracted. Run: python src/climate_extractor.py")

## 8. Spatial Overlap Analysis

In [None]:
from scipy.spatial import cKDTree

# How far are validation stations from the nearest training station?
train_coords = np.radians(train_stations[["Latitude", "Longitude"]].values)
val_coords = np.radians(val_stations.values)

tree = cKDTree(train_coords)
dists, idxs = tree.query(val_coords, k=1)

# Convert radians to approximate km (Earth radius ~6371 km)
dist_km = dists * 6371

print("Nearest training station distance for each validation station:")
print(f"  Min:    {dist_km.min():.1f} km")
print(f"  Median: {np.median(dist_km):.1f} km")
print(f"  Mean:   {dist_km.mean():.1f} km")
print(f"  Max:    {dist_km.max():.1f} km")

fig, ax = plt.subplots(figsize=(10, 4))
ax.bar(range(len(dist_km)), sorted(dist_km), edgecolor="black", alpha=0.7)
ax.set_xlabel("Validation Station (sorted)")
ax.set_ylabel("Distance to Nearest Training Station (km)")
ax.set_title("Spatial Gap: How far are validation sites from training data?")
ax.axhline(np.median(dist_km), color="red", linestyle="--", label=f"Median: {np.median(dist_km):.0f} km")
ax.legend()
plt.tight_layout()
plt.savefig("../outputs/figures/spatial_gap.png", dpi=150, bbox_inches="tight")
plt.show()

## 9. Inter-Station Variability

In [None]:
# How much variability is between-station vs within-station?
for col in targets:
    total_var = train[col].var()
    between_var = train.groupby(["Latitude", "Longitude"])[col].mean().var()
    within_var = train.groupby(["Latitude", "Longitude"])[col].var().mean()

    print(f"\n{col}:")
    print(f"  Total variance:   {total_var:.2f}")
    print(f"  Between-station:  {between_var:.2f} ({between_var/total_var*100:.1f}%)")
    print(f"  Within-station:   {within_var:.2f} ({within_var/total_var*100:.1f}%)")

print("\n=> High between-station variance means location-specific features (terrain, land cover, soil) are critical.")
print("=> High within-station variance means temporal features (climate, seasons) matter too.")

## 10. Feature Availability Summary

In [None]:
processed = DATA_DIR / "processed"

features_status = {
    "Landsat (benchmark)": (DATA_DIR / "train_landsat_features.csv").exists(),
    "TerraClimate PET (benchmark)": (DATA_DIR / "train_terraclimate_features.csv").exists(),
    "TerraClimate Extended (14 vars)": (processed / "train_terraclimate_extended.csv").exists(),
    "Terrain (SRTM elevation)": (processed / "terrain_features.csv").exists(),
    "Land Cover (ESA WorldCover)": (processed / "landcover_features.csv").exists(),
    "Soil (SoilGrids)": (processed / "soil_features.csv").exists(),
}

print("Feature Dataset Status:")
print("=" * 50)
for name, available in features_status.items():
    status = "READY" if available else "NOT EXTRACTED"
    print(f"  {name:40s} {status}")

n_ready = sum(features_status.values())
print(f"\n{n_ready}/{len(features_status)} feature sources available.")

if n_ready < len(features_status):
    print("\nTo extract missing features, run:")
    if not features_status["TerraClimate Extended (14 vars)"]:
        print("  python src/climate_extractor.py --split both")
    if not features_status["Terrain (SRTM elevation)"]:
        print("  python src/terrain_extractor.py")
    if not features_status["Land Cover (ESA WorldCover)"]:
        print("  python src/landcover_extractor.py")
    if not features_status["Soil (SoilGrids)"]:
        print("  python src/soil_extractor.py")

## 11. Baseline Model Check (Spatial CV)

In [None]:
import sys
sys.path.insert(0, "../src")
from data_loader import build_full_dataset
from feature_builder import build_features, get_feature_columns
from model_trainer import assign_spatial_groups
from evaluation import spatial_cv_report

# Load full dataset with all available features
train_full = build_full_dataset("train")
train_full = build_features(train_full, is_training=True)

feature_cols = get_feature_columns(train_full)
print(f"Total features available: {len(feature_cols)}")
for c in feature_cols:
    missing_pct = train_full[c].isna().sum() / len(train_full) * 100
    print(f"  {c:30s} {missing_pct:5.1f}% missing")

In [None]:
from sklearn.ensemble import RandomForestRegressor

X = train_full[feature_cols]
groups = assign_spatial_groups(train_full)

print("Spatial CV with Random Forest (baseline check):")
print("=" * 60)

for target in targets:
    y = train_full[target]

    def rf_factory():
        return RandomForestRegressor(n_estimators=200, max_depth=15,
                                     min_samples_leaf=5, random_state=42, n_jobs=-1)

    print(f"\n--- {target} ---")
    report, oof = spatial_cv_report(rf_factory, X, y, groups)
    print(report.to_string(index=False))

## Summary

Key findings to guide feature engineering and modeling:

1. **Spatial gap**: Validation stations are in a different region — models must generalize spatially
2. **Between-station variance**: How much of the signal comes from location vs. time?
3. **DRP distribution**: Likely bimodal — consider log1p transform
4. **Missing Landsat data**: Cloud cover creates informative missingness
5. **Feature extraction status**: Check which data sources still need extraction
6. **Baseline spatial CV**: Compare with benchmark R² ~0.20 on leaderboard