# Optimized Species Distribution Modeling (SDM)

This notebook contains a streamlined and modularized version of the SDM workflow. 
It includes steps for initialization, data preparation, model training, prediction, and export.

In [None]:
# 1. Setup and Imports
import ee
import geemap
import pandas as pd
import numpy as np
from google.cloud import bigquery
import geopandas as gpd
import shapely.geometry
import json

# Initialize Earth Engine
try:
    ee.Initialize(project='cryptic-yen-457008-p4')
except Exception as e:
    ee.Authenticate()
    ee.Initialize(project='cryptic-yen-457008-p4')

#  Initialize BigQuery
client = bigquery.Client(project='rsc-cropmap-lzp')

# 2. Run the Query
sql = """
SELECT
    pred1 AS species,
    CAST(agri_year AS INT64) AS year,
    ST_AsGeoJSON(geometry) AS geometry
FROM
    `rsc-cropmap-lzp.published.Fused_Categories_Orchards`
WHERE
    pred1 IS NOT NULL
    AND pred1 = 'avocado'
"""
print("Running BigQuery...")
df = client.query(sql).to_dataframe(create_bqstorage_client=False)

# 3. Convert to GeoDataFrame
df['geometry'] = df['geometry'].apply(lambda x: shapely.geometry.shape(json.loads(x)))
gdf = gpd.GeoDataFrame(df, geometry='geometry', crs="EPSG:4326")

# 4. Convert to Earth Engine FeatureCollection (Critical step for SDM)
print("Converting to Earth Engine object...")
data_raw = geemap.gdf_to_ee(gdf)

In [None]:
# 2. Configuration
# Centralize all constants and paths here for easy management

CONFIG = {
    'BANDS': ['OrderST', 'aspect', 'elevation', 'slope', 'bio01', 'bio12', 'srad', 'vdf'],
    'ASSETS': {
        'SOIL': "projects/cryptic-yen-457008-p4/assets/IsraelSoilTaxonomy",
        'BIO1': "projects/cryptic-yen-457008-p4/assets/SDM/wc2_1_30s_bio_1",
        'BIO12': "projects/cryptic-yen-457008-p4/assets/SDM/wc2_1_30s_bio_12",
        'SRAD': "projects/cryptic-yen-457008-p4/assets/SDM/wc2_1_30s_srad",
        'VAPR': "projects/cryptic-yen-457008-p4/assets/SDM/wc2_1_30s_vapr",
        'MODEL_OUTPUT': 'projects/cryptic-yen-457008-p4/assets/avocado_final_model'
    },
    'CMIP6_COLLECTION': "NASA/GDDP-CMIP6",
    'VISUALIZATION': {
        'SUITABILITY': {"min": 0, "max": 1, "palette": ["ffffff", "cecece", "fcd163", "66a000", "204200"]},
        'DIFF': {"min": -0.3, "max": 0.3, "palette": ["d7191c", "ffffff", "2c7bb6"]}
    },
    'EXPORT': {
        'FOLDER': 'GEE_Exports',
        'SCALE': 1000
    },
    'GRAIN_SIZE': 1000,
    'TEST_YEAR': 2018
}

In [None]:
# 3. Helper Functions

def get_predictors():
    """Loads and preprocesses predictor variables."""
    # Topography
    terrain = ee.Algorithms.Terrain(ee.Image("USGS/SRTMGL1_003")).unmask()
    
    # Soil
    soil_fc = ee.FeatureCollection(CONFIG['ASSETS']['SOIL'])
    u_types = soil_fc.aggregate_array('OrderST').distinct().sort()
    soil_img = soil_fc.map(lambda f: f.set('Code', u_types.indexOf(f.get('OrderST')))) \
        .reduceToImage(['Code'], ee.Reducer.first()).rename('OrderST').unmask(-1)
    
    # Climate (Current)
    # Load new assets
    bio1 = ee.Image(CONFIG['ASSETS']['BIO1']).rename('bio01').unmask()
    bio12 = ee.Image(CONFIG['ASSETS']['BIO12']).rename('bio12').unmask()
    srad = ee.Image(CONFIG['ASSETS']['SRAD']).rename('srad').unmask()
    vapr = ee.Image(CONFIG['ASSETS']['VAPR']).rename('vapr').unmask()

    # Calculate VDF (Vapor Pressure Deficit)
    # VPD = es - ea (vapr)
    # es = 0.6108 * exp(17.27 * T / (T + 237.3))
    # bio1 is Mean Temp (check units, assuming Celsius based on standard WorldClim)
    
    es = bio1.expression(
        '0.6108 * exp((17.27 * T) / (T + 237.3))',
        {'T': bio1}
    )
    
    # vapr is in kPa, es is in kPa (0.6108 is kPa)
    vdf = es.subtract(vapr).rename('vdf')

    # Combine all
    return bio1.addBands([bio12, srad, vdf, soil_img, terrain.select(['elevation', 'slope', 'aspect'])])

