#### Copyright (C) 2017 The University of Sydney, Australia
This program is free software; you can redistribute it and/or modify it under the terms of the GNU General Public License, version 2, as published by the Free Software Foundation.

This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for more details.

You should have received a copy of the GNU General Public License along with this program; if not, write to Free Software Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.

Author:  John Cannon

## Predict output using Chebyshev polynomials modelled on 2-dimensional input

The source code in this notebook denotes the output as 'z' and the inputs as 'x' and 'y'. For each paleo input x and y grid it generates an output z grid.

**Note** that this notebook does not include source code to model the output z data based on the input x and y data. It only uses the resultant modelled Chebyshev coefficients to generate z output from x and y inputs. In order to model z=f(x,y) you can use the GMT 'trend2d' command. For example:

    gmt trend2d -Fxym -N4r -V -W ...

...and use its verbose output to set the 'model_coefficients' variable (see the next notebook cell below) as well as setting the subsequent offset and scale variables.

The current model coefficients (and offsets/scales) in the next cell below are those used to generate results for the paper:

    Müller, R.D. and Dutkiewicz, A., 2018,
    Oceanic crustal carbon cycle drives 26 million-year atmospheric carbon dioxide periodicities,
    Science Advances, 4:eaaq0500, 1-7.
    
https://www.earthbyte.org/oceanic-crustal-carbon-cycle-drives-26-million-year-atmospheric-carbon-dioxide-periodicities/

...and, as such, this notebook outputs CO2 from the two inputs age and bottom water temperature (however, in general, any quantities could be used for x, y and z). All age and bottom water temperature grids can be found here: https://www.earthbyte.org/webdav/ftp/Data_Collections/Muller_Dutkiewicz_2018_SciAdv

...to run this notebook without any modifications you will first need to copy 'agegrid_0.nc', 'agegrid_1.nc', 'bottom_water_temp_0.nc' and 'bottom_water_temp_1.nc' from the above link to the '../data/' directory.

**NOTE**: This notebook requires GMT to be installed.


In [None]:
# Base filename and extension of input x and y rasters to sample.
x_filename_base, x_filename_ext = '../data/agegrid', 'nc'
y_filename_base, y_filename_ext = '../data/bottom_water_temp', 'nc'

# Base filename output z raster to generate.
z_filename_base = '../data/CO2'

# Define the time range.
# Used to get paleo x, y and z raster filenames (from base filenames).
min_time = 0
max_time = 1

# Lat/lon spacing in z output grid.
grid_spacing = 0.5

# Clamp output z minimum and maximum (individually applicable).
zrange = 0.0, None

# Whether input x and y data were modelled using the logarithm (log10) of the input data.
log_scale_input_x, log_scale_input_y = True, True

# Whether to grid output data using GMT 'surface' (if False then use 'nearneighbor' instead).
grid_using_gmt_surface = False

#
# The following model coefficients and offsets/scales are those used to generate results for the paper:
#
#   Müller, R.D. and Dutkiewicz, A., 2018,
#   Oceanic crustal carbon cycle drives 26 million-year atmospheric carbon dioxide periodicities,
#   Science Advances, 4:eaaq0500, 1-7.
#   
#   https://www.earthbyte.org/oceanic-crustal-carbon-cycle-drives-26-million-year-atmospheric-carbon-dioxide-periodicities/
#
#
# However in general the following values should be replaced with the actual values obtained by modelling
# a subset of the output z data based on the input x and y data.
#
# This can be achieved using GMT 'trend2d' such as running 'gmt trend2d -Fxym -N4r -V ...' to
# get 4 model coefficients from the stderr console output line 'Model Coefficients: '.
# The data passed to GMT 'trend2d' should be an xyz file where each row maps an x and y value to
# a z value. There can also be a weight (see GMT 'trend2d') in which case '-W' option should also
# be passed to GMT 'trend2d'.
#

