In [None]:
# -*- coding: utf-8 -*-
#COLAB PREPROCESSING

# INSTALL DEPENDENCIES

!pip install -q datasets geopandas fastparquet scikit-learn tqdm pandas numpy shapely matplotlib seaborn


# CLEAR CACHE

print("Clearing Hugging Face datasets cache...")
!rm -rf /root/.cache/huggingface/datasets


# IMPORTS

import pandas as pd
import geopandas as gpd
from shapely.geometry import Point
from shapely import from_wkb
from datasets import load_dataset
from scipy.spatial import cKDTree
import numpy as np
import os
from tqdm import tqdm
import warnings
from sklearn.impute import KNNImputer, SimpleImputer
from IPython.display import display, HTML
import matplotlib.pyplot as plt
import seaborn as sns

warnings.filterwarnings('ignore')
sns.set_style("whitegrid")
plt.rcParams['figure.figsize'] = (15, 6)

# CONFIGURATION
RAW_NDVI_INPUT_FILE = 'ndvi_by_field.parquet'
OUTPUT_PARQUET_PATH = 'zeroshot_ground_truth.parquet'
OUTPUT_CSV_PATH = 'zeroshot_ground_truth.csv'

FIELD_DATA_REPO = 'jaelin215/eucropmap_2022_hungary_by_field_ids'
WEATHER_DATA_REPO = 'MWasil/Hungary_weather_data_2022_2025'
SMAP_DATA_REPO = 'jaelin215/SMAP_L3_SM_P'
ERA5_DATA_REPO = 'jmdu/ERA5-Land'

CRS = 'EPSG:4326'
PROJECTED_CRS = 'EPSG:23700'

# HELPER FUNCTIONS (From 01_LSTM notebook)


def build_spatial_index_by_date(gdf, date_col='date'):
    """Pre-build KD-trees for each date to avoid rebuilding."""
    print(f"   Building spatial indices by date for {len(gdf)} records...")
    spatial_index = {}

    gdf = gdf[gdf.geometry.is_valid & ~gdf.geometry.is_empty]

    for date in tqdm(gdf[date_col].unique(), desc="   Indexing"):
        subset = gdf[gdf[date_col] == date]
        if subset.empty:
            continue

        coords = np.vstack([subset.geometry.x, subset.geometry.y]).T
        spatial_index[date] = {
            'tree': cKDTree(coords),
            'data': subset.reset_index(drop=True)
        }
    print("   ...spatial indices built.")
    return spatial_index


def fast_nearest_join_by_date(source_gdf, target_spatial_index, target_cols, prefix):
    """Ultra-fast nearest neighbor join using pre-built spatial indices."""
    results = []
    unique_dates = source_gdf['date'].unique()

    for date in tqdm(unique_dates, desc=f"   Joining {prefix if prefix else 'data'}"):
        source_subset = source_gdf[source_gdf['date'] == date].copy()
        if source_subset.empty:
            continue

        if date not in target_spatial_index:
            continue

        target_tree = target_spatial_index[date]['tree']
        target_data = target_spatial_index[date]['data']

        if target_data.empty:
            continue

        source_coords = np.vstack([source_subset.geometry.x, source_subset.geometry.y]).T
        distances, indices = target_tree.query(source_coords, k=1)

        # Vectorized column copy
        for col in target_cols:
            # Add prefix only if it's not empty
            new_col_name = f"{prefix}{col}" if prefix else col
            source_subset[new_col_name] = target_data.iloc[indices][col].values

        # Select the new columns
        final_cols = ['field_id', 'date'] + [f"{prefix}{col}" if prefix else col for col in target_cols]
        results.append(source_subset[final_cols])

    if not results:
        return pd.DataFrame()

    return pd.concat(results, ignore_index=True)

# PART 1: DATA AGGREGATION (Based on 01_LSTM)

def load_field_data():
    """Loads field polygons and calculates centroids."""
    print("Loading field geometry data...")
    try:
        fields_ds = load_dataset(FIELD_DATA_REPO, split='train')
        fields_gdf = gpd.GeoDataFrame(fields_ds.to_pandas())
        fields_gdf = fields_gdf.loc[:, ~fields_gdf.columns.duplicated()]

        fields_gdf['geometry'] = fields_gdf['geometry'].apply(lambda x: from_wkb(x))
        fields_gdf = fields_gdf.set_geometry('geometry')
        fields_gdf.set_crs(CRS, inplace=True)
        fields_gdf['centroid'] = fields_gdf.geometry.to_crs(PROJECTED_CRS).centroid.to_crs(CRS)

        print(f"   Loaded {len(fields_gdf):,} fields")
        return fields_gdf[['field_id', 'geometry', 'centroid']].copy()
    except Exception as e:
        print(f"ERROR: {e}")
        return None

