In [48]:
import pygmt
import xarray as xr
import pandas as pd
import numpy as np
import netCDF4 as nc
import os
import shutil
import dask
from astropy.io import ascii as asx
from astropy.table import Table
import astropy
import subprocess
import json
from io import StringIO
import math

import time

list_of_depths = [1, 47, 109, 202, 293, 396, 487, 597, 660, 726, 798, 874, 954, 1040, 1130, 1224, 1324, 1428, 1536, 1639, 1737, 1830, 1916, 1998, 2074, 2210, 2324, 2416, 2514, 2595, 2704, 2813, 2867]

#extended_list_of_depths = [2867, 2840, 2813, 2785, 2758, 2731, 2704, 2677, 2650, 2623, 2595, 2568, 2541, 2514, 2487, 2454, 2416, 2373, 2324, 2270, 2210, 2145, 2074, 1998, 1916, 1830, 1737, 1639, 1536, 1428, 1324, 1224, 1130, 1040, 954, 874, 798, 726, 660, 597, 540, 487, 439, 396, 357, 323, 293, 268, 248, 233, 217, 202, 186, 171, 155, 140, 124, 109, 93, 78, 62, 47, 31, 16, 0]


In [1]:
## Loop through all grids and re-sample them at "sample_interval" (string) grid spacing 

def resample(grids_loc, sample_interval, out_dir):
    grds = []
    
    for fn in sorted(os.listdir(grids_loc)):
        filename = os.path.join(grids_loc, fn)
        
        grds.append(filename)
        
    for infile in grds:
        
        infile_comps = infile.split('/')
        
        outfile = out_dir+'/'+infile_comps[-1]
        
        pygmt.grdsample(grid = infile, outgrid = outfile, spacing = sample_interval, region = 'd', verbose = 'e')
        
    

In [3]:
# pad zeros to the front of file name depths to make all file name depths(km) 4 digits (e.g. 47 km becomes 0047 km)
# ensures files are sorted correctly when read in for other loops over files

def mantle_flow_sort(list_of_depths, source_dir, out_dir):
    
    str_depths = []
    
    test_0km = '0'
    test_0km = test_0km.zfill(4)+'km'
    
    for i in list_of_depths:
        str_dpt = str(i).zfill(4)+'km'
        
        str_depths.append(str_dpt)
         
    for filename in sorted(os.listdir(source_dir)):
        fn = os.path.join(source_dir, filename)
        for dpt in str_depths:
            
            if dpt in filename:
                shutil.copy(fn, out_dir)
                
            elif test_0km in filename:
                shutil.copy(fn, out_dir)
                break

# Cross-sections

## Functions

In [4]:
## PyGMT grdtrack function
## LOOPS OVER NETCDF FILES AND SAMPLES THEM AT POINTS DEFINED IN A SHAPEFILE --> OUTPUTS A SET OF ASCII FILES

## First function runs in PyGMT, second runs in GMT as a subprocess

## transect = shapefile
## grids_loc = directory for inpit netCDF files
## out_loc = directory to write output ASCII files to
## list_of_depths = list of each depth being sampled
## wave_type = string indicating which wave type is being sampled --> _pw, _sw or _b

def make_transects(transect, grids_loc, out_loc, list_of_depths, wave_type):
    
    grds = []
    
    for fn in sorted(os.listdir(grids_loc)):
        filename = os.path.join(grids_loc, fn)
        
        grds.append(filename)
        
    count = 0
    
    out_ext = ".txt"
    
    for grd in grds:
        
        dpt_string = str(list_of_depths[count])
        
        if len(dpt_string) < 4:
            if len(dpt_string) == 3:
                ofn_string = '0'+dpt_string+wave_type+out_ext
            
            if len(dpt_string) == 2:
                ofn_string = '00'+dpt_string+wave_type+out_ext
            
            if len(dpt_string) == 1:
                ofn_string = '000'+dpt_string+wave_type+out_ext
        else:
            ofn_string = dpt_string+wave_type+out_ext
            
        
        outfile_name = os.path.join(out_loc, ofn_string)
        
        pygmt.grdtrack(transect, grd, outfile = outfile_name)
        
        count = count + 1
        

def grdtrack_loop(pnts, grids_loc, out_loc, list_of_depths, wave_type):
    
    grds = []
    
    for fn in sorted(os.listdir(grids_loc)):
        filename = os.path.join(grids_loc, fn)
        
        grds.append(filename)
    
    count = 0
    
    
    for grd in grds:
        
        dpt_string = str(list_of_depths[count])
        
        if len(dpt_string) < 4:
            if len(dpt_string) == 3:
                ofn_string = '0'+dpt_string+wave_type+'.txt'
            
            if len(dpt_string) == 2:
                ofn_string = '00'+dpt_string+wave_type+'.txt'
            
            if len(dpt_string) == 1:
                ofn_string = '000'+dpt_string+wave_type+'.txt'
        else:
            ofn_string = dpt_string+wave_type+'.txt'
            
        
        outfile = os.path.join(out_loc, ofn_string)
        
        gridfile = '-G'+grd
        
        subprocess.run(['gmt', 'grdtrack', pnts, gridfile, '>', outfile])
        
        count = count + 1

        
####################################################################
####################################################################
####################################################################
####################################################################


## COMBINES GRDTRACK ASCII FILES INTO ONE FILE
## For cross-section production - After looping over all depths using grdtrack to sample along a transect, combine all transects into one ascii file
## Finds the lowest variation of lat/lon and replaces that column with non-dimensional depth --> e.g. if sampling along a latitude line, lat is removed


## in_dir = location of ASCII files from grdtracks function
## out_dir = location to write the output concatenated ASCII file --> single .txt file
## out_str = string giving the desired name of the output file in the form --> {name}.txt
## lonlat = integer 0 or 1 to specify whether we replace lat or lon --> set to 2 if you want it to set based on variation in each column (remove the column with the least variation)
## Generally leave lonlat set to 2

def concat_ascii(in_dir, out_dir, out_str, list_of_depths, lonlat):
    
    table_list = []
    
    ascii_files = []
    
    for fn in sorted(os.listdir(in_dir)):
        filename = os.path.join(in_dir, fn)
        
        ascii_files.append(filename)
    
    
    if lonlat == 2:
        t_table = asx.read(ascii_files[0])
        
        lon_var = (t_table['col1'].max()) - (t_table['col1'].min())
        
        lat_var = (t_table['col2'].max()) - (t_table['col2'].min())
        
        if abs(lon_var) > abs(lat_var):
            lonlat = 1
            
        else:
            lonlat = 0
        
    count = 0
    
    for txt in ascii_files:
        
        ascii_table = asx.read(txt)
        
        t_len = len(ascii_table)
        
        dpt_list = [list_of_depths[count]] * t_len
        
        count = count+1
        
        col_name = ascii_table.colnames[lonlat]
        
        ascii_table[col_name] = dpt_list
        
        table_list.append(ascii_table)
        
    
    comb_table = astropy.table.vstack(table_list)
    
    comb_table.rename_columns(('col1', 'col2', 'col3'), ('lon', 'lat', 'vote'))
    
    rename_col = comb_table.colnames[lonlat]
    comb_table.rename_column(rename_col, 'depth')
    
    if lonlat == 0:
        comb_table = comb_table['lat', 'depth', 'vote']
        
    comb_table['depth'] = (6371. - comb_table['depth'])/6371.
    
    outfile_name = os.path.join(out_dir, out_str)
    
    asx.write(comb_table, output = outfile_name, delimiter = '\t')
        

####################################################################
####################################################################
####################################################################
####################################################################


## FUNCTION TO RUN BOTH BLOCKMEDIAN AND SURFACE
## Function to run in PyGMT and in GMT as a subprocess, sometime GMT provides more functionality but these are not always used

def cross_section(in_dat, out_bm, out_surf, inc, reg, ct, head):
    
    pygmt.blockmedian(data = in_dat, outfile = out_bm, spacing = inc, region = reg, coltypes = ct, header = head)

    
    pygmt.surface(data = out_bm, outgrid = out_surf, spacing = inc, region = reg, coltypes = ct)
    
    
    

def gmt_cross_section(bm_infile, increment, region, coltypes, header, bm_out, surf_out, upper_boundary, lower_boundary, tension):
    
    subprocess.run(['gmt', 'blockmedian', bm_infile, increment, region, coltypes, header, '>', bm_out])
    
    subprocess.run(['gmt', 'surface', bm_out, surf_out, increment, region, upper_boundary, lower_boundary, tension])
    
    

####################################################################
####################################################################
####################################################################
####################################################################

## Function to run all above functions in order to produce a cross-section netCDF file that can be plotted as a cross-section

def make_cross_section(transect, list_of_depths, vm_loc, vm_trans_out, wave_type, mf_loc, mf_trans_out, vm_work_dir,
                      mf_work_dir, region, coltypes, header, vm_upper, vm_lower, mf_upper, mf_lower):
    
    make_transects(transect, vm_loc, vm_trans_out, list_of_depths, wave_type)
    
    make_transects(transect, mf_loc, mf_trans_out, list_of_depths, '_mf')
    
    concat_ascii(vm_trans_out, vm_work_dir, 'tomo_concat.txt', list_of_depths, 2)
    
    concat_ascii(mf_trans_out, mf_work_dir, 'mf_concat.txt', list_of_depths, 2)
    
    vm_cs = vm_work_dir+'tomo_concat.txt'
    
    vm_bm = vm_work_dir+'tomo_bm.txt'
    
    vm_surf = '-G'+vm_work_dir+'tomo_surf.nc'
    
    gmt_cross_section(vm_cs, '-I0.5/0.005', region, coltypes, '-h1', vm_bm, vm_surf, vm_upper, vm_lower, '-T0')

    mf_cs = mf_work_dir+'mf_concat.txt'
    
    mf_bm = mf_work_dir+'mf_bm.txt'
    
    mf_surf = '-G'+mf_work_dir+'mf_surf.nc'
    
    gmt_cross_section(mf_cs, '-I0.5/0.005', region, coltypes, '-h1', mf_bm, mf_surf, mf_upper, mf_lower, '-T0')
    
    


In [2]:
### PLOTTING CROSS-SECTIONS CREATED ABOVE BY THE "make_cross_section" AND PLOTTING MULTIPLE CROSS-SECTIONS AS A SUBPLOT

def gmt_plot_cs(grdfile, region, projection):
    
    subprocess.run(['gmt', 'begin', 'map'])
    
    subprocess.run(['gmt', 'grdimage', grdfile, '-B', region, projection, '-Ei'])
    
    subprocess.run(['gmt', 'colorbar'])
    
    subprocess.run(['gmt', 'end', 'show'])


def pygmt_polar_plot(grd, frame, dpi, reg, proj, cb_frame, out_fig):
    
    ## Initialise a figure
    fig = pygmt.Figure()
    
    # add a colourmap of a gridfile (.nc) to the figure
    fig.grdimage(grid = grd, frame = frame, dpi = dpi, region = reg, projection = proj)
    
    # add a colourbar to show the scaling of the colourmap
    fig.colorbar(frame= cb_frame)
    
    
    ## Save the figure
    fig.savefig(out_fig)
    
    
###################################################


def subplots(fsize, cs_grd, cs_frame, cs_reg, cs_proj, cb_frame, yshift, line, depth_reg):
    
    fig = pygmt.Figure()
    
    fig.subplot(nrows = 2, ncols = 1, figsize = fsize)
    
    fig.set_panel(panel=[0,1])
    fig.grdimage(grid = cs_grd, frame = cs_frame, region = cs_reg, dpi = 'i', projection = cs_proj)
    fig.basemap(projection =  cs_proj, region = depth_reg, frame = ['E', 'y+lDepth(km)'])
    fig.colorbar(frame= cb_frame)
    
    fig.shift_origin(yshift=yshift)
    
    fig.set_panel(panel=[0,0])
    fig.basemap(region='d', projection="N15c", frame=["af", "WSne"])
    fig.coast(shorelines="0.5p,black", projection='N15c')
    fig.plot(data = line, pen = '1.5p,red', projection='N15c')
    
    return(fig)

# Comparison maps and statistics

## Functions

In [3]:
### CREATING COMPARISON MAPS --> COMPARES THE LOCATION OF SUBDUCTED SLABS IN BOTH TOMOGRAPHY AND THE CONVECTION MODEL AND PRODUCES A MAP DISPLAYING THE DIFFERENCE

# READ IN FILES FOR TOMOGRAPHY AND CONVECTION MODEL AND CLIPS EACH TO A DESIRED LIMIT
# I use mf_above(no slab) set to [limit, 3], mf_below(slab) set to [limit, 2], tomo_above(slab) set to [limit, 1] and tommo_below(slab) set to [limit, 4]
# Those values provide unique sums for grdmath so that true positive = 3, false negative = 4, false positive = 6, true negative = 7

def comp_map(tomo_grid, tomo_out, tomo_above, tomo_below, mf_grid, mf_out, mf_above, mf_below, comp_out):
    
    pygmt.grdclip(grid = tomo_grid, outgrid = tomo_out, above = tomo_above, below = tomo_below)
    
    pygmt.grdclip(grid = mf_grid, outgrid = mf_out, above = mf_above, below = mf_below)
    
    subprocess.run(['gmt', 'grdmath', tomo_out, mf_out, 'ADD', '=', comp_out])
    
