# Apply neural network model to GLORYS12V1 daily data
Created by Ivan Lima on Fri May  6 2022 15:34:42 -0400

In [None]:
%matplotlib inline
import xarray as xr
import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
import os, datetime
from tqdm import notebook
print('Last updated on {}'.format(datetime.datetime.now().ctime()))

## Load neural network model and data scaler

In [None]:
import torch, joblib
import torch.nn as nn

scaler = joblib.load('models/scaler_nosat.joblib')

features = ['depth', 'bottom_depth', 'Temperature', 'Salinity', 'pCO2_monthave']

n_features = len(features) # number of input variables
n_targets = 2  # number of output variables
n_hidden = 256 # number of hidden layers
learning_rate = 0.001

class MLPReg(nn.Module):
    def __init__(self, n_features, n_hidden, n_targets):
        super(MLPReg, self).__init__()
        self.l1    = nn.Linear(n_features, n_hidden)
        self.l2    = nn.Linear(n_hidden, n_hidden)
        self.l3    = nn.Linear(n_hidden, n_targets)
        self.activ = nn.LeakyReLU()
    
    def forward(self, x):
        out = self.l1(x)
        out = self.activ(out)
        out = self.l2(out)
        out = self.activ(out)
        out = self.l3(out)
        return out

nn_reg = MLPReg(n_features=n_features, n_hidden=n_hidden, n_targets=n_targets) # create model instance
loss_func = nn.MSELoss()                                                       # loss function (mean square error)
optimizer = torch.optim.Adam(nn_reg.parameters(), lr=learning_rate)            # optimizer

nn_reg.load_state_dict(torch.load('models/nn_reg_nosat_state.pth'))
nn_reg.eval()

## Extract bottom depth at grid points

In [None]:
from scipy.interpolate import griddata

ds_grid = xr.open_dataset('/bali/data/ilima/GLORYS12V1/daily/GLORYS12V1_NW_Atlantic_1993_daily.nc',
                          drop_variables = ['mlotst','zos','bottomT'])

xx, yy = np.meshgrid(ds_grid.longitude.values, ds_grid.latitude.values)
df_positions = pd.DataFrame({'longitude': xx.ravel(), 'latitude': yy.ravel()})

# read bottom topography data
etopo = xr.open_dataset('data/etopo5.nc', chunks='auto')
# etopo['bath'] = etopo.bath.where(etopo.bath<0) # ocean points only
etopo = etopo.isel(X=slice(3100,4000), Y=slice(1300,1700)) # subset data to make things faster

X = np.where(etopo.X>180, etopo.X-360, etopo.X) # 0:360 -> -180:180
lon_topo, lat_topo = np.meshgrid(X, etopo.Y.values)
lon, lat = df_positions.longitude.values, df_positions.latitude.values
bottom_depth = griddata((lon_topo.ravel(), lat_topo.ravel()), etopo.bath.values.ravel(), (lon, lat), method='linear')
df_positions['bottom_depth'] = np.abs(bottom_depth)
print(df_positions.bottom_depth.min(), df_positions.bottom_depth.max())
df_positions.head()

## Read monthly atmospheric pCO2 data

In [None]:
df_pco2_monthly = pd.read_csv('work/co2_mm_mlo.csv')
df_pco2_monthly = df_pco2_monthly.set_index(['year','month'])

# for i in df_pco2_monthly.loc[2016].index:
#     print(i, df_pco2_monthly.loc[(2016,i),'average'])
# atm_pco2 = [df_pco2_monthly.loc[(2016,i),'average'] for i in ds.time.dt.month.values]

## Apply NN model to GLORYS12V1 data and compute carbonate chemistry variables 

In [None]:
import PyCO2SYS as pyco2
import gsw

cols = ['time', 'longitude', 'latitude', 'depth', 'bottom_depth', 'Temperature', 'Salinity', 'pCO2_monthave']

year = 2019
ds_in = xr.open_dataset('/bali/data/ilima/GLORYS12V1/daily/GLORYS12V1_NW_Atlantic_{}_daily.nc'.format(year),
                        drop_variables = ['mlotst','zos','bottomT'])
ds_out = [] # store list of datasets

