<img src="assets/header_notebook.jpg" />
<hr style="color:#5A7D9F;">
<p align="center">
    <b style="font-size:2vw; color:#5A7D9F; font-weight:bold;">
    <center>Ocean subgrid parameterizations in an idealized model using machine learning</center>
    </b>
</p>
<hr style="color:#5A7D9F;">

[comment]: <> (Section)
<hr style="color:#5A7D9F;">
<p align="center">
    <b style="font-size:1.5vw; color:#5A7D9F;">
    <center>Initialization</center>
    </b>
</p>
<hr style="color:#5A7D9F;">

[comment]: <> (Description)
<p align="justify">
    In this section, one is able to:
</p>

- Load all the librairies and functions  (from PYQG benchmark or custom ones) used throughout the entire notebook.

In [None]:
# -----------------
#     Librairies
# -----------------
#
# --------- Standard ---------
import os
import sys
import json
import glob
import math
import torch
import random
import fsspec
import matplotlib
import numpy             as np
import pandas            as pd
import xarray            as xr
import seaborn           as sns
import matplotlib.pyplot as plt
from argparse import ArgumentParser
from scipy.stats import gaussian_kde
from torch.utils.tensorboard import SummaryWriter

# --------- PYQG ---------
import pyqg
import pyqg.diagnostic_tools
from   pyqg.diagnostic_tools import calc_ispec         as _calc_ispec
import pyqg_parameterization_benchmarks.coarsening_ops as coarsening

calc_ispec = lambda *args, **kwargs: _calc_ispec(*args, averaging = False, truncate =False, **kwargs)

# --------- PYQG Benchmark ---------
from pyqg_parameterization_benchmarks.utils           import *
from pyqg_parameterization_benchmarks.utils_TFE       import *
from pyqg_parameterization_benchmarks.plots_TFE       import *
from pyqg_parameterization_benchmarks.online_metrics  import diagnostic_differences
from pyqg_parameterization_benchmarks.neural_networks import FullyCNN, FCNNParameterization


# --------- Jupyter ---------
# You can use the magic function %matplotlib inline to enable the inline plotting, 
# where the plots/graphs will be displayed just below the cell where your plotting
# commands are written. 
%matplotlib inline
plt.rcParams.update({'font.size': 13})

# Making sure modules are reloaded when modified
%load_ext autoreload
%autoreload 2

[comment]: <> (Section)
<hr style="color:#5A7D9F;">
<p align="center">
    <b style="font-size:1.5vw; color:#5A7D9F;">
    <center>PYQG - Generating & Saving dataset<\center>
    </b>
</p>
<hr style="color:#5A7D9F;">

[comment]: <> (Description)
<p align="justify">
    In this section, one will be able to generate:
</p>

- A **high resolution** (= HR) simulation from a quasi-geostrophic model (PYQG), 
        
- A **low resolution** (= LR) simulation.
        
- An **augmented low resolution** (= ALR) simulations.
        
Furthermore, one will be able to:
        
- Observe the **state** and **subgrid variables** associated to the dataset in the corresponding folder.

- **Save** the datasets on the hard drive. 

<hr style="color:#5A7D9F; width: 100%;" align="left">
<p align="center">
	<b style="font-size:1vw;">
	<center>Simulation type</center>
	</b>
</p>
<hr style="color:#5A7D9F; width: 100%;" align="left">
<table style="width: 100%;" border="1">
	<tbody>
		<tr>
			<td style="width: 15%;" align="center"><b>INDEX</b></td>
			<td style="width: 13%;" align="center">0</td>
			<td style="width: 13%;" align="center">1</td>
			<td style="width: 15%;" align="center">2</td>
			<td style="width: 15%;" align="center">3</td>
			<td style="width: 16%;" align="center">4</td>
			<td style="width: 18%;" align="center">5</td>
		</tr>
		<tr>
			<td style="width: 15%;"align="center"><b>TYPE</b></td>
			<td style="width: 13%;" align="center">Eddies</td>
			<td style="width: 13%;" align="center">Jets  </td>
			<td style="width: 15%;" align="center">Eddies (Debug)</td>
			<td style="width: 15%;" align="center">Jets   (Debug)</td>
			<td style="width: 16%;" align="center">Eddies (Random)</td>
			<td style="width: 18%;" align="center">Jets (Random)</td>
		</tr>
	</tbody>