###################################################  

## Function to loop through all files (in depth order) and perform the above comparison function on them

def comp_grds(mf_dir, mf_out_dir, mf_above, mf_below, tomo_dir, tomo_out_dir, tomo_above, tomo_below, out_dir, reconstruction_string):
    
    mf_files = []
    
    for file in sorted(os.listdir(mf_dir)):
        fn = os.path.join(mf_dir, file)
        
        mf_files.append(fn)
        
    tomo_files = []
    
    for file in sorted(os.listdir(tomo_dir)):
        fn = os.path.join(tomo_dir, file)
        
        tomo_files.append(fn)
            
    mf_files = mf_files[1:-1]
    
    tomo_files = tomo_files[1:-1]
    
    for x, mf_slice in enumerate(mf_files):
        
        infile_tomo_comps = tomo_files[x].split('/')
        
        depth = (infile_tomo_comps[-1])[0:6]
        
        comp_out = out_dir+'/'+depth+'_'+reconstruction_string+'.nc'
        
        tomo_out = tomo_out_dir+'/'+infile_tomo_comps[-1]
        
        mf_out = mf_out_dir+'/'+depth[0:4]+'_'+reconstruction_string+'.nc'
        
        comp_map(tomo_files[x], tomo_out, tomo_above, tomo_below, mf_slice, mf_out, mf_above, mf_below, comp_out)
        
        
#####################################################

## loop through grids made in above function and map them as a horizontal depth slice, label depth in the top left corner

def horizontal_fig_create(grds_dir, colour_map, output_dir, file_type):
    
    comp_grds = []
    
    for file in sorted(os.listdir(grds_dir)):
        fn = os.path.join(grds_dir, file)
        
        comp_grds.append(fn)
        
    for x, grd in enumerate(comp_grds):
        
        fig = pygmt.Figure()
        
        fig.basemap(region='d', projection="N15c", frame=["af", "WSne"])
        
        fig.grdimage(grid = grd, projection="N15c", cmap = colour_map, interpolation = 'n+a')
        
        fig.coast(shorelines="0.5p,black", projection='N15c')
        
        d_string = str(list_of_depths[x+1]) + " km"
        
        fig.text(text = d_string, position = 'TL', font = '14p')
        
        grd_comps = grd.split('/')
        
        depth = (grd_comps[-1])[0:4]
        
        fig_name = output_dir+'/'+depth+'_comp_map'+file_type
        
        fig.savefig(fig_name)


In [4]:
## read in netCDF file from above comparison grids and count the occurrences of each value, calculate comparison success from those values

def comp_accuracy(comp_grid):
    
    comp_xrarray = xr.open_dataarray(comp_grid)
    
    comp_array = comp_xrarray.to_numpy()
    
    true_positive = np.count_nonzero(comp_array == 3)
    false_negative = np.count_nonzero(comp_array == 4)
    false_positive = np.count_nonzero(comp_array == 6)
    true_negative = np.count_nonzero(comp_array == 7)
    
    accuracy = (true_positive + true_negative)/(true_positive + true_negative + false_positive + false_negative)
    
    if (true_positive+false_negative) != 0:
        TPR = true_positive/(true_positive+false_negative)
    elif (true_positive+false_negative) == 0:
        TPR = 0.
    else:
        TPR = -99.
    
    TNR = true_negative/(false_positive+true_negative)
    
    F1_score = (2*true_positive)/(2*true_positive + false_positive + false_negative)
    
    precision = true_positive/(true_positive + false_positive)
    
    return([accuracy, TPR, TNR, F1_score, precision])

##################################################

## Read in folder containing comparison grids for all depths and calculate the success metrics, store in lists for plotting

def comp_stats(comp_dir):
    
    comp_grds = []
    
    for file in sorted(os.listdir(comp_dir)):
        fn = os.path.join(comp_dir, file)
        
        comp_grds.append(fn)
    
    acc_list = []
    
    tpr_list = []
    
    tnr_list = []
    
    f1_list = []
    
    precision_list = []
    
    for comp_grd in comp_grds:
        
        comp_nums = comp_accuracy(comp_grd)
        
        acc_list.append(comp_nums[0])
        
        tpr_list.append(comp_nums[1])
        
        tnr_list.append(comp_nums[2])
        
        f1_list.append(comp_nums[3])
        
        precision_list.append(comp_nums[4])
        
    return(acc_list, tpr_list, tnr_list, f1_list, precision_list)

# Find slab thresholds

## Functions

In [5]:
## For a given model type (tomography or convection model) read in the vote or temp grids for all depths and set a threshold
## Calculate the slab area at each depth for the given threshold, grids are projected to Earth's surface so area is given as a percentage of Earth surface area

def slab_area(in_dir, limit, list_of_depths, model_type):

    #if model_type == 'tomo':
    #    limit = limit
    
    ## Set contour limit, appending 'r' to the start gives below a given limit, rather than above
    if model_type == 'mf':
        limit = 'r'+str(limit)
        
    #else:
    #    return('Please set "model_type" parameter to either "tomo" or "mf"')
    
    ## Create a list of filenames directing towards each horizontal depth slice map, ordered for depth
    fn_list = []
    for file in sorted(os.listdir(in_dir)):
        fn = os.path.join(in_dir, file)
        
        fn_list.append(fn)
    
    count = 0
    
    for fn in fn_list:
        
        if count == 0:
            count = count+1

        elif count == 1:
            df_area = pygmt.grdvolume(fn, region = 'd', contour = limit)
            df_area['depth'] = list_of_depths[count]
            count = count + 1
        
        elif count < (len(list_of_depths)-1):
            df_area_concat = pygmt.grdvolume(fn, contour = limit, region = 'd')
            df_area_concat['depth'] = list_of_depths[count]
                       
            df_area = pd.concat([df_area, df_area_concat], axis=0)
            
            count = count + 1
            
        else:
            break
        
    area_list = df_area[1].tolist()
    
    area_percent_list = []
    
    for i in np.arange(0, len(area_list)):
        
        percent = area_list[i]*100./5.10047818418e+14
        
        area_percent_list.append(percent)
        
    return(area_percent_list)


#################################################

## Test a range of thresholds for both tomo and convection model and store the slab areas at each depth that each threshold produces

def thresholds(mf_thresh, mf_dir, tomo_thresh, tomo_dir, list_of_depths):
    
    mf_slab_perc = []
    
    for mf_limit in mf_thresh:
        
        mf_percent_list = slab_area(mf_dir, mf_limit, list_of_depths, 'mf')
        
        mf_slab_perc.append(mf_percent_list)
        
    tomo_slab_perc = []
    
    for t_limit in tomo_thresh:
        
        tomo_percent_list = slab_area(tomo_dir, t_limit, list_of_depths, 'tomo')
        
        tomo_slab_perc.append(tomo_percent_list)
    
    comb_list = []
    
    hm_meandiff = []
    
    hm_meanperc_diff = []
    
    for mf_ind, mf_percent_list in enumerate(mf_slab_perc):
        
        hm_md_row = []
        
        hm_mdperc_row = []
        
        for t_ind, tomo_percent_list in enumerate(tomo_slab_perc):
            
            diff_list = []
            perc_diff_list = []
            for x in range(len(mf_percent_list)):
                
                diff = abs(mf_percent_list[x] - tomo_percent_list[x])
                
                diff_list.append(diff)
                
                perc_diff = (min(mf_percent_list[x],tomo_percent_list[x])/max(mf_percent_list[x],tomo_percent_list[x])) *100
                
                perc_diff_list.append(perc_diff)
            
            mean_diff = sum(diff_list)/len(diff_list)
            
            #perc_diff_list = perc_diff_list[7:24]
            perc_diff_list = perc_diff_list[1:-1]
            
            meanperc_diff = sum(perc_diff_list)/len(perc_diff_list)
            
            comb_tup = (mf_thresh[mf_ind], tomo_thresh[t_ind], mean_diff)
            
            comb_list.append(comb_tup)
            
            hm_md_row.append(mean_diff)
            
            hm_mdperc_row.append(meanperc_diff)
            
        hm_meandiff.append(hm_md_row)
        
        hm_meanperc_diff.append(hm_mdperc_row)
            
    return(comb_list, hm_meandiff, hm_meanperc_diff)

################################################################

## Save lists storing thresholded slab areas for heatmap plotting in another environment that is compatible with matplotlib

def save_meandiff_lists(file_name, list_to_save):
    
    with open(file_name, 'w') as outfile:
        json.dump(list_to_save, outfile)
    

# Analysis of existing/melted slabs

In [12]:
def slab_locs(slab_dir):
    
    fn_list = []
    for file in sorted(os.listdir(slab_dir)):
        fn = os.path.join(slab_dir, file)
        
        fn_list.append(fn)
        
    list_of_lines = []
    
    slab_file_key = []
    
    for fn in fn_list:
        
        with open(fn, 'r') as file:
            last_line = file.readlines()[-1]
            
        last_line = last_line.strip()
        
        depth_data = list(filter(None, last_line.split(' ')))
        
        depth_data = [float(x) for x in depth_data]
        
        depth_data[0] = 6371 - (depth_data[0]/1000)
            
        lat = 90 - (depth_data[1] * (180/math.pi))
        depth_data.append(lat)
        
        lon = depth_data[2] * (180/math.pi)
        depth_data.append(lon)
        
        depth_data = [depth_data[5], depth_data[4], depth_data[0]]
        
        list_of_lines.append(depth_data)
        
        slab_file_key.append(fn)
        
    slab_array = np.array(list_of_lines)
    
    # slab_array = slab_array[np.argsort(slab_array[:, 2])]
    
    return(slab_array, slab_file_key)

##############################################################################


def slab_exists_arrays(slab_dir, mf_grids_dir, list_of_depths, threshold):
    
    slab_array, slab_file_key = slab_locs(slab_dir)
    
    mf_grids = []
    for file in sorted(os.listdir(mf_grids_dir)):
        fn = os.path.join(mf_grids_dir, file)
        
        mf_grids.append(fn)
        
    mf_grids = mf_grids[1:-1]
    
    depth_list = list_of_depths[1:-1]
    
    working_list = []
    
    no_slab_list = []
    
    slab_exist_key = []
    
    melt_key = []
    
    print(len(slab_array))
    
    for x, slab_data in enumerate(slab_array):
        
        try:
        
            max_depth = next(x for x, val in enumerate(depth_list) if val > slab_data[2])
        except:
            slab_data = slab_array[x:x+1]
            
            above = pygmt.grdtrack(points = slab_data, grid = mf_grids[-1])
            alist = above.values.tolist()[0]
            
            if alist[-1]<=threshold:
                working_list.append(alist)
                slab_exist_key.append(slab_file_key[x])
                continue
                
            else:
                no_slab_list.append(alist)
                melt_key.append(slab_file_key[x])
                continue
            
        min_depth = max_depth-1
        
        slab_data = slab_array[x:x+1]
        
        below = pygmt.grdtrack(points = slab_data, grid = mf_grids[max_depth])
        
        above = pygmt.grdtrack(points = slab_data, grid = mf_grids[min_depth])
        
        blist = below.values.tolist()[0]
        
        alist = above.values.tolist()[0]
        
        test_slab = len(working_list)
        
        test_melt = len(no_slab_list)
        
        if blist[-1]<=threshold:
            working_list.append(blist)
            slab_exist_key.append(slab_file_key[x])
            continue
            
        elif alist[-1]<=threshold:
            working_list.append(alist)
            slab_exist_key.append(slab_file_key[x])
            continue
            
        else:
            no_slab_list.append(blist)
            melt_key.append(slab_file_key[x])
            
        if test_slab != len(working_list) and test_melt != len(no_slab_list):
            print('slab and melt')
        elif test_slab == (len(working_list)-2):
            print('slab double up')
            
        else:
            continue
    
        
    slabs_only_array = np.array(working_list)
    slabs_only_array = slabs_only_array[np.argsort(slabs_only_array[:, 2])]
    
    no_slabs_array = np.array(no_slab_list)
    no_slabs_array = no_slabs_array[np.argsort(no_slabs_array[:, 2])]
            
    return(slabs_only_array, no_slabs_array, slab_exist_key, melt_key)

###########################################################################
##########################################################################