for t in notebook.tqdm(range(ds_in.dims['time'])):
    ds = ds_in.isel(time=t)

    # add monthly atmospheric pCO2
    for i in df_pco2_monthly.loc[year].index:
        if i==1:
            fill = np.nan
        else:
            fill = ds.pCO2_monthave

        ds['pCO2_monthave'] = xr.where(ds.time.dt.month==i, df_pco2_monthly.loc[(2016,i),'average'], fill)


    # merge bottom depth with GLORYS12V1 data
    df_glorys = ds[['pCO2_monthave','thetao','so']].to_dataframe()
    df_glorys = df_glorys.reset_index().rename(columns={'thetao':'Temperature', 'so':'Salinity'})
    df_glorys = pd.merge(df_positions, df_glorys, on=['longitude', 'latitude'])
    df_data = df_glorys[cols].dropna()
    # print('{:,d} data points\n'.format(df_data.shape[0]))

    X_numpy = df_data[features].values # select features
    X_numpy_scaled = scaler.transform(X_numpy) # rescale features
    X = torch.from_numpy(X_numpy_scaled.astype(np.float32)) # convert array to tensor

    # apply model to rescaled features
    with torch.no_grad():
        Y_pred = nn_reg(X)

    # add estimated DIC & TA to features dataframe
    df_data['DIC'] = Y_pred[:,0]
    df_data['TA'] = Y_pred[:,1]

    # compute additional carbonate chemistry variables
    pressure =  gsw.p_from_z(-df_data.depth.values, df_data.latitude.values)
    kwargs = dict(
        par1 = df_data.TA.values,   # TA
        par2 = df_data.DIC.values,  # DIC
        par1_type = 1,             # type 1 = alkalinity
        par2_type = 2,             # type 2 = DIC
        salinity = df_data.Salinity.values,
        temperature = df_data.Temperature.values,
        pressure = pressure,
        opt_k_carbonic = 10,  # LDK00, Lueker et al 2000
        opt_k_bisulfate = 1,  # D90a, Dickson 1990
        opt_total_borate = 2, # LKB10, Lee et al 2010
        opt_k_fluoride = 2    # PF87, Perez & Fraga 1987
    )
    results = pyco2.sys(**kwargs)
    co2sys_vars = ['pH', 'saturation_calcite', 'saturation_aragonite']
    for vname in co2sys_vars:
        df_data[vname] = results[vname]
    
    # merge estimated carbonate chemistry variables to original dataframe
    for vname in ['DIC','TA'] + co2sys_vars:
        df_glorys[vname] = df_data[vname]

    # convert dataframe to xarray dataset
    df = df_glorys.set_index(['time','depth','latitude','longitude'])
    ds_bgc = df[['Temperature', 'Salinity', 'DIC', 'TA','pH', 'saturation_calcite', 'saturation_aragonite']].to_xarray()

    ds_out.append(ds_bgc)

## Create output dataset including T, S, DIC & TA

In [None]:
var_attrs = {
    'DIC': {'long_name': 'Dissolved inorganic carbon',
            'standard_name': 'DIC',
            'units': 'micro mol/kg',
            'unit_long': 'micro mol/kg'},
    'TA': {'long_name': 'Total alkalinity',
            'standard_name': 'TA',
            'units': 'micro mol/kg',
            'unit_long': 'micro mol/kg'},
    'pH': {'long_name': 'Total pH',
            'standard_name': 'pH',
            'units': '',
            'unit_long': ''},
    'saturation_calcite': {'long_name': 'Calcite saturation state',
            'standard_name': 'Calcite saturation',
            'units': '',
            'unit_long': ''},
    'saturation_aragonite': {'long_name': 'Aragonite saturation state',
            'standard_name': 'Aragonite saturation',
            'units': '',
            'unit_long': ''},
}

ds_out = xr.concat(ds_out, dim='time')

# copy variable attributes
for vname in var_attrs:
    ds_out[vname].attrs.update(var_attrs[vname])
for attr in ['long_name','standard_name','units','unit_long']:
    ds_out.Temperature.attrs[attr] = ds.thetao.attrs[attr]
    ds_out.Salinity.attrs[attr] = ds.so.attrs[attr]
    for vname in ['depth','latitude','longitude']:
        ds_out[vname].attrs[attr] = ds[vname].attrs[attr]

now = datetime.datetime.now().ctime()
attrs={'contents':'Estimated carbonate chemistry variables for GLORYS12V1 output',
       'history':'Created by Ivan Lima <ilima@whoi.edu> on {}'.format(now)}
ds_out.attrs.update(attrs)

## Save data into monthly files 

In [None]:
outdir = '/bali/data/ilima/GLORYS12V1/daily/BGC/3D'
for mon in range(1,13):
    outfile = os.path.join(outdir, 'GLORYS12V1_NW_Atlantic_{}-{:02d}_BGC.nc'.format(year, mon))
    ds_month = ds_out.sel(time=ds_out.time.dt.month.isin([mon]))
    print('writing {}'.format(outfile))
    ds_month.to_netcdf(outfile, mode='w', unlimited_dims=['time'])

In [None]:
# ds_out.to_netcdf('test.nc', mode='w', unlimited_dims=['time'])#, encoding={'zlib': True, 'complevel': 9})

In [None]:
# t = 366/2
# fig, axs = plt.subplots(2, 3, sharex=True, sharey=True, figsize=(15, 10))
# _ = ds_out.Temperature.isel(time=t, depth=0).plot(ax=axs[0,0])
# _ = ds_out.Salinity.isel(time=t, depth=0).plot(ax=axs[0,1], robust=True)
# _ = ds_out.DIC.isel(time=t, depth=0).plot(ax=axs[0,2], robust=True)
# _ = ds_out.TA.isel(time=t, depth=0).plot(ax=axs[1,0], robust=True)
# _ = ds_out.pH.isel(time=t, depth=0).plot(ax=axs[1,1], robust=True)
# _ = ds_out.saturation_aragonite.isel(time=t, depth=0).plot(ax=axs[1,2], robust=True)