# Driver Data Extraction - Step by Step with Diagnostics

This notebook walks through the climate driver importance extraction process from TFT model predictions, with detailed diagnostics at each step.

Based on: `/burg-archive/home/al4385/phenofusion/src/phenofusion/dataio/driversdata.py`

## Step 1: Import Required Libraries

In [None]:
import pandas as pd
import numpy as np
import pickle
from scipy.stats import linregress
from typing import List, Tuple, Optional
import logging
import matplotlib.pyplot as plt
import seaborn as sns

# Configure logging
logging.basicConfig(
    level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger(__name__)

print("✓ Libraries imported successfully")

## Step 2: Configure Parameters

Set the file paths and parameters for the analysis.

In [None]:
# Configuration
PFT = "SHR"  # Plant functional type
FORECAST_WINDOW = 30  # Days

# File paths
DATA_PATH = (
    f"/burg/glab/users/al4385/data/TFT_40_overlapping_samples/{PFT}_1982_2021.pkl"
)
PRED_PATH = (
    f"/burg/glab/users/al4385/predictions/pred_40year_moresamples/{PFT}_20251023.pkl"
)
COORD_PATH = f"/burg/glab/users/al4385/data/coordinates/{PFT}.parquet"
OUTPUT_BASE = (
    f"/burg/glab/users/ms7073/analysis/driversdata/oversampling/{PFT}_diagnostic"
)

# Phenology detection thresholds
MIN_DIFF = 0.01 if PFT in ["BET", "SHR"] else 0.20
MIN_SLOPE = 0.001 if PFT in ["BET", "SHR"] else 0.002

print(f"Configuration:")
print(f"  PFT: {PFT}")
print(f"  Forecast Window: {FORECAST_WINDOW} days")
print(f"  Min CSIF Difference: {MIN_DIFF}")
print(f"  Min Slope: {MIN_SLOPE}")
print(f"\nFile paths:")
print(f"  Data: {DATA_PATH}")
print(f"  Predictions: {PRED_PATH}")
print(f"  Coordinates: {COORD_PATH}")

## Step 3: Load Data Files

Load the processed data, predictions, and coordinates.

In [None]:
# Load original data
print("Loading data...")
with open(DATA_PATH, "rb") as fp:
    data = pickle.load(fp)

print(f"✓ Data loaded")
print(f"  Keys: {list(data.keys())}")
print(f"  Data sets: {list(data['data_sets'].keys())}")

In [None]:
# Load predictions
print("Loading predictions...")
with open(PRED_PATH, "rb") as fp:
    preds = pickle.load(fp)

print(f"✓ Predictions loaded")
print(f"  Keys: {list(preds.keys())}")
print(f"  Attention scores shape: {preds['attention_scores'].shape}")
print(f"  Predicted quantiles shape: {preds['predicted_quantiles'].shape}")
print(
    f"  Historical selection weights shape: {preds['historical_selection_weights'].shape}"
)

In [None]:
# Load coordinates
print("Loading coordinates...")
coords = pd.read_parquet(COORD_PATH).drop_duplicates()

print(f"✓ Coordinates loaded")
print(f"  Shape: {coords.shape}")
print(f"  Columns: {list(coords.columns)}")
print(f"  Unique locations: {coords['location'].nunique()}")
print(f"\nFirst few rows:")
coords.head()

## Step 4: Examine Test Data Structure

In [None]:
# Extract test data
test_data = data["data_sets"]["test"]

print("Test data structure:")
print(f"  Keys: {list(test_data.keys())}")
print(f"\nShapes:")
for key, value in test_data.items():
    if hasattr(value, "shape"):
        print(f"  {key}: {value.shape}")
    else:
        print(f"  {key}: {type(value)}")

print(f"\nFirst few IDs:")
print(test_data["id"][:5].flatten())

print(f"\nFirst few target values (CSIF):")
print(test_data["target"][:5].flatten())

## Step 5: Create Analysis DataFrame

Combine data, predictions, and coordinates into a single DataFrame.

In [None]:
# Create base DataFrame
df = pd.DataFrame(
    {
        "Index": test_data["id"].flatten(),
        "CSIF": test_data["target"].flatten(),
    }
)

# Add predictions (median/0.5 quantile)
df["pred_05"] = preds["predicted_quantiles"][:, :, 1].flatten()

print(f"Base DataFrame created:")
print(f"  Shape: {df.shape}")
print(f"  Columns: {list(df.columns)}")
print(f"\nFirst few rows:")
df.head(10)

In [None]:
# Parse location and time from Index
df[["location", "time"]] = df["Index"].str.split("_", n=1, expand=True)
df["location"] = df["location"].astype(int)
df["time"] = pd.to_datetime(df["time"])

# Sort by location and time
df = df.sort_values(by=["location", "time"])

# Add temporal features
df["doy"] = df["time"].dt.dayofyear
df["year"] = df["time"].dt.year
df["month"] = df["time"].dt.month
df["day"] = df["time"].dt.day

# Drop index column
df = df.drop(columns=["Index"])

print(f"After parsing temporal information:")
print(f"  Shape: {df.shape}")
print(f"  Columns: {list(df.columns)}")
print(f"  Date range: {df['time'].min()} to {df['time'].max()}")
print(f"  Year range: {df['year'].min()} to {df['year'].max()}")
print(f"\nFirst few rows:")
df.head(10)

In [None]:
# Merge with coordinates
df = pd.merge(coords, df, on="location", how="left")

print(f"After merging with coordinates:")
print(f"  Shape: {df.shape}")
print(f"  Columns: {list(df.columns)}")
print(f"  Unique locations: {df['location'].nunique()}")
print(f"  Records per location (avg): {len(df) / df['location'].nunique():.1f}")
print(f"\nSpatial extent:")
print(f"  Latitude range: {df['latitude'].min():.2f} to {df['latitude'].max():.2f}")
print(f"  Longitude range: {df['longitude'].min():.2f} to {df['longitude'].max():.2f}")
print(f"\nFirst few rows:")
df.head()

## Step 6: Visualize CSIF Time Series Examples

In [None]:
# Plot CSIF time series for a few random locations
sample_locations = df["location"].unique()[:4]

fig, axes = plt.subplots(2, 2, figsize=(15, 10))
axes = axes.flatten()

for idx, loc in enumerate(sample_locations):
    loc_data = df[df["location"] == loc].sort_values("time")

    axes[idx].plot(loc_data["time"], loc_data["CSIF"], label="Observed", alpha=0.7)
    axes[idx].plot(loc_data["time"], loc_data["pred_05"], label="Predicted", alpha=0.7)
    axes[idx].set_title(
        f'Location {loc}\n(Lat: {loc_data["latitude"].iloc[0]:.2f}, Lon: {loc_data["longitude"].iloc[0]:.2f})'
    )
    axes[idx].set_xlabel("Date")
    axes[idx].set_ylabel("CSIF")
    axes[idx].legend()
    axes[idx].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print("Sample CSIF time series plotted for 4 locations")

## Step 7: Detect Phenology Indices (SOS and EOS)

Use CSIF slope analysis to identify Start of Season (SOS) and End of Season (EOS) samples.

In [None]:
# Initialize lists for phenology indices
SOS_indices = []
EOS_indices = []
slopes_sos = []  # Store slopes for diagnostics
slopes_eos = []
csif_ranges_sos = []
csif_ranges_eos = []

batch_size = FORECAST_WINDOW
sos_count = 0
eos_count = 0
skipped_incomplete = 0
skipped_multi_location = 0
skipped_low_signal = 0
skipped_nan = 0


print(f"Scanning {len(df)} records in batches of {batch_size}...\n")

for start in range(0, len(df), batch_size):
    batch_df = df.iloc[start : start + batch_size]

    if batch_df["CSIF"].isna().any():
        skipped_nan += 1
        continue
    # Skip incomplete batches
    if len(batch_df) < batch_size:
        skipped_incomplete += 1
        continue

    # Check if batch is from same location
    if batch_df["location"].nunique() > 1:
        skipped_multi_location += 1
        continue

    # Get CSIF values
    csif_values = batch_df["CSIF"].values

    # Check if there's sufficient signal
    csif_range = abs(csif_values[-1] - csif_values[0])
    if csif_range < MIN_DIFF:
        skipped_low_signal += 1
        continue

    # Calculate slope
    x = np.arange(len(csif_values))
    slope, _, _, _, _ = linregress(x, csif_values)

    # Classify as SOS or EOS based on slope
    if slope >= MIN_SLOPE:
        # Positive slope = Start of Season
        SOS_indices.append(batch_df.index[0])
        slopes_sos.append(slope)
        csif_ranges_sos.append(csif_range)
        sos_count += 1
    elif slope <= -MIN_SLOPE - 0.0005:
        # Negative slope = End of Season
        EOS_indices.append(batch_df.index[0])
        slopes_eos.append(slope)
        csif_ranges_eos.append(csif_range)
        eos_count += 1

print(f"Phenology Detection Results:")
print(f"  Total batches processed: {len(df) // batch_size}")
print(f"  Skipped (incomplete): {skipped_incomplete}")
print(f"  Skipped (multi-location): {skipped_multi_location}")
print(f"  Skipped (low signal): {skipped_low_signal}")
print(f"\n  SOS samples detected: {sos_count}")
print(f"  EOS samples detected: {eos_count}")
print(f"  Skipped (NaNs): {skipped_nan}")

if sos_count > 0:
    print(
        f"\n  SOS slopes: min={min(slopes_sos):.6f}, max={max(slopes_sos):.6f}, mean={np.mean(slopes_sos):.6f}"
    )
    print(
        f"  SOS CSIF ranges: min={min(csif_ranges_sos):.4f}, max={max(csif_ranges_sos):.4f}, mean={np.mean(csif_ranges_sos):.4f}"
    )

if eos_count > 0:
    print(
        f"\n  EOS slopes: min={min(slopes_eos):.6f}, max={max(slopes_eos):.6f}, mean={np.mean(slopes_eos):.6f}"
    )
    print(
        f"  EOS CSIF ranges: min={min(csif_ranges_eos):.4f}, max={max(csif_ranges_eos):.4f}, mean={np.mean(csif_ranges_eos):.4f}"
    )

In [None]:
print(f"  Skipped (NaNs): {skipped_nan}")

In [None]:
# Convert to prediction indices
SOS_pred_indices = [int(i / batch_size) for i in SOS_indices]
EOS_pred_indices = [int(i / batch_size) for i in EOS_indices]

# Filter out indices beyond prediction array bounds
max_pred_index = len(preds["attention_scores"]) - 1
SOS_pred_indices = [idx for idx in SOS_pred_indices if idx <= max_pred_index]
EOS_pred_indices = [idx for idx in EOS_pred_indices if idx <= max_pred_index]

print(f"Prediction Indices:")
print(f"  SOS prediction indices: {len(SOS_pred_indices)} (filtered from {sos_count})")
print(f"  EOS prediction indices: {len(EOS_pred_indices)} (filtered from {eos_count})")
print(f"  Max valid prediction index: {max_pred_index}")

## Step 8: Visualize Slope Distributions

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(15, 5))

