# Initialization

In [1]:
from cbrain.imports import *
from cbrain.data_generator import *
from cbrain.cam_constants import *
from cbrain.losses import *
from cbrain.utils import limit_mem
from cbrain.layers import *
from cbrain.data_generator import DataGenerator

import tensorflow as tf
physical_devices = tf.config.experimental.list_physical_devices('GPU') 
tf.config.experimental.set_memory_growth(physical_devices[0], True)
tf.config.experimental.set_memory_growth(physical_devices[1], True)
tf.config.experimental.set_memory_growth(physical_devices[2], True)

from tensorflow import math as tfm
import tensorflow_probability as tfp
from tensorflow.keras.layers import *
from tensorflow.keras.models import *
import xarray as xr
import numpy as np
from cbrain.model_diagnostics import ModelDiagnostics
import matplotlib as mpl
import matplotlib.pyplot as plt
import matplotlib.image as imag
import scipy.integrate as sin
import cartopy.crs as ccrs
import matplotlib.ticker as mticker
from cartopy.mpl.gridliner import LONGITUDE_FORMATTER, LATITUDE_FORMATTER
import pickle

#TRAINDIR = '/local/Tom.Beucler/SPCAM_PHYS/'
TRAINDIR = '/DFS-L/DATA/pritchard/tbeucler/SPCAM/SPCAM_PHYS/'
DATADIR = '/project/meteo/w2w/A6/S.Rasp/SP-CAM/fluxbypass_aqua/'
PREFIX = '8col009_01_'
#%cd /filer/z-sv-pool12c/t/Tom.Beucler/SPCAM/CBRAIN-CAM
%cd /export/home/tbeucler/CBRAIN-CAM
# Otherwise tensorflow will use ALL your GPU RAM for no reason
#limit_mem()

/nfspool-0/home/tbeucler/CBRAIN-CAM/notebooks/tbeucler_devlog
/nfspool-0/home/tbeucler/CBRAIN-CAM


# Custom generator (all outputs minus the residual ones)

## Build custom generator class

Taking as argument the *output* indices it will not be trained on **out_cut_off** (var_cut_off refers to the *input* indices it is not trained on). **out_cut_off** will be formatted as a dictionary with int entries corresponding to the single index to exclude from the output vector.

In [2]:
return_var_idxs

<function cbrain.utils.return_var_idxs(ds, var_list, var_cut_off=None)>

In [3]:
def return_var_idxs_outputcutoff(ds, var_list, out_cut_off=None):
    """
    To be used on stacked variable dimension. Returns indices array
    Parameters
    ----------
    ds: xarray dataset
    var_list: list of variables
    Returns
    -------
    var_idxs: indices array
    """
    if out_cut_off is None:
        var_idxs = np.concatenate([np.where(ds.var_names == v)[0] for v in var_list])
    else:
        idxs_list = []
        for v in var_list:
            i = np.where(ds.var_names == v)[0]
            if v in out_cut_off.keys():
                i = np.delete(i,out_cut_off[v])
            idxs_list.append(i)
        var_idxs = np.concatenate(idxs_list)
    return var_idxs

In [4]:
class DictNormalizer_outputcutoff(object):
    """Normalizer that takes a conversion dictionary as input. Simply scales by factors in dict."""
    def __init__(self, norm_ds, var_list, dic=None,out_cut_off=None):
        if dic is None: dic = conversion_dict
        var_idxs = return_var_idxs_outputcutoff(norm_ds, var_list, out_cut_off=out_cut_off)
        var_names = norm_ds.var_names[var_idxs].copy()
        scale = []
        for v in var_list:
            s = np.atleast_1d(dic[v])
            # Modification below: Delete scaling factor for outputs
            # that have been cut off via out_cut_off 
            if v in out_cut_off.keys(): s = np.delete(s,out_cut_off[v])
            scale.append(s)
        self.scale = np.concatenate(scale).astype('float32')
        self.transform_arrays = {
            'scale': self.scale,
        }

    def transform(self, x):
        return x * self.scale

    def inverse_transform(self, x):
        return x / self.scale

