# <center>OceanBasin_Manuscript JN</center>
This jupyter-notebook is used to make a set of manuscript figures

<center>Figure 1: ....</center>
<br>
<br>






# Leiden Detection + Post Processing with Reconstruction (Evolution)

In [None]:
#################################################
#################### Imports ####################
#################################################
import os
import copy as cp
import ExoCcycle as EC
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import networkx as nx
from netCDF4 import Dataset
import cmcrameri.cm as cmc
import itertools

# Create nodeclustering object
from cdlib import evaluation
from cdlib import NodeClustering


####################################
### Reconstruction period inputs ###
####################################
ages = [0, 5, 10, 15, 20, 25, 30, 35, 40, 45, 50, 55, 60, 65, 70, 75, 80];
ages = [0]
agestr = [ str(age) for age in ages ];
minBasinCntV = [10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10]
minBasinCntV = [10]
minBasinCntVstr = [ str(cnt) for cnt in minBasinCntV ];



#########################################
### Define Community Detection Inputs ###
#########################################
# Set the detection method
communityDetectionMethod = "Leiden-Girvan-Newman"
communityDetectionMethod = "Leiden"

# Define basin merging criteria
mergerPackageName = "Lite"; # ['threshold'] = [0]

# Define basin merging criteria
mergerPackage = EC.utils.mergerPackages(mergerPackageName);
mergerPackage['verbose'] = False;

# Resolution for quality function
resolutions = [.01]

# Minimum number of basins to have in output (only used for
# girvan-newman or composite algorithms)
minBasinCnts = [12]

# Set the ensemble size to use for the first part of the composite community detection
# This part runs Louvain or Leiden algorithms to reduce the network complexity. Setting
# a non-one ensemble ensures that community structure is robust given inherent randomness
# of initial node clustering. Note that ensembles of size 100 for 1 degree resolution data
# only increase total computational time by 1-2 minutes.
ensembleSizes = [50];


# Define fields to plot weights for. Note that weights must be calculated
# for these fields.
fieldNums = ["Field2"]

# Show the resolution, ensembleSize, and minBasinCnt used for community detection
# runs. Note that only one community detection is run here.
def combine_lists(*lists):
    """
    Combine N input lists into all possible unique combinations (cartesian product).
    Returns a numpy array of shape (number_of_combinations, N).
    """
    # Generate all combinations (cartesian product)
    combinations = list(itertools.product(*lists))
    
    # Convert to numpy array
    A = np.array(combinations)
    
    return A

# Show the resolution, ensembleSize, and minBasinCnt used for community detection
# runs. Note that only one community detection is run here.
for resolution, ensembleSize, minBasinCnt, age in combine_lists(resolutions, ensembleSizes, minBasinCnts, ages):
    print(resolution, ensembleSize, minBasinCnt, age)

