In [None]:
import os
import pygmt
import struct
import imageio
import pygplates
import subprocess
import numpy as np
import xarray as xr
import pandas as pd
from pathlib import Path
import cartopy.crs as ccrs
from scipy.spatial import cKDTree
from scipy.interpolate import griddata
from scipy.ndimage import gaussian_filter
from scipy.spatial.transform import Rotation as scpRot

%matplotlib inline
%config InlineBackend.figure_format = 'retina'

This notebook builds a temporally refined paleo-elevation reconstruction for a coarser one. 

Here as an example we build a 1Ma paleo-elevation model out of a 5Ma interval one.

We first define a set of functions to perform the smaller time stepping interpolation using a backward/forward approach...

In [None]:
def getPaleoTopo(dem_folder, time):
    # Get the paleosurface mesh file (as netcdf file)
    paleoDemsPath = Path(dem_folder)
    initialLandscapePath = list(paleoDemsPath.glob("**/%dMa.nc" % int(time)))[0]
    # Open it with xarray
    data = xr.open_dataset(initialLandscapePath)
    return data.sortby(data.latitude)

def getPaleoRain(rain_folder, time, glon, glat):
    # Get the paleosurface mesh file (as netcdf file)
    paleoRainPath = Path(rain_folder)
    initialRainPath = list(paleoRainPath.glob("**/%dMa.nc" % int(time)))[0]
    # Open it with xarray
    data = xr.open_dataset(initialRainPath)
    datai = data.interp(lat=glat, lon=glon)
    return datai.sortby(data.lat)

def getPlateIDs(plate_folder, time, lonlat):
    # Read plate IDs from gPlates exports
    velfile = plate_folder + "vel" + str(int(time)) + "Ma.xy"
    data = pd.read_csv(
        velfile,
        sep=r"\s+",
        engine="c",
        header=None,
        na_filter=False,
        dtype=float,
        low_memory=False,
    )
    data = data.drop_duplicates().reset_index(drop=True)
    llvel = data.iloc[:, 0:2].to_numpy()
    gplateID = data.iloc[:, -1].to_numpy().astype(int)
    vtree = cKDTree(llvel)
    dist, ids = vtree.query(lonlat, k=1)
   
    return gplateID[ids]

def polarToCartesian(radius, theta, phi, useLonLat=True):
    if useLonLat == True:
        theta, phi = np.radians(theta+180.), np.radians(90. - phi)
    X = radius * np.cos(theta) * np.sin(phi)
    Y = radius * np.sin(theta) * np.sin(phi)
    Z = radius * np.cos(phi)
    
    #Return data either as a list of XYZ coordinates or as a single XYZ coordinate
    if (type(X) == np.ndarray):
        return np.stack((X, Y, Z), axis=1)
    else:
        return np.array([X, Y, Z])

def cartesianToPolarCoords(XYZ, useLonLat=True):
    X, Y, Z = XYZ[:, 0], XYZ[:, 1], XYZ[:, 2]
    R = (X**2 + Y**2 + Z**2)**0.5
    theta = np.arctan2(Y, X)
    phi = np.arccos(Z / R)
    
    #Return results either in spherical polar or leave it in radians
    if useLonLat == True:
        theta, phi = np.degrees(theta), np.degrees(phi)
        lon, lat = theta - 180, 90 - phi
        lon[lon < -180] = lon[lon < -180] + 360
        return R, lon, lat
    else:
        return R, theta, phi

def quaternion(axis, angle):
    return [np.sin(angle/2) * axis[0], 
            np.sin(angle/2) * axis[1], 
            np.sin(angle/2) * axis[2], 
            np.cos(angle/2)]

def getRotations(time, deltaTime, plateIds, rotationModel):
    rotations = {}
    for plateId in np.unique(plateIds):
        stageRotation = rotationModel.get_rotation(int(time-deltaTime), int(plateId), int(time))
        stageRotation = stageRotation.get_euler_pole_and_angle()
        axisLatLon = stageRotation[0].to_lat_lon()
        axis = polarToCartesian(1, axisLatLon[1], axisLatLon[0])
        angle = stageRotation[1]
        rotations[plateId] = scpRot.from_quat(quaternion(axis, angle))
    return rotations