In [5]:
class DataGenerator_outputcutoff(tf.keras.utils.Sequence):
    """
    https://stanford.edu/~shervine/blog/keras-how-to-generate-data-on-the-fly
    Data generator class.
    """

    def __init__(self, data_fn, input_vars, output_vars,
                 norm_fn=None, input_transform=None, output_transform=None,
                 batch_size=1024, shuffle=True, xarray=False, var_cut_off=None,
                out_cut_off=None):
        # Just copy over the attributes
        self.data_fn, self.norm_fn = data_fn, norm_fn
        self.input_vars, self.output_vars = input_vars, output_vars
        self.batch_size, self.shuffle = batch_size, shuffle

        # Open datasets
        self.data_ds = xr.open_dataset(data_fn)
        if norm_fn is not None: self.norm_ds = xr.open_dataset(norm_fn)

        # Compute number of samples and batches
        self.n_samples = self.data_ds.vars.shape[0]
        self.n_batches = int(np.floor(self.n_samples) / self.batch_size)

        # Get input and output variable indices
        self.input_idxs = return_var_idxs(self.data_ds, input_vars, var_cut_off)
        self.output_idxs = return_var_idxs_outputcutoff(self.data_ds, output_vars, out_cut_off=out_cut_off)
        self.n_inputs, self.n_outputs = len(self.input_idxs), len(self.output_idxs)

        # Initialize input and output normalizers/transformers
        if input_transform is None:
            self.input_transform = Normalizer()
        elif type(input_transform) is tuple:
            self.input_transform = InputNormalizer(
                self.norm_ds, input_vars, input_transform[0], input_transform[1], var_cut_off)
        else:
            self.input_transform = input_transform  # Assume an initialized normalizer is passed

        if output_transform is None:
            self.output_transform = Normalizer()
        elif type(output_transform) is dict:
            self.output_transform = DictNormalizer_outputcutoff(self.norm_ds, output_vars, output_transform,
                                                                out_cut_off=out_cut_off)
        else:
            self.output_transform = output_transform  # Assume an initialized normalizer is passed

        # Now close the xarray file and load it as an h5 file instead
        # This significantly speeds up the reading of the data...
        if not xarray:
            self.data_ds.close()
            self.data_ds = h5py.File(data_fn, 'r')

    def __len__(self):
        return self.n_batches

    def __getitem__(self, index):
        # Compute start and end indices for batch
        start_idx = index * self.batch_size
        end_idx = start_idx + self.batch_size

        # Grab batch from data
        batch = self.data_ds['vars'][start_idx:end_idx]

        # Split into inputs and outputs
        X = batch[:, self.input_idxs]
        Y = batch[:, self.output_idxs]

        # Normalize
        X = self.input_transform.transform(X)
        Y = self.output_transform.transform(Y)

        return X, Y

    def on_epoch_end(self):
        self.indices = np.arange(self.n_batches)
        if self.shuffle: np.random.shuffle(self.indices)

## Build custom generator and compare to standard generator

For the manuscript's purposes, we will choose the lowest levels as the residuals for direct comparison with the reference ACnet.

### Standard generator

In [6]:
TRAINFILEQ = '8col009_01_train.nc'
VALIDFILEQ = '8col009_01_valid.nc'
NORMFILEQ = '8col009_01_norm.nc'
TESTFILEQ = '8col009_01_test.nc'

In [7]:
scale_dictQ = load_pickle('./nn_config/scale_dicts/009_Wm2_scaling.pkl')

In [8]:
in_varsQ = ['QBP', 'QCBP', 'QIBP', 'TBP', 'VBP', 
           'Qdt_adiabatic', 'QCdt_adiabatic', 'QIdt_adiabatic', 'Tdt_adiabatic', 'Vdt_adiabatic',
           'PS', 'SOLIN', 'SHFLX', 'LHFLX']
out_varsQ = ['PHQ', 'PHCLDLIQ', 'PHCLDICE', 'TPHYSTND', 'QRL', 'QRS', 'DTVKE', 
            'FSNT', 'FSNS', 'FLNT', 'FLNS', 'PRECT', 'PRECTEND', 'PRECST', 'PRECSTEN']

In [9]:
train_genQ = DataGenerator(
    data_fn = TRAINDIR+TRAINFILEQ,
    input_vars = in_varsQ,
    output_vars = out_varsQ,
    norm_fn = TRAINDIR+NORMFILEQ,
    input_transform = ('mean', 'maxrs'),
    output_transform = scale_dictQ,
    batch_size=1024,
    shuffle=True
)

