![alternatvie text](../assets/header.jpg)
<hr style="color:#5A7D9F;">
<p align="center">
    <b style="font-size:2.5vw; color:#5A7D9F; font-weight:bold;">
    Ocean parameterizations in an idealized model using machine learning
    </b>
</p>
<hr style="color:#5A7D9F;">

[comment]: <> (Section)
<hr style="color:#5A7D9F;">
<p align="center">
    <b style="font-size:1.5vw; color:#5A7D9F;">
    Initialization
    </b>
</p>
<hr style="color:#5A7D9F;">

[comment]: <> (Description)
<p align="justify">
    In this section, one can initialize the notebook by loading all the librairies and basic functions used to run it properly.
</p>

In [None]:
# -- Librairies --

import os
import sys
import numpy as np
import matplotlib
import matplotlib.pyplot as plt

import json
import fsspec
import xarray as xr
import pyqg
import pyqg.diagnostic_tools
import functions.coarsening_ops as coarsening

from tqdm.notebook import tqdm, trange
from pyqg.diagnostic_tools import calc_ispec as _calc_ispec
from pyqg_parameterization_benchmarks.neural_networks import FullyCNN
from pyqg_parameterization_benchmarks.neural_networks import FCNNParameterization

# You can use the magic function %matplotlib inline to enable the inline plotting, 
# where the plots/graphs will be displayed just below the cell where your plotting commands are written. 
%matplotlib inline
plt.rcParams.update({'font.size': 13})

calc_ispec = lambda *args, **kwargs: _calc_ispec(*args, averaging = False, truncate =False, **kwargs)

In [None]:
# -- Functions --

# The datasets used in L. Zanna & Al.'s paper  are hosted on globus as zarr files
def get_dataset(index, base_url = "https://g-402b74.00888.8540.data.globus.org"):

    paths = ['eddy/high_res', 'jet/high_res']
    mapper = fsspec.get_mapper(f"{base_url}/{paths[index]}.zarr")
    return xr.open_zarr(mapper, consolidated=True)

[comment]: <> (Section)
<hr style="color:#5A7D9F;">
<p align="center">
    <b style="font-size:1.5vw; color:#5A7D9F;">
    PYQG - Generating / Loading datasets
    </b>
</p>
<hr style="color:#5A7D9F;">

[comment]: <> (Description)
<p align="justify">
    In this section, one will be able to:
</p>

- Generate a new high resolution (HR) dataset or load a pre-existing one coming from L. Zanna & Al.'s paper

- Generate the low resolution dataset obtained from the HR by filtering and coarsening it. The operators used to perform the operation are the ones used by L.Zanna & Al.'s' or one's own.

- Doc : https://pyqg.readthedocs.io/en/latest/examples/two-layer.html#Initialize-and-Run-the-Model

<hr style="color:#5A7D9F; width: 100px;" align="left">
<p style="color:#5A7D9F;">High resolution</p>
<hr style="color:#5A7D9F; width: 100px;" align="left">
<table style="width: 100%;" border="1">
	<tbody>
		<tr>
			<td style="width: 15%;" align="center">INDEX</td>
			<td style="width: 15%;" align="center">1</td>
			<td style="width: 15%;" align="center">2</td>
			<td style="width: 15%;" align="center">3</td>
		</tr>
		<tr>
			<td style="width: 15%;"align="center">&nbsp;GENERATION</td>
			<td style="width: 15%;" align="center">Eddies - Type 1</td>
			<td style="width: 15%;" align="center">Jets - Type 1</td>
			<td style="width: 15%;" align="center">Eddies - Test</td>
		</tr>
		<tr>
			<td style="width: 15%;" align="center">LOADING</td>
			<td style="width: 15%;" align="center">Eddies</td>
			<td style="width: 15%;" align="center">Jets</td>
			<td style="width: 15%;" align="center">&nbsp;TBD.</td>
		</tr>
	</tbody>
</table>

where,

