In [1]:
from pathlib import Path

import dask
import numpy as np
import pandas as pd
import xarray as xr
from dask.distributed import Client
from seapopym.configuration.no_transport.parameter import ForcingParameters, ForcingUnit, KernelParameters

from seapopym_optimization import wrapper

User parameters

A batch of 1000 samples takes about 48 seconds to run on my machine.


In [2]:
nb_samples_by_batch = 1000

quantity_of_interest = ["mean", "variance", "argmax"]

time_start = "2005-01-01"
time_start_analysis = "2006-01-01"
time_end = "2007-01-01"

In [3]:
stations_locations = pd.read_json("../1_data_processing/1_3_Sensibility/stations_locations.json")
stations_locations = stations_locations.set_index("name")
stations_locations

Unnamed: 0_level_0,longitude,latitude,temperature,primary production
name,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
BARENTS,26.969,74.62,4.036164,121.380569
HOT,-158.004,22.752,23.839729,254.277267
BATS,-64.2,31.604,21.537741,265.166229
PAPA,-149.996,50.006,6.785365,276.715942
GUAM,149.995,13.001,27.390701,112.10244


In [4]:
multi_index_columns = pd.MultiIndex.from_product(
    [stations_locations.index, quantity_of_interest], names=["station", "quantity_of_interest"]
)
column_index_flatten = pd.Index(
    [f"{station}_{quantity_of_interest}" for station, quantity_of_interest in multi_index_columns], name="station"
)
multi_index_columns

MultiIndex([('BARENTS',     'mean'),
            ('BARENTS', 'variance'),
            ('BARENTS',   'argmax'),
            (    'HOT',     'mean'),
            (    'HOT', 'variance'),
            (    'HOT',   'argmax'),
            (   'BATS',     'mean'),
            (   'BATS', 'variance'),
            (   'BATS',   'argmax'),
            (   'PAPA',     'mean'),
            (   'PAPA', 'variance'),
            (   'PAPA',   'argmax'),
            (   'GUAM',     'mean'),
            (   'GUAM', 'variance'),
            (   'GUAM',   'argmax')],
           names=['station', 'quantity_of_interest'])

In [5]:
client = Client()
client

0,1
Connection method: Cluster object,Cluster type: distributed.LocalCluster
Dashboard: http://127.0.0.1:8787/status,

0,1
Dashboard: http://127.0.0.1:8787/status,Workers: 4
Total threads: 12,Total memory: 48.00 GiB
Status: running,Using processes: True

0,1
Comm: tcp://127.0.0.1:50813,Workers: 4
Dashboard: http://127.0.0.1:8787/status,Total threads: 12
Started: Just now,Total memory: 48.00 GiB

0,1
Comm: tcp://127.0.0.1:50826,Total threads: 3
Dashboard: http://127.0.0.1:50829/status,Memory: 12.00 GiB
Nanny: tcp://127.0.0.1:50816,
Local directory: /var/folders/z_/8j3qx1mn0299kkpjgz9g53780000gq/T/dask-scratch-space/worker-yevi6z6o,Local directory: /var/folders/z_/8j3qx1mn0299kkpjgz9g53780000gq/T/dask-scratch-space/worker-yevi6z6o

0,1
Comm: tcp://127.0.0.1:50825,Total threads: 3
Dashboard: http://127.0.0.1:50830/status,Memory: 12.00 GiB
Nanny: tcp://127.0.0.1:50818,
Local directory: /var/folders/z_/8j3qx1mn0299kkpjgz9g53780000gq/T/dask-scratch-space/worker-o7wlpgki,Local directory: /var/folders/z_/8j3qx1mn0299kkpjgz9g53780000gq/T/dask-scratch-space/worker-o7wlpgki

0,1
Comm: tcp://127.0.0.1:50824,Total threads: 3
Dashboard: http://127.0.0.1:50828/status,Memory: 12.00 GiB
Nanny: tcp://127.0.0.1:50820,
Local directory: /var/folders/z_/8j3qx1mn0299kkpjgz9g53780000gq/T/dask-scratch-space/worker-yxvbaiah,Local directory: /var/folders/z_/8j3qx1mn0299kkpjgz9g53780000gq/T/dask-scratch-space/worker-yxvbaiah

