# Spatial Optimization Example

This notebook demonstrates how to perform a spatial optimization using the `SpatialObservation` and `SpatialScoreProcessor` classes.


In [None]:
import logging

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import xarray as xr
from dask.distributed import Client
from seapopym.configuration.no_transport import ForcingParameter, ForcingUnit, KernelParameter
from seapopym.model import NoTransportModel
from seapopym.standard.labels import CoordinatesLabels

from seapopym_optimization.algorithm.genetic_algorithm.factory import GeneticAlgorithmFactory
from seapopym_optimization.algorithm.genetic_algorithm.genetic_algorithm import GeneticAlgorithmParameters
from seapopym_optimization.algorithm.genetic_algorithm.logbook import Logbook, LogbookCategory, LogbookIndex
from seapopym_optimization.configuration_generator.no_transport_configuration_generator import (
    NoTransportConfigurationGenerator,
)
from seapopym_optimization.cost_function import SpatialScoreProcessor
from seapopym_optimization.cost_function.cost_function import CostFunction
from seapopym_optimization.cost_function.metric import nrmse_std_comparator
from seapopym_optimization.functional_group import NoTransportFunctionalGroup, Parameter
from seapopym_optimization.functional_group.base_functional_group import FunctionalGroupSet
from seapopym_optimization.functional_group.parameter_initialization import random_uniform_exclusive
from seapopym_optimization.observations.observation import DayCycle
from seapopym_optimization.observations.spatial import SpatialObservation

logging.basicConfig(level=logging.INFO)
client = Client()

## 1. Generate Synthetic Data

We generate synthetic forcing data (temperature and primary production) for a 1x1x1 grid over 365 days.


In [None]:
nb_days = 365
time_index = pd.date_range("2023-01-01", periods=nb_days, freq="D")

# Temperature: Constant 25 degrees
temperature = xr.DataArray(
    data=np.full((nb_days, 2, 2, 2), 25.0),
    dims=["time", "latitude", "longitude", "depth"],
    coords={
        "time": time_index,
        "latitude": [0, 1],
        "longitude": [0, 1],
        "depth": [0, 1],
    },
    name="temperature",
    attrs={
        "units": "Celsius",
        "long_name": "Sea surface temperature",
        "standard_name": "sea_surface_temperature",
    },
)

# Primary production: Random noise
primary_production = xr.DataArray(
    data=np.random.rand(nb_days, 2, 2) * 0.1,
    dims=["time", "latitude", "longitude"],
    coords={
        "time": time_index,
        "latitude": [0, 1],
        "longitude": [0, 1],
    },
    name="primary_production",
    attrs={
        "units": "kg/m^2/day",
        "long_name": "Primary production",
        "standard_name": "primary_production",
    },
)

# Set axis attributes
temperature.time.attrs = {"axis": "T"}
primary_production.time.attrs = {"axis": "T"}
temperature.latitude.attrs = {"axis": "Y"}
primary_production.latitude.attrs = {"axis": "Y"}
temperature.longitude.attrs = {"axis": "X", "unit": "degrees_east"}
primary_production.longitude.attrs = {"axis": "X"}
temperature.depth.attrs = {"axis": "Z"}

forcing_parameter = ForcingParameter(
    temperature=ForcingUnit(forcing=temperature),
    primary_production=ForcingUnit(forcing=primary_production),
)

In [None]:
temperature[:, :, :, 1] -= 10
temperature[:, :, 0, :] -= 5
temperature[:, 0, :, :] -= 2.5

primary_production[:, 0, :] *= 1.5
primary_production[:, :, 0] *= 2

## 2. Define Functional Group

We define a `Zooplankton` functional group that migrates between layer 1 (day) and layer 0 (night).


In [None]:
# Create a configuration generator
configuration_generator = NoTransportConfigurationGenerator()

# Run a model with known parameters to generate synthetic observations
initial_config = configuration_generator.generate(
    functional_group_parameters=[
        NoTransportFunctionalGroup(
            name="Zooplankton",
            day_layer=1,
            night_layer=0,
            energy_transfert=0.1668,
            gamma_tr=-0.11,
            tr_0=10.38,
            gamma_lambda_temperature=0.15,
            lambda_temperature_0=1 / 150,
        )
    ],
    forcing_parameters=forcing_parameter,
    kernel=KernelParameter(),
)


## 3. Generate Spatial Observations

We generate fragmented spatial observations. We'll simulate 'true' biomass first, then sample from it.


In [None]:
with NoTransportModel.from_configuration(initial_config) as initial_model:
    initial_model.run()
    true_biomass = initial_model.state.biomass