In [10]:
valid_genQ = DataGenerator(
    data_fn = TRAINDIR+VALIDFILEQ,
    input_vars = in_varsQ,
    output_vars = out_varsQ,
    norm_fn = TRAINDIR+NORMFILEQ,
    input_transform = ('mean', 'maxrs'),
    output_transform = scale_dictQ,
    batch_size=1024,
    shuffle=True
)

In [11]:
test_genQ = DataGenerator(
    data_fn = TRAINDIR+TESTFILEQ,
    input_vars = in_varsQ,
    output_vars = out_varsQ,
    norm_fn = TRAINDIR+NORMFILEQ,
    input_transform = ('mean', 'maxrs'),
    output_transform = scale_dictQ,
    batch_size=1024,
    shuffle=True
)

### Custom generator

In [12]:
in_vars_custom = ['QBP', 'QCBP', 'QIBP', 'TBP', 'VBP', 
           'Qdt_adiabatic', 'QCdt_adiabatic', 'QIdt_adiabatic', 'Tdt_adiabatic', 'Vdt_adiabatic',
           'PS', 'SOLIN', 'SHFLX', 'LHFLX']
out_vars_custom = ['PHQ', 'PHCLDLIQ', 'PHCLDICE', 'TPHYSTND', 'QRL', 'QRS', 'DTVKE', 
            'FSNT', 'FLNT', 'PRECT', 'PRECTEND', 'PRECST', 'PRECSTEN']

In [13]:
out_cut_off_low = {}

In [14]:
out_cut_off_low = {}
out_cut_off_low['PHQ'] = 29
out_cut_off_low['TPHYSTND'] = 29

In [15]:
out_cut_off_low

{'PHQ': 29, 'TPHYSTND': 29}

In [16]:
train_gen_custom = DataGenerator_outputcutoff(
    data_fn = TRAINDIR+TRAINFILEQ,
    input_vars = in_vars_custom,
    output_vars = out_vars_custom,
    norm_fn = TRAINDIR+NORMFILEQ,
    input_transform = ('mean', 'maxrs'),
    output_transform = scale_dictQ,
    batch_size=1024,
    shuffle=True,
    out_cut_off=out_cut_off_low
)

In [17]:
valid_gen_custom = DataGenerator_outputcutoff(
    data_fn = TRAINDIR+VALIDFILEQ,
    input_vars = in_vars_custom,
    output_vars = out_vars_custom,
    norm_fn = TRAINDIR+NORMFILEQ,
    input_transform = ('mean', 'maxrs'),
    output_transform = scale_dictQ,
    batch_size=1024,
    shuffle=True,
    out_cut_off=out_cut_off_low
)

In [18]:
test_gen_custom = DataGenerator_outputcutoff(
    data_fn = TRAINDIR+TESTFILEQ,
    input_vars = in_vars_custom,
    output_vars = out_vars_custom,
    norm_fn = TRAINDIR+NORMFILEQ,
    input_transform = ('mean', 'maxrs'),
    output_transform = scale_dictQ,
    batch_size=1024,
    shuffle=True,
    out_cut_off=out_cut_off_low
)

# Neural network 

In [145]:
inp = Input(shape=(304,))
densout = Dense(512, activation='linear')(inp)
densout = LeakyReLU(alpha=0.3)(densout)
for i in range (4):
    densout = Dense(512, activation='linear')(densout)
    densout = LeakyReLU(alpha=0.3)(densout)
out = Dense(214, activation='linear')(densout)
UCnet_214 = tf.keras.models.Model(inp, out)

In [149]:
name = 'UCnet_214_1'
path_HDF5 = '/DFS-L/DATA/pritchard/tbeucler/SPCAM/HDF5_DATA/'
earlyStopping = EarlyStopping(monitor='val_loss', patience=10, verbose=0, mode='min')
mcp_save = ModelCheckpoint(path_HDF5+name+'.hdf5',save_best_only=True, monitor='val_loss', mode='min')

In [150]:
UCnet_214.compile(tf.keras.optimizers.Adam(), loss=mse)

In [151]:
Nep = 10
UCnet_214.fit_generator(train_gen_custom, epochs=Nep, 
                        validation_data=valid_gen_custom,\
              callbacks=[earlyStopping, mcp_save])

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


<tensorflow.python.keras.callbacks.History at 0x7fae7e79f160>

# Train other NNs corresponding to the five NNs with optimized $\beta $

## q8T4

In [19]:
out_cut_off_low = {}
out_cut_off_low['PHQ'] = 8
out_cut_off_low['TPHYSTND'] = 4

