The original Parquet dataset contained only two field_id entries for Hungarian regions. Additional regions have now been incorporated, and the modeling was completed for Vas and Budapest using a GRU model for zero-shot comparison. To support satellite mapping in the application, the data must be pre-processed to include the new fields and enriched with geographic coordinates.



In [None]:
!pip install -U --force-reinstall fsspec datasets

In [None]:
# -*- coding: utf-8 -*-
"""
Pre-processing & Feature Engineering Pipeline
Author: Samantha Lee (Updated with Coordinate Extraction)

This script:
1. Loads field polygons, NDVI, Weather, SMAP, and ERA5 data.
2. Aggregates all data to a daily field-level granularity.
3. EXTRACTS LAT/LON coordinates and appends them to the Field ID.
4. Cleans, imputes, and adds derived features (Growing Season, Stress).
"""

# INSTALL DEPENDENCIES
!pip install -q datasets geopandas fastparquet scikit-learn tqdm pandas numpy shapely matplotlib seaborn

# IMPORTS
import pandas as pd
import geopandas as gpd
from shapely.geometry import Point
from shapely import from_wkb
from datasets import load_dataset
from scipy.spatial import cKDTree
import numpy as np
import os
from tqdm import tqdm
import warnings
from sklearn.impute import KNNImputer, SimpleImputer
import matplotlib.pyplot as plt
import seaborn as sns

warnings.filterwarnings('ignore')
sns.set_style("whitegrid")
plt.rcParams['figure.figsize'] = (15, 6)

# CONFIGURATION
RAW_NDVI_INPUT_FILE = 'ndvi_by_field.parquet'
OUTPUT_PARQUET_PATH = 'zeroshot_ground_truth_with_coords.parquet'
OUTPUT_CSV_PATH = 'zeroshot_ground_truth_with_coords.csv'

FIELD_DATA_REPO = 'jaelin215/eucropmap_2022_hungary_by_field_ids'
WEATHER_DATA_REPO = 'MWasil/Hungary_weather_data_2022_2025'
SMAP_DATA_REPO = 'jaelin215/SMAP_L3_SM_P'
ERA5_DATA_REPO = 'jmdu/ERA5-Land'

CRS = 'EPSG:4326'          # Standard Lat/Lon
PROJECTED_CRS = 'EPSG:23700' # Hungarian EOV (for metric calculations)

# HELPER FUNCTIONS

def build_spatial_index_by_date(gdf, date_col='date'):
    """Pre-build KD-trees for each date to avoid rebuilding."""
    if gdf.empty:
        return {}

    print(f"   Building spatial indices by date for {len(gdf)} records...")
    spatial_index = {}

    gdf = gdf[gdf.geometry.is_valid & ~gdf.geometry.is_empty]

    for date in tqdm(gdf[date_col].unique(), desc="   Indexing"):
        subset = gdf[gdf[date_col] == date]
        if subset.empty:
            continue

        coords = np.vstack([subset.geometry.x, subset.geometry.y]).T
        spatial_index[date] = {
            'tree': cKDTree(coords),
            'data': subset.reset_index(drop=True)
        }
    print("   ...spatial indices built.")
    return spatial_index


def fast_nearest_join_by_date(source_gdf, target_spatial_index, target_cols, prefix):
    """Ultra-fast nearest neighbor join using pre-built spatial indices."""
    results = []
    unique_dates = source_gdf['date'].unique()

    for date in tqdm(unique_dates, desc=f"   Joining {prefix if prefix else 'data'}"):
        source_subset = source_gdf[source_gdf['date'] == date].copy()
        if source_subset.empty:
            continue

        if date not in target_spatial_index:
            continue

        target_tree = target_spatial_index[date]['tree']
        target_data = target_spatial_index[date]['data']

        if target_data.empty:
            continue

        source_coords = np.vstack([source_subset.geometry.x, source_subset.geometry.y]).T
        distances, indices = target_tree.query(source_coords, k=1)

        # Vectorized column copy
        for col in target_cols:
            new_col_name = f"{prefix}{col}" if prefix else col
            source_subset[new_col_name] = target_data.iloc[indices][col].values

        # Select the new columns
        final_cols = ['field_id', 'date'] + [f"{prefix}{col}" if prefix else col for col in target_cols]
        results.append(source_subset[final_cols])

    if not results:
        return pd.DataFrame()

    return pd.concat(results, ignore_index=True)

