# Spatio-temporal mineral prospectivity modelling of New Guinea

### Ehsan Farahbakhsh<sup>1</sup>, Sabin Zahirovic<sup>1</sup>, Brent I. A. McInnes<sup>2</sup>, Sara Polanco<sup>1</sup>, Fabian Kohlmann<sup>3</sup>, Maria Seton<sup>1</sup>, R. Dietmar M&uuml;ller<sup>1</sup>

<sup>1</sup>*EarthByte Group, School of Geosciences, The University of Sydney, Sydney, Australia*

<sup>2</sup>*John de Laeter Centre, Faculty of Science and Engineering, Curtin University, Perth, Australia*

<sup>3</sup>*Lithodat Pty. Ltd., Melbourne, Australia*

This notebook enables the user to create a spatio-temporal mineral prospectivity model of New Guinea. It comprises two main sections; in the first section, kinematic features are extracted, and in the second section, machine learning algorithms are applied to create a prospectivity model.

### Libraries

In [None]:
# setup the working environment
import warnings
warnings.filterwarnings('ignore')

import cartopy.crs as ccrs
from cartopy.mpl.gridliner import LONGITUDE_FORMATTER, LATITUDE_FORMATTER
import cmcrameri.cm as ccm
from collections import deque
import contextily as cx
import cv2
import geopandas as gpd
import gplately
from ipywidgets import interact
import math
import matplotlib as mpl
import matplotlib.cm as cm
from matplotlib.gridspec import GridSpec
from matplotlib.lines import Line2D
import matplotlib.patheffects as pe
from matplotlib.patches import Patch
import matplotlib.pyplot as plt
import matplotlib.ticker as mticker
from netCDF4 import Dataset
import numpy as np
import os
from osgeo import gdal
from osgeo import osr
import pandas as pd
import pickle
import plotly.express as px
from ptt.subduction_convergence import subduction_convergence_over_time
from pulearn import BaggingPuClassifier
from gplately import pygplates
from rasterio.plot import show
import rioxarray as rxr
from scipy import ndimage, stats
from scipy.cluster import hierarchy
from scipy.interpolate import griddata, make_interp_spline
from scipy.ndimage import gaussian_filter1d
import scipy.spatial
import seaborn as sns
import shapefile
from shapely.geometry import LineString, Point
import statistics

# machine learning
from sklearn.ensemble import RandomForestClassifier
from sklearn.experimental import enable_iterative_imputer
from sklearn.impute import IterativeImputer
from sklearn.metrics import accuracy_score, auc, confusion_matrix, f1_score, precision_score, recall_score, roc_auc_score, roc_curve
from sklearn.metrics import ConfusionMatrixDisplay
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MinMaxScaler, StandardScaler
from skopt import BayesSearchCV
from skopt.space import Categorical, Integer, Real

# from lib_stamp_muller2016 import *
from lib_stamp_muller2019 import *

# load parameters
# from parameters_muller2016 import parameters
from parameters_muller2019 import parameters

### Extract Convergence Kinematic Features

In [None]:
plate_motion_model = parameters['plate_motion_model']

# get start time, end time, and time step from parameters.py
time_period = parameters['time']
start_time = time_period['start']
end_time = time_period['end']
time_step = time_period['step']
time_steps = list(range(start_time, end_time+1, time_step))

conv_dir = parameters['convergence_data_dir']
conv_prefix = parameters['convergence_data_filename_prefix']
conv_ext = parameters['convergence_data_filename_ext']

trench_points_features(
    start_time,
    end_time,
    time_step,
    conv_dir,
    conv_prefix,
    conv_ext,
    plate_motion_model,
    random_state=42
)

### Create the Plate Reconstruction Model

In [None]:
if plate_motion_model == 'muller2016':
    rotation_files = parameters['rotation_file']
    topology_files = parameters['topology_file']
elif plate_motion_model == 'muller2019':
    rotation_files = [os.path.join(dirpath, f) for (dirpath, dirnames, filenames) in os.walk(parameters['rotation_dir']) for f in filenames]
    topology_files = [os.path.join(dirpath, f) for (dirpath, dirnames, filenames) in os.walk(parameters['topology_dir']) for f in filenames]

coastlines = parameters['coastlines_file']
static_polygons = parameters['static_polygons_file']
continents = parameters['coastlines_file']
cob = parameters['cob_file']

rotation_model = pygplates.RotationModel(rotation_files)

topology_features = pygplates.FeatureCollection()
for topology_file in topology_files:
    topology_features.add(pygplates.FeatureCollection(topology_file))

# use the PlateReconstruction object to create a plate motion model
model = gplately.PlateReconstruction(rotation_model, topology_features, static_polygons)

### Plot Kinematic Features in a Global Scale

In [None]:
selected_features = parameters['selected_features']
selected_features_plot = selected_features.copy()
selected_features_plot.remove('distance_deg')

agegrid_dir = parameters['agegrid_dir']

extent_globe = [-180, 180, -90, 90]

@interact
def show_map(time=time_steps, feature=selected_features_plot):
    # call the PlotTopologies object
    gplot = gplately.PlotTopologies(model, coastlines, continents, cob, time=time)
    
    if plate_motion_model == 'muller2016':
        agegrid_file = agegrid_dir + f'Muller_etal_2016_AREPS_v1.17_AgeGrid-{time}.nc'
    elif plate_motion_model == 'muller2019':
        agegrid_file = agegrid_dir + f'Muller_etal_2019_Tectonics_v2.0_AgeGrid-{time}.nc'
    
    agegrid = gplately.grids.read_netcdf_grid(agegrid_file)
    
    features_t = pd.read_csv(f'{conv_dir}/{conv_prefix}_{time}.00.{conv_ext}', index_col=False)

    # dual colour bars
    fig = plt.figure(figsize=(16, 12))
    gs = GridSpec(2, 2, hspace=-0.4, wspace=0.4, height_ratios=[1, 0.02])
    ax = fig.add_subplot(gs[0, :], projection=ccrs.Mollweide(central_longitude=150))
    
    ax.set_global()

    im = gplot.plot_grid(ax, agegrid.data, cmap=ccm.lapaz_r, vmin=0, vmax=230, alpha=0.5, zorder=1) # cmap: Blues, viridis, winter

    gplot.plot_continents(ax, edgecolor='none', facecolor='tan', zorder=2) # facecolor: tan
    gplot.plot_ridges(ax, color='red', alpha=0.5, zorder=3)
    gplot.plot_plate_motion_vectors(ax, spacingX=10, spacingY=10, normalise=True, alpha=0.1, zorder=4)

    sc = ax.scatter(features_t['trench_lon'], features_t['trench_lat'], 50, marker='.',
                    c=features_t[feature], cmap=ccm.hawaii_r, transform=ccrs.PlateCarree(), zorder=5) # cmap: Spectral_r, YlOrRd

    gplot.plot_trenches(ax, color='k', alpha=0.3, zorder=6)
    gplot.plot_subduction_teeth(ax, spacing=0.05, color='k', alpha=0.3, zorder=7)
    
    ax.gridlines(linestyle=':')
    
    cax1 = fig.add_subplot(gs[1, 0])
    cax2 = fig.add_subplot(gs[1, 1])
    
    fig.colorbar(sc, cax=cax2, orientation='horizontal', label=feature, extend='both')
    fig.colorbar(im, cax=cax1, orientation='horizontal', label='Seafloor age (Ma)', extend='max')
    
    # Define custom legend handles
    custom_handles = [
        Patch(facecolor='tan', edgecolor='none', label='Continental Crust'),  # Custom handle for the filled polygon
        Line2D([0], [0], color='red', lw=2, label='Mid-Ocean Ridge')  # Custom handle for the line (ridge)
    ]

    # Add the custom legend to the plot
    ax.legend(handles=custom_handles, loc='lower left')
    
    ax.set_title(f'Subduction Zones {time} Ma')
        
    plt.show()

### Clip Trench Points based on the Target Extent

In [None]:
target_extent_anim_file = parameters['target_extent_animation']
target_extent_anim_gdf = gpd.read_file(target_extent_anim_file)
target_extent_anim_bounds = target_extent_anim_gdf.bounds
target_extent_anim = [target_extent_anim_bounds.loc[0]['minx'], target_extent_anim_bounds.loc[0]['maxx'],
                      target_extent_anim_bounds.loc[0]['miny'], target_extent_anim_bounds.loc[0]['maxy']]

if target_extent_anim[0] < -180:
    target_extent_anim[0] = -180
if target_extent_anim[1] > 180:
    target_extent_anim[1] = 180
if target_extent_anim[2] < -90:
    target_extent_anim[2] = -90
if target_extent_anim[3] > 90:
    target_extent_anim[3] = 90

features_target_extent_anim_files_lst = []

for time in time_steps:
    features_target_extent_anim_files_lst.append(f'{conv_dir}/{conv_prefix}_target_extent_animation_{time}.00.{conv_ext}')

for features_target_extent_anim_file in features_target_extent_anim_files_lst:
    if not os.path.isfile(features_target_extent_anim_file):
        time = features_target_extent_anim_files_lst.index(features_target_extent_anim_file)
        features_t = pd.read_csv(f'{conv_dir}/{conv_prefix}_{time}.00.{conv_ext}', index_col=False)
        features_target_extent_anim_lst = []

        for i in range(features_t.shape[0]):
            x = features_t.iloc[i]['trench_lon']
            y = features_t.iloc[i]['trench_lat']
            p = Point((x, y))
            if p.within(target_extent_anim_gdf.geometry[0]):
                features_target_extent_anim_lst.append(features_t.iloc[i].values)

        features_target_extent_anim = pd.DataFrame(np.row_stack(features_target_extent_anim_lst), columns=features_t.columns)
        features_target_extent_anim.to_csv(features_target_extent_anim_file, index=False)

In [None]:
target_extent_map_file = parameters['target_extent_map']
target_extent_map_gdf = gpd.read_file(target_extent_map_file)
target_extent_map_bounds = target_extent_map_gdf.bounds
target_extent_map = [target_extent_map_bounds.loc[0]['minx'], target_extent_map_bounds.loc[0]['maxx'],
                     target_extent_map_bounds.loc[0]['miny'], target_extent_map_bounds.loc[0]['maxy']]

if target_extent_map[0] < -180:
    target_extent_map[0] = -180
if target_extent_map[1] > 180:
    target_extent_map[1] = 180
if target_extent_map[2] < -90:
    target_extent_map[2] = -90
if target_extent_map[3] > 90:
    target_extent_map[3] = 90

features_target_extent_map_files_lst = []

for time in time_steps:
    features_target_extent_map_files_lst.append(f'{conv_dir}/{conv_prefix}_target_extent_map_{time}.00.{conv_ext}')

for features_target_extent_map_file in features_target_extent_map_files_lst:
    if not os.path.isfile(features_target_extent_map_file):
        time = features_target_extent_map_files_lst.index(features_target_extent_map_file)
        features_t = pd.read_csv(f'{conv_dir}/{conv_prefix}_{time}.00.{conv_ext}', index_col=False)
        features_target_extent_map_lst = []

        for i in range(features_t.shape[0]):
            x = features_t.iloc[i]['trench_lon']
            y = features_t.iloc[i]['trench_lat']
            p = Point((x, y))
            if p.within(target_extent_map_gdf.geometry[0]):
                features_target_extent_map_lst.append(features_t.iloc[i].values)

        features_target_extent_map = pd.DataFrame(np.row_stack(features_target_extent_map_lst), columns=features_t.columns)
        features_target_extent_map.to_csv(features_target_extent_map_file, index=False)

### Plot Kinematic Features based on the Target Extent

In [None]:
def set_ax(ax, extent, interval_x, interval_y, font_size=None, stock_img=True, order=None):
    if stock_img:
        ax.stock_img()

    ax.set_extent(extent)

    gl = ax.gridlines(crs=ccrs.PlateCarree(), draw_labels=True,
                      linewidth=1, color='gray', alpha=0.5, linestyle='--', 
                      xlocs=np.arange(-180, 180, interval_x), ylocs=np.arange(-90, 90, interval_y), zorder=order)
    gl.xlabels_top = False
    gl.ylabels_right = False
    gl.xformatter = LONGITUDE_FORMATTER
    gl.yformatter = LATITUDE_FORMATTER
    gl.xlabel_style = {'color': 'gray', 'weight': 'bold', 'fontsize': font_size}
    gl.ylabel_style = {'color': 'gray', 'weight': 'bold', 'fontsize': font_size}

In [None]:
proj = ccrs.LambertAzimuthalEqualArea(150, 0)

@interact
def show_map(time=time_steps, feature=selected_features_plot):
    # call the PlotTopologies object
    gplot = gplately.PlotTopologies(model, coastlines, continents, cob, time=time)
    
    if plate_motion_model == 'muller2016':
        agegrid_file = agegrid_dir + f'Muller_etal_2016_AREPS_v1.17_AgeGrid-{time}.nc'
    elif plate_motion_model == 'muller2019':
        agegrid_file = agegrid_dir + f'Muller_etal_2019_Tectonics_v2.0_AgeGrid-{time}.nc'
    
    agegrid = gplately.grids.read_netcdf_grid(agegrid_file)
    
    features_t = pd.read_csv(f'{conv_dir}/{conv_prefix}_target_extent_map_{time}.00.{conv_ext}', index_col=False)

    # dual colour bars
    fig = plt.figure(figsize=(6, 8))
    gs = GridSpec(2, 2, hspace=-0.75, wspace=0.1, height_ratios=[1, 0.01])
    ax = fig.add_subplot(gs[0, :], projection=proj)
    
    set_ax(ax, target_extent_map, 10, 5, stock_img=False, order=8)

    im = gplot.plot_grid(ax, agegrid.data, cmap=ccm.lapaz_r, vmin=0, vmax=153, alpha=0.5, zorder=1)

    gplot.plot_continents(ax, edgecolor='gray', facecolor='tan', zorder=2)
    gplot.plot_ridges(ax, color='red', alpha=0.5, zorder=3)
    gplot.plot_plate_motion_vectors(ax, spacingX=10, spacingY=10, normalise=False, regrid_shape=10, alpha=0.2, zorder=4)

    sc = ax.scatter(features_t['trench_lon'], features_t['trench_lat'], 50, marker='.',
                    c=features_t[feature], cmap=ccm.hawaii_r, transform=ccrs.PlateCarree(), zorder=5)

    gplot.plot_trenches(ax, color='k', alpha=0.3, zorder=6)
    gplot.plot_subduction_teeth(ax, spacing=0.03, color='k', alpha=0.3, zorder=7)
        
    cax1 = fig.add_subplot(gs[1, 0])
    cax2 = fig.add_subplot(gs[1, 1])
    
    fig.colorbar(sc, cax=cax2, orientation='horizontal', label=feature, extend='both')
    fig.colorbar(im, cax=cax1, orientation='horizontal', label='Seafloor age (Ma)', extend='max')
    
    # Define custom legend handles
    custom_handles = [
        Patch(facecolor='tan', edgecolor='gray', label='Continental Crust'),  # Custom handle for the filled polygon
        Line2D([0], [0], color='red', lw=2, label='Mid-Ocean Ridge')  # Custom handle for the line (ridge)
    ]

    # Add the custom legend to the plot
    legend = ax.legend(handles=custom_handles, loc='upper right', bbox_to_anchor=(1.4, 1), borderaxespad=0.)
    
#     plt.savefig(
#     f'./figures/muller2019/features/conv_angle_deg_edited_.png',
#     bbox_inches='tight',
#     pad_inches=0.1,
#     dpi=150
#     )
    
    plt.show()

### Reconstruct Mineral Occurrences

In [None]:
min_occ_file = parameters['min_occ_file']
coreg_input_dir = parameters['coreg_input_dir']
coreg_input_files = parameters['coreg_input_files']
min_occ_data_file = coreg_input_dir + coreg_input_files[0] # returns 'mineral_occurrences'

if os.path.isfile(min_occ_data_file):
    min_occ_data = pd.read_csv(min_occ_data_file, index_col=False)
else:
    # id, lon, lat, age, and plate id of mineral occurrences
    min_occ_data = process_real_deposits(min_occ_file, start_time, end_time, time_step, plate_motion_model)
    # save the attributes of mineral occurrences
    min_occ_data.to_csv(min_occ_data_file, index=False)

In [None]:
proj = ccrs.LambertAzimuthalEqualArea(150, 0)