def load_pixel_data():
    """Loads pre-aggregated NDVI data."""
    print(f"Loading NDVI data from {RAW_NDVI_INPUT_FILE}...")

    if not os.path.exists(RAW_NDVI_INPUT_FILE):
        print(f"   ERROR: File not found!")
        return None

    try:
        pixel_df = pd.read_parquet(RAW_NDVI_INPUT_FILE)
        pixel_df = pixel_df.loc[:, ~pixel_df.columns.duplicated()]

        required = ['dt_acquired', 'field_id', 'ndvi_mean']
        if not all(col in pixel_df.columns for col in required):
            print(f"   ERROR: Missing required columns")
            return None

        pixel_df.rename(columns={'ndvi_mean': 'NDVI_value', 'dt_acquired': 'date'}, inplace=True)
        pixel_df['date'] = pd.to_datetime(pixel_df['date']).dt.normalize()

        daily_pixel_df = pixel_df.groupby(['field_id', pd.Grouper(key='date', freq='D')])['NDVI_value'].agg(
            NDVI_mean='mean'
        ).reset_index()

        daily_pixel_df['NDVI_min'] = daily_pixel_df['NDVI_mean']
        daily_pixel_df['NDVI_max'] = daily_pixel_df['NDVI_mean']

        # Add pixel_count and NDVI_std for compatibility with 02_cleaning script
        daily_pixel_df['pixel_count'] = np.nan # This script starts from mean, so no count
        daily_pixel_df['NDVI_std'] = np.nan    # This script starts from mean, so no std

        print(f"   Processed {len(daily_pixel_df):,} records")
        return daily_pixel_df.dropna(subset=['NDVI_mean'])
    except Exception as e:
        print(f"   ERROR: {e}")
        return None

def load_weather_data():
    """Loads and aggregates weather data to daily points."""
    print("Loading weather data...")
    try:
        weather_ds = load_dataset(WEATHER_DATA_REPO, split='train')
        weather_df = weather_ds.to_pandas()
        weather_df = weather_df.loc[:, ~weather_df.columns.duplicated()]

        if 'time' in weather_df.columns:
            weather_df.rename(columns={'time': 'datetime'}, inplace=True)

        weather_df['datetime'] = pd.to_datetime(weather_df['datetime'])
        weather_df['date'] = weather_df['datetime'].dt.normalize()

        print("   Aggregating to daily...")
        daily_weather = weather_df.groupby(['date', 'lat', 'lon']).agg(
            weather_temp_mean=('temperature_2m', 'mean'),
            weather_temp_min=('temperature_2m', 'min'),
            weather_temp_max=('temperature_2m', 'max'),
            weather_precip_sum=('precipitation', 'sum'),
            weather_humidity_mean=('relative_humidity_2m', 'mean'),
            weather_wind_mean=('wind_speed_10m', 'mean'),
            weather_pressure_mean=('surface_pressure', 'mean')
        ).reset_index()

        # Add 7-day rolling features here (as done in 02_lstm_cleaning)
        daily_weather = daily_weather.sort_values(by=['lat', 'lon', 'date'])
        daily_weather['weather_precip_7d'] = daily_weather.groupby(['lat', 'lon'])['weather_precip_sum'].transform(
            lambda x: x.rolling(7, min_periods=1).sum()
        )
        daily_weather['weather_temp_7d_mean'] = daily_weather.groupby(['lat', 'lon'])['weather_temp_mean'].transform(
            lambda x: x.rolling(7, min_periods=1).mean()
        )

        weather_gdf = gpd.GeoDataFrame(
            daily_weather,
            geometry=gpd.points_from_xy(daily_weather.lon, daily_weather.lat),
            crs=CRS
        )
        print(f"   Loaded {len(weather_gdf):,} daily weather point-records")
        return weather_gdf

    except Exception as e:
        print(f"   ERROR: {e}")
        return None

def load_smap_data():
    """Loads and aggregates SMAP data to daily points."""
    print("Loading SMAP soil data...")
    try:
        data_files = {
            "train": ["smap_hungary_2022_cells.parquet", "smap_hungary_2023_cells.parquet",
                      "smap_hungary_2024_cells.parquet", "smap_hungary_2025_cells.parquet"]
        }
        smap_ds = load_dataset(SMAP_DATA_REPO, data_files=data_files, split='train')
        smap_df = smap_ds.to_pandas()
        smap_df = smap_df.loc[:, ~smap_df.columns.duplicated()]

        date_col = None
        for col in ['dt', 'date', 'time']:
            if col in smap_df.columns:
                date_col = col
                break
        if date_col is None: raise ValueError("No date column found in SMAP")

        smap_df['date'] = pd.to_datetime(smap_df[date_col], errors='coerce').dt.normalize()
        smap_df = smap_df.dropna(subset=['date'])

        print("   Aggregating AM/PM to daily...")
        smap_daily = smap_df.groupby(['date', 'lat', 'lon']).agg(
            smap_soil_moisture=('soil_moisture', 'mean'),
            smap_surface_temp_C=('surface_temp_C', 'mean'),
            smap_veg_water=('veg_water', 'mean'),
            smap_clay_fraction=('clay_fraction', 'first')
        ).reset_index()

        smap_gdf = gpd.GeoDataFrame(
            smap_daily,
            geometry=gpd.points_from_xy(smap_daily.lon, smap_daily.lat),
            crs=CRS
        )
        print(f"   Loaded {len(smap_gdf):,} daily SMAP point-records")
        return smap_gdf

    except Exception as e:
        print(f"   ERROR: {e}")
        return None