<table style="width: 100%;" border="1">
	<tbody>
		<tr style="height: 21px;">
			<td style="width: 16%;" align="center">PARAMETERS</td>
			<td style="width: 12%;" align="center">nx</td>
			<td style="width: 12%;" align="center">dt</td>
			<td style="width: 12%;" align="center">tmax</td>
			<td style="width: 12%;" align="center">tavestart</td>
			<td style="width: 12%;" align="center">rek</td>
			<td style="width: 12%;" align="center">Δ</td>
			<td style="width: 12%;" align="center">β</td>
		</tr>
		<tr style="height: 21.5px;">
			<td style="width: 16%;" align="center">DESCRIPTION</td>
			<td style="width: 12%;" align="center">Number of real space grid points in the x directions</td>
			<td style="width: 12%;" align="center">Numerical timestep (in hours)</td>
			<td style="width: 12%;" align="center">Total time of integration (in years)</td>
			<td style="width: 12%;" align="center">Start time for averaging (in years)</td>
			<td style="width: 12%;" align="center">Linear drag in lower layer</td>
			<td style="width: 12%;" align="center">Layer thickness ratio (H1/H2)</td>
			<td style="width: 12%;" align="center">Gradient of coriolis parameter.</td>
		</tr>
		<tr style="height: 21.5px;">
			<td style="width: 16%;" align="center">EDDIES - TYPE 1</td>
			<td style="width: 12%;" align="center">256</td>
			<td style="width: 12%;" align="center">1</td>
			<td style="width: 12%;" align="center">10</td>
			<td style="width: 12%;" align="center">5</td>
			<td style="width: 12%;" align="center">5.789e-7</td>
			<td style="width: 12%;" align="center">0.25</td>
			<td style="width: 12%;" align="center">1.5 * 1e-11</td>
		</tr>
		<tr style="height: 21.5px;">
			<td style="width: 16%;" align="center">EDDIES - TEST</td>
			<td style="width: 12%;" align="center">256</td>
			<td style="width: 12%;" align="center">1</td>
			<td style="width: 12%;" align="center">5</td>
			<td style="width: 12%;" align="center">3</td>
			<td style="width: 12%;" align="center">5.789e-7</td>
			<td style="width: 12%;" align="center">0.25</td>
			<td style="width: 12%;" align="center">1.5 * 1e-11</td>
		</tr>
		<tr style="height: 21px;">
			<td style="width: 16%;" align="center">JETS - TYPE 1</td>
			<td style="width: 12%;" align="center">256</td>
			<td style="width: 12%;" align="center">1</td>
			<td style="width: 12%;" align="center">10</td>
			<td style="width: 12%;" align="center">5</td>
			<td style="width: 12%;" align="center">7e-08</td>
			<td style="width: 12%;" align="center">0.1</td>
			<td style="width: 12%;" align="center">1e-11</td>
		</tr>
	</tbody>
</table>

In [None]:
# Define if the high-resolution dataset is generated (True) or loaded (False)
generating_HR = True

# Define the type of high resolution data to generate with:
generating_HR_type = 3

# Define the type of high resolution data to load
loading_HR_type    = 1

In [None]:
# -- High resolution simulation --
#
#---------
# Loading
#---------
if generating_HR == False:

    # Displaying information over terminal (1)
    print("-- Simulation --")
    print("Loading")

    # Loading the corresponding high resolution dataset
    high_res = get_dataset(loading_HR_type - 1).isel(run = 0)

    # Contains possible type of simulations
    type_sim = ["Eddies", "Jets"]

    # Displaying information over terminal (2)
    print("Done")
    print(f"\n-- Type --\n {type_sim[loading_HR_type - 1]}")
    print(f"\n-- Parameters --\n {json.loads(high_res.attrs['pyqg_params'])}")