@interact
def show_map(time=time_steps, feature=selected_features_plot):
    lons_lats_recon = []
    
    for min_occ in min_occ_data.iterrows():
        if time == 0:
            lons_lats_recon.append((min_occ[1]['lon'], min_occ[1]['lat'], min_occ[1]['weight']))
        elif int(min_occ[1]['age']) < time:
            lons_lats_recon.append((np.nan, np.nan, np.nan))
        elif int(min_occ[1]['age']) == time:
            lons_lats_recon.append((min_occ[1]['lon_recon'], min_occ[1]['lat_recon'], min_occ[1]['weight']))
        else:
            lat_lon_recon = get_recon_ccords([min_occ[1]['lon']],
                                             [min_occ[1]['lat']],
                                             plate_motion_model='muller2019',
                                             time=time)[0]
            lons_lats_recon.append(tuple((lat_lon_recon[1], lat_lon_recon[0], min_occ[1]['weight'])))

    # call the PlotTopologies object
    gplot = gplately.PlotTopologies(model, coastlines, continents, cob, time=time)
    
    if plate_motion_model == 'muller2016':
        agegrid_file = agegrid_dir + f'Muller_etal_2016_AREPS_v1.17_AgeGrid-{time}.nc'
    elif plate_motion_model == 'muller2019':
        agegrid_file = agegrid_dir + f'Muller_etal_2019_Tectonics_v2.0_AgeGrid-{time}.nc'
    
    agegrid = gplately.grids.read_netcdf_grid(agegrid_file)
    
    features_t = pd.read_csv(f'{conv_dir}/{conv_prefix}_{time}.00.{conv_ext}', index_col=False)

    # dual colour bars
    fig = plt.figure(figsize=(6, 8))
    gs = GridSpec(2, 2, hspace=-0.75, wspace=0.1, height_ratios=[1, 0.01])
    ax = fig.add_subplot(gs[0, :], projection=proj)
    
    set_ax(ax, target_extent_map, 10, 5, stock_img=False, order=9)

    im = gplot.plot_grid(ax, agegrid.data, cmap=ccm.lapaz_r, vmin=0, vmax=153, alpha=0.5, zorder=1)

    gplot.plot_continents(ax, edgecolor='gray', facecolor='tan', zorder=2)
    gplot.plot_ridges(ax, color='red', alpha=0.5, zorder=3)
    gplot.plot_plate_motion_vectors(ax, spacingX=10, spacingY=10, normalise=False, regrid_shape=10, alpha=0.2, zorder=4)

    sc0 = ax.scatter(features_t['trench_lon'], features_t['trench_lat'], 50, marker='.',
                    c=features_t[feature], cmap=ccm.hawaii_r, transform=ccrs.PlateCarree(), zorder=5)
    
    sc1 = ax.scatter(
        [coords[0] for coords in lons_lats_recon],
        [coords[1] for coords in lons_lats_recon],
        transform=ccrs.PlateCarree(),
        marker='o',
        facecolor='yellow',
        edgecolor='black',
        s=[x * 20 for x in [coords[2] for coords in lons_lats_recon]],
        alpha=0.7,
        zorder=6
    )

    gplot.plot_trenches(ax, color='k', alpha=0.3, zorder=7)
    gplot.plot_subduction_teeth(ax, spacing=0.03, color='k', alpha=0.3, zorder=8)
    
    cax1 = fig.add_subplot(gs[1, 0])
    cax2 = fig.add_subplot(gs[1, 1])
    
    fig.colorbar(sc0, cax=cax2, orientation='horizontal', label=feature, extend='both')
    fig.colorbar(im, cax=cax1, orientation='horizontal', label='Seafloor age (Ma)', extend='max')
    
    # Define custom legend handles
    custom_handles = [
        Patch(facecolor='tan', edgecolor='gray', label='Continental Crust'),  # Custom handle for the filled polygon
        Line2D([0], [0], color='red', lw=2, label='Mid-Ocean Ridge'),  # Custom handle for the line (ridge)
        Line2D([0], [0], marker='o', markerfacecolor='yellow', markeredgecolor='black', markersize=5, linestyle='None', label='Mineral Occurrence')  # Custom handle for mineral occurrences
    ]

    # Add the custom legend to the plot
    legend = ax.legend(handles=custom_handles, loc='upper right', bbox_to_anchor=(1.42, 1), borderaxespad=0.)
    
    plt.show()

### Create Buffer Zones

In [None]:
if plate_motion_model == 'muller2016':
    rotation_files = parameters['rotation_files']
    topology_files = parameters['topology_files']
elif plate_motion_model == 'muller2019':
    rotation_files = [os.path.join(dirpath, f) for (dirpath, dirnames, filenames) in os.walk(parameters['rotation_dir']) for f in filenames]
    topology_files = [os.path.join(dirpath, f) for (dirpath, dirnames, filenames) in os.walk(parameters['topology_dir']) for f in filenames]

buffer_zones_files_lst = []

for time in time_steps:
    buffer_zones_files_lst.append(f'{coreg_input_dir}buffer_zones/buffer_zone_{time}_Ma.shp')

for buffer_zone_file in buffer_zones_files_lst:
    if not os.path.isfile(buffer_zone_file):
        index = buffer_zones_files_lst.index(buffer_zone_file)
        
        resolved_topologies = []
        shared_boundary_sections = []
        # use pygplates to resolve the topologies
        pygplates.resolve_topologies(topology_files, rotation_files, resolved_topologies, time_steps[index], shared_boundary_sections)

        # subduction zones
        subduction_geoms = []
        get_subduction_geometries(subduction_geoms, shared_boundary_sections)

        _, buffer_zone = generate_buffer_zones(subduction_geoms, width=3)
        buffer_zone.to_file(buffer_zone_file)
        print(f'Buffer zones saved to {buffer_zone_file}')

buffer_zones_lst = []
buffer_zones_clipped_lst = []

for buffer_zone_file in buffer_zones_files_lst:
    buffer_zone = gpd.read_file(buffer_zone_file)
    buffer_zone_clipped = buffer_zone.clip(target_extent_anim_gdf)
    buffer_zones_lst.append(buffer_zone)
    buffer_zones_clipped_lst.append(buffer_zone_clipped)

### Plot Buffer Zones in a Global Scale

In [None]:
proj = ccrs.Mollweide(central_longitude=150)

@interact
def show_map(time=time_steps):
    # call the PlotTopologies object
    gplot = gplately.PlotTopologies(model, coastlines, continents, cob, time=time)
    
    if plate_motion_model == 'muller2016':
        agegrid_file = agegrid_dir + f'Muller_etal_2016_AREPS_v1.17_AgeGrid-{time}.nc'
    elif plate_motion_model == 'muller2019':
        agegrid_file = agegrid_dir + f'Muller_etal_2019_Tectonics_v2.0_AgeGrid-{time}.nc'
    
    agegrid = gplately.grids.read_netcdf_grid(agegrid_file)
    
    # single colour bar
    fig = plt.figure(figsize=(16, 12))
    ax = plt.axes(projection=proj)
    ax.set_global()

    im = gplot.plot_grid(ax, agegrid.data, cmap=ccm.lapaz_r, vmin=0, vmax=230, alpha=0.5, zorder=1)

    gplot.plot_continents(ax, edgecolor='none', facecolor='tan', zorder=2)
    gplot.plot_ridges(ax, color='red', alpha=0.5, zorder=3)
    gplot.plot_plate_motion_vectors(ax, spacingX=10, spacingY=10, normalise=True, alpha=0.1, zorder=4)

    buffer_zones_lst[time_steps.index(time)].plot(
        ax=ax,
        transform=ccrs.PlateCarree(),
        edgecolor='none',
        facecolor='gray',
        linewidth=1,
        alpha=0.7,
        zorder=5,
    )
    
    gplot.plot_trenches(ax, color='k', zorder=6)
    gplot.plot_subduction_teeth(ax, spacing=0.05, color='k', zorder=7)
    
    ax.gridlines(linestyle=':')
    
    fig.colorbar(im, orientation='horizontal', shrink=0.4, pad=0.05, label='Seafloor Age (Ma)', extend='max')
    
    # Define custom legend handles
    custom_handles = [
        Patch(facecolor='tan', edgecolor='none', label='Continental Crust'),  # Custom handle for the filled polygon
        Patch(facecolor='gray', edgecolor='none', label='Target Areas in\nBack-Arc Basins'),  # Custom handle for the buffer zones
        Line2D([0], [0], color='red', lw=2, label='Mid-Ocean Ridge')  # Custom handle for the line (ridge)
    ]

    # Add the custom legend to the plot
    ax.legend(handles=custom_handles, loc='lower left')
    
    plt.show()

### Plot Buffer Zones based on the Target Extent

In [None]:
proj = ccrs.LambertAzimuthalEqualArea(150, 0)

@interact
def show_map(time=time_steps):
    # call the PlotTopologies object
    gplot = gplately.PlotTopologies(model, coastlines, continents, cob, time=time)
    
    if plate_motion_model == 'muller2016':
        agegrid_file = agegrid_dir + f'Muller_etal_2016_AREPS_v1.17_AgeGrid-{time}.nc'
    elif plate_motion_model == 'muller2019':
        agegrid_file = agegrid_dir + f'Muller_etal_2019_Tectonics_v2.0_AgeGrid-{time}.nc'
    
    agegrid = gplately.grids.read_netcdf_grid(agegrid_file)
    
    # single colour bar
    fig = plt.figure(figsize=(6, 8))
    ax = plt.axes(projection=proj)
    
    set_ax(ax, target_extent_map, 10, 5, stock_img=False, order=8)

    im = gplot.plot_grid(ax, agegrid.data, cmap=ccm.lapaz_r, vmin=0, vmax=153, alpha=0.5, zorder=1)

    gplot.plot_continents(ax, edgecolor='gray', facecolor='tan', zorder=2)
    gplot.plot_ridges(ax, color='red', alpha=0.5, zorder=3)
    gplot.plot_plate_motion_vectors(ax, spacingX=10, spacingY=10, normalise=False, regrid_shape=10, alpha=0.2, zorder=4)
    
    buffer_zones_clipped_lst[time_steps.index(time)].plot(
        ax=ax,
        transform=ccrs.PlateCarree(),
        edgecolor='none',
        facecolor='gray',
        linewidth=1,
        alpha=0.7,
        zorder=5
    )
    
    gplot.plot_trenches(ax, color='k', alpha=0.3, zorder=6)
    gplot.plot_subduction_teeth(ax, spacing=0.03, color='k', alpha=0.3, zorder=7)
    
    fig.colorbar(im, orientation='horizontal', shrink=0.4, pad=0.05, label='Seafloor Age (Ma)', extend='max')
    
    # Define custom legend handles
    custom_handles = [
        Patch(facecolor='tan', edgecolor='gray', label='Continental Crust'),  # Custom handle for the filled polygon
        Patch(facecolor='gray', edgecolor='none', label='Target Areas in\nBack-Arc Basins'),  # Custom handle for the buffer zones
        Line2D([0], [0], color='red', lw=2, label='Mid-Ocean Ridge')  # Custom handle for the line (ridge)
    ]

    # Add the custom legend to the plot
    legend = ax.legend(handles=custom_handles, loc='upper right', bbox_to_anchor=(1.4, 1), borderaxespad=0.)
    
    plt.show()

### Generate Random Samples

In [None]:
random_data_file = coreg_input_dir + coreg_input_files[1]
num_features = len(selected_features)

if os.path.isfile(random_data_file):
    random_data = pd.read_csv(random_data_file, index_col=False)
    time_steps_random = random_data['age'].tolist()
else:
    time_steps_random, random_data = generate_random_samples(buffer_zones_clipped_lst,
                                          start_time=start_time,
                                          end_time=end_time,
                                          time_step=time_step,
                                          num_features=num_features,
                                          num_features_factor=5,
                                          rand_factor=20,
                                          plate_motion_model='muller2019',
                                          random_state=42
                                         )
    random_data.to_csv(random_data_file, index=False, float_format='%.4f')

### Plot Random Samples based on the Target Extent

In [None]:
time_steps_random_sorted = sorted(time_steps_random)
time_steps_random_sorted = [*set(time_steps_random_sorted)]

proj = ccrs.LambertAzimuthalEqualArea(150, 0)

@interact
def show_map(time=time_steps_random_sorted):
    # call the PlotTopologies object
    gplot = gplately.PlotTopologies(model, coastlines, continents, cob, time=time)
    
    if plate_motion_model == 'muller2016':
        agegrid_file = agegrid_dir + f'Muller_etal_2016_AREPS_v1.17_AgeGrid-{time}.nc'
    elif plate_motion_model == 'muller2019':
        agegrid_file = agegrid_dir + f'Muller_etal_2019_Tectonics_v2.0_AgeGrid-{time}.nc'
    
    agegrid = gplately.grids.read_netcdf_grid(agegrid_file)
    
    # single colour bar
    fig = plt.figure(figsize=(6, 8))
    ax = plt.axes(projection=proj)
    
    set_ax(ax, target_extent_map, 10, 5, stock_img=False, order=9)

    im = gplot.plot_grid(ax, agegrid.data, cmap=ccm.lapaz_r, vmin=0, vmax=153, alpha=0.5, zorder=1)

    gplot.plot_continents(ax, edgecolor='gray', facecolor='tan', zorder=2)
    gplot.plot_ridges(ax, color='red', alpha=0.5, zorder=3)
    gplot.plot_plate_motion_vectors(ax, spacingX=10, spacingY=10, normalise=False, regrid_shape=10, alpha=0.2, zorder=4)
    
    buffer_zones_clipped_lst[time_steps.index(time)].plot(
        ax=ax,
        transform=ccrs.PlateCarree(),
        edgecolor='none',
        facecolor='gray',
        linewidth=1,
        alpha=0.7,
        zorder=5
    )
    
    random_samples = random_data.loc[random_data['age'] == time]
    ax.scatter(
        random_samples['lon'],
        random_samples['lat'],
        transform=ccrs.PlateCarree(),
        marker='X',
        edgecolor='black',
        facecolor='cyan',
        s=50,
        zorder=8
    )
    
    gplot.plot_trenches(ax, color='k', alpha=0.3, zorder=6)
    gplot.plot_subduction_teeth(ax, spacing=0.03, color='k', alpha=0.3, zorder=7)
    
    fig.colorbar(im, orientation='horizontal', shrink=0.4, pad=0.05, label='Seafloor Age (Ma)', extend='max')
    
    # Define custom legend handles
    custom_handles = [
        Patch(facecolor='tan', edgecolor='gray', label='Continental Crust'),  # Custom handle for the filled polygon
        Patch(facecolor='gray', edgecolor='none', label='Target Areas in\nBack-Arc Basins'),  # Custom handle for the buffer zones
        Line2D([0], [0], color='red', lw=2, label='Mid-Ocean Ridge')  # Custom handle for the line (ridge)
    ]

    # Add the custom legend to the plot
    legend = ax.legend(handles=custom_handles, loc='upper right', bbox_to_anchor=(1.4, 1), borderaxespad=0.)
        
    plt.show()

### Generate Target Points in Back-Arc Basins

In [None]:
target_points_coreg_in_files_lst = []
mask_coords_files_lst = []

for time in time_steps:
    target_points_coreg_in_files_lst.append(coreg_input_dir + coreg_input_files[2] + f'_{time}_Ma.csv')
    mask_coords_files_lst.append(coreg_input_dir + f'mask_{time}_Ma.csv')

for target_points_file, mask_coords_file in zip(target_points_coreg_in_files_lst, mask_coords_files_lst):
    if not(os.path.isfile(target_points_file) and os.path.isfile(mask_coords_file)):
        index = target_points_coreg_in_files_lst.index(target_points_file)

        # generate target points
        target_points, mask_coords, nx, ny = generate_samples(buffer_zones_clipped_lst[index], 0.2, 0.2, # dist_x and dist_y
                                                              time_steps[index], plate_motion_model='muller2019')
        # save the attributes of target points
        target_points.to_csv(target_points_file, index=False, float_format='%.4f')
        # save the mask
        mask_coords.to_csv(mask_coords_file, index=False, float_format='%.4f')
        print(f'Target points saved to {target_points_file}')

target_points_coreg_in_lst = []
mask_coords_lst = []

for target_points_file, mask_coords_file in zip(target_points_coreg_in_files_lst, mask_coords_files_lst):
    target_points_coreg_in_lst.append(pd.read_csv(target_points_file, index_col=False))
    mask_coords_lst.append(pd.read_csv(mask_coords_file, index_col=False))

### Plot Target Points based on the Target Extent

In [None]:
proj = ccrs.LambertAzimuthalEqualArea(150, 0)

@interact
def show_map(time=time_steps):
    # call the PlotTopologies object
    gplot = gplately.PlotTopologies(model, coastlines, continents, cob, time=time)
    
    if plate_motion_model == 'muller2016':
        agegrid_file = agegrid_dir + f'Muller_etal_2016_AREPS_v1.17_AgeGrid-{time}.nc'
    elif plate_motion_model == 'muller2019':
        agegrid_file = agegrid_dir + f'Muller_etal_2019_Tectonics_v2.0_AgeGrid-{time}.nc'

    agegrid = gplately.grids.read_netcdf_grid(agegrid_file)
    
    # single colour bar
    fig = plt.figure(figsize=(6, 8))
    ax = plt.axes(projection=proj)
    
    set_ax(ax, target_extent_map, 10, 5, stock_img=False, order=9)

    im = gplot.plot_grid(ax, agegrid.data, cmap=ccm.lapaz_r, vmin=0, vmax=153, alpha=0.5, zorder=1)

    gplot.plot_continents(ax, edgecolor='gray', facecolor='tan', zorder=2)
    gplot.plot_ridges(ax, color='red', alpha=0.5, zorder=3)
    gplot.plot_plate_motion_vectors(ax, spacingX=10, spacingY=10, normalise=False, regrid_shape=10, alpha=0.2, zorder=4)
    
    buffer_zones_clipped_lst[time_steps.index(time)].plot(
        ax=ax,
        transform=ccrs.PlateCarree(),
        edgecolor='none',
        facecolor='gray',
        linewidth=1,
        alpha=0.7,
        zorder=5
    )
    
    ax.scatter(
        target_points_coreg_in_lst[time_steps.index(time)]['lon'],
        target_points_coreg_in_lst[time_steps.index(time)]['lat'],
        transform=ccrs.PlateCarree(),
        marker='.',
        c='red',
        s=1,
        zorder=8
    )
        
    gplot.plot_trenches(ax, color='k', alpha=0.3, zorder=6)
    gplot.plot_subduction_teeth(ax, spacing=0.03, color='k', alpha=0.3, zorder=7)
    
    fig.colorbar(im, orientation='horizontal', shrink=0.4, pad=0.05, label='Seafloor Age (Ma)', extend='max')
    
    # Define custom legend handles
    custom_handles = [
        Patch(facecolor='tan', edgecolor='gray', label='Continental Crust'),  # Custom handle for the filled polygon
        Patch(facecolor='gray', edgecolor='none', label='Target Areas in\nBack-Arc Basins'),  # Custom handle for the buffer zones
        Line2D([0], [0], color='red', lw=2, label='Mid-Ocean Ridge'),  # Custom handle for the line (ridge)
        Line2D([0], [0], marker='.', markerfacecolor='red', markeredgecolor='none', markersize=10, linestyle='None', label='Target Points')
    ]

    # Add the custom legend to the plot
    legend = ax.legend(handles=custom_handles, loc='upper right', bbox_to_anchor=(1.4, 1), borderaxespad=0.)
            
    plt.show()