# Create fragmented observations
# Day observation (layer 1)
day_obs_data = true_biomass.sel(functional_group=0).copy()
# Fragment: keep only 1% of data
mask = np.random.rand(*day_obs_data.shape) < 0.01
day_obs_data = day_obs_data.where(mask)
# Ensure coordinates are correct for SpatialObservation (T, X, Y, Z)
day_obs_data = day_obs_data.expand_dims(dim={CoordinatesLabels.Z: [1]})

day_observation = SpatialObservation(name="day_obs", observation=day_obs_data, observation_type=DayCycle.DAY)

# Night observation (layer 0)
night_obs_data = true_biomass.sel(functional_group=0).copy()
# Fragment: keep only 1% of data
mask = np.random.rand(*night_obs_data.shape) < 0.01
night_obs_data = night_obs_data.where(mask)
night_obs_data = night_obs_data.expand_dims(dim={CoordinatesLabels.Z: [0]})

night_observation = SpatialObservation(name="night_obs", observation=night_obs_data, observation_type=DayCycle.NIGHT)

## 4. Optimization Setup

We set up the cost function using `SpatialScoreProcessor` and run the genetic algorithm.
We define the functional group with parameters to optimize.


In [None]:
# Define functional group for optimization with Parameter objects
epsilon = np.finfo(float).eps
functional_groups = [
    NoTransportFunctionalGroup(
        name="Zooplankton",
        day_layer=1,
        night_layer=0,
        energy_transfert=Parameter("D1N1_energy_transfert", epsilon, 0.5, init_method=random_uniform_exclusive),
        gamma_tr=Parameter("D1N1_gamma_tr", -0.3, -epsilon, init_method=random_uniform_exclusive),
        tr_0=Parameter("D1N1_tr_0", epsilon, 100, init_method=random_uniform_exclusive),
        gamma_lambda_temperature=Parameter(
            "D1N1_gamma_lambda_temperature", epsilon, 1 / 4, init_method=random_uniform_exclusive
        ),
        lambda_temperature_0=Parameter("D1N1_lambda_temperature_0", epsilon, 0.3, init_method=random_uniform_exclusive),
    ),
]

fg_set = FunctionalGroupSet(functional_groups=functional_groups)

cost_function = CostFunction(
    configuration_generator=configuration_generator,
    functional_groups=fg_set,
    forcing=forcing_parameter,
    kernel=KernelParameter(),
    observations=[day_observation, night_observation],
    processor=SpatialScoreProcessor(comparator=nrmse_std_comparator),
)
logbook = Logbook.from_sobol_samples(fg_set, sample_number=16, fitness_names=["day_obs", "night_obs"])

config_generator = NoTransportConfigurationGenerator()
ga_params = GeneticAlgorithmParameters(
    ETA=20,
    INDPB=0.2,
    CXPB=0.9,
    MUTPB=1,
    NGEN=5,
    POP_SIZE=100,
    cost_function_weight=(-1, -1),
)

genetic_algorithm = GeneticAlgorithmFactory.create_distributed(
    meta_parameter=ga_params,
    cost_function=cost_function,
    client=client,
    logbook=logbook,
)


In [None]:
results = genetic_algorithm.optimize()

## 5. Plot comparison


In [None]:
best_individual = results.sort_values(by=("Weighted_fitness", "Weighted_fitness"), ascending=False).iloc[0]["Parametre"]
best_individual

In [None]:
# Create a configuration generator
configuration_generator = NoTransportConfigurationGenerator()

# Run a model with known parameters to generate synthetic observations
initial_config = configuration_generator.generate(
    functional_group_parameters=[
        NoTransportFunctionalGroup(
            name="Zooplankton",
            day_layer=1,
            night_layer=0,
            energy_transfert=best_individual["D1N1_energy_transfert"],
            gamma_tr=best_individual["D1N1_gamma_tr"],
            tr_0=best_individual["D1N1_tr_0"],
            gamma_lambda_temperature=best_individual["D1N1_gamma_lambda_temperature"],
            lambda_temperature_0=best_individual["D1N1_lambda_temperature_0"],
        )
    ],
    forcing_parameters=forcing_parameter,
    kernel=KernelParameter(),
)
with NoTransportModel.from_configuration(initial_config) as initial_model:
    initial_model.run()
    predicted_biomass = initial_model.state.biomass

In [None]:
predicted_biomass

In [None]:
true_biomass

In [None]:
fig, axes = plt.subplots(4, 1, figsize=(8, 9), sharex=True)
for ax, lat, lon in zip(axes, [0, 0, 1, 1], [0, 1, 0, 1]):
    true_biomass.sel(X=lon, Y=lat).squeeze().plot(ax=ax, label="OBS")
    predicted_biomass.sel(X=lon, Y=lat).squeeze().plot(ax=ax, label="PRED")
    ax.set_title(f"X={lon} ; Y={lat}")
    ax.set_xlabel("")
    ax.legend(loc="lower right")