# USE THIS ONE
def slab_exists(slab_dir, mf_grids_dir, list_of_depths, threshold):
    
    slab_array, slab_file_key = slab_locs(slab_dir)
    
    mf_grids = []
    for file in sorted(os.listdir(mf_grids_dir)):
        fn = os.path.join(mf_grids_dir, file)
        
        mf_grids.append(fn)
        
    mf_grids = mf_grids[1:-1]
    
    depth_list = list_of_depths[1:-1]
    
    working_list = []
    
    no_slab_list = []
    
    slab_exist_key = []
    
    melt_key = []
    
    print(len(slab_array))
    
    for x, slab_data in enumerate(slab_array):
        
        try:
        
            max_depth = next(x for x, val in enumerate(depth_list) if val > slab_data[2])
        except:
            slab_data = slab_array[x:x+1]
            
            above = pygmt.grdtrack(points = slab_data, grid = mf_grids[-1])
            alist = above.values.tolist()[0]
            
            if alist[-1]<=threshold:
                working_list.append(alist)
                slab_exist_key.append(slab_file_key[x])
                continue
                
            else:
                no_slab_list.append(alist)
                melt_key.append(slab_file_key[x])
                continue
            
        min_depth = max_depth-1
        
        slab_data = slab_array[x:x+1]
        
        below = pygmt.grdtrack(points = slab_data, grid = mf_grids[max_depth])
        
        above = pygmt.grdtrack(points = slab_data, grid = mf_grids[min_depth])
        
        blist = below.values.tolist()[0]
        
        alist = above.values.tolist()[0]
        
        test_slab = len(working_list)
        
        test_melt = len(no_slab_list)
        
        if blist[-1]<=threshold:
            working_list.append(blist)
            slab_exist_key.append(slab_file_key[x])
            continue
            
        elif alist[-1]<=threshold:
            working_list.append(alist)
            slab_exist_key.append(slab_file_key[x])
            continue
            
        else:
            no_slab_list.append(blist)
            melt_key.append(slab_file_key[x])
            
        if test_slab != len(working_list) and test_melt != len(no_slab_list):
            print('slab and melt')
        elif test_slab == (len(working_list)-2):
            print('slab double up')
            
        else:
            continue
    
        
    slabs_only_array = np.array(working_list)
    slab_dict = {}
    for i, array in enumerate(slabs_only_array):
        slab_dict[slab_exist_key[i]] = array
    
    no_slabs_array = np.array(no_slab_list)
    melt_dict = {}
    for i, array in enumerate(no_slabs_array):
        melt_dict[melt_key[i]] = array
            
    return(slabs_only_array, no_slabs_array, slab_dict, melt_dict)

In [13]:
### FIND SLABS WHICH MATCH TOMOGRAPHY

def slab_match_tomo(slab_dict, melt_dict, comp_dir, list_of_depths):
    
    comp_grds = []
    for file in sorted(os.listdir(comp_dir)):
        fn = os.path.join(comp_dir, file)
        
        comp_grds.append(fn)
    
    depth_list = list_of_depths[1:-1]
    
    slab_arrays = list(slab_dict.values())
    slab_keys = list(slab_dict.keys())
    
    melt_arrays = list(melt_dict.values())
    melt_keys = list(melt_dict.keys())
    
    slab_match = []
    slab_problem = []
    
    melt_match = []
    melt_problem = []
    
    for x, comp_grd in enumerate(comp_grds):
        
        working_arrays = [array for array in slab_arrays if array[2] < depth_list[x] and array[2] > depth_list[x-1]]
        working_arrays = np.array(working_arrays)
        
        working_keys = [slab_keys[ind] for ind, array in enumerate(slab_arrays) if array[2] < depth_list[x] and array[2] > depth_list[x-1]]
        
        if len(working_arrays) != 0:
            below_array = pygmt.grdtrack(points = working_arrays, grid = comp_grds[x-1], interpolation = 'n')
            below_array = below_array.to_numpy()
            above_array = pygmt.grdtrack(points = working_arrays, grid = comp_grd, interpolation = 'n')
            above_array = above_array.to_numpy()
        
            for i, point in enumerate(below_array):
                if point[-1] == 3 or point[-1] == 4:
                    slab_match.append(working_keys[i])
                elif above_array[i][-1] == 3 or above_array[i][-1] == 4:
                    slab_match.append(working_keys[i])
                else:
                    slab_problem.append(working_keys[i])
        
        
        working_arrays = [array for array in melt_arrays if array[2] < depth_list[x] and array[2] > depth_list[x-1]]
        working_arrays = np.array(working_arrays)
        
        working_keys = [melt_keys[ind] for ind, array in enumerate(melt_arrays) if array[2] < depth_list[x] and array[2] > depth_list[x-1]]
        
        if len(working_arrays) != 0:
            below_array = pygmt.grdtrack(points = working_arrays, grid = comp_grds[x-1], interpolation = 'n')
            below_array = below_array.to_numpy()
            above_array = pygmt.grdtrack(points = working_arrays, grid = comp_grd, interpolation = 'n')
            above_array = above_array.to_numpy()
            
            for i, point in enumerate(below_array):
                if point[-1] == 3 or point[-1] == 4:
                    melt_match.append(working_keys[i])
                elif above_array[i][-1] == 3 or above_array[i][-1] == 4:
                    melt_match.append(working_keys[i])
                else:
                    melt_problem.append(working_keys[i])
    
    return(slab_match, slab_problem, melt_match, melt_problem)


In [14]:
##### FIX DICTIONARY FUNCTIONS

def slab_locs(slab_dir):
    
    fn_list = []
    for file in sorted(os.listdir(slab_dir)):
        fn = os.path.join(slab_dir, file)
        
        fn_list.append(fn)
        
    list_of_lines = []
    
    slab_file_key = []
    
    for fn in fn_list:
        
        with open(fn, 'r') as file:
            last_line = file.readlines()[-1]
            
        last_line = last_line.strip()
        
        depth_data = list(filter(None, last_line.split(' ')))
        
        # depth_data = radius, colat, lon, time
        depth_data = [float(x) for x in depth_data]
        
        depth_data[0] = 6371 - (depth_data[0]/1000)
            
        lat = 90 - (depth_data[1] * (180/math.pi))
        depth_data.append(lat)
        
        lon = depth_data[2] * (180/math.pi)
        depth_data.append(lon)
        
        # depth_data = [lon, lat, depth]
        depth_data = [depth_data[5], depth_data[4], depth_data[0]]
        
        list_of_lines.append(depth_data)
        
        slab_file_key.append(fn)
        
    slab_array = np.array(list_of_lines)
    
    # slab_array = slab_array[np.argsort(slab_array[:, 2])]
    
    return(slab_array, slab_file_key)

###############################################################################

def slab_exists_weighted(slab_dir, mf_grids_dir, list_of_depths, threshold):
    
    slab_array, slab_file_key = slab_locs(slab_dir)
    
    mf_grids = []
    for file in sorted(os.listdir(mf_grids_dir)):
        fn = os.path.join(mf_grids_dir, file)
        
        mf_grids.append(fn)
        
    mf_grids = mf_grids[1:-1]
    
    depth_list = list_of_depths[1:-1]
    
    working_list = []
    
    no_slab_list = []
    
    slab_exist_key = []
    
    melt_key = []
    
    for x, slab_data in enumerate(slab_array):
        
        try:
        
            max_depth = next(x for x, val in enumerate(depth_list) if val > slab_data[2])
        except:
            slab_data = slab_array[x:x+1]
            
            above = pygmt.grdtrack(points = slab_data, grid = mf_grids[-1])
            alist = above.values.tolist()[0]
            
            if alist[-1]<=threshold:
                working_list.append(alist)
                slab_exist_key.append(slab_file_key[x])
                continue
                
            else:
                no_slab_list.append(alist)
                melt_key.append(slab_file_key[x])
                continue
            
        min_depth = max_depth-1
        
        slab_data = slab_array[x:x+1]
        
        below = pygmt.grdtrack(points = slab_data, grid = mf_grids[max_depth])
        
        above = pygmt.grdtrack(points = slab_data, grid = mf_grids[min_depth])
        
        blist = below.values.tolist()[0]
        
        alist = above.values.tolist()[0]
        
        #######
        
        above_temp = alist[-1]
        below_temp = blist[-1]
        
        above_depth = depth_list[min_depth]
        below_depth = depth_list[max_depth]
        point_depth = slab_data[0][2]
        
        weighted_anomaly = above_temp + ((point_depth - above_depth) / (below_depth - above_depth)) * (below_temp - above_temp)
        
        if weighted_anomaly <= threshold:
            new_slab_array = np.array([slab_data[0][0], slab_data[0][1], slab_data[0][2], weighted_anomaly])
            working_list.append(new_slab_array)
            slab_exist_key.append(slab_file_key[x])
            continue
            
        else:
            new_slab_array = np.array([slab_data[0][0], slab_data[0][1], slab_data[0][2], weighted_anomaly])
            no_slab_list.append(new_slab_array)
            melt_key.append(slab_file_key[x])
            continue
        
        ######
        
    slabs_only_array = np.array(working_list)
    slab_dict = {}
    for i, array in enumerate(slabs_only_array):
        slab_dict[slab_exist_key[i]] = array
    
    no_slabs_array = np.array(no_slab_list)
    melt_dict = {}
    for i, array in enumerate(no_slabs_array):
        melt_dict[melt_key[i]] = array
            
    return(slabs_only_array, no_slabs_array, slab_dict, melt_dict)

##########################################################################3

def slab_match_tomo(slab_dict, melt_dict, comp_dir, list_of_depths):
    
    comp_grds = []
    for file in sorted(os.listdir(comp_dir)):
        fn = os.path.join(comp_dir, file)
        
        comp_grds.append(fn)
    
    depth_list = list_of_depths[1:-1]
    
    slab_arrays = list(slab_dict.values())
    slab_keys = list(slab_dict.keys())
    
    melt_arrays = list(melt_dict.values())
    melt_keys = list(melt_dict.keys())
    
    slab_match = []
    slab_problem = []
    
    slab_match_dict = {}
    slab_prob_dict = {}
    
    melt_match = []
    melt_problem = []
    
    melt_match_dict = {}
    melt_prob_dict = {}
    
    for x, comp_grd in enumerate(comp_grds):
        
        working_arrays = [array for array in slab_arrays if array[2] < depth_list[x] and array[2] > depth_list[x-1]]
        working_arrays = np.array(working_arrays)
        
        working_keys = [slab_keys[ind] for ind, array in enumerate(slab_arrays) if array[2] < depth_list[x] and array[2] > depth_list[x-1]]
        
        if len(working_arrays) != 0:
            below_array = pygmt.grdtrack(points = working_arrays, grid = comp_grds[x-1], interpolation = 'n')
            below_array = below_array.to_numpy()
            above_array = pygmt.grdtrack(points = working_arrays, grid = comp_grd, interpolation = 'n')
            above_array = above_array.to_numpy()
        
            for i, point in enumerate(below_array):
                if point[-1] == 3 or point[-1] == 4:
                    slab_match.append(working_keys[i])
                    slab_match_dict[working_keys[i]] = working_arrays[i]
                elif above_array[i][-1] == 3 or above_array[i][-1] == 4:
                    slab_match.append(working_keys[i])
                    slab_match_dict[working_keys[i]] = working_arrays[i]
                else:
                    slab_problem.append(working_keys[i])
                    slab_prob_dict[working_keys[i]] = working_arrays[i]
        
        
        working_arrays = [array for array in melt_arrays if array[2] < depth_list[x] and array[2] > depth_list[x-1]]
        working_arrays = np.array(working_arrays)
        
        working_keys = [melt_keys[ind] for ind, array in enumerate(melt_arrays) if array[2] < depth_list[x] and array[2] > depth_list[x-1]]
        
        if len(working_arrays) != 0:
            below_array = pygmt.grdtrack(points = working_arrays, grid = comp_grds[x-1], interpolation = 'n')
            below_array = below_array.to_numpy()
            above_array = pygmt.grdtrack(points = working_arrays, grid = comp_grd, interpolation = 'n')
            above_array = above_array.to_numpy()
            
            for i, point in enumerate(below_array):
                if point[-1] == 3 or point[-1] == 4:
                    melt_match.append(working_keys[i])
                    melt_match_dict[working_keys[i]] = working_arrays[i]
                elif above_array[i][-1] == 3 or above_array[i][-1] == 4:
                    melt_match.append(working_keys[i])
                    melt_match_dict[working_keys[i]] = working_arrays[i]
                else:
                    melt_problem.append(working_keys[i])
                    melt_prob_dict[working_keys[i]] = working_arrays[i]
    
    return(slab_match_dict, slab_prob_dict, melt_match_dict, melt_prob_dict)


# Filter slab dictionaries by depth

In [15]:
def slab_filter_depths(input_array, depth_filter_list, threshold = 100, filter_type = 0):
    
    list_of_depth_arrays = []
    
    file_at_depth_key = []
    
    # Group each slab terminal into groups based on the depth they are closest to
    if filter_type == 0:
        for x, depth in enumerate(depth_filter_list):
            
            if x > 0 and x != (len(depth_filter_list)-1):
                working_array = input_array[np.logical_and((input_array[:, 2] - depth_filter_list[x-1]) > (depth - input_array[:, 2]),
                                                           (input_array[:, 2] - depth) < (depth_filter_list[x+1] - input_array[:, 2]))]
                list_of_depth_arrays.append(working_array)
                
            elif x == 0:
                working_array = input_array[(input_array[:, 2] - depth) < ((depth_filter_list[x+1]) - input_array[:, 2])]
                list_of_depth_arrays.append(working_array)
                
            else:
                working_array = input_array[(input_array[:, 2] - depth_filter_list[x-1]) > (depth - input_array[:, 2])]
                list_of_depth_arrays.append(working_array)
                
    # Create groups of all slab terminals within 'x' km of a given depth --> e.g. 3 arrays for slabs within 100km of [1000, 1500, 2000]km depths
    else:        
        for depth in depth_filter_list:
            
            working_array = input_array[abs((input_array[:, 2] - depth)) <= threshold]
            
            list_of_depth_arrays.append(working_array)
            
    return(list_of_depth_arrays)