# PART 1: DATA LOADING

def load_field_data():
    """Loads field polygons and calculates centroids."""
    print("Loading field geometry data...")
    try:
        fields_ds = load_dataset(FIELD_DATA_REPO, split='train')
        fields_gdf = gpd.GeoDataFrame(fields_ds.to_pandas())
        fields_gdf = fields_gdf.loc[:, ~fields_gdf.columns.duplicated()]

        fields_gdf['geometry'] = fields_gdf['geometry'].apply(lambda x: from_wkb(x))
        fields_gdf = fields_gdf.set_geometry('geometry')
        fields_gdf.set_crs(CRS, inplace=True)

        # Calculate centroid in Lat/Lon (EPSG:4326)
        fields_gdf['centroid'] = fields_gdf.geometry.to_crs(PROJECTED_CRS).centroid.to_crs(CRS)

        print(f"   Loaded {len(fields_gdf):,} fields")
        # Return field_id, geometry, and centroid
        return fields_gdf[['field_id', 'geometry', 'centroid']].copy()
    except Exception as e:
        print(f"ERROR: {e}")
        return None

def load_pixel_data():
    """Loads pre-aggregated NDVI data."""
    print(f"Loading NDVI data from {RAW_NDVI_INPUT_FILE}...")

    if not os.path.exists(RAW_NDVI_INPUT_FILE):
        print(f"   ERROR: File not found!")
        return None

    try:
        pixel_df = pd.read_parquet(RAW_NDVI_INPUT_FILE)
        pixel_df = pixel_df.loc[:, ~pixel_df.columns.duplicated()]

        required = ['dt_acquired', 'field_id', 'ndvi_mean']
        if not all(col in pixel_df.columns for col in required):
            print(f"   ERROR: Missing required columns")
            return None

        pixel_df.rename(columns={'ndvi_mean': 'NDVI_value', 'dt_acquired': 'date'}, inplace=True)
        pixel_df['date'] = pd.to_datetime(pixel_df['date']).dt.normalize()

        daily_pixel_df = pixel_df.groupby(['field_id', pd.Grouper(key='date', freq='D')])['NDVI_value'].agg(
            NDVI_mean='mean'
        ).reset_index()

        daily_pixel_df['NDVI_min'] = daily_pixel_df['NDVI_mean']
        daily_pixel_df['NDVI_max'] = daily_pixel_df['NDVI_mean']

        # Add placeholder columns for compatibility
        daily_pixel_df['pixel_count'] = np.nan
        daily_pixel_df['NDVI_std'] = np.nan

        print(f"   Processed {len(daily_pixel_df):,} records")
        return daily_pixel_df.dropna(subset=['NDVI_mean'])
    except Exception as e:
        print(f"   ERROR: {e}")
        return None

def load_weather_data():
    """Loads and aggregates weather data."""
    print("Loading weather data...")
    try:
        weather_ds = load_dataset(WEATHER_DATA_REPO, split='train')
        weather_df = weather_ds.to_pandas()
        weather_df = weather_df.loc[:, ~weather_df.columns.duplicated()]

        if 'time' in weather_df.columns:
            weather_df.rename(columns={'time': 'datetime'}, inplace=True)

        weather_df['datetime'] = pd.to_datetime(weather_df['datetime'])
        weather_df['date'] = weather_df['datetime'].dt.normalize()

        print("   Aggregating to daily...")
        daily_weather = weather_df.groupby(['date', 'lat', 'lon']).agg(
            weather_temp_mean=('temperature_2m', 'mean'),
            weather_temp_min=('temperature_2m', 'min'),
            weather_temp_max=('temperature_2m', 'max'),
            weather_precip_sum=('precipitation', 'sum'),
            weather_humidity_mean=('relative_humidity_2m', 'mean'),
            weather_wind_mean=('wind_speed_10m', 'mean'),
            weather_pressure_mean=('surface_pressure', 'mean')
        ).reset_index()

        daily_weather = daily_weather.sort_values(by=['lat', 'lon', 'date'])
        daily_weather['weather_precip_7d'] = daily_weather.groupby(['lat', 'lon'])['weather_precip_sum'].transform(
            lambda x: x.rolling(7, min_periods=1).sum()
        )
        daily_weather['weather_temp_7d_mean'] = daily_weather.groupby(['lat', 'lon'])['weather_temp_mean'].transform(
            lambda x: x.rolling(7, min_periods=1).mean()
        )

        weather_gdf = gpd.GeoDataFrame(
            daily_weather,
            geometry=gpd.points_from_xy(daily_weather.lon, daily_weather.lat),
            crs=CRS
        )
        print(f"   Loaded {len(weather_gdf):,} daily weather point-records")
        return weather_gdf

    except Exception as e:
        print(f"   ERROR: {e}")
        return None