def remove_duplicates(data, grain_size):
    """Removes duplicate presence points within the same pixel."""
    random_raster = ee.Image.random().reproject("EPSG:4326", None, grain_size)
    rand_point_vals = random_raster.sampleRegions(
        collection=ee.FeatureCollection(data), geometries=True
    )
    return rand_point_vals.distinct("random")

def split_data(data, predictors, aoi, test_year, grain_size):
    """Splits data into Train, Validation, and Test sets with pseudo-absences."""
    
    # 1. De-duplicate Presence
    print("Removing duplicates...")
    presence = remove_duplicates(data, grain_size)
    print(f"Presence points after de-duplication: {presence.size().getInfo()}")

    # 2. Split Presence by Year and Random
    # Test Set (Hold out year)
    pres_test = presence.filter(ee.Filter.eq('year', test_year)).map(lambda f: f.set('PresAbs', 1))
    
    # Remaining (Train + Val)
    pres_remain = presence.filter(ee.Filter.neq('year', test_year))
    pres_remain = pres_remain.randomColumn()
    
    # Train (70%) / Val (30%)
    pres_train = pres_remain.filter(ee.Filter.lt('random', 0.7)).map(lambda f: f.set('PresAbs', 1))
    pres_val = pres_remain.filter(ee.Filter.gte('random', 0.7)).map(lambda f: f.set('PresAbs', 1))
    
    # 3. Generate Pseudo-Absences
    print("Generating pseudo-absences...")
    
    # Presence mask (user logic)
    presence_mask = presence.reduceToImage(properties=['random'], reducer=ee.Reducer.first()) \
        .reproject('EPSG:4326', None, grain_size).mask().neq(1).selfMask()
        
    # Valid predictor area (mask of first band)
    cl_mask = predictors.select(0).mask()
    
    # Area for Pseudo-Absences
    area_for_pa = presence_mask.updateMask(cl_mask).clip(aoi)
    
    # Generate absences (Total count approx equal to total presence)
    total_pres_count = presence.size()
    absences = predictors.sample(
        region=area_for_pa.geometry(), 
        scale=grain_size, 
        numPixels=total_pres_count.multiply(1.2), # Generate a bit more to be safe
        geometries=True
    ).randomColumn().map(lambda f: f.set('PresAbs', 0))
    
    # Split Absences to match Presence ratios
    # We want roughly 1:1 ratio in each set
    count_test = pres_test.size()
    count_train = pres_train.size()
    
    # Sort by random to easily pick chunks
    absences_list = absences.toList(absences.size())
    
    abs_test = ee.FeatureCollection(absences_list.slice(0, count_test))
    abs_train = ee.FeatureCollection(absences_list.slice(count_test, count_test.add(count_train)))
    abs_val = ee.FeatureCollection(absences_list.slice(count_test.add(count_train)))
    
    # 4. Merge and Sample
    def sample_data(pres, abs_):
        merged = pres.merge(abs_)
        return predictors.select(CONFIG['BANDS']).sampleRegions(
            collection=merged, 
            properties=["PresAbs"], 
            scale=grain_size, 
            tileScale=16
        )

    train_data = sample_data(pres_train, abs_train)
    val_data = sample_data(pres_val, abs_val)
    test_data = sample_data(pres_test, abs_test)
    
    return train_data, val_data, test_data