0,1
Comm: tcp://127.0.0.1:50827,Total threads: 3
Dashboard: http://127.0.0.1:50831/status,Memory: 12.00 GiB
Nanny: tcp://127.0.0.1:50822,
Local directory: /var/folders/z_/8j3qx1mn0299kkpjgz9g53780000gq/T/dask-scratch-space/worker-562v5giy,Local directory: /var/folders/z_/8j3qx1mn0299kkpjgz9g53780000gq/T/dask-scratch-space/worker-562v5giy


Samples (sobol sequence)


In [6]:
input_parameters = pd.read_parquet("./input_samples.parquet")
input_parameters

Unnamed: 0,energy_transfert,tr_0,gamma_tr,lambda_0,gamma_lambda
0,0.322816,9.083201,-0.319703,140.064326,-0.059983
1,0.003590,9.083201,-0.319703,140.064326,-0.059983
2,0.322816,41.212428,-0.319703,140.064326,-0.059983
3,0.322816,9.083201,-0.140274,140.064326,-0.059983
4,0.322816,9.083201,-0.319703,323.008823,-0.059983
...,...,...,...,...,...
1190695,0.251086,18.763531,-0.078224,88.508398,-0.361641
1190696,0.251086,34.402784,-0.371647,88.508398,-0.361641
1190697,0.251086,34.402784,-0.078224,473.318956,-0.361641
1190698,0.251086,34.402784,-0.078224,88.508398,-0.123918


Setting the output file. This file will be filled batch after batch with QoI values.


In [7]:
output_sobol_index_filepath = Path("./output_sobol_index.parquet")
if output_sobol_index_filepath.exists():
    output_sobol_index = pd.read_parquet(output_sobol_index_filepath)
else:
    output_sobol_index = pd.DataFrame(columns=multi_index_columns)
    output_sobol_index.to_parquet(output_sobol_index_filepath)
output_sobol_index

station,BARENTS,BARENTS,BARENTS,HOT,HOT,HOT,BATS,BATS,BATS,PAPA,PAPA,PAPA,GUAM,GUAM,GUAM
quantity_of_interest,mean,variance,argmax,mean,variance,argmax,mean,variance,argmax,mean,variance,argmax,mean,variance,argmax
0,0.004000,3.479870e-06,205.0,0.002950,2.277094e-07,50.0,0.003824,2.358212e-06,97.0,0.006814,7.913743e-07,287.0,0.001023,6.137193e-09,113.0
1,0.000044,4.303894e-10,205.0,0.000033,2.816304e-11,50.0,0.000043,2.916630e-10,97.0,0.000076,9.787696e-11,287.0,0.000011,7.590464e-13,113.0
2,0.003988,3.623789e-06,208.0,0.002950,2.277094e-07,50.0,0.003824,2.358212e-06,97.0,0.006795,9.268995e-07,289.0,0.001023,6.137193e-09,113.0
3,0.003999,3.456675e-06,205.0,0.002950,2.277094e-07,50.0,0.003824,2.358212e-06,97.0,0.006810,8.046640e-07,289.0,0.001023,6.137193e-09,113.0
4,0.008151,5.712565e-06,254.0,0.006751,5.638046e-07,79.0,0.008491,4.738705e-06,111.0,0.014415,2.375407e-06,300.0,0.002324,1.664002e-08,113.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1190695,0.000499,3.177044e-07,139.0,0.000066,2.351754e-10,166.0,0.000074,1.607287e-09,79.0,0.000495,3.933408e-08,112.0,0.000029,1.595918e-11,198.0
1190696,0.000507,3.413467e-07,130.0,0.000067,2.065802e-10,48.0,0.000074,1.434071e-09,91.0,0.000498,4.364409e-08,105.0,0.000029,1.609563e-11,196.0
1190697,0.002547,1.786764e-06,216.0,0.000066,3.636120e-10,139.0,0.000078,1.927438e-09,109.0,0.002297,4.228186e-07,155.0,0.000029,5.423783e-11,200.0
1190698,0.001458,1.194931e-06,191.0,0.000351,7.279864e-09,54.0,0.000544,9.299581e-08,101.0,0.002183,1.692077e-07,276.0,0.000101,1.518171e-10,76.0