if len(slopes_sos) > 0:
    axes[0].hist(slopes_sos, bins=50, alpha=0.7, color="green", edgecolor="black")
    axes[0].axvline(
        MIN_SLOPE, color="red", linestyle="--", label=f"Threshold: {MIN_SLOPE}"
    )
    axes[0].set_title(f"SOS Slope Distribution (n={len(slopes_sos)})")
    axes[0].set_xlabel("Slope")
    axes[0].set_ylabel("Frequency")
    axes[0].legend()
    axes[0].grid(True, alpha=0.3)

if len(slopes_eos) > 0:
    axes[1].hist(slopes_eos, bins=50, alpha=0.7, color="brown", edgecolor="black")
    axes[1].axvline(
        -MIN_SLOPE - 0.0005,
        color="red",
        linestyle="--",
        label=f"Threshold: {-MIN_SLOPE - 0.0005:.4f}",
    )
    axes[1].set_title(f"EOS Slope Distribution (n={len(slopes_eos)})")
    axes[1].set_xlabel("Slope")
    axes[1].set_ylabel("Frequency")
    axes[1].legend()
    axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

## Step 9: Visualize Example SOS and EOS Windows

In [None]:
# Plot example SOS and EOS windows
fig, axes = plt.subplots(2, 2, figsize=(15, 10))

