In [None]:
! pip install optuna

Collecting optuna
  Downloading optuna-4.3.0-py3-none-any.whl.metadata (17 kB)
Collecting alembic>=1.5.0 (from optuna)
  Downloading alembic-1.15.2-py3-none-any.whl.metadata (7.3 kB)
Collecting colorlog (from optuna)
  Downloading colorlog-6.9.0-py3-none-any.whl.metadata (10 kB)
Downloading optuna-4.3.0-py3-none-any.whl (386 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m386.6/386.6 kB[0m [31m6.5 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading alembic-1.15.2-py3-none-any.whl (231 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m231.9/231.9 kB[0m [31m9.2 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading colorlog-6.9.0-py3-none-any.whl (11 kB)
Installing collected packages: colorlog, alembic, optuna
Successfully installed alembic-1.15.2 colorlog-6.9.0 optuna-4.3.0


In [None]:
import ee
import geemap
import optuna
import plotly
from plotly.subplots import make_subplots
import plotly.graph_objects as go
import pickle
from google.colab import output

In [None]:
cloud_project = 'ee-bomacarthur'

try:
  ee.Initialize(project=cloud_project)
except:
  ee.Authenticate()
  ee.Initialize(project=cloud_project)


In [None]:
m = geemap.Map(width=800)
m

Map(center=[0, 0], controls=(WidgetControl(options=['position', 'transparent_bg'], widget=SearchDataGUI(childr…

In [None]:
# Bring in the water balance model from Google Earth Engine
# Load datasets
fields = ee.FeatureCollection("projects/ee-bomacarthur/assets/Test_Fields/testFieldsMeterDataArea_4_16_update")
etAndPrcpCol = ee.FeatureCollection("projects/ee-bomacarthur/assets/etAndPrcp")

# Sort the et and prcp collection by date
et_and_prcp = etAndPrcpCol.sort('system:time_start')

# Define the optimization function
def water_balance_model(precip_const, residual_const):

    xy_constants = ee.FeatureCollection([
    ee.Feature(None, {'month': i + 1, 'precipConst': precip_const[i], 'residualConst': residual_const[i]}) for i in range(12)
    ])

    # Create a function to calculate monthly Water Balance = P(x) - ET
    def calculate_water_balance(image):
        month = image.date().get('month')
        x_value = xy_constants.filter(ee.Filter.eq('month', month)).first().getNumber('precipConst')
        precip_effective = image.select('P').multiply(x_value).rename('Peff')
        calculate_water_balance = precip_effective.subtract(image.select('ETa')).rename('wb')
        return image.addBands(precip_effective).addBands(calculate_water_balance).copyProperties(image, image.propertyNames())

    water_balance = ee.ImageCollection(et_and_prcp.filterDate('2011-01-01', '2024-01-01')).map(calculate_water_balance)

    # Manually add constant 0 'wbAdjusted' band to first image from Dec 2010
    et_and_prcp_first = ee.Image(et_and_prcp.filterDate('2010-12-01', '2011-01-01').first())
    first_x_value = xy_constants.filter(ee.Filter.eq('month', 12)).first().getNumber('precipConst')
    wb_first = et_and_prcp_first.addBands(et_and_prcp_first.select('P').multiply(first_x_value) \
                                .subtract(et_and_prcp_first.select('ETa')).rename('wb'))
    first_image = wb_first.addBands(wb_first.select('wb').where(wb_first.select('wb').lt(0), 0).rename('wbRadj'))

    # Add stored water from previous month to create monthly adjusted water balance
    def add_residuals_function(current, previous):
        previous_image = ee.Image(ee.List(previous).get(-1))
        current_image = ee.Image(current)
        month = current_image.date().get('month')
        y_value = xy_constants.filter(ee.Filter.eq('month', month)).first().getNumber('residualConst')
        residual_prev_mon = previous_image.select('wbRadj').rename('wbRn1')
        residual_ratio_prev_mon = residual_prev_mon.multiply(ee.Image.constant(y_value)).rename('wbRn1Adj')
        wb_residual_cur_mon = current_image.select('wb').where(current_image.select('wb').lt(0), 0).unmask().rename('wbR')
        wb_adjusted = current_image.select('wb').add(residual_ratio_prev_mon).rename('wbA')
        residual_adjusted_cur_mon = wb_adjusted.select('wbA').where(wb_adjusted.select('wbA').lt(0), 0).unmask().rename('wbRadj')

        updated_image = current_image.addBands(residual_prev_mon).addBands(residual_ratio_prev_mon) \
                                     .addBands(wb_residual_cur_mon).addBands(wb_adjusted).addBands(residual_adjusted_cur_mon) \
                                     .copyProperties(current_image, current_image.propertyNames())

        return ee.List(previous).add(updated_image)
    # Iterate the residual function over the water balance image collection
    initial_list = ee.List([first_image])
    add_residuals = ee.ImageCollection.fromImages(water_balance.iterate(add_residuals_function, initial_list))

    # Calculate Zonal Stats for test field polygons
    def calculate_zonal_stats(image3):
        image_date_start = image3.date()
        image_date_end = image_date_start.advance(1, 'month')
        fields_filtered = fields.filterDate(image_date_start, image_date_end)
        zonal_stats = image3.reduceRegions(**{
            'collection': fields_filtered,
            'reducer': ee.Reducer.mean(),
            'scale': 30,
            'crs': 'EPSG:32610',
            'tileScale': 16
        })
        return zonal_stats.map(lambda feature: feature.select(['area', 'endDate', 'ETa', 'groupNumber', 'precipitation', 'precipEffective',
                                                             'startDate', 'system:time_end', 'system:time_start', 'wb', 'wbA', 'wbR', 'wbRn1', 'wbRadj',
                                                             'wbRn1Adj', 'waterMeterDepth', 'waterMeterVolume']))

    zonal_stats = ee.FeatureCollection(add_residuals.filterDate('2011-01-01', '2024-01-01').map(calculate_zonal_stats)).flatten()

    # Calculate Model Water Volume

    # Filter feature collection for two cases, wbA > 0 and wbA < 0
    non_zero_irrigation = zonal_stats.filter(ee.Filter.lt('wbA', 0))
    zero_irrigation = zonal_stats.filter(ee.Filter.gte('wbA', 0))

    # Fields with positive irrigation wbA < 0
    non_zero_irrigation_col = non_zero_irrigation.map(lambda feature: feature.set({
        'ETi': feature.getNumber('wbA').abs(),
        'waterModelDepth': feature.getNumber('wbA').abs()
    }).select(['area', 'endDate', 'ETa', 'ETi', 'groupNumber', 'precipitation', 'precipEffective',
               'startDate', 'system:time_end', 'system:time_start', 'wb', 'wbA', 'wbR', 'wbRn1', 'wbRadj',
               'wbRn1Adj', 'waterMeterDepth', 'waterMeterVolume', 'waterModelDepth']))

    # Fields with zero irrigation wbA >= 0
    zero_irrigation_col = zero_irrigation.map(lambda feature: feature.set({
        'ETi': 0,
        'waterModelDepth': 0
    }).select(['area', 'endDate', 'ETa', 'ETi', 'groupNumber', 'precipitation', 'precipEffective',
               'startDate', 'system:time_end', 'system:time_start', 'wb', 'wbA', 'wbR', 'wbRn1', 'wbRadj',
               'wbRn1Adj', 'waterMeterDepth', 'waterMeterVolume', 'waterModelDepth']))
    # Merge feature Collections
    merged_collection = ee.FeatureCollection([non_zero_irrigation_col, zero_irrigation_col]).flatten().sort('system:time_start')

    # Calculate Model Minus Meter (Predicted - observed)
    test_results = merged_collection.map(lambda feature: feature.set({
        'meterMinusModelDepth': feature.getNumber('waterMeterDepth').subtract(feature.getNumber('waterModelDepth')),
        'meterMinusModelDepthAbs': feature.getNumber('waterMeterDepth').subtract(feature.getNumber('waterModelDepth')).abs()
    }).copyProperties(feature))

    # Calculate Mean Absolute Error of Depth
    mae = test_results.reduceColumns(**{
        'reducer': ee.Reducer.sum(),
        'selectors': ['meterMinusModelDepthAbs']
    }).getNumber('sum').divide(763)

    return mae.getInfo()

In [None]:
# Define the optimization function
def objective(trial):
  # Suggest values for precipitation constant for each month
    precip_const = [trial.suggest_discrete_uniform(f'precip_const_{i+1}', 0.0, 1.0, 0.01) for i in range(12)]

  # Suggest values for the hyperparameters (month-specific)
    precip_const = [trial.suggest_discrete_uniform(f'precip_const_{i+1}', 0.0, 1.0, 0.01) for i in range(12)]
    residual_const = [trial.suggest_discrete_uniform(f'residual_const_{i+1}', 0.0, 1.0, 0.01) for i in range(12)]

    # Call the water balance model with the suggested hyperparameters
    mae = water_balance_model(precip_const, residual_const)

    return mae

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
# import the study object from google Drive
with open('/content/drive/My Drive/WaterBalanceTestResults/XY/optuna_optimization_xy_updated.pkl', 'rb') as f:
    study_xy_updated = pickle.load(f)

In [None]:
# Create an Optuna study and optimize
# storage = optuna.storages.InMemoryStorage()
# study_xy_updated = optuna.create_study(direction='minimize')
study_xy_updated.optimize(objective, n_trials=100)

# Print the best hyperparameters and MAE
print("Best hyperparameters:", study_xy_updated.best_params)
print("Best MAE:", study_xy_updated.best_value)

  precip_const = [trial.suggest_discrete_uniform(f'precip_const_{i+1}', 0.0, 1.0, 0.01) for i in range(12)]
  precip_const = [trial.suggest_discrete_uniform(f'precip_const_{i+1}', 0.0, 1.0, 0.01) for i in range(12)]
  residual_const = [trial.suggest_discrete_uniform(f'residual_const_{i+1}', 0.0, 1.0, 0.01) for i in range(12)]
[I 2025-04-18 22:33:30,236] Trial 300 finished with value: 22.3881304098131 and parameters: {'precip_const_1': 0.43, 'precip_const_2': 0.8, 'precip_const_3': 0.43, 'precip_const_4': 0.98, 'precip_const_5': 0.34, 'precip_const_6': 0.34, 'precip_const_7': 0.56, 'precip_const_8': 0.5700000000000001, 'precip_const_9': 0.17, 'precip_const_10': 0.74, 'precip_const_11': 0.62, 'precip_const_12': 0.71, 'residual_const_1': 0.09, 'residual_const_2': 0.9500000000000001, 'residual_const_3': 0.08, 'residual_const_4': 0.92, 'residual_const_5': 0.21, 'residual_const_6': 0.44, 'residual_const_7': 0.33, 'residual_const_8': 0.24, 'residual_const_9': 0.8200000000000001, 'residual_con

Best hyperparameters: {'precip_const_1': 0.4, 'precip_const_2': 0.11, 'precip_const_3': 0.37, 'precip_const_4': 0.98, 'precip_const_5': 0.27, 'precip_const_6': 0.35000000000000003, 'precip_const_7': 0.49, 'precip_const_8': 0.45, 'precip_const_9': 0.14, 'precip_const_10': 0.6900000000000001, 'precip_const_11': 0.6900000000000001, 'precip_const_12': 0.63, 'residual_const_1': 0.39, 'residual_const_2': 0.92, 'residual_const_3': 0.1, 'residual_const_4': 0.93, 'residual_const_5': 0.26, 'residual_const_6': 0.43, 'residual_const_7': 0.34, 'residual_const_8': 0.23, 'residual_const_9': 0.84, 'residual_const_10': 0.11, 'residual_const_11': 0.27, 'residual_const_12': 0.67}
Best MAE: 22.24943462523977


In [None]:
# Save the study object to Google Drive
with open('/content/drive/My Drive/WaterBalanceTestResults/XY/optuna_optimization_xy_updated.pkl', 'wb') as f:
    pickle.dump(study_xy_updated, f)

In [None]:
# Save file as CSV locally
study_xy_updated.trials_dataframe().to_csv('/content/drive/My Drive/WaterBalanceTestResults/XY/optuna_optimization_xy_updated.csv')

In [None]:
# Get the optimization history plot
opt_hist_plot = optuna.visualization.plot_optimization_history(study_xy_updated)

# Adjust the size of the points in the plot
for trace in opt_hist_plot.data:
    if trace.type == 'scatter':  # Ensure we are modifying the correct trace
        trace.marker.size = 3  # Adjust the size of the points

# Show the modified plot
opt_hist_plot.show()

In [None]:
# Generate the parameter importance plot
importance_plot = optuna.visualization.plot_param_importances(study_xy_updated)

# Define a mapping of old hyperparameter names to new names
name_mapping = {
    "precip_const_1": "Precip Constant Jan",
    "precip_const_2": "Precip Constant Feb",
    "precip_const_3": "Precip Constant Mar",
    "precip_const_4": "Precip Constant Apr",
    "precip_const_5": "Precip Constant May",
    "precip_const_6": "Precip Constant Jun",
    "precip_const_7": "Precip Constant Jul",
    "precip_const_8": "Precip Constant Aug",
    "precip_const_9": "Precip Constant Sep",
    "precip_const_10": "Precip Constant Oct",
    "precip_const_11": "Precip Constant Nov",
    "precip_const_12": "Precip Constant Dec",
    "residual_const_1": "Residual Constant Jan",
    "residual_const_2": "Residual Constant Feb",
    "residual_const_3": "Residual Constant Mar",
    "residual_const_4": "Residual Constant Apr",
    "residual_const_5": "Residual Constant May",
    "residual_const_6": "Residual Constant Jun",
    "residual_const_7": "Residual Constant Jul",
    "residual_const_8": "Residual Constant Aug",
    "residual_const_9": "Residual Constant Sep",
    "residual_const_10": "Residual Constant Oct",
    "residual_const_11": "Residual Constant Nov",
    "residual_const_12": "Residual Constant Dec"
}

# Update the x-axis labels in the plot
importance_plot.update_yaxes(
    ticktext=[name_mapping.get(param, param) for param in importance_plot.data[0].y],  # Update tick labels
    tickvals=importance_plot.data[0].y  # Keep the original tick positions
)

# Show the updated plot
importance_plot.show()

In [None]:
# Generate the parallel coordinate plot
parallel_plot = optuna.visualization.plot_parallel_coordinate(study_xy_updated)

# Define the name mapping
name_mapping = {
    "precip_const_1": "Prcp Const Jan",
    "precip_const_2": "Prcp Const Feb",
    "precip_const_3": "Prcp Const Mar",
    "precip_const_4": "Prcp Const Apr",
    "precip_const_5": "Prcp Const May",
    "precip_const_6": "Prcp Const Jun",
    "precip_const_7": "Prcp Const Jul",
    "precip_const_8": "Prcp Const Aug",
    "precip_const_9": "Prcp Const Sep",
    "precip_const_10": "Prcp Const Oct",
    "precip_const_11": "Prcp Const Nov",
    "precip_const_12": "Prcp Const Dec",
    "residual_const_1": "Resid Const Jan",
    "residual_const_2": "Resid Const Feb",
    "residual_const_3": "Resid Const Mar",
    "residual_const_4": "Resid Const Apr",
    "residual_const_5": "Resid Const May",
    "residual_const_6": "Resid Const Jun",
    "residual_const_7": "Resid Const Jul",
    "residual_const_8": "Resid Const Aug",
    "residual_const_9": "Resid Const Sep",
    "residual_const_10": "Resid Const Oct",
    "residual_const_11": "Resid Const Nov",
    "residual_const_12": "Resid Const Dec"
}

# Extract the original dimensions (x-axis labels) from the plot
original_dimensions = parallel_plot.data[0].dimensions

# Debugging: Print the original names to check for mismatches
print("Original dimension names:")
for dim in original_dimensions:
    print(dim.label)

# Create a list of new dimension names using the name_mapping
new_dimensions = []
for dim in original_dimensions:
    original_name = dim.label
    new_name = name_mapping.get(original_name, original_name)  # Use the mapped name or fallback to original
    new_dimensions.append(new_name)

# Debugging: Print the new names to verify the mapping
# print("\nNew dimension names:")
# for original, new in zip([dim.label for dim in original_dimensions], new_dimensions):
#     print(f"{original} -> {new}")

# Update the x-axis labels in the plot
for i, dim in enumerate(parallel_plot.data[0].dimensions):
    dim.label = new_dimensions[i]

# Update layout for better spacing (optional)
parallel_plot.update_layout(
    title_text="Parallel Coordinate Plot",  # Add a title to the whole plot
    title_x=0.5,  # Center the title
    width=1200,  # Adjust width as needed
    height=600,  # Adjust height as needed
)

# Show the updated plot
parallel_plot.show()

Original dimension names:
Objective Value
precip_const_1
precip_const_10
precip_const_11
precip_const_12
precip_const_2
precip_const_3
precip_const_4
precip_const_5
precip_const_6
precip_const_7
precip_const_8
precip_const_9
residual_const_1
residual_const_10
residual_const_11
residual_const_12
residual_const_2
residual_const_3
residual_const_4
residual_const_5
residual_const_6
residual_const_7
residual_const_8
residual_const_9


In [None]:
emperical_distribution_plot = optuna.visualization.plot_edf(study_xy_updated)
emperical_distribution_plot

In [None]:
# Generate the slice plot
slice_plot = optuna.visualization.plot_slice(study_xy)

# Define the name mapping
name_mapping = {
    "precip_constant_1": "Precip Constant Jan",
    "precip_constant_2": "Precip Constant Feb",
    "precip_constant_3": "Precip Constant Mar",
    "precip_constant_4": "Precip Constant Apr",
    "precip_constant_5": "Precip Constant May",
    "precip_constant_6": "Precip Constant Jun",
    "precip_constant_7": "Precip Constant Jul",
    "precip_constant_8": "Precip Constant Aug",
    "precip_constant_9": "Precip Constant Sep",
    "precip_constant_10": "Precip Constant Oct",
    "precip_constant_11": "Precip Constant Nov",
    "precip_constant_12": "Precip Constant Dec",
    "residual_const_1": "Residual Constant Jan",
    "residual_const_2": "Residual Constant Feb",
    "residual_const_3": "Residual Constant Mar",
    "residual_const_4": "Residual Constant Apr",
    "residual_const_5": "Residual Constant May",
    "residual_const_6": "Residual Constant Jun",
    "residual_const_7": "Residual Constant Jul",
    "residual_const_8": "Residual Constant Aug",
    "residual_const_9": "Residual Constant Sep",
    "residual_const_10": "Residual Constant Oct",
    "residual_const_11": "Residual Constant Nov",
    "residual_const_12": "Residual Constant Dec"
}

# Determine the number of subplots and the grid layout
num_subplots = len(slice_plot.data)
num_rows = (num_subplots + 3) // 4  # Ensure 4 plots per row
num_cols = 4

# Extract hyperparameter names from the slice plot data
hyperparameter_names = [trace.name for trace in slice_plot.data]

# Map the hyperparameter names to the new names using name_mapping
new_titles = [name_mapping.get(name, name) for name in hyperparameter_names]

# Create a new subplot grid with the new titles
grid_plot = make_subplots(
    rows=num_rows,
    cols=num_cols,
    subplot_titles=new_titles  # Use the new titles here
)

# Add each subplot to the grid
for i, trace in enumerate(slice_plot.data):
    row = (i // num_cols) + 1
    col = (i % num_cols) + 1
    grid_plot.add_trace(trace, row=row, col=col)

# Update layout for better spacing
grid_plot.update_layout(
    title_text="Hyperparameter Slice Plot",  # Title for the whole plot
    title_x=0.5,  # Center the title
    height=300 * num_rows,  # Adjust height based on the number of rows
    width=1000,  # Adjust width as needed
    showlegend=False,
)

# Show the grid figure
grid_plot.show()