def load_era5_data():
    """Loads ERA5 data as daily points."""
    print("Loading ERA5 data...")
    try:
        data_files = {
            "train": [
                "features_daily_2015-2025.parquet",
                "features_daily_2015-2025_updated_2025-09-12.parquet"
            ]
        }
        era5_ds = load_dataset(ERA5_DATA_REPO, data_files=data_files, split='train')
        era5_df = era5_ds.to_pandas()
        era5_df = era5_df.loc[:, ~era5_df.columns.duplicated()]

        era5_df['date'] = pd.to_datetime(era5_df['date']).dt.normalize()

        rename_map = {
            '2m_temperature': 'era5_2m_temperature',
            'total_precipitation': 'era5_total_precipitation',
            'gdd_base5': 'era5_gdd_base5',
            'volumetric_soil_water_layer_1': 'era5_volumetric_soil_water_layer_1',
            'sw_root': 'era5_sw_root',
            'surface_solar_radiation_downwards': 'era5_surface_solar_radiation_downwards'
        }
        era5_df = era5_df.rename(columns=rename_map)

        if 'era5_2m_temperature' in era5_df.columns:
            if era5_df['era5_2m_temperature'].mean() > 200:
                era5_df['era5_2m_temperature'] = era5_df['era5_2m_temperature'] - 273.15
            era5_df['era5_gdd_base5'] = era5_df['era5_2m_temperature'].apply(lambda t: max(0, t - 5))

        # Add cumulative GDD (as done in 02_lstm_cleaning)
        era5_df = era5_df.sort_values(by=['latitude', 'longitude', 'date'])
        era5_df['era5_gdd_cumsum'] = era5_df.groupby(['latitude', 'longitude'])['era5_gdd_base5'].transform(
             lambda x: x.fillna(0).cumsum()
        )

        era5_gdf = gpd.GeoDataFrame(
            era5_df,
            geometry=gpd.points_from_xy(era5_df.longitude, era5_df.latitude),
            crs=CRS
        )
        print(f"   Loaded {len(era5_gdf):,} daily ERA5 point-records")
        return era5_gdf

    except Exception as e:
        print(f"   ERROR: {e}")
        return None

def run_aggregation():
    """
    New aggregation logic based on 01_LSTM notebook.
    Merges data by finding the nearest point-in-time neighbor
    to each field's centroid.
    """
    print("\n" + "="*80)
    print("PART 1: AGGREGATING DATA (LSTM-Style)")
    print("="*80)

    fields_gdf = load_field_data()
    if fields_gdf is None: return None

    daily_pixel_df = load_pixel_data()
    if daily_pixel_df is None: return None

    weather_gdf = load_weather_data()
    if weather_gdf is None: return None

    smap_gdf = load_smap_data()
    if smap_gdf is None:
        print("   WARNING: SMAP data failed, continuing without it")
        smap_gdf = gpd.GeoDataFrame()

    era5_gdf = load_era5_data()
    if era5_gdf is None:
        print("   WARNING: ERA5 data failed, continuing without it")
        era5_gdf = gpd.GeoDataFrame()

    active_field_ids = daily_pixel_df['field_id'].unique()
    fields_gdf = fields_gdf[fields_gdf['field_id'].isin(active_field_ids)]
    print(f"\nProcessing {len(active_field_ids):,} fields with NDVI data")

    unique_dates = daily_pixel_df['date'].unique()
    min_date = unique_dates.min()
    max_date = unique_dates.max()
    print(f"   Date range: {min_date.date()} to {max_date.date()}")

    all_dates = pd.date_range(start=min_date, end=max_date, freq='D')

    print("   Creating field-date matrix...")
    baseline_df = pd.MultiIndex.from_product(
        [active_field_ids, all_dates],
        names=['field_id', 'date']
    ).to_frame(index=False)

    baseline_gdf = baseline_df.merge(
        fields_gdf[['field_id', 'centroid']],
        on='field_id',
        how='left'
    )
    baseline_gdf = gpd.GeoDataFrame(baseline_gdf, geometry='centroid', crs=CRS)
    print(f"   Baseline: {len(baseline_gdf):,} field-date combinations")

    print("\n   Building spatial indices for fast joins...")
    weather_index = build_spatial_index_by_date(weather_gdf)
    smap_index = build_spatial_index_by_date(smap_gdf) if not smap_gdf.empty else {}
    era5_index = build_spatial_index_by_date(era5_gdf) if not era5_gdf.empty else {}

    print("\n   Performing spatial joins...")
    # --- FIX: Set prefix to '' to avoid duplicate prefixes ---
    weather_cols = ['weather_temp_mean', 'weather_temp_min', 'weather_temp_max',
                    'weather_precip_sum', 'weather_humidity_mean', 'weather_wind_mean',
                    'weather_pressure_mean', 'weather_precip_7d', 'weather_temp_7d_mean']
    weather_joined = fast_nearest_join_by_date(baseline_gdf, weather_index, weather_cols, '')

    smap_cols = ['smap_soil_moisture', 'smap_surface_temp_C', 'smap_veg_water', 'smap_clay_fraction']
    smap_joined = fast_nearest_join_by_date(baseline_gdf, smap_index, smap_cols, '')

    era5_cols = [col for col in era5_gdf.columns if col.startswith('era5_')]
    era5_joined = fast_nearest_join_by_date(baseline_gdf, era5_index, era5_cols, '')
    # --- END FIX ---

    print("\n   Merging all datasets...")
    final_df = baseline_df.copy()

    if not weather_joined.empty:
        final_df = final_df.merge(weather_joined, on=['field_id', 'date'], how='left')
        print("   Weather data merged")

    if not smap_joined.empty:
        final_df = final_df.merge(smap_joined, on=['field_id', 'date'], how='left')
        print("   SMAP data merged")

    if not era5_joined.empty:
        final_df = final_df.merge(era5_joined, on=['field_id', 'date'], how='left')
        print("   ERA5 data merged")

    final_df = final_df.merge(daily_pixel_df, on=['field_id', 'date'], how='left')
    print("   NDVI data merged")

    # Add temporal features (from 02_lstm_cleaning logic)
    final_df['day_of_year'] = final_df['date'].dt.dayofyear
    final_df['month'] = final_df['date'].dt.month
    final_df['week'] = final_df['date'].dt.isocalendar().week
    print("   Temporal features (day, month, week) added")

    print(f"\nAggregation complete. Shape: {final_df.shape}")
    print(f"Columns: {final_df.columns.tolist()}")
    return final_df