### Generate Target Points Inside New Guinea Island

In [None]:
target_polygon = parameters['target_polygon']
target_polygon_gdf = gpd.read_file(target_polygon)

target_points_ng_0_file = coreg_input_dir + coreg_input_files[3] + '_0_Ma.csv'
mask_coords_ng_0_file = coreg_input_dir + 'mask_ng_0_Ma.csv'

if os.path.isfile(target_points_ng_0_file) and os.path.isfile(mask_coords_ng_0_file):
    target_points_ng_0 = pd.read_csv(target_points_ng_0_file, index_col=False)
    mask_coords_ng_0 = pd.read_csv(mask_coords_ng_0_file, index_col=False)
else:
    target_points_ng_0, mask_coords_ng_0, nx_ng_0, ny_ng_0 = generate_samples(target_polygon_gdf, 0.2, 0.2, # dist_x and dist_y
                                                                              time=0, plate_motion_model='muller2019')
    target_points_ng_0.to_csv(target_points_ng_0_file, index=False, float_format='%.4f')
    mask_coords_ng_0.to_csv(mask_coords_ng_0_file, index=False, float_format='%.4f')

proj = ccrs.LambertAzimuthalEqualArea(150, 0)

# call the PlotTopologies object
gplot = gplately.PlotTopologies(model, coastlines, continents, cob, time=0)

if plate_motion_model == 'muller2016':
    agegrid_file = agegrid_dir + f'Muller_etal_2016_AREPS_v1.17_AgeGrid-0.nc'
elif plate_motion_model == 'muller2019':
    agegrid_file = agegrid_dir + f'Muller_etal_2019_Tectonics_v2.0_AgeGrid-0.nc'

agegrid = gplately.grids.read_netcdf_grid(agegrid_file)

ng_extent_bounds = target_polygon_gdf.bounds
ng_extent = [ng_extent_bounds.loc[0]['minx'], ng_extent_bounds.loc[0]['maxx'],
             ng_extent_bounds.loc[0]['miny'], ng_extent_bounds.loc[0]['maxy']]

# single colour bar
fig = plt.figure(figsize=(6, 8))
ax = plt.axes(projection=proj)

set_ax(ax, ng_extent, 10, 5, stock_img=False, order=9)

im = gplot.plot_grid(ax, agegrid.data, cmap=ccm.lapaz_r, vmin=0, vmax=153, alpha=0.5, zorder=1)

gplot.plot_continents(ax, edgecolor='gray', facecolor='tan', zorder=2)
gplot.plot_ridges(ax, color='red', alpha=0.5, zorder=3)
gplot.plot_plate_motion_vectors(ax, spacingX=10, spacingY=10, normalise=False, regrid_shape=10, alpha=0.2, zorder=4)

ax.scatter(
    target_points_ng_0['lon'],
    target_points_ng_0['lat'],
    transform=ccrs.PlateCarree(),
    marker='.',
    c='green',
    s=1,
    zorder=8
)

gplot.plot_trenches(ax, color='k', alpha=0.3, zorder=6)
gplot.plot_subduction_teeth(ax, spacing=0.03, color='k', alpha=0.3, zorder=7)

fig.colorbar(im, orientation='horizontal', shrink=0.4, pad=0.05, label='Seafloor Age (Ma)', extend='max')

# Define custom legend handles
custom_handles = [
    Patch(facecolor='tan', edgecolor='gray', label='Continental Crust'),  # Custom handle for the filled polygon
    Line2D([0], [0], color='red', lw=2, label='Mid-Ocean Ridge'),  # Custom handle for the line (ridge)
    Line2D([0], [0], marker='.', markerfacecolor='green', markeredgecolor='none', markersize=10, linestyle='None', label='Target Points')
    
]

# Add the custom legend to the plot
legend = ax.legend(handles=custom_handles, loc='upper right', bbox_to_anchor=(1.4, 1), borderaxespad=0.)

plt.show()

In [None]:
target_points_ng_coreg_in_files_lst = []

time_steps_ng = time_steps.copy()
time_steps_ng.remove(0)

for time in time_steps_ng:
    target_points_ng_coreg_in_files_lst.append(coreg_input_dir + coreg_input_files[3] + f'_{time}_Ma.csv')

for target_points_ng_file in target_points_ng_coreg_in_files_lst:
    if not os.path.isfile(target_points_ng_file):
        index = target_points_ng_coreg_in_files_lst.index(target_points_ng_file)
        # generate target points
        target_points_ng = generate_samples_polygon(target_points_ng_0_file, time_steps_ng[index], plate_motion_model='muller2019')
        # save the attributes of target points
        target_points_ng.to_csv(target_points_ng_file, index=False, float_format='%.4f')
        print(f'Target points saved to {target_points_ng_file}')

target_points_ng_coreg_in_lst = []

for target_points_ng_file in target_points_ng_coreg_in_files_lst:
    target_points_ng_coreg_in_lst.append(pd.read_csv(target_points_ng_file, index_col=False))

In [None]:
proj = ccrs.LambertAzimuthalEqualArea(150, 0)

@interact
def show_map(time=time_steps_ng):
    # call the PlotTopologies object
    gplot = gplately.PlotTopologies(model, coastlines, continents, cob, time=time)
    
    if plate_motion_model == 'muller2016':
        agegrid_file = agegrid_dir + f'Muller_etal_2016_AREPS_v1.17_AgeGrid-{time}.nc'
    elif plate_motion_model == 'muller2019':
        agegrid_file = agegrid_dir + f'Muller_etal_2019_Tectonics_v2.0_AgeGrid-{time}.nc'

    agegrid = gplately.grids.read_netcdf_grid(agegrid_file)
    
    # single colour bar
    fig = plt.figure(figsize=(6, 8))
    ax = plt.axes(projection=proj)
    
    set_ax(ax, target_extent_map, 10, 5, stock_img=False, order=9)

    im = gplot.plot_grid(ax, agegrid.data, cmap=ccm.lapaz_r, vmin=0, vmax=153, alpha=0.5, zorder=1)

    gplot.plot_continents(ax, edgecolor='gray', facecolor='tan', zorder=2)
    gplot.plot_ridges(ax, color='red', alpha=0.5, zorder=3)
    gplot.plot_plate_motion_vectors(ax, spacingX=10, spacingY=10, normalise=False, regrid_shape=10, alpha=0.2, zorder=4)
    
    ax.scatter(
        target_points_ng_coreg_in_lst[time_steps_ng.index(time)]['lon'],
        target_points_ng_coreg_in_lst[time_steps_ng.index(time)]['lat'],
        transform=ccrs.PlateCarree(),
        marker='.',
        c='green',
        s=0.1,
        zorder=8
    )
        
    gplot.plot_trenches(ax, color='k', alpha=0.3, zorder=6)
    gplot.plot_subduction_teeth(ax, spacing=0.03, color='k', alpha=0.3, zorder=7)
    
    fig.colorbar(im, orientation='horizontal', shrink=0.4, pad=0.05, label='Seafloor Age (Ma)', extend='max')
    
    # Define custom legend handles
    custom_handles = [
        Patch(facecolor='tan', edgecolor='gray', label='Continental Crust'),  # Custom handle for the filled polygon
        Line2D([0], [0], color='red', lw=2, label='Mid-Ocean Ridge'),  # Custom handle for the line (ridge)
        Line2D([0], [0], marker='.', markerfacecolor='green', markeredgecolor='none', markersize=10, linestyle='None', label='Target Points')
    ]

    # Add the custom legend to the plot
    legend = ax.legend(handles=custom_handles, loc='upper right', bbox_to_anchor=(1.4, 1), borderaxespad=0.)
            
    plt.show()

### Coregistration and Data Wrangling

In [None]:
coreg_output_dir = parameters['coreg_output_dir']
positive_data_file = coreg_output_dir + coreg_input_files[0]
unlabelled_data_file = coreg_output_dir + coreg_input_files[1]
target_points_coreg_out_files_lst = []
target_points_ng_coreg_out_files_lst = []

for time in time_steps:
    target_points_coreg_out_files_lst.append(coreg_output_dir + coreg_input_files[2] + f'_{time}_Ma.csv')
    
for time in time_steps:
    target_points_ng_coreg_out_files_lst.append(coreg_output_dir + coreg_input_files[3] + f'_{time}_Ma.csv')

coregistration(
    coreg_input_dir,
    coreg_output_dir,
    coreg_input_files,
    conv_dir,
    conv_prefix,
    conv_ext,
    time_steps=time_steps,
    search_radius=3
)

positive_data = pd.read_csv(positive_data_file, index_col=False)
unlabelled_data = pd.read_csv(unlabelled_data_file, index_col=False)

target_points_coreg_out_lst = []
for file_name in target_points_coreg_out_files_lst:
    target_points_coreg_out_lst.append(pd.read_csv(file_name, index_col=False))

target_points_ng_coreg_out_lst = []
for file_name in target_points_ng_coreg_out_files_lst:
    target_points_ng_coreg_out_lst.append(pd.read_csv(file_name, index_col=False))

In [None]:
ml_input_dir = parameters['ml_input_dir']

positive_data['label'] = 1
unlabelled_data['label'] = 0

positive_features = positive_data[selected_features]
unlabelled_features = unlabelled_data[selected_features]
features_all = pd.concat([positive_features, unlabelled_features]).reset_index(drop=True)

# save correlation csv file
corr_file = ml_input_dir + 'correlation.csv'

if os.path.isfile(corr_file):
    corr = pd.read_csv(corr_file, index_col=0)
else:
    corr = features_all.corr(method='spearman').round(3)
    corr.to_csv(corr_file, index=True)

corr.style.background_gradient(cmap='coolwarm', axis=None).format('{:.3}')

In [None]:
selected_features_names = parameters['selected_features_names_nounit_01']
corr.columns = selected_features_names
corr.index = selected_features_names

f = plt.figure(figsize=(20, 15))
plt.matshow(corr, fignum=f.number, cmap='coolwarm', vmin=-1, vmax=1)
plt.xticks(range(corr.select_dtypes(['number']).shape[1]), corr.select_dtypes(['number']).columns, fontsize=14, rotation=90)
plt.yticks(range(corr.select_dtypes(['number']).shape[1]), corr.select_dtypes(['number']).columns, fontsize=14)
cb = plt.colorbar(aspect=50)
cb.ax.tick_params(labelsize=14)

# plt.savefig(
#     f'./figures/muller2019/correlation.png',
#     bbox_inches='tight',
#     pad_inches=0.1,
#     dpi=150
#     )

In [None]:
def analyze_correlations(file_path, threshold=0.7):
    # Read the correlation matrix
    corr_matrix = pd.read_csv(file_path, index_col=0)
    
    # Dictionary to store correlations
    correlations = {}
    
    for column in corr_matrix.columns:
        positive_corr = []
        negative_corr = []
        feature = corr_matrix[column]
        
        for i in range(feature.shape[0]):
            if abs(feature[i]) >= threshold and feature.index[i] != column:
                if feature[i] > 0:
                    positive_corr.append((feature.index[i], feature[i]))
                else:
                    negative_corr.append((feature.index[i], feature[i]))
        
        if positive_corr or negative_corr:
            correlations[column] = {
                'positive': sorted(positive_corr, key=lambda x: x[1], reverse=True),
                'negative': sorted(negative_corr, key=lambda x: x[1])
            }
    
    return correlations

def generate_report(correlations, threshold):
    print(f"Correlation Analysis Report (Threshold: {threshold})")
    print("=" * 50)
    
    for feature, corr in correlations.items():
        print(f"\nFeature: {feature}")
        print("-" * 30)
        
        if corr['positive']:
            print("Positive Correlations:")
            for c, value in corr['positive']:
                print(f"  {c}: {value:.3f}")
        
        if corr['negative']:
            print("Negative Correlations:")
            for c, value in corr['negative']:
                print(f"  {c}: {value:.3f}")
    
    print("\nTotal features with strong correlations:", len(correlations))

# Main execution
file_path = ml_input_dir + 'correlation.csv'
threshold = 0.7

correlations = analyze_correlations(file_path, threshold)
generate_report(correlations, threshold)

In [None]:
# Convert correlations to distances
distance_matrix = 1 - np.abs(corr)

# Perform hierarchical clustering
linkage_matrix = hierarchy.linkage(
    hierarchy.distance.squareform(distance_matrix),
    method='complete'
)

# Create figure
plt.figure(figsize=(12, 8))

# Create dendrogram
dendrogram = hierarchy.dendrogram(
    linkage_matrix,
    labels=corr.columns,
    orientation='right',
    leaf_font_size=12,
    leaf_rotation=0
)

# Get the axes object
ax = plt.gca()

# Get y-axis tick labels
labels = ax.get_yticklabels()

# Get leaf colors from dendrogram
# The dendrogram returns a dictionary containing the color information
leaf_colors = {}
for i, leaf_color in enumerate(dendrogram['leaves_color_list']):
    leaf_colors[dendrogram['ivl'][i]] = leaf_color

# Set each label's color to match its corresponding line
for label in labels:
    label.set_color(leaf_colors[label.get_text()])

# Customize the plot
plt.title('Hierarchical Clustering Dendrogram (Spearman Correlation)', pad=20)
plt.xlabel('|Spearman Correlation|')

# Get current x-axis limits
current_xlim = plt.xlim()

# Create new tick positions (keep the same scale but show different labels)
tick_positions = np.linspace(0, current_xlim[1], 6)
tick_labels = [f'{x:.1f}' for x in np.linspace(1.0, 0.0, 6)]

# Set new ticks and labels
plt.xticks(tick_positions, tick_labels)

# Add vertical line at x=0.7
# Need to convert from our display scale (1.0 to 0.0) to the actual scale (0 to max)
line_position = current_xlim[1] * (1 - 0.7)  # Convert from display scale to actual scale
plt.axvline(x=line_position, color='r', linestyle='--', alpha=0.5)
plt.annotate(f'0.7', xy=(0.315, -0.06), xycoords='data',
             xytext=(0, -20), textcoords='offset points',
             ha='center', va='top', color='r',
             arrowprops=dict(arrowstyle='->', color='r'))

# Adjust layout
plt.tight_layout()

# plt.savefig(
#     f'./figures/muller2019/dendrogram.png',
#     bbox_inches='tight',
#     pad_inches=0.1,
#     dpi=150
# )

plt.show()

In [None]:
Xy_train_file = ml_input_dir + 'Xy_train.csv'
Xy_pos_test_file = ml_input_dir + 'Xy_test.csv'

if os.path.isfile(Xy_train_file) and os.path.isfile(Xy_pos_test_file):
    Xy_train = pd.read_csv(Xy_train_file, index_col=False)
    features_list = Xy_train.columns.tolist()
    features_list = [e for e in features_list if e not in ('label', 'weight')]
    Xy_pos_test = pd.read_csv(Xy_pos_test_file, index_col=False)
    print('Training data file already exists!')
else:
    positive_labels = positive_data[positive_data.columns[-1]]
    unlabelled_labels = unlabelled_data[unlabelled_data.columns[-1]]

    positive_weights = positive_data['weight']
    unlabelled_weights = unlabelled_data['weight']
    
    labels = pd.concat([positive_labels, unlabelled_labels]).reset_index(drop=True)
    weights = pd.concat([positive_weights, unlabelled_weights]).reset_index(drop=True)
    
    features_labels_all = pd.concat([features_all, weights, labels], axis=1).reset_index(drop=True)

    # drop highly correlated features
    # create a correlation matrix
    corr_matrix = features_all.corr(method='spearman').abs()
    # select the upper triangle of the correlation matrix
    corr_upper = corr_matrix.where(np.triu(np.ones(corr_matrix.shape), k=1).astype(bool))
    # find features with the correlation greater than 0.7
    corr_drop = [column for column in corr_upper.columns if any(corr_upper[column] > 0.7)]
    print('List of the features removed due to high correlation with other features:', corr_drop)
    # drop features
    features = features_all.drop(corr_drop, axis=1)
    features_list = features.columns.tolist()
    features_labels = pd.concat([features, weights, labels], axis=1).reset_index(drop=True)
    features_labels_list = features_labels.columns.tolist()

    positive_data = features_labels[features_labels['label']==1]
    unlabelled_data = features_labels[features_labels['label']==0]

    positive_features = positive_data[positive_data.columns[:-1]]
    unlabelled_features = unlabelled_data[unlabelled_data.columns[:-1]]

    positive_labels = positive_data[positive_data.columns[-1]]
    unlabelled_labels = unlabelled_data[unlabelled_data.columns[-1]]

    # split positive samples into training and test datasets
    X_pos_train, X_pos_test, y_pos_train, y_pos_test = train_test_split(positive_features, positive_labels, train_size=0.8, random_state=42)
    X_train = np.vstack((X_pos_train, unlabelled_features))
    y_train = np.vstack((y_pos_train.values.reshape(-1, 1), unlabelled_labels.values.reshape(-1, 1)))
    
    Xy_train_original = np.hstack((X_train, y_train))
    Xy_train_original = pd.DataFrame(Xy_train_original, columns=features_labels_list)
    Xy_train_original.to_csv(ml_input_dir + 'Xy_train_original.csv', index=False)

    Xy_train_features = Xy_train_original[features_list]
    Xy_train_labels = Xy_train_original['label']
    Xy_train_weights = Xy_train_original['weight']

    st_scaler = StandardScaler()
    X_train = st_scaler.fit_transform(Xy_train_features)
    Xy_train = np.hstack((X_train, Xy_train_weights.values.reshape(-1, 1), Xy_train_labels.values.reshape(-1, 1)))
    Xy_train = pd.DataFrame(Xy_train, columns=features_labels_list)

    Xy_pos_test_original = np.hstack((X_pos_test, y_pos_test.values.reshape(-1, 1)))
    Xy_pos_test_original = pd.DataFrame(Xy_pos_test_original, columns=features_labels_list)
    Xy_pos_test_original.to_csv(ml_input_dir + 'Xy_pos_test_original.csv', index=False)
    
    Xy_pos_test_features = Xy_pos_test_original[features_list]
    Xy_pos_test_labels = Xy_pos_test_original['label']
    Xy_pos_test_weights = Xy_pos_test_original['weight']
    
    X_pos_test = st_scaler.transform(Xy_pos_test_features)
    Xy_pos_test = np.hstack((X_pos_test, Xy_pos_test_weights.values.reshape(-1, 1), Xy_pos_test_labels.values.reshape(-1, 1)))
    Xy_pos_test = pd.DataFrame(Xy_pos_test, columns=features_labels_list)

    Xy_train.to_csv(ml_input_dir + 'Xy_train.csv', index=False)
    Xy_pos_test.to_csv(ml_input_dir + 'Xy_pos_test.csv', index=False)

    # save the standard scaler model
    with open(ml_input_dir + 'st_scaler.pkl', 'wb') as f:
        pickle.dump(st_scaler, f)

    print('\nNumber of features reduced from', num_features, 'to', len(features_list))
    print('Number of positive samples:', positive_features.shape[0])

    print('\nNumber of training samples:', X_train.shape[0]), print('Number of training labels:', y_train.shape[0])
    print('Number of positive testing samples:', X_pos_test.shape[0]), print('Number of testing labels:', y_pos_test.shape[0])