def load_smap_data():
    """Loads and aggregates SMAP data."""
    print("Loading SMAP soil data...")
    try:
        data_files = {
            "train": ["smap_hungary_2022_cells.parquet", "smap_hungary_2023_cells.parquet",
                      "smap_hungary_2024_cells.parquet", "smap_hungary_2025_cells.parquet"]
        }
        smap_ds = load_dataset(SMAP_DATA_REPO, data_files=data_files, split='train')
        smap_df = smap_ds.to_pandas()
        smap_df = smap_df.loc[:, ~smap_df.columns.duplicated()]

        date_col = next((col for col in ['dt', 'date', 'time'] if col in smap_df.columns), None)
        if date_col is None: raise ValueError("No date column found in SMAP")

        smap_df['date'] = pd.to_datetime(smap_df[date_col], errors='coerce').dt.normalize()
        smap_df = smap_df.dropna(subset=['date'])

        smap_daily = smap_df.groupby(['date', 'lat', 'lon']).agg(
            smap_soil_moisture=('soil_moisture', 'mean'),
            smap_surface_temp_C=('surface_temp_C', 'mean'),
            smap_veg_water=('veg_water', 'mean'),
            smap_clay_fraction=('clay_fraction', 'first')
        ).reset_index()

        smap_gdf = gpd.GeoDataFrame(
            smap_daily,
            geometry=gpd.points_from_xy(smap_daily.lon, smap_daily.lat),
            crs=CRS
        )
        print(f"   Loaded {len(smap_gdf):,} daily SMAP point-records")
        return smap_gdf

    except Exception as e:
        print(f"   ERROR: {e}")
        return None

def load_era5_data():
    """Loads ERA5 data."""
    print("Loading ERA5 data...")
    try:
        data_files = {
            "train": [
                "features_daily_2015-2025.parquet",
                "features_daily_2015-2025_updated_2025-09-12.parquet"
            ]
        }
        era5_ds = load_dataset(ERA5_DATA_REPO, data_files=data_files, split='train')
        era5_df = era5_ds.to_pandas()
        era5_df = era5_df.loc[:, ~era5_df.columns.duplicated()]

        era5_df['date'] = pd.to_datetime(era5_df['date']).dt.normalize()

        rename_map = {
            '2m_temperature': 'era5_2m_temperature',
            'total_precipitation': 'era5_total_precipitation',
            'gdd_base5': 'era5_gdd_base5',
            'volumetric_soil_water_layer_1': 'era5_volumetric_soil_water_layer_1',
            'sw_root': 'era5_sw_root',
            'surface_solar_radiation_downwards': 'era5_surface_solar_radiation_downwards'
        }
        era5_df = era5_df.rename(columns=rename_map)

        if 'era5_2m_temperature' in era5_df.columns:
            if era5_df['era5_2m_temperature'].mean() > 200:
                era5_df['era5_2m_temperature'] = era5_df['era5_2m_temperature'] - 273.15
            era5_df['era5_gdd_base5'] = era5_df['era5_2m_temperature'].apply(lambda t: max(0, t - 5))

        era5_df = era5_df.sort_values(by=['latitude', 'longitude', 'date'])
        era5_df['era5_gdd_cumsum'] = era5_df.groupby(['latitude', 'longitude'])['era5_gdd_base5'].transform(
             lambda x: x.fillna(0).cumsum()
        )

        era5_gdf = gpd.GeoDataFrame(
            era5_df,
            geometry=gpd.points_from_xy(era5_df.longitude, era5_df.latitude),
            crs=CRS
        )
        print(f"   Loaded {len(era5_gdf):,} daily ERA5 point-records")
        return era5_gdf

    except Exception as e:
        print(f"   ERROR: {e}")
        return None

