In [None]:
import warnings

import matplotlib.pyplot as plt
import xarray as xr
from dask.distributed import Client
from pint import UnitStrippedWarning
from seapopym.configuration.no_transport.parameter import ForcingParameters
from seapopym.configuration.parameters.parameter_forcing import ForcingUnit
from seapopym.standard.units import StandardUnitsLabels

from seapopym_optimization import Observation, constraint
from seapopym_optimization.cost_function import NoTransportCostFunction
from seapopym_optimization.functional_groups import FunctionalGroupOptimizeNoTransport, Parameter
from seapopym_optimization.genetic_algorithm import GeneticAlgorithm, GeneticAlgorithmParameters
from seapopym_optimization.taylor_diagram import ModTaylorDiagram, generate_mod_taylor_diagram

warnings.simplefilter("ignore", category=UnitStrippedWarning)
xr.set_options(
    display_expand_attrs=False,
    display_expand_data_vars=False,
    display_expand_coords=False,
    display_expand_data=False,
)

In [None]:
path_to_forcing = "../../../1_data_processing/1_1_Forcing/data/1_products/Hot_cmems_climato.zarr"
path_to_npp = "../../../1_data_processing/1_1_Forcing/data/1_products/Hot_observed_npp_climato.zarr"
path_to_obs = "../../../1_data_processing/1_1_Forcing/data/1_products/Hot_obs_zoo_climato_monthly_2002_2015.zarr"
export_file_name = "SeapoPym_HOT_climato_obs_npp_opti_all_parameters_1_group_night"

In [None]:
LATITUDE = 22.75
LONGITUDE = -158
TIME_START = "2005-01-02"
TIME_END = "2009-12-27"
STABILIZATION_TIME = 5
SAVE = True

## Loading


### Forcing


In [None]:
forcing = xr.open_zarr(path_to_forcing)
forcing = forcing.sel(time=slice(TIME_START, TIME_END))
forcing["T"].attrs["units"] = StandardUnitsLabels.temperature.units
forcing.load()

### Epipelagic layer


In [None]:
epi_layer_depth = forcing["pelagic_layer_depth"].sel(depth=0).load()
epi_layer_depth = epi_layer_depth.resample(time="1D").mean()
epi_layer_depth.attrs["units"] = "meter"
epi_layer_depth = epi_layer_depth.pint.quantify()
epi_layer_depth

<!-- ## Observed NPP -->


In [None]:
observed_npp = xr.open_zarr(path_to_npp)
observed_npp = observed_npp.sel(time=slice(TIME_START, TIME_END))
observed_npp = observed_npp.dropna("time", how="all")
observed_npp = observed_npp.resample(time="D").interpolate("linear")
observed_npp.load()

### Observations


In [None]:
observations = xr.open_zarr(path_to_obs).load()
observations = observations.sel(latitude=LATITUDE, longitude=LONGITUDE, method="nearest")
observations = observations.resample(time="1D").mean().dropna("time")
observations = observations.pint.quantify().pint.to("mg/m^3")
observations = observations * epi_layer_depth
observations = observations.drop_vars("depth")
observations

observations


Select the kind of observation you want to use.


In [None]:
# observations_selected = observations[["day_lowess_0.2", "night_lowess_0.2"]].rename(
#     {"day_lowess_0.2": "day", "night_lowess_0.2": "night"}
# )

observations_selected = observations[["night"]]

Remove the X first months to let the model reach the stationary state.


In [None]:
observations_selected_without_init = observations_selected.isel(time=slice(STABILIZATION_TIME, None))
observations_selected_without_init

Create structure for SeapoPym simulation.


In [None]:
forcing_parameters = ForcingParameters(
    temperature=ForcingUnit(forcing=forcing["T"], resolution=1 / 12, timestep=1),
    primary_production=ForcingUnit(forcing=observed_npp["l12"], resolution=1 / 12, timestep=1),
)

## Setup the parameters and the cost function


In [None]:
functional_groups = [
    FunctionalGroupOptimizeNoTransport(
        name="Zooplankton",
        day_layer=0,
        night_layer=0,
        energy_coefficient=Parameter("D1N1_energy_coefficient", 0.001, 0.4),
        tr_rate=Parameter("D1N1_tr_rate", -0.3, -0.001),
        tr_max=Parameter("D1N1_tr_max", 0, 50),
        inv_lambda_rate=Parameter("D1N1_inv_lambda_rate", -0.3, -0.001),
        inv_lambda_max=Parameter("D1N1_inv_lambda_max", 100, 200),
    ),
]

In [None]:
cost_function = NoTransportCostFunction(
    functional_groups=functional_groups,
    forcing_parameters=forcing_parameters,
    observations=[
        Observation(name="Hot climato", observation=observations_selected_without_init, observation_type="monthly")
    ],
    normalized_mse=True,
    root_mse=True,
)

