In [None]:
from tqdm.notebook import tqdm
import numpy as np
from numba import njit, prange
from math import sqrt

def _loc_max(max_distance, X, Y, Field, loc_threshold = 0):
    
    peak_x, peak_y, peak_field, peak_idx_x, peak_idx_y = _find_reduced_set_of_local_maxima(max_distance, X, Y, Field, loc_threshold)
          
    arg_index = np.argsort(peak_field)
    peak_x = np.flip([peak_x[i] for i in arg_index])
    peak_y = np.flip([peak_y[i] for i in arg_index])
    peak_idx_x = np.flip([peak_idx_x[i] for i in arg_index])
    peak_idx_y = np.flip([peak_idx_y[i] for i in arg_index])
    peak_field = np.flip([peak_field[i] for i in arg_index])
    
    return peak_idx_x, peak_idx_y, peak_x, peak_y, peak_field

In [None]:
@njit()
def _find_all_local_maxima(X, Y, Field):
            
    loc_max_x, loc_max_y, loc_max_field = [], [], []
    idx_x, idx_y = [], []
            
    for i in range(2, X.shape[0]-2):
                
        for j in range(2, Y.shape[1]-2):
                    
            if np.isfinite(Field[i, j]) and Field[i, j] > Field[i+1, j] and Field[i, j] > Field[i-1, j] and Field[i, j] > Field[i, j+1] and Field[i, j] > Field[i, j-1]:
                        
                loc_max_x.append(X[i, j])
                loc_max_y.append(Y[i, j])
                loc_max_field.append(Field[i, j])
                idx_x.append(j)
                idx_y.append(i)
            
    return idx_x, idx_y, loc_max_x, loc_max_y, loc_max_field

In [None]:
@njit()
def _find_reduced_set_of_local_maxima(max_distance, X, Y, Field, loc_threshold):
    
    idx_x, idx_y, loc_max_x, loc_max_y, loc_max_field = _find_all_local_maxima(X, Y, Field)
    
    n_loc_max = len(loc_max_x)
        
    peak_x, peak_y, peak_field = [], [], []
    peak_idx_x, peak_idx_y = [], []
        
    for i in range(n_loc_max):
            
        bool_loc_max = True
    
        for j in range(n_loc_max):
            
            if i != j and loc_max_field[i] < loc_max_field[j] and sqrt((loc_max_x[i]-loc_max_x[j])**2+(loc_max_y[i]-loc_max_y[j])**2) <= max_distance:
                    
                bool_loc_max = False
                
        if bool_loc_max and loc_max_field[i] > loc_threshold:
                
            peak_x.append(loc_max_x[i])
            peak_y.append(loc_max_y[i])
            peak_field.append(loc_max_field[i])
            peak_idx_x.append(idx_x[i])
            peak_idx_y.append(idx_y[i])
            
    return peak_x, peak_y, peak_field, peak_idx_x, peak_idx_y