#------------
# Generating
#------------
else:
    # Definition of the parameters for each simulation type
    nx        = [256        , 256   , 256]
    dt        = [1          , 1     , 1]
    tmax      = [10         , 10    , 5]
    tavestart = [5          , 5     , 3]
    rek       = [5.789e-7   , 7e-08 , 5.789e-7]
    delta     = [0.25       , 0.1   , 0.25]
    beta      = [1.5 * 1e-11, 1e-11 , 1.5 * 1e-11]

    # Creation of dictionnary
    simulation_parameters              = {}
    simulation_parameters['nx']        = nx[generating_HR_type - 1]
    simulation_parameters['dt']        = dt[generating_HR_type - 1]        * 60 * 60
    simulation_parameters['tmax']      = tmax[generating_HR_type - 1]      * 24 * 60 * 60 * 360
    simulation_parameters['tavestart'] = tavestart[generating_HR_type - 1] * 24 * 60 * 60 * 360
    simulation_parameters['rek']       = rek[generating_HR_type - 1]
    simulation_parameters['delta']     = delta[generating_HR_type - 1]
    simulation_parameters['beta']      = beta[generating_HR_type - 1]

    # -- Running the simulation --

    # Displaying information over terminal (1)
    print("-- Simulation --")
    print(f"INFO: Total steps: {(tmax[generating_HR_type - 1] * 24 * 60 * 60 * 360)/dt[generating_HR_type - 1]}")

    # Creation of the model
    model_hr = pyqg.QGModel(**simulation_parameters)

    # Running the simulation
    model_hr.run()

    # Conversion to xarray dataset
    high_res = model_hr.to_dataset().isel()

    # Contains possible type of simulations
    type_sim = ["Eddies - Type 1", "Jets - Type 1", "Eddies - Test"]

    # Displaying information over terminal (2)
    print(f"\n-- Type --\n {type_sim[generating_HR_type - 1]}")
    print(f"\n-- Parameters --\n {simulation_parameters}")

<hr style="color:#5A7D9F; width: 96px;" align="left">
<p style="color:#5A7D9F;">Low resolution</p>
<hr style="color:#5A7D9F; width: 96PX;" align="left">
<table style="width: 100%;" border="1">
	<tbody>
		<tr>
			<td style="width: 10%;" align="center">&nbsp;OPERATOR</td>
			<td style="width: 15%;" align="center">&nbsp;1</td>
			<td style="width: 15%;" align="center">&nbsp;2</td>
			<td style="width: 17%;" align="center">3</td>
			<td style="width: 10%;" align="center">&nbsp;4</td>
			<td style="width: 10%;" align="center">&nbsp;5</td>
		</tr>
		<tr>
			<td style="width: 10%;"align="center">&nbsp;DESCRIPTION</td>
			<td style="width: 15%;">Spectral Truncation, Sharp Filter</td>
			<td style="width: 15%;">Spectral Truncation, Gaussian Filter</td>
			<td style="width: 17%;">GCM Filter, Averaging and Coarsening</td>
			<td style="width: 10%;">&nbsp;TBD.</td>
			<td style="width: 10%;">&nbsp;TBD.</td>
		</tr>
	</tbody>
</table>

In [None]:
# Define which type of filtering and coarsening operators should be applied
operator_index = 1

In [None]:
# -- Coarsening and filtering --

# Displaying information over terminal (1)
print("-- Simulation --")
print("Coarsening and filtering")

if operator_index == 1:
    model_lr = coarsening.Operator1(model_hr, low_res_nx = 64) # Spectral truncation + sharp filter

elif operator_index == 2:
    model_lr = coarsening.Operator2(model_hr, low_res_nx = 64) # Spectral truncation + sharp filter

elif operator_index == 3:
    model_lr = coarsening.Operator3(model_hr, low_res_nx = 64) # GCM-Filters + real-space coarsening

elif operator_index == 4:
    pass                                                       # Possibility to be implemented

elif operator_index == 5:
    pass                                                       # Possibility to be implemented

else:
    print("ERROR - Operator index is invalid !")

