<div style="display: flex; align-items: center;">
    <h1>Optimizing parameters in a WOFOST crop model using <code>diffWOFOST</code></h1>
    <img src="https://raw.githubusercontent.com/WUR-AI/diffWOFOST/refs/heads/main/docs/logo/diffwofost.png" width="150" style="margin-left: 20px;">
</div>


This Jupyter notebook demonstrates the optimization of parameters in a
differentiable model using the `diffwofost` package. The package provides
differentiable implementations of the WOFOST model and its associated
sub-models. As `diffwofost` is under active development, this notebook focuses on
`leaf_dynamics`. 

To enable these models to operate independently, certain state variables
required by the model are supplied as "external states" derived from the test
data. Also, at this stage, only a limited subset of model parameters has been made
differentiable.

## 1. Leaf dynamics

In this section, we will demonstrate how to optimize two parameters `TWDI` and `SPAN` in
leaf_dynamics model using a differentiable version of leaf_dynamics.
The optimization will be done using the Adam optimizer from `torch.optim`.

### 1.1 software requirements

To run this notebook, we need to install the `diffwofost`; the differentiable
version of WOFOST models. Since the package is constantly under development, make
sure you have the latest version of `diffwofost` installed in your
python environment. You can install it using pip:

In [1]:
# install diffwofost
!pip install diffwofost



In [2]:
# ---- import libraries ----
import copy
import torch
import numpy
import yaml
from pathlib import Path
from diffwofost.physical_models.config import Configuration, ComputeConfig
from diffwofost.physical_models.crop.leaf_dynamics import WOFOST_Leaf_Dynamics
from diffwofost.physical_models.utils import EngineTestHelper
from diffwofost.physical_models.utils import prepare_engine_input
from diffwofost.physical_models.utils import get_test_data

In [3]:
# ---- disable a warning: this will be fixed in the future ----
import warnings
warnings.filterwarnings("ignore", message="To copy construct from a tensor.*")

### 1.2. Data

A test dataset of `LAI` (Leaf area index, including stem and pod area) and
`TWLV` (Dry weight of total leaves (living + dead)) will be used to optimize
parametesr `TWDI` (total initial dry weight) and `SPAN` (life span of leaves).
Note that in leaf_dynamic, changes in `SPAN` dont affect `TWLV`. 

The data is stored in PCSE tests folder, and can be doewnloded from PCSE repsository.
You can select any of the files related to `leaf_dynamics` model with a file name that follwos the pattern
`test_leafdynamics_wofost72_*.yaml`. Each file contains different data depending on the locatin and crop type.
For example, you can download the file "test_leafdynamics_wofost72_01.yaml" as:

In [4]:
import urllib.request

url = "https://raw.githubusercontent.com/ajwdewit/pcse/refs/heads/master/tests/test_data/test_leafdynamics_wofost72_01.yaml"
filename = "test_leafdynamics_wofost72_01.yaml"

urllib.request.urlretrieve(url, filename)
print(f"Downloaded: {filename}")

Downloaded: test_leafdynamics_wofost72_01.yaml


In [5]:
# ---- Check the path to the files that are downloaded as explained above ----
test_data_path = "test_leafdynamics_wofost72_01.yaml"

In [6]:
# ---- Here we read the test data and set some variables ----
test_data = get_test_data(test_data_path)
(crop_model_params_provider, weather_data_provider, agro_management_inputs, external_states) = (
    prepare_engine_input(test_data, ["SPAN", "TDWI", "TBASE", "PERDL", "RGRLAI"])
)

expected_results = test_data["ModelResults"]
expected_lai_twlv = torch.tensor(
    [[float(item["LAI"]), float(item["TWLV"])] for item in expected_results],
    dtype=ComputeConfig.get_dtype(),
    device=ComputeConfig.get_device(),
).unsqueeze(0) # shape: [1, time_steps, 2]

# ---- dont change this: in this config file we specified the diffrentiable version of leaf_dynamics ----
leaf_dynamics_config = Configuration(
    CROP=WOFOST_Leaf_Dynamics,
    OUTPUT_VARS=["LAI", "TWLV"],
)

### 1.3. Helper classes/functions

The model parameters shoudl stay in a valid range. To ensure this, we will use
`BoundedParameter` class with (min, max) and initial values for each
parameter. You might change these values depending on the crop type and
location. But dont use a very small range, otherwise gradiants will be very
small and the optimization will be very slow.

In [7]:
# ---- Adjust the values if needed  ----
TDWI_MIN, TDWI_MAX, TDWI_INIT = (0.0, 1.0, 0.40)
SPAN_MIN, SPAN_MAX, SPAN_INIT = (10.0, 60.0, 25.0)

# ---- Helper for bounded parameters ----
class BoundedParameter(torch.nn.Module):
    def __init__(self, low, high, init_value):
        super().__init__()
        self.low = low
        self.high = high

        # Normalize to [0, 1]
        init_norm = (init_value - low) / (high - low)

        # Parameter in raw logit space
        self.raw = torch.nn.Parameter(
            torch.logit(
                torch.tensor(
                    init_norm, dtype=ComputeConfig.get_dtype(), device=ComputeConfig.get_device()
                ),
                eps=1e-6,
            )
        )

    def forward(self):
        return self.low + (self.high - self.low) * torch.sigmoid(self.raw)