# PART 2: CLEANING & FEATURES (NEW - From 02_LSTM)

CLEANING_STEPS = """
 STEP 1: DROP USELESS COLUMNS
─────────────────────────────────────────────────────────────────────────────
Problem: Some columns have no useful information
  • NDVI_std: 100% missing (0 non-null values)
    → This happens when only 1 pixel per field (no variation to compute std)
  • pixel_count: 94% missing (only 5.9% have values)
    → Not necessary for crop health prediction

Solution: Remove these columns entirely
  → Reduces dimensionality without losing information
  → Saves memory and computation time

STEP 2: INTERPOLATE MISSING NDVI VALUES
─────────────────────────────────────────────────────────────────────────────
Problem: NDVI has gaps due to cloudy days (satellites can't see through clouds)
  Example:
    Day 1: NDVI = 0.50
    Day 5: NDVI = NaN (cloudy)
    Day 10: NDVI = 0.60

Solution: Linear interpolation per field
  → Fills Day 5: NDVI = 0.55 (halfway between 0.50 and 0.60)

Why this works:
  • Plant growth is continuous and smooth
  • NDVI changes gradually over days/weeks
  • Field-specific: each field has its own growth pattern
  • Biological validity: no sudden jumps in vegetation

Method: Per-field interpolation
  1. Group data by field_id
  2. Sort by date within each field
  3. Use linear interpolation between known values
  4. Forward fill at start (use first valid value)
  5. Backward fill at end (use last valid value)

STEP 3: IMPUTE WEATHER & CLIMATE DATA
─────────────────────────────────────────────────────────────────────────────
Problem: Some weather/climate data may have gaps (sensor failures, etc.)

Solution: Three-tier imputation strategy
  1. Forward Fill (primary): Use previous day's value
     → Weather changes slowly, yesterday's value is a good proxy

  2. Backward Fill (backup): Use next day's value if no previous
     → Better than leaving NaN

  3. Median Imputation (last resort): Use field's median value
     → Statistically sound fallback

Applied to:
  • Temperature (mean, min, max)
  • Precipitation, humidity, wind, pressure
  • Soil moisture (SMAP)
  • Climate variables (ERA5)

STEP 4: REMOVE OUTLIERS
─────────────────────────────────────────────────────────────────────────────
Problem: Sensor errors, data corruption, or extreme anomalies

Solution: Domain-based filtering
  • NDVI: Must be in [-1, 1] (physical constraint of NDVI formula)
  • Temperature: Must be in [-30°C, 50°C] (Hungary climate range)
  • Precipitation: Must be in [0, 200] mm/day (physical maximum)
  • Soil moisture: Must be in [0, 1] (saturation limit)

Why this matters:
  → Outliers cause unstable LSTM training
  → Gradient explosion from extreme values
  → Model learns noise instead of signal

Expected removal: < 0.1% of data (typically 100-500 rows)

STEP 5: OPTIMIZE DATA TYPES
─────────────────────────────────────────────────────────────────────────────
Problem: Pandas defaults to float64, int64 (8 bytes per value)

Solution: Downcast to smaller types
  • float64 → float32 (50% memory savings, sufficient precision)
  • int64 → int8/int16 where appropriate (87.5-75% savings)
    - month: int8 (max value = 12)
    - day_of_year: int16 (max value = 366)

Benefits:
  → 50% less memory usage (293 MB → 150 MB)
  → Faster data loading and processing
  → More sequences fit in GPU memory

STEP 6: ADD DERIVED FEATURES
─────────────────────────────────────────────────────────────────────────────
Problem: Raw data doesn't explicitly capture critical crop events

Solution: Engineer domain-specific indicators

1. Growing Season Flag
   is_growing_season = (month >= 5) & (month <= 9)
   → Sunflowers grow May-September in Hungary
   → Binary flag: 1 during growth, 0 otherwise

2. Heat Stress
   heat_stress = (temp_max > 35°C)
   → Sunflowers suffer above 35°C
   → Reduces photosynthesis and yield

3. Cold Stress
   cold_stress = (temp_min < 5°C)
   → Damages young plants in spring
   → Can kill flowers in late season

4. Drought Stress
   drought_stress = (soil_moisture < 0.2) & (precip_7d < 10mm)
   → Low soil water + no recent rain = stress
   → Most critical factor for sunflower health

Why add these?
  → Encodes agronomic knowledge
  → Easier for LSTM to learn binary flags
  → Improves model interpretability
  → Faster convergence during training

STEP 7: VALIDATION CHECKS
─────────────────────────────────────────────────────────────────────────────
Final quality assurance:
  1. No NaN values remaining (LSTM requirement)
  2. No infinite values (causes training crashes)
  3. No duplicate field-date combinations (data integrity)
  4. All values within valid ranges (sanity check)
  5. Sequence lengths adequate for LSTM (min 30-50 observations)
  6. Data types optimized (memory efficiency)
  7. Column names consistent (no special characters)

If any check fails → Fix automatically or report to user
"""