# Displaying information over terminal (2)
print("Done")

# More informations over terminal
if generating_HR == False:
    print(f"\n-- Type --\n {type_sim[loading_HR_type - 1]}")
else:
    
    # Contains possible type of simulations
    type_sim = ["Eddies - Type 1", "Jets - Type 1", "Eddies - Test"]

    # Displaying information over terminal (2)
    print(f"\n-- Type --\n {type_sim[generating_HR_type - 1]}")

In [None]:
# -- Computing subgrid forcing / subgrid fluxes terms and creation of dataset --
#
# Creation of the dataset
low_res = model_lr.m2.to_dataset().isel()

# Subgrid potential vorticity
sq = model_lr.subgrid_fluxes("q")

print(sq)

low_res["q_subgrid_forcing"] = (("time", "lev", "y", "x"), sq)


[comment]: <> (Section)
<hr style="color:#5A7D9F;">
<p align="center">
    <b style="font-size:1.5vw; color:#5A7D9F;">
    PYQG - High and Low resolution dataset comparison
    </b>
</p>
<hr style="color:#5A7D9F;">

[comment]: <> (Description)
<p align="justify">
    In this section, one will be able to compare:
</p>

- **State variables**

<table style="width: 100%;">
	<tbody>
		<tr>
			<td width="12%;" align="center">q</td>
			<td width="15%;" align="center">u</td>
			<td width="15%;" align="center">v</td>
			<td width="15%;" align="center">ufull</td>
			<td width="15%;" align="center">vfull</td>
			<td width="23%;" align="center">streamfunction - Ψ (NW)</td>
		</tr>
		<tr>
			<td align="center">Potential vorcitity</td>
			<td align="center">x-velocity relative to the background flow</td>
			<td align="center">y-velocity relative to the background flow</td>
			<td align="center">x-velocity with the background flow </td>
			<td align="center">y-velocity with the background flow </td>
			<td align="center">A particular case of a vector potential of velocity , related to velocity by the equality</td>
		</tr>
	</tbody>
</table>

- **Forcing variables**

- `q_subgrid_forcing`: $S_q \equiv \overline{(\mathbf{u} \cdot \nabla)q} - (\overline{\mathbf{u}} \cdot \overline{\nabla})\overline{q}$

- `u_subgrid_forcing`: $S_u \equiv \overline{(\mathbf{u} \cdot \nabla)u} - (\overline{\mathbf{u}} \cdot \overline{\nabla})\overline{u}$

- `v_subgrid_forcing`: $S_v \equiv \overline{(\mathbf{u} \cdot \nabla)v} - (\overline{\mathbf{u}} \cdot \overline{\nabla})\overline{v}$

- `uq_subgrid_flux`: $\phi_{uq} \equiv \overline{uq} - \bar{u}\bar{q}$

- `vq_subgrid_flux`: $\phi_{vq} \equiv \overline{vq} - \bar{v}\bar{q}$. Note that $\nabla \cdot \langle\phi_{uq}, \phi_{vq}\rangle = S_q$.

- `uu_subgrid_flux`: $\phi_{uu} \equiv \overline{u^2} - \bar{u}^2$

- `vv_subgrid_flux`: $\phi_{vv} \equiv \overline{v^2} - \bar{v}^2$

- `uv_subgrid_flux`: $\phi_{uv} \equiv \overline{uv} - \bar{u}\bar{v}$. 

- `dqdt_bar`: PV tendency from the high-resolution model, filtered and coarsened to low resolution

- `dqbar_dt`: PV tendency from the low-resolution model, initialized at $\overline{q}$.

 *NOTE*

- $\nabla \cdot \langle\phi_{uq}, \phi_{vq}\rangle = S_q$.

- $\nabla \cdot \langle\phi_{uu}, \phi_{uv}\rangle = S_u$ and $\nabla \cdot \langle\phi_{uv}, \phi_{vv}\rangle = S_v$.