# Plot 2 SOS examples
if len(SOS_indices) >= 2:
    for i in range(2):
        idx = SOS_indices[i]
        window = df.iloc[idx : idx + batch_size]
        axes[0, i].plot(window["doy"], window["CSIF"], "go-", label="CSIF")
        axes[0, i].set_title(f"SOS Example {i+1}\nSlope: {slopes_sos[i]:.6f}")
        axes[0, i].set_xlabel("Day of Year")
        axes[0, i].set_ylabel("CSIF")
        axes[0, i].legend()
        axes[0, i].grid(True, alpha=0.3)

# Plot 2 EOS examples
if len(EOS_indices) >= 2:
    for i in range(2):
        idx = EOS_indices[i]
        window = df.iloc[idx : idx + batch_size]
        axes[1, i].plot(window["doy"], window["CSIF"], "ro-", label="CSIF")
        axes[1, i].set_title(f"EOS Example {i+1}\nSlope: {slopes_eos[i]:.6f}")
        axes[1, i].set_xlabel("Day of Year")
        axes[1, i].set_ylabel("CSIF")
        axes[1, i].legend()
        axes[1, i].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

## Step 10: Extract Driver Weights for SOS Samples

For each SOS sample, find the time window with maximum attention and extract driver importance weights.

In [None]:
def find_max_attention_window(index: int, preds, forecast_window: int) -> int:
    """
    Find the time window with maximum attention scores.
    """
    # Get mean attention across all horizons
    att_array = np.mean(preds["attention_scores"][index], axis=0)

    max_sum = -np.inf
    best_start_index = None

    # Slide window to find maximum
    # max_start = len(att_array) - forecast_window
    max_start = 336
    for i in range(max_start):
        current_sum = np.sum(att_array[i : i + forecast_window])
        if current_sum > max_sum:
            max_sum = current_sum
            best_start_index = i

    return best_start_index if best_start_index is not None else 0