#############################################################################################

def slab_filter_depths_dict(input_array, input_dict, depth_filter_list, threshold = 100, filter_type = 0):
    
    list_of_depth_dicts = []
    
    # Separates entries into depth which they are closest to
    if filter_type == 0:
        
        dkeys = list(input_dict.keys())
        
        dvals = list(input_dict.values())
        
        for x, depth in enumerate(depth_filter_list):
            
            if x > 0 and x != (len(depth_filter_list)-1):
                working_array = input_array[np.logical_and((input_array[:, 2] - depth_filter_list[x-1]) > (depth - input_array[:, 2]),
                                                           (input_array[:, 2] - depth) < (depth_filter_list[x+1] - input_array[:, 2]))]
                
                
                if len(working_array) != 0:
                    working_array.tolist()
                    work_dict = {}
                    
                    for array in working_array:
                        dct_array_list = [x for x, d_arr in enumerate(dvals) if np.array_equal(d_arr,array) == True]
                        
                        work_dict[dkeys[dct_array_list[0]]] = array
                    
                    list_of_depth_dicts.append(work_dict)
                    
                else:
                    list_of_depth_dicts.append({})
                
            
            ## Append empty dictionary for first depth --> slabs seeded at 250km, first depth at 40km, should be no entries here
            elif x == 0:
                working_array = input_array[(input_array[:, 2] - depth) < ((depth_filter_list[x+1]) - input_array[:, 2])]
                #list_of_depth_arrays.append(working_array)
                list_of_depth_dicts.append({})
            
            
            else:
                working_array = input_array[(input_array[:, 2] - depth_filter_list[x-1]) > (depth - input_array[:, 2])]
                #list_of_depth_arrays.append(working_array)
                list_of_depth_dicts.append({})
    
    # Creates set of dictionaries containing slab terminals within a certain threshold of each depth in a depth list 
    # --> e.g. 3 dictionaries containing all terminals within 100km depth of [1000, 1500, 2000]km
    else:        
        for depth in depth_filter_list:
            
            work_dict = {key:array for (key,array) in input_dict.items() if abs((array[2] - depth)) <= threshold}
            
            list_of_depth_dicts.append(work_dict)
                    
    return(list_of_depth_dicts)
        

# Slab sinking functions

In [16]:
## Create a data frame of each slabs present-day position and order them by depth

#def find_slab_depths(slab_dir):
def find_slab_present_day_depths(slab_dir):
    
    fn_list = []
    for file in sorted(os.listdir(slab_dir)):
        fn = os.path.join(slab_dir, file)
        
        fn_list.append(fn)
        
    list_of_lines = []
    
    for fn in fn_list:
        
        with open(fn, 'r') as file:
            last_line = file.readlines()[-1]
            
        last_line = last_line.strip()
        
        split_fn = fn.split('248km/')
        
        depth_data = last_line+'  '+split_fn[-1]
        
        depth_data = list(filter(None, depth_data.split(' ')))
        
        list_of_lines.append(depth_data)
    
    for line in list_of_lines:
        line[0] = float(line[0])
        line[1] = float(line[1])
        line[2] = float(line[2])
        line[3] = float(line[3])
    
    df_depths = pd.DataFrame(list_of_lines)
    df_depths.columns = ['depth', 'lat', 'long', 'time', 'name']
    
    df_depths_sort = df_depths.sort_values(by = 'depth')
    return(df_depths, df_depths_sort, list_of_lines)

#################################################################
#################################################################

def lateral_advection_distance(slab_path_file, dtype = 'deg', sinking_rate = False):
    
    list_of_lines = []
    
    # Create an array from the input slab file lines, convert entries from string to float
    with open(slab_path_file, 'r') as file:
        line_list = file.readlines()
    
        for line in line_list:
            line = line.strip()
            
            line = list(filter(None, line.split(' ')))
            
            line[0] = float(line[0])
            line[1] = float(line[1])
            line[2] = float(line[2])
            line[3] = float(line[3])
            
            list_of_lines.append(line)
            
    work_array = np.array(list_of_lines)
    
    #lat_dist = [0]
    #depth_list = [248.]
    lat_dist = []
    depth_list = []
    
    total_lat_distance = []
    total_vert_distance = []
    
    sinking_dist = []
    
    sinking_rate_list = []
    
    for x, line in enumerate(work_array):
        
        if x == 0:
            continue
            
        else:
            
            # Set radius = midpoint between each depth
            radius = (line[0] + work_array[x-1][0])/2
            #radius = 6371000 # Earth Radius
            
            delta_d = abs(line[0] - work_array[x-1][0])
            total_vert_distance.append(delta_d/1000)
            
            # Set initial latitude and longitude variables (convert colat to lat)
            lat1 = ((90*math.pi/180) - line[1])
            lon1 = line[2]
            
            # set second lat/lon variables --> convert colat to lat
            lat2 = ((90*math.pi/180) - work_array[x-1][1])
            lon2 = work_array[x-1][2]
            
            # Calculate difference between lat/lon positions
            del_lat = lat2 - lat1
            del_lon = lon2 - lon1
            
            # Calculate great circle distance the slab has travelled
            a = (math.sin(del_lat/2)**2) + math.cos(lat1)*math.cos(lat2)*(math.sin(del_lon/2)**2)
            
            c = 2 * math.atan2(math.sqrt(a), math.sqrt(1-a))
            
            # Calculate distance in km
            dist = (radius*c)/1000
            total_lat_distance.append(dist)
            
            # Convert distance in km to cm, divide by 100 000 (length of timestep in years) to get cm/yr
            dist_rate = (dist*1000*100)/100000
            
            # convert distance rate from cm/yr to deg/myr --> rearrange circle arc length equation l = theta*r --> theta = l/r
            # multiply radius by 100 to convert from m to cm
            # multiply by 1 000 000 to convert from per year rate to per myr rate
            # Convert to degrees using np.degrees()
            dist_rate_deg = (np.degrees((dist_rate/(radius*100))))*1000000
            
            # Convert radius to depth (subtract radius from earth radius) and divide by 1000 to convert from m to km
            depth_list.append((6371000-radius)/1000)
            
            #lat_dist.append(dist_rate)
            if dtype == 'deg':
                lat_dist.append(dist_rate_deg)
            else:
                lat_dist.append(dist_rate)
            
            if sinking_rate == True:
                vert_sr = (abs(line[0] - work_array[x-1][0])/100000)*100
                sinking_rate_list.append(vert_sr)
            #sinking_dist.append(vert_sr)
            
    # Sum lateral distances for each time step to get a total lateral distance travelled
    advection_dist = sum(total_lat_distance)
    
    sinking_dist = sum(total_vert_distance)
    
    if sinking_rate == True:
            return(lat_dist, advection_dist, depth_list, sinking_dist, sinking_rate_list)
    
    return(lat_dist, advection_dist, depth_list, sinking_dist)
        
def displacement_list(slab_path_dir, slab_file_list = [], dtype = 'deg', sinking_rate_calc = False):
    
    if len(slab_file_list) == 0:
        fn_list = []
        for file in sorted(os.listdir(slab_path_dir)):
            fn = os.path.join(slab_path_dir, file)
            
            fn_list.append(fn)
    else:
        fn_list = slab_file_list
    
    advection_dist_list = []
    depth_lists = []
    lat_dist_lists = []
    sinking_dist_list = []
    
    vert_sink_list = []
    
    for fn in fn_list:
        if sinking_rate_calc == False:
            lat_dist_list, advection_dist, depth_list, sinking_dist = lateral_advection_distance(fn, dtype)
            
        else:
            lat_dist_list, advection_dist, depth_list, sinking_dist, vert_sr_list = lateral_advection_distance(fn, dtype, sinking_rate = sinking_rate_calc)
            vert_sink_list.append(vert_sr_list)
        
        advection_dist_list.append(advection_dist)
        depth_lists.append(depth_list)
        lat_dist_lists.append(lat_dist_list)
        sinking_dist_list.append(sinking_dist)
    # return(lists of lateral movement rates, list of depths at time, list of each slabs total lateral and vertical distance)
    
    if sinking_rate_calc == True:
            return(lat_dist_lists, depth_lists, advection_dist_list, sinking_dist_list, vert_sink_list)
    
    return(lat_dist_lists, depth_lists, advection_dist_list, sinking_dist_list)


def smooth_lat_curve(lateral_motion_rate_lists, depth_lists):
    ## Lat motion datasets contain many measurements at inconsistent depths, this resamples each slab at common depths by finding the midpoint between
    ## advection rates on either side of the sample depth
    
    ## Initialise lists to contain each slab's resampled average lateral advection rate and the depths that the slab was sampled at
    mean_lists = []
    sample_depth_lists = []
    
    ## Iterate over each slab's lateral advection rate list
    for x, work_rate in enumerate(lateral_motion_rate_lists):
        # find the current loop iteration's corresponding depth list
        work_depths = depth_lists[x]
        
        # Initialise a count flag at 0
        count = 0
        
        # define a list of depths to sample at:
        # begin at 250km, increase in 100km increments to the greatest depth that is less than the deepest point of the slab
        sample_depths = np.arange(250, max(work_depths), 100)
        
        mean_list = [] # Initialise empty list to store the averaged rates at each depth
        
        # Iterate over each depth to sample and find the advection rate at the depths either side of the sample depth, average them, and add them to the list
        for i, depth in enumerate(work_depths):
            try:
                if depth<sample_depths[-1]:
                    if depth<sample_depths[count] and work_depths[i+1]>sample_depths[count]:
                        work_mean = (work_rate[i]+work_rate[i+1])/2
                        mean_list.append(work_mean)
                    
                        count = count+1
            except:
                continue
        mean_lists.append(mean_list)
        sample_depth_lists.append(sample_depths)
    
    # Returns: mean_lists = each slab's list of resampled rates; sample_depth_lists = each mean_list corresponding list of sampled depths
    return(mean_lists, sample_depth_lists)


def mean_lat_advection(smoothed_lat_lists, sampled_depths_lists):
    # Initialise list to store lateral advection rate mean for each depth
    lat_mean = []
    
    # max_length = find the length of the longest list (deepest slab)
    # deepest_list = return the list of depths for the deepest slab
    max_length = len(max(smoothed_lat_lists, key=len))
    deepest_list = max(sampled_depths_lists, key=len)
    
    # each loop creates a list of all lateral advection rates at depth 'i', then finds the mean and standard deviation
    for i in np.arange(max_length):
        work_mean = []
        
        for sub_list in smoothed_lat_lists:
            try:
                work_mean.append(sub_list[i])
                
            except:
                continue
        
        mean = sum(work_mean)/len(work_mean)
        lat_mean.append(mean)
    
    # Returns: lat_mean = list of the average advection rate for each depth; deepest_list = depths which have a corresponding average in lat_mean
    return(lat_mean, deepest_list)


######################################################################
######################################################################


def slab_depth_at_time(slab_path_file, max_time):    
    # Initialise an empty list --> this will be filled with arrays, where each array is a time step (1-line) from the slab tracking
    list_of_lines = []
    
    # Open the slab path file and read the lines
    with open(slab_path_file, 'r') as file:
        line_list = file.readlines()
        
        # Iterate over the file lines --> convert from one string for the whole line into a single string for each line element and convert that to a float
        # Array elements --> [depth(convert from radius), colat(radians), lon(radians), time before present(convert from time since tracking began)]
        for x, line in enumerate(line_list):
            
            line = line.strip()
            
            line = list(filter(None, line.split(' ')))
            
            line[0] = 6371 - (float(line[0])/1000)
            line[1] = 90 - (np.degrees(float(line[1])))
            line[2] = np.degrees(float(line[2]))
            line[3] = abs(max_time - float(line[3]))
            
            # Calculate sinking rate --> find difference in depths since the previous time step and convert to cm/yr
            # Set first time period sinking rate to 0
            if x != 0:
#                sink_rate = (abs(line[0] - list_of_lines[x-1][0]) / 100000)*1000*100
                sink_rate = ((line[0] - list_of_lines[x-1][2]) / 100000)*1000*100
                
                line.append(sink_rate)
            else:
                line.append(0.)
                
            apline = [line[2], line[1], line[0], line[3], line[4]]
        
            list_of_lines.append(apline)
            
    # Turn the list of lines into an array
    work_array = np.array(list_of_lines)   
    
    return(work_array)