In [None]:
target_points_ml_in_files_lst = []
for time in time_steps:
    target_points_ml_in_files_lst.append(ml_input_dir + coreg_input_files[2] + f'_{time}_Ma.csv')
    
target_points_ml_in_lst = []
for target_points_ml_in_file in target_points_ml_in_files_lst:
    if os.path.isfile(target_points_ml_in_file):
        target_points_ml_in_lst.append(pd.read_csv(target_points_ml_in_file, index_col=False))
    if not os.path.isfile(target_points_ml_in_file):
        index = target_points_ml_in_files_lst.index(target_points_ml_in_file)
        target_points_ml_in = target_points_coreg_out_lst[index][selected_features]
        target_points_ml_in = target_points_ml_in[target_points_ml_in.columns.intersection(features_list)]
        
        try:
            target_points_ml_in = st_scaler.transform(target_points_ml_in)
        except:
            # load the model
            with open(ml_input_dir + 'st_scaler.pkl', 'rb') as f:
                st_scaler = pickle.load(f)
            target_points_ml_in = st_scaler.transform(target_points_ml_in)
            
        target_points_ml_in = pd.DataFrame(target_points_ml_in, columns=features_list)
        target_points_ml_in_lst.append(target_points_ml_in)
        target_points_ml_in.to_csv(target_points_ml_in_file, index=False)

In [None]:
target_points_ng_ml_in_files_lst = []
for time in time_steps:
    target_points_ng_ml_in_files_lst.append(ml_input_dir + coreg_input_files[3] + f'_{time}_Ma.csv')
    
target_points_ng_ml_in_lst = []
for target_points_ng_ml_in_file in target_points_ng_ml_in_files_lst:
    if os.path.isfile(target_points_ng_ml_in_file):
        target_points_ng_ml_in_lst.append(pd.read_csv(target_points_ng_ml_in_file, index_col=False))
    if not os.path.isfile(target_points_ng_ml_in_file):
        index = target_points_ng_ml_in_files_lst.index(target_points_ng_ml_in_file)
        target_points_ng_ml_in = target_points_ng_coreg_out_lst[index][selected_features]
        target_points_ng_ml_in = target_points_ng_ml_in[target_points_ng_ml_in.columns.intersection(features_list)]
        
        try:
            target_points_ng_ml_in = st_scaler.transform(target_points_ng_ml_in)
        except:
            # load the model
            with open(ml_input_dir + 'st_scaler.pkl', 'rb') as f:
                st_scaler = pickle.load(f)
            target_points_ng_ml_in = st_scaler.transform(target_points_ng_ml_in)
            
        target_points_ng_ml_in = pd.DataFrame(target_points_ng_ml_in, columns=features_list)
        target_points_ng_ml_in_lst.append(target_points_ng_ml_in)
        target_points_ng_ml_in.to_csv(target_points_ng_ml_in_file, index=False)

### Machine Learning
#### Positive and Unlabelled Bagging

In [None]:
ml_output_dir = parameters['ml_output_dir']
model_pub_file = ml_output_dir + 'model_pub.pkl'

if os.path.isfile(model_pub_file):
    print('A PUB model file exists. Attempting to load...')
    try:
        with open(model_pub_file, 'rb') as f:
            model_pub = pickle.load(f)
        print('Model loaded successfully.')
    except (ModuleNotFoundError, ImportError) as e:
        print(f"Error loading the model: {e}")
        print("This may be due to a version mismatch. Proceeding to train a new model.")
        model_pub = None
else:
    print('No existing model found. Proceeding to train a new model.')
    model_pub = None

if model_pub is None:
    # Random Forest model structure
    rf_pub = RandomForestClassifier(n_jobs=-1, random_state=42)
    pub = BaggingPuClassifier(rf_pub, n_jobs=-1, random_state=42)

    n_fold = 10

    features = Xy_train[Xy_train.columns[:-1]]
    labels = Xy_train[Xy_train.columns[-1]]
    X_train, X_test, y_train, y_test = train_test_split(features, labels, train_size=0.8, random_state=42)
    
    Xy_pub_train = np.hstack((X_train, y_train.values.reshape(-1, 1)))
    Xy_pub_train = pd.DataFrame(Xy_pub_train, columns=Xy_train.columns)
    Xy_pub_train.to_csv(ml_input_dir + 'Xy_pub_train.csv', index=False)
    
    Xy_pub_test = np.hstack((X_test, y_test.values.reshape(-1, 1)))
    Xy_pub_test = pd.DataFrame(Xy_pub_test, columns=Xy_train.columns)
    Xy_pub_test.to_csv(ml_input_dir + 'Xy_pub_test.csv', index=False)

    weights_train = X_train[X_train.columns[-1]]
    X_train = X_train[X_train.columns[:-1]]
    weights_test = X_test[X_test.columns[-1]]
    X_test = X_test[X_test.columns[:-1]]

    search_space = {
    'estimator__bootstrap': Categorical([True, False]), # values for boostrap can be either True or False
    'estimator__max_depth': Integer(5, 20), # values of max_depth are integers
    'estimator__max_features': Categorical([None, 'sqrt','log2']), 
    'estimator__min_samples_leaf': Integer(2, 20),
    'estimator__min_samples_split': Integer(2, 30),
    'estimator__n_estimators': Integer(10, 200),
    'max_samples': Integer(int(0.5*(len(y_train)-sum(y_train))), int(0.9*(len(y_train)-sum(y_train))))
    }

    pub_bayes_search = BayesSearchCV(pub,
                                     search_space,
                                     n_iter=100, # specify how many iterations
                                     scoring='f1',
                                     n_jobs=4,
                                     cv=n_fold,
                                     verbose=1,
                                     random_state=42)
    pub_bayes_search.fit(X_train, y_train, sample_weight=weights_train)
    
    # Extract the optimization results
    optimization_results = pub_bayes_search.cv_results_['mean_test_score']
    
    model_pub = pub_bayes_search.best_estimator_
    model_pub_acc = pub_bayes_search.best_score_    
    print('The highest F1-score during cross validation:', model_pub_acc)
    
    # save the model
    with open(model_pub_file, 'wb') as f:
        pickle.dump(model_pub, f)
    
    # Plot the Bayesian optimization progress
    plt.figure(figsize=(12, 6))
    plt.plot(range(1, len(optimization_results) + 1), optimization_results, marker='o', color='black', markerfacecolor='red')
    plt.xlim(0, 51)
    plt.xlabel('Bayesian Optimization Iteration')
    plt.ylabel('Mean Test Precision')
    plt.title('Bayesian Optimization Progress')
    plt.grid(True, linestyle=':')
    plt.tight_layout()
    plt.savefig(ml_output_dir + 'bayesian_optimization_pub_progress.png')
    plt.show()

In [None]:
print(model_pub)

Xy_pub_test_file = ml_input_dir + 'Xy_pub_test.csv'
Xy_pub_test = pd.read_csv(Xy_pub_test_file, index_col=False)

features = Xy_pub_test[Xy_pub_test.columns[:-2]]
labels = Xy_pub_test[Xy_pub_test.columns[-1]]
weights = Xy_pub_test[Xy_pub_test.columns[-2]]

X_pred = model_pub.predict(features)

# assuming that all zeros in labels are true negative samples
cMatrix = confusion_matrix(labels, X_pred)
X_pred_acc = accuracy_score(labels, X_pred, sample_weight=weights)
X_pred_pre = precision_score(labels, X_pred, sample_weight=weights)
X_pred_rec = recall_score(labels, X_pred, sample_weight=weights)
X_pred_f1 = f1_score(labels, X_pred, sample_weight=weights)

print('Confusion matrix:\n', cMatrix)
print('Accuracy:', X_pred_acc)
print('Precision:', X_pred_pre)
print('Recall:', X_pred_rec)
print('F1-Score:', X_pred_f1)

In [None]:
Xy_pos_test = pd.read_csv(ml_input_dir + 'Xy_pos_test.csv', index_col=False)
X_pos_test = Xy_pos_test[Xy_pos_test.columns[:-2]]
y_pos_test = Xy_pos_test[Xy_pos_test.columns[-1]]
weights_test = Xy_pos_test[Xy_pos_test.columns[-2]]
X_pos_pred = model_pub.predict(X_pos_test)
X_pos_pred_acc = accuracy_score(y_pos_test, X_pos_pred, sample_weight=weights_test)
print('Accuracy:', X_pos_pred_acc)

In [None]:
Xy_train = pd.read_csv(ml_input_dir + 'Xy_train.csv', index_col=False)
Xy_train_features = Xy_train[Xy_train.columns[:-2]]
Xy_train_labels = Xy_train[Xy_train.columns[-1]]
Xy_train_weights = Xy_train[Xy_train.columns[-2]]
Xy_train_labels_new = model_pub.predict(Xy_train_features)

Xy_train_labels_new_acc = accuracy_score(Xy_train_labels, Xy_train_labels_new, sample_weight=Xy_train_weights)
Xy_train_labels_new_pre = precision_score(Xy_train_labels, Xy_train_labels_new, sample_weight=Xy_train_weights)
Xy_train_labels_new_rec = recall_score(Xy_train_labels, Xy_train_labels_new, sample_weight=Xy_train_weights)
Xy_train_labels_new_f1 = f1_score(Xy_train_labels, Xy_train_labels_new, sample_weight=Xy_train_weights)

# assuming that all zeros in Xy_train_labels are true negative samples
print('Accuracy:', Xy_train_labels_new_acc)
print('Precision:', Xy_train_labels_new_pre)
print('Recall:', Xy_train_labels_new_rec)
print('F1-Score:', Xy_train_labels_new_f1)

mask = Xy_train_labels == 1
Xy_train_labels_new[mask] = 1
Xy_train['label'] = Xy_train_labels_new
Xy_train.to_csv(ml_output_dir + 'Xy_train_new.csv', index=False)

#### Random Forest

In [None]:
model_rf_file = ml_output_dir + 'model_rf.pkl'

if os.path.isfile(model_rf_file):
    print('A model file exists. Attempting to load...')
    try:
        with open(model_rf_file, 'rb') as f:
            model_rf = pickle.load(f)
        print('Model loaded successfully.')
    except (ModuleNotFoundError, ImportError) as e:
        print(f"Error loading the model: {e}")
        print("This may be due to a version mismatch. Proceeding to train a new model.")
        model_rf = None
else:
    print('No existing model found. Proceeding to train a new model.')
    model_rf = None

if model_rf is None:
    # Random Forest model structure
    rf = RandomForestClassifier(n_jobs=-1, random_state=42)

    n_fold = 10

    Xy_train_file = ml_output_dir + 'Xy_train_new.csv'
    Xy_train = pd.read_csv(Xy_train_file, index_col=False)
    features = Xy_train[Xy_train.columns[:-1]]
    labels = Xy_train[Xy_train.columns[-1]]
    X_train, X_test, y_train, y_test = train_test_split(features, labels, train_size=0.8, random_state=42)

    Xy_rf_train = np.hstack((X_train, y_train.values.reshape(-1, 1)))
    Xy_rf_train = pd.DataFrame(Xy_rf_train, columns=Xy_train.columns)
    Xy_rf_train.to_csv(ml_input_dir + 'Xy_rf_train.csv', index=False)
    
    Xy_rf_test = np.hstack((X_test, y_test.values.reshape(-1, 1)))
    Xy_rf_test = pd.DataFrame(Xy_rf_test, columns=Xy_train.columns)
    Xy_rf_test.to_csv(ml_input_dir + 'Xy_rf_test.csv', index=False)

    weights_train = X_train[X_train.columns[-1]]
    X_train = X_train[X_train.columns[:-1]]
    weights_test = X_test[X_test.columns[-1]]
    X_test = X_test[X_test.columns[:-1]]

    search_space = {
    'bootstrap': Categorical([True, False]), # values for boostrap can be either True or False
    'max_depth': Integer(5, 20), # values of max_depth are integers
    'max_features': Categorical([None, 'sqrt','log2']), 
    'min_samples_leaf': Integer(2, 20),
    'min_samples_split': Integer(2, 30),
    'n_estimators': Integer(10, 200)
    }

    rf_bayes_search = BayesSearchCV(rf,
                                    search_space,
                                    n_iter=100, # specify how many iterations
                                    scoring='f1',
                                    n_jobs=4,
                                    cv=n_fold,
                                    verbose=1,
                                    random_state=42)
    rf_bayes_search.fit(X_train, y_train, sample_weight=weights_train)
    
    # Extract the optimization results
    optimization_results = rf_bayes_search.cv_results_['mean_test_score']
    
    model_rf = rf_bayes_search.best_estimator_
    model_rf_acc = rf_bayes_search.best_score_ 
    print('The highest F1-score during cross validation:', model_rf_acc)
    
    # save the model
    with open(ml_output_dir + 'model_rf.pkl', 'wb') as f:
        pickle.dump(model_rf, f)
    
    importances = []
    estimators = model_rf.estimators_
    importances = [estimators[j].feature_importances_.reshape(-1, 1) for j in range(len(estimators))]
    importances = np.hstack(importances)

    # Plot the Bayesian optimization progress
    plt.figure(figsize=(12, 6))
    plt.plot(range(1, len(optimization_results) + 1), optimization_results, marker='o', color='black', markerfacecolor='red')
    plt.xlim(0, 51)
    plt.xlabel('Bayesian Optimization Iteration')
    plt.ylabel('Mean Test Accuracy')
    plt.title('Bayesian Optimization Progress')
    plt.grid(True, linestyle=':')
    plt.tight_layout()
    plt.savefig(ml_output_dir + 'bayesian_optimization_rf_progress.png')
    plt.show()
else:
    print("Using the loaded model. No new training performed.")
    print("To retrain the model, delete or rename the existing model file and run the script again.")

print("Script execution completed.")

In [None]:
print(model_rf)

Xy_rf_test_file = ml_input_dir + 'Xy_rf_test.csv'
Xy_rf_test = pd.read_csv(Xy_rf_test_file, index_col=False)

features = Xy_rf_test[Xy_rf_test.columns[:-2]]
labels = Xy_rf_test[Xy_rf_test.columns[-1]]
weights = Xy_rf_test[Xy_rf_test.columns[-2]]

X_pred = model_rf.predict(features)

cMatrix = confusion_matrix(labels, X_pred)
X_pred_acc = accuracy_score(labels, X_pred, sample_weight=weights)
X_pred_pre = precision_score(labels, X_pred, sample_weight=weights)
X_pred_rec = recall_score(labels, X_pred, sample_weight=weights)
X_pred_f1 = f1_score(labels, X_pred, sample_weight=weights)

print('Confusion matrix:\n', cMatrix)
print('Accuracy:', X_pred_acc)
print('Precision:', X_pred_pre)
print('Recall:', X_pred_rec)
print('F1-Score:', X_pred_f1)

In [None]:
Xy_pos_test = pd.read_csv(ml_input_dir + 'Xy_pos_test.csv', index_col=False)
X_pos_test = Xy_pos_test[Xy_pos_test.columns[:-2]]
y_pos_test = Xy_pos_test[Xy_pos_test.columns[-1]]
weights_test = Xy_pos_test[Xy_pos_test.columns[-2]]
X_pos_pred = model_rf.predict(X_pos_test)
X_pos_pred_acc = accuracy_score(y_pos_test, X_pos_pred, sample_weight=weights_test)
print('Accuracy:', X_pos_pred_acc)

#### ROC Plot