print("Extracting SOS driver weights...")

sos_driver_data = []
max_window_start = 365 - FORECAST_WINDOW
skipped_window = 0
processed = 0

for index in SOS_pred_indices[:5]:  # Show diagnostics for first 5
    try:
        # Find maximum attention window
        window_start = find_max_attention_window(index, preds, FORECAST_WINDOW)

        print(f"\nSample {index}:")
        print(f"  Max attention window starts at day: {window_start}")

        # Skip if window extends beyond valid range
        if window_start > max_window_start:
            print(f"  Skipped: window extends beyond valid range")
            skipped_window += 1
            continue

        # Extract median weights for each driver
        weights = {}
        hist_weights = preds["historical_selection_weights"][index]

        driver_names = ["tmin", "tmax", "rad", "precip", "photo", "sm"]
        for i, var in enumerate(driver_names, 1):
            weight_values = hist_weights[
                window_start : window_start + FORECAST_WINDOW, i
            ]
            weights[f"hist_{var}"] = np.median(weight_values)
            print(f"  {var}: {weights[f'hist_{var}']:.4f}")

        # Get location ID
        location_id = int(data["data_sets"]["test"]["id"][index][0].split("_")[0])
        weights["location"] = location_id
        print(f"  Location: {location_id}")

        sos_driver_data.append(weights)
        processed += 1

    except Exception as e:
        print(f"  Error processing index {index}: {e}")
        continue