In [20]:
out_cut_off_low

{'PHQ': 8, 'TPHYSTND': 4}

In [21]:
inp = Input(shape=(304,))
densout = Dense(512, activation='linear')(inp)
densout = LeakyReLU(alpha=0.3)(densout)
for i in range (4):
    densout = Dense(512, activation='linear')(densout)
    densout = LeakyReLU(alpha=0.3)(densout)
out = Dense(214, activation='linear')(densout)
UCnet_214_q8T4 = tf.keras.models.Model(inp, out)

In [22]:
UCnet_214_q8T4.compile(tf.keras.optimizers.Adam(), loss=mse)

In [19]:
name = 'UCnet_214_q8T4'
path_HDF5 = '/DFS-L/DATA/pritchard/tbeucler/SPCAM/HDF5_DATA/'

In [24]:
UCnet_214_q8T4.load_weights(path_HDF5+name+'.hdf5')

In [25]:
earlyStopping = EarlyStopping(monitor='val_loss', patience=10, verbose=0, mode='min')
mcp_save = ModelCheckpoint(path_HDF5+name+'.hdf5',save_best_only=True, monitor='val_loss', mode='min')

In [26]:
UCnet_214_q8T4.fit_generator(train_gen_custom, epochs=5, 
                        validation_data=valid_gen_custom,\
                        callbacks=[earlyStopping, mcp_save])

Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5


<tensorflow.python.keras.callbacks.History at 0x7f787c3dcb38>

In [28]:
# inp = Input(shape=(304,))
# densout = Dense(512, activation='linear')(inp)
# densout = LeakyReLU(alpha=0.3)(densout)
# for i in range (4):
#     densout = Dense(512, activation='linear')(densout)
#     densout = LeakyReLU(alpha=0.3)(densout)
# out = Dense(214, activation='linear')(densout)
# UCnet_214_q8T4 = tf.keras.models.Model(inp, out)

# name = 'UCnet_214_q8T4'
# path_HDF5 = '/DFS-L/DATA/pritchard/tbeucler/SPCAM/HDF5_DATA/'
# earlyStopping = EarlyStopping(monitor='val_loss', patience=10, verbose=0, mode='min')
# mcp_save = ModelCheckpoint(path_HDF5+name+'.hdf5',save_best_only=True, monitor='val_loss', mode='min')

# UCnet_214_q8T4.compile(tf.keras.optimizers.Adam(), loss=mse)

# UCnet_214_q8T4.fit_generator(train_gen_custom, epochs=Nep, 
#                         validation_data=valid_gen_custom,\
#                         callbacks=[earlyStopping, mcp_save])

## q3T26

In [31]:
out_cut_off_low = {}
out_cut_off_low['PHQ'] = 3
out_cut_off_low['TPHYSTND'] = 26

In [33]:
Nep = 10

In [34]:
inp = Input(shape=(304,))
densout = Dense(512, activation='linear')(inp)
densout = LeakyReLU(alpha=0.3)(densout)
for i in range (4):
    densout = Dense(512, activation='linear')(densout)
    densout = LeakyReLU(alpha=0.3)(densout)
out = Dense(214, activation='linear')(densout)
UCnet_214_q3T26 = tf.keras.models.Model(inp, out)

name = 'UCnet_214_q3T26'
path_HDF5 = '/DFS-L/DATA/pritchard/tbeucler/SPCAM/HDF5_DATA/'
earlyStopping = EarlyStopping(monitor='val_loss', patience=10, verbose=0, mode='min')
mcp_save = ModelCheckpoint(path_HDF5+name+'.hdf5',save_best_only=True, monitor='val_loss', mode='min')

UCnet_214_q3T26.compile(tf.keras.optimizers.Adam(), loss=mse)

UCnet_214_q3T26.fit_generator(train_gen_custom, epochs=Nep, 
                        validation_data=valid_gen_custom,\
                        callbacks=[earlyStopping, mcp_save])

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


<tensorflow.python.keras.callbacks.History at 0x7f787c0c2cf8>

## q4T4

In [20]:
Nep = 10

In [21]:
out_cut_off_low = {}
out_cut_off_low['PHQ'] = 4
out_cut_off_low['TPHYSTND'] = 4

In [22]:
Input

<function tensorflow.python.keras.engine.input_layer.Input(shape=None, batch_size=None, name=None, dtype=None, sparse=False, tensor=None, ragged=False, **kwargs)>