def slabs_data(slab_path_dir, slab_file_list = []):
    
    # Check to see if a list of slab files has been set, if not, a directory must be set
    if len(slab_file_list) != 0:
        fn_list = slab_file_list
        
    # Create list of slab particle files from a directory containing files for tracked slabs
    elif os.path.isdir(slab_path_dir) == True:
        fn_list = []
        for file in sorted(os.listdir(slab_path_dir)):
            fn = os.path.join(slab_path_dir, file)
            
            fn_list.append(fn)
    else:
        slab_paths = np.array([])
        return(slab_paths)
    # Initialise empty list to create an array from
    array_list = []
    
    # From the first slab particle, find the time the slabs were tracked for
    with open(fn_list[0], 'r') as file:
        last_line = file.readlines()[-1]
            
        last_line = last_line.strip()
        
        depth_data = list(filter(None, last_line.split(' ')))
    
        max_time = float(depth_data[-1])
    
    # create an array from each slab particle file which details --> [lon, lat, depth, time, sinking rate]
    # Add arrays to a list of arrays
    for fn in fn_list:
        
        warray = slab_depth_at_time(fn, max_time)
        
        array_list.append(warray)
    
    
    # Create an array of arrays where each array represents the sinking profile a single slab particle
    slab_paths = np.array(array_list)
    
    return(slab_paths)
    

In [17]:
## Comments written for depth but works for any data recorded as a time series

def timeseries_mnstdv(input_array, var_ind):
    
    depth_lists = []
    
    # create a list of slab depths for each slab and create a list of lists from them
    # j = index position of desired variable, e.g. j = 0 if depth is desired variable and depth is in the first column
    for array in input_array:
        depth_lists.append(((array.T)[var_ind]).tolist())
    
    # Zip depth list of lists --> creates a list of lists where each list shows the depth of slabs for a given age (e.g. list[0] = all slab depths at 200ma)
    zipped_depths = list(zip(*depth_lists))
    
    # Sum each list of depths at time
    mean_depths = [sum(i) for i in zipped_depths]
    
    # Calculate mean by dividing the sum of depths at a given time by the total number of slabs (length of depth_lists)
    dl_length = len(depth_lists)
    for i, depth in enumerate(mean_depths):
        
        mean_depths[i] = depth/dl_length
    
        
    std_dev_list = []
    
    # Calculate standard deviation for each time period
    for d_list in zipped_depths:
        std_dev = np.std(d_list)
        std_dev_list.append(std_dev)
        
    std_dev_upper = []
    
    std_dev_lower = []
    
    # Create list of depths for +-1 sigma at each depth
    for x, depth in enumerate(mean_depths):
        
        std_dev_upper.append(depth-(std_dev_list[x]))
        std_dev_lower.append(depth+(std_dev_list[x]))
    
    time_list = ((input_array[0].T)[3]).tolist()
    
    # Add standard deviation lists together (second reversed) to allow for plotting of a closed polygon and create time series to match
    std_dev_poly = std_dev_upper + std_dev_lower[::-1]
    
    time_2 = time_list + time_list[::-1]
    
    return(mean_depths, std_dev_poly, time_list, time_2)
    
    
##################################################################


# sample_depths = np.arange(rnage of depths to sample)

def d_vs_sr_stats(input_array, depth_ind, sr_ind):    
    
    test_depth = []
    test_rate = []
    
    ## create list of each slab's depths and sinking rates
    for array in input_array:
        
        test_depth.append(array[:,depth_ind].tolist())
        test_rate.append(array[:,sr_ind].tolist())
    
    maximum_depth = max([x for x2 in test_depth for x in x2])
    
    sample_depths = np.arange(250, maximum_depth, 100)
    
    slab_ind_list = []
    
    # create pairs showing a depth and the first index of a pair that encompass the depth
    for d_list in test_depth:
        ind_list = []
        for td in sample_depths:
            for x, slab_depth in enumerate(d_list):
                if x < len(d_list)-1:
                    if slab_depth < td and d_list[x+1] > td:
                        ind_list.append([x, td])
                        
                    elif slab_depth > td and d_list[x+1] < td:
                        ind_list.append([x, td])
    
        slab_ind_list.append(ind_list)
    
    
    # Find the average sinking rate for the sinking rates at the 2 above mentioned indices, replace index with mean sr
    for x, sr_list in enumerate(test_rate):
        working_mean_rate = []
        
        for ind_lis in slab_ind_list[x]:
            mean = (sr_list[ind_lis[0]]+sr_list[ind_lis[0]+1])/2
            
            ind_lis[0] = mean
    
    # ZIp lists to create a pair of lists, 0 shows mean sinking rates and 1 shows a key for the depth of the corresponding sr in list 0
    zipped_list = []
    for x_list in slab_ind_list:
        wl = list(zip(*x_list))
        for x, item in enumerate(wl):
            wl[x] = list(item)
        
        zipped_list.append(wl)
    
    
    # sort above lists into new lists which group sinking rates for common depths
    full_sr_list = []
    
    for depth in sample_depths:
        working_sr_list = []
        
        for ziplist in zipped_list:
            index_list = []
            idp = 0
            while True:
                try:
                    idp = ziplist[1].index(depth, idp)
                    
                    index_list.append(idp)
                    idp = idp+1
                    
                except:
                    break
                
            for i in index_list:
                working_sr_list.append(ziplist[0][i])
                
        full_sr_list.append(working_sr_list)
    
    # Find mean and std. dev for each above depth coded sinking rate list
    mean_sr_list = []
    
    std_sr_list = []
    
    for srl in full_sr_list:
        if len(srl) != 0:
            mean = sum(srl)/len(srl)
            
            std = np.std(srl)
            
            mean_sr_list.append(mean)
            std_sr_list.append(std)
        else:
            continue
            
    sdv_upper = []

    sdv_lower = []
    
    # Create +- 1 sigma std dev lists for plotting of polygon
    for x, sr in enumerate(mean_sr_list):
        
        sdv_upper.append(sr-(std_sr_list[x]))
        sdv_lower.append(sr+(std_sr_list[x]))
    
    # Add std dev lists together (reverse second) to plot poly, make depth list to match
    list_sd = list(sample_depths)
    std_dev_poly = sdv_upper+sdv_lower[::-1]
    dpt2 = list_sd + list_sd[::-1]
    
    return(mean_sr_list, std_dev_poly, list_sd, dpt2)



# Plotting Slab Tracks

### Example code to plot the sinking paths of slab tracer particles


fig = pygmt.Figure()

fig.basemap(region='d', projection="N15c", frame=["af", "WSne"])
fig.grdimage(grid = '/home/robby/Desktop/slab_tracking/gld446-temp-000Ma-2677km.grd', region='d', projection = 'N15c', cmap = 'gray')
fig.coast(shorelines="0.5p,black", projection='N15c')

fig.grdimage(grid = '/home/robby/honours/comp/45perc_comp/comp_grds/2074_b_gld446.nc', projection = 'N15c', cmap = '/home/robby/honours/comp_palette.cpt')

fig.plot(data = '/home/robby/honours/tectonic_recons/reconstructed_200.00Ma.xy', projection = 'N15c', pen = '0.5p,black')

fn_list = []

slab_path_dir = '/home/robby/Desktop/slab_tracking/200Ma_248km'

for file in sorted(os.listdir(slab_path_dir)):
    fn = os.path.join(slab_path_dir, file)
    
    fn_list.append(fn)

start = time.time()
    
for slab_path_file in fn_list:
    with open(slab_path_file, 'r') as file:
        line_list = file.readlines()
        
        list_of_lines = []
        for line in line_list:
            line = line.strip()
            
            line = list(filter(None, line.split(' ')))
            
            line[0] = 6371 - (float(line[0])/1000)
            line[1] = float(line[1])
            line[2] = float(line[2])
            line[3] = float(line[3])
            
            line.append(line[2] * (180/math.pi))
        
            line.append(90 - (line[1] * (180/math.pi)))
            
            list_of_lines.append(line)
            
        
        w_array = np.array(list_of_lines)
        
        fig.plot(data = w_array, incols = [4,5,0], cmap = '/home/robby/Desktop/slab_tracking/depth.cpt', style = 'c0.1', region = 'd')
        
end = time.time()
    
fig.show()

print(end - start)


# Operations on tracked slabs

In [20]:
# PLOT TERMINAL LOCATIONS FOR EXISTING AND THERMALLY ASSIMILATED SLABS AT DEPTH

def terminals_vs_tomo(list_of_depths, comp_dir, slabs_array, melted_array, comp_cmap, fig_dir, age_str):
    
    d_list = list_of_depths[1:-1]
    
    fn_list = []
    for file in sorted(os.listdir(comp_dir)):
        fn = os.path.join(comp_dir, file)
        
        fn_list.append(fn)
    
    slabs_array = slabs_array[1:-1]
    
    melted_array = melted_array[1:-1]
    
    slabs_ind = [x for x in range(len(slabs_array)) if len(slabs_array[x]) != 0]
    
    melted_ind = [x for x in range(len(melted_array)) if len(melted_array[x]) != 0]
    
    #fn_list = [fn for x, fn in enumerate(fn_list) if x in slabs_ind or x in melted_ind]

    for x, comp_grd in enumerate(fn_list):
        
        plot_slabs = []
        plot_melted = []
        if x in slabs_ind:
            
            work_array = pygmt.grdtrack(points = slabs_array[x], grid = comp_grd, interpolation = 'n')
            work_array = work_array.to_numpy()
            
            for point in work_array:
                if point[-1] == 3 or point[-1] == 4:
                    plot_slabs.append(point)

        if x in melted_ind:
            work_melted = pygmt.grdtrack(points = melted_array[x], grid = comp_grd, interpolation = 'n')
            work_melted = work_melted.to_numpy()
            
            for point in work_melted:
                if point[-1] == 4 or point[-1] == 3:
                    plot_melted.append(point)
                    
        if len(plot_slabs) != 0 or len(plot_melted) != 0:
            working_fig = pygmt.Figure()
            
            working_fig.basemap(region='d', projection="N15c", frame=["af", "WSne"])
            working_fig.grdimage(grid = comp_grd, projection = 'N15c', cmap = comp_cmap, interpolation = 'n+a')
            working_fig.coast(shorelines="0.5p,black", projection='N15c')
            
            if len(plot_slabs) > 0:
                plot_slabs = np.array(plot_slabs)
                working_fig.plot(data = plot_slabs, projection = 'N15c', style = 'c0.15c', incols = [0,1], color = 'red')
                
            if len(plot_melted) > 0:
                plot_melted = np.array(plot_melted)
                working_fig.plot(data = plot_melted, projection = 'N15c', style = 'i0.15c', incols = [0,1], color = 'yellow')
            
            fig_name_str = fig_dir+'/'+str(d_list[x])+'_'+age_str+'_tomo_v_term.png'
            
            if os.path.exists(fig_name_str) == False:
                working_fig.savefig(fig_name_str)
            
        else:
            continue
   

In [21]:
# PLOT WORMS FOR TERMINALS WHICH MATCH TOMOGRAPHY

def term_v_tomo_dict(in_slab_dict, comp_dir, list_of_depths, comp_cmap, fig_dir, age_str, plot_option = 'slab'):
    
    d_list = list_of_depths[1:-1]
    
    fn_list = []
    for file in sorted(os.listdir(comp_dir)):
        fn = os.path.join(comp_dir, file)
        
        fn_list.append(fn)
    
    #in_slab_dict = in_slab_dict[1:-1]
    #in_melt_dict = in_melt_dict[1:-1]
    slab_dict = in_slab_dict[1:-1]
    
    #if plot_option == 'melt':
    #    slabs_ind = [x for x in range(len(in_melt_dict)) if len(in_melt_dict[x]) != 0]
    #    slab_dict = in_melt_dict
    #    
    #else:
    #    slabs_ind = [x for x in range(len(in_slab_dict)) if len(in_slab_dict[x]) != 0]
    #    slab_dict = in_slab_dict
        
    slabs_ind = [x for x in range(len(slab_dict)) if len(slab_dict[x]) != 0]
    
    for x, comp_grd in enumerate(fn_list):
        
        plot_slabs = []
        if x in slabs_ind:
            work_dict = slab_dict[x]
            tarray = np.array(list(work_dict.values()))
            fn_keys = list(work_dict.keys())

            work_array = pygmt.grdtrack(points = tarray, grid = comp_grd, interpolation = 'n')
            work_array = work_array.to_numpy()
            
            for i, point in enumerate(work_array):
                if point[-1] == 3 or point[-1] == 4:
                    plot_slabs.append(fn_keys[i])
            
            if len(plot_slabs) != 0:
                working_fig = pygmt.Figure()
                
                working_fig.basemap(region='d', projection="N15c", frame=["af", "WSne"])
                working_fig.grdimage(grid = comp_grd, projection = 'N15c', cmap = comp_cmap, interpolation = 'n+a')
                working_fig.coast(shorelines="0.5p,black", projection='N15c')
                
                for slab_path_file in plot_slabs:
                    with open(slab_path_file, 'r') as file:
                        line_list = file.readlines()
        
                        list_of_lines = []
                        for line in line_list:
                            line = line.strip()
                            
                            line = list(filter(None, line.split(' ')))
                            
                            line[0] = 6371 - (float(line[0])/1000)
                            line[1] = float(line[1])
                            line[2] = float(line[2])
                            line[3] = float(line[3])
                            
                            line.append(line[2] * (180/math.pi))
                        
                            line.append(90 - (line[1] * (180/math.pi)))
                            
                            list_of_lines.append(line)
            
        
                        plot_array = np.array(list_of_lines)
        
                        working_fig.plot(data = plot_array, incols = [4,5,0], cmap = '/media/robby/arbiter/honours/colour_maps/depth.cpt', style = 'c0.1', region = 'd')
            
                fig_name_str = fig_dir+'/'+str(d_list[x])+'_'+age_str+'_'+plot_option+'_tracks_v_tomo.png'
                
                if os.path.exists(fig_name_str) == False:
                    working_fig.savefig(fig_name_str)
                    
                    