# Run for multiple resolution, ensembleSize, and minBasinCnt
for resolution, ensembleSize, minBasinCnt, age in combine_lists(resolutions, ensembleSizes, minBasinCnts, ages):
    # Change data type to avoid errors
    minBasinCnt = int(minBasinCnt);
    ensembleSize = int(ensembleSize);
    age = int(age);
    
    # Set detection method
    detectionMethod = {"method":communityDetectionMethod,
                       "resolution":resolution,
                       "minBasinCnt":minBasinCnt,
                       "ensembleSize":ensembleSize,
                       "minBasinLargerThanSmallMergers":True,
                       "mergerPackage":mergerPackage}

    # Set the edge weight scheme for node connections
    # Options:
    #    "useGlobalDifference", "useEdgeDifference", "useEdgeGravity"
    #    "useLogistic", "useNormPDFFittedSigmoid", "useQTGaussianSigmoid"
    #    "useQTGaussianShiftedGaussianWeightDistribution"
    edgeWeightMethod = {"method":"useQTGaussianShiftedGaussianWeightDistribution",
                       "shortenFactor": 5,
                       "shiftFactor": .5,
                       "minWeight": 0.01}

    # Make folder to hold figure results
    !mkdir -p figures/GMD_Manuscript/CodeOutputs/ReconBathymetry
    fldName = EC.utils.makeFolderSeries(fldBase='figures/GMD_Manuscript/CodeOutputs/ReconBathymetry/Reconstruction_{}Ma-{}-PP'.format(age, communityDetectionMethod))
    print("Storing images in {}".format(fldName))
    
    # Short readme text to write to folder with images
    readmetxt = "Note that the Bathymetry values are shown with a colorbar that represents 1 std that are area weighted.";
    readmetxt += "\nUsing model S <- QTG with useQTGaussianShiftedGaussianWeightDistribution (cdfCenter  = qtDissSTD*{0} and cdfStretch = qtDissSTD/{1}) for edge weights".format(edgeWeightMethod["shiftFactor"], edgeWeightMethod["shortenFactor"]);
    readmetxt += "\nUsing S/distanceV edge weight";
    readmetxt += "\nWhere S = 1-CDF(difference), CDF=cumulative density function.";
    readmetxt += "\nWhere difference = values1-values2, the differnce between node property value after a Quantile Transformation values1-values2.";
    readmetxt += "\nThe CDF used for S is calculated as follows:";
    readmetxt += "\nThe absolute value of node values differences are collected into a vector (dataEdgeDiff).";
    readmetxt += "\nOutliers removed using the IQR method to make a filtered dataset (dataEdgeDiffIQRFiltered).";
    readmetxt += "\ndataEdgeDiffIQRFiltered (all positive) is mirror about 0 (symmetric about 0) is converted to a gaussian (z-score space) using a Quantile Transformation.";
    readmetxt += "\nThe distribution (dist1) of differences in gaussian (z-score space) is used to construct a CDF function";
    readmetxt += "\nA new distribution created from dist1 by offsetting it by 1 sigma_dist1 and shortening it by setting a new 1 sigma_dist2 of sigma_dist1/{}".format(edgeWeightMethod["shortenFactor"]);
    readmetxt += "\nThen |difference| of node properties can be expressed in z-score space and S can be calculated as S = 1-CDF(difference)";
    readmetxt += "\nThe minimum weight for this method is set to {}".format(edgeWeightMethod['minWeight']);
    readmetxt += "\nUsing {} algorithm".format(detectionMethod["method"]);
    readmetxt += "\n{} resolution: {}".format(detectionMethod["method"], detectionMethod["resolution"]);
    readmetxt += "\n{} ensemble size: {}".format(detectionMethod["method"], detectionMethod["ensembleSize"]);
    readmetxt += "\nGirvan-Newman minimum unisolated basins: {}".format(detectionMethod['minBasinCnt']);
    readmetxt += "\nCommunity merger package is EC.utils.mergePackage(package='{}')".format(mergerPackageName);



    #################################################################
    ### Create basin object and set Field for Community detection ###
    #################################################################

    # Create basin object
    body = ["Earth", "Mars", "Venus", "Moon"]
    body = body[0]
    basins = EC.utils.BasinsEA(dataDir=os.getcwd()+"/bathymetries/{}".format(body),
                             filename="{}_resampled_1deg.nc".format(body),
                             body=body);
    ####################################
    # Add bathymetry field
    basins.addField(resolution = basins.Fields["Field1"]["resolution"],
                    dataGrid =  os.getcwd()+'/PNAS_Bogumil_Results/bathymetryNCFiles/Bathymetry_{}Ma.nc'.format(age),
                    parameter = "z",
                    parameterUnit = basins.Fields["Field1"]["parameterUnit"],
                    parameterName = basins.Fields["Field1"]["parameterName"])

    # Assign fields to use in community detection
    basins.useFields(fieldList=np.array(["Field2"]))

    # Show all fields stored in basins object
    basins.getFields(usedFields = False)

    # Show all fields stored in basins object that will be used
    # for community detection.
    basins.getFields(usedFields = True)
    
    # Set field mask parameters
    fieldMaskParameter = {"usedField":0, "fliprl":False, "flipud":True}
    
    #########################################
    ### Run Community Detection Algorithm ###
    #########################################

    # Define basins based on user input boundaries.
    # For the Louvain-Girvan-Newman composite algorithm the variable
    # minBasinCnt refers to the number of basins to maintain that are
    # not completely isolated after running the louvain algorithm.
    basins.defineBasins(detectionMethod = detectionMethod,
                        edgeWeightMethod = edgeWeightMethod,
                        reducedRes={"on":True,"factor":1},
                        read=False,
                        write=True,
                        verbose=False)


    # Merge communities based off criteria 
    basins.applyMergeBasinMethods(mergerID=0, mergerPackage=mergerPackage)

    # Convert basinID equal area grid to regular grid
    basins.interp2regularGrid(mask=True)


    #####################################
    ### Plot results of community IDs ###
    #####################################
    EC.utils.plotGlobal(basins.lat, basins.lon, basins.BasinIDA,
                        outputDir = os.getcwd()+"/"+fldName,
                        fidName = "plotGlobal.png",
                        cmapOpts={"cmap":"jet",
                                  "cbar-title":"cbar-title",
                                  "cbar-range":[0,np.nanmax(basins.BasinIDA)]},
                        pltOpts={"valueType": "BasinID divided by {}".format(basins.Fields["Field1"]['parameterName']),
                                 "valueUnits": "-",
                                 "plotTitle":"",
                                 "plotZeroContour":False,
                                 "plotIntegerContours":True,
                                 "transparent":True},
                        savePNG=True,
                        saveSVG=False)

    for fieldNum in fieldNums:
        # Read reconstructed bathymetry to plot it
        lon, lat, bathymetry, continents1 = load_netcdf_data(age)
        lon = lon[::10].T[::10].T
        lat = lat[::10].T[::10].T
        bathymetry = bathymetry[::10].T[::10].T

        # Calculate area weighted average and standard deviation (for plotting)
        areaWeights, longitudes, latitudes, totalArea, totalAreaCalculated = EC.utils.areaWeights(resolution=basins.Fields["Field1"]['resolution'],
                                                                                                  LonStEd = [np.min(basins.lon),np.max(basins.lon)+basins.Fields["Field1"]['resolution']],
                                                                                                  LatStEd = [np.min(basins.lat),np.max(basins.lat)+basins.Fields["Field1"]['resolution']])
        ave, std = EC.utils.weightedAvgAndStd(basins.bathymetry, areaWeights)

        #########################
        ### Plot input fields ###
        #########################
        EC.utils.plotGlobal(lat, lon, bathymetry,
                            outputDir = os.getcwd()+"/"+fldName,
                            fidName = "plotGlobal_{0}.png".format(basins.Fields[fieldNum]['parameterName']),
                            cmapOpts={"cmap":"jet",
                                      "cbar-title":"cbar-title",
                                      "cbar-range":[ave-1*std,
                                                    ave+1*std]},
                            pltOpts={"valueType": "{0}".format(basins.Fields[fieldNum]['parameterName']),
                                     "valueUnits": "{}".format(basins.Fields[fieldNum]['parameterUnit']),
                                     "plotTitle":"",
                                     "plotZeroContour":False,
                                     "transparent":True},
                            savePNG=True,
                            saveSVG=False)
    
    
    ###########################
    ### Plot DQT-CDF Values ###
    ###########################
    for fieldNum in fieldNums:
        #####################################################
        ####### Find average node neighbor difference #######
        #####################################################
        # Iterate over each node
        attrs = None;
        for node in basins.G.nodes:
            # Average node connection difference for each node
            temp = 0; cnt = 0;
            for conNode in basins.G.neighbors(node):
                temp+= np.abs( basins.G.nodes[conNode][fieldNum]-basins.G.nodes[node][fieldNum] )
                cnt+=1;
            try:
                diffAve = temp/cnt; 
            except:
                # If node has no connections. This rarely happens.
                #print(temp, cnt, basins.G.neighbors(node) )
                diffAve = 0
            # Collect average node neighbor difference property
            if attrs == None:
                attrs = {node: {"diffAve": diffAve}};
            else:
                attrs[node] = {"diffAve": diffAve};

        # Assign average node neighbor difference node property to graph
        G = nx.set_node_attributes(basins.G, attrs)

        # List values
        diffAve_values = list(basins.G.nodes(data="diffAve"))


        #####################################################
        ################ Interpolate to grid ################
        #####################################################
        def interp2regularGrid(basins, dataIrregular=None, mask=True):
            """
            interp2regularGrid method is used to interpolate data to
            a regular grid given an input of irregular spaced data.

            Parameters
            -----------
            dataIrregular : NUMPY ARRAY
                3XN numpy array with columns of longitude, latitude, magnitude.
                The default is None. This will make the function define the 
                dataIrregular variable with basinIDs.
            mask : STRING
                The path to a netCDF4 file that can be used to mask the result
                of interpolation. The default is None.

            Returns
            --------
            array : NUMPY ARRAY
                A 2nxn array that hold node properties for each corresponding
                entry in basins.lat and basins.lon. 

            """

            import copy as cp

            # Get basin IDs from network object.
            tmpValuesID  = nx.get_node_attributes(basins.G, "diffAve");
            tmpValuesPos = nx.get_node_attributes(basins.G, "pos");

            # Define an array to hold longitude, latitude, and basinID
            dataIrregular = np.zeros((len(tmpValuesPos), 3))

            # Iterate over all nodes so each node's longitude, latitude,
            # and basinID can be added to the dataIrregular array.
            for i in tmpValuesID:
                dataIrregular[i,:] = np.array([tmpValuesPos[i][1], tmpValuesPos[i][0], tmpValuesID[i]])

            # Define an array 2nxn to hold the basin IDs for the regular grid
            # on the surface of the a sphere (planet). 
            array = cp.deepcopy(basins.lat)

            # Define a mapping function that maps node indecies on a irregular grid
            # to those on the regular grid. This will speed up calculations if this
            # function is called more than once.

            # Iterate over all latitude and longitudes of the input grid.
            for i in range(len(basins.lat[:,0])):
                for j in range(len(basins.lat[0,:])):
                    # Find the distances from each regular grid point (i,j) to all
                    # irregular grid points.
                    x = EC.utils.haversine_distance(lat2= dataIrregular[:,1], lat1= basins.lat[i,j],
                                            lon2= dataIrregular[:,0], lon1= basins.lon[i,j],
                                            radius=1)

                    # Assign the nearest basin ID to element (i,j) 
                    array[i,j] = dataIrregular[np.argwhere(np.nanmin(x) == x)[0][0], 2]

            ## Apply the mask
            if mask:
                array[np.isnan(basins.maskValue)] = np.nan


            return array;

        # Interpolate from equal area grid to regular spaced latitude/longitude
        diffAveGrd = interp2regularGrid(basins, mask=True)

        #####################################################
        ####################### Plot ########################
        #####################################################
        # Transform using the QT
        DQT_zscore = cp.deepcopy(diffAveGrd);
        shape = np.shape(DQT_zscore)
        DQT_zscore = np.reshape(DQT_zscore, (np.size(DQT_zscore), 1) )
        DQT_zscoreNonNans = cp.deepcopy( DQT_zscore[~np.isnan(DQT_zscore)] )
        DQT_zscoreNonNans =  basins.Fields[fieldNum]['weightMethodPara']['qt'].transform( np.reshape(DQT_zscoreNonNans, (len(DQT_zscoreNonNans), 1) ) )             
        DQT_zscore[~np.isnan(DQT_zscore)] = np.reshape(DQT_zscoreNonNans, (len(DQT_zscoreNonNans),) )
        DQT_zscore = np.reshape(DQT_zscore, shape)


        # Plot using ExoCcycle plotGlobal function
        EC.utils.plotGlobal(basins.lat, basins.lon, DQT_zscore,
                            outputDir = os.getcwd()+"/"+fldName,
                            fidName = "plotGlobal_DQT_{}.png".format(basins.Fields[fieldNum]['parameterUnit']),
                            cmapOpts={"cmap":cmc.batlow,
                                      "cbar-title":"cbar-title",
                                      "cbar-range":[0,
                                                    2]},
                            pltOpts={"valueType": "DQT",
                                     "valueUnits": "zscore",
                                     "plotTitle":"",
                                     "plotZeroContour":False,
                                     "transparent":True},
                            savePNG=True,
                            saveSVG=False)

    # Plot using ExoCcycle plotGlobal function
    # Same as above except plots the weight values
    # excluding the distance dependence between nodes.
    # However, this should be negligible.
    def complementCDF(diff,
                      transformer,
                      std,
                      shortenFactor = edgeWeightMethod['shortenFactor'],
                      shiftFactor = edgeWeightMethod['shiftFactor'],
                      minWeight = edgeWeightMethod['minWeight']):
        """
        complementCDF is function used to calculate the weight of a node
        pair connection given the following inputs.

        Parameters
        -----------
        diff : FLOAT
            A node pair difference value from field of data.
        transformer : OBJECT
            Quantile transformation that is used to convert diff
            input into z-score value.
        std : FLAOT
            Standard devitaion of total field data input after
            being quantile transformed (i.e., this value should
            be ~0). 
        shortenFactor : FLOAT
            Factor to shorten CDF distribution by.
        shiftFactor : FLOAT
            Factor to shift CDF distribution by.
        minWeight : FLOAT
            A value between 0 and 1 that determines the minimum
            weight to assign to a diff value. 

        Returns
        --------
        node pair edge weight(s)

        """
        from scipy import stats
        # Transform from diff-space to gaussian-space
        if len(diff) == 1:
            QTGdiff = transformer.transform( np.reshape( np.array(diff), (1,1) ) );
        else:
            shape = np.shape(diff)
            QTGdiff = transformer.transform( np.reshape( np.array(diff), (np.size(diff), 1) ) )
            QTGdiff = np.reshape(QTGdiff, shape);
        # Get probablity in stretched distribution
        cdfCenter  = std*shiftFactor
        cdfStretch = std/shortenFactor
        CDF = stats.norm.cdf(QTGdiff, loc=cdfCenter, scale=cdfStretch)
        # Divide by probablity in normal distribution. This
        # scales probablility between 0-1.
        # Note that:
        #   S->1 for |value1 - value2|-> 0   and
        #   S->0 for |value1 - value2|-> inf
        Ss = ( (1-CDF) + minWeight )/(minWeight+1);

        return Ss

    Ss= complementCDF(diffAveGrd,
                      transformer = basins.Fields[fieldNum]['weightMethodPara']['qt'],
                      std = basins.Fields[fieldNum]['weightMethodPara']['qtDissSTD'],
                      shortenFactor = edgeWeightMethod['shortenFactor'],
                      shiftFactor = edgeWeightMethod['shiftFactor'],
                      minWeight = edgeWeightMethod['minWeight'])

    EC.utils.plotGlobal(basins.lat, basins.lon, Ss,
                        outputDir = os.getcwd()+"/"+fldName,
                        fidName = "plotGlobal_DQT_{}_weights.png".format(basins.Fields[fieldNum]['parameterUnit']),
                        cmapOpts={"cmap":cmc.batlow,
                                  "cbar-title":"cbar-title",
                                  "cbar-range":[0,
                                                1]},
                        pltOpts={"valueType": "DQT",
                                 "valueUnits": "Weights",
                                 "plotTitle":"",
                                 "plotZeroContour":False,
                                 "transparent":True},
                        savePNG=True,
                        saveSVG=False)
    
    ###########################################
    ### Report community evaluation metrics ###
    ###########################################
    
    # Create node cluster
    # Note that the small basin mergers are not inlcuded
    # in the LGNClusters. Only large basin mergers such that
    # small basin mergers results in X chosen basins.
    LeidenClusters=NodeClustering(communities=basins.LDcommunities,
                           graph=basins.G,
                           method_name="consensus_ledien_fixed",
                           method_parameters={
                               "resolution_parameter": resolution,
                               "runs": ensembleSize,
                               "distance_threshold": 0.3}
                          )

    LGNClusters=NodeClustering(communities=basins.communitiesFinal,
                           graph=basins.G,
                           method_name="consensus_ledien_fixed",
                           method_parameters={
                               "resolution_parameter": resolution,
                               "runs": ensembleSize,
                               "distance_threshold": 0.3}
                          )


    # Calculate community detection metrics
    for cluster, method in zip([LeidenClusters, LGNClusters], ["LeidenClusters", "LGNClusters"]):
        newman_girvan_modularity = evaluation.newman_girvan_modularity(basins.G, cluster)
        internal_edge_density = evaluation.internal_edge_density(basins.G, cluster)
        erdos_renyi_modularity= evaluation.erdos_renyi_modularity(basins.G, cluster)
        modularity_density    = evaluation.modularity_density(basins.G, cluster)
        avg_embeddedness      = evaluation.avg_embeddedness(basins.G, cluster)
        conductance           = evaluation.conductance(basins.G, cluster)
        surprise              = evaluation.surprise(basins.G, cluster)

        # Add community evaluation metrics to output
        readmetxt += "\n\nCommunity evaluation metrics ({}):\n".format(method);
        readmetxt += "newman_girvan_modularity:\t {}\n".format(newman_girvan_modularity.score)
        readmetxt += "erdos_renyi_modularity:\t\t {}\n".format(erdos_renyi_modularity.score)
        readmetxt += "modularity_density:\t\t {}\n".format(modularity_density.score)
        readmetxt += "internal_edge_density:\t\t {} +- {} (std)\n".format(internal_edge_density.score, internal_edge_density.std)
        readmetxt += "avg_embeddedness:\t\t {} +- {} (std)\n".format(avg_embeddedness.score, avg_embeddedness.std)
        readmetxt += "conductance:\t\t\t {} +- {} (std)\n".format(conductance.score, conductance.std)
        readmetxt += "surprise:\t\t\t {}\n".format(surprise.score)
        
    
    with open(fldName+"/readme.txt", "w") as text_file:
        text_file.write(readmetxt)