def movePlates(sphereXYZ, plateIds, rotations):
    newXYZ = np.copy(sphereXYZ)
    for idx in np.unique(plateIds):
        rot = rotations[idx]
        newXYZ[plateIds == idx] = rot.apply(newXYZ[plateIds == idx])
    return newXYZ

def interpData(data,xyz,mvxyz,ngbh=1):
    # Build the kdtree
    ptree = cKDTree(mvxyz)
    distNbghs, idNbghs = ptree.query(xyz, k=ngbh)
    if ngbh == 1:
        return data[idNbghs]
    
    # Inverse weighting distance...
    weights = np.divide(
        1.0,
        distNbghs,
        out=np.zeros_like(distNbghs),
        where=distNbghs != 0,
    )
    onIDs = np.where(distNbghs[:, 0] == 0)[0]
    temp = np.sum(weights, axis=1)
    tmp = np.sum(weights * data[idNbghs], axis=1)
    # Elevation
    interpZ = np.divide(
        tmp, temp, out=np.zeros_like(temp), where=temp != 0
    )
    if len(onIDs) > 0:
        interpZ[onIDs] = data[idNbghs[onIDs, 0]]
    return interpZ

def runSubProcess(args, output=True, cwd="."):
    p = subprocess.Popen(
        args,
        cwd=cwd,
        stdout=subprocess.PIPE,
        stderr=subprocess.STDOUT,
        universal_newlines=True,
    )
    lines = []
    while True:
        line = p.stdout.readline()
        if not line and p.poll() is not None:
            break
        lines.append(line)
        if output:
            print(line, end="")

    if p.returncode != 0:
        output = "".join(lines)
        if "ERROR: " in output:
            _, _, error_msg = output.partition("ERROR: ")
        elif "what()" in output:
            _, _, error_msg = output.partition("what(): ")
        else:
            error_msg = "dbscan aborted unexpectedly."
        error_msg = " ".join(error_msg.split())

        raise RuntimeError(error_msg)

def clusterZ(time, mvxyz, elev, clustngbh=6, clustdist=10.e3, output=False, nprocs=4, cwd="."):
    if output:
        print("\ndbscan MPI")
    dims = [len(mvxyz), 3]
    linepts = mvxyz.ravel()
    lgth = len(linepts)
    fbin = "nodes" + str(time) + ".bin"
    with open(fbin, mode="wb") as f:
        f.write(struct.pack("i" * 2, *[int(i) for i in dims]))
        f.write(struct.pack("f" * (lgth), *[float(i) for i in linepts]))
    fnc = "clusters" + str(time) + ".nc"
    mpi_args = [
        "mpirun",
        "-np",
        str(nprocs),
        "dbscan",
        "-i",
        fbin,
        "-b",
        "-m",
        "2",
        "-e",
        str(clustdist),
        "-o",
        fnc,
    ]
    runSubProcess(mpi_args, output, cwd)
    if output:
        print("\nGet global ID of clustered vertices")
    cluster = xr.open_dataset(fnc)
    isClust = cluster.cluster_id.values > 0
    clustPtsX = cluster.position_col_X0.values[isClust]
    clustPtsY = cluster.position_col_X1.values[isClust]
    clustPtsZ = cluster.position_col_X2.values[isClust]
    clustPts = np.vstack((clustPtsX, clustPtsY))
    clustPts = np.vstack((clustPts, clustPtsZ)).T
    ptree = cKDTree(mvxyz)
    dist, ids = ptree.query(clustPts, k=1)
    isCluster = np.zeros(len(mvxyz), dtype=int)
    isCluster[ids] = 1
    idCluster = isCluster > 0
    ptsCluster = mvxyz[idCluster]
    ctree = cKDTree(ptsCluster)
    _, clustNgbhs = ctree.query(ptsCluster, k=clustngbh)
    clustNgbhs = clustNgbhs[:, 1:]
    args = [
        "rm",
        fbin,
        fnc,
    ]
    runSubProcess(args, output, cwd)

    # Get heights of nearest neighbours
    heightsInCluster = elev[idCluster]
    neighbourHeights = heightsInCluster[clustNgbhs]

    # For points in cluster, set new heights to the maximum height of
    # nearest neighbours
    clustZ = elev.copy()
    neighbourHeights.partition(1,axis=1)
    clustZ[idCluster] = np.mean(neighbourHeights[:,-int(clustngbh/2):], axis=1)

    return clustZ