################################################################3


def subset_term_v_tomo(list_of_depths, depth_filter, comp_dir, slabs_array, melted_array, comp_cmap, fig_dir, age_str):
    
    d_list = list_of_depths[1:-1]
    
    fn_list = []
    for file in sorted(os.listdir(comp_dir)):
        fn = os.path.join(comp_dir, file)
        
        fn_list.append(fn)
        
    comp_ind = [x for x, depth in enumerate(d_list) if depth in depth_filter]
    
    for i, x in enumerate(comp_ind):
        
        comp_grd = fn_list[x]
        
        plot_slabs = []
        working_slab = pygmt.grdtrack(points = slabs_array[i], grid = comp_grd, interpolation = 'n')
        working_slab = working_slab.to_numpy()
        for point in working_slab:
            if point[-1] == 3 or point[-1] == 4:
                plot_slabs.append(point)
                
        plot_slabs = np.array(plot_slabs)
        
        plot_melted = []
        working_melted = pygmt.grdtrack(points = melted_array[i], grid = comp_grd, interpolation = 'n')
        working_melted = working_melted.to_numpy()
        for point in working_melted:
            if point[-1] == 3 or point[-1] == 4:
                plot_melted.append(point)
                
        plot_melted = np.array(plot_melted)
        
        
        lpm = len(plot_melted)
        lps = len(plot_slabs)
        
        if lps+lpm != 0:
            
            working_fig = pygmt.Figure()
            
            working_fig.basemap(region='d', projection="N15c", frame=["af", "WSne"])
            working_fig.grdimage(grid = comp_grd, projection = 'N15c', cmap = comp_cmap, interpolation = 'n+a')
            working_fig.coast(shorelines="0.5p,black", projection='N15c')
            
            if lps != 0:
                working_fig.plot(data = plot_slabs, projection = 'N15c', style = 'c0.15c', incols = [0,1], color = 'red')
            
            if lpm != 0:
                working_fig.plot(data = plot_melted, projection = 'N15c', style = 'i0.15c', incols = [0,1], color = 'yellow')
            
            fig_name_str = fig_dir+'/'+str(d_list[x])+'_'+age_str+'_tomo_v_term.png'
                
            if os.path.exists(fig_name_str) == False:
                working_fig.savefig(fig_name_str)


In [22]:
### CREATE MAPS WHICH SHOW ALL TERMINALS MATCHING TOMOGRAPHY AT X DEPTH COLOUR CODED BY AGE

def slab_terms_all_ages(all_slab_dicts, all_melt_dicts, comp_dir, list_of_depths, age_key, fig_dir, comp_cpt, age_cpt):
    
    ## Cut out top and bottom entries for depth list and all dictionaries
    d_list = list_of_depths[1:-1]
    
    all_slab_dicts = [dct[1:-1] for dct in all_slab_dicts]
    all_melt_dicts = [dct[1:-1] for dct in all_melt_dicts]
    
    comp_fn_list = []
    for file in sorted(os.listdir(comp_dir)):
        fn = os.path.join(comp_dir, file)
        
        comp_fn_list.append(fn)
        
    for x, comp_grid in enumerate(comp_fn_list):
        
        plot_slab_terms = []
        for age, slab_dict in enumerate(all_slab_dicts):
            if len(slab_dict[x]) != 0:
                work_dict = slab_dict[x]
                tarray = np.array(list(work_dict.values()))
                fn_keys = list(work_dict.keys())
                print(tarray)
        
                # check grid value for slab terminals and create new array
                work_array = pygmt.grdtrack(points = tarray, grid = comp_grid, interpolation = 'n')
                work_array = work_array.to_numpy()
                
                for point in work_array:
                    if point[-1] == 3 or point[-1] == 4:
                        point = np.append(point, age_key[age])
                        plot_slab_terms.append(point)
        
        plot_melt_terms = []
        for age, melt_dict in enumerate(all_melt_dicts):
            if len(melt_dict[x]) != 0:
                work_dict = melt_dict[x]
                tarray = np.array(list(work_dict.values()))
                fn_keys = list(work_dict.keys())
        
                # check grid value for slab terminals and create new array
                work_array = pygmt.grdtrack(points = tarray, grid = comp_grid, interpolation = 'n')
                work_array = work_array.to_numpy()
                
                for point in work_array:
                    if point[-1] == 3 or point[-1] == 4:
                        point = np.append(point, age_key[age])
                        plot_melt_terms.append(point)
                        
        slab_flag = len(plot_slab_terms)
        melt_flag = len(plot_melt_terms)

        if slab_flag + melt_flag != 0:
            work_fig = pygmt.Figure()
            
            work_fig.basemap(region='d', projection="N15c", frame=["af", "WSne"])
            work_fig.grdimage(grid = comp_grid, cmap = comp_cpt, projection="N15c", region = 'd', interpolation='n+a')
            work_fig.coast(shorelines="0.5p,black", projection='N15c')
            
            if slab_flag != 0:
                work_fig.plot(data = plot_slab_terms,
                              projection = 'N15c',
                              style = 'c0.15c',
                              pen = '0.25p,black',
                              incols = [0,1,5],
                              cmap = age_cpt)
                
            if melt_flag != 0:
                work_fig.plot(data = plot_melt_terms,
                              projection = 'N15c',
                              style = 'i0.15c',
                              pen = '0.25p,black',
                              incols = [0,1,5],
                              cmap = age_cpt)
            
            work_fig.colorbar(cmap = age_cpt)
                
            fig_name = fig_dir+'/'+str(d_list[x])+'_age_coded_terminals.png'
            
            work_fig.savefig(fig_name)
            

#######################################################################################
            
# slab_array_list = list of arrays --> each array contains the existing slabs for a slab tracking start time filtered into depth categories

def terminals_age_coded(comp_dir, slab_array_list, melt_array_list, age_list, depth_filter, fig_dir):
    
    # Cut out top and bottom depths
    comp_d_list = list_of_depths[1:-1]
    
    fn_list = []
    for file in sorted(os.listdir(comp_dir)):
        fn = os.path.join(comp_dir, file)
        
        fn_list.append(fn)
        
    # find indices for comparison grids which occur in depth filter
    comp_ind = [x for x, depth in enumerate(comp_d_list) if depth in depth_filter]
    
    for i, d_ind in enumerate(comp_ind):
        # List of arrays storing slabs at a given depth for each slab tracking start time, age code each one, turn into an array
        working_slab_list = []
        for x, at_age_array in enumerate(slab_array_list):
            # at_age_array = array containing existing slabs at age filtered into separate arrays based on particle depth
            # w_array = the slab particle array corresponding to the relevant comp_grd depth
            w_array = at_age_array[i]
            
            append_list = [age_list[x]]*len(w_array)
            w_array = np.c_[w_array, append_list]
            
            working_slab_list.append(w_array)
        
        working_slab_array = np.array(working_slab_list, dtype = 'object')
        
        # Find the slabs particles which match tomography from the above list of arrays
        plot_exist = []
        for points_array in working_slab_array:
            w_array = pygmt.grdtrack(points = points_array, grid = fn_list[d_ind], interpolation = 'n')
            w_array = w_array.to_numpy()
            
            for point in w_array:
                if point[-1] == 3 or point[-1] == 4:
                    plot_exist.append(point)
        Plot_exist = np.array(plot_exist)            
        ############### Perform above operations for melted slabs
        
        working_melt_list = []
        for x, at_age_array in enumerate(melt_array_list):
            w_array = at_age_array[i]
            
            append_list = [age_list[x]]*len(w_array)
            w_array = np.c_[w_array, append_list]
            
            working_melt_list.append(w_array)
        
        working_melt_array = np.array(working_melt_list, dtype = 'object')

        plot_melt = []
        for points_array in working_melt_array:
            w_array = pygmt.grdtrack(points = points_array, grid = fn_list[d_ind], interpolation = 'n')
            w_array = w_array.to_numpy()
            
            for point in w_array:
                if point[-1] == 3 or point[-1] == 4:
                    plot_melt.append(point)
        plot_melt = np.array(plot_melt)            
        ###################### Plot above data
        
        work_fig = pygmt.Figure()
        
        work_fig.basemap(region='d', projection="N15c", frame=["af", "WSne"])

        work_fig.grdimage(grid = fn_list[d_ind], projection = 'N15c', cmap = '/media/robby/arbiter/honours/colour_maps/comp_palette.cpt', interpolation = 'n+a')

        work_fig.coast(shorelines="0.5p,black", projection='N15c')
        
        
        work_fig.plot(data = plot_exist, projection = 'N15c', style = 'c0.15c', incols = [0,1,4], cmap = '/media/robby/arbiter/honours/colour_maps/terminals_age.cpt')
        work_fig.plot(data = plot_melt, projection = 'N15c', style = 'i0.15c', incols = [0,1,4], cmap = '/media/robby/arbiter/honours/colour_maps/terminals_age.cpt')

        work_fig.colorbar(cmap = '/media/robby/arbiter/honours/colour_maps/terminals_age.cpt', frame = ["a40", 'x+l"Slab Depth"', 'y+lkm'])
        
        figure_save_str = fig_dir+'/'+str()
        
        work_fig.savefig(figure_save_str)
            

In [23]:
def terminals_age_coded(all_slab_match_dicts, all_melt_match_dicts, comp_dir, fig_dir, reconstruction, list_of_depths, comp_cmap, age_cmap,
                        age_key = [40,80,120,160,200,240,280], cb_truncate = None):
    
    slab_match_arrays = []
    
    for age_dict in all_slab_match_dicts:
        working_array = np.array(list(age_dict.values()))
        slab_match_arrays.append(working_array)
        
    melt_match_arrays = []
    
    for age_dict in all_melt_match_dicts:
        working_array = np.array(list(age_dict.values()))
        melt_match_arrays.append(working_array)
        
    filtered_slab_arrays = []
    
    for array in slab_match_arrays:
        work_filter = slab_filter_depths(array, list_of_depths)
        filtered_slab_arrays.append(work_filter)
    
    filtered_melt_arrays = []
    
    for array in melt_match_arrays:
        work_filter = slab_filter_depths(array, list_of_depths)
        filtered_melt_arrays.append(work_filter)
    
    comp_grds = []
    for file in sorted(os.listdir(comp_dir)):
        fn = os.path.join(comp_dir, file)
        
        comp_grds.append(fn)
    
    working_depths = list_of_depths[1:-1]
    
    filtered_slab_arrays = [array[1:-1] for array in filtered_slab_arrays]
    filtered_melt_arrays = [array[1:-1] for array in filtered_melt_arrays]
    
    # x, i
    for x, comp_grd in enumerate(comp_grds):
        
        working_slab = [age_array[x] for age_array in filtered_slab_arrays]
        working_melt = [age_array[x] for age_array in filtered_melt_arrays]
        
        slab_lengths = [len(age_array[x]) for age_array in filtered_slab_arrays]
        melt_lengths = [len(age_array[x]) for age_array in filtered_melt_arrays]
        
        slab_check = sum(slab_lengths)
        melt_check = sum(melt_lengths)
        
        if (slab_check+melt_check) != 0:
            working_fig = pygmt.Figure()
            
            working_fig.basemap(region='d', projection="N15c", frame=["af", "WSne"])
            working_fig.grdimage(grid = comp_grd, projection = 'N15c', cmap = comp_cmap, interpolation = 'n+a')
            working_fig.coast(shorelines="0.5p,black", projection='N15c')
            
            work_exist = []
            if slab_check != 0:
                for i, array in enumerate(working_slab):
                    if len(array) !=0:
                        append_list = [age_key[i]]*len(array)
                        work_plot = np.c_[array, append_list]
                        work_exist.append(work_plot)
                
                plot_exist = []
                for array in work_exist:
                    for i in array:
                        plot_exist.append(i)
                plot_exist = np.array(plot_exist)
    
                working_fig.plot(data = plot_exist, projection = 'N15c', style = 'c0.2c', pen='0.1,black', incols = [0,1,4], cmap = age_cmap)        
            
            work_melt = []
            if melt_check != 0:
                for i, array in enumerate(working_melt):
                    if len(array) !=0:
                        append_list = [age_key[i]]*len(array)
                        plot_array = np.c_[array, append_list]
                        work_melt.append(plot_array)
                        
                plot_melt = []
                for array in work_melt:
                    for i in array:
                        plot_melt.append(i)
                plot_melt = np.array(plot_melt)
                        
                working_fig.plot(data = plot_melt, projection = 'N15c', style = 'i0.2c', pen='0.1,black', incols = [0,1,4], cmap = age_cmap)
            
            if cb_truncate == None:
                working_fig.colorbar(cmap = age_cmap, frame = ["a40", 'x+l"age"', 'y+lmyr'])
            elif cb_truncate == True:
                working_fig.colorbar(cmap = age_cmap, frame = ["a40", 'x+l"age"', 'y+lmyr'], truncate = [40,160])
                
            fig_name = fig_dir+'/'+reconstruction+'_'+str(working_depths[x])+'_age_coded_terminals.png'
            
            working_fig.savefig(fig_name)