---

# Cost function definition

Prepare forcing and parameters definition


In [8]:
input_forcing = xr.open_dataset("../1_data_processing/1_3_Sensibility/all_stations.zarr", engine="zarr")
input_forcing = input_forcing.sel(time=slice(time_start, time_end))
input_forcing

In [9]:
FORCING_PARAMETERS = ForcingParameters(
    temperature=ForcingUnit.from_dataset(forcing=input_forcing, name="T", resolution=0.08333, timestep=1),
    primary_production=ForcingUnit.from_dataset(input_forcing, name="npp", resolution=0.08333, timestep=1),
)

|	npp unit is milligram / day / meter ** 2, it will be converted to kilogram / day / meter ** 2.
[0m


In [10]:
def wrapper_model_generator_no_transport(fg_parameters):
    fg_parameters = wrapper.FunctionalGroupGeneratorNoTransport(np.array([fg_parameters]))
    return wrapper.model_generator_no_transport(
        fg_parameters=fg_parameters,
        forcing_parameters=FORCING_PARAMETERS,
    )

Official scoring function


In [11]:
def compute_quantity_of_interest(biomass_forcing_station, station):
    return (
        float(biomass_forcing_station.mean().data),
        float(biomass_forcing_station.var().data),
        int(biomass_forcing_station.argmax("time").data),  # TODO: Compute the DayOfYear of the argmax
    )


@dask.delayed
def cost_function(x: np.ndarray):
    energy_transfert, tr_0, gamma_tr, inv_lambda_0, gamma_inv_lambda = x.T
    fg_parameters = [0, 0, energy_transfert, tr_0, gamma_tr, inv_lambda_0, gamma_inv_lambda]

    model = wrapper_model_generator_no_transport(fg_parameters)

    model.run()
    biomass_forcing = model.export_biomass().sel(time=slice(time_start_analysis, time_end))

    results = []
    for station in stations_locations.index:
        biomass_forcing_station = biomass_forcing.sel(
            latitude=stations_locations.loc[station, "latitude"],
            longitude=stations_locations.loc[station, "longitude"],
            functional_group=0,
        )
        results += compute_quantity_of_interest(biomass_forcing_station, station)

    return results

In [12]:
def batch_cost_function_execution(input_parameters: pd.DataFrame) -> np.ndarray:
    resultats = [cost_function(param) for param in input_parameters.to_numpy()]
    return np.array(dask.compute(*resultats))

Test function


In [13]:
# TEST FUNCTION
# def batch_cost_function_execution(input_parameters: pd.DataFrame) -> np.ndarray:
#     return np.full((input_parameters.shape[0], len(quantity_of_interest)), 1)

---


Run as much batch you can


In [None]:
for batch_number in range(0, (len(input_parameters) // nb_samples_by_batch) + 1):
    min_batch = batch_number * nb_samples_by_batch
    max_batch = min(batch_number * nb_samples_by_batch + nb_samples_by_batch, len(input_parameters)) - 1
    print(f"Batch {batch_number} = {min_batch} : {max_batch}")

    if not (max_batch) in output_sobol_index.index:
        batch_samples = input_parameters.iloc[
            batch_number * nb_samples_by_batch : batch_number * nb_samples_by_batch + nb_samples_by_batch
        ]

        results = batch_cost_function_execution(batch_samples)
        results = pd.DataFrame(data=results, columns=multi_index_columns, index=batch_samples.index)

        output_sobol_index = pd.concat([output_sobol_index, results])
        output_sobol_index.to_parquet(output_sobol_index_filepath)

Batch 0 = 0 : 1000
Batch 1 = 1000 : 2000
Batch 2 = 2000 : 3000
Batch 3 = 3000 : 4000
Batch 4 = 4000 : 5000
Batch 5 = 5000 : 6000
Batch 6 = 6000 : 7000
Batch 7 = 7000 : 8000
Batch 8 = 8000 : 9000
Batch 9 = 9000 : 10000
Batch 10 = 10000 : 11000
Batch 11 = 11000 : 12000
Batch 12 = 12000 : 13000
Batch 13 = 13000 : 14000
Batch 14 = 14000 : 15000
Batch 15 = 15000 : 16000
Batch 16 = 16000 : 17000
Batch 17 = 17000 : 18000
Batch 18 = 18000 : 19000
Batch 19 = 19000 : 20000
Batch 20 = 20000 : 21000
Batch 21 = 21000 : 22000
Batch 22 = 22000 : 23000
Batch 23 = 23000 : 24000
Batch 24 = 24000 : 25000
Batch 25 = 25000 : 26000
Batch 26 = 26000 : 27000
Batch 27 = 27000 : 28000
Batch 28 = 28000 : 29000
Batch 29 = 29000 : 30000
Batch 30 = 30000 : 31000
Batch 31 = 31000 : 32000
Batch 32 = 32000 : 33000
Batch 33 = 33000 : 34000
Batch 34 = 34000 : 35000
Batch 35 = 35000 : 36000
Batch 36 = 36000 : 37000
Batch 37 = 37000 : 38000
Batch 38 = 38000 : 39000
Batch 39 = 39000 : 40000
Batch 40 = 40000 : 41000
Batch 4

Show output


In [16]:
output_sobol_index

station,BARENTS,BARENTS,BARENTS,HOT,HOT,HOT,BATS,BATS,BATS,PAPA,PAPA,PAPA,GUAM,GUAM,GUAM
quantity_of_interest,mean,variance,argmax,mean,variance,argmax,mean,variance,argmax,mean,variance,argmax,mean,variance,argmax
0,0.004000,3.479870e-06,205.0,0.002950,2.277094e-07,50.0,0.003824,2.358212e-06,97.0,0.006814,7.913743e-07,287.0,0.001023,6.137193e-09,113.0
1,0.000044,4.303894e-10,205.0,0.000033,2.816304e-11,50.0,0.000043,2.916630e-10,97.0,0.000076,9.787696e-11,287.0,0.000011,7.590464e-13,113.0
2,0.003988,3.623789e-06,208.0,0.002950,2.277094e-07,50.0,0.003824,2.358212e-06,97.0,0.006795,9.268995e-07,289.0,0.001023,6.137193e-09,113.0
3,0.003999,3.456675e-06,205.0,0.002950,2.277094e-07,50.0,0.003824,2.358212e-06,97.0,0.006810,8.046640e-07,289.0,0.001023,6.137193e-09,113.0
4,0.008151,5.712565e-06,254.0,0.006751,5.638046e-07,79.0,0.008491,4.738705e-06,111.0,0.014415,2.375407e-06,300.0,0.002324,1.664002e-08,113.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1190695,0.000499,3.177044e-07,139.0,0.000066,2.351754e-10,166.0,0.000074,1.607287e-09,79.0,0.000495,3.933408e-08,112.0,0.000029,1.595918e-11,198.0
1190696,0.000507,3.413467e-07,130.0,0.000067,2.065802e-10,48.0,0.000074,1.434071e-09,91.0,0.000498,4.364409e-08,105.0,0.000029,1.609563e-11,196.0
1190697,0.002547,1.786764e-06,216.0,0.000066,3.636120e-10,139.0,0.000078,1.927438e-09,109.0,0.002297,4.228186e-07,155.0,0.000029,5.423783e-11,200.0
1190698,0.001458,1.194931e-06,191.0,0.000351,7.279864e-09,54.0,0.000544,9.299581e-08,101.0,0.002183,1.692077e-07,276.0,0.000101,1.518171e-10,76.0
