# DATA EXTRACTION

From the benchmark notebook and data set provided it is clear that the data used lacks depth in terms of driver columns for the target variables.

Also taking to account that using Microsoft planetary data takes more time I will instead use data from Google Earth Engine.

The data extracted will include:
- raw bands wich will help us calulate Turbidity and other driver features (green, blue, red, NIR, SWIR1, SWIR2)
- Environmental and Terrain factors (precipitation, temperature, pet, evlevation, slope, soil composition etc.)
- Surrounding and Landcover type features (water occurence, population, land cover)

In [7]:
# Library Importation
import ee
import geemap
import google.auth
from google.oauth2.credentials import Credentials
from google_auth_oauthlib.flow import InstalledAppFlow
import pandas as pd
import numpy as np
import json
import time

In [8]:
# Authentication and initialization
ee.Initialize(project = 'cleanwatai-466814')


In [9]:
# Library Importation
import ee
import geemap
import pandas as pd
import numpy as np
import time

ee.Initialize(project='cleanwatai-466814')


def run_clean_sweep(csv_path, output_name, export_folder='EE_Exports'):

    # Configuration constants
    BUFFER_SIZE_SMALL  = 30
    BUFFER_SIZE_LARGE  = 3000
    MISSING_VALUE      = -9999
    CHIRPS_RESOLUTION  = 5566
    ERA5_RESOLUTION    = 11132
    SOIL_RESOLUTION    = 250

    # Load & clean CSV — only keep columns needed for extraction
    # (other columns like NDMI, nir, green etc. may contain NaN which
    #  breaks geemap.pandas_to_ee serialization)
    df = pd.read_csv(csv_path)
    KEEP_COLS = ['Index', 'Latitude', 'Longitude', 'Sample Date']
    df = df[KEEP_COLS]
    df = df.where(pd.notnull(df), None)
    assert not df.isnull().values.any(), "NaN still present after cleaning!"

    points = geemap.pandas_to_ee(df, latitude='Latitude', longitude='Longitude')

    bounds     = ee.Geometry.Rectangle([16.0, -35.0, 33.0, -22.0])
    sat_window = 30  # Keep tight window — temporal accuracy matters for water quality

    # Satellite collections
    def mask_landsat(img):
        qa   = img.select('QA_PIXEL')
        mask = qa.bitwiseAnd(1 << 3).eq(0).And(qa.bitwiseAnd(1 << 4).eq(0))
        return img.select(['SR_B.*']).updateMask(mask).multiply(0.0000275).copyProperties(img, ['system:time_start'])

    l7 = (ee.ImageCollection("LANDSAT/LE07/C02/T1_L2")
          .filterBounds(bounds).map(mask_landsat)
          .select(['SR_B1','SR_B2','SR_B3','SR_B4','SR_B5'],
                  ['blue','green','red','nir','swir1']))
    l8 = (ee.ImageCollection("LANDSAT/LC08/C02/T1_L2")
          .filterBounds(bounds).map(mask_landsat)
          .select(['SR_B2','SR_B3','SR_B4','SR_B5','SR_B6'],
                  ['blue','green','red','nir','swir1']))
    landsat_col = l7.merge(l8)

    modis_col = (ee.ImageCollection("MODIS/061/MOD09GA")
                 .filterBounds(bounds)
                 .select(['sur_refl_b03','sur_refl_b04','sur_refl_b01','sur_refl_b02','sur_refl_b06'],
                         ['blue','green','red','nir','swir1'])
                 .map(lambda img: img.multiply(0.0001).copyProperties(img, ['system:time_start'])))

    chirps    = ee.ImageCollection('UCSB-CHG/CHIRPS/DAILY')
    era5      = ee.ImageCollection("ECMWF/ERA5_LAND/DAILY_AGGR").select(['temperature_2m','dewpoint_temperature_2m'])
    srtm      = ee.Image("USGS/SRTMGL1_003")
    jrc       = ee.Image("JRC/GSW1_4/GlobalSurfaceWater").select('occurrence')
    soil_ph   = ee.Image("OpenLandMap/SOL/SOL_PH-H2O_USDA-4C1A2A_M/v02").select('b0')
    soil_clay = ee.Image("OpenLandMap/SOL/SOL_CLAY-WFRACTION_USDA-3A1A1A_M/v02").select('b0')
    soil_sand = ee.Image("OpenLandMap/SOL/SOL_SAND-WFRACTION_USDA-3A1A1A_M/v02").select('b0')
    ghsl      = ee.ImageCollection("JRC/GHSL/P2023A/GHS_POP")

    def extract_features(feature):

        date_str    = ee.String(feature.get('Sample Date'))
        clean_date  = date_str.replace('/', '-', 'g').replace('\\.', '-', 'g')
        is_iso      = clean_date.slice(0, 4).match('^[0-9]+$')
        date_format = ee.Algorithms.If(is_iso, 'yyyy-MM-dd', 'd-M-y')
        sample_date = ee.Date.parse(date_format, clean_date)
        geom        = feature.geometry()

        # ── Precipitation ────────────────────────────────────────────────────
        r0 = chirps.filterDate(sample_date, sample_date.advance(1, 'day')).first()
        r1 = chirps.filterDate(sample_date.advance(-1, 'day'), sample_date).first()
        r3 = chirps.filterDate(sample_date.advance(-3, 'day'), sample_date.advance(1, 'day')).sum()

        # BUG FIX 1: get_precip was using ee.Algorithms.If(value, ...) which
        # treats 0.0 (a valid dry-day reading) as falsy → returned -9999.
        # Fix: check IsEqual(value, None) explicitly instead of truthiness.
        def get_precip(img, geom):
            reduced = img.reduceRegion(ee.Reducer.first(), geom, CHIRPS_RESOLUTION)
            value   = reduced.get('precipitation')
            return ee.Number(ee.Algorithms.If(
                ee.Algorithms.IsEqual(value, None),  # Only -9999 if truly absent
                MISSING_VALUE,
                value
            ))

        # ── Spectral ─────────────────────────────────────────────────────────
        def get_nearest(col):
            filt = col.filterDate(
                sample_date.advance(-sat_window, 'day'),
                sample_date.advance(sat_window, 'day')
            ).filterBounds(geom)
            nearest = ee.Image(filt.map(lambda img: img.set(
                'dist', img.date().difference(sample_date, 'day').abs()
            )).sort('dist').first())
            return ee.Image(ee.Algorithms.If(
                filt.size().gt(0), nearest, ee.Image(0).updateMask(ee.Image(0))
            ))

        l_img = get_nearest(landsat_col)
        m_img = get_nearest(modis_col)
        has_l = l_img.bandNames().size().gt(0)
        has_m = m_img.bandNames().size().gt(0)

        fallback_img = (ee.Image.constant(-9999).rename('blue')
                        .addBands(ee.Image.constant(-9999).rename('green'))
                        .addBands(ee.Image.constant(-9999).rename('red'))
                        .addBands(ee.Image.constant(-9999).rename('nir'))
                        .addBands(ee.Image.constant(-9999).rename('swir1')))

        final_spec = ee.Image(ee.Algorithms.If(has_l, l_img, ee.Algorithms.If(has_m, m_img, fallback_img)))
        final_spec = final_spec.unmask(-9999)
        final_spec = final_spec.where(final_spec.neq(final_spec).Or(final_spec.abs().gt(1e10)), -9999)

        # FIX: Use IsEqual(dist, None) — dist=0 means same-day image, which is valid
        days_offset = ee.Number(ee.Algorithms.If(
            has_l,
            ee.Algorithms.If(
                ee.Algorithms.IsEqual(l_img.get('dist'), None),
                MISSING_VALUE,
                ee.Number(l_img.get('dist')).abs()
            ),
            ee.Algorithms.If(
                has_m,
                ee.Algorithms.If(
                    ee.Algorithms.IsEqual(m_img.get('dist'), None),
                    MISSING_VALUE,
                    ee.Number(m_img.get('dist')).abs()
                ),
                MISSING_VALUE
            )
        ))
        source = ee.Algorithms.If(has_l, 'Landsat', ee.Algorithms.If(has_m, 'MODIS', 'None'))

        s_fallback = ee.Dictionary({'blue': MISSING_VALUE, 'green': MISSING_VALUE, 'red': MISSING_VALUE,
                                    'nir': MISSING_VALUE, 'swir1': MISSING_VALUE})
        raw_s_vals = ee.Dictionary(ee.Algorithms.If(
            final_spec.bandNames().size().gt(0),
            # Buffer=90m samples more pixels, reducing chance of all-masked empty result
            final_spec.reduceRegion(reducer=ee.Reducer.mean(), geometry=geom.buffer(90),
                                    scale=BUFFER_SIZE_SMALL, maxPixels=1e6),
            s_fallback
        ))
        # BUG FIX 2: combine() keeps the CALLING dict's values on key conflicts.
        # Old: s_fallback.combine(raw_s_vals) → fallback wins, real data lost!
        # Fix: raw_s_vals.combine(s_fallback) → real data wins, fallback only
        #      fills keys that are genuinely absent from the real result.
        s_vals = raw_s_vals.combine(s_fallback)

        # ── Weather ───────────────────────────────────────────────────────────
        weather    = era5.filterDate(sample_date, sample_date.advance(1, 'day')).first()
        w_fallback = ee.Dictionary({'temperature_2m': MISSING_VALUE, 'dewpoint_temperature_2m': MISSING_VALUE})
        raw_w_vals = ee.Dictionary(ee.Algorithms.If(
            weather,
            weather.reduceRegion(ee.Reducer.first(), geom, ERA5_RESOLUTION, bestEffort=True),
            w_fallback
        ))
        # Same fix: real data takes priority over fallback
        w_vals = raw_w_vals.combine(w_fallback)

        # BUG FIX 3: get_safe_value was using ee.Algorithms.If(value, ...)
        # which treats 0.0 as falsy — any zero measurement returned -9999.
        # Fix: explicitly check IsEqual(value, None) for the null test.
        def get_safe_value(dict_obj, key, default=MISSING_VALUE):
            value = ee.Dictionary(dict_obj).get(key)
            return ee.Number(ee.Algorithms.If(
                ee.Algorithms.IsEqual(value, None),   # Truly missing?
                default,
                value                                  # Return as-is (0 is valid!)
            ))

        # ── Temperature & RH ─────────────────────────────────────────────────
        t_raw  = get_safe_value(w_vals, 'temperature_2m')
        t      = ee.Number(ee.Algorithms.If(t_raw.neq(MISSING_VALUE), t_raw.subtract(273.15), MISSING_VALUE))

        td_raw = get_safe_value(w_vals, 'dewpoint_temperature_2m')
        td     = ee.Number(ee.Algorithms.If(td_raw.neq(MISSING_VALUE), td_raw.subtract(273.15), MISSING_VALUE))

        # ERA5 always has data globally — RH should virtually never be -9999
        rh = ee.Number(ee.Algorithms.If(
            t.neq(MISSING_VALUE).And(td.neq(MISSING_VALUE)),
            ee.Number(100).multiply(
                td.multiply(17.625).divide(td.add(243.04)).exp()
            ).divide(
                t.multiply(17.625).divide(t.add(243.04)).exp()
            ),
            MISSING_VALUE
        ))

        # ── Topography ───────────────────────────────────────────────────────
        topo_fallback = ee.Dictionary({'elevation': MISSING_VALUE, 'slope': MISSING_VALUE})
        topo = (srtm.addBands(ee.Terrain.slope(srtm))
                    .reduceRegion(ee.Reducer.first(), geom, BUFFER_SIZE_SMALL, bestEffort=True)
                    .combine(topo_fallback))  # real data wins

        # ── Population ───────────────────────────────────────────────────────
        # GHSL only has data for: 1975, 1990, 2000, 2015, 2020, 2025, 2030
        # For 2011-2014, use closest year (2015)
        sample_year = sample_date.get('year')
        pop_year = ee.Number(ee.Algorithms.If(
            sample_year.lt(2008), 2000,  # Before 2008 → use 2000
            ee.Algorithms.If(
                sample_year.lt(2018), 2015,  # 2008-2017 → use 2015
                2020  # 2018+ → use 2020
            )
        ))
        pop = ghsl.filter(ee.Filter.calendarRange(pop_year, pop_year, 'year')).first()

        return feature.set({
            # Precipitation — 0 is a valid value, should never be -9999
            'precip_0d':        get_precip(r0, geom),
            'precip_1d':        get_precip(r1, geom),
            'precip_3d_sum':    get_precip(r3, geom),

            # Spectral bands
            'band_blue':        get_safe_value(s_vals, 'blue'),
            'band_green':       get_safe_value(s_vals, 'green'),
            'band_red':         get_safe_value(s_vals, 'red'),
            'band_nir':         get_safe_value(s_vals, 'nir'),
            'band_swir1':       get_safe_value(s_vals, 'swir1'),
            'days_offset':      days_offset,

            # Weather — ERA5 is global, humidity should virtually never be -9999
            'temp_c':           t,
            'humidity':         rh,

            # Soil
            'soil_ph':          get_safe_value(soil_ph.reduceRegion(ee.Reducer.first(), geom, SOIL_RESOLUTION, bestEffort=True), 'b0'),
            'soil_clay_perc':   get_safe_value(soil_clay.reduceRegion(ee.Reducer.first(), geom, SOIL_RESOLUTION, bestEffort=True), 'b0'),
            'soil_sand_perc':   get_safe_value(soil_sand.reduceRegion(ee.Reducer.first(), geom, SOIL_RESOLUTION, bestEffort=True), 'b0'),

            # Topography & water
            'elevation':        get_safe_value(topo, 'elevation'),
            'slope':            get_safe_value(topo, 'slope'),
            'water_occurrence': get_safe_value(
                jrc.reduceRegion(ee.Reducer.first(), geom, BUFFER_SIZE_SMALL, bestEffort=True), 'occurrence'
            ),

            # Population — 0 is valid (uninhabited area)
            # FIX: use IsEqual(pop, None) instead of If(pop, ...) so that
            # a pop image that exists but returns 0 is not treated as missing
            'pop_density_3km':  ee.Number(ee.Algorithms.If(
                ee.Algorithms.IsEqual(pop, None),
                MISSING_VALUE,
                get_safe_value(
                    pop.reduceRegion(ee.Reducer.sum(), geom.buffer(BUFFER_SIZE_LARGE), 100),
                    'population_count'
                )
            )),

            'data_source': source
        })

    enriched_points = points.map(extract_features)

    task = ee.batch.Export.table.toDrive(
        collection=enriched_points,
        description=output_name,
        folder=export_folder,
        fileNamePrefix=output_name,
        fileFormat='CSV'
    )
    task.start()
    return task