print(f"\nProcessed {processed} samples for diagnostics")
print(f"Total SOS samples to process: {len(SOS_pred_indices)}")

In [None]:
# Process all SOS samples
print("Processing all SOS samples...")

sos_driver_data = []
skipped_window = 0
errors = 0

for index in SOS_pred_indices:
    try:
        window_start = find_max_attention_window(index, preds, FORECAST_WINDOW)

        if window_start > max_window_start:
            skipped_window += 1
            continue

        weights = {}
        hist_weights = preds["historical_selection_weights"][index]

        for i, var in enumerate(["tmin", "tmax", "rad", "precip", "photo", "sm"], 1):
            weights[f"hist_{var}"] = np.median(
                hist_weights[window_start : window_start + FORECAST_WINDOW, i]
            )

        location_id = int(data["data_sets"]["test"]["id"][index][0].split("_")[0])
        weights["location"] = location_id

        sos_driver_data.append(weights)

    except Exception as e:
        errors += 1
        continue

# Create DataFrame
sos_driver_df = pd.DataFrame(sos_driver_data)

print(f"\nSOS Driver Extraction Results:")
print(f"  Input samples: {len(SOS_pred_indices)}")
print(f"  Skipped (window): {skipped_window}")
print(f"  Errors: {errors}")
print(f"  Successfully extracted: {len(sos_driver_df)}")
print(f"\nDriver DataFrame shape: {sos_driver_df.shape}")
print(f"Columns: {list(sos_driver_df.columns)}")
print(f"\nFirst few rows:")
sos_driver_df.head()

## Step 11: Extract Driver Weights for EOS Samples

In [None]:
print("Processing all EOS samples...")

eos_driver_data = []
skipped_window = 0
errors = 0

for index in EOS_pred_indices:
    try:
        window_start = find_max_attention_window(index, preds, FORECAST_WINDOW)

        if window_start > max_window_start:
            skipped_window += 1
            continue

        weights = {}
        hist_weights = preds["historical_selection_weights"][index]

        for i, var in enumerate(["tmin", "tmax", "rad", "precip", "photo", "sm"], 1):
            weights[f"hist_{var}"] = np.median(
                hist_weights[window_start : window_start + FORECAST_WINDOW, i]
            )

        location_id = int(data["data_sets"]["test"]["id"][index][0].split("_")[0])
        weights["location"] = location_id

        eos_driver_data.append(weights)

    except Exception as e:
        errors += 1
        continue

# Create DataFrame
eos_driver_df = pd.DataFrame(eos_driver_data)

print(f"\nEOS Driver Extraction Results:")
print(f"  Input samples: {len(EOS_pred_indices)}")
print(f"  Skipped (window): {skipped_window}")
print(f"  Errors: {errors}")
print(f"  Successfully extracted: {len(eos_driver_df)}")
print(f"\nDriver DataFrame shape: {eos_driver_df.shape}")
print(f"Columns: {list(eos_driver_df.columns)}")
print(f"\nFirst few rows:")
eos_driver_df.head()

## Step 12: Visualize Driver Weight Distributions