- `dqdt_bar - dqbar_dt` can be an alternative to $S_q$ (it's very similar, except it also accounts for numerical dissipation).

In [None]:
# -- Functions (State variables) --
#
# Allows to display easily an image (from coarsening notebook)
def imshow(arr, vlim = 3e-5):
    plt.xticks([]); plt.yticks([])
    return plt.imshow(arr, vmin = -vlim, vmax = vlim, cmap = 'bwr', interpolation = 'none')

# Used to plot easily a state variable
def plotStateVariable(high_res, low_res, state_variable = "q"):

    # Text for the caption
    caption = ["Upper Level : z = 1", "Lower Level : z = 2"]

    # Looping over the levels
    for l in range(2):

        # Initialization of the plot
        fig = plt.figure(figsize=(21, 6))

        # Plotting the state variables (Note: in the coarsening class, the original high resolution model is stored in m1 !)
        if state_variable == "q":

            # High resolution
            plt.subplot(1, 2, 1)
            high_res.q.isel(lev = l, time = -1).plot()

            # Low resolution
            plt.subplot(1, 2, 2)
            low_res.q.isel(lev = l, time = -1).plot()

            # Adding a caption to the plot
            fig.text(0.45, -0.1, f"$Figure$: Representation of the potential vorticity q for the high resolution (left) and low resolution simulations (right) - {caption[l]}", ha = 'center')

        elif state_variable == "u":

            # High resolution
            plt.subplot(1, 2, 1)
            high_res.u.isel(lev = l, time = -1).plot()

            # Low resolution
            plt.subplot(1, 2, 2)
            low_res.u.isel(lev = l, time = -1).plot()

            # Adding a caption to the plot
            fig.text(0.45, -0.1, f"$Figure$: Representation of the horizontal velocity u for the high resolution (left) and low resolution simulations (right) - {caption[l]}", ha = 'center')

        elif state_variable == "v":
                    
            # High resolution
            plt.subplot(1, 2, 1)
            high_res.v.isel(lev = l, time = -1).plot()

            # Low resolution
            plt.subplot(1, 2, 2)
            low_res.v.isel(lev = l, time = -1).plot()

            # Adding a caption to the plot
            fig.text(0.45, -0.1, f"$Figure$: Representation of the vertical velocity y for the high resolution (left) and low resolution simulations (right) - {caption[l]}", ha = 'center')

        elif state_variable == "ufull":
                    
            # High resolution
            plt.subplot(1, 2, 1)
            high_res.ufull.isel(lev = l, time = -1).plot()

            # Low resolution
            plt.subplot(1, 2, 2)
            low_res.ufull.isel(lev = l, time = -1).plot()

            # Adding a caption to the plot
            fig.text(0.45, -0.1, f"$Figure$: Representation of the horizontal velocity u with background flow for the high resolution (left) and low resolution simulations (right) - {caption[l]}", ha = 'center')


        elif state_variable == "vfull":
                    
            # High resolution
            plt.subplot(1, 2, 1)
            high_res.vfull.isel(lev = l, time = -1).plot()

            # Low resolution
            plt.subplot(1, 2, 2)
            low_res.vfull.isel(lev = l, time = -1).plot()

            # Adding a caption to the plot
            fig.text(0.45, -0.1, f"$Figure$: Representation of the vertical velocity v with background flow for the high resolution (left) and low resolution simulations (right) - {caption[l]}", ha = 'center')


        elif state_variable == "streamfunction":
                    
            # High resolution
            plt.subplot(1, 2, 1)
            high_res.streamfunction.isel(lev = l, time = -1).plot()

            # Low resolution
            plt.subplot(1, 2, 2)
            low_res.streamfunction.isel(lev = l, time = -1).plot()

            # Adding a caption to the plot
            fig.text(0.45, -0.1, f"$Figure$: Representation of the streamfunction for the high resolution (left) and low resolution simulations (right) - {caption[l]}", ha = 'center')

In [None]:
# -- Functions (Subgrid variables) --
#
# Allows to display easily an image (from coarsening notebook, version 2 with axis label)
def imshow_2(arr, vlim = 3e-5):
    plt.xticks([]); plt.yticks([])
    plt.xlabel("Grid coordinates ($\mathbb{R}$) - $x$ direction")
    return plt.imshow(arr, vmin = -vlim, vmax = vlim, cmap = 'bwr', interpolation = 'none')

def plotForcingVariable(model_lr, state_variable = "sq"):

    # Subgrid - Potential vorticity
    if state_variable == "sq":

        # Initialization of the figure
        fig = plt.figure(figsize=(22, 6))
        plt.subplot(1, 2, 1, title = '$S_{q_{total}}$')
        plt.ylabel("Grid coordinates ($\mathbb{R}$) - $y$ direction")

        # Total potential vorticity forcing term (Sq_tot)
        imshow_2(model_lr.q_forcing_total[0], 3e-11)

        # Subgrid potential vorticity term (Sq)
        plt.subplot(1, 2, 2, title = '$S_{q}$')
        im = imshow_2(model_lr.subgrid_forcing('q')[0], 3e-11)
        cb = fig.colorbar(im, ax = fig.axes, pad=0.15).set_label('$S_{q}$ [$s^{-2}$]')

        # Power spectrum of subgrid potential vorticity
        plt.figure(figsize=(15, 4))
        plt.title("Power spectrum of $S_{q}$")

        # Retreiving results
        Sq = model_lr.subgrid_forcing('q')

        # Applying fast fourier transform
        line = plt.loglog(*calc_ispec(model_lr.m2, np.abs(model_lr.m2.fft(Sq))[0]**2), label = "Low Resolution")
        plt.loglog(*calc_ispec(model_lr.m2, np.abs(model_lr.m2.fft(Sq))[1]**2), color=line[0]._color, ls='--', label='Low Resolution - (Lower bound)')
        plt.legend(ncol=3)
        plt.grid()
        plt.ylabel("Power spectrum of PV forcing")
        plt.xlabel("Isotropic wavenumber - $\lambda$")

    # Subgrid - Horizontal velocity
    if state_variable == "su":

        # Initialization of the figure
        fig = plt.figure(figsize=(22, 6))
        plt.plot(title = '$S_{u}$')
        plt.ylabel("Grid coordinates ($\mathbb{R}$) - $y$ direction")

        # Subgrid potential vorticity term (Sq)
        im = imshow_2(model_lr.subgrid_forcing('u')[0], 1.5e-7)
        cb = fig.colorbar(im, ax = fig.axes, pad=0.15).set_label('$S_{u}$ [$m\,s^{-2}$]')

        # Power spectrum of subgrid horizontal velocity
        plt.figure(figsize=(15,4))
        plt.title("Power spectrum of $S_{u}$")

        # Retreiving results
        Su = model_lr.subgrid_forcing('u')

        # Applying fast fourier transform
        line = plt.loglog(*calc_ispec(model_lr.m2, np.abs(model_lr.m2.fft(Su))[0]**2), label = "Low Resolution")
        plt.loglog(*calc_ispec(model_lr.m2, np.abs(model_lr.m2.fft(Su))[1]**2), color=line[0]._color, ls='--', label='Low Resolution - (Lower bound)')
        plt.legend(ncol=3)
        plt.ylabel("Power spectrum of velocity forcing")
        plt.xlabel("Isotropic wavenumber")
        plt.grid()

    # Subgrid - Vorticity flux
    if state_variable == "flux":

        # Initialization of the figure
        fig = plt.figure(figsize=(22, 6))
        plt.subplot(1, 2, 1, title = '$\phi_{q_{u}}$')
        plt.ylabel("Grid coordinates ($\mathbb{R}$) - $y$ direction")

        # Retreiving fluxes
        uq, vq = model_lr.subgrid_fluxes('q')

        # Subgrid vorticity flux in horizontal direction
        imshow_2(uq[1], 1.5e-8)

        # Subgrid vorticity flux in vertical direction
        plt.subplot(1, 2, 2, title = '$\phi_{q_{v}}$')
        im = imshow_2(vq[1], 1.5e-7)
        cb = fig.colorbar(im, ax = fig.axes, pad=0.15).set_label('$\phi_{q}$ [$m\,s^{-2}$]')

<hr style="color:#5A7D9F; width: 97px;" align="left">
<p style="color:#5A7D9F;">State variables</p>
<hr style="color:#5A7D9F; width: 97PX;" align="left">

In [None]:
# -- Potential vorticity q -- 
plotStateVariable(high_res, low_res, state_variable = "q")

In [None]:
# -- Horizontal velocity u -- 
plotStateVariable(high_res, low_res, state_variable = "u")

In [None]:
# -- Vertical velocity v -- 
plotStateVariable(high_res, low_res, state_variable = "v")

In [None]:
# -- Horizontal velocity with background flow ufull -- 
plotStateVariable(high_res, low_res, state_variable = "ufull")

In [None]:
# -- Vertical velocity with background flow vfull -- 
plotStateVariable(high_res, low_res, state_variable = "vfull")

<hr style="color:#5A7D9F; width: 113px;" align="left">
<p style="color:#5A7D9F;">Subgrid variables</p>
<hr style="color:#5A7D9F; width: 113PX;" align="left">

In [None]:
# -- Subgrid potential vorticity --
plotForcingVariable(model_lr, state_variable = "sq")

In [None]:
# -- Subgrid horizontal velocity --
plotForcingVariable(model_lr, state_variable = "su")

In [None]:
# -- Subgrid potential vorticity flux --
plotForcingVariable(model_lr, state_variable = "flux")

[comment]: <> (Section)
<hr style="color:#5A7D9F;">
<p align="center">
    <b style="font-size:1.5vw; color:#5A7D9F;">
    Fully Convolutional Neural Network - Architecture definition, training and saving
    </b>
</p>
<hr style="color:#5A7D9F;">

[comment]: <> (Description)
<p align="justify">
    In this section, one will be able to:
</p>
 
- Create a new fully convolutional neural network based on the class FCNN created by L. Zanne & Al;

- Train the network on the dataset created in the former section;

- Save it for later use.

*NOTE*

- If you desire to experiment with the **architecture of the FCNN**, you can change it in the file *neural_network.py* starting at line 40

In [None]:
# Define the name of the folder to store the resulting network
result_folder = "test"

In [None]:
# -- Functions --
#
# Check if a folder exist, if not, create it
def checkFolder(folder_name):

    if not os.path.exists(folder_name):
        os.makedirs(folder_name)

# TEMPORARY
def get_dataset_2(path, base_url = "https://g-402b74.00888.8540.data.globus.org"):
    mapper = fsspec.get_mapper(f"{base_url}/{path}.zarr")
    return xr.open_zarr(mapper, consolidated=True)


In [None]:
# -- Security --
#
# Define the complete path to result folder
result_folder_path = f"../models/{result_folder}/"

# Checks the availability of the result folder
checkFolder(result_folder_path)

In [None]:
# -- TO BE REMOVED
eddy_forcing1 = get_dataset_2('eddy/forcing1').isel(run=0).load()

In [None]:
%load_ext autoreload
%autoreload 2

from pyqg_parameterization_benchmarks.neural_networks import FCNNParameterization

In [None]:
FCNNParameterization.test()

In [None]:
# The FCNN is created, trained and saved easily using the train_on class method.
FCNN_trained = FCNNParameterization.train_on(dataset    = eddy_forcing1, 
                                             directory  = result_folder_path,
                                             inputs     = ['q', 'u', 'v'],
                                             targets    = ['q_subgrid_forcing'],
                                             num_epochs = 1, 
                                             zero_mean  = True, 
                                             padding    = 'circular')