# A simple optimization problem: NoTransport model


In [None]:
import numpy as np
import pandas as pd
import plotly
import xarray as xr
from dask.distributed import Client
from matplotlib import pyplot as plt
from seapopym.configuration.no_transport import ForcingParameter, ForcingUnit

from seapopym_optimization.algorithm.genetic_algorithm.genetic_algorithm import (
    GeneticAlgorithm,
    GeneticAlgorithmParameters,
)
from seapopym_optimization.algorithm.genetic_algorithm.logbook import OptimizationLog
from seapopym_optimization.cost_function.cost_function import CostFunction, DayCycle, TimeSeriesObservation
from seapopym_optimization.functional_group import NoTransportFunctionalGroup, Parameter
from seapopym_optimization.functional_group.base_functional_group import FunctionalGroupSet
from seapopym_optimization.functional_group.parameter_initialization import random_uniform_exclusive
from seapopym_optimization.model_generator import NoTransportConfigurationGenerator

plotly.offline.init_notebook_mode()

In [None]:
import logging

logging.basicConfig(level=logging.INFO, format="\n%(asctime)s - %(name)s - %(levelname)s - %(message)s")
logger = logging.getLogger("seapopym_optimization")
logger.setLevel(logging.INFO)

In [None]:
functional_groups = [
    NoTransportFunctionalGroup(
        name="Zooplankton",
        day_layer=0,
        night_layer=0,
        energy_transfert=Parameter(
            "D1N1_energy_transfert energy_transfert", 0.001, 0.3, init_method=random_uniform_exclusive
        ),
        gamma_tr=Parameter("D1N1_gamma_tr", -0.3, -0.001, init_method=random_uniform_exclusive),
        tr_0=Parameter("D1N1_tr_0", 0, 30, init_method=random_uniform_exclusive),
        gamma_lambda_temperature=Parameter(
            "D1N1_gamma_lambda_temperature", 1 / 300, 1, init_method=random_uniform_exclusive
        ),
        lambda_temperature_0=Parameter("D1N1_lambda_temperature_0", 0, 0.3, init_method=random_uniform_exclusive),
    ),
]
fg_set = FunctionalGroupSet(functional_groups=functional_groups)

In [None]:
nb_days_by_year = 365
nb_years = 2
temperature = xr.DataArray(
    data=(np.sin(np.linspace(0, (2 * np.pi) * nb_years, nb_days_by_year * nb_years)) * 5 + 20).reshape(
        (nb_days_by_year * nb_years, 1, 1, 1)
    ),
    dims=["time", "latitude", "longitude", "depth"],
    coords={
        "time": pd.date_range("2023-01-01", periods=nb_days_by_year * nb_years, freq="D"),
        "latitude": [0],
        "longitude": [0],
        "depth": [0],
    },
    name="temperature",
    attrs={
        "units": "Celsius",
        "long_name": "Sea surface temperature",
        "standard_name": "sea_surface_temperature",
    },
)
primary_production = xr.DataArray(
    data=(
        (
            np.random.rand(nb_days_by_year * nb_years).reshape((nb_days_by_year * nb_years, 1, 1))
            + (np.cos(np.linspace(0, np.pi * nb_years, nb_days_by_year * nb_years))).reshape(
                (nb_days_by_year * nb_years, 1, 1)
            )
        )
        + 2
    )
    / 100,
    dims=["time", "latitude", "longitude"],
    coords={
        "time": pd.date_range("2023-01-01", periods=nb_days_by_year * nb_years, freq="D"),
        "latitude": [0],
        "longitude": [0],
    },
    name="primary_production",
    attrs={
        "units": "kg/m^2/day",
        "long_name": "Primary production",
        "standard_name": "primary_production",
    },
)
temperature.time.attrs = {"axis": "T"}
primary_production.time.attrs = {"axis": "T"}

temperature.latitude.attrs = {"axis": "Y"}
primary_production.latitude.attrs = {"axis": "Y"}

temperature.longitude.attrs = {"axis": "X", "unit": "degrees_east"}
primary_production.longitude.attrs = {"axis": "X"}

temperature.depth.attrs = {"axis": "Z"}

forcing_parameter = ForcingParameter(
    temperature=ForcingUnit(forcing=temperature),
    primary_production=ForcingUnit(forcing=primary_production),
)

In [None]:
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(10, 8), sharex=True)
primary_production.plot(label="Primary Production", ax=ax1)
ax1.legend()
ax1.set_title("Primary Production Time Series")
temperature.plot(label="Temperature", ax=ax2)
ax2.legend()
ax2.set_title("Temperature Time Series")
plt.show()

In [None]:
model_generator = NoTransportConfigurationGenerator(forcing_parameters=forcing_parameter)

In [None]:
initial_model = model_generator.generate(
    functional_group_parameters=[
        {
            "energy_transfert": 0.1668,
            "day_layer": 0,
            "night_layer": 0,
            "gamma_tr": -0.11,
            "tr_0": 10.38,
            "gamma_lambda_temperature": 0.15,
            "lambda_temperature_0": 1 / 150,
        }
    ],
    functional_group_names=["Zooplankton"],
)
initial_model.run()
observed_biomass = initial_model.state.biomass
observed_biomass = observed_biomass.expand_dims({"layer": [0]}).isel(functional_group=0).drop_vars(["functional_group"])
observed_biomass.layer.attrs = {"axis": "Z"}
plt.figure(figsize=(10, 4))
observed_biomass.plot()
plt.title("Observed Biomass")


In [None]:
observation = TimeSeriesObservation(
    name="Zooplankton Biomass", observation=observed_biomass, observation_type=DayCycle.DAY
)

In [None]:
cost_function = CostFunction(model_generator=model_generator, observations=[observation], functional_groups=fg_set)

In [None]:
logbook = OptimizationLog.from_sobol_samples(fg_set, sample_number=2, fitness_names=["Zooplankton Biomass"])
logbook

In [None]:
metaparam = GeneticAlgorithmParameters(
    ETA=20, INDPB=0.2, CXPB=0.7, MUTPB=1, NGEN=5, POP_SIZE=10, cost_function_weight=(-1,)
)
genetic_algorithm = GeneticAlgorithm(
    meta_parameter=metaparam, cost_function=cost_function, client=Client(), logbook=logbook
)

In [None]:
genetic_algorithm.distribute_data()

In [None]:
genetic_algorithm.client

In [None]:
# Run optimization and get the Logbook results
optimization_results = genetic_algorithm.optimize()

In [None]:
optimization_results

In [None]:
optimization_results.dataset

In [None]:
optimization_results.dataset.weighted_fitness.min("individual").plot()