def print_cleaning_methodology():
    """Print the detailed cleaning methodology."""
    print(CLEANING_STEPS)


def clean_and_impute_data(df):
    """
    Clean the preprocessed dataset with detailed logging.
    Accepts a DataFrame as input.
    """

    print("\n" + "="*80)
    print(" "*25 + "PART 2: DATA CLEANING PIPELINE (LSTM-Style)")
    print("="*80)

    print_cleaning_methodology()

    print("\n" + "="*80)
    print(" "*25 + "STARTING CLEANING PROCESS")
    print("="*80)

    if df is None or df.empty:
        print("   ERROR: Input DataFrame is empty. Cannot clean.")
        return None

    print(f"  Initial shape: {df.shape[0]:,} rows × {df.shape[1]} columns")
    print(f"  Memory usage: {df.memory_usage(deep=True).sum() / 1e6:.1f} MB")

    initial_rows = len(df)
    initial_memory = df.memory_usage(deep=True).sum() / 1e6

    # --- STEP 1: DROP USELESS COLUMNS ---
    print("\n" + "─" * 80)
    print("STEP 1: DROPPING USELESS COLUMNS")
    print("─" * 80)

    columns_to_drop = ['NDVI_std', 'pixel_count', 'centroid', 'geometry']
    cols_to_drop_present = [col for col in columns_to_drop if col in df.columns]

    for col in cols_to_drop_present:
        missing_pct = (df[col].isnull().sum() / len(df)) * 100
        print(f"  • {col}: {missing_pct:.1f}% missing → DROPPED")

    df = df.drop(columns=cols_to_drop_present, errors='ignore')
    print(f"\n  New shape: {df.shape[0]:,} rows × {df.shape[1]} columns")

    # --- STEP 2: ANALYZE MISSING DATA ---
    print("\n" + "─" * 80)
    print("ANALYZING MISSING DATA BEFORE CLEANING")
    print("─" * 80)
    missing_summary = df.isnull().sum()
    missing_pct = (missing_summary / len(df) * 100).round(2)
    missing_df = pd.DataFrame({
        'Column': missing_summary.index,
        'Missing Count': missing_summary.values,
        'Missing %': missing_pct.values
    }).sort_values('Missing %', ascending=False)
    cols_with_missing = missing_df[missing_df['Missing Count'] > 0]
    if len(cols_with_missing) > 0:
        print("  Columns with missing values:")
        for _, row in cols_with_missing.iterrows():
            print(f"    • {row['Column']}: {row['Missing Count']:,} ({row['Missing %']:.2f}%)")
    else:
        print("  No missing values found!")

    # --- STEP 3: INTERPOLATE NDVI ---
    print("\n" + "─" * 80)
    print("STEP 2: INTERPOLATING NDVI VALUES")
    print("─" * 80)

    df = df.sort_values(['field_id', 'date']).reset_index(drop=True)

    ndvi_cols = ['NDVI_mean', 'NDVI_min', 'NDVI_max']
    ndvi_cols = [col for col in ndvi_cols if col in df.columns]

    for col in ndvi_cols:
        before_missing = df[col].isnull().sum()
        if before_missing > 0:
            print(f"  • {col}: {before_missing:,} missing values")
            df[col] = df.groupby('field_id')[col].transform(
                lambda x: x.interpolate(method='linear', limit_direction='both')
            )
            after_missing = df[col].isnull().sum()
            filled = before_missing - after_missing
            print(f"    → Filled {filled:,} values via linear interpolation")
            if after_missing > 0:
                print(f"    → Forward/backward filling {after_missing:,} remaining gaps")
                df[col] = df.groupby('field_id')[col].ffill().bfill()
        else:
            print(f"  • {col}: No missing values.")
    print(f"\n  NDVI interpolation complete")

    # --- STEP 4: IMPUTE WEATHER/CLIMATE DATA ---
    print("\n" + "─" * 80)
    print("STEP 3: IMPUTING WEATHER & CLIMATE DATA")
    print("─" * 80)

    continuous_cols = [
        'weather_temp_mean', 'weather_temp_min', 'weather_temp_max',
        'weather_precip_sum', 'weather_humidity_mean', 'weather_wind_mean',
        'weather_pressure_mean', 'smap_soil_moisture', 'smap_surface_temp_C',
        'smap_veg_water', 'smap_clay_fraction', 'era5_2m_temperature',
        'era5_total_precipitation', 'era5_gdd_base5',
        'era5_volumetric_soil_water_layer_1', 'era5_sw_root',
        'era5_surface_solar_radiation_downwards', 'era5_gdd_cumsum',
        'weather_precip_7d', 'weather_temp_7d_mean'
    ]
    continuous_cols = [c for c in continuous_cols if c in df.columns]

    total_filled = 0
    for col in continuous_cols:
        before_missing = df[col].isnull().sum()
        if before_missing > 0:
            df[col] = df.groupby('field_id')[col].ffill()
            after_ffill = df[col].isnull().sum()
            df[col] = df.groupby('field_id')[col].bfill()
            after_bfill = df[col].isnull().sum()
            filled = before_missing - after_bfill
            if filled > 0:
                total_filled += filled
                print(f"  • {col}: filled {filled:,} values")

    remaining_missing = df[continuous_cols].isnull().sum().sum()
    if remaining_missing > 0:
        print(f"\n {remaining_missing:,} values still missing, using median imputation...")
        imputer = SimpleImputer(strategy='median')
        df[continuous_cols] = imputer.fit_transform(df[continuous_cols])
        print(f"  Filled {remaining_missing:,} values with median")
    print(f"\n  Total values imputed: {total_filled:,}")

    # --- STEP 5: REMOVE OUTLIERS ---
    print("\n" + "─" * 80)
    print("STEP 4: REMOVING OUTLIERS")
    print("─" * 80)
    before_outlier = len(df)

    if 'NDVI_mean' in df.columns:
        ndvi_outliers = len(df[(df['NDVI_mean'] < -1) | (df['NDVI_mean'] > 1)])
        df = df[(df['NDVI_mean'] >= -1) & (df['NDVI_mean'] <= 1)]
        if ndvi_outliers > 0: print(f"  • NDVI outside [-1, 1]: removed {ndvi_outliers:,} rows")

    if 'weather_temp_mean' in df.columns:
        temp_outliers = len(df[(df['weather_temp_mean'] < -30) | (df['weather_temp_mean'] > 50)])
        df = df[(df['weather_temp_mean'] >= -30) & (df['weather_temp_mean'] <= 50)]
        if temp_outliers > 0: print(f"  • Temperature outside [-30°C, 50°C]: removed {temp_outliers:,} rows")

    if 'weather_precip_sum' in df.columns:
        precip_outliers = len(df[(df['weather_precip_sum'] < 0) | (df['weather_precip_sum'] > 200)])
        df = df[(df['weather_precip_sum'] >= 0) & (df['weather_precip_sum'] <= 200)]
        if precip_outliers > 0: print(f"  • Precipitation outside [0, 200] mm: removed {precip_outliers:,} rows")

    if 'smap_soil_moisture' in df.columns:
        sm_outliers = len(df[(df['smap_soil_moisture'] < 0) | (df['smap_soil_moisture'] > 1)])
        df = df[(df['smap_soil_moisture'] >= 0) & (df['smap_soil_moisture'] <= 1)]
        if sm_outliers > 0: print(f"  • Soil moisture outside [0, 1]: removed {sm_outliers:,} rows")

    after_outlier = len(df)
    total_removed = before_outlier - after_outlier
    if total_removed > 0: print(f"\n  Removed {total_removed:,} outlier rows ({total_removed/before_outlier*100:.3f}%)")
    else: print(f"\n  No outliers detected")

    # --- STEP 6: OPTIMIZE DTYPES ---
    print("\n" + "─" * 80)
    print("STEP 5: OPTIMIZING DATA TYPES")
    print("─" * 80)
    memory_before = df.memory_usage(deep=True).sum() / 1e6

    float_cols = df.select_dtypes(include=['float64']).columns
    if len(float_cols) > 0:
        df[float_cols] = df[float_cols].astype('float32')
        print(f"  • Converted {len(float_cols)} columns: float64 → float32")

    if 'month' in df.columns:
        df['month'] = df['month'].astype('int8')
        print(f"  • Converted month: int64 → int8")

    if 'week' in df.columns:
        df['week'] = df['week'].astype('int8')
        print(f"  • Converted week: int64 → int8")

    if 'day_of_year' in df.columns:
        df['day_of_year'] = df['day_of_year'].astype('int16')
        print(f"  • Converted day_of_year: int64 → int16")

    memory_after = df.memory_usage(deep=True).sum() / 1e6
    memory_saved = memory_before - memory_after
    print(f"\n  Memory: {memory_before:.1f} MB → {memory_after:.1f} MB (Saved: {memory_saved:.1f} MB)")

    # --- STEP 7: ADD DERIVED FEATURES ---
    print("\n" + "─" * 80)
    print("STEP 6: ADDING DERIVED FEATURES")
    print("─" * 80)

    if 'month' in df.columns:
        df['is_growing_season'] = ((df['month'] >= 5) & (df['month'] <= 9)).astype('int8')
        print(f"  • is_growing_season: {df['is_growing_season'].sum():,} rows in May-Sept")
    else: print("  • 'month' column not found, skipping 'is_growing_season'")

    if 'weather_temp_max' in df.columns:
        df['heat_stress'] = (df['weather_temp_max'] > 35).astype('int8')
        print(f"  • heat_stress: {df['heat_stress'].sum():,} days with temp > 35°C")
    else: print("  • 'weather_temp_max' column not found, skipping 'heat_stress'")

    if 'weather_temp_min' in df.columns:
        df['cold_stress'] = (df['weather_temp_min'] < 5).astype('int8')
        print(f"  • cold_stress: {df['cold_stress'].sum():,} days with temp < 5°C")
    else: print("  • 'weather_temp_min' column not found, skipping 'cold_stress'")

    if 'smap_soil_moisture' in df.columns and 'weather_precip_7d' in df.columns:
        df['drought_stress'] = ((df['smap_soil_moisture'] < 0.2) & (df['weather_precip_7d'] < 10)).astype('int8')
        print(f"  • drought_stress: {df['drought_stress'].sum():,} days with low moisture & no rain")
    else: print("  • 'smap_soil_moisture' or 'weather_precip_7d' not found, skipping 'drought_stress'")
    print(f"\n  Added derived features")

    # --- STEP 8: VALIDATION ---
    print("\n" + "─" * 80)
    print("STEP 7: FINAL VALIDATION CHECKS")
    print("─" * 80)

    total_nans = df.isnull().sum().sum()
    if total_nans > 0:
        print(f" WARNING: {total_nans:,} NaN values remaining!")
    else: print(f"  No missing values (NaNs: 0)")

    numeric_cols = df.select_dtypes(include=[np.number]).columns
    inf_count = np.isinf(df[numeric_cols]).sum().sum()
    if inf_count > 0:
        print(f" WARNING: {inf_count:,} infinite values found!")
        df[numeric_cols] = df[numeric_cols].replace([np.inf, -np.inf], np.nan)
        print("  Replaced infinites with NaN. Now filling with median...")
        imputer = SimpleImputer(strategy='median')
        df[numeric_cols] = imputer.fit_transform(df[numeric_cols])
        print(f"  Replaced infinites with median")
    else: print(f"  No infinite values (Infs: 0)")

    duplicates = df.duplicated(subset=['field_id', 'date']).sum()
    if duplicates > 0:
        print(f"Removing {duplicates:,} duplicate field-date combinations")
        df = df.drop_duplicates(subset=['field_id', 'date'], keep='first')
    else: print(f"  No duplicates")

    seq_lengths = df.groupby('field_id').size()
    print(f"\n  Sequence lengths per field:")
    print(f"    • Min: {seq_lengths.min()} | Max: {seq_lengths.max()} | Mean: {seq_lengths.mean():.1f}")
    min_required = 30
    short_fields = seq_lengths[seq_lengths < min_required]
    if len(short_fields) > 0:
        print(f" {len(short_fields)} fields have < {min_required} observations")
    else: print(f"    All fields have >= {min_required} observations")

    print("\n" + "="*80)
    print(" "*25 + "CLEANING SUMMARY")
    print("="*80)
    final_rows = len(df)
    final_memory = df.memory_usage(deep=True).sum() / 1e6
    print(f"\nDATASET CHANGES:")
    print(f"  Rows:    {initial_rows:,} → {final_rows:,} ({final_rows-initial_rows:+,})")
    print(f"  Columns: {len(df.columns)}")
    print(f"  Memory:  {initial_memory:.1f} MB → {final_memory:.1f} MB ({final_memory/initial_memory*100:.1f}%)")
    print(f"\nDATA QUALITY:")
    print(f"  Fields:      {df['field_id'].nunique()}")
    print(f"  Date range:  {df['date'].min().date()} to {df['date'].max().date()}")
    print(f"  NaN values:  {df.isnull().sum().sum()}")

    return df