# PART 2: AGGREGATION

def run_aggregation():
    """Aggregates data and ADDS COORDINATES."""
    print("\n" + "="*80)
    print("PART 1: AGGREGATING DATA (LSTM-Style)")
    print("="*80)

    fields_gdf = load_field_data()
    if fields_gdf is None: return None

    daily_pixel_df = load_pixel_data()
    if daily_pixel_df is None: return None

    # --- FIXED SECTION: Explicit checks for None ---
    weather_gdf = load_weather_data()
    if weather_gdf is None:
        print("   WARNING: Weather data failed load, creating empty.")
        weather_gdf = gpd.GeoDataFrame()

    smap_gdf = load_smap_data()
    if smap_gdf is None:
        print("   WARNING: SMAP data failed load, creating empty.")
        smap_gdf = gpd.GeoDataFrame()

    era5_gdf = load_era5_data()
    if era5_gdf is None:
        print("   WARNING: ERA5 data failed load, creating empty.")
        era5_gdf = gpd.GeoDataFrame()
    # ----------------------------------------------

    active_field_ids = daily_pixel_df['field_id'].unique()
    fields_gdf = fields_gdf[fields_gdf['field_id'].isin(active_field_ids)]
    print(f"\nProcessing {len(active_field_ids):,} fields with NDVI data")

    unique_dates = daily_pixel_df['date'].unique()
    min_date = unique_dates.min()
    max_date = unique_dates.max()

    all_dates = pd.date_range(start=min_date, end=max_date, freq='D')

    print("   Creating field-date matrix...")
    baseline_df = pd.MultiIndex.from_product(
        [active_field_ids, all_dates],
        names=['field_id', 'date']
    ).to_frame(index=False)

    baseline_gdf = baseline_df.merge(
        fields_gdf[['field_id', 'centroid']],
        on='field_id',
        how='left'
    )
    # Ensure baseline_gdf is a GeoDataFrame
    baseline_gdf = gpd.GeoDataFrame(baseline_gdf, geometry='centroid', crs=CRS)

    # --- NEW: EXTRACT COORDINATES ---
    print("\n   Extracting coordinates from field centroids...")
    # 1. Extract Lat/Lon (WGS84)
    baseline_gdf['latitude'] = baseline_gdf.geometry.y
    baseline_gdf['longitude'] = baseline_gdf.geometry.x

    # 2. Extract Projected Coords (EOV) for ID generation
    #    This creates the '+236033+262315' format
    projected = baseline_gdf.geometry.to_crs(PROJECTED_CRS)
    baseline_gdf['proj_x'] = projected.x
    baseline_gdf['proj_y'] = projected.y

    # Calculate new ID string in a separate column
    baseline_gdf['field_id_with_coords'] = baseline_gdf.apply(
        lambda row: f"{row['field_id']}|{int(row['proj_x']):+}{int(row['proj_y']):+}",
        axis=1
    )
    # --- END NEW ---

    print("\n   Building spatial indices for fast joins...")
    weather_index = build_spatial_index_by_date(weather_gdf)
    smap_index = build_spatial_index_by_date(smap_gdf) if not smap_gdf.empty else {}
    era5_index = build_spatial_index_by_date(era5_gdf) if not era5_gdf.empty else {}

    print("\n   Performing spatial joins...")
    weather_cols = ['weather_temp_mean', 'weather_temp_min', 'weather_temp_max',
                    'weather_precip_sum', 'weather_humidity_mean', 'weather_wind_mean',
                    'weather_pressure_mean', 'weather_precip_7d', 'weather_temp_7d_mean']
    weather_joined = fast_nearest_join_by_date(baseline_gdf, weather_index, weather_cols, '')

    smap_cols = ['smap_soil_moisture', 'smap_surface_temp_C', 'smap_veg_water', 'smap_clay_fraction']
    smap_joined = fast_nearest_join_by_date(baseline_gdf, smap_index, smap_cols, '')

    era5_cols = [col for col in era5_gdf.columns if col.startswith('era5_')]
    era5_joined = fast_nearest_join_by_date(baseline_gdf, era5_index, era5_cols, '')

    print("\n   Merging all datasets...")
    # Start with baseline containing the NEW IDs and Coordinates
    final_df = pd.DataFrame(baseline_gdf[['field_id', 'date', 'field_id_with_coords', 'latitude', 'longitude']])

    if not weather_joined.empty:
        final_df = final_df.merge(weather_joined, on=['field_id', 'date'], how='left')

    if not smap_joined.empty:
        final_df = final_df.merge(smap_joined, on=['field_id', 'date'], how='left')

    if not era5_joined.empty:
        final_df = final_df.merge(era5_joined, on=['field_id', 'date'], how='left')

    final_df = final_df.merge(daily_pixel_df, on=['field_id', 'date'], how='left')
    print("   NDVI data merged")

    # SWAP IDs: Use the new ID with coordinates
    final_df['original_field_id'] = final_df['field_id']
    final_df['field_id'] = final_df['field_id_with_coords']
    final_df.drop(columns=['field_id_with_coords'], inplace=True)
    print("   Field IDs updated with coordinates.")

    # Add temporal features
    final_df['day_of_year'] = final_df['date'].dt.dayofyear
    final_df['month'] = final_df['date'].dt.month
    final_df['week'] = final_df['date'].dt.isocalendar().week

    # Return dataframe (dropped geometry to make it a standard DataFrame)
    return final_df