In [None]:
# Plot driver weight distributions for SOS
if len(sos_driver_df) > 0:
    driver_cols = [
        "hist_tmin",
        "hist_tmax",
        "hist_rad",
        "hist_precip",
        "hist_photo",
        "hist_sm",
    ]

    fig, axes = plt.subplots(2, 3, figsize=(15, 10))
    axes = axes.flatten()

    for idx, col in enumerate(driver_cols):
        axes[idx].hist(
            sos_driver_df[col].dropna(),
            bins=50,
            alpha=0.7,
            color="green",
            edgecolor="black",
        )
        axes[idx].set_title(
            f'SOS - {col.replace("hist_", "").upper()}\nMean: {sos_driver_df[col].mean():.4f}'
        )
        axes[idx].set_xlabel("Weight")
        axes[idx].set_ylabel("Frequency")
        axes[idx].grid(True, alpha=0.3)

    plt.suptitle("SOS Driver Weight Distributions", fontsize=16, y=1.00)
    plt.tight_layout()
    plt.show()

    # Summary statistics
    print("SOS Driver Weight Summary:")
    print(sos_driver_df[driver_cols].describe())

In [None]:
# Plot driver weight distributions for EOS
if len(eos_driver_df) > 0:
    driver_cols = [
        "hist_tmin",
        "hist_tmax",
        "hist_rad",
        "hist_precip",
        "hist_photo",
        "hist_sm",
    ]

    fig, axes = plt.subplots(2, 3, figsize=(15, 10))
    axes = axes.flatten()

    for idx, col in enumerate(driver_cols):
        axes[idx].hist(
            eos_driver_df[col].dropna(),
            bins=50,
            alpha=0.7,
            color="brown",
            edgecolor="black",
        )
        axes[idx].set_title(
            f'EOS - {col.replace("hist_", "").upper()}\nMean: {eos_driver_df[col].mean():.4f}'
        )
        axes[idx].set_xlabel("Weight")
        axes[idx].set_ylabel("Frequency")
        axes[idx].grid(True, alpha=0.3)

    plt.suptitle("EOS Driver Weight Distributions", fontsize=16, y=1.00)
    plt.tight_layout()
    plt.show()

    # Summary statistics
    print("EOS Driver Weight Summary:")
    print(eos_driver_df[driver_cols].describe())

## Step 13: Merge with Coordinates and Check for Missing Data

In [None]:
# Merge SOS drivers with coordinates
sos_coord_driver_df = pd.merge(coords, sos_driver_df, on="location", how="left")

print("SOS Data After Coordinate Merge:")
print(f"  Total locations in coords: {len(coords)}")
print(f"  Locations with SOS driver data: {len(sos_driver_df)}")
print(f"  Final merged DataFrame: {sos_coord_driver_df.shape}")

driver_cols = [
    "hist_tmin",
    "hist_tmax",
    "hist_rad",
    "hist_precip",
    "hist_photo",
    "hist_sm",
]
missing_mask = sos_coord_driver_df[driver_cols].isnull().all(axis=1)
n_missing = missing_mask.sum()

print(f"\nMissing Data:")
print(f"  Locations with all drivers missing: {n_missing}")
print(f"  Percentage missing: {100 * n_missing / len(sos_coord_driver_df):.1f}%")

print(f"\nFirst few rows:")
sos_coord_driver_df.head(10)

In [None]:
# Merge EOS drivers with coordinates
eos_coord_driver_df = pd.merge(coords, eos_driver_df, on="location", how="left")

print("EOS Data After Coordinate Merge:")
print(f"  Total locations in coords: {len(coords)}")
print(f"  Locations with EOS driver data: {len(eos_driver_df)}")
print(f"  Final merged DataFrame: {eos_coord_driver_df.shape}")

missing_mask = eos_coord_driver_df[driver_cols].isnull().all(axis=1)
n_missing = missing_mask.sum()

print(f"\nMissing Data:")
print(f"  Locations with all drivers missing: {n_missing}")
print(f"  Percentage missing: {100 * n_missing / len(eos_coord_driver_df):.1f}%")

print(f"\nFirst few rows:")
eos_coord_driver_df.head(10)

## Step 14: Spatial Imputation of Missing Values

Fill missing driver values using nearby spatial locations.