def display_cleaned_data_info(df):
    """Display detailed information about the cleaned dataset."""
    if df is None or df.empty:
        print("DataFrame is empty. Nothing to display.")
        return

    # --- FIX: Removed the stray 'S' ---
    print("\n" + "="*80)
    print(" "*20 + "CLEANED DATASET INFORMATION")
    print("="*80)
    print("\n DATAFRAME INFO:")
    print("─" * 80)
    df.info(memory_usage='deep')
    print("\n" + "="*80)
    print(" "*20 + "FIRST 20 ROWS OF CLEANED DATA")
    print("="*80)
    df_display = df.copy()
    df_display['date'] = df_display['date'].dt.strftime('%Y-%m-%d')
    pd.set_option('display.max_columns', None)
    pd.set_option('display.width', None)
    pd.set_option('display.max_colwidth', 50)
    print("\n" + df_display.head(20).to_string(index=True))
    print("\n" + "="*80)
    print(" "*25 + "STATISTICAL SUMMARY")
    print("="*80)
    key_cols = ['NDVI_mean', 'weather_temp_mean', 'weather_precip_sum',
                'smap_soil_moisture', 'era5_gdd_base5', 'era5_gdd_cumsum']
    key_cols = [c for c in key_cols if c in df.columns]
    if key_cols:
        print("\nKey Feature Statistics:")
        print(df[key_cols].describe().round(3).to_string())
    else: print("\nNo key statistic columns found to display.")