In [None]:
def roc_plot(y_test, z_test, n_classes, labels_name, average='macro'):
    fpr = {}
    tpr = {}
    roc_auc = {}

    y_test_dummies = pd.get_dummies(y_test).values
    
    for i in range(n_classes):
        fpr[i], tpr[i], _ = roc_curve(y_test_dummies[:, i], z_test[:, i])
        roc_auc[i] = auc(fpr[i], tpr[i])

    # roc for each class
    fig, ax = plt.subplots(figsize=(6, 6))
    ax.plot([0, 1], [0, 1], 'k--')
    ax.set_xlim([0.0, 1.0])
    ax.set_ylim([0.0, 1.05])
    ax.set_xlabel('False Positive Rate')
    ax.set_ylabel('True Positive Rate')
    ax.set_title('Receiver Operating Characteristic')
    
    for i in range(n_classes):
        ax.plot(fpr[i], tpr[i], label='{}, AUC = {}'.format(labels_name[i], '{0:.4f}'.format(roc_auc[i])))
    
    ax.legend(loc='best')
    ax.grid(alpha=0.5)
    sns.despine()
    plt.show()
    print('ROC AUC score:', roc_auc_score(y_test_dummies, z_test, average=average))

# predict for the test dataset
z_test = model_rf.predict_proba(features)

labels_name = ['Non-mineralised', 'Mineralised']
roc_plot(labels, z_test, 2, labels_name)

#### Feature Importance

In [None]:
feature_importances_file = ml_output_dir + 'feature_importances_muller2019.csv'

if os.path.isfile(feature_importances_file):
    feature_importances = pd.read_csv(feature_importances_file, index_col=False).to_numpy().tolist()
else:
    output_features_index = [selected_features.index(feature) for feature in features_list]
    selected_features_names = parameters['selected_features_names_nounit_01']
    selected_features_names = [selected_features_names[i] for i in output_features_index]

    importances_mean = importances.mean(axis=1)
    importances_var = importances.var(axis=1)

    feature_importances = [(feature, round(importance, 4)) for feature, importance in zip(selected_features_names, importances_mean)]
    feature_importances = sorted(feature_importances, key=lambda x:x[1], reverse=True)
    feature_importances_df = pd.DataFrame(feature_importances, columns=['Feature', 'Importance'])
    feature_importances_df['Variance'] = importances_var
    
    feature_importances_df.to_csv(feature_importances_file, index=False)

In [None]:
# list of features sorted from most to least important
sorted_importances = [importance[1] for importance in feature_importances]
# cumulative importance
cumulative_importances = np.cumsum(sorted_importances)

x_values = list(range(len(feature_importances)))
x_values = [x+1 for x in x_values]

fig = plt.figure(figsize=(4, 4))
ax2 = fig.add_subplot(111)
ax1 = ax2.twinx()
ax2.set_facecolor('whitesmoke')

ax2.bar(x_values, sorted_importances, edgecolor='gray', facecolor='LightSalmon', width=1, alpha=0.5)
ax1.plot(x_values, cumulative_importances, 'k--')

plt.xlim(0.5, len(cumulative_importances)+0.5)
ax1.set_ylim(0, 1.05)

ax2.set_ylabel('Feature Importance')
ax1.set_ylabel('Cumulative Importance')

In [None]:
# print significant features above some threshold
feature_importances.sort(key=lambda x:x[1])
ft_imps = [x[1] for x in feature_importances]

fig, ax = plt.subplots()
ax.set_facecolor('whitesmoke')
bar = ax.barh(range(len(ft_imps)), ft_imps)

def gradientbars(bars, data):
    ax = bars[0].axes
    lim = ax.get_xlim()+ax.get_ylim()
    for bar in bars:
        bar.set_zorder(1)
        bar.set_facecolor('none')
        bar.set_edgecolor('black')
        x, y = bar.get_xy()
        w, h = bar.get_width(), bar.get_height()
        cmap = plt.get_cmap('coolwarm')
        grad = np.atleast_2d(np.linspace(0, 1*w/max(data), 256))
        ax.imshow(grad, extent=[x, x+w, y, y+h], aspect='auto', zorder=0, norm=mpl.colors.NoNorm(vmin=0, vmax=1), cmap=cmap, alpha=0.8)
        manual_labels = [x[0] for x in feature_importances]
        ax.set_yticks(np.arange(0, len(data), 1).tolist())
        ax.set_yticklabels(manual_labels, minor=False)
    ax.axis(lim)
    ax.set_xlabel('Feature Importance')

gradientbars(bar, ft_imps)
plt.gca().yaxis.grid(False)

# plt.savefig(
#     f'./figures/muller2019/importances.png',
#     bbox_inches='tight',
#     pad_inches=0.1,
#     dpi=150
#     )

plt.show()

### Histogram

In [None]:
Xy_train_new = pd.read_csv(ml_output_dir + 'Xy_train_new.csv', index_col=False)
Xy_train_original = pd.read_csv(ml_input_dir + 'Xy_train_original.csv', index_col=False)

@interact
def show_map(feature=Xy_train_new.columns):
    fig = plt.figure(figsize=(6, 6))
    ax1 = fig.add_subplot(111)
    ax1.set_facecolor('whitesmoke')
    ax2 = ax1.twiny()
    ax3 = ax2.twinx()
    
    # Calculate bin edges based on the entire dataset
    min_val = Xy_train_new[feature].min()
    max_val = Xy_train_new[feature].max()
    bin_edges = np.linspace(min_val, max_val, 26)  # 25 bins, 26 edges
    
    # Histogram for original data (transparent)
    ax1.hist(Xy_train_original[feature], bins=25, alpha=0.0)
    
    # Histograms for standardized data with black edges and consistent bin widths
    h1 = ax2.hist(Xy_train_new.loc[Xy_train_new['label']==0][feature], bins=bin_edges, 
                  color='LightSalmon', label='Negative', 
                  edgecolor='black', linewidth=1)
    h2 = ax2.hist(Xy_train_new.loc[Xy_train_new['label']==1][feature], bins=bin_edges, 
                  color='DarkSeaGreen', label='Positive', alpha=0.8, 
                  edgecolor='black', linewidth=1)
    
    kde_x = np.linspace(min_val, max_val, 100)
    
    kde1 = stats.gaussian_kde(Xy_train_new.loc[Xy_train_new['label']==0][feature])
    kde2 = stats.gaussian_kde(Xy_train_new.loc[Xy_train_new['label']==1][feature])
    
    k1 = ax3.plot(kde_x, kde1(kde_x), color='LightSalmon')
    k2 = ax3.plot(kde_x, kde2(kde_x), color='DarkSeaGreen')
    
    # Calculate maximum values for y-axes
    max_freq = max(np.max(h1[0]), np.max(h2[0]))
    max_density = max(np.max(kde1(kde_x)), np.max(kde2(kde_x)))
    
    # Set y-axis limits with some padding
    ax2.set_ylim(0, max_freq * 1.1)
    ax3.set_ylim(0, max_density * 1.1)
    
    # Adjust tick locations for both y-axes
    ax2.yaxis.set_major_locator(plt.MaxNLocator(5))
    ax3.yaxis.set_major_locator(plt.MaxNLocator(5))
    
    ax2.legend(loc='upper right')
    
    ax1.set_xlabel(feature + '\n(Actual)')
    ax2.set_xlabel(feature + '\n(Standardised)')
    ax1.set_ylabel('Frequency')
    ax3.set_ylabel('Probability density')

#     plt.savefig(
#         f'./figures/muller2019/histograms/conv_angle_deg_edited.png',
#         bbox_inches='tight',
#         pad_inches=0.1,
#         dpi=150
#     )
    
    plt.show()

### Scatter Plot

In [None]:
features_1 = Xy_train_new.columns.drop('label')
features_2 = deque(features_1)
features_2.rotate()
features_3 = deque(features_2)
features_3.rotate()

@interact
def show_map(feature_1=features_1, feature_2=features_2, feature_3=features_3):

    fig = plt.figure(figsize=(6, 6))
    gs = GridSpec(2, 2, hspace=0.4, wspace=0.2, height_ratios=[1, 0.03])
    ax1 = fig.add_subplot(gs[0, :])
    ax2 = ax1.twinx()
    ax3 = ax2.twiny()

    min1 = Xy_train_original[feature_1].min()
    max1 = Xy_train_original[feature_1].max()
    min2 = Xy_train_original[feature_2].min()
    max2 = Xy_train_original[feature_2].max()

    x_range = np.linspace(min1, max1+0.1*max1, num=100)
    y_range = np.linspace(min2, max2+0.1*max2, num=100)
    grid_x, grid_y = np.meshgrid(x_range, y_range)
    
    grid_data = griddata(list(zip(Xy_train_original[feature_1], Xy_train_original[feature_2])), Xy_train_original[feature_3],
                         (grid_x, grid_y), method='nearest', fill_value=0)
    grid_data = ndimage.gaussian_filter(grid_data, sigma=3)
    cb1 = ax1.imshow(grid_data.T, extent=(min1, max1, min2, max2), origin='lower', aspect='auto', cmap=plt.cm.Spectral_r, alpha=0)

    sc1 = ax1.scatter(Xy_train_original.loc[Xy_train_original['label']==0, feature_1],
                      Xy_train_original.loc[Xy_train_original['label']==0, feature_2], 40, marker='.', c='blue', alpha=0)
    sc2 = ax1.scatter(Xy_train_original.loc[Xy_train_original['label']==1, feature_1],
                      Xy_train_original.loc[Xy_train_original['label']==1, feature_2], 40, marker='.', c='orange', alpha=0)

    min1 = Xy_train_new[feature_1].min()
    max1 = Xy_train_new[feature_1].max()
    min2 = Xy_train_new[feature_2].min()
    max2 = Xy_train_new[feature_2].max()
    
    x_range = np.linspace(min1, max1+0.1*max1, num=100)
    y_range = np.linspace(min2, max2+0.1*max2, num=100)
    grid_x, grid_y = np.meshgrid(x_range, y_range)
    
    grid_data = griddata(list(zip(Xy_train_new[feature_1], Xy_train_new[feature_2])), Xy_train_new[feature_3],
                         (grid_x, grid_y), method='nearest', fill_value=0)
    grid_data = ndimage.gaussian_filter(grid_data, sigma=3)
    cb2 = ax3.imshow(grid_data.T, extent=(min1, max1, min2, max2), origin='lower', aspect='auto', cmap=plt.cm.Spectral_r, alpha=0.7)
    
    sc3 = ax3.scatter(Xy_train_new.loc[Xy_train_new['label']==1, feature_1],
                      Xy_train_new.loc[Xy_train_new['label']==1, feature_2],
                      100, marker='.', facecolor='mediumseagreen', edgecolor='black')
    sc4 = ax3.scatter(Xy_train_new.loc[Xy_train_new['label']==0, feature_1],
                      Xy_train_new.loc[Xy_train_new['label']==0, feature_2],
                      100, marker='.', facecolor='orangered', edgecolor='black', alpha=0.7)
        
    ax3.legend([sc3, sc4], ['Positive', 'Negative'], loc='best',  borderaxespad=0.1, fontsize=8) # numpoints=1

    ax1.set_xlabel(feature_1 + '\n(Actual)')
    ax1.set_ylabel(feature_2 + '\n(Actual)')
    ax2.set_ylabel(feature_2 + '\n(Standardised)')
    ax3.set_xlabel(feature_1 + '\n(Standardised)')
        
    cax1 = fig.add_subplot(gs[1, 1])
    cax2 = fig.add_subplot(gs[1, 0])
        
    fig.colorbar(cb2, cax=cax2, orientation='horizontal', label=feature_3 + '\n(Standardised)', extend='both')
    fig.colorbar(cb1, cax=cax1, orientation='horizontal', label=feature_3 + '\n(Actual)', extend='both')
    
#     plt.savefig(
#         f'./figures/muller2019/features_three/dist_nearest_edge_deg.png',
#         bbox_inches='tight',
#         pad_inches=0.1,
#         dpi=150
#     )

    plt.show()

### Boxplot

In [None]:
Xy_train_pivot = Xy_train_new.pivot(columns=['label'])
Xy_train_original_pivot = Xy_train_original.pivot(columns=['label'])
nb_groups1 = Xy_train_new['label'].nunique()
nb_groups2 = Xy_train_original['label'].nunique()

@interact
def show_map(feature=Xy_train_new.columns):
    bplot1 = [Xy_train_pivot[feature][var].dropna() for var in Xy_train_pivot[feature]]
    bplot2 = [Xy_train_original_pivot[feature][var].dropna() for var in Xy_train_original_pivot[feature]]
    fig, ax1 = plt.subplots(figsize=(6, 6))
    ax1.set_facecolor('whitesmoke')
    
    # Define colors for negative and positive samples
    colors = ['LightSalmon', 'DarkSeaGreen']
    
    # Create box plots for standardized data with different colors
    bp1 = ax1.boxplot(bplot1, positions=np.arange(nb_groups1), patch_artist=True,
                      whis=(5, 95), widths=0.2,
                      flierprops=dict(marker='.', markeredgecolor='black', fillstyle=None),
                      medianprops=dict(color='black'))
    
    # Color the boxes
    for patch, color in zip(bp1['boxes'], colors):
        patch.set_facecolor(color)
    
    # Create transparent box plots for original data
    box_param2 = dict(whis=(5, 95), widths=0, patch_artist=True,
                      flierprops=dict(marker='.', markeredgecolor='none', fillstyle=None),
                      medianprops=dict(color='none'), whiskerprops=dict(color='none'),
                      boxprops=dict(facecolor='none', edgecolor='none'))
    ax2 = ax1.twinx()
    ax2.boxplot(bplot2, positions=np.arange(nb_groups2), **box_param2)
    
    # Format x ticks
    labelsize = 12
    ax1.set_xticks(np.arange(nb_groups1))
    ax1.set_xticklabels(['Negative', 'Positive'])
    ax1.tick_params(axis='x', labelsize=labelsize)
    
    # Format y ticks
    ax1.tick_params(axis='y', labelsize=labelsize)
    ax2.tick_params(axis='y', labelsize=labelsize)
    
    # Format axes labels
    label_fmt = dict(size=12, labelpad=15)
    ax1.set_xlabel(feature, **label_fmt)
    ax1.set_ylabel(feature + '\n(Standardised)', **label_fmt)
    ax2.set_ylabel(feature + '\n(Actual)', **label_fmt)
        
    plt.show()

### Violin Plot

In [None]:
# function to calculate whiskers (for outliers)
def calculate_whiskers(data):
    q1, q3 = np.percentile(data, [25, 75])
    iqr = q3 - q1
    whisker_low = q1 - 1.5 * iqr
    whisker_high = q3 + 1.5 * iqr
    return whisker_low, whisker_high

# pivot dataframes
Xy_train_new_pivot = Xy_train_new.pivot(columns=['label'])
Xy_train_original_pivot = Xy_train_original.pivot(columns=['label'])

# calculate the number of unique groups
nb_groups1 = Xy_train_new['label'].nunique()
nb_groups2 = Xy_train_original['label'].nunique()

colors = ['LightSalmon', 'DarkSeaGreen']

@interact
def show_map(feature=Xy_train_new.columns):
    vplot1_data = [Xy_train_new_pivot[feature][var].dropna() for var in Xy_train_new_pivot[feature]]
    vplot2_data = [Xy_train_original_pivot[feature][var].dropna() for var in Xy_train_original_pivot[feature]]

    fig, ax1 = plt.subplots(figsize=(6, 6))
    ax1.set_facecolor('whitesmoke')

    # create violin plots
    vplot1_parts = ax1.violinplot(vplot1_data, positions=np.arange(nb_groups1))
    
    for i, part in enumerate(vplot1_parts['bodies']):
        part.set_facecolor(colors[i])
        part.set_edgecolor(colors[i])
        
    vplot1_parts['cbars'].set_edgecolor('black')
    vplot1_parts['cmins'].set_edgecolor('black')
    vplot1_parts['cmaxes'].set_edgecolor('black')
    
    ax2 = ax1.twinx()
    vplot2_parts = ax2.violinplot(vplot2_data, positions=np.arange(nb_groups2))
    
    for part in vplot2_parts['bodies']:
        part.set_facecolor('none')
        part.set_edgecolor('none')
        
    vplot2_parts['cbars'].set_edgecolor('none')
    vplot2_parts['cmins'].set_edgecolor('none')
    vplot2_parts['cmaxes'].set_edgecolor('none')

    # plotting outliers for the first violin plot
    for i, data in enumerate(vplot1_data):
        low, high = calculate_whiskers(data)
        outliers = data[(data > high) | (data < low)]
        ax1.scatter([i]*len(outliers), outliers, facecolor='red', edgecolor='black', s=20)

    # format x ticks
    labelsize = 12
    ax1.set_xticks(np.arange(nb_groups1))
    ax1.set_xticklabels(['Negative', 'Positive'])

    # format axes labels
    ax1.set_ylabel(feature + '\n(Standardised)')
    ax2.set_ylabel(feature + '\n(Actual)')
        
    plt.show()

### Box-Violin Plot

In [None]:
# pivot dataframes
Xy_train_pivot = Xy_train_new.pivot(columns=['label'])
Xy_train_original_pivot = Xy_train_original.pivot(columns=['label'])
nb_groups1 = Xy_train_new['label'].nunique()
nb_groups2 = Xy_train_original['label'].nunique()