In [None]:
def impute_nearby_values(
    df: pd.DataFrame,
    lat_range: float = 0.5,
    lon_range: float = 0.5,
    max_distance: float = 2.0,
) -> pd.DataFrame:
    """
    Impute missing driver values using nearby spatial locations.
    """
    driver_cols = [
        "hist_tmin",
        "hist_tmax",
        "hist_rad",
        "hist_precip",
        "hist_photo",
        "hist_sm",
    ]

    # Identify rows with missing data
    missing_mask = df[driver_cols].isnull().all(axis=1)
    n_missing = missing_mask.sum()

    if n_missing == 0:
        print("No missing values to impute")
        return df

    print(f"Found {n_missing} locations with missing data")

    # Create a copy to avoid modifying during iteration
    df_result = df.copy()
    filled_count = 0

    # For each row with data, find and fill nearby missing rows
    for idx, row in df.iterrows():
        if row[driver_cols].isnull().any():
            continue

        # Find nearby locations with missing data
        lat_match = (df["latitude"] >= row["latitude"] - lat_range) & (
            df["latitude"] <= row["latitude"] + lat_range
        )
        lon_match = (df["longitude"] >= row["longitude"] - lon_range) & (
            df["longitude"] <= row["longitude"] + lon_range
        )

        # Additional distance check
        lat_diff = np.abs(df["latitude"] - row["latitude"])
        lon_diff = np.abs(df["longitude"] - row["longitude"])
        dist = np.sqrt(lat_diff**2 + lon_diff**2)
        dist_match = dist <= max_distance

        # Combined mask
        nearby_missing = lat_match & lon_match & dist_match & missing_mask

        if nearby_missing.any():
            # Impute values
            df_result.loc[nearby_missing, driver_cols] = row[driver_cols].values
            filled_count += nearby_missing.sum()

    # Count remaining missing values
    n_remaining = df_result[driver_cols].isnull().all(axis=1).sum()
    print(
        f"After imputation: {n_remaining} locations still missing ({n_missing - n_remaining} filled)"
    )

    return df_result


print("Imputing SOS missing values...")
sos_imputed_df = impute_nearby_values(sos_coord_driver_df)

print("\nImputing EOS missing values...")
# eos_imputed_df = impute_nearby_values(eos_coord_driver_df)

## Step 15: Visualize Spatial Coverage

In [None]:
# Visualize spatial coverage for SOS
fig, axes = plt.subplots(1, 2, figsize=(15, 6))

# Before imputation
has_data = ~sos_coord_driver_df[driver_cols].isnull().all(axis=1)
axes[0].scatter(
    sos_coord_driver_df.loc[has_data, "longitude"],
    sos_coord_driver_df.loc[has_data, "latitude"],
    c="green",
    s=10,
    alpha=0.5,
    label="Has data",
)
axes[0].scatter(
    sos_coord_driver_df.loc[~has_data, "longitude"],
    sos_coord_driver_df.loc[~has_data, "latitude"],
    c="red",
    s=10,
    alpha=0.5,
    label="Missing",
)
axes[0].set_title(f"SOS Before Imputation\n{has_data.sum()} locations with data")
axes[0].set_xlabel("Longitude")
axes[0].set_ylabel("Latitude")
axes[0].legend()
axes[0].grid(True, alpha=0.3)


plt.tight_layout()
plt.show()

In [None]:
# Visualize spatial coverage for EOS
fig, axes = plt.subplots(1, 2, figsize=(15, 6))