# MAIN

def main():
    print("="*80)
    print(" " * 15 + "COLAB PREPROCESSING - MERGED PIPELINE")
    print("="*80)

    if not os.path.exists(RAW_NDVI_INPUT_FILE):
        print(f"FATAL ERROR: Input file not found: {RAW_NDVI_INPUT_FILE}")
        print("Please upload your NDVI parquet file to the session.")
        return

    print(f"Found input file: {RAW_NDVI_INPUT_FILE}")

    # --- PART 1: AGGREGATION (from 01_lstm) ---
    df_agg = run_aggregation()
    if df_agg is None:
        print("\nPipeline halted during aggregation.")
        return

    # --- PART 2: CLEANING & FEATURES (from 02_lstm) ---
    df_cleaned = clean_and_impute_data(df_agg)
    if df_cleaned is None:
        print("\nPipeline halted during cleaning.")
        return

    # --- PART 3: VISUALIZATION (from 02_lstm) ---
    display_cleaned_data_info(df_cleaned)

    # --- PART 4: SAVING ---
    print("\n" + "="*80)
    print(" "*30 + "SAVING OUTPUTS")
    print("="*80)

    if df_cleaned.empty:
        print("   Final DataFrame is empty. Nothing to save.")
        return

    print("\n Writing files...")
    try:
        print("  • Saving Parquet...")
        df_cleaned.to_parquet(OUTPUT_PARQUET_PATH, index=False, compression='snappy')
        parquet_size = os.path.getsize(OUTPUT_PARQUET_PATH)
        print(f"    {OUTPUT_PARQUET_PATH} (Size: {parquet_size / 1_000_000:.2f} MB)")

        print("  • Saving CSV...")
        df_csv = df_cleaned.copy()
        df_csv['date'] = df_csv['date'].dt.strftime('%Y-%m-%d')
        # --- FIX: Use the global variable OUTPUT_CSV_PATH ---
        df_csv.to_csv(OUTPUT_CSV_PATH, index=False)
        print(f"    {OUTPUT_CSV_PATH}")

    except Exception as e:
        print(f"\nAn error occurred while saving files: {e}")
        return

    print("\n" + "="*80)
    print(" "*20 + "MERGED PIPELINE COMPLETE!")
    print("="*80)
    print("\nYour dataset is now:")
    print("  • Clean (no missing or invalid values)")
    print("  • Optimized (reduced memory usage)")
    print("  • Enhanced (stress indicators added)")
    print("  • LSTM-ready (validated sequences)")
    print("\nOutput files:")
    print(f"  • {OUTPUT_PARQUET_PATH}")
    print(f"  • {OUTPUT_CSV_PATH}")