</table>

<br>

<hr style="color:#5A7D9F; width: 100%;" align="left">
<p align="center">
	<b style="font-size:1vw;">
	<center>Parameters</center>
	</b>
</p>
<hr style="color:#5A7D9F; width: 100%;" align="left">
<table style="width: 100%;" border="1">
	<tbody>
		<tr style="height: 21px;">
			<td style="width: 16%;" align="center"><b>PARAMETERS</b></td>
			<td style="width: 12%;" align="center">nx</td>
			<td style="width: 10%;" align="center">dt</td>
			<td style="width: 12%;" align="center">tmax</td>
			<td style="width: 12%;" align="center">tavestart</td>
			<td style="width: 12%;" align="center">rek</td>
			<td style="width: 12%;" align="center">Δ</td>
			<td style="width: 14%;" align="center">β</td>
		</tr>
		<tr style="height: 21.5px;">
			<td style="width: 16%;" align="center"><b>DESCRIPTION</b></td>
			<td style="width: 12%;" align="center">Number of real space grid points in the x directions</td>
			<td style="width: 10%;" align="center">Numerical timestep (in hours)</td>
			<td style="width: 12%;" align="center">Total time of integration (in years)</td>
			<td style="width: 12%;" align="center">Start time for averaging (in years)</td>
			<td style="width: 12%;" align="center">Linear drag in lower layer</td>
			<td style="width: 12%;" align="center">Layer thickness ratio (H1/H2)</td>
			<td style="width: 14%;" align="center">Gradient of coriolis parameter.</td>
		</tr>
		<tr style="height: 21.5px;">
			<td style="width: 16%;" align="center"><b>EDDIES</b></td>
			<td style="width: 12%;" align="center">256</td>
			<td style="width: 10%;" align="center">1</td>
			<td style="width: 12%;" align="center">10</td>
			<td style="width: 12%;" align="center">5</td>
			<td style="width: 12%;" align="center">5.789e-7</td>
			<td style="width: 12%;" align="center">0.25</td>
			<td style="width: 14%;" align="center">1.5 * 1e-11</td>
		</tr>
		<tr style="height: 21px;">
			<td style="width: 16%;" align="center"><b>JETS</b></td>
			<td style="width: 12%;" align="center">256</td>
			<td style="width: 10%;" align="center">1</td>
			<td style="width: 12%;" align="center">10</td>
			<td style="width: 12%;" align="center">5</td>
			<td style="width: 12%;" align="center">7e-08</td>
			<td style="width: 12%;" align="center">0.1</td>
			<td style="width: 14%;" align="center">1e-11</td>
		</tr>
		<tr style="height: 21.5px;">
			<td style="width: 16%;" align="center"><b>EDDIES (Debug)</b></td>
			<td style="width: 12%;" align="center">256</td>
			<td style="width: 10%;" align="center">1</td>
			<td style="width: 12%;" align="center">2</td>
			<td style="width: 12%;" align="center">1</td>
			<td style="width: 12%;" align="center">5.789e-7</td>
			<td style="width: 12%;" align="center">0.25</td>
			<td style="width: 14%;" align="center">1.5 * 1e-11</td>
		</tr>
		<tr style="height: 21px;">
			<td style="width: 16%;" align="center"><b>JETS (Debug)</b></td>
			<td style="width: 12%;" align="center">256</td>
			<td style="width: 10%;" align="center">1</td>
			<td style="width: 12%;" align="center">2</td>
			<td style="width: 12%;" align="center">1</td>
			<td style="width: 12%;" align="center">7e-08</td>
			<td style="width: 12%;" align="center">0.1</td>
			<td style="width: 14%;" align="center">1e-11</td>
		</tr>
		<tr style="height: 21.5px;">
			<td style="width: 16%;" align="center"><b>EDDIES (Random)</b></td>
			<td style="width: 12%;" align="center">256</td>
			<td style="width: 10%;" align="center">1</td>
			<td style="width: 12%;" align="center">10</td>
			<td style="width: 12%;" align="center">5</td>
			<td style="width: 12%;" align="center">[5.7, 5.9] * 1e-7</td>
			<td style="width: 12%;" align="center">0.25</td>
			<td style="width: 14%;" align="center">[1.45, 1.55] * 1e-11</td>
		</tr>
		<tr style="height: 21px;">
			<td style="width: 16%;" align="center"><b>JETS (Random)</b></td>
			<td style="width: 12%;" align="center">256</td>
			<td style="width: 10%;" align="center">1</td>
			<td style="width: 12%;" align="center">10</td>
			<td style="width: 12%;" align="center">5</td>
			<td style="width: 12%;" align="center">[6.9, 7.1] * 1e-8</td>
			<td style="width: 12%;" align="center">0.1</td>
			<td style="width: 14%;" align="center">[0.95, 1.05] * 1e-11</td>
		</tr>
	</tbody>