In [24]:
### CREATE LIST OF SLABS MATCHING TOMOGRAPHY TO BE PLOTTED AS WORMS 

def working_slab_tracks(in_slab_dict, in_melt_dict, comp_dir):
    
    slab_dict_list = in_slab_dict[1:-1]
    melt_dict_list = in_melt_dict[1:-1]
    
    comp_fn_list = []
    for file in sorted(os.listdir(comp_dir)):
        fn = os.path.join(comp_dir, file)
        
        comp_fn_list.append(fn)
    
    plot_slab_worms = []
    plot_melt_worms = []
    for x, comp_grd in enumerate(comp_fn_list):
        
        if len(slab_dict_list[x]) != 0:
            ## Grab dictionary corresponding to comp depth, take the slab data and create an array from them, get list of file names for these slabs
            work_dict = slab_dict_list[x]
            tarray = np.array(list(work_dict.values()))
            fn_keys = list(work_dict.keys())
            
            # check grid value for slab terminals and create new array
            work_array = pygmt.grdtrack(points = tarray, grid = comp_grd, interpolation = 'n')
            work_array = work_array.to_numpy()
            
            working_points = []
            # if a terminal sits over a tomo slab, add corresponding file to list of slabs to be plotted
            for i, point in enumerate(work_array):
                if point[-1] == 3 or point[-1] == 4:
                    working_points.append(fn_keys[i])
                    
            plot_slab_worms.append(working_points)
                
        else:
            plot_slab_worms.append([])
            
        if len(melt_dict_list[x]) != 0:
            ## Grab dictionary corresponding to comp depth, take the slab data and create an array from them, get list of file names for these slabs
            work_dict = melt_dict_list[x]
            tarray = np.array(list(work_dict.values()))
            fn_keys = list(work_dict.keys())
            
            # check grid value for slab terminals and create new array
            work_array = pygmt.grdtrack(points = tarray, grid = comp_grd, interpolation = 'n')
            work_array = work_array.to_numpy()
            
            working_points = []
            # if a terminal sits over a tomo slab, add corresponding file to list of slabs to be plotted
            for i, point in enumerate(work_array):
                if point[-1] == 3 or point[-1] == 4:
                    working_points.append(fn_keys[i])
                    
            plot_melt_worms.append(working_points)
            
        else:
            plot_melt_worms.append([])
            
    return(plot_slab_worms, plot_melt_worms)

In [25]:
## Plot slabs vs subduction zones based on above/below checks

def plot_initial_slabs(list_of_slab_files):
    
    plot_initial_arrays = []
    
    # Iterate over lists of slab files
    if isinstance(list_of_slab_files[0], list):
        for slab_file_list in list_of_slab_files:
            
            working_list = []
            
            # Iterate over slab track files in slab files list
            for file in slab_file_list:
                apline = []
                # Open file and store the first and last lines
                with open(file, 'r') as slab_file:
                    lines = slab_file.readlines()
                    first_line = lines[0]
                    last_line = lines[-1]
                    
                    first_line = first_line.strip()
                    first_line = list(filter(None, first_line.split(' ')))
                    
                    last_line = last_line.strip()
                    last_line = list(filter(None, last_line.split(' ')))
                    
                    first_line = [float(x) for x in first_line]
                    
                    # Create line containing [initial_lon, initial_lat, terminal_depth]
                    apline.append(first_line[2] * 180/math.pi)
                    apline.append(90 - (first_line[1] * (180/math.pi)))
                    apline.append(6371 - (float(last_line[0])/1000))
                    
                    working_list.append(apline)
            
            # Create array that shows the start location and end depth for each slab in the file list
            plot_initial_arrays.append(np.array(working_list))
        
    else:
        slab_file_list = list_of_slab_files
        working_list = []
        
        # Iterate over slab track files in slab files list
        for file in slab_file_list:
            apline = []
            
            # Open file and store the first and last lines
            with open(file, 'r') as slab_file:
                lines = slab_file.readlines()
                first_line = lines[0]
                last_line = lines[-1]
                
                first_line = first_line.strip()
                first_line = list(filter(None, first_line.split(' ')))
                
                last_line = last_line.strip()
                last_line = list(filter(None, last_line.split(' ')))
                
                first_line = [float(x) for x in first_line]
                
                # Create line containing [initial_lon, initial_lat, terminal_depth]
                apline.append(first_line[2] * 180/math.pi)
                apline.append(90 - (first_line[1] * (180/math.pi)))
                apline.append(6371 - (float(last_line[0])/1000))
                
                working_list.append(apline)
        
        # Create array that shows the start location and end depth for each slab in the file list
        plot_initial_arrays.append(np.array(working_list))
    return(plot_initial_arrays)

##################################################################
##################################################################

### PLOT SLABS VS SUBDUCTION ZONES USING NEAREST DEPTH CHECK

def slabs_vs_subzones(in_slab_dict, in_melt_dict, comp_dir, list_of_depths):
    
    d_list = list_of_depths[1:-1]
    slab_dict = in_slab_dict[1:-1]
    melt_dict = in_melt_dict[1:-1]
    
    fn_list = []
    for file in sorted(os.listdir(comp_dir)):
        fn = os.path.join(comp_dir, file)
        
        fn_list.append(fn)

    slabs_ind = [x for x, dct in enumerate(slab_dict) if len(dct) != 0]
    melt_ind = [x for x, dct in enumerate(melt_dict) if len(dct) != 0]
    
    plot_slabs = []
    plot_problem_slabs = []
    
    plot_melt = []
    plot_problem_melt = []
    for x, comp_grd in enumerate(fn_list):
        
        if x in slabs_ind:
            ## Grab dictionary corresponding to comp depth, take the slab data and create an array from them, get list of file names for these slabs
            work_dict = slab_dict[x]
            tarray = np.array(list(work_dict.values()))
            fn_keys = list(work_dict.keys())
            
            # check grid value for slab terminals and create new array
            work_array = pygmt.grdtrack(points = tarray, grid = comp_grd, interpolation = 'n')
            work_array = work_array.to_numpy()
            
            # if a terminal sits over a tomo slab, add corresponding file to list of slabs to be plotted
            for i, point in enumerate(work_array):
                if point[-1] == 3 or point[-1] == 4:
                    plot_slabs.append(fn_keys[i])
                else:
                    plot_problem_slabs.append(fn_keys[i])
                            
        if x in melt_ind:
            ## Grab dictionary corresponding to comp depth, take the slab data and create an array from them, get list of file names for these slabs
            work_dict = melt_dict[x]
            tarray = np.array(list(work_dict.values()))
            fn_keys = list(work_dict.keys())
            
            # check grid value for slab terminals and create new array
            work_array = pygmt.grdtrack(points = tarray, grid = comp_grd, interpolation = 'n')
            work_array = work_array.to_numpy()
            
            # if a terminal sits over a tomo slab, add corresponding file to list of slabs to be plotted
            for i, point in enumerate(work_array):
                if point[-1] == 3 or point[-1] == 4:
                    plot_melt.append(fn_keys[i])
                else:
                    plot_problem_melt.append(fn_keys[i])
                    
    plot_sub_zones = [plot_slabs, plot_problem_slabs, plot_melt, plot_problem_melt]
    
    ## iterate through each list of files, grab the first line and format as 'lon,lat,depth'
    ## add each slab initial to a new list for that category (slab/melt, problems), append that list to a list of lists
    plot_initial_points = []
    for file_list in plot_sub_zones:
        
        working_list = []
        for file in file_list:
            apline = []
            with open(file, 'r') as slab_file:
                lines = slab_file.readlines()
                first_line = lines[0]
                last_line = lines[-1]
                
                first_line = first_line.strip()
                first_line = list(filter(None, first_line.split(' ')))
                
                last_line = last_line.strip()
                last_line = list(filter(None, last_line.split(' ')))
                
                first_line = [float(x) for x in first_line]
                
                apline.append(first_line[2] * 180/math.pi)
                apline.append(90 - (first_line[1] * (180/math.pi)))
                apline.append(6371 - (float(last_line[0])/1000))
                
                working_list.append(apline)
                
        working_array = np.array(working_list)
        plot_initial_points.append(working_array)
    
    return(plot_initial_points)

# AOU Slab Comparison

In [None]:
## READS IN Atlas of the Underworld DATA AND EXTRACTS THE RELEVANT DATA

# columns = depth, lon, lat, base_depth, error, top_depth, error
# first 3 columns = midpoint data

aou_array_pd = (pd.read_csv('/media/robby/arbiter/honours/others_data/atlas_of_the_underworld/aou_data.csv')).to_numpy()

#np_test = df_import.to_numpy()

aou_array_full = np.delete(aou_array_pd, [0,1,2], axis = 0)

keep_rows = [1,4,5,6,9,10,11,12]

aou_array = aou_array_full[:,keep_rows]

slab_name_list = list(aou_array[:,0])

aou_array = np.delete(aou_array,0,axis = 1)
aou_array = aou_array.astype(float)

print(len(aou_array))

keep_bases = [13,9,14,10]
keep_tops = [15,11,16,12]

aou_bases_str = aou_array_full[:,keep_bases]
aou_tops_str = aou_array_full[:,keep_tops]


In [27]:
## CREATES ARRAYS OF AOU SLAB BASE AND TOP DATA

#aou_bases = aou_bases.astype(float)
#aou_tops = aou_tops.astype(float)

work_aou_bases = []
for array in aou_bases_str:
    try:
        warray = array.astype(float)
        if warray[0] < 4.5e3:
            if np.all(warray==0) == False:
                
                age = (warray[0] + warray[2])/2
                error = abs(age - warray[0])
                
                warray[0] = age
                warray[2] = error
                work_aou_bases.append(warray)
        
    except:
        continue
        
aou_bases = np.array(work_aou_bases)

work_aou_tops = []
for array in aou_tops_str:
    try:
        warray = array.astype(float)
        if warray[0] < 4.5e3:
            if np.all(warray==0) == False:
                
                age = (warray[0] + warray[2])/2
                error = abs(age - warray[0])
                
                warray[0] = age
                warray[2] = error
                work_aou_tops.append(warray)
        
    except:
        continue
        
aou_tops = np.array(work_aou_tops)


In [28]:
## Calculate great circle distance between two points --> default radius is along the earths surface, optionally input radius variable
## Input lat/lon is in degrees

def earth_surf_dist(lat1, lon1, lat2, lon2, radius = 6371000):
    
    # define each lat points and convert to radians
    lat1 = np.radians(lat1)
    lat2 = np.radians(lat2)
    
    # Define each lon point and convert to radians
    lon1 = np.radians(lon1)
    lon2 = np.radians(lon2)
    
    # Calculate difference between start and end lat/lon points
    del_lat = lat2 - lat1
    del_lon = lon2 - lon1
    
    a = (math.sin(del_lat/2)**2) + math.cos(lat1)*math.cos(lat2)*(math.sin(del_lon/2)**2)
        
    c = 2 * math.atan2(math.sqrt(a), math.sqrt(1-a))
    
    # Calculate distance and convert from m to km
    dist = (radius*c)/1000
    
    # return(great circle distance between two points)
    return(dist)