@interact
def show_combined_plot(feature=Xy_train_new.columns):
    # Prepare data
    plot_data1 = [Xy_train_pivot[feature][var].dropna() for var in Xy_train_pivot[feature]]
    plot_data2 = [Xy_train_original_pivot[feature][var].dropna() for var in Xy_train_original_pivot[feature]]
    
    # Create figure and axes
    fig, ax1 = plt.subplots(figsize=(8, 6))
    ax1.set_facecolor('whitesmoke')
    
    # Define colors
    colors = ['LightSalmon', 'DarkSeaGreen']
    
    # Create violin plots first (semi-transparent)
    vplot1_parts = ax1.violinplot(plot_data1, positions=np.arange(nb_groups1))
    
    # Style violin plots
    for i, body in enumerate(vplot1_parts['bodies']):
        body.set_facecolor(colors[i])
        body.set_edgecolor(colors[i])
        
    vplot1_parts['cbars'].set_edgecolor('black')
    vplot1_parts['cmins'].set_edgecolor('black')
    vplot1_parts['cmaxes'].set_edgecolor('black')
    
    # Create box plots on top with gray whiskers and no caps
    bp1 = ax1.boxplot(plot_data1, positions=np.arange(nb_groups1), patch_artist=True,
                      whis=(5, 95), widths=0.2,
                      flierprops=dict(marker='.', markersize=8, markerfacecolor='red', markeredgecolor='black', fillstyle=None),
                      medianprops=dict(color='black', linewidth=1.5),
                      whiskerprops=dict(color='black', linestyle='-'),  # Gray whiskers
                      capprops=dict(visible=False))  # Remove whisker caps
    
    # Color the boxes
    for patch, color in zip(bp1['boxes'], colors):
        patch.set_facecolor(color)
    
    # Create transparent box plots for original data
    ax2 = ax1.twinx()
    box_param2 = dict(whis=(5, 95), widths=0, patch_artist=True,
                     flierprops=dict(marker='.', markeredgecolor='none', fillstyle=None),
                     medianprops=dict(color='none'), whiskerprops=dict(color='none'),
                     boxprops=dict(facecolor='none', edgecolor='none'))
    ax2.boxplot(plot_data2, positions=np.arange(nb_groups2), **box_param2)
        
    # Format axes
    ax1.set_xticks(np.arange(nb_groups1))
    ax1.set_xticklabels(['Negative', 'Positive'])
    ax1.tick_params(axis='x')
    ax1.tick_params(axis='y')
    ax2.tick_params(axis='y')
    
    # Format labels
    ax1.set_ylabel(feature + '\n(Standardised)')
    ax2.set_ylabel(feature + '\n(Actual)')
    
    plt.tight_layout()
    plt.show()

### Probability in Back-Arc Basins

In [None]:
target_points_prob_files_lst = []
target_points_prob_lst = []

for time in time_steps:
    target_points_prob_files_lst.append(ml_output_dir + f'target_points_prob_{time}_Ma.csv')

for i, target_points_prob_file in enumerate(target_points_prob_files_lst):
    if not(os.path.isfile(target_points_prob_file)):
        df = target_points_ml_in_lst[target_points_prob_files_lst.index(target_points_prob_file)]
        probs = model_rf.predict_proba(df)[:, 1].reshape(-1, 1)

        mm_scaler1 = MinMaxScaler()
        probs_scaled = mm_scaler1.fit_transform(probs)

        df_xy = df.copy()
        df_xy['lon'] = target_points_coreg_out_lst[i]['lon'].to_numpy()
        df_xy['lat'] = target_points_coreg_out_lst[i]['lat'].to_numpy()
        df_xy['prob'] = probs_scaled
        df_xy.to_csv(target_points_prob_file, index=False)
        target_points_prob_lst.append(df_xy)
        
for target_points_prob_file in target_points_prob_files_lst:
    target_points_prob_lst.append(pd.read_csv(target_points_prob_file, index_col=False))

In [None]:
proj = ccrs.LambertAzimuthalEqualArea(150, 0)

@interact
def show_map(time=time_steps):
    lons_lats_recon = []
    
    for min_occ in min_occ_data.iterrows():
        if time == 0:
            lons_lats_recon.append((min_occ[1]['lon'], min_occ[1]['lat'], min_occ[1]['weight']))
        elif int(min_occ[1]['age']) < time:
            lons_lats_recon.append((np.nan, np.nan, np.nan))
        elif int(min_occ[1]['age']) == time:
            lons_lats_recon.append((min_occ[1]['lon_recon'], min_occ[1]['lat_recon'], min_occ[1]['weight']))
        else:
            lat_lon_recon = get_recon_ccords([min_occ[1]['lon']],
                                             [min_occ[1]['lat']],
                                             plate_motion_model='muller2019',
                                             time=time)[0]
            lons_lats_recon.append(tuple((lat_lon_recon[1], lat_lon_recon[0], min_occ[1]['weight'])))
            
    # call the PlotTopologies object
    gplot = gplately.PlotTopologies(model, coastlines, continents, cob, time=time)
    
    if plate_motion_model == 'muller2016':
        agegrid_file = agegrid_dir + f'Muller_etal_2016_AREPS_v1.17_AgeGrid-{time}.nc'
    elif plate_motion_model == 'muller2019':
        agegrid_file = agegrid_dir + f'Muller_etal_2019_Tectonics_v2.0_AgeGrid-{time}.nc'

    agegrid = gplately.grids.read_netcdf_grid(agegrid_file)
    
    plot_x = target_points_prob_lst[time_steps.index(time)]['lon']
    plot_y = target_points_prob_lst[time_steps.index(time)]['lat']
    
    # dual colour bars
    fig = plt.figure(figsize=(6, 8))
    gs = GridSpec(2, 2, hspace=-0.75, wspace=0.1, height_ratios=[1, 0.01])
    ax = fig.add_subplot(gs[0, :], projection=proj)
    
    set_ax(ax, target_extent_map, 10, 5, stock_img=False, order=9)

    im = gplot.plot_grid(ax, agegrid.data, cmap=ccm.lapaz_r, vmin=0, vmax=153, alpha=0.5, zorder=1)

    gplot.plot_continents(ax, edgecolor='gray', facecolor='tan', zorder=2)
    gplot.plot_ridges(ax, color='red', alpha=0.5, zorder=3)
    gplot.plot_plate_motion_vectors(ax, spacingX=10, spacingY=10, normalise=False, regrid_shape=10, alpha=0.2, zorder=4)
    
    sc0 = ax.scatter(
        plot_x,
        plot_y,
        transform=ccrs.PlateCarree(),
        marker='.',
        c=target_points_prob_lst[time_steps.index(time)]['prob'],
        s=30,
        cmap=ccm.hawaii_r,
        zorder=5
    )
    
    sc1 = ax.scatter(
        [coords[0] for coords in lons_lats_recon],
        [coords[1] for coords in lons_lats_recon],
        transform=ccrs.PlateCarree(),
        marker='o',
        facecolor='yellow',
        edgecolor='black',
        s=[x * 20 for x in [coords[2] for coords in lons_lats_recon]],
        alpha=0.7,
        zorder=6
    )

    gplot.plot_trenches(ax, color='k', alpha=0.3, zorder=7)
    gplot.plot_subduction_teeth(ax, spacing=0.03, color='k', alpha=0.3, zorder=8)

    cax1 = fig.add_subplot(gs[1, 0])
    cax2 = fig.add_subplot(gs[1, 1])
    
    fig.colorbar(sc0, cax=cax2, orientation='horizontal', label='Mineralisation probability')
    fig.colorbar(im, cax=cax1, orientation='horizontal', label='Seafloor age (Ma)', extend='max')
    
    # Define custom legend handles
    custom_handles = [
        Patch(facecolor='tan', edgecolor='gray', label='Continental Crust'),  # Custom handle for the filled polygon
        Line2D([0], [0], color='red', lw=2, label='Mid-Ocean Ridge'),  # Custom handle for the line (ridge)
        Line2D([0], [0], marker='o', markerfacecolor='yellow', markeredgecolor='black', markersize=5, linestyle='None', label='Mineral Occurrence')  # Custom handle for mineral occurrences
    ]

    # Add the custom legend to the plot
    legend = ax.legend(handles=custom_handles, loc='upper right', bbox_to_anchor=(1.42, 1), borderaxespad=0.)
    
    ax.set_title(f'Porphyry Mineralisation Probability {time} Ma')
        
    plt.show()

In [None]:
mask_coords_lst = []
for mask_coords_file in mask_coords_files_lst:
    mask_coords_lst.append(pd.read_csv(mask_coords_file, index_col=False))

proj = ccrs.LambertAzimuthalEqualArea(150, 0)

@interact
def show_map(time=time_steps):
    lons_lats_recon = []
    
    for min_occ in min_occ_data.iterrows():
        if time == 0:
            lons_lats_recon.append((min_occ[1]['lon'], min_occ[1]['lat'], min_occ[1]['weight']))
        elif int(min_occ[1]['age']) < time:
            lons_lats_recon.append((np.nan, np.nan, np.nan))
        elif int(min_occ[1]['age']) == time:
            lons_lats_recon.append((min_occ[1]['lon_recon'], min_occ[1]['lat_recon'], min_occ[1]['weight']))
        else:
            lat_lon_recon = get_recon_ccords([min_occ[1]['lon']],
                                             [min_occ[1]['lat']],
                                             plate_motion_model='muller2019',
                                             time=time)[0]
            lons_lats_recon.append(tuple((lat_lon_recon[1], lat_lon_recon[0], min_occ[1]['weight'])))
    
    # call the PlotTopologies object
    gplot = gplately.PlotTopologies(model, coastlines, continents, cob, time=time)
    
    if plate_motion_model == 'muller2016':
        agegrid_file = agegrid_dir + f'Muller_etal_2016_AREPS_v1.17_AgeGrid-{time}.nc'
    elif plate_motion_model == 'muller2019':
        agegrid_file = agegrid_dir + f'Muller_etal_2019_Tectonics_v2.0_AgeGrid-{time}.nc'

    agegrid = gplately.grids.read_netcdf_grid(agegrid_file)
    
    mask_coords = mask_coords_lst[time_steps.index(time)]
    
    probabilities = []
    count = 0
            
    for mask in mask_coords['include']:
        if mask:
            probabilities.append(target_points_prob_lst[time_steps.index(time)]['prob'][count])
            count += 1
        else:
            probabilities.append(np.nan)
    
    nx = mask_coords_lst[time_steps.index(time)]['lon'].nunique()
    ny = mask_coords_lst[time_steps.index(time)]['lat'].nunique()
    
    x_min = mask_coords_lst[time_steps.index(time)]['lon'].min()
    x_max = mask_coords_lst[time_steps.index(time)]['lon'].max()
    y_min = mask_coords_lst[time_steps.index(time)]['lat'].min()
    y_max = mask_coords_lst[time_steps.index(time)]['lat'].max()
    
    probabilities_2d = np.reshape(probabilities, (ny, nx))
    probabilities_2d_ud = np.flipud(np.reshape(probabilities, (ny, nx)))
    
    # dual colour bars
    fig = plt.figure(figsize=(6, 8))
    gs = GridSpec(2, 2, hspace=-0.75, wspace=0.1, height_ratios=[1, 0.01])
    ax = fig.add_subplot(gs[0, :], projection=proj)
    
    set_ax(ax, target_extent_map, 10, 5, stock_img=False, order=9)

    im0 = gplot.plot_grid(ax, agegrid.data, cmap=ccm.lapaz_r, vmin=0, vmax=153, alpha=0.5, zorder=1)

    gplot.plot_continents(ax, edgecolor='gray', facecolor='tan', zorder=2)
    gplot.plot_ridges(ax, color='red', alpha=0.5, zorder=3)
    gplot.plot_plate_motion_vectors(ax, spacingX=10, spacingY=10, normalise=False, regrid_shape=10, alpha=0.2, zorder=4)
    
    im1 = plt.imshow(
        probabilities_2d,
        transform=ccrs.PlateCarree(),
        origin='lower',
        cmap=ccm.hawaii_r,
        interpolation='nearest',
        extent=(x_min, x_max, y_min, y_max),
        zorder=5
    )
    
    sc1 = ax.scatter(
        [coords[0] for coords in lons_lats_recon],
        [coords[1] for coords in lons_lats_recon],
        transform=ccrs.PlateCarree(),
        marker='o',
        facecolor='yellow',
        edgecolor='black',
        s=[x * 20 for x in [coords[2] for coords in lons_lats_recon]],
        alpha=0.7,
        zorder=6
    )

    gplot.plot_trenches(ax, color='k', alpha=0.3, zorder=7)
    gplot.plot_subduction_teeth(ax, spacing=0.03, color='k', alpha=0.3, zorder=8)
    
    cax1 = fig.add_subplot(gs[1, 0])
    cax2 = fig.add_subplot(gs[1, 1])
    
    fig.colorbar(im1, cax=cax2, orientation='horizontal', label='Mineralisation probability')
    fig.colorbar(im0, cax=cax1, orientation='horizontal', label='Seafloor age (Ma)', extend='max')
    
    # Define custom legend handles
    custom_handles = [
        Patch(facecolor='tan', edgecolor='gray', label='Continental Crust'),  # Custom handle for the filled polygon
        Line2D([0], [0], color='red', lw=2, label='Mid-Ocean Ridge'),  # Custom handle for the line (ridge)
        Line2D([0], [0], marker='o', markerfacecolor='yellow', markeredgecolor='black', markersize=5, linestyle='None', label='Mineral Occurrence')  # Custom handle for mineral occurrences
    ]

    # Add the custom legend to the plot
    legend = ax.legend(handles=custom_handles, loc='upper right', bbox_to_anchor=(1.42, 1), borderaxespad=0.)
    
    ax.set_title(f'Porphyry Mineralisation Probability {time} Ma')
        
    plt.show()

### Probability in New Guinea

In [None]:
target_points_ng_prob_files_lst = []
target_points_ng_prob_lst = []

for time in time_steps:
    target_points_ng_prob_files_lst.append(ml_output_dir + f'target_points_ng_prob_{time}_Ma.csv')

for i, target_points_ng_prob_file in enumerate(target_points_ng_prob_files_lst):
    if not(os.path.isfile(target_points_ng_prob_file)):
        df = target_points_ng_ml_in_lst[target_points_ng_prob_files_lst.index(target_points_ng_prob_file)]
        probs = model_rf.predict_proba(df)[:, 1].reshape(-1, 1)

        mm_scaler1 = MinMaxScaler()
        probs_scaled = mm_scaler1.fit_transform(probs)

        df_xy = df.copy()
        df_xy['lon'] = target_points_ng_coreg_out_lst[i]['lon'].to_numpy()
        df_xy['lat'] = target_points_ng_coreg_out_lst[i]['lat'].to_numpy()
        df_xy['prob'] = probs_scaled
        df_xy.to_csv(target_points_ng_prob_file, index=False)
        target_points_ng_prob_lst.append(df_xy)
        
for target_points_ng_prob_file in target_points_ng_prob_files_lst:
    target_points_ng_prob_lst.append(pd.read_csv(target_points_ng_prob_file, index_col=False))

In [None]:
proj = ccrs.LambertAzimuthalEqualArea(150, 0)

@interact
def show_map(time=time_steps):
    lons_lats_recon = []
    
    for min_occ in min_occ_data.iterrows():
        if time == 0:
            lons_lats_recon.append((min_occ[1]['lon'], min_occ[1]['lat'], min_occ[1]['weight']))
        elif int(min_occ[1]['age']) < time:
            lons_lats_recon.append((np.nan, np.nan, np.nan))
        elif int(min_occ[1]['age']) == time:
            lons_lats_recon.append((min_occ[1]['lon_recon'], min_occ[1]['lat_recon'], min_occ[1]['weight']))
        else:
            lat_lon_recon = get_recon_ccords([min_occ[1]['lon']],
                                             [min_occ[1]['lat']],
                                             plate_motion_model='muller2019',
                                             time=time)[0]
            lons_lats_recon.append(tuple((lat_lon_recon[1], lat_lon_recon[0], min_occ[1]['weight'])))
            
    # call the PlotTopologies object
    gplot = gplately.PlotTopologies(model, coastlines, continents, cob, time=time)
    
    if plate_motion_model == 'muller2016':
        agegrid_file = agegrid_dir + f'Muller_etal_2016_AREPS_v1.17_AgeGrid-{time}.nc'
    elif plate_motion_model == 'muller2019':
        agegrid_file = agegrid_dir + f'Muller_etal_2019_Tectonics_v2.0_AgeGrid-{time}.nc'

    agegrid = gplately.grids.read_netcdf_grid(agegrid_file)
    
    plot_x = target_points_ng_prob_lst[time_steps.index(time)]['lon']
    plot_y = target_points_ng_prob_lst[time_steps.index(time)]['lat']
    
    # dual colour bars
    fig = plt.figure(figsize=(6, 8))
    gs = GridSpec(2, 2, hspace=-0.75, wspace=0.1, height_ratios=[1, 0.01])
    ax = fig.add_subplot(gs[0, :], projection=proj)
    
    set_ax(ax, target_extent_map, 10, 5, stock_img=False, order=9)

    im = gplot.plot_grid(ax, agegrid.data, cmap=ccm.lapaz_r, vmin=0, vmax=230, alpha=0.5, zorder=1)

    gplot.plot_continents(ax, edgecolor='gray', facecolor='tan', zorder=2)
    gplot.plot_ridges(ax, color='red', alpha=0.5, zorder=3)
    gplot.plot_plate_motion_vectors(ax, spacingX=10, spacingY=10, normalise=False, regrid_shape=10, alpha=0.2, zorder=4)
    
    sc0 = ax.scatter(
        plot_x,
        plot_y,
        transform=ccrs.PlateCarree(),
        marker='.',
        c=target_points_ng_prob_lst[time_steps.index(time)]['prob'],
        s=30,
        cmap=ccm.hawaii_r,
        zorder=5
    )
    
    sc1 = ax.scatter(
        [coords[0] for coords in lons_lats_recon],
        [coords[1] for coords in lons_lats_recon],
        transform=ccrs.PlateCarree(),
        marker='o',
        facecolor='yellow',
        edgecolor='black',
        s=[x * 20 for x in [coords[2] for coords in lons_lats_recon]],
        alpha=0.7,
        zorder=6
    )

    gplot.plot_trenches(ax, color='k', alpha=0.3, zorder=7)
    gplot.plot_subduction_teeth(ax, spacing=0.03, color='k', alpha=0.3, zorder=8)

    cax1 = fig.add_subplot(gs[1, 0])
    cax2 = fig.add_subplot(gs[1, 1])
    
    fig.colorbar(sc0, cax=cax2, orientation='horizontal', label='Mineralisation probability')
    fig.colorbar(im, cax=cax1, orientation='horizontal', label='Seafloor age (Ma)', extend='max')
    
    # Define custom legend handles
    custom_handles = [
        Patch(facecolor='tan', edgecolor='gray', label='Continental Crust'),  # Custom handle for the filled polygon
        Line2D([0], [0], color='red', lw=2, label='Mid-Ocean Ridge'),  # Custom handle for the line (ridge)
        Line2D([0], [0], marker='o', markerfacecolor='yellow', markeredgecolor='black', markersize=5, linestyle='None', label='Mineral Occurrence')  # Custom handle for mineral occurrences
    ]

    # Add the custom legend to the plot
    legend = ax.legend(handles=custom_handles, loc='upper right', bbox_to_anchor=(1.42, 1), borderaxespad=0.)
    
    ax.set_title(f'Porphyry Mineralisation Probability {time} Ma')
    
    plt.show()