# The Chebyshev coefficients obtained by modelling x and y.
# Note: There should always be 10 coefficients. If not all coefficients are used then
# the remaining ones should be set to 0.
model_coefficients = [
    0.475231946155, 0.590711766992, 0.872236363623, 0.960545304715, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
# The offset and scale for the input x and y data to get it into the range [-1,1].
# This is what GMT 'trend2d' does to get the data in the range of Chebyshev polynomials.
# This is obtained by performing the following calculations on the input x and y data used in modelling:
#
#    model_offset_x = 0.5 * (min_modelled_x + max_modelled_x)
#    model_offset_y = 0.5 * (min_modelled_y + max_modelled_y)
#    model_scale_x  = 2.0 / (max_modelled_x - min_modelled_x)
#    model_scale_y  = 2.0 / (max_modelled_y - min_modelled_y)
#
# ...where the min/max values are of the data used when modelling the Chebyshev polynomials
# (eg, the data passed to GMT 'trend2d').
model_offset_x, model_offset_y = 0.121022119685, 0.15051499783
model_scale_x, model_scale_y = 0.471470802081, 0.869175979354

In [None]:
from __future__ import print_function
import sys
# Add directory containing the 'ptt' module (Plate Tectonic Tools) to the Python path.
sys.path.append('../../')
from ptt.utils.call_system_command import call_system_command
import math
import os
import sys


def predict_grid_chebyshev(
        input_points, # List of (lon, lat) tuples,
        x_grid_filename,
        y_grid_filename,
        model_offset_x,
        model_offset_y,
        model_scale_x,
        model_scale_y,
        log_scale_input_x,
        log_scale_input_y,
        model_coefficients,
        # Optional 2-tuple (min, max) to clamp output z values to.
        # Also either min or max can be None, eg, (0, None) to clamp only to min of zero...
        zrange = None):
    
    # Extract optional min/max to clamp z values.
    zmin = None
    zmax = None
    if zrange is not None:
        zmin, zmax = zrange
    
    # If we're doing logarithmic scaling then x and/or y values must be positive and non-zero in
    # which case we set the minimum to a small positive number.
    min_x = 1e-12 if log_scale_input_x else None
    min_y = 1e-12 if log_scale_input_y else None
    
    # Get the input x and y data at the input points.
    lon_lat_x_list = get_positions_and_scalars(input_points, x_grid_filename, min_x)
    lon_lat_y_list = get_positions_and_scalars(input_points, y_grid_filename, min_y)
    if not lon_lat_x_list or not lon_lat_y_list:
        # There are no input x values or no input y values so return empty list of z values.
        return []
    
    # Merge the x and y lists.
    # Only keep points where there are x *and* y values.
    lon_lat_x_y_list = []
    x_dict = dict(((lon, lat), x) for lon, lat, x in lon_lat_x_list)
    for lon, lat, y in lon_lat_y_list:
        if (lon, lat) in x_dict:
            x = x_dict[(lon, lat)]
            lon_lat_x_y_list.append((lon, lat, x, y))
    
    # For each input point predict z using GMT trend2d polynomial model z=f(x,y).
    lon_lat_z_list = []
    for lon, lat, x, y in lon_lat_x_y_list:
        
        z = chebyshev_polynomial_function(
                x, y,
                model_offset_x, model_offset_y,
                model_scale_x, model_scale_y,
                log_scale_input_x, log_scale_input_y,
                model_coefficients)
        
        # Optional min/max z clamping.
        if zmin is not None and z < zmin:
            z = zmin
        if zmax is not None and z > zmax:
            z = zmax
        
        lon_lat_z_list.append((lon, lat, z))
    
    return lon_lat_z_list
    
    
def chebyshev_polynomial_function(
        x, y,
        offset_x, offset_y,
        scale_x, scale_y,
        log_scale_input_x, log_scale_input_y,
        m):
    """
    Combines Chebyshev basis polynomials in x and y with transformed x and y data to get output z.
    """

    # Shift/scale the data into the range of the input data passed to GMT 'trend2d'.
    x, y = transform_2d(x, y, offset_x, offset_y, scale_x, scale_y, log_scale_input_x, log_scale_input_y)
    
    # The Chebyshev linear polynomial is T1(x) = x
    # The Chebyshev quadratic polynomial is T2(x) = 2*x*x - 1
    # The Chebyshev cubic polynomial is T3(x) = 4*x*x*x - 3*x
    #
    # Note: When the model number is less than 10 the last (10 - model_number) coefficients
    # will be zero and have no effect - we do this to avoid 10 'if' statements.
    model = (m[0] +
        m[1] * x +
        m[2] * y +
        m[3] * x*y +
        m[4] * (2*x*x - 1) +
        m[5] * (2*y*y - 1) +
        m[6] * (4*x*x*x - 3*x) +
        m[7] * (2*x*x - 1) * (y) +
        m[8] * (x) * (2*y*y - 1) +
        m[9] * (4*y*y*y - 3*y))
    
    return model


# Shift/scale the data into the range of the original input data (the data that was modelled).
def transform_2d(x, y, offset_x, offset_y, scale_x, scale_y, log_scale_input_x, log_scale_input_y):
    if log_scale_input_x:
        x = math.log10(x)
    if log_scale_input_y:
        y = math.log10(y)
        
    # We need to offset and scale the input data into the range [-1,1].
    # This is what GMT 'trend2d' does to get the data in the range of Chebyshev polynomials.
    x = (x - offset_x) * scale_x
    y = (y - offset_y) * scale_y
    
    return (x, y)

In [None]:
def generate_input_points_grid(
        grid_spacing_degrees):
    """
    Generate a global grid of uniform points in latitude and longitude.
    """
    
    if grid_spacing_degrees <= 0:
        raise ValueError('Grid spacing must be non-zero and positive.')
    
    input_points = []
    
    # Data points start *on* dateline (-180).
    # If 180 is an integer multiple of grid spacing then final longitude also lands on dateline (+180).
    num_latitudes = int(math.floor(180.0 / grid_spacing_degrees)) + 1
    num_longitudes = int(math.floor(360.0 / grid_spacing_degrees)) + 1
    for lat_index in range(num_latitudes):
        lat = -90 + lat_index * grid_spacing_degrees
        
        for lon_index in range(num_longitudes):
            lon = -180 + lon_index * grid_spacing_degrees
            
            input_points.append((lon, lat))
    
    return (input_points, num_longitudes, num_latitudes)


def get_positions_and_scalars(
        input_points,
        scalar_grid_filename,
        min_scalar=None):
    """
    Returns a list of scalars (one per (lon, lat) point in the 'input_points' list).
    For input points outside the scalar grid then scalars will be Nan (ie, 'math.isnan(scalar)' will return True).
    """
    
    input_points_data = ''.join('{0} {1}\n'.format(lon, lat) for lon, lat in input_points)

    # The command-line strings to execute GMT 'grdtrack'.
    grdtrack_command_line = ["gmt", "grdtrack", "-nl", "-G{0}".format(scalar_grid_filename)]
    stdout_data = call_system_command(grdtrack_command_line, stdin=input_points_data, return_stdout=True)
    
    lon_lat_scalar_list = []
    
    # Read lon, lat and scalar values from the output of 'grdtrack'.
    for line in stdout_data.splitlines():
        if line.strip().startswith(('#', '>')):
            continue
        
        line_data = line.split()
        num_values = len(line_data)
        
        # If just a line containing white-space then skip to next line.
        if num_values == 0:
            continue
        
        if num_values < 3:
            print('Ignoring line "{0}" - has fewer than 3 white-space separated numbers.'.format(line), file=sys.stderr)
            continue
            
        try:
            # Convert strings to numbers.
            lon = float(line_data[0])
            lat = float(line_data[1])
            
            # The scalar got appended to the last column by 'grdtrack'.
            scalar = float(line_data[-1])
            
            # If the point is outside the grid then the scalar grid will return 'NaN'.
            if math.isnan(scalar):
                #print('Ignoring line "{0}" - point is outside scalar grid.'.format(line), file=sys.stderr)
                continue
            
            # Clamp to min value if requested.
            if (min_scalar is not None and
                scalar < min_scalar):
                scalar = min_scalar
            
        except ValueError:
            print('Ignoring line "{0}" - cannot read floating-point lon, lat and scalar values.'.format(line), file=sys.stderr)
            continue
        
        lon_lat_scalar_list.append((lon, lat, scalar))
    
    return lon_lat_scalar_list

In [None]:
def write_predicted_data(
        z_base_filename,
        lon_lat_z_data,
        grid_spacing,
        grid_using_gmt_surface=False,
        # The following are only used when 'grid_using_gmt_surface' is True...
        x_grid_filename=None,
        y_grid_filename=None,
        # Optional 2-tuple (min, max) to clamp output z values to.
        # Also either min or max can be None, eg, (0, None) to clamp only to min of zero...
        zrange=None):
    
    z_xyz_filename = u'{0}.xy'.format(z_base_filename)
    write_xyz_file(z_xyz_filename, lon_lat_z_data)
    
    z_grd_filename = u'{0}.nc'.format(z_base_filename)
    write_grd_z_file(
            z_grd_filename,
            z_xyz_filename,
            grid_spacing,
            grid_using_gmt_surface, x_grid_filename, y_grid_filename, zrange)


def write_xyz_file(output_filename, output_data):
    with open(output_filename, 'w') as output_file:
        for output_line in output_data:
            output_file.write(' '.join(str(item) for item in output_line) + '\n')


def write_grd_z_file(
        z_grd_filename,
        z_xyz_filename,
        grid_spacing,
        grid_using_gmt_surface=False,
        # The following are only used when 'grid_using_gmt_surface' is True...
        x_grd_filename=None,
        y_grd_filename=None,
        zrange=None):
    
    if grid_using_gmt_surface:
        
        # The command-line strings to execute GMT 'surface'.
        gmt_surface_command_line = [
                "gmt",
                "surface",
                z_xyz_filename,
                "-R{0}/{1}/{2}/{3}".format(-180, 180, -90, 90),
                "-I{0}".format(grid_spacing),
                "-T0.5",
                "-fg",
                # Use GMT gridline registration since our input point grid has data points on the grid lines.
                # Gridline registration is the default so we don't need to force pixel registration...
                #"-r", # Force pixel registration since data points are at centre of cells.
                "-G{0}=nf".format(z_grd_filename)]
        
        # Extract optional min/max to clamp z values.
        zmin = None
        zmax = None
        if zrange is not None:
            zmin, zmax = zrange
        if zmin is not None:
            gmt_surface_command_line.append("-Ll{0}".format(zmin))
        if zmax is not None:
            gmt_surface_command_line.append("-Lu{0}".format(zmax))
        
        call_system_command(gmt_surface_command_line)
        
        # Mask z grid file against both x and y grid files since GMT surface produces
        # z values everywhere (including masked NaN regions).
        temp_filename = "tmp_gmt_surface_masking.nc"
        gmt_grdmath_command_line = [
                "gmt",
                "grdmath",
                "-R{0}/{1}/{2}/{3}".format(-180, 180, -90, 90),
                z_grd_filename,
                x_grd_filename,
                "OR",
                "=",
                "{0}=nf".format(temp_filename)]
        call_system_command(gmt_grdmath_command_line)
        
        gmt_grdmath_command_line = [
                "gmt",
                "grdmath",
                "-R{0}/{1}/{2}/{3}".format(-180, 180, -90, 90),
                temp_filename,
                y_grd_filename,
                "OR",
                "=",
                "{0}=nf".format(z_grd_filename)]
        call_system_command(gmt_grdmath_command_line)
        
        if os.access(temp_filename, os.R_OK):
            os.remove(temp_filename)
        
    else:
        # The command-line strings to execute GMT 'nearneighbor'.
        gmt_nearneighbor_command_line = [
                "gmt",
                "nearneighbor",
                z_xyz_filename,
                "-N4",
                "-S{0}d".format(1.5 * grid_spacing),
                "-I{0}".format(grid_spacing),
                "-R{0}/{1}/{2}/{3}".format(-180, 180, -90, 90),
                "-n+bg",
                "-fg",
                # Use GMT gridline registration since our input point grid has data points on the grid lines.
                # Gridline registration is the default so we don't need to force pixel registration...
                #"-r", # Force pixel registration since data points are at centre of cells.
                "-G{0}=nf".format(z_grd_filename)]
        
        call_system_command(gmt_nearneighbor_command_line)

In [None]:
# Generate the input points.
input_points, _, _ = generate_input_points_grid(grid_spacing)

# Reconstruction times.
reconstruction_times = range(min_time, max_time + 1)

# Iterate over time steps and generate the output z grids (one for each time).
for time_index, reconstruction_time in enumerate(reconstruction_times):
    
    # Determine paleo raster filenames.
    x_input_grid_filename = '{0}_{1}.{2}'.format(x_filename_base, reconstruction_time, x_filename_ext)
    y_input_grid_filename = '{0}_{1}.{2}'.format(y_filename_base, reconstruction_time, y_filename_ext)
    z_output_base_filename = '{0}_{1}'.format(z_filename_base, reconstruction_time)
    
    lon_lat_z_list = predict_grid_chebyshev(
            input_points,
            x_input_grid_filename, y_input_grid_filename,
            model_offset_x, model_offset_y,
            model_scale_x, model_scale_y,
            log_scale_input_x, log_scale_input_y,
            model_coefficients,
            zrange)
    
    write_predicted_data(
            z_output_base_filename,
            lon_lat_z_list,
            grid_spacing,
            grid_using_gmt_surface,
            x_input_grid_filename,
            y_input_grid_filename,
            zrange)