if __name__ == '__main__':
    main()

Clearing Hugging Face datasets cache...
               COLAB PREPROCESSING - MERGED PIPELINE
Found input file: ndvi_by_field.parquet

PART 1: AGGREGATING DATA (LSTM-Style)
Loading field geometry data...


Generating train split:   0%|          | 0/112526 [00:00<?, ? examples/s]

   Loaded 112,526 fields
Loading NDVI data from ndvi_by_field.parquet...
   Processed 2,237 records
Loading weather data...


Generating train split:   0%|          | 0/1816320 [00:00<?, ? examples/s]

   Aggregating to daily...
   Loaded 74,304 daily weather point-records
Loading SMAP soil data...


Generating train split: 0 examples [00:00, ? examples/s]

   Aggregating AM/PM to daily...
   Loaded 72,609 daily SMAP point-records
Loading ERA5 data...


Generating train split: 0 examples [00:00, ? examples/s]

   Loaded 15,164,696 daily ERA5 point-records

Processing 52 fields with NDVI data
   Date range: 2025-03-02 to 2025-10-29
   Creating field-date matrix...
   Baseline: 12,584 field-date combinations

   Building spatial indices for fast joins...
   Building spatial indices by date for 74304 records...


   Indexing: 100%|██████████| 1376/1376 [00:03<00:00, 379.57it/s]


   ...spatial indices built.
   Building spatial indices by date for 72609 records...


   Indexing: 100%|██████████| 1289/1289 [00:03<00:00, 325.62it/s]


   ...spatial indices built.
   Building spatial indices by date for 15164696 records...


   Indexing: 100%|██████████| 3908/3908 [04:29<00:00, 14.50it/s]


   ...spatial indices built.

   Performing spatial joins...


   Joining data: 100%|██████████| 242/242 [00:02<00:00, 110.98it/s]
   Joining data: 100%|██████████| 242/242 [00:01<00:00, 205.05it/s]
   Joining data: 100%|██████████| 242/242 [00:01<00:00, 143.62it/s]



   Merging all datasets...
   Weather data merged
   SMAP data merged
   ERA5 data merged
   NDVI data merged
   Temporal features (day, month, week) added

Aggregation complete. Shape: (12584, 30)
Columns: ['field_id', 'date', 'weather_temp_mean', 'weather_temp_min', 'weather_temp_max', 'weather_precip_sum', 'weather_humidity_mean', 'weather_wind_mean', 'weather_pressure_mean', 'weather_precip_7d', 'weather_temp_7d_mean', 'smap_soil_moisture', 'smap_surface_temp_C', 'smap_veg_water', 'smap_clay_fraction', 'era5_2m_temperature', 'era5_gdd_base5', 'era5_volumetric_soil_water_layer_1', 'era5_sw_root', 'era5_total_precipitation', 'era5_surface_solar_radiation_downwards', 'era5_gdd_cumsum', 'NDVI_mean', 'NDVI_min', 'NDVI_max', 'pixel_count', 'NDVI_std', 'day_of_year', 'month', 'week']

                         PART 2: DATA CLEANING PIPELINE (LSTM-Style)

 STEP 1: DROP USELESS COLUMNS
─────────────────────────────────────────────────────────────────────────────
Problem: Some columns have n