<a href="https://colab.research.google.com/github/BoMacArthur/WRIA_1_Irrigation_Models/blob/main/Google_Colab_Earth_Engine_Scripts/Optuna_SWBR_13_Best_Model_Selection.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
! pip install optuna

In [None]:
import ee
import geemap
import optuna
import plotly
from plotly.subplots import make_subplots
import plotly.graph_objects as go
import pickle
from google.colab import output
import time

In [None]:
cloud_project = 'ee-bomacarthur'

try:
  ee.Initialize(project=cloud_project)
except:
  ee.Authenticate()
  ee.Initialize(project=cloud_project)


In [None]:
m = geemap.Map(width=800)
m

In [None]:
# Bring in the water balance model from Google Earth Engine
# Load datasets
fields = ee.FeatureCollection("projects/ee-bomacarthur/assets/Test_Fields/validationFields")
etAndPrcpCol = ee.FeatureCollection("projects/ee-bomacarthur/assets/etAndPrcp")

# Sort the et and prcp collection by date
et_and_prcp = etAndPrcpCol.sort('system:time_start')

# Count number of fields
size = fields.size()

# Define the optimization function
def water_balance_model(precip_const_1, precip_const_2, precip_const_3, precip_const_4, precip_const_5, precip_const_6,
                              precip_const_7, precip_const_8, precip_const_9, precip_const_10, precip_const_11, precip_const_12, residual_const):

    # X_Constants
    x_constants = ee.FeatureCollection(
        [ee.Feature(None, {'month': 1,  'precipConst': precip_const_1}),
         ee.Feature(None, {'month': 2,  'precipConst': precip_const_2}),
         ee.Feature(None, {'month': 3,  'precipConst': precip_const_3}),
         ee.Feature(None, {'month': 4,  'precipConst': precip_const_4}),
         ee.Feature(None, {'month': 5,  'precipConst': precip_const_5}),
         ee.Feature(None, {'month': 6,  'precipConst': precip_const_6}),
         ee.Feature(None, {'month': 7,  'precipConst': precip_const_7}),
         ee.Feature(None, {'month': 8,  'precipConst': precip_const_8}),
         ee.Feature(None, {'month': 9,  'precipConst': precip_const_9}),
         ee.Feature(None, {'month': 10, 'precipConst': precip_const_10}),
         ee.Feature(None, {'month': 11, 'precipConst': precip_const_11}),
         ee.Feature(None, {'month': 12, 'precipConst': precip_const_12})
        ]
    );

    # Y_Constant
    y_constant = residual_const

    # Create a function to calculate monthly Water Balance = P(x) - ET
    def calculate_water_balance(image):
        month = image.date().get('month')
        x_value = x_constants.filter(ee.Filter.eq('month', month)).first().getNumber('precipConst')
        precip_effective = image.select('P').multiply(x_value).rename('Peff')
        calculate_water_balance = precip_effective.subtract(image.select('ETa')).rename('wb')
        return image.addBands(precip_effective).addBands(calculate_water_balance).copyProperties(image, image.propertyNames())

    water_balance = ee.ImageCollection(et_and_prcp.filterDate('2011-01-01', '2024-01-01')).map(calculate_water_balance)

    # Manually add constant 0 'wbAdjusted' band to first image from Dec 2010
    et_and_prcp_first = ee.Image(et_and_prcp.filterDate('2010-12-01', '2011-01-01').first())
    first_x_value = x_constants.filter(ee.Filter.eq('month', 12)).first().getNumber('precipConst')
    wb_first = et_and_prcp_first.addBands(et_and_prcp_first.select('P').multiply(first_x_value) \
                                .subtract(et_and_prcp_first.select('ETa')).rename('wb'))
    first_image = wb_first.addBands(wb_first.select('wb').where(wb_first.select('wb').lt(0), 0).rename('wbRadj'))

    # Add stored water from previous month to create monthly adjusted water balance
    def add_residuals_function(current, previous):
        previous_image = ee.Image(ee.List(previous).get(-1))
        current_image = ee.Image(current)
        month = current_image.date().get('month')
        y_value = y_constant
        residual_prev_mon = previous_image.select('wbRadj').rename('wbRn1')
        residual_ratio_prev_mon = residual_prev_mon.multiply(ee.Image.constant(y_value)).rename('wbRn1Adj')
        wb_residual_cur_mon = current_image.select('wb').where(current_image.select('wb').lt(0), 0).unmask().rename('wbR')
        wb_adjusted = current_image.select('wb').add(residual_ratio_prev_mon).rename('wbA')
        residual_adjusted_cur_mon = wb_adjusted.select('wbA').where(wb_adjusted.select('wbA').lt(0), 0).unmask().rename('wbRadj')

        updated_image = current_image.addBands(residual_prev_mon).addBands(residual_ratio_prev_mon) \
                                     .addBands(wb_residual_cur_mon).addBands(wb_adjusted).addBands(residual_adjusted_cur_mon) \
                                     .copyProperties(current_image, current_image.propertyNames())

        return ee.List(previous).add(updated_image)
    # Iterate the residual function over the water balance image collection
    initial_list = ee.List([first_image])
    add_residuals = ee.ImageCollection.fromImages(water_balance.iterate(add_residuals_function, initial_list))

    # Calculate Zonal Stats for test field polygons
    def calculate_zonal_stats(image3):
        image_date_start = image3.date()
        image_date_end = image_date_start.advance(1, 'month')
        fields_filtered = fields.filterDate(image_date_start, image_date_end)
        zonal_stats = image3.reduceRegions(**{
            'collection': fields_filtered,
            'reducer': ee.Reducer.mean(),
            'scale': 30,
            'crs': 'EPSG:32610',
            'tileScale': 16
        })
        return zonal_stats.map(lambda feature: feature.select(['area', 'endDate', 'ETa', 'groupNumber', 'precipitation', 'precipEffective',
                                                             'startDate', 'system:time_end', 'system:time_start', 'wb', 'wbA', 'wbR', 'wbRn1', 'wbRadj',
                                                             'wbRn1Adj', 'waterMeterDepth', 'waterMeterVolume']))

    zonal_stats = ee.FeatureCollection(add_residuals.filterDate('2011-01-01', '2024-01-01').map(calculate_zonal_stats)).flatten()

    # Calculate Model Water Volume

    # Filter feature collection for two cases, wbA > 0 and wbA < 0
    non_zero_irrigation = zonal_stats.filter(ee.Filter.lt('wbA', 0))
    zero_irrigation = zonal_stats.filter(ee.Filter.gte('wbA', 0))

    # Fields with positive irrigation wbA < 0
    non_zero_irrigation_col = non_zero_irrigation.map(lambda feature: feature.set({
        'ETi': feature.getNumber('wbA').abs(),
        'waterModelDepth': feature.getNumber('wbA').abs()
    }).select(['area', 'endDate', 'ETa', 'ETi', 'groupNumber', 'precipitation', 'precipEffective',
               'startDate', 'system:time_end', 'system:time_start', 'wb', 'wbA', 'wbR', 'wbRn1', 'wbRadj',
               'wbRn1Adj', 'waterMeterDepth', 'waterMeterVolume', 'waterModelDepth']))

    # Fields with zero irrigation wbA >= 0
    zero_irrigation_col = zero_irrigation.map(lambda feature: feature.set({
        'ETi': 0,
        'waterModelDepth': 0
    }).select(['area', 'endDate', 'ETa', 'ETi', 'groupNumber', 'precipitation', 'precipEffective',
               'startDate', 'system:time_end', 'system:time_start', 'wb', 'wbA', 'wbR', 'wbRn1', 'wbRadj',
               'wbRn1Adj', 'waterMeterDepth', 'waterMeterVolume', 'waterModelDepth']))
    # Merge feature Collections
    merged_collection = ee.FeatureCollection([non_zero_irrigation_col, zero_irrigation_col]).flatten().sort('system:time_start')

    # Calculate Model Minus Meter (Predicted - observed)
    test_results = merged_collection.map(lambda feature: feature.set({
        'meterMinusModelDepth': feature.getNumber('waterMeterDepth').subtract(feature.getNumber('waterModelDepth')),
        'meterMinusModelDepthAbs': feature.getNumber('waterMeterDepth').subtract(feature.getNumber('waterModelDepth')).abs()
    }).copyProperties(feature))

    # Calculate Mean Absolute Error of Depth
    mae = test_results.reduceColumns(**{
        'reducer': ee.Reducer.sum(),
        'selectors': ['meterMinusModelDepthAbs']
    }).getNumber('sum').divide(size)

    return mae

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# import the study object from google Drive
with open('/content/drive/My Drive/WaterBalanceTestResults/XY_static/optuna_optimization_x_y_static_no_outliers_1000.pkl', 'rb') as f:
    study_x_y_static_no_outliers_1000 = pickle.load(f)