</table>

<br>

<hr style="color:#5A7D9F; width: 100%;" align="left">
<p align="center">
	<b style="font-size:1vw;">
	<center>Coarsening operators</center>
	</b>
</p>
<hr style="color:#5A7D9F; width: 100%;" align="left">
<table style="width: 100%;" border="1">
	<tbody>
		<tr>
			<td style="width: 10%;" align="center"><b>OPERATOR</b></td>
			<td style="width: 15%;" align="center">&nbsp;1</td>
			<td style="width: 15%;" align="center">&nbsp;2</td>
			<td style="width: 17%;" align="center">3</td>
		</tr>
		<tr>
			<td style="width: 10%;"align="center"><b>DESCRIPTION</b></td>
			<td style="width: 15%;"align="center">Spectral Truncation, Sharp Filter</td>
			<td style="width: 15%;"align="center">Spectral Truncation, Gaussian Filter</td>
			<td style="width: 17%;"align="center">GCM Filter, Averaging and Coarsening</td>
		</tr>
	</tbody>
</table>
<hr style="color:#5A7D9F; width: 100%;" align="left">

In [None]:
# ----------------------------------
#             Documentation
# ----------------------------------
# save_folder        : Name of the folder used to save the datasets
# nb_threads         : Number of threads used to run the simulation
# simulation_type    : Type of simulation used to generate the dataset
# memory             : Total number of memory allocated [GB] (used for security purpose)
# skipped_time       : Time [year] at which the sampling of the simulation starts
# save_high_res      : Choose if the whole high resolution is saved or just the last sample (memory saving)
# operator_cf        : Coarsening and filtering operator applied on the high resolution simulation
# target_sample_size : Number of samples expected to be in the datasets (nb_sample >= target_sample_size)
#-------------------------------------------
%cd ../src/pyqg_parameterization_benchmarks/
#-------------------------------------------

%run generate_dataset.py --save_folder        test                                                                                                                                   \
                         --simulation_type       2                                                                                                                                   \
                         --target_sample_size 1000                                                                                                                                   \
                         --operator_cf           1                                                                                                                                   \
                         --skipped_time          1                                                                                                                                   \
                         --nb_threads            1                                                                                                                                   \
                         --memory               10                                                                                                                                   \
                         --save_high_res     False

# Attention / are at the end hidden
#-------------------------------------------
%cd ../../notebooks
#-------------------------------------------

[comment]: <> (Section)
<hr style="color:#5A7D9F;">
<p align="center">
    <b style="font-size:1.5vw; color:#5A7D9F;">
    <center>Ocean subgrid parameterization - Learning</center>
    </b>
</p>
<hr style="color:#5A7D9F;">

[comment]: <> (Description)
<p align="justify">
    In this section, one will be able to:
</p>
 
- **Create** and **train** a new parameterization using the datasets created previously.