def runSubProcess(args, output=True, cwd="."):
    # Launch a subprocess
    p = subprocess.Popen(
        args,
        cwd=cwd,
        stdout=subprocess.PIPE,
        stderr=subprocess.STDOUT,
        universal_newlines=True,
    )

    # Capture and re-print OpenMC output in real-time
    lines = []
    while True:
        # If OpenMC is finished, break loop
        line = p.stdout.readline()
        if not line and p.poll() is not None:
            break

        lines.append(line)
        if output:
            # If user requested output, print to screen
            print(line, end="")

    # Raise an exception if return status is non-zero
    if p.returncode != 0:
        # Get error message from output and simplify whitespace
        output = "".join(lines)
        if "ERROR: " in output:
            _, _, error_msg = output.partition("ERROR: ")
        elif "what()" in output:
            _, _, error_msg = output.partition("what(): ")
        else:
            error_msg = "dbscan aborted unexpectedly."
        error_msg = " ".join(error_msg.split())

        raise RuntimeError(error_msg)

We now define the main script that will perform either the forward or backward interpolation of the paleo-elevation. 

In [None]:
def runScript(time, dt, forward=False, rain=False):
    
    rotation_fname = 'PALEOMAP_PlateModel.rot'
    polygon_fname = 'PlateBoundaries.gpml'
    dem_folder = 'ndem/'
    rain_folder = 'rain/'
    plate_folder = 'vel1Ma/'
    radius = 6371*1000.
    
    rotationModel = pygplates.RotationModel(rotation_fname)
    topoFeature = pygplates.FeatureCollection(polygon_fname)
        
    paleoZ = getPaleoTopo(dem_folder,time[0])
    glon = paleoZ.longitude.values
    glat = paleoZ.latitude.values
    shape = paleoZ.z.shape
    lons, lats = np.meshgrid(glon, glat)
    
    lonlat = np.empty((len(lons.ravel()),2))
    lonlat[:,0] = lons.ravel()
    lonlat[:,1] = lats.ravel()
    
    if rain:
        paleoR = getPaleoRain(rain_folder, time[0], glon, glat)
        

    if forward:
        out_path = 'forward'
        out_path2 = 'backward'
    else:
        out_path = 'backward'
    if not os.path.exists(out_path):
        os.makedirs(out_path)
        
    if rain:
        rdata = []
        rdata.append(paleoR)
        newrdata = paleoR.copy()
        newrdata = newrdata.drop_vars(names=['z'])
        newrdata['z'] = (['lat', 'lon'],  paleoR.z.values)
        newrdata.to_netcdf(out_path+'/nrain'+str(time[0])+'Ma.nc')
    
    ndata = []
    ndata.append(paleoZ)
    nheights = gaussian_filter(paleoZ.z.values, sigma=0.75)
    newdata = paleoZ.copy()
    newdata = newdata.drop_vars(names=['z'])
    newdata['z'] = (['latitude', 'longitude'],  nheights)
    newdata.to_netcdf(out_path+'/ndem'+str(time[0])+'Ma.nc')
    
    print(' + start loop',time[0])
    val = [0.2,0.4,0.6,0.8]
    for s in range(len(time)-1):
    
        plateIds = getPlateIDs(plate_folder, time[s], lonlat)
        heights = ndata[s].z.values.ravel()
        sphericalZ = heights + radius
        sXYZ = polarToCartesian(sphericalZ, lonlat[:,0], lonlat[:,1])
        rotations = getRotations(time[s], dt, plateIds, rotationModel)

        movXYZ = movePlates(sXYZ, plateIds, rotations)
        
        nheights = clusterZ(time[s], movXYZ, heights)
        newZ = interpData(nheights,sXYZ,movXYZ).reshape(shape)
        
        if rain:
            newR = interpData(rdata[s].z.values.ravel(),sXYZ,movXYZ).reshape(shape)
            
        if forward:
            bdata = xr.open_dataset(out_path2+'/ndem'+str(time[s]-dt)+'Ma.nc')
            diff = (bdata.z.values - newZ)*val[s]
            nheights = gaussian_filter(newZ+diff, sigma=0.75)
            
        data = ndata[s].copy()
        data = data.drop_vars(names=['z'])
        if forward:
            data['z'] = (['latitude', 'longitude'],  nheights)
        else:
            data['z'] = (['latitude', 'longitude'],  newZ)
        data.to_netcdf(out_path+'/ndem'+str(time[s]-dt)+'Ma.nc')
        ndata.append(data)
        
        if rain:
            data = rdata[s].copy()
            data = data.drop_vars(names=['z'])
            data['z'] = (['lat', 'lon'],  newR)
            data.to_netcdf(out_path+'/nrain'+str(time[s]-dt)+'Ma.nc')
            rdata.append(data)
            
        print('    -  done time: ',time[s]-dt)