# Before imputation
has_data = ~eos_coord_driver_df[driver_cols].isnull().all(axis=1)
axes[0].scatter(
    eos_coord_driver_df.loc[has_data, "longitude"],
    eos_coord_driver_df.loc[has_data, "latitude"],
    c="brown",
    s=10,
    alpha=0.5,
    label="Has data",
)
axes[0].scatter(
    eos_coord_driver_df.loc[~has_data, "longitude"],
    eos_coord_driver_df.loc[~has_data, "latitude"],
    c="red",
    s=10,
    alpha=0.5,
    label="Missing",
)
axes[0].set_title(f"EOS Before Imputation\n{has_data.sum()} locations with data")
axes[0].set_xlabel("Longitude")
axes[0].set_ylabel("Latitude")
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# After imputation
has_data_imp = ~eos_imputed_df[driver_cols].isnull().all(axis=1)
axes[1].scatter(
    eos_imputed_df.loc[has_data_imp, "longitude"],
    eos_imputed_df.loc[has_data_imp, "latitude"],
    c="brown",
    s=10,
    alpha=0.5,
    label="Has data",
)
axes[1].scatter(
    eos_imputed_df.loc[~has_data_imp, "longitude"],
    eos_imputed_df.loc[~has_data_imp, "latitude"],
    c="red",
    s=10,
    alpha=0.5,
    label="Missing",
)
axes[1].set_title(f"EOS After Imputation\n{has_data_imp.sum()} locations with data")
axes[1].set_xlabel("Longitude")
axes[1].set_ylabel("Latitude")
axes[1].legend()
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

## Step 16: Final Data Quality Check

In [None]:
# SOS data quality
n_complete_sos = (~sos_imputed_df[driver_cols].isnull().any(axis=1)).sum()
print("SOS Data Quality:")
print(f"  Total locations: {len(sos_imputed_df)}")
print(
    f"  Complete records: {n_complete_sos} ({100*n_complete_sos/len(sos_imputed_df):.1f}%)"
)
print(f"\nMissing values per driver:")
for col in driver_cols:
    n_missing = sos_imputed_df[col].isnull().sum()
    print(f"  {col}: {n_missing} ({100*n_missing/len(sos_imputed_df):.1f}%)")

print("\n" + "=" * 50 + "\n")

# EOS data quality
n_complete_eos = (~eos_imputed_df[driver_cols].isnull().any(axis=1)).sum()
print("EOS Data Quality:")
print(f"  Total locations: {len(eos_imputed_df)}")
print(
    f"  Complete records: {n_complete_eos} ({100*n_complete_eos/len(eos_imputed_df):.1f}%)"
)
print(f"\nMissing values per driver:")
for col in driver_cols:
    n_missing = eos_imputed_df[col].isnull().sum()
    print(f"  {col}: {n_missing} ({100*n_missing/len(eos_imputed_df):.1f}%)")

## Step 17: Save Output Files

In [None]:
# Save SOS driver data
sos_output_path = f"{OUTPUT_BASE}_SOS.csv"
sos_imputed_df.to_csv(sos_output_path, index=False)
print(f"✓ SOS driver data saved to: {sos_output_path}")
print(f"  Shape: {sos_imputed_df.shape}")

# Save EOS driver data
eos_output_path = f"{OUTPUT_BASE}_EOS.csv"
eos_imputed_df.to_csv(eos_output_path, index=False)
print(f"\n✓ EOS driver data saved to: {eos_output_path}")
print(f"  Shape: {eos_imputed_df.shape}")

print("\n" + "=" * 50)
print("Driver data extraction complete!")
print("=" * 50)

## Summary

This notebook walked through the entire driver data extraction pipeline:

1. **Data Loading**: Loaded test data, predictions, and coordinates
2. **Data Preparation**: Created analysis DataFrame with temporal features
3. **Phenology Detection**: Used CSIF slope analysis to identify SOS and EOS samples
4. **Driver Extraction**: Extracted climate driver importance weights from attention scores
5. **Spatial Imputation**: Filled missing values using nearby locations
6. **Quality Checks**: Validated data completeness and spatial coverage
7. **Output**: Saved driver data to CSV files for mapping

The diagnostic outputs at each step help understand:
- Data quality and completeness
- Phenology detection performance
- Driver weight distributions
- Spatial coverage patterns
- Impact of imputation strategies