In [None]:
target_points_ng_probs_file = ml_output_dir + 'target_points_ng_prob.csv'

if not(os.path.isfile(target_points_ng_probs_file)):
    prob_lst = []

    for row in target_points_ng_0.iterrows():
        index_initial = row[1]['index']
        index_lst = []
        probs = []

        for time in time_steps:
            target_points_ng_coreg_out = target_points_ng_coreg_out_lst[time]
            target_points_ng_coreg_out_index = target_points_ng_coreg_out['index'].tolist()
            try:
                index_lst.append(target_points_ng_coreg_out_index.index(index_initial))
            except ValueError:
                index_lst.append(None)

        for time in time_steps:
            target_points_ng_prob = target_points_ng_prob_lst[time]
            if index_lst[time] != None:
                target_points_ng_prob_val = target_points_ng_prob.iloc[index_lst[time]]['prob']
                probs.append(target_points_ng_prob_val)
            else:
                probs.append(np.nan)

        prob_lst.append(probs)

    prob_arr = np.array(prob_lst)
    ma_list = [f"{i}_ma" for i in range(31)]
    prob_df = pd.DataFrame(prob_arr, columns=ma_list)

    target_points_ng_probs = pd.DataFrame()
    target_points_ng_probs['lon'] = target_points_ng_0['lon']
    target_points_ng_probs['lat'] = target_points_ng_0['lat']
    target_points_ng_probs = pd.concat([target_points_ng_probs, prob_df], axis=1)

    probs = []

    for row in target_points_ng_probs.iterrows():
        prob = row[1][2:]
        prob = prob.dropna()
        prob_filtered = prob[prob > 0.5]
        if not prob_filtered.empty:
            prob_result = prob_filtered.median()
        else:
            prob_result = prob.median()
        probs.append(prob_result)

    probs = [0 if pd.isna(x) else x for x in probs]
    target_points_ng_probs['prob'] = probs
    target_points_ng_probs.to_csv(ml_output_dir + 'target_points_ng_prob.csv', index=False)
else:
    target_points_ng_probs = pd.read_csv(target_points_ng_probs_file, index_col=False)

In [None]:
proj = ccrs.LambertAzimuthalEqualArea(150, 0)

# call the PlotTopologies object
gplot = gplately.PlotTopologies(model, coastlines, continents, cob, time=0)

if plate_motion_model == 'muller2016':
    agegrid_file = agegrid_dir + f'Muller_etal_2016_AREPS_v1.17_AgeGrid-0.nc'
elif plate_motion_model == 'muller2019':
    agegrid_file = agegrid_dir + f'Muller_etal_2019_Tectonics_v2.0_AgeGrid-0.nc'

agegrid = gplately.grids.read_netcdf_grid(agegrid_file)

plot_x = target_points_ng_probs['lon']
plot_y = target_points_ng_probs['lat']

ng_extent_bounds = target_polygon_gdf.bounds
ng_extent = [ng_extent_bounds.loc[0]['minx'], ng_extent_bounds.loc[0]['maxx'],
             ng_extent_bounds.loc[0]['miny'], ng_extent_bounds.loc[0]['maxy']]

# single colour bar
fig = plt.figure(figsize=(9, 12))
gs = GridSpec(2, 2, hspace=-0.68, wspace=0.1, height_ratios=[1, 0.01])
ax = fig.add_subplot(gs[0, :], projection=proj)

set_ax(ax, ng_extent, 10, 5, stock_img=False, order=9)

im = gplot.plot_grid(ax, agegrid.data, cmap=ccm.lapaz_r, vmin=0, vmax=153, alpha=0.5, zorder=1)

gplot.plot_continents(ax, edgecolor='gray', facecolor='tan', zorder=2)
gplot.plot_ridges(ax, color='red', alpha=0.5, zorder=3)
gplot.plot_plate_motion_vectors(ax, spacingX=10, spacingY=10, normalise=False, regrid_shape=10, alpha=0.2, zorder=4)

sc0 = ax.scatter(
    plot_x,
    plot_y,
    transform=ccrs.PlateCarree(),
    marker='.',
    c=target_points_ng_probs['prob'],
    s=10,
    cmap=ccm.hawaii_r,
    zorder=5
)

sc1 = ax.scatter(
    min_occ_data['lon'],
    min_occ_data['lat'],
    transform=ccrs.PlateCarree(),
    marker='o',
    facecolor='yellow',
    edgecolor='black',
    s=20 * min_occ_data['weight'],
    alpha=0.7,
    zorder=6
)

gplot.plot_trenches(ax, color='k', alpha=0.3, zorder=6)
gplot.plot_subduction_teeth(ax, spacing=0.03, color='k', alpha=0.3, zorder=7)

cax1 = fig.add_subplot(gs[1, 0])
cax2 = fig.add_subplot(gs[1, 1])

fig.colorbar(sc0, cax=cax2, orientation='horizontal', label='Mineralisation Probability')
fig.colorbar(im, cax=cax1, orientation='horizontal', label='Seafloor age (Ma)', extend='max')

# Define custom legend handles
custom_handles = [
    Patch(facecolor='tan', edgecolor='gray', label='Continental Crust'),  # Custom handle for the filled polygon
    Line2D([0], [0], color='red', lw=2, label='Mid-Ocean Ridge'),  # Custom handle for the line (ridge)
    Line2D([0], [0], marker='o', markerfacecolor='yellow', markeredgecolor='black', markersize=5, linestyle='None', label='Mineral Occurrence')  # Custom handle for mineral occurrences
]

# Add the custom legend to the plot
legend = ax.legend(handles=custom_handles, loc='lower left')

ax.set_title(f'Porphyry Mineralisation Probability')

# plt.savefig(
# f'./figures/muller2019/target_points.png',
# bbox_inches='tight',
# pad_inches=0.1,
# dpi=150
# )

plt.show()

In [None]:
proj = ccrs.LambertAzimuthalEqualArea(150, 0)

# call the PlotTopologies object
gplot = gplately.PlotTopologies(model, coastlines, continents, cob, time=0)

if plate_motion_model == 'muller2016':
    agegrid_file = agegrid_dir + 'Muller_etal_2016_AREPS_v1.17_AgeGrid-0.nc'
elif plate_motion_model == 'muller2019':
    agegrid_file = agegrid_dir + 'Muller_etal_2019_Tectonics_v2.0_AgeGrid-0.nc'

agegrid = gplately.grids.read_netcdf_grid(agegrid_file)

probabilities = []
count = 0

for mask in mask_coords_ng_0['include']:
    if mask:
        probabilities.append(target_points_ng_probs['prob'][count])
        count += 1
    else:
        probabilities.append(np.nan)

nx = mask_coords_ng_0['lon'].nunique()
ny = mask_coords_ng_0['lat'].nunique()

x_min = mask_coords_ng_0['lon'].min()
x_max = mask_coords_ng_0['lon'].max()
y_min = mask_coords_ng_0['lat'].min()
y_max = mask_coords_ng_0['lat'].max()

probabilities_2d = np.reshape(probabilities, (ny, nx))
probabilities_2d_ud = np.flipud(np.reshape(probabilities, (ny, nx)))

# dual colour bars
fig = plt.figure(figsize=(9, 12))
gs = GridSpec(2, 2, hspace=-0.68, wspace=0.1, height_ratios=[1, 0.01])
ax = fig.add_subplot(gs[0, :], projection=proj)

set_ax(ax, ng_extent, 5, 3, stock_img=False, order=9)

im0 = gplot.plot_grid(ax, agegrid.data, cmap=ccm.lapaz_r, vmin=0, vmax=153, alpha=0.5, zorder=1)

gplot.plot_continents(ax, edgecolor='gray', facecolor='tan', zorder=2)
gplot.plot_ridges(ax, color='red', alpha=0.5, zorder=3)
gplot.plot_plate_motion_vectors(ax, spacingX=10, spacingY=10, normalise=False, regrid_shape=10, alpha=0.2, zorder=4)

im1 = plt.imshow(
    probabilities_2d,
    transform=ccrs.PlateCarree(),
    origin='lower',
    cmap=ccm.hawaii_r,
    interpolation='nearest',
    extent=(x_min, x_max, y_min, y_max),
    zorder=5
)

sc1 = ax.scatter(
    min_occ_data['lon'],
    min_occ_data['lat'],
    transform=ccrs.PlateCarree(),
    marker='o',
    facecolor='yellow',
    edgecolor='black',
    s=20 * min_occ_data['weight'],
    alpha=0.7,
    zorder=6
)

gplot.plot_trenches(ax, color='k', alpha=0.5, zorder=7)
gplot.plot_subduction_teeth(ax, spacing=0.01, color='k', alpha=0.5, zorder=8)

cax1 = fig.add_subplot(gs[1, 0])
cax2 = fig.add_subplot(gs[1, 1])

fig.colorbar(im1, cax=cax2, orientation='horizontal', label='Mineralisation probability')
fig.colorbar(im0, cax=cax1, orientation='horizontal', label='Seafloor age (Ma)', extend='max')

# Define custom legend handles
custom_handles = [
    Patch(facecolor='tan', edgecolor='gray', label='Continental Crust'),  # Custom handle for the filled polygon
    Line2D([0], [0], color='red', lw=2, label='Mid-Ocean Ridge'),  # Custom handle for the line (ridge)
    Line2D([0], [0], marker='o', markerfacecolor='yellow', markeredgecolor='black', markersize=5, linestyle='None', label='Mineral Occurrence')  # Custom handle for mineral occurrences
]

# Add the custom legend to the plot
legend = ax.legend(handles=custom_handles, loc='lower left')

ax.set_title(f'Porphyry Mineralisation Probability')

# plt.savefig(
# f'./figures/muller2019/prob-pa/prob_ng.png',
# bbox_inches='tight',
# pad_inches=0.1,
# dpi=150
# )

plt.show()

In [None]:
target_points_ng_prob_map_file = ml_output_dir + 'target_points_ng_prob.tif'

if not(os.path.isfile(target_points_ng_prob_map_file)):
    # export the map to a GeoTIFF file
    xmin, ymin, xmax, ymax = [x_min, y_min, x_max, y_max]
    geotransform = (xmin, 0.2, 0, ymax, 0, -0.2)
    driver = gdal.GetDriverByName('GTiff')
    dataset = driver.Create(target_points_ng_prob_map_file, nx, ny, 1, gdal.GDT_Float32)
    dataset.SetGeoTransform(geotransform)
    srs = osr.SpatialReference()
    srs.ImportFromEPSG(4326)
    dataset.SetProjection(srs.ExportToWkt())
    dataset.GetRasterBand(1).WriteArray(probabilities_2d_ud)
    dataset.FlushCache()
    dataset = None

In [None]:
# Load the geotiff file (mineral prospectivity map) using rioxarray
prospectivity_data = rxr.open_rasterio(ml_output_dir + 'target_points_ng_prob.tif', masked=True).squeeze()

# Load the shapefile (mineral occurrences)
mineral_occurrences = gpd.read_file('../GIS/min_occ_ng.shp')

# Extract probability values at known mineral occurrence locations
occurrence_probabilities = []
for geom in mineral_occurrences.geometry:
    x, y = geom.x, geom.y
    # Extract probability value at this point
    prob_value = prospectivity_data.sel(x=x, y=y, method='nearest').item()
    occurrence_probabilities.append(prob_value)

# Convert to numpy array for easier manipulation
occurrence_probabilities = np.array(occurrence_probabilities)
occurrence_probabilities = occurrence_probabilities[~np.isnan(occurrence_probabilities)]

# Calculate total study area and the number of cells
total_area = np.count_nonzero(~np.isnan(prospectivity_data))
total_occurrences = len(occurrence_probabilities)

# Define the probability thresholds
thresholds = np.linspace(0, 1, 100)

# Initialize lists to store results
occurrence_percentages = []
area_percentages = []

# Calculate percentages for each threshold
for threshold in thresholds:
    occurrence_percentage = np.sum(occurrence_probabilities >= threshold) / total_occurrences * 100
    area_percentage = np.sum(prospectivity_data >= threshold) / total_area * 100
    occurrence_percentages.append(occurrence_percentage)
    area_percentages.append(area_percentage)

# Convert lists to numpy arrays for plotting
occurrence_percentages = np.array(occurrence_percentages)
area_percentages = np.array(area_percentages)

# Set the font sizes
SMALL_SIZE = 12
MEDIUM_SIZE = 14
LARGE_SIZE = 16

# Create the P-A plot
fig, ax1 = plt.subplots(figsize=(10, 6))

# Plot the occurrence percentages
ax1.plot(thresholds, occurrence_percentages, 'b-', linewidth=2)
ax1.set_xlabel('Probability', fontsize=MEDIUM_SIZE)
ax1.set_ylabel('Percentage of Known Mineral Occurrences', color='b', fontsize=MEDIUM_SIZE)
ax1.set_ylim(0, 100)
ax1.set_xlim(0, 1)
ax1.grid(True)

# Increase tick label sizes
ax1.tick_params(axis='both', labelsize=SMALL_SIZE)

# Create a second y-axis to plot the area percentages
ax2 = ax1.twinx()
ax2.plot(thresholds, area_percentages, 'r-', linewidth=2)
ax2.set_ylabel('Percentage of Study Area', color='r', fontsize=MEDIUM_SIZE)
ax2.set_ylim(100, 0)  # Inverted y-axis
ax2.tick_params(axis='y', labelsize=SMALL_SIZE)

# Add horizontal and vertical lines
intersection_x = 0.79
intersection_y1 = 60
intersection_y2 = 40
h_line = ax1.hlines(y=intersection_y1, xmin=0, xmax=1, color='g', linestyles=':')
v_line = ax1.vlines(x=intersection_x, ymin=0, ymax=100, color='g', linestyles=':')

# Add a text box to explain the intersection point with larger font
ax1.annotate(f'At probability {intersection_x}:\n{intersection_y1}% of occurrences\n{intersection_y2}% of area',
             xy=(0.72, 0.82), xycoords='axes fraction',
             ha='center', va='top',
             fontsize=SMALL_SIZE,
             bbox=dict(boxstyle='round', fc='white', ec='gray', alpha=0.8))

plt.tight_layout()  # Adjust layout to prevent clipping of labels

# Adjust the plot margins to make room for annotations
plt.subplots_adjust(left=0.15, right=0.85)

# plt.savefig(
# f'./figures/muller2019/prob-pa/pa_ng.png',
# bbox_inches='tight',
# pad_inches=0.1,
# dpi=150
# )

plt.show()

### Coregistration of Mineral Occurrences

In [None]:
min_occ_prob_dir = parameters['min_occ_prob_dir']
min_occ_prob_files_lst = []
min_occ_prob_tran_files_lst = []

for index in min_occ_data['index']:
    min_occ_prob_files_lst.append(min_occ_prob_dir + f'min_occ_features_{index}.csv')
    min_occ_prob_tran_files_lst.append(min_occ_prob_dir + f'min_occ_features_tran_{index}.csv')

coregistration_point(
    min_occ_data,
    conv_dir,
    conv_prefix,
    conv_ext,
    min_occ_prob_dir,
    file_prefix='min_occ_features',
    time_steps=time_steps,
    search_radius=3,
    plate_motion_model='muller2019'
)

for min_occ_prob_file, min_occ_prob_tran_file in zip(min_occ_prob_files_lst, min_occ_prob_tran_files_lst):
    if not os.path.isfile(min_occ_prob_tran_file):
        min_occ_prob = pd.read_csv(min_occ_prob_file, index_col=False)
        min_occ_prob_tran = min_occ_prob.copy()
        probs = []

        for i, row in min_occ_prob.iterrows():
            row_features = row[features_list]
                
            if row_features.isnull().values.any():
                probs.append(np.nan)
            else:
                try:
                    row_features = st_scaler.transform(row_features.values.reshape(1, -1))
                except:
                    # load the model
                    with open(ml_input_dir + 'st_scaler.pkl', 'rb') as f:
                        st_scaler = pickle.load(f)
                    row_features = st_scaler.transform(row_features.values.reshape(1, -1))

                min_occ_prob_tran.loc[min_occ_prob_tran['age'] == i, features_list] = row_features[0].tolist()
                prob = model_rf.predict_proba(row_features)[0, 1]
                probs.append(prob)
                
        mm_scaler2 = MinMaxScaler()
        probs_scaled = mm_scaler2.fit_transform(np.array(probs).reshape(-1, 1))
        min_occ_prob['prob'] = probs_scaled
        min_occ_prob_tran['prob'] = probs_scaled
        min_occ_prob.to_csv(min_occ_prob_file, index=False)
        min_occ_prob_tran.to_csv(min_occ_prob_tran_file, index=False)

### Probability Changes

