In [None]:
import numpy as np
import matplotlib.pyplot as plt
import geopandas as gpd
import pandas as pd
import pygmt
import pygplates

from gprm.datasets import Reconstructions, Zircons
from gprm.utils.spatial import topology_lookup
#from gprm.utils.create_gpml import gdf2gpml, gpml2gdf

#from collections import OrderedDict

import sys
sys.path.append('..')
import molchan

%load_ext autoreload
%autoreload 2


M2021 = Reconstructions.fetch_Merdith2021()
L2023 = Reconstructions.fetch_Li2023()

SedimentaryZircons = Zircons.get_sedimentary_samples(version=2021)


for f in M2021.continent_polygons[0]:
    if f.get_name() in ['Laurentia Parautochthon',
                        'Kootenay Terrane (NAM affinity Colpron et al 2007)',
                        'Cassiar Terrane (NAM affinity Colpron et al 2007)',
                        'Yukon-Tanana Terrane',
                        'Baikal-Muya']:
        f.set_valid_time(1000,-999)
        


In [None]:
def prepare_rasters(reconstruction_model, boundary_lookup,
                    reconstruction_times, sampling=0.5, polygon_buffer_distance=None):
    # Make two raster sequences, where:
    # 1. Mask rasters where the pixels lying within continents (from those lying outside, therefore unreconstructable)
    # 2. Distance rasters from subduction zone geometries

    reconstruction_raster_dict = molchan.generate_raster_sequence_from_polygons(
        reconstruction_model.continent_polygons[0],
        reconstruction_model.rotation_model,
        reconstruction_times,
        sampling=sampling,
        buffer_distance=polygon_buffer_distance
    )

    sz_distance_dict = molchan.generate_distance_raster_sequence(
        boundary_lookup,
        reconstruction_model,
        reconstruction_times,
        sampling=sampling
    )
    
    #Combine the distance rasters with the mask rasters
    sz_distance_dict_mask = molchan.combine_raster_sequences(sz_distance_dict, 
                                                             reconstruction_raster_dict)
    
    return sz_distance_dict_mask


In [None]:
# Load the ore deposits, and reconstruct them to the time at the nearest time in the raster sequences
def reconstruct_target_features(features, reconstruction_model, analysis_time_step, age_field):

    features = reconstruction_model.assign_plate_ids(features, copy_valid_times=True)
    
    #features = features.query('FROMAGE>=@age_field')

    features['AnalysisAge'] = np.round(features[age_field]/analysis_time_step)*analysis_time_step
    
    features = features[features['AnalysisAge']<=features['FROMAGE']]

    return reconstruction_model.reconstruct_to_time_of_appearance(features, 
                                                                  ReconstructTime='AnalysisAge')


In [None]:
# Run experiments where the distance of each deposit to the nearest target at the time of deposit formation
# is determined from the distance raster sequence

analysis_time_step = 10.
raster_sampling = 2.
polygon_buffer_distance = 500e3
max_time = 1000.
interpolater = 'pygmt'

reconstruction_times = np.arange(0,1001,analysis_time_step)

MODEL = M2021

# Create a lookup table of the subduction zones reconstructed from topologies
sz_lookup = topology_lookup(MODEL, 
                            reconstruction_times,
                            boundary_types=['subduction'])


sz_distance_dict_mask = prepare_rasters(MODEL, sz_lookup,
                                        reconstruction_times, 
                                        sampling=raster_sampling, 
                                        polygon_buffer_distance=polygon_buffer_distance)


In [None]:
ABC = Zircons.tectonic_category(SedimentaryZircons)

Azircons = ABC.query("TectonicClass.str.match('A')")

Analysis_Result = reconstruct_target_features(Azircons.query('Est_Depos_Age_Ma<=@max_time'),
                                              MODEL, analysis_time_step, age_field='Est_Depos_Age_Ma')

Analysis_Result


In [None]:
#results = []
"""
#for i,row in Analysis_Result.iterrows():
def mpoint(feature, raster_dict, age_field_name, interpolater):
    #print(feature)
    reconstruction_time = feature[age_field_name]
    result = molchan.molchan_point(raster_dict[reconstruction_time],
                           pd.DataFrame(data={'Longitude': [feature.geometry.x], 
                                              'Latitude': [feature.geometry.y]}),
                           interpolater='scipy',
                           return_fraction=False)
    return result
    #results.append(result)

interpolater = 'scipy'
Analysis_Result.apply(lambda x: mpoint(x,
                                       sz_distance_dict_mask,
                                       'AnalysisAge',
                                       interpolater), axis=1)

#Analysis_Result
"""
# Now get the distances from the raster sequence mapped to each point
#'''
results_sz = molchan.space_time_distances(sz_distance_dict_mask,
                                          Analysis_Result,
                                          age_field_name='AnalysisAge', 
                                          interpolater=interpolater)
#'''

In [None]:
# To get the skill, we need to know not only the distances of the points to the features
# but also how these distances compare to the % of all points across all times within
# the same distance contour
sz_fractions = molchan.space_time_molchan_test(
    sz_distance_dict_mask, 
    results_sz['distance'].dropna(),
    healpix_resolution=16,
    #distance_step=1e5,
    interpolater=interpolater)
print(sz_fractions[2])



plt.plot(sz_fractions[0],sz_fractions[1], 'r', 
         label='Samples to subduction zone  | Skill = {:0.3f}'.format(sz_fractions[2]))
plt.grid()
plt.legend()
plt.xlim(0,1)
plt.ylim(0,1)
plt.xlabel('Grid Fraction')
plt.ylabel('Fraction of points missed')

#plt.savefig('./images/molchan_space_time_1000Ma.png')
#plt.close()
plt.show()


In [None]:
results_sz


In [None]:
#age_field = 'Est_Depos_Age_Ma'
#Azircons.query('FROMAGE>=@age_field')

num_cpus = 2

from joblib import Parallel, delayed

tmp = Parallel(n_jobs=num_cpus)(delayed(molchan.scipy_interpolater) \
                                (sz_distance_dict_mask[reconstruction_time], [(0,-88),(0,-80),(10,-75)]) \
                               for reconstruction_time in reconstruction_times)

np.vstack(tmp)

In [None]:
import rasterio
import xarray as xr

#src = rasterio.open(sz_distance_dict_mask)

#da = xr.open_dataarray(sz_distance_dict_mask[])
#da = sz_distance_dict_mask[200]
#da.sel



In [None]:
from gprm.utils.proximity import boundary_proximity

brrr = molchan.generate_raster_sequence_from_polygons(
    MODEL.continent_polygons[0],
    MODEL.rotation_model,
    reconstruction_times[:5],
    sampling=raster_sampling
)


In [None]:
br = brrr[20].copy(deep=True)
bn = boundary_proximity(br)

#bn = bn.where(bn<1e6, np.nan)
br.data[bn.data<500000] = 1
br.plot()