<table style="width: 100%;" border="1">
	<tbody>
		<tr>
			<td style="width: 15%;" align="center"><b>Train</b></td>
			<td style="width: 13%;" align="center">Validation</td>
			<td style="width: 13%;" align="center">Offline</td>
			<td style="width: 15%;" align="center">online</td>
		</tr>
		<tr>
			<td style="width: 15%;"align="center"><b>TYPE</b></td>
			<td style="width: 13%;" align="center">Eddies</td>
			<td style="width: 13%;" align="center">Jets  </td>
			<td style="width: 15%;" align="center">Eddies (Debug)</td>
		</tr>
	</tbody>
</table>

In [None]:
# ----------------------------------
#             Documentation
# ----------------------------------
# folder_training   : Folder used to load data as training data
# folder_validation : Folder used to load data as training data
# save_directory    : Folder used to load data as training data
# inputs            : Type of inputs given to the parameterization for training
# targets           : Parameterization ouptut
# num_epochs        : Number of epochs made by the paremeterization while training
# zero_mean         : Type of pre-processing made on the datasets
# padding           : Type of padding used by the parameterization
# memory            : Total number of memory allocated [GB] (used for security purpose)
# sim_type          : Type of fluid simulation studied (used to order tensorboard folders)
#-------------------------------------------
%cd ../src/pyqg_parameterization_benchmarks/
#-------------------------------------------

%run train_parameterization.py --folder_training         eddies_training_1 eddies_training_2 eddies_training_3                                                                                                            \
                               --folder_validation     eddies_validation_1                                                                                                           \
                               --save_directory                       test1                                                                                                                          \
                               --inputs                                  q                                                                                                                     \
                               --targets                 q_subgrid_forcing                                                                                                                          \
                               --num_epochs                              1                                                                                                                          \
                               --zero_mean                            True                                                                                                                          \
                               --padding                          circular                                                                                                                          \
                               --memory                                  8                                                                                                                          \
                               --sim_type                           eddies                                                                                                                          

# Attention / are at the end hidden
#-------------------------------------------
%cd ../../notebooks
#-------------------------------------------

[comment]: <> (Section)
<hr style="color:#5A7D9F;">
<p align="center">
    <b style="font-size:1.5vw; color:#5A7D9F;">
    <center>Ocean subgrid parameterization - Testing (Offline)</center>
    </b>
</p>
<hr style="color:#5A7D9F;">

[comment]: <> (Description)
<p align="justify">
    In this section, one will be able to:
</p>
 
- **Load** a trained FCNN parameterization;

- Evaluate its **offline performances** on a test set, i.e. it's ability to **predict** accurately the **subgrid forcing terms**;

<hr style="color:#5A7D9F;">

The two metrics used to evaluate the offline performances of the parameterization are:

- **Pearson correlation** ($\rho$) : where $\sigma$ denotes the empirical standard deviation of a quantity over the dataset. This quantity is between -1 and 1 and can remain high even when R2 is negative, e.g. if predictions are wrong by a large but consistent scaling factor. It's mathematical expression is:

$$\rho = \dfrac{\text{Cov}(S, \hat{S})}{\sigma_S \sigma_{\hat{S}}}$$

<br>

- **Coefficient of determination** ($R^2$) : which is 1 when predictions are perfect, 0 when predictions are no better than than always predicting the mean, and negative when worse than always predicting the mean. It's mathematical expression is:

$$R^2 = 1 - \dfrac{E[(S - \hat{S})^2]}{E[(S - E[S])^2]}$$

In [None]:
# ----------------------------------
#             Documentation
# ----------------------------------
# folder_offline : Folders used to load data as offline test data
# folder_models  : Folder (inside the model folder) used to load all the different models to be tested
# memory         : Total number of memory allocated [GB] (used for security purpose)
#-------------------------------------------
%cd ../src/pyqg_parameterization_benchmarks/
#-------------------------------------------

%run offline.py --folder_offline        eddies_offline_1 eddies_offline_2 eddies_offline_3 eddies_offline_4 eddies_offline_5 eddies_offline_6 eddies_offline_7 eddies_offline_8 eddies_offline_9 eddies_offline_10 \
                --folder_models                 baseline \
                --memory                              16
                                                                                                                         