Another helper class is `OptDiffLeafDynamics` which is a subclass of `torch.nn.Module`. 
We use this class to wrap the `EngineTestHelper` function and make it easier to run the model `leaf_dynamic`.

In [8]:
# ---- Wrap the model with torch.nn.Module----
class OptDiffLeafDynamics(torch.nn.Module):
    def __init__(self, crop_model_params_provider, weather_data_provider, agro_management_inputs, leaf_dynamics_config, external_states):
        super().__init__()
        self.crop_model_params_provider = crop_model_params_provider
        self.weather_data_provider = weather_data_provider
        self.agro_management_inputs = agro_management_inputs
        self.config = leaf_dynamics_config
        self.external_states = external_states

        # bounded parameters
        self.tdwi = BoundedParameter(TDWI_MIN, TDWI_MAX, init_value=TDWI_INIT)
        self.span = BoundedParameter(SPAN_MIN, SPAN_MAX, init_value=SPAN_INIT)

    def forward(self):
        # currently, copying is needed due to an internal issue in engine
        crop_model_params_provider_ = copy.deepcopy(self.crop_model_params_provider)
        external_states_ = copy.deepcopy(self.external_states)
        
        tdwi_val = self.tdwi()
        span_val = self.span()
        
        # pass new value of parameters to the model
        crop_model_params_provider_.set_override("TDWI", tdwi_val, check=False)
        crop_model_params_provider_.set_override("SPAN", span_val, check=False)

        engine = EngineTestHelper(
            crop_model_params_provider_,
            self.weather_data_provider,
            self.agro_management_inputs,
            self.config,
            external_states_,
        )
        engine.run_till_terminate()
        results = engine.get_output()
        
        return torch.stack(
            [torch.stack([item["LAI"], item["TWLV"]]) for item in results]
        ).unsqueeze(0) # shape: [1, time_steps, 2]

In [9]:
# ----  Create model ---- 
opt_model = OptDiffLeafDynamics(
    crop_model_params_provider,
    weather_data_provider,
    agro_management_inputs,
    leaf_dynamics_config,
    external_states,
)

In [10]:
# ----  Early stopping ---- 
best_loss = float("inf")
patience = 10  # Number of steps to wait for improvement
patience_counter = 0
min_delta = 1e-4 

# ----  Optimizer ---- 
optimizer = torch.optim.Adam(opt_model.parameters(), lr=0.1)

# ----  We use relative MAE as loss because there are two outputs with different untis ----  
denom = torch.mean(torch.abs(expected_lai_twlv), dim=1) 

# Training loop (example)
for step in range(101):
    optimizer.zero_grad()
    results = opt_model() 
    mae = torch.mean(torch.abs(results - expected_lai_twlv), dim=1)
    rmae = mae / denom
    loss = rmae.sum()  # example: relative mean absolute error
    loss.backward()
    optimizer.step()

    print(f"Step {step}, Loss {loss.item():.4f}, TDWI {opt_model.tdwi().item():.4f}, SPAN {opt_model.span().item():.4f}")
    # Early stopping logic
    if loss.item() < best_loss - min_delta:
        best_loss = loss.item()
        patience_counter = 0
    else:
        patience_counter += 1
        if patience_counter >= patience:
            print(f"Early stopping at step {step}")
            break

Step 0, Loss 0.2682, TDWI 0.4242, SPAN 26.0705
Step 1, Loss 0.2400, TDWI 0.4485, SPAN 27.1814
Step 2, Loss 0.2117, TDWI 0.4727, SPAN 28.3303
Step 3, Loss 0.1790, TDWI 0.4962, SPAN 29.5140
Step 4, Loss 0.1502, TDWI 0.5190, SPAN 30.7290
Step 5, Loss 0.1148, TDWI 0.5351, SPAN 31.8327
Step 6, Loss 0.0873, TDWI 0.5459, SPAN 32.8254
Step 7, Loss 0.0633, TDWI 0.5525, SPAN 33.6958
Step 8, Loss 0.0375, TDWI 0.5559, SPAN 34.4362
Step 9, Loss 0.0164, TDWI 0.5564, SPAN 35.0038
Step 10, Loss 0.0001, TDWI 0.5543, SPAN 35.2185
Step 11, Loss 0.0048, TDWI 0.5500, SPAN 35.0892
Step 12, Loss 0.0019, TDWI 0.5436, SPAN 34.7440
Step 13, Loss 0.0092, TDWI 0.5356, SPAN 34.3535
Step 14, Loss 0.0175, TDWI 0.5260, SPAN 33.9622
Step 15, Loss 0.0296, TDWI 0.5151, SPAN 33.5733
Step 16, Loss 0.0409, TDWI 0.5030, SPAN 33.2059
Step 17, Loss 0.0512, TDWI 0.4949, SPAN 33.2068
Step 18, Loss 0.0512, TDWI 0.4903, SPAN 33.4762
Step 19, Loss 0.0422, TDWI 0.4888, SPAN 33.9525
Step 20, Loss 0.0299, TDWI 0.4900, SPAN 34.5985
Ea

In [11]:
# ---- validate the results using test data ---- 
print(f"Actual TDWI {crop_model_params_provider["TDWI"].item():.4f}, SPAN {crop_model_params_provider["SPAN"].item():.4f}")

Actual TDWI 0.5100, SPAN 35.0000