# PART 3: CLEANING & IMPUTATION

def clean_and_impute_data(df):
    """Clean dataset while preserving coordinates."""
    print("\n" + "="*80)
    print("PART 2: DATA CLEANING PIPELINE")
    print("="*80)

    if df is None or df.empty:
        print("   ERROR: Input DataFrame is empty.")
        return None

    # --- STEP 1: DROP USELESS COLUMNS ---
    print("STEP 1: DROPPING USELESS COLUMNS")
    # Note: We do NOT drop 'latitude' or 'longitude'
    columns_to_drop = ['NDVI_std', 'pixel_count', 'centroid', 'geometry', 'original_field_id']
    cols_to_drop_present = [col for col in columns_to_drop if col in df.columns]
    df = df.drop(columns=cols_to_drop_present, errors='ignore')
    print(f"  Dropped: {cols_to_drop_present}")

    # --- STEP 2: INTERPOLATE NDVI ---
    print("STEP 2: INTERPOLATING NDVI VALUES")
    df = df.sort_values(['field_id', 'date']).reset_index(drop=True)
    ndvi_cols = ['NDVI_mean', 'NDVI_min', 'NDVI_max']

    for col in [c for c in ndvi_cols if c in df.columns]:
        df[col] = df.groupby('field_id')[col].transform(
            lambda x: x.interpolate(method='linear', limit_direction='both')
        )
        df[col] = df.groupby('field_id')[col].ffill().bfill()

    # --- STEP 3: IMPUTE WEATHER ---
    print("STEP 3: IMPUTING WEATHER DATA")
    weather_cols = [c for c in df.columns if any(x in c for x in ['weather_', 'smap_', 'era5_'])]

    for col in weather_cols:
        # Forward fill then backward fill
        df[col] = df.groupby('field_id')[col].ffill().bfill()
        # If still missing (entire field missing), use dataset median
        if df[col].isnull().sum() > 0:
            df[col] = df[col].fillna(df[col].median())

    # --- STEP 4: REMOVE OUTLIERS ---
    print("STEP 4: REMOVING OUTLIERS")
    if 'NDVI_mean' in df.columns:
        mask = (df['NDVI_mean'] >= -1) & (df['NDVI_mean'] <= 1)
        df = df[mask]

    # --- STEP 5: DERIVED FEATURES ---
    print("STEP 5: ADDING DERIVED FEATURES")
    if 'month' in df.columns:
        df['is_growing_season'] = ((df['month'] >= 5) & (df['month'] <= 9)).astype(int)

    if 'weather_temp_max' in df.columns:
        df['heat_stress'] = (df['weather_temp_max'] > 35).astype(int)

    if 'weather_temp_min' in df.columns:
        df['cold_stress'] = (df['weather_temp_min'] < 5).astype(int)

    if 'smap_soil_moisture' in df.columns and 'weather_precip_7d' in df.columns:
        # Simple drought index
        df['drought_stress'] = ((df['smap_soil_moisture'] < 0.2) & (df['weather_precip_7d'] < 10)).astype(int)

    return df