Set the genetic algorithm meta parameters.


In [None]:
genetic_algo_parameters = GeneticAlgorithmParameters(
    MUTPB=0.30,
    INDPB=0.2,
    ETA=5,
    CXPB=0.7,
    NGEN=10,
    POP_SIZE=1000,
    cost_function_weight=(-1,),
)

Finaly, create the Genetic Algorithm.


In [None]:
client = Client()
genetic_algo = GeneticAlgorithm(
    cost_function=cost_function,
    parameter_genetic_algorithm=genetic_algo_parameters,
    client=client,
    logbook_path=f"{export_file_name}_logbook.json",
)

And watch the magic on the Dask dashboard :


In [None]:
genetic_algo.client

## Run the optimization


In [None]:
viewer = genetic_algo.optimize()

## Optimization statistics


In [None]:
viewer.hall_of_fame.head(10)

In [None]:
viewer.fitness_evolution()

In [None]:
viewer.parameters_standardized_deviation()

In [None]:
viewer.parameters_scatter_matrix(nbest=2000)

In [None]:
fig = viewer.box_plot(3, nbest=500)
fig.show()

In [None]:
groups = [["D1N1_energy_coefficient", "D1N1_tr_rate", "D1N1_tr_max", "D1N1_inv_lambda_rate", "D1N1_inv_lambda_max"]]

fig = viewer.parallel_coordinates(nbest=100, unselected_opacity=0, parameter_groups=groups, uniformed=True)

for group in fig:
    display(group)

In [None]:
if SAVE:
    for i, trace in enumerate(fig):
        trace.write_html(f"Parallel_coordinates_{export_file_name}_{i}.html")

# Plots


### Time series of X best individuals


In [None]:
import plotly.graph_objects as go

interval = 50
old_figure = viewer.time_series(interval, title=["HOT"], client=client)[0]


figure_update = go.Figure()

figure_update.add_trace(old_figure.data[0].update(name="Optimal parameterization", xaxis="x", yaxis="y"))
figure_update.add_trace(old_figure.data[1].update(xaxis="x", yaxis="y", showlegend=False))
figure_update.add_trace(old_figure.data[2].update(name="Observations", xaxis="x", yaxis="y"))


original = (
    viewer.original_simulation.pint.quantify()
    .pint.to("mg/m2")
    .mean(["functional_group", "latitude", "longitude"])
    .to_series()
    .reset_index()
)
figure_update.add_trace(
    go.Scatter(
        x=original["time"],
        y=original["biomass"],
        mode="lines",
        name="Original parameterization",
        line=dict(color="red", width=2),
    )
)

figure_update.update_layout(
    title="Comparison of parameterization at the HOT station",
    xaxis_title="Time",
    yaxis_title="Biomass (mg/m2)",
    legend_title="",
)
figure_update.update_layout(width=1400 / 2, height=600 / 1.5)

In [None]:
if SAVE:
    fig.write_html(f"Biomass_best_individuals_{export_file_name}.html")

### Taylor Diagram


In [None]:
fig = viewer.taylor_diagram(1, client=client)
# dont show legend
fig.update_layout(showlegend=False)
fig.show()

In [None]:
monthly_obs = (
    observations["night"]
    .pint.quantify()
    .pint.to("mg/m2")
    .mean(["latitude", "longitude", "layer"])
    .sel(time=slice("2006", "2007"))
)
monthly_obs

In [None]:
monthly_pred_opti = (
    viewer.best_individuals_simulations(1)
    .pint.quantify()
    .pint.to("mg/m2")
    .mean(["latitude", "longitude", "individual", "functional_group"])
    .sel(time=slice("2006", "2007"))
).interp_like(monthly_obs)
monthly_pred_opti

In [None]:
monthly_pred_original = (
    viewer.original_simulation.pint.quantify()
    .pint.to("mg/m2")
    .mean(["latitude", "longitude", "functional_group"])
    .sel(time=slice("2006", "2007"))
    .interp_like(monthly_obs)
)
monthly_pred_original

In [None]:
diagram = ModTaylorDiagram()

all_model = [monthly_pred_opti, monthly_pred_original]
all_obs = [monthly_obs, monthly_obs]

all_names = ["Prediction", "Original parameterization"]

for model, obs, name in zip(all_model, all_obs, all_names):
    diagram = generate_mod_taylor_diagram(diagram, obs=obs.to_series(), model=model.to_series(), name=name)
diagram.plot()
plt.title("Comparison of the original and optimized parameterization\n")

# export the figure
if SAVE:
    plt.savefig(f"Taylor_{export_file_name}.png")

plt.show()