def wait_for_task(task, label='Task'):
    print(f'  {label} started — monitoring...')
    while True:
        status = task.status()
        state  = status['state']
        if state == 'COMPLETED':
            print(f'  ✅ {label} complete!')
            return True
        elif state == 'FAILED':
            print(f'  ❌ {label} FAILED: {status.get("error_message", "unknown error")}')
            return False
        elif state in ('CANCEL_REQUESTED', 'CANCELLED'):
            print(f'  ⚠️  {label} was cancelled.')
            return False
        else:
            print(f'  ... {label} status: {state}')
            time.sleep(30)


if __name__ == '__main__':
    EXPORT_FOLDER = 'EE_Exports'

    print('Starting Training Extraction...')
    train_task = run_clean_sweep('Training_Dataset.csv', 'Training_Master', EXPORT_FOLDER)
    wait_for_task(train_task, 'Training')

    print('\n' + '=' * 30)
    print('Starting Validation Extraction...')
    val_task = run_clean_sweep('Validation_Dataset.csv', 'Validation_Master', EXPORT_FOLDER)
    wait_for_task(val_task, 'Validation')

    print('\nDone! Check your Google Drive folder:', EXPORT_FOLDER)

Starting Training Extraction...
  Training started — monitoring...
  ... Training status: READY
  ... Training status: RUNNING
  ... Training status: RUNNING
  ... Training status: RUNNING
  ... Training status: RUNNING
  ... Training status: RUNNING
  ... Training status: RUNNING
  ... Training status: RUNNING
  ... Training status: RUNNING
  ... Training status: RUNNING
  ... Training status: RUNNING
  ... Training status: RUNNING
  ... Training status: RUNNING
  ... Training status: RUNNING
  ... Training status: RUNNING
  ... Training status: RUNNING
  ... Training status: RUNNING
  ... Training status: RUNNING
  ... Training status: RUNNING
  ... Training status: RUNNING
  ... Training status: RUNNING
  ... Training status: RUNNING
  ... Training status: RUNNING
  ... Training status: RUNNING
  ... Training status: RUNNING
  ... Training status: RUNNING
  ... Training status: RUNNING
  ... Training status: RUNNING
  ... Training status: RUNNING
  ... Training status: RUNNING
  ...