# MAIN EXECUTION

if __name__ == "__main__":
    print("Starting Pre-processing Pipeline...")

    # 1. Aggregation
    raw_df = run_aggregation()

    # 2. Cleaning
    if raw_df is not None:
        cleaned_df = clean_and_impute_data(raw_df)

        # 3. Save
        print("\n" + "="*80)
        print("SAVING RESULTS")
        print("="*80)

        if cleaned_df is not None:
            print(f"Saving to {OUTPUT_PARQUET_PATH}...")
            cleaned_df.to_parquet(OUTPUT_PARQUET_PATH, index=False)

            print("Done!")
            print(cleaned_df.head(3))

Starting Pre-processing Pipeline...

PART 1: AGGREGATING DATA (LSTM-Style)
Loading field geometry data...
   Loaded 112,526 fields
Loading NDVI data from ndvi_by_field.parquet...
   Processed 1,172 records
Loading weather data...
   Aggregating to daily...
   Loaded 74,304 daily weather point-records
Loading SMAP soil data...
   Loaded 72,609 daily SMAP point-records
Loading ERA5 data...


features_daily_2015-2025.parquet:   0%|          | 0.00/270M [00:00<?, ?B/s]

features_daily_2015-2025_updated_2025-09(…):   0%|          | 0.00/273M [00:00<?, ?B/s]

Generating train split: 0 examples [00:00, ? examples/s]

   Loaded 15,164,696 daily ERA5 point-records

Processing 24 fields with NDVI data
   Creating field-date matrix...

   Extracting coordinates from field centroids...

   Building spatial indices for fast joins...
   Building spatial indices by date for 74304 records...


   Indexing: 100%|██████████| 1376/1376 [00:06<00:00, 216.94it/s]


   ...spatial indices built.
   Building spatial indices by date for 72609 records...


   Indexing: 100%|██████████| 1289/1289 [00:02<00:00, 567.02it/s]


   ...spatial indices built.
   Building spatial indices by date for 15164696 records...


   Indexing: 100%|██████████| 3908/3908 [04:59<00:00, 13.04it/s]


   ...spatial indices built.

   Performing spatial joins...


   Joining data: 100%|██████████| 242/242 [00:02<00:00, 111.47it/s]
   Joining data: 100%|██████████| 242/242 [00:01<00:00, 202.43it/s]
   Joining data: 100%|██████████| 242/242 [00:01<00:00, 139.44it/s]



   Merging all datasets...
   NDVI data merged
   Field IDs updated with coordinates.

PART 2: DATA CLEANING PIPELINE
STEP 1: DROPPING USELESS COLUMNS
  Dropped: ['NDVI_std', 'pixel_count', 'original_field_id']
STEP 2: INTERPOLATING NDVI VALUES
STEP 3: IMPUTING WEATHER DATA
STEP 4: REMOVING OUTLIERS
STEP 5: ADDING DERIVED FEATURES

SAVING RESULTS
Saving to zeroshot_ground_truth_with_coords.parquet...
Done!
                                            field_id       date   latitude  \
0  vtx|Baranya|rapeseed|0x0|+232274+269464|+59167... 2025-03-02  46.307969   
1  vtx|Baranya|rapeseed|0x0|+232274+269464|+59167... 2025-03-03  46.307969   
2  vtx|Baranya|rapeseed|0x0|+232274+269464|+59167... 2025-03-04  46.307969   

   longitude  weather_temp_mean  weather_temp_min  weather_temp_max  \
0  18.290345           3.320833              -0.7               8.3   
1  18.290345           3.008333              -2.5               9.9   
2  18.290345           4.120833              -2.3              