def train_model(training_data, mode='MULTIPROBABILITY'):
    """Trains the Random Forest classifier."""
    classifier = ee.Classifier.smileRandomForest(250).train(training_data, "PresAbs", CONFIG['BANDS'])
    return classifier.setOutputMode(mode)

def get_future_climate(scenario, model='ACCESS1-0', year=2050):
    """Fetches and processes future climate data."""
    # ... (To be updated in next step, but keeping placeholder for now to avoid breaking)
    # For now, just return what was there or updated logic if I want to do it all at once.
    # Let's do it all at once to be efficient.
    
    start_year = year - 9
    end_year = year + 10
    
    nex = ee.ImageCollection(CONFIG['CMIP6_COLLECTION']) \
        .filter(ee.Filter.date(f'{start_year}-01-01', f'{end_year}-12-31')) \
        .filter(ee.Filter.eq('scenario', scenario)) \
        .filter(ee.Filter.eq('model', model))

    def convert(img):
        # Pr is in kg m-2 s-1 (mm/s). Multiply by 86400 to get mm/day.
        pr = img.select('pr').multiply(86400).rename('precip_mm')
        
        # Tas is in Kelvin. Subtract 273.15 to get Celsius.
        tas = img.select('tas').subtract(273.15).rename('tmean_c')
        
        # Rsds is in W m-2. Convert to kJ m-2 day-1.
        # 1 W = 1 J/s. 1 day = 86400 s.
        # W/m2 * 86400 = J/m2/day. Divide by 1000 for kJ.
        # Factor = 86.4
        srad = img.select('rsds').multiply(86.4).rename('srad')
        
        # Calculate VPD using Tas and Hurs (%)
        # es = 0.6108 * exp(17.27 * T / (T + 237.3))
        # ea = es * (hurs / 100)
        # vpd = es - ea
        
        t = tas
        hurs = img.select('hurs')
        
        es = t.expression(
            '0.6108 * exp((17.27 * T) / (T + 237.3))',
            {'T': t}
        )
        
        ea = es.multiply(hurs.divide(100))
        vdf = es.subtract(ea).rename('vdf')
        
        return img.addBands([pr, tas, srad, vdf])

    nex_agg = nex.map(convert).map(lambda i: i.resample('bilinear').reproject('EPSG:4326', None, 1000))
    
    # Aggregate over time (mean for most, sum for precip if needed, but usually mean annual precip? 
    # Bio12 is Annual Precip. So we need sum of daily precip for a year, then mean over 20 years?
    # Or mean daily precip * 365?
    # WorldClim Bio12 is Annual Precipitation.
    # CMIP6 'pr' converted to mm/day.
    # Mean(mm/day) * 365 = Annual Precip.
    
    bio01 = nex_agg.select('tmean_c').mean().rename('bio01')
    bio12 = nex_agg.select('precip_mm').mean().multiply(365).rename('bio12')
    srad = nex_agg.select('srad').mean().rename('srad')
    vdf = nex_agg.select('vdf').mean().rename('vdf')
    
    return bio01.addBands([bio12, srad, vdf])

def export_image_to_drive(image, description, filename, region):
    """Creates and starts an export task."""
    task = ee.batch.Export.image.toDrive(
        image=image,
        description=description,
        folder=CONFIG['EXPORT']['FOLDER'],
        fileNamePrefix=filename,
        scale=CONFIG['EXPORT']['SCALE'],
        region=region,
        maxPixels=1e13
    )
    task.start()
    print(f"Started export task: {description}")


In [None]:
# 4. Main Execution Flow

# --- A. Prepare Data ---
print("Preparing data...")
predictors = get_predictors()
soil_fc = ee.FeatureCollection(CONFIG['ASSETS']['SOIL'])
aoi = soil_fc.geometry().bounds()

if 'data_raw' in locals():
    print("Splitting data and preparing datasets...")
    # data_raw is the raw EE FeatureCollection from BigQuery
    train_data, val_data, test_data = split_data(
        data_raw, predictors, aoi, CONFIG['TEST_YEAR'], CONFIG['GRAIN_SIZE']
    )
    
    print(f"Training set size: {train_data.size().getInfo()}")
    print(f"Validation set size: {val_data.size().getInfo()}")
    print(f"Test set size: {test_data.size().getInfo()}")