In [30]:
def aou_slab_finder(list_of_depths, comp_dir, aou_name_keys, aou_slab_array, depth_range, reg_buffer, search_rad):
    
    # Cut out top and bottom depths from list of depths
    d_list = list_of_depths[1:-1]
    
    # Create list of comparison grid files
    grd_list = []
    for file in sorted(os.listdir(comp_dir)):
        fn = os.path.join(comp_dir, file)
        
        grd_list.append(fn)
    
    # initialise empty dictionary to store aou slab search data
    aou_dict = {}
    
    # Iterate over slab points in aou data
    for x, aou_slab in enumerate(aou_slab_array):
        
        # Find index of all comparison grids within a depth tolerance of the aou slab midpoint
        #search_grids = [i for i, depth in enumerate(d_list) if abs(aou_slab[0] - depth) < depth_range]
        search_grids = [i for i, depth in enumerate(d_list) if aou_slab[5] <= depth <= aou_slab[3]]
        
        working_list = []
        
        # Iterate over grids within depth tolerance
        for grid_ind in search_grids:
            
            search_grid = grd_list[grid_ind]
            
            # Set search region --> take aou slab lat/lon as centrepoint of a square region with side lengths 2*reg_buffer(degrees, reg_buffer on either side of centre point) 
            reg = [aou_slab[1]-reg_buffer, aou_slab[1]+reg_buffer, aou_slab[2]-reg_buffer, aou_slab[2]+reg_buffer]
            
            # cut comparison grid down to size of reg and store nodes in an array
            work_xyz = pygmt.grd2xyz(grid = search_grid, region = reg, output_type = 'numpy')
            
            # delete all nodes greater than search_rad km away from aou slab point
            del_list = [del_ind for del_ind, array in enumerate(work_xyz) if earth_surf_dist(aou_slab[2], aou_slab[1], array[1], array[0]) > search_rad]
            del_list.reverse()
            
            for del_ind in del_list:
                work_xyz = np.delete(work_xyz, del_ind, axis = 0)
            
            # Count all nodes with z values of 3 or 4 and store as count_slab, corresponding to true positive and false negative (identifies slabs which match
            # the tomography model used for comparison)
            comp_list = list(work_xyz[:,2])
            count_slab = comp_list.count(3) + comp_list.count(4)
            
            # Find ratio of count_slab/total_nodes, if >25% of nodes correspond to a tomography slab, declare that slab as existing at that depth in our tomography model
            slab_ratio = count_slab/len(comp_list)
            
            if slab_ratio >= 0.25:
                working_list.append(d_list[grid_ind])
        
        # Print out the names and location data of any slabs which don't match tomography at all --> not found in our tomography model
        if len(working_list) == 0:
            print(aou_name_keys[x], aou_slab)
            
        # create tuple matching aou slab location (lon,lat,depth) to the depths that slab is found in tomography --> after checking all depths within depth tolerance
        work_tup = ([aou_slab[1], aou_slab[2], aou_slab[0]], working_list)
        
        # Create new dictionary entry matching above tuple with the name of the slab so it can be compared to data in the aou
        aou_dict[aou_name_keys[x]] = work_tup
    
    # return(dictionary created above)
    return(aou_dict)
    

In [31]:
def tracks_vs_aou(track_dictionary, aou_dictionary, depth_tolerance, deg_tolerance):
    
    slab_keys = list(track_dictionary.keys()) # Create a list of the file names in the slab track dictionary
    aou_data = list(aou_dictionary.values()) # Create a list of arrays containing aou [[slab locations] [depths it is found on tomography]] 
    
    aou_data = [i[0] for i in aou_data] # aou_data = [lon, lat, depth] --> extract aou slab location data from the aou_data
    
    aou_depths = [i[2] for i in aou_data] # Create a list of depths at which AOU slab centres are found

    spag_files = [] # Initialise list to contain files of slab tracks that match AOU slab centres 
    
    # Iterate over slab files, check the present day location of the slab track, find index of aou_slabs within 100km of the tracked slab
    for key in slab_keys:
        work_array = track_dictionary[key]
        
        work_ind = [i for i, depth in enumerate(aou_depths) if abs(depth - work_array[2]) < depth_tolerance]
        
        # run following code only if there are AOU slabs within 100km depth of tracked slab
        if len(work_ind) != 0:
            
            # store location data for tracked slab
            lon1 = work_array[0]
            lat1 = work_array[1]
            slab_depth = work_array[2]
            
            # iterate over aou slabs within 100km depth of tracked slab
            for ind in work_ind:
                
                # Store location for aou slab
                lon2 = aou_data[ind][0]
                lat2 = aou_data[ind][1]
                
                # Create radius variable inbetween aou slab and tracked slab
                mean_rad = ((aou_data[ind][2]+slab_depth)/2)*1000
                
                # Calculate distance between aou slab and tracked slab
                work_dist = earth_surf_dist(lat1, lon1, lat2, lon2, mean_rad)
                
                work_deg = np.degrees((work_dist*1000)/mean_rad)
                
                # if tracked slab is within deg_tolerance degrees of aou slab, add tracked slab file to a list
                if work_deg < deg_tolerance:
                    spag_files.append(key)
    
    # Return list of slab track files that are within 100km depth and 250km lateral distance of an aou slab               
    return(spag_files)

# Distance Travelled vs Thermal Anomaly

In [32]:
def distance_v_temp(tracked_slab_dict, age_grid, sub_stats_array, check_age_conv = True, list_of_slab_files = []):
    
    if isinstance(tracked_slab_dict, dict) == True:
        tracked_slab_files = list(tracked_slab_dict.keys())
    else:
        tracked_slab_files = list_of_slab_files
    
    lat_rate_lists, depth_lists, lateral_dist_list, vertical_dist_list = displacement_list('filler', tracked_slab_files)
    
    if check_age_conv == True:
        age_list, mean_age = sample_age_grids(tracked_slab_files, age_grid)
        convergence_rates, age_list_unused, migration_rates = sample_conv_rate(tracked_slab_files, sub_stats_array)
    
    total_dist_list = [(val + vertical_dist_list[i]) for i, val in enumerate(lateral_dist_list)]
    
    plot_list = []
    
    #plot3d_list = []
    
    if isinstance(tracked_slab_dict, dict) == True:
        for x, slab_file in enumerate(tracked_slab_files):
            
            work_array = tracked_slab_dict[slab_file]
            
            #plot_point = [total_dist_list[x], work_array[-1]]
            #
            if check_age_conv == True:
                plot_point = [lateral_dist_list[x], vertical_dist_list[x], total_dist_list[x], work_array[-1], age_list[x], convergence_rates[x], migration_rates[x]]
            else:
                plot_point = [lateral_dist_list[x], vertical_dist_list[x], total_dist_list[x], work_array[-1]]
            
            plot_list.append(plot_point)
            
            #plot3d_point = [lateral_dist_list[x], vertical_dist_list[x], work_array[-1]]
            #plot3d_list.append(plot3d_point)
    else:
        for x, slab_file in enumerate(tracked_slab_files):
            
            #plot_point = [total_dist_list[x], work_array[-1]]
            #
            if check_age_conv == True:
                plot_point = [lateral_dist_list[x], vertical_dist_list[x], total_dist_list[x], 0, age_list[x], convergence_rates[x], migration_rates[x]]
            else:
                plot_point = [lateral_dist_list[x], vertical_dist_list[x], total_dist_list[x], 0]
            
            plot_list.append(plot_point)
            
            #plot3d_point = [lateral_dist_list[x], vertical_dist_list[x], work_array[-1]]
            #plot3d_list.append(plot3d_point)
    
    plot_array = np.array(plot_list)
    
    #plot3d_array = np.array(plot3d_list)
    
    return(plot_array)
    

# Success Metric

In [33]:
# Slab dicts should be in the order (slab_match, slab_prob, thermal_equilibrium_match, thermal_equilibrium_prob)

# list_of_list_of_dicts = [list_dicts_40ma, list_dicts_80ma, list_dicts_120ma, etc...]
# list_dicts_40ma = [gcm30_slab_match_dict_40ma, gcm30_slab_prob_dict_40ma, gcm30_melt_match_dict_40ma, gcm30_melt_prob_dict_40ma]

def all_ages_success_rate(list_of_list_of_dicts):
    
    match_dicts = [(len(list_dicts[0]) + len(list_dicts[2])) for list_dicts in list_of_list_of_dicts]
    
    prob_dicts = [(len(list_dicts[1]) + len(list_dicts[3])) for list_dicts in list_of_list_of_dicts]
    
    tomo_success_rate = sum(match_dicts)/(sum(match_dicts) + sum(prob_dicts))
    
    
    tp_slab_match = [len(list_dicts[0]) for list_dicts in list_of_list_of_dicts]
    
    tp_slab_prob = [(len(list_dicts[1]) + len(list_dicts[2]) + len(list_dicts[3])) for list_dicts in list_of_list_of_dicts]
    
    track_match_tp_rate = sum(tp_slab_match)/(sum(tp_slab_match) + sum(tp_slab_prob))
    
    tomo_success_rate = tomo_success_rate*100
    
    track_match_tp_rate = track_match_tp_rate*100
    
    return(tomo_success_rate, track_match_tp_rate)


def success_rate_per_age(list_of_list_of_dicts):
    
    tomo_success_rates = []
    
    model_success_rates = []
    
    for at_age_list in list_of_list_of_dicts:
        
        tp_num = len(at_age_list[0])
        fn_num = len(at_age_list[2])
        
        prob_num = len(at_age_list[1])+len(at_age_list[3])
        
        tomo_success = (tp_num+fn_num)/(tp_num+fn_num+prob_num)
        tomo_success_rates.append(tomo_success)
        
        model_success = tp_num/(tp_num+fn_num+prob_num)
        model_success_rates.append(model_success)
        
    tomo_success_rates = [(i*100) for i in tomo_success_rates]
    
    model_success_rates = [(i*100) for i in model_success_rates]
        
    # return(list showing success rate for each time period)
    return(tomo_success_rates, model_success_rates)

# Sample Age Grids

In [34]:
def sample_age_grids_defunct(slab_initials_list, age_grid, test_index = None, input_type = 'list_of_lists'):
    
    if test_index != None and input_type == 'list_of_lists':
        
        working_array = slab_initials_list[test_index]
        
        df_subduction_age = pygmt.grdtrack(grid = age_grid, points = working_array, z_only = True)
        
        work_list = df_subduction_age.values.tolist()
        
        age_list = [age for sublist in work_list for age in sublist]
        
        mean_age = sum(age_list)/len(age_list)
        
        return(age_list, mean_age)
    
    elif input_type != 'list_of_lists':
        
        slab_initials = plot_initial_slabs(slab_initials_list)
        
        working_array = slab_initials
        
        df_subduction_age = pygmt.grdtrack(grid = age_grid, points = working_array, z_only = True)
        
        work_list = df_subduction_age.values.tolist()
        
        age_list = [age for sublist in work_list for age in sublist]
        
        mean_age = sum(age_list)/len(age_list)
        
        return(age_list, mean_age)
    
    else:
        list_of_age_lists = []
        
        list_of_mean_ages = []
        
        for slab_list in slab_initials_list:
            
            working_array = slab_list
        
            df_subduction_age = pygmt.grdtrack(grid = age_grid, points = working_array, z_only = True)
            
            work_list = df_subduction_age.values.tolist()
            
            age_list = [age for sublist in work_list for age in sublist]
            
            mean_age = sum(age_list)/len(age_list)
            
            list_of_age_lists.append(age_list)
            list_of_mean_ages.append(mean_age)
            
        return(list_of_age_lists, list_of_mean_ages)
            
    

In [35]:
def sample_age_grids(age_slab_files, age_grid):
    
    work_initials = plot_initial_slabs(age_slab_files)
    
    working_array = work_initials[0]
    
    df_subduction_age = pygmt.grdtrack(grid = age_grid, points = working_array, z_only = True)
    #, nodata = '200'
    
    work_list = df_subduction_age.values.tolist()
    
    age_list = [age for sublist in work_list for age in sublist]
    
    mean_age = sum(age_list)/len(age_list)
    
    return(age_list, mean_age)

In [36]:
def clip_age_grids(agegrid_dir, output_dir, above_list):
    
    age_grids = []
    
    for file in sorted(os.listdir(agegrid_dir)):
        fn = os.path.join(agegrid_dir, file)
        
        age_grids.append(fn)
    
    for in_grid in age_grids:
        
        fn = (in_grid.split('/'))[-1]
        
        output = output_dir + '/' + fn
        
        pygmt.grdclip(grid = in_grid, outgrid = output, region = 'd', above = above_list)

# Sample Convergence Rates

In [37]:
def sample_conv_rate(tracked_slab_dict, sub_stats_array):
    
    if isinstance(tracked_slab_dict, dict) == True:
        track_files = list(tracked_slab_dict.keys())
    else:
        track_files = tracked_slab_dict
    work_initials = plot_initial_slabs(track_files)[0]
    
    convergence_rates = []
    seafloor_ages = []
    migration_rates = []
    
    test_subs = []
    test_tracks = []
    
    for array in work_initials:
        lon = array[0]
        lat = array[1]
        test_tracks.append([lon,lat])
        
        check_lon = 5
        work_stats = sub_stats_array[abs(lon - sub_stats_array[:,0]) < check_lon]
        while len(work_stats) == 0:
            check_lon = check_lon+2.5
            work_stats = sub_stats_array[abs(lon - sub_stats_array[:,0]) < check_lon]
        
        check_lat = 5
        final_stats = work_stats[abs(lat - work_stats[:,1]) < check_lat]
        while len(final_stats) == 0:
            check_lat = check_lat+2.5
            final_stats = work_stats[abs(lat - work_stats[:,1]) < check_lat]
        
        if len(final_stats) == 1:
            convergence_rates.append(final_stats[0][2])
            seafloor_ages.append(final_stats[0][4])
            test_subs.append([final_stats[0][0], final_stats[0][1]])
            migration_rates.append(final_stats[0][5])
            
        else:
            work_dists = []
            for stats in final_stats:
                lon2 = stats[0]
                lat2 = stats[1]
            
                dist = earth_surf_dist(lat, lon, lat2, lon2)
                work_dists.append(dist)
                
            ind = work_dists.index(max(work_dists))
            convergence_rates.append(final_stats[ind,2])
            seafloor_ages.append(final_stats[ind,4])
            test_subs.append([final_stats[ind][0], final_stats[ind][1]])
            migration_rates.append(final_stats[ind,5])
            
    # test_subs, test_tracks        
    return(convergence_rates, seafloor_ages, migration_rates)