In [None]:
# Sort trials by value and get all trials within 1% of the best trial
top_trials = sorted(study_x_y_static_no_outliers_1000.trials, key=lambda t: t.value)[:225]
print("Trials within 1% of best trial:")
for trial in top_trials:
    print(f"Value: {trial.value}, Params: {trial.params}")

In [None]:
# Create an empty list to store the results of each trial
results = []

# Loop over each trial from list of top trials
for i, trial in enumerate(top_trials):
    print(f"Processing trial {i+1}/225...")

    # Dictionary of parameter values from each trial
    params = trial.params
    # Build a list of numerical Parametr values
    param_list = [params[f'precip_const_{m}'] for m in range(1, 13)] + [params['residual_const']]

    try:
        # Run the Earth Engine model with parameter list as model inputs
        earth_engine_result = water_balance_model(*param_list)
        # Extract the numerical results of the Earth Engine model
        mae_value = earth_engine_result.getInfo()

        # Create a dictionary to merge param values with MAE result and add to empty results list
        result_entry = {**params, 'mae': mae_value}
        results.append(result_entry)

        # Print results
        print(f"Trial {i+1} MAE: {mae_value}")
        print(f"Params: {params}")

    # Print error message for any unsuccessful trials
    except Exception as e:
        print(f"Trial {i+1} failed: {e}")
        results.append({**params, 'mae': None})

    # Pause script for 1 second between trials
    time.sleep(1)

print("All trials processed.")

# Find the result with the lowest MAE
best_result = min(results, key=lambda r: r['mae'])
best_index = results.index(best_result)

print(f"\nBest trial SWBR-13: (Trial #{best_index + 1}):")
print(f"MAE: {best_result['mae']}")
print("Parameters:")
for k, v in best_result.items():
    if k != 'mae':
        print(f"  {k}: {v}")