else:
    print("WARNING: 'data_raw' variable not found. Please ensure BigQuery step ran successfully.")


In [None]:
# --- Data Analysis ---
import seaborn as sns
import matplotlib.pyplot as plt

if 'train_data' in locals():
    print("Generating Correlation Matrix...")
    # Sample training data to pandas for analysis (limit to 5000 points to avoid timeouts)
    # We need to sample the predictors at the training points
    # Actually, train_data already has the predictor values if sampled correctly?
    # In split_data, we used sampleRegions, so the properties should be there.
    
    # Convert to Pandas
    # Limit size
    n_samples = min(5000, train_data.size().getInfo())
    df_train = geemap.ee_to_pandas(train_data.limit(n_samples))
    
    # Select predictor columns
    cols = CONFIG['BANDS']
    
    # Plot
    plt.figure(figsize=(10, 8))
    sns.heatmap(df_train[cols].corr(), annot=True, cmap='coolwarm', fmt=".2f")
    plt.title("Predictor Correlation Matrix")
    plt.show()


In [None]:
# --- B. Train Model ---
if 'train_data' in locals():
    print("Training model...")
    rf_model = train_model(train_data)
    
    # --- Validation ---
    print("Validating model...")
    validated = val_data.classify(rf_model)
    # Calculate Accuracy (or other metrics)
    error_matrix = validated.errorMatrix('PresAbs', 'classification')
    print("Validation Accuracy:", error_matrix.accuracy().getInfo())
    print("Validation Kappa:", error_matrix.kappa().getInfo())
    
    # --- Testing ---
    print(f"Testing on year {CONFIG['TEST_YEAR']}...")
    tested = test_data.classify(rf_model)
    test_matrix = tested.errorMatrix('PresAbs', 'classification')
    print("Test Accuracy:", test_matrix.accuracy().getInfo())


In [None]:
# --- Model Interpretation ---
if 'rf_model' in locals():
    print("Calculating Variable Importance...")
    importance = rf_model.explain().get('importance').getInfo()
    
    # Plot
    import matplotlib.pyplot as plt
    
    plt.figure(figsize=(10, 6))
    plt.bar(importance.keys(), importance.values())
    plt.xticks(rotation=45)
    plt.title("Variable Importance")
    plt.ylabel("Importance")
    plt.show()


In [None]:
# --- C. Future Predictions ---

def predict_suitability(model, climate_stack, static_predictors):
    full_stack = climate_stack.addBands(static_predictors.select(['OrderST', 'elevation', 'slope', 'aspect']))
    return full_stack.select(CONFIG['BANDS']).classify(model).arrayGet([1])

if 'rf_model' in locals():
    print("Predicting future scenarios...")
    
    scenarios = ['ssp245', 'ssp585']
    years = [2050, 2100]
    
    future_maps = {}
    
    for scenario in scenarios:
        for year in years:
            print(f"Processing {scenario} - {year}...")
            future_climate = get_future_climate(scenario, year=year)
            suitability_map = predict_suitability(rf_model, future_climate, predictors)
            future_maps[f"{scenario}_{year}"] = suitability_map
            
    # Calculate Difference (Example: SSP585 2050 vs SSP245 2050)
    if 'ssp585_2050' in future_maps and 'ssp245_2050' in future_maps:
        diff_map = future_maps['ssp585_2050'].subtract(future_maps['ssp245_2050'])
        print("Difference map created.")


In [None]:
# 5. Visualization
Map = geemap.Map(layout={'height':'600px', 'width':'100%'})
Map.centerObject(aoi, 7)

if 'future_maps' in locals():
    for name, img in future_maps.items():
        Map.addLayer(img.clip(aoi), CONFIG['VISUALIZATION']['SUITABILITY'], f"Suitability {name}")
        
    if 'diff_map' in locals():
        Map.addLayer(diff_map.clip(aoi), CONFIG['VISUALIZATION']['DIFF'], "Diff SSP585-SSP245 (2050)")

Map

In [None]:
# 6. Exports
if 'future_maps' in locals():
    for name, img in future_maps.items():
        # export_image_to_drive(img, f'export_{name}', f'avocado_{name}', aoi)
        pass