In [None]:
@interact
def show_map(file1=min_occ_prob_files_lst, file2=min_occ_prob_tran_files_lst, feature=features_list):
    df1 = pd.read_csv(file1, index_col=False)
    df2 = pd.read_csv(file2, index_col=False)

    fig = plt.figure(figsize=(12, 4))
    ax2 = fig.add_subplot(121, xlim=[df1['age'].max(), 0])
    ax1 = ax2.twinx()

    ax2.plot(df1['age'], df1['prob'], c='red')
    
    index1 = min_occ_prob_files_lst.index(file1)
    age1 = min_occ_data.iloc[index1]['age']
    ax2.vlines(x=age1, ymin=0, ymax=1, color='k', linestyles=':')
    
    ax2.set_ylim(0, 1)
    ax1.plot(df1['age'], df1[feature], c='blue')
    
    ax4 = fig.add_subplot(122, xlim=[df2['age'].max(), 0])
    ax3 = ax4.twinx()

    ax4.plot(df2['age'], df2['prob'], c='red')
    
    index2 = min_occ_prob_tran_files_lst.index(file2)
    age2 = min_occ_data.iloc[index2]['age']
    ax4.vlines(x=age2, ymin=0, ymax=1, color='k', linestyles=':')
    
    ax4.set_ylim(0, 1)
    ax3.plot(df2['age'], df2[feature], c='blue')

    ax2.set_ylabel('Probability')
    ax1.set_ylabel(feature + ' (Actual)')
    ax4.set_ylabel('Probability')
    ax3.set_ylabel(feature + ' (Standardised)')
    
    fig.tight_layout()
    
    plt.show()

#### Smoothed Plot

In [None]:
# smoothing function for the plots
def smooth_data(x, y, points=200):
    x_new = np.linspace(x.min(), x.max(), points)
    spl = make_interp_spline(x, y, k=2)  # b-spline
    y_smooth = spl(x_new)
    return x_new, y_smooth

@interact
def show_map(file1=min_occ_prob_files_lst, feature=features_list):
    df1 = pd.read_csv(file1, index_col=False)
    df1 = df1.replace([np.inf, -np.inf], np.nan).dropna()

    fig = plt.figure(figsize=(6, 4))
    ax1 = fig.add_subplot(111, xlim=[df1['age'].max(), 0])
    ax2 = ax1.twinx()
    
    ax1.set_facecolor('whitesmoke')

    # smooth and plot probability
    age_smooth, prob_smooth = smooth_data(df1['age'], df1['prob'])
    ln1 = ax1.plot(age_smooth, prob_smooth, c='orangered')

    index1 = min_occ_prob_files_lst.index(file1)
    age1 = min_occ_data.iloc[index1]['age']
    ln2 = ax1.vlines(x=age1, ymin=0, ymax=1, color='k', linestyles=':', label='Age of Formation')

    ax1.set_ylim(0, 1)

    # smooth and plot feature
    age_smooth, feature_smooth = smooth_data(df1['age'], df1[feature])
    ln3 = ax2.plot(age_smooth, feature_smooth, c='royalblue')
    
    # Define custom legend handles
    custom_handles = [
        Line2D([0], [0], color='k', linestyle=':', label='Age of formation')  # Custom handle for the line (age of formation)
    ]

    # Add the custom legend to the plot
    legend = ax2.legend(handles=custom_handles, loc='lower left')

    ax1.set_xlabel('Age (Ma)')
    ax1.set_ylabel('Mineralisation probability', color='orangered')
    ax2.set_ylabel('Orthogonal component of the\nrelative motion vector (cm/yr)', color='royalblue')
        
    fig.tight_layout()
    
#     plt.savefig(
#         f'./figures/muller2019/panguna/conv_ortho_cm_yr.png',
#         bbox_inches='tight',
#         pad_inches=0.1,
#         dpi=150
#     )
    
    plt.show()

### Traceplot

In [None]:
proj = ccrs.LambertAzimuthalEqualArea(150, 0)

@interact
def show_map(file=min_occ_prob_files_lst, time=time_steps):
    df = pd.read_csv(file, index_col=False)
    lons = df.loc[df['age'] >= time]['lon'].tolist()
    lats = df.loc[df['age'] >= time]['lat'].tolist()
    bm = df.loc[df['age'] >= time]['before_mineralisation'].tolist()
    val = df.loc[df['age'] >= time]['valid'].tolist()
    
    lons_inval = []
    lats_inval = []
    lons_bm = []
    lats_bm = []
    lons_am = []
    lats_am = []
    
    # colour of the last point
    if not val[0]:
        last_point = 'invalid'
    elif bm[0]:
        last_point = 'before_mineralisation'
    else:
        last_point = 'after_mineralisation'
    
    for index in range(1, len(lons)):
        if not val[index]:
            lons_inval.append(lons[index])
            lats_inval.append(lats[index])
        elif bm[index]:
            lons_bm.append(lons[index])
            lats_bm.append(lats[index])
        else:
            lons_am.append(lons[index])
            lats_am.append(lats[index])
        
    # call the PlotTopologies object
    gplot = gplately.PlotTopologies(model, coastlines, continents, cob, time=time)

    if plate_motion_model == 'muller2016':
        agegrid_file = agegrid_dir + f'Muller_etal_2016_AREPS_v1.17_AgeGrid-{time}.nc'
    elif plate_motion_model == 'muller2019':
        agegrid_file = agegrid_dir + f'Muller_etal_2019_Tectonics_v2.0_AgeGrid-{time}.nc'
        
    agegrid = gplately.grids.read_netcdf_grid(agegrid_file)
    
    mask_coords = mask_coords_lst[time_steps.index(time)]
    
    probabilities = []
    count = 0
            
    for mask in mask_coords['include']:
        if mask:
            probabilities.append(target_points_prob_lst[time_steps.index(time)]['prob'][count])
            count += 1
        else:
            probabilities.append(np.nan)
    
    nx = mask_coords_lst[time_steps.index(time)]['lon'].nunique()
    ny = mask_coords_lst[time_steps.index(time)]['lat'].nunique()
    
    x_min = mask_coords_lst[time_steps.index(time)]['lon'].min()
    x_max = mask_coords_lst[time_steps.index(time)]['lon'].max()
    y_min = mask_coords_lst[time_steps.index(time)]['lat'].min()
    y_max = mask_coords_lst[time_steps.index(time)]['lat'].max()
    
    probabilities_2d = np.reshape(probabilities, (ny, nx))
    probabilities_2d_ud = np.flipud(np.reshape(probabilities, (ny, nx)))
    
    # single colour bar
    fig = plt.figure(figsize=(6, 8))
    gs = GridSpec(2, 2, hspace=-0.4, wspace=0.1, height_ratios=[1, 0.02])
    ax = fig.add_subplot(gs[0, :], projection=proj)
    
    set_ax(ax, target_extent_anim, 15, 15, stock_img=False, order=10)

    im0 = gplot.plot_grid(ax, agegrid.data, cmap=ccm.lapaz_r, vmin=0, vmax=153, alpha=0.5, zorder=1)
    
    gplot.plot_continents(ax, edgecolor='gray', facecolor='tan', zorder=2)
    gplot.plot_ridges_and_transforms(ax, color='red', alpha=0.5, zorder=3)
    gplot.plot_plate_motion_vectors(ax, spacingX=10, spacingY=10, normalise=False, regrid_shape=10, alpha=0.2, zorder=4)
    
    im1 = plt.imshow(
        probabilities_2d,
        transform=ccrs.PlateCarree(),
        origin='lower',
        cmap=ccm.hawaii_r,
        interpolation='nearest',
        extent=(x_min, x_max, y_min, y_max),
        alpha=0.7,
        zorder=5
    )
    
    if last_point == 'invalid':
        sc = ax.scatter(lons[0], lats[0], transform=ccrs.PlateCarree(), marker='*', facecolor='gray', s=20, zorder=6)
    elif last_point == 'before_mineralisation':
        sc = ax.scatter(lons[0], lats[0], transform=ccrs.PlateCarree(), marker='*', facecolor='royalblue', s=20, zorder=6)
    else:
        sc = ax.scatter(lons[0], lats[0], transform=ccrs.PlateCarree(), marker='*', facecolor='lime', s=20, zorder=6)
    
    sc = ax.scatter(lons_bm, lats_bm, transform=ccrs.PlateCarree(), marker='.', facecolor='royalblue', s=20, zorder=7)
    sc = ax.scatter(lons_am, lats_am, transform=ccrs.PlateCarree(), marker='.', facecolor='lime', s=20, zorder=7)
    sc = ax.scatter(lons_inval, lats_inval, transform=ccrs.PlateCarree(), marker='.', facecolor='gray', s=20, zorder=7)
    
    gplot.plot_trenches(ax, color='k', alpha=0.3, zorder=8)
    gplot.plot_subduction_teeth(ax, spacing=0.03, color='k', alpha=0.3, zorder=9)
    
    cax1 = fig.add_subplot(gs[1, 0])
    cax2 = fig.add_subplot(gs[1, 1])
    
    fig.colorbar(im1, cax=cax2, orientation='horizontal', label='Mineralisation Probability')
    fig.colorbar(im0, cax=cax1, orientation='horizontal', label='Seafloor age (Ma)', extend='max')

    # Define custom legend handles
    custom_handles = [
        Patch(facecolor='tan', edgecolor='gray', label='Continental Crust'),  # Custom handle for the filled polygon
        Line2D([0], [0], color='red', lw=2, label='Mid-Ocean Ridge'),  # Custom handle for the line (ridge)
        Line2D([0], [0], marker='.', markerfacecolor='royalblue', markeredgecolor='none', markersize=10, linestyle='None', label='Pre-Mineralisation'),
        Line2D([0], [0], marker='.', markerfacecolor='lime', markeredgecolor='none', markersize=10, linestyle='None', label='Post-Mineralisation')
    ]

    # Add the custom legend to the plot
    legend = ax.legend(handles=custom_handles, loc='lower left')
    legend.set_zorder(11)
    
    plt.show()

In [None]:
proj = ccrs.LambertAzimuthalEqualArea(150, 0)

@interact
def show_map(file=min_occ_prob_files_lst, time=time_steps):
    df = pd.read_csv(file, index_col=False)
    lons = df.loc[df['age'] >= time]['lon'].tolist()
    lats = df.loc[df['age'] >= time]['lat'].tolist()
    bm = df.loc[df['age'] >= time]['before_mineralisation'].tolist()
    val = df.loc[df['age'] >= time]['valid'].tolist()
        
    # colour of the last point
    if not val[0]:
        last_point = 'invalid'
    elif bm[0]:
        last_point = 'before_mineralisation'
    else:
        last_point = 'after_mineralisation'

    # list of invalid lines
    lons_inval = []
    lats_inval = []
    lons_inval_temp = []
    lats_inval_temp = []

    for index2 in range(len(lons)-1, -1, -1):
        if not val[index2]:
            lons_inval_temp.append(lons[index2])
            lats_inval_temp.append(lats[index2])
            if index2 == 0:
                lons_inval.append(lons_inval_temp)
                lats_inval.append(lats_inval_temp)
        else:
            if len(lons_inval_temp) != 0:
                lons_inval_temp.append(lons[index2])
                lats_inval_temp.append(lats[index2])
                lons_inval.append(lons_inval_temp)
                lats_inval.append(lats_inval_temp)
                lons_inval_temp = []
                lats_inval_temp = []
            else:
                continue

    # list of lines created before mineralisation
    lons_bm = []
    lats_bm = []
    lons_bm_temp = []
    lats_bm_temp = []

    for index2 in range(len(lons)-1, -1, -1):
        if bm[index2] and val[index2]:
            lons_bm_temp.append(lons[index2])
            lats_bm_temp.append(lats[index2])
            if index2 == 0:
                lons_bm.append(lons_bm_temp)
                lats_bm.append(lats_bm_temp)
        else:
            if len(lons_bm_temp) != 0:
                lons_bm_temp.append(lons[index2])
                lats_bm_temp.append(lats[index2])
                lons_bm.append(lons_bm_temp)
                lats_bm.append(lats_bm_temp)
                lons_bm_temp = []
                lats_bm_temp = []
            else:
                continue

    # list of lines created after mineralisation
    lons_am = []
    lats_am = []
    lons_am_temp = []
    lats_am_temp = []

    for index2 in range(len(lons)-1, -1, -1):
        if not bm[index2] and val[index2]:
            lons_am_temp.append(lons[index2])
            lats_am_temp.append(lats[index2])
            if index2 == 0:
                lons_am.append(lons_am_temp)
                lats_am.append(lats_am_temp)
        else:
            if len(lons_am_temp) != 0:
                lons_am_temp.append(lons[index2])
                lats_am_temp.append(lats[index2])
                lons_am.append(lons_am_temp)
                lats_am.append(lats_am_temp)
                lons_am_temp = []
                lats_am_temp = []
            else:
                continue

    # call the PlotTopologies object
    gplot = gplately.PlotTopologies(model, coastlines, continents, cob, time=time)
    
    if plate_motion_model == 'muller2016':
        agegrid_file = agegrid_dir + f'Muller_etal_2016_AREPS_v1.17_AgeGrid-{time}.nc'
    elif plate_motion_model == 'muller2019':
        agegrid_file = agegrid_dir + f'Muller_etal_2019_Tectonics_v2.0_AgeGrid-{time}.nc'

    agegrid = gplately.grids.read_netcdf_grid(agegrid_file)
    
    mask_coords = mask_coords_lst[time_steps.index(time)]
    
    probabilities = []
    count = 0
            
    for mask in mask_coords['include']:
        if mask:
            probabilities.append(target_points_prob_lst[time_steps.index(time)]['prob'][count])
            count += 1
        else:
            probabilities.append(np.nan)
    
    nx = mask_coords_lst[time_steps.index(time)]['lon'].nunique()
    ny = mask_coords_lst[time_steps.index(time)]['lat'].nunique()
    
    x_min = mask_coords_lst[time_steps.index(time)]['lon'].min()
    x_max = mask_coords_lst[time_steps.index(time)]['lon'].max()
    y_min = mask_coords_lst[time_steps.index(time)]['lat'].min()
    y_max = mask_coords_lst[time_steps.index(time)]['lat'].max()
    
    probabilities_2d = np.reshape(probabilities, (ny, nx))
    probabilities_2d_ud = np.flipud(np.reshape(probabilities, (ny, nx)))
    
    # single colour bar
    fig = plt.figure(figsize=(6, 8))
    gs = GridSpec(2, 2, hspace=-0.4, wspace=0.1, height_ratios=[1, 0.02])
    ax = fig.add_subplot(gs[0, :], projection=proj)
    
    set_ax(ax, target_extent_anim, 15, 15, stock_img=False, order=10)

    im0 = gplot.plot_grid(ax, agegrid.data, cmap=ccm.lapaz_r, vmin=0, vmax=153, alpha=0.5, zorder=1)
    
    gplot.plot_continents(ax, edgecolor='gray', facecolor='tan', zorder=2)
    gplot.plot_ridges_and_transforms(ax, color='red', alpha=0.5, zorder=3)
    gplot.plot_plate_motion_vectors(ax, spacingX=10, spacingY=10, normalise=False, regrid_shape=10, alpha=0.2, zorder=4)
    
    im1 = plt.imshow(
        probabilities_2d,
        transform=ccrs.PlateCarree(),
        origin='lower',
        cmap=ccm.hawaii_r,
        interpolation='nearest',
        extent=(x_min, x_max, y_min, y_max),
        alpha=0.7,
        zorder=5
    )
        
    for lons_, lats_ in zip(lons_bm, lats_bm):
        sc = ax.plot(lons_, lats_, transform=ccrs.PlateCarree(), color='royalblue', zorder=6)
        
    for lons_, lats_ in zip(lons_am, lats_am):
        sc = ax.plot(lons_, lats_, transform=ccrs.PlateCarree(), color='lime', zorder=6)
        
    for lons_, lats_ in zip(lons_inval, lats_inval):
        sc = ax.plot(lons_, lats_, transform=ccrs.PlateCarree(), color='gray', zorder=6)
        
    if last_point == 'invalid':
        sc = ax.scatter(lons[0], lats[0], transform=ccrs.PlateCarree(), marker='*', facecolor='gray', s=20, zorder=7)
    elif last_point == 'before_mineralisation':
        sc = ax.scatter(lons[0], lats[0], transform=ccrs.PlateCarree(), marker='*', facecolor='royalblue', s=20, zorder=7)
    else:
        sc = ax.scatter(lons[0], lats[0], transform=ccrs.PlateCarree(), marker='*', facecolor='lime', s=20, zorder=7)
    
    gplot.plot_trenches(ax, color='k', alpha=0.3, zorder=8)
    gplot.plot_subduction_teeth(ax, spacing=0.03, color='k', alpha=0.3, zorder=9)
    
    cax1 = fig.add_subplot(gs[1, 0])
    cax2 = fig.add_subplot(gs[1, 1])
    
    fig.colorbar(im1, cax=cax2, orientation='horizontal', label='Mineralisation Probability')
    fig.colorbar(im0, cax=cax1, orientation='horizontal', label='Seafloor age (Ma)', extend='max')
    
    # Define custom legend handles
    custom_handles = [
        Patch(facecolor='tan', edgecolor='gray', label='Continental Crust'),  # Custom handle for the filled polygon
        Line2D([0], [0], color='red', lw=2, label='Mid-Ocean Ridge'),  # Custom handle for the line (ridge)
        Line2D([0], [0], color='royalblue', lw=2, label='Pre-Mineralisation'),
        Line2D([0], [0], color='lime', lw=2, label='Post-Mineralisation')
    ]

    # Add the custom legend to the plot
    legend = ax.legend(handles=custom_handles, loc='lower left')
    legend.set_zorder(11)
    
    plt.show()