# Attention / are at the end hidden
#-------------------------------------------
%cd ../../notebooks
#-------------------------------------------

<hr style="color:#5A7D9F; width: 100%;" align="left">
<p align="center">
	<b style="font-size:1vw;">
	<center>Online (GPU)</center>
	</b>
</p>
<hr style="color:#5A7D9F; width: 100%;" align="left">

In this section, the goal is to test the **ability of the network to make good predictions** by checking the **physics** of the results.

In [None]:
    
    # Used to plot easily the results
    def imshow(arr):
        plt.imshow(arr, vmin = 0, vmax = 1, cmap = 'inferno')
        mean = arr.mean().data
        plt.text(32, 32, f"{mean:.2f}", color = ('white' if mean < 0.75 else 'black'),
                 fontweight = 'bold', ha = 'center', va = 'center', fontsize=16)
        plt.xticks([]); plt.yticks([])

    def colorbar(label):
        plt.colorbar().set_label(label, fontsize=16,rotation=0,ha='left',va='center')

    def getDatasetAttributes(dataset):

        # Retreives all the attributes
        attributes = {}

        # Constructing new attr
        for k, v in dataset.attrs.items():

            # Updating string
            updt_k = k.replace('pyqg:', '')

            # Updating attribute dictionnary with correct parameters
            if k in ["nx", "dt", "tmax", "tavestart", "rek", "delta", "beta"]:
                attributes[k] = v

        return attributes.copy()

In [None]:
# -- Computing predictions --
#
online_preds = FCNN_trained.run_online(sampling_freq=24 * 26, **getDatasetAttributes(test_LR))

# Storing simulation results
online_simulation = [(test_HR, 'High-res'),(test_LR, 'Low-res'), (online_preds,'Low-res + FCNN')]

In [None]:
# -- Energy Budget Comparison --
fig = energy_budget_figure(online_simulation)
fig.suptitle(f"Energy Budget Comparison - {loading_folder} - {loading_folder_data}")
plt.tight_layout()
plt.show()

# Complete path to save the figure
online_path = f"../datasets/{loading_folder_data}/online/"

# Check if image folder exists
if not os.path.exists(online_path):
    os.makedirs(online_path)

# Save the figure
fig.savefig(online_path + f"/energy_budget.png")

In [None]:
# -- Upper Potential Vorticity Comparison --
fig = plt.figure(figsize = (13, 4))
plt.suptitle(f"Upper Potential Vorticity - {loading_folder} - {loading_folder_data}")

for i, (m, label) in enumerate(online_simulation):
    plt.subplot(1, 3, i + 1, title = label)
    plt.imshow(m.q.isel(lev = 0, time = -1), cmap='bwr', vmin = -3e-5, vmax = 3e-5)
plt.colorbar(label="Upper PV [$s^{-1}$]")
plt.tight_layout()

# Complete path to save the figure
online_path = f"../datasets/{loading_folder_data}/online/"

# Check if image folder exists
if not os.path.exists(online_path):
    os.makedirs(online_path)

# Save the figure
fig.savefig(online_path + f"/PV_comparison.png")

In [None]:
# -- Vorticity Distribution Comparison (Quasi Steady-State) --
fig = fig = plt.figure()
plt.title(f"Differences in distributions of quasi-steady $q_1$ - {loading_folder} - {loading_folder_data}")
for i, (m, label) in enumerate(online_simulation):
    data = m.q.isel(lev = 0, time = slice(-20, None)).data.ravel()
    dist = gaussian_kde(data)
    x = np.linspace(*np.percentile(data, [1,99]), 1000)
    plt.plot(x, dist(x), label = label, lw = 3, ls = ('--' if 'FCNN' in label else '-'))
plt.legend()
plt.xlabel("Upper PV [$s^{-1}$]")
plt.ylabel("Probability density")
plt.tight_layout()

# Complete path to save the figure
online_path = f"../datasets/{loading_folder_data}/online/"

# Check if image folder exists
if not os.path.exists(online_path):
    os.makedirs(online_path)

# Save the figure
fig.savefig(online_path + f"/distribution.png")