In [23]:
inp = Input(shape=(304,))
densout = Dense(512, activation='linear')(inp)
densout = LeakyReLU(alpha=0.3)(densout)
for i in range (4):
    densout = Dense(512, activation='linear')(densout)
    densout = LeakyReLU(alpha=0.3)(densout)
out = Dense(214, activation='linear')(densout)
UCnet_214_q4T4 = tf.keras.models.Model(inp, out)

In [24]:
inp = Input(shape=(304,))
densout = Dense(512, activation='linear')(inp)
densout = LeakyReLU(alpha=0.3)(densout)
for i in range (4):
    densout = Dense(512, activation='linear')(densout)
    densout = LeakyReLU(alpha=0.3)(densout)
out = Dense(214, activation='linear')(densout)
UCnet_214_q4T4 = tf.keras.models.Model(inp, out)

name = 'UCnet_214_q4T4'
path_HDF5 = '/DFS-L/DATA/pritchard/tbeucler/SPCAM/HDF5_DATA/'
earlyStopping = EarlyStopping(monitor='val_loss', patience=10, verbose=0, mode='min')
mcp_save = ModelCheckpoint(path_HDF5+name+'.hdf5',save_best_only=True, monitor='val_loss', mode='min')

UCnet_214_q4T4.compile(tf.keras.optimizers.Adam(), loss=mse)

UCnet_214_q4T4.fit_generator(train_gen_custom, epochs=Nep, 
                        validation_data=valid_gen_custom,\
                        callbacks=[earlyStopping, mcp_save])

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


<tensorflow.python.keras.callbacks.History at 0x7f1b8031e438>

## q5T5

In [25]:
out_cut_off_low = {}
out_cut_off_low['PHQ'] = 5
out_cut_off_low['TPHYSTND'] = 5

In [26]:
inp = Input(shape=(304,))
densout = Dense(512, activation='linear')(inp)
densout = LeakyReLU(alpha=0.3)(densout)
for i in range (4):
    densout = Dense(512, activation='linear')(densout)
    densout = LeakyReLU(alpha=0.3)(densout)
out = Dense(214, activation='linear')(densout)
UCnet_214_q5T5 = tf.keras.models.Model(inp, out)

name = 'UCnet_214_q5T5'
path_HDF5 = '/DFS-L/DATA/pritchard/tbeucler/SPCAM/HDF5_DATA/'
earlyStopping = EarlyStopping(monitor='val_loss', patience=10, verbose=0, mode='min')
mcp_save = ModelCheckpoint(path_HDF5+name+'.hdf5',save_best_only=True, monitor='val_loss', mode='min')

UCnet_214_q5T5.compile(tf.keras.optimizers.Adam(), loss=mse)

UCnet_214_q5T5.fit_generator(train_gen_custom, epochs=Nep, 
                        validation_data=valid_gen_custom,\
                        callbacks=[earlyStopping, mcp_save])

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


<tensorflow.python.keras.callbacks.History at 0x7f1b802c67f0>

## q18T28

In [27]:
out_cut_off_low = {}
out_cut_off_low['PHQ'] = 18
out_cut_off_low['TPHYSTND'] = 28

In [28]:
inp = Input(shape=(304,))
densout = Dense(512, activation='linear')(inp)
densout = LeakyReLU(alpha=0.3)(densout)
for i in range (4):
    densout = Dense(512, activation='linear')(densout)
    densout = LeakyReLU(alpha=0.3)(densout)
out = Dense(214, activation='linear')(densout)
UCnet_214_q18T28 = tf.keras.models.Model(inp, out)

name = 'UCnet_214_q18T28'
path_HDF5 = '/DFS-L/DATA/pritchard/tbeucler/SPCAM/HDF5_DATA/'
earlyStopping = EarlyStopping(monitor='val_loss', patience=10, verbose=0, mode='min')
mcp_save = ModelCheckpoint(path_HDF5+name+'.hdf5',save_best_only=True, monitor='val_loss', mode='min')

UCnet_214_q18T28.compile(tf.keras.optimizers.Adam(), loss=mse)

UCnet_214_q18T28.fit_generator(train_gen_custom, epochs=Nep, 
                        validation_data=valid_gen_custom,\
                        callbacks=[earlyStopping, mcp_save])

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


<tensorflow.python.keras.callbacks.History at 0x7f1b80221ac8>

In [29]:
2

2