Define the time step from 100 Ma to 0:

In [None]:
ntime = np.arange(100,-5,-5)
ntime

Perform the caculations, it will create a new series of netcdf file at 1Ma interval:

In [None]:
for k in range(len(ntime)-1):
    
    # Backward 
    print('+ Backward run')
    dt = -1
    time = np.arange(ntime[k+1],ntime[k],-dt)
    runScript(time,dt,forward=False,rain=False)

    # Forward
    print('+ Forward run')
    dt = 1
    time = np.arange(ntime[k],ntime[k+1],-dt)
    runScript(time,dt,forward=True,rain=False)

Copy the last paleo-elevation:

In [None]:
args = ["cp", "backward/ndem0Ma.nc", "forward/"]
runSubProcess(args, True, '.')

In case you ran your model with rainfall map turned-on (actually it is maybe not the best approach for climate)

In [None]:
out_path1 = 'forward'
out_path2 = 'backward'
outf = 'comb_rain'

lst = []
for k in range(len(ntime)-1):
    p = 0
    for k in range(ntime[k],ntime[k+1],-1):
        if p < 3:
            filename = out_path1+'/nrain'+str(k)+'Ma.nc' 
            lst.append(filename)
        else:
            filename = out_path2+'/nrain'+str(k)+'Ma.nc' 
            lst.append(filename)
        p += 1
filename = out_path2+'/nrain0Ma.nc' 
lst.append(filename)

args = ["cp", "backward/nrain0Ma.nc", "forward/"]
runSubProcess(args, True, '.')

if not os.path.exists(outf):
    os.makedirs(outf)
for k in range(len(lst)):
    args = ["cp", lst[k], outf]
    runSubProcess(args, True, '.')

Perform the plotting of the elevation through time:

In [None]:
def getBounds(time,topology_features,rotation_model):
    resolved_topologies = []
    shared_boundary_sections = []
    pygplates.resolve_topologies(topology_features, rotation_model, resolved_topologies, time, shared_boundary_sections)
    wrapper = pygplates.DateLineWrapper(0.)
    subductions = []
    oceanRidges = []
    otherBounds = []
    for shared_boundary_section in shared_boundary_sections:
        if shared_boundary_section.get_feature().get_feature_type() == pygplates.FeatureType.create_gpml('MidOceanRidge'):
            for shared_sub_segment in shared_boundary_section.get_shared_sub_segments():
                split_geometry = wrapper.wrap(shared_sub_segment.get_geometry())
                for geometry in split_geometry:
                    X=[]
                    Y=[]
                    for point in geometry.get_points():
                        X.append(point.get_longitude()),Y.append(point.get_latitude())
                    x,y = X,Y
                    subductions.append([x,y])
        elif shared_boundary_section.get_feature().get_feature_type() == pygplates.FeatureType.create_gpml('SubductionZone'):
            for shared_sub_segment in shared_boundary_section.get_shared_sub_segments():
                split_geometry = wrapper.wrap(shared_sub_segment.get_geometry())
                for geometry in split_geometry:
                    X=[]
                    Y=[]
                    for point in geometry.get_points():
                        X.append(point.get_longitude()),Y.append(point.get_latitude())
                    x,y = X,Y
                    oceanRidges.append([x,y])
        else: 
            for shared_sub_segment in shared_boundary_section.get_shared_sub_segments():
                split_geometry = wrapper.wrap(shared_sub_segment.get_geometry())
                for geometry in split_geometry:
                    X=[]
                    Y=[]
                    for point in geometry.get_points():
                        X.append(point.get_longitude()),Y.append(point.get_latitude())
                    x,y = X,Y
                    otherBounds.append([x,y])   
    return subductions, oceanRidges, otherBounds

def plotElev(time, data, subductions, oceanRidges, otherBounds, out_path):
    fig = pygmt.Figure()
    with pygmt.config(FONT='6p,Helvetica,black'):
        pygmt.makecpt(cmap="geo", series=[-6000, 6000])
        fig.basemap(region='d', projection='W6i', frame='afg')
        viewset = data.z
        fig.grdimage(viewset, shading=True, frame=False)
#         fig.grdcontour(interval=0.1,grid=viewset,limit=[-0.1, 0.1])
        fig.colorbar(position="jBC+o0c/-1.5c+w8c/0.3c+h",frame=["a2000", "x+lElevation", "y+lm"])

        for k in range(len(subductions)):
            fig.plot(x=subductions[k][0], y=subductions[k][1], 
                     pen="1p,red", 
                     transparency="0")

        for k in range(len(oceanRidges)):
            fig.plot(x=oceanRidges[k][0], y=oceanRidges[k][1], 
                     pen="1p,white", 
                     transparency="0")

        for k in range(len(otherBounds)):
            fig.plot(x=otherBounds[k][0], y=otherBounds[k][1], 
                     pen="1p,purple", 
                     transparency="0")

    # Customising the font style
    fig.text(text=str(time)+" Ma", position="TL", font="8p,Helvetica-Bold,black") 
    fname = out_path+'/elev'+str(time)+'Ma.png'
    fig.savefig(fname=fname,dpi=500)
#     fig.show(dpi=500, width=1000)
    return

Save the paleo-elevation as `png` file

In [None]:
stime = np.arange(100,-1,-1)

rotation_fname = 'PALEOMAP_PlateModel.rot'
polygon_fname = 'PlateBoundaries.gpml'
out_elev = 'forward'

rotationModel = pygplates.RotationModel(rotation_fname)
topoFeature = pygplates.FeatureCollection(polygon_fname)

for k in range(len(stime)):
    subductions, oceanRidges, otherBounds = getBounds(int(stime[k]), topoFeature, rotationModel)
    elevfile = out_elev+'/ndem'+str(stime[k])+'Ma.nc'
    data = xr.open_dataset(elevfile)
    plotElev(stime[k], data, subductions, oceanRidges, otherBounds, out_elev)

Save as a movie:

In [None]:
stime = np.arange(100,-1,-1)
out_elev = 'forward'
images = []
for k in range(len(stime)):
    filename = out_elev+'/elev'+str(int(stime[k]))+'Ma.png'
    images.append(imageio.imread(filename))
                                    
imageio.mimsave('elev1Ma.mp4', images, fps=1)