# Setup

In [None]:
import tensorflow as tf
from tensorflow import keras

import rmsp
import sys
import os
import pandas as pd
import warnings
warnings.filterwarnings("ignore")
from matplotlib import pyplot as plt
import numpy as np
import math
import shutil
import copy
import pygeostat as gs
from tqdm.notebook import trange
tf.__version__

%load_ext autoreload
%autoreload 2

In [None]:
rmsp.activate()

In [None]:
sys.path.insert(0, os.path.abspath('../Tools'))
from file_export import PickleExporter, FigureExporter
from gaussian_mv import GmmUtility
from lambda_distribution import GeneralizedLambdaDist
from utility import get_lambdas_keras
from utility import fix_ipython_autocomplete
fix_ipython_autocomplete()

In [None]:
def create_axes(n_col, n_plots, figsize, **kwargs):

    n_rows = n_plots // n_col + int(n_plots % n_col > 0)

    fig, axes = plt.subplots(n_rows, n_col, figsize=figsize, **kwargs)

    axes = axes.flatten()

    n_invisible = 0
    if (n_plots % n_col) > 0:
        n_invisible = n_col - n_plots % n_col

        for ax in axes[-n_invisible:]:
            ax.set_visible(False)

    return fig, axes[:len(axes) - n_invisible]

# Introduction

This notebook contains a workflow to implement data imputation for North West Territories data set. RMSP is used instead of pygeostat where applicable to simplify the code and improve the performance.

The data imputation is based on quantifying the conditional distribution of a missing variable. The conditional distribution is informed by the multivariate and spatial relationships. Since it is not feasible to model the full multivariate spatial distribution, the two main components are quantified separately and then merged/combined based on Bayesian updating. Collocated multivariate relationship and univariate spatial continuity of the missing variable are the two main components. 

In this notebook, MLP ANN networks are used to quantify the conditional moments based on homotopic multivariate relationships. The conditional moments are used to fit a parameteric Lambda distribution. The Lambda distribution is combined with the spatial conditional distribution (i.e. SK/normal equations) before being sampled to generate one realization of a missing value.

Note: The tensorflow version should be 2.0 or newer. 

# Settings

In [None]:
outdir = 'Output/Imputation_MLP/'
data_dir = 'Output/DataInventory/'
mlp_dir = 'Output/LambdaDistributionMl/'
gs.mkdir(outdir)

inputdir = 'data/NWTData'

gs.Parameters['data.tmin'] = -998
gs.Parameters['data.null'] = -999
aspects = {'xy': 2, 'xz': 20, 'yz': 6}

cmap = 'RdYlGn_r'

In [None]:
pickle_data = PickleExporter(outdir)
save_figure = FigureExporter(outdir)
save_figure_paper = FigureExporter(
    "../../JournalPapers/ImputationUsingLambdaDistAndMl/Latex/elsarticle-template/Figures_Ni/"
)

## <span style='color:#5177F9;'> Helper function </span>

In [None]:
def set_axis_label_font(ax, fontsize=12, title=True, prefix=''):
    label = ax.get_xlabel()
    if len(label)>0:
        label=prefix + label
    ax.set_xlabel(label, fontsize=fontsize)
    label = ax.get_ylabel()
    if len(label)>0:
        label=prefix + label
    ax.set_ylabel(label, fontsize=fontsize)
    if title:
        ax.set_title(ax.get_title(), fontsize=fontsize)

In [None]:
def set_cabr_label_font(fig, fontsize=11, prefix=''):
    for child in fig.get_children():
        if 'cbar' in type(child).__name__.lower():
            label = child.get_ylabel()
            child.set_ylabel(prefix+label, fontsize=fontsize)

In [None]:
def scaplot_compare(data, variables, prefix='', **kwargs):
    n_var = len(variables)
    n_plots = int((n_var * (n_var-1))/2)

    fig, axes = create_axes(n_plots,n_plots,(12,3))
    count = 0
    for i in range(n_var-1):
        for j in range(i+1,n_var):
            data.scatplot(variables[i], variables[j], ax=axes[count],cmap=cmap, stats=['count', 'corr'], **kwargs)
            set_axis_label_font(axes[count], prefix=prefix)
            count+=1
    return fig, axes

# Start of Case Study

## Load data 

In [None]:
data = rmsp.from_pickle(data_dir+'PooledData.pkl')
data.describe()

In [None]:
response_variables = rmsp.from_pickle(data_dir+'response_variables.pkl')
missing_variables = rmsp.from_pickle(data_dir+'missing_variables.pkl')
n_var_miss = len(missing_variables)
variables = response_variables + missing_variables

## Declustering, Despiking and NS transform

In [None]:
search = rmsp.Search(min_comps=1, max_comps=1, ranges=[50000.0] * 3)
est = rmsp.NNEstimator(search)

wts = []
for var in variables:
    est.estimate(data, data, var, accumulate_weights=True)
    data[var + "_wt"] = est.get_cumulative_weights()
    wts.append(var + "_wt")
data.describe()

In [None]:
for orient in ("xy", "xz", "yz"):
    fig, axes = rmsp.ImageGrid(
        1,
        len(variables),
        figsize=(len(variables) * 5, 12),
        cbar_mode="each",
        axes_pad=(0.6, 0.4),
    )
    for ax, var in zip(axes, variables):
        data.sectionplot(
            orient=orient,
            var=var + "_wt",
            s=10,
            ax=ax,
            grid=True,
            tickangs=(45, 45),
            missing_color="r",
            aspect=aspects[orient],
        )
        set_axis_label_font(ax)
    set_cabr_label_font(fig)

In [None]:
dsp = rmsp.DespikeMVSpatial(
    num_neighbors=min(len(data), 10),
    despike_level=0.00001,
    spike_epsilon=0.00001,
    wt_to_random = 0.5
)

data[variables] = dsp.fit_transform(
    data, variables
).values

In [None]:
import matplotlib.gridspec as gridspec
import matplotlib.pyplot as plt

fig, axes = create_axes(3,len(variables), (14,3))

for i, variable in enumerate(variables):
    data.cdfplot(variable, variable + "_wt", ax=axes[i], log=False, lw=2, grid=True)
    axes[i].set_xlabel(axes[i].get_xlabel() + ' (%)')
    set_axis_label_font(axes[i])

# save_figure_paper('CSHistBeforeNS.pdf')
save_figure_paper('CdfBeforeNS.png')

In [None]:
data_ns = data.copy()

ns_transformers = {}

for var in variables:
    ns_transformer = rmsp.NSTransformer()
    data_ns[var] = ns_transformer.fit_transform(data[var], data[var+'_wt'])
    ns_transformers.update({var:ns_transformer})

data_ns.head()

In [None]:
import matplotlib.gridspec as gridspec
import matplotlib.pyplot as plt

fig, axes = create_axes(3,len(variables), (14,3))

for i, variable in enumerate(variables):
    data_ns.histplot(variable, ax=axes[i], face_c='lightgreen', edge_c='k', xlabel=f'Ns:{variable}')
    set_axis_label_font(axes[i])
save_figure_paper('CSHistAfterNS.pdf')

## Location map for NScores

In [None]:
stat = data_ns.describe()
for orient in ("xy", "xz", "yz"):
    fig, axes = rmsp.ImageGrid(
        1,
        n_var_miss,
        figsize=(n_var_miss * 5, 12),
        cbar_mode="each",
        axes_pad=(0.6, 0.4),
    )

    stat = data_ns.describe()

    for ax, var in zip(axes, missing_variables):
        data_ns.sectionplot(
            var,
            orient=orient,
            s=15,
            ax=ax,
            grid=True,
            clim=[-3, 3],
            tickangs=(45, 45),
            missing_color="r",
            aspect=aspects[orient],
        )
        if orient == "xy":
            ax.text(
                0.1,
                0.9,
                "n = %d" % (stat[var]["count"]),
                transform=ax.transAxes,
                fontsize=13,
            )
        else:
            ax.text(
                0.1,
                0.1,
                f"{aspects[orient]}x vertical exageration",
                transform=ax.transAxes,
                fontsize=13,
            )
        set_axis_label_font(ax)
    set_cabr_label_font(fig, prefix='NS:')
    plt.tight_layout(w_pad=2)
    save_figure_paper(f'locmap_ns_{orient}.png')

## Bivariate relationship plots (After NS)

In [None]:
cmap = 'RdYlGn_r'

In [None]:
mask = data_ns[missing_variables].notna().all(axis=1)
fig = data_ns[mask].scatplots(
    variables=variables,
    figsize=(8, 8),
    s=5,
    num_sample=7000,
    axes_pad=(0.15, 0.15),
    stats=['count','rankcorr', 'corr'],
    cbar=True,
    lims = {var: (-3.5,3.5) for var in variables},
    cmap=cmap
)

for ax in fig.axes:
    set_axis_label_font(ax, prefix='NS:')
    ax.grid()
    
save_figure_paper(
    r"CSMV1.png"
)

## Data spacing analysis

Data spacing analysis for variables with missing samples

In [None]:
ds = data_ns.horizontal_spacing(n_nearest=1, var=missing_variables[0], nexcept=10000)

In [None]:
ds_column = ds.columns[-1]
_ = ds.sectionplot(ds_column, s=5, tickangs=(45,45), cmap='Spectral_r')
_, ax = ds.cdfplot(ds_column, annotate_stats=["p0.05","p25", "p50", "p75", "p95"], log = True, figsize=(8,4))

## Experimental variogram calculation


In [None]:
# Get the average nearest distance as a measure of average data spacing
lag_length = np.mean(ds[ds.columns[-1]])

x_range =  (data_ns[data_ns.x].max() - data_ns[data_ns.x].min())/4 # Limitted the range
y_range =  (data_ns[data_ns.y].max() - data_ns[data_ns.y].min())/4

n_lag_x = int(x_range/lag_length)
n_lag_y = int(y_range/lag_length)


nlags_horz = 12
nlags_vert = 8

### Search setup

In [None]:
exp_vario_search = []

# Horizontal
lags, tols = rmsp.Lags.merge_lags_tols(
    [
        # rmsp.Lags(lag_length * 0.5, lag_length * 0.5 * 0.45, 1),
        rmsp.Lags(lag_length, lag_length * 0.45, nlags_horz),
    ]
)
search_vario_horz = rmsp.ExpVarioSearch(0, 0, lags, tols, azmtol=90, incltol=90)
exp_vario_search.append(search_vario_horz)

# Vertical
lag_vert = rmsp.Lags(1, 1 * 0.75, nlags_vert)
lags, tols = lag_vert.get_lags_tols()
search_vario_horz = rmsp.ExpVarioSearch(0, 90, lags, tols, azmtol=45, incltol=15)
exp_vario_search.append(search_vario_horz)

In [None]:
vario_exp = {}
for var in missing_variables:
    vario_exp[var] = rmsp.ExpVario("traditional").calculate(
        data_ns, var, searches=exp_vario_search
    )

In [None]:
fig, axes = plt.subplots(2, 2, figsize=(10, 6))

variogram_names = ['Horizontal', 'Vertical']

for i, var in enumerate(missing_variables):
    data_var = vario_exp[var].data
    for index in data_var['Search Index'].unique():
        mask = data_var['Search Index'] == index
        vario_exp[var].plot(
            search_index=index,
            ax=axes[int(index), i],
            ylim=(0, 1.2),
            title=(f"{var} ({variogram_names[i]})"),
            pairs_bar=True,
            c=f'C{i}',
            ms=5
        )
        set_axis_label_font(axes[int(index), i])
plt.tight_layout()

## Variogram Modeling

In [None]:
vario_models = {}

num_struct = 2

range3_dict = {'Fe': [[8, 9], [10, 11] ], 'SiO2': [[10, 12], [12, 14]]}

for var in missing_variables:

    
    vario_model = rmsp.VarioModel.fit_experimental(
        vario_exp[var],
        num_struct=num_struct,
        nugget=[0.05, 0.05],
        shapes="exponential",
        var_contribs=[[0.0, 1.0]] * num_struct,
        angle1=[0.0] * num_struct,  # a list with length 2 can be used to pass range
        angle2=[0.0] * num_struct,
        angle3=[0.0] * num_struct,
        angles_fixed_across=True,
        range1=[[40, 100], [100, 300.0]],
        # range2=[[10, 20000.0], [20000, 80000.0]],
        range3=range3_dict[var],
        ranges12_bounds=1,
        ranges13_bounds=None,
        invdist_wt=False,
        numpairs_wt=True,
        consider_early_exit=False,
        try_unique_dir_lock=True,
        lock_min_range_to_close_lag=False,
        max_no_improvement=500,
        minpairs=50,
    )

    vario_models.update({var: vario_model})

In [None]:
df_list = []
for var, varmodel in vario_models.items():
    df = vario_models[var].to_table()
    df.columns = pd.MultiIndex.from_tuples(((var, x) for x in df.columns))
    df_list.append(df)

df = pd.concat(df_list, axis=1)
df

In [None]:
fig, axes = plt.subplots(2, 2, figsize=(10, 6))
rmsp.GlobalParams["plotting.varioplot.gammasize"] = 1
for i, var in enumerate(missing_variables):
    data_var = vario_exp[var].data
    for index in data_var["Search Index"].unique():
        mask = data_var["Search Index"] == index
        azim = data_var[mask].Azimuth.mean()
        incl = data_var[mask].Inclination.mean()
        vario_exp[var].plot(
            search_index=index,
            ax=axes[int(index), i],
            ylim=(0, 1.2),
            title=var,
            label = f"(azm: {azim:.2f}, incl: {incl:.2f})",
            pairs_bar=True,
            c=f"C{i}",
            tickangs=(45, 0),
            ms=5,
        )
        vario_models[var].plot_draw(
            ax=axes[int(index), i], azm=azim, incl=incl, lw=2, c=f"C{i}", ls="--"
        )
        axes[int(index), i].legend(fontsize=12)
        set_axis_label_font(axes[int(index), i])
plt.tight_layout()
save_figure_paper("CSVarg.png")

## Fitting GMM 

The fitted GMM will be used only as a cross reference

In [None]:
gmm_model = rmsp.GMM().fit(data_ns[variables], num_kernels=9)

In [None]:
import unittest
from test_suite import GmmUtilityTest

suite = unittest.TestLoader().loadTestsFromTestCase(GmmUtilityTest)

unittest.TextTestRunner(verbosity=2).run(suite)

In [None]:
mean_list = []
cov_list = []
contrib_list = []

for kernel in gmm_model.kernels:
    mean_list.append(kernel.mean)
    cov_list.append(kernel.cov)
    contrib_list.append(kernel.wt)
    
gmm_util = GmmUtility(
    data=data_ns,
    variable_names=variables,
    mean_vector_list=mean_list,
    covariance_matrix_list=cov_list,
    contribution_list=contrib_list,
)

In [None]:
gmm_util.summary_plot(cmap=cmap)

## Fitting MLP for the missing variables


Setting parameters to model and train an MLP ANN that estimates the first four moments

In [None]:
import papermill as pm

### Helper function

In [None]:
def get_conditional_moments(input_vals, model_first_moment, model_central_moments, reference_data, label_variable):
    """Calculate conditional moments"""
    mean = model_first_moment.predict(input_vals)[0][0]
    variance, skewness, kurtosis = model_central_moments.predict(
        input_vals
    )
    variance = variance[0][0]
    skewness = skewness[0][0]
    kurtosis = kurtosis[0][0]

    max_var = reference_data[label_variable].max()
    min_var = reference_data[label_variable].min()
    span = max_var - min_var
    max_var = max_var + 0.01*span
    min_var = min_var - 0.01*span

    mean = min(mean, max_var)
    mean = max(mean, min_var)

    try:
        skewness = skewness / pow(variance, 1.5)
        kurtosis = kurtosis / pow(variance, 2)
    except ZeroDivisionError:
        print(input_vals)
        variance = 1.0
        skewness = 0
        kurtosis = 3.0

    skewness = max(skewness, -1.1)
    skewness = min(skewness, 1.5)

    kurtosis = max(kurtosis, 1.8)
    kurtosis = min(kurtosis, 5.9)

    return mean, variance, skewness, kurtosis

In [None]:
# add a unqie hash column
data_ns['Hash'] = pd.util.hash_pandas_object(data_ns[[data_ns.dhid, data_ns.x, data_ns.y]], hash_key='0314')

In [None]:
# mlp execution function
def execute_mlp_for_conditional_mv(
    missing_variable, response_variables, input_data, mlp2_nodes_1=64, mlp2_nodes_2=16
):
    label_variable_mlcon = missing_variable
    feature_variables_mlcon = response_variables
    outdir_mlcon = f"Output/MlForConditionalDistCaseStudy/{label_variable_mlcon}/"

    data_dir_mlcon = outdir
    input_data_pkl = f"data_ns_{label_variable_mlcon}.pkl"
    pickle_data(input_data, input_data_pkl, clean_name=False)

    input_template = "02-MlForConditionalDistributionTemplate.ipynb"
    notebook = input_template.replace("Template", missing_variable)

    data_out_mlcon = "data_out.pkl"

    # execute notebook
    pm_exec_info = pm.execute_notebook(
        input_template,
        notebook,
        parameters=dict(
            outdir=outdir_mlcon,
            data_dir=data_dir_mlcon,
            ns_data_pkl=input_data_pkl,
            label_variable=label_variable_mlcon,
            feature_variables=feature_variables_mlcon,
            out_file=data_out_mlcon,
            mlp2_nodes_1 = mlp2_nodes_1,
            mlp2_nodes_2 = mlp2_nodes_2
        ),
    )
    
    from utility import make_html
    make_html(notebook, notebook.replace('.ipynb', '.html'))
    
    rmsp.remove_file(notebook)

    model_first_moment = keras.models.load_model(outdir_mlcon + f"modelfirstmoment")
    model_central_moments = keras.models.load_model(
        outdir_mlcon + f"modelcentralmoments"
    )

    return model_first_moment, model_central_moments

In [None]:
models_mpl = {}

### First Missing variable

In [None]:
model_first_moment, model_central_moments = execute_mlp_for_conditional_mv(
    missing_variables[0], response_variables, data_ns, mlp2_nodes_1=8, mlp2_nodes_2=8
)

models_mpl[missing_variables[0]] = [model_first_moment, model_central_moments]

### Second Variable

In [None]:
model_first_moment, model_central_moments = execute_mlp_for_conditional_mv(
    missing_variables[1],
    response_variables + [missing_variables[0]],
    data_ns,
    mlp2_nodes_1=16,
    mlp2_nodes_2=4,
)

models_mpl[missing_variables[1]] = [model_first_moment, model_central_moments]

## MLP and Lambda distribution check 

This section is a code snippet to use an already fitted MLP model to get the conditional moments of a missing variable. Then, lambda distribution is used to parametrize the conditional distribution.

In [None]:
class_name = 'Ia'
mlp_model = keras.models.load_model(mlp_dir+f'Lambda_{class_name}_Keras')

In [None]:
out_dir_lambda = "LambdaFitResults/"
try:
    shutil.rmtree(out_dir_lambda)
except:
    pass
rmsp.make_dir(out_dir_lambda)
save_figure_check = FigureExporter(os.path.join(outdir, "LambdaFitResults/"))
def get_simulated_value(a1, a2, a3, a4, out_dir, variable_name):
    a4 = min(a4, 4.5)

    # Fit the lambda distribution
    lambdas = get_lambdas_keras(a1, a2, a3, a4, mlp_model)
    gld = GeneralizedLambdaDist(*lambdas)
    fig, axes = plt.subplots(1, 2, figsize=(16, 4))
    # lambda distribution pdf
    ax = gld.pdf_plot(ax=axes[0], return_ax=True)
    axes[0].set_xlabel(r"$Y_m$", fontsize=12)
    axes[0].set_ylabel(r"pdf", fontsize=12)
    text = f"mean: {a1:.2f}\n"
    text += f"$\sigma$: {np.sqrt(a2):.2}\n"
    text += f"$skew$: {a3:.2f}\n"
    text += f"$kurtosis$: {a4:.2f}"
    ax.text(0.6, 0.7, text, transform=ax.transAxes, ha="left")

    gld.dist_plot(cdf=False, ax=axes[1], color="green")
    axes[1].set_xlabel(r"$Y_m$", fontsize=12)

    save_figure("Conditional_comparison_{}_{:.2f}.png".format(variable_name, a1))

    return gld.simulate()

In [None]:
for _, row in data_ns.head(10).iterrows():

    if pd.isna(row[[missing_variables[0]]]).any():

        a1,a2, a3, a4 = get_conditional_moments(
            row[response_variables].values.reshape(-1, len(response_variables)),
            *models_mpl[missing_variables[0]],
            data_ns,
            missing_variables[0]
        )
        print(a1, a2, a3, a4)
        get_simulated_value(
            a1, a2, a3, a4, out_dir_lambda, variable_name=missing_variables[0],
        )

## Data Imputation (One Example Realization)

In [None]:
from probability_updating import (
    update_bayesian_pr, update_bayesian_ind
)

In [None]:
def simulate(cdf, x):
    return np.interp(np.random.rand(), cdf, x)

### Check Updating function (using CDF)

In this section, permanence of ratio (PR) are used to combine the fitted lambda distribution, the conditional Gaussian distribution and the prior representative normal distribution. 

Below is an example that uses four example moments to fit a lambda distribution and then combined with an example conditional Gaussian distribution using PR to get the final updated distribution.

In [None]:
lambdas = get_lambdas_keras(1.384, 1.071*1.071, -0.466, 2.978, mlp_model)
gld = GeneralizedLambdaDist(*lambdas)

fig, ax = plt.subplots(1,1, figsize=(7,5))

ax.set_title('Conditional Independence', fontsize= 14)
x_vals, F_gld, F_spatial, F_global, cdf_updated = update_bayesian_pr(gld, 1.15, 1.8, n_sample = 100)
ax.plot(x_vals, F_gld, 'g', label='MV')
ax.plot(x_vals, F_spatial, 'r', label = 'Spatial')
ax.plot(x_vals, F_global, 'gray', label = 'global')
ax.plot(x_vals, cdf_updated, 'k', label = 'updated')
ax.plot([-8,8], [1,1], c='gray', lw=2)
ax.set_xlim([-6,6])
ax.set_ylim([0,1.1])
ax.legend(fontsize= 14)

### Implementing the imputation work flow

Study one realization

In [None]:
# Keep the null index datafarame
data_imputation = data_ns.copy()
data_imputation['Missing'] = 0

mask = data_imputation[missing_variables].isna().any(axis=1)
data_imputation.loc[mask, 'Missing'] = 1

data_imputation.head()

In [None]:
spatial_dict = {
    var: {data_ns.x: [], data_ns.y: [], data_ns.z: [], "Estimate": [], "Variance": []} for var in missing_variables
}

In [None]:
out_dir_imputation_plots = os.path.join(outdir, "ImputedFigures")
try:
    shutil.rmtree(out_dir_imputation_plots)
except:
    pass
rmsp.make_dir(out_dir_imputation_plots)

# previously imputed variable will be added to the response variables
response_var_dict = {
    missing_variables[0]: response_variables,
    missing_variables[1]: response_variables + [missing_variables[0]],
}

item = 0
for variable in missing_variables:

    vario_model = vario_models[variable]
#     search = rmsp.Search.from_vario_ranges(vario_model)
    search = rmsp.Search([0.0] * 3, [100.0] * 3, min_comps=1, max_comps=100)
    krig = rmsp.KrigeEstimator().set_params(
        search, vario_model, "sk", sk_mean=np.nanmean(data_imputation[variable])
    )

    # Shuffle data
    data_imputation_shuffled = data_imputation.sample(frac=1).reset_index(drop=False)

    for idx, row in data_imputation_shuffled.iterrows():
        
        index = row['index']

        spatial_dict[variable][data_ns.x].append(row[data_ns.x])
        spatial_dict[variable][data_ns.y].append(row[data_ns.y])
        spatial_dict[variable][data_ns.z].append(row[data_ns.z])

        if pd.isna(row[variable]):

            # mask missing location to implement kriging
            mask = data_imputation["Hash"] == row["Hash"]
            result = krig.estimate(
                data_imputation.loc[mask],
                data_imputation.loc[~mask],
                variable,
                output=["estimate", "estimate_var"],
            )

            spatial_mean, spatial_variance = (
                result["estimate"].values[0],
                result["estimate_var"].values[0],
            )

            spatial_dict[variable]["Estimate"].append(spatial_mean)
            spatial_dict[variable]["Variance"].append(spatial_variance)

            # Get the conditioning data for GMM
            conditioning_data = [
                None if math.isnan(val) else val for val in row[gmm_util.variable_names]
            ]
            # Get the uni/bivariate conditional distribution based on the GMM components
            (
                cond_means,
                cond_covariances,
                cond_contributions,
            ) = gmm_util.get_conditional_pdf(conditioning_data=conditioning_data)

            # Grab the marginal univariate GMM componenets (first variable)
            cond_means = np.array(cond_means)
            cond_means = cond_means[:, 0].reshape(gmm_util.n_components, 1)
            cond_covariances = np.array(cond_covariances)
            cond_covariances = cond_covariances[:, 0, 0].reshape(
                gmm_util.n_components, 1, 1
            )

            # use the
            a1, a2, a3, a4 = get_conditional_moments(
                row[response_var_dict[variable]].values.reshape(
                    -1, len(response_var_dict[variable])
                ),
                *models_mpl[variable],
                data_ns,
                variable,
            )

            if a2 == 0:
                print(
                    a1,
                    a2,
                    a3,
                    a4,
                    index,
                    row[response_var_dict[variable]].values.reshape(
                        -1, len(response_var_dict[variable])
                    ),
                )
                # a1, a2, a3, a4 = gmm_util.get_moments(
                #     cond_means, cond_covariances, cond_contributions
                # )

            # Fit the lambda distribution given the first four moments
            lambdas = get_lambdas_keras(a1, a2, a3, a4, mlp_model)

            gld = GeneralizedLambdaDist(*lambdas)

            x_vals, F_gld, F_spatial, F_global, F_updated = update_bayesian_pr(
                gld, spatial_mean=spatial_mean, spatial_variance=spatial_variance
            )

            data_imputation.loc[index, variable] = simulate(F_updated, x_vals)

            item += 1
            # Plots
            if item % 10 == 0:
                fig, axes = plt.subplots(1, 2, figsize=(16, 4))

                ax = axes[0]
                # lambda distribution pdf
                _ = gld.pdf_plot(ax=ax, return_ax=False)
                # Marginal/conditional GMM pdf
                GmmUtility.univariate_pdf_from_mixture_plot(
                    cond_means,
                    cond_covariances,
                    cond_contributions,
                    variable_name=variable,
                    ax=ax,
                )
                ax.set_title(
                    "Conditional distribution (GMM and lamda distribution)", fontsize=13
                )

                ax = axes[1]
                ax.plot(x_vals, F_gld, "g", label=r"Lambda Distribution")
                ax.plot(
                    x_vals,
                    F_spatial,
                    "r",
                    label=r"Spatial ($\mu:${:.2f}, $\sigma^2:${:.2f})".format(
                        spatial_mean, spatial_variance
                    ),
                )
                ax.plot(x_vals, F_global, "gray", label="Global")
                ax.plot(x_vals, F_updated, "blue", label="Updated")
                ax.plot([-8, 8], [1, 1], c="gray", lw=2)
                ax.set_xlim([-6, 6])
                ax.set_xlabel(variable, fontsize=12)
                ax.set_ylim([0, 1.1])
                ax.set_ylabel("Cumulative Distribution Function (CDF)", fontsize=12)
                ax.legend(fontsize=12, loc="center left")
                ax.set_title("Bayesian Updating", fontsize=13)

                gs.export_image(
                    os.path.join(
                        out_dir_imputation_plots,
                        "Conditional_comparison_{}_{:g}.png".format(variable, item),
                    )
                )
                plt.show()
        else:
            spatial_dict[variable]["Estimate"].append(np.nan)
            spatial_dict[variable]["Variance"].append(np.nan)
        

### Check if there is any imputation issues

In [None]:
for var in missing_variables:
    mask = pd.isna(data_imputation[var])
    display(data_imputation[mask])
    assert len(data_imputation[mask]) == 0

### Check the location map for the kriging

In [None]:
for orient in ("xy", "xz", "yz"):
    fig, axes = rmsp.ImageGrid(
        1,
        len(missing_variables),
        figsize=(n_var_miss * 5, 12),
        cbar_mode="each",
        axes_pad=(0.6, 0.4),
    )

    for ax, var in zip(axes, missing_variables):

        point_data = pd.DataFrame(spatial_dict[var])

        point_data = rmsp.PointData(point_data, x=data_ns.x, y=data_ns.y, z=data_ns.z)

        point_data.sectionplot(
            ax=ax,
            orient=orient,
            var="Estimate",
            missing_color='gray',
            s=10,
            tickangs=(45, 0),
            aspect=aspects[orient],
            grid=True,
            title=f"Conditional Mean ({var})",
        )
        set_axis_label_font(ax)

        n_imputed = sum(pd.notna(point_data).apply(lambda x: all(x), axis=1))
        if orient == 'xy':
            ax.text(0.07, 0.92, f"n_imputed = {n_imputed:g}", transform=ax.transAxes)
        set_cabr_label_font(fig)

In [None]:
for orient in ("xy", "xz", "yz"):
    fig, axes = rmsp.ImageGrid(
        1,
        len(missing_variables),
        figsize=(n_var_miss * 5, 12),
        cbar_mode="each",
        axes_pad=(0.6, 0.4),
    )

    for ax, var in zip(axes, missing_variables):

        point_data = pd.DataFrame(spatial_dict[var])

        point_data = rmsp.PointData(point_data, x=data_ns.x, y=data_ns.y, z=data_ns.z)

        point_data.sectionplot(
            ax=ax,
            orient=orient,
            var="Variance",
            missing_color='gray',
            s=10,
            grid=True,
            aspect=aspects[orient],
            title=f"Conditional Variance ({var})",
            tickangs=(45, 0),
            cmap="RdYlGn_r",
        )
        set_axis_label_font(ax)

        n_imputed = sum(pd.notna(point_data).apply(lambda x: all(x), axis=1))

        if ax == 'xy':
            ax.text(0.07, 0.92, f"n_imputed = {n_imputed:g}", transform=ax.transAxes)
            
        set_cabr_label_font(fig)


### Checking the multivariate relationship (NS units)

Relationship between response variables and the imputed ones

In [None]:
mask_miss = data_imputation.Missing>0
fig = data_imputation[mask_miss].scatplots(
    variables=variables, figsize=(10, 10), stats="all", s=8,
    lims=(-3.5, 3.5),
    num_sample=7000,
    axes_pad=(0.15, 0.15),
    cmap=cmap,
    cbar=True
)
for ax in fig.axes:
    set_axis_label_font(ax, prefix = 'NS:')

In [None]:
fig = data_imputation.scatplots(
    variables=variables, figsize=(10, 10), stats="all", s=8,
    lims=(-3.5, 3.5),
    num_sample=7000,
    axes_pad=(0.15, 0.15),
    cmap=cmap,
    cbar=True
)
for ax in fig.axes:
    set_axis_label_font(ax,  prefix = 'NS:')

In [None]:
fig = gs.scatter_plots_lu(
    data_imputation[mask_miss],
    data_imputation,
    figsize=(15, 15),
    align_orient=False,
    lower_variables=['Ni', 'Fe', 'SiO2'],
    upper_variables=variables,
    stat_blk="all",
    s=10,
    cmap=cmap,
)

for ax in fig.axes:
    set_axis_label_font(ax,  prefix = 'NS:')

save_figure_paper("LUMV.png")

In [None]:
fig, axes = scaplot_compare(data_ns, variables, grid=True, s=8,  prefix = 'NS:')
fig, axes = scaplot_compare(data_imputation, variables, grid=True, s=8,  prefix = 'NS:')
fig, axes = scaplot_compare(data_imputation[mask_miss], variables, grid=True, s=8,  prefix = 'NS:')

### Back transformation

In [None]:
data_imputation_final = data_imputation.copy()

for var in variables:
    data_imputation_final[var] = ns_transformers[var].inverse_transform(data_imputation[var])

In [None]:
fig, axes = create_axes(3,len(variables), (14,4))

for i, variable in enumerate(variables):
    data_imputation_final.histplot(variable,ax=axes[i], log=True)
    set_axis_label_font(axes[i])

### Location map plot

After data imputation, all data locations have a value for the two missing variables

In [None]:
stat = data_imputation_final.describe()
for orient in ("xy", "xz", "yz"):
    fig, axes = rmsp.ImageGrid(
        1,
        len(variables),
        figsize=(len(variables) * 5, 12),
        cbar_mode="each",
        axes_pad=(0.6, 0.4),
    )
    for ax, var in zip(axes, variables):
        data_imputation_final.sectionplot(
            orient=orient,
            var=var,
            s=10,
            ax=ax,
            grid=True,
            tickangs=(45, 45),
            missing_color="r",
            aspect=aspects[orient],
        )

        set_axis_label_font(ax)
        if orient == 'xy':
            ax.text(0.1, 0.92, 'n = %d'%(stat[var]['count']), transform=ax.transAxes)
        set_cabr_label_font(fig)

In [None]:
mask_miss = data_imputation_final.Missing>0
fig,axes = scaplot_compare(data, variables, grid=True, s=8)
fig.tight_layout()
fig.suptitle('Original (Heterotopic)', y=1.03)
save_figure_paper('mv_orig.png')
fig,axes = scaplot_compare(data_imputation_final, variables, grid=True, s=8)
fig.tight_layout()
fig.suptitle('Full (After Imputation)', y=1.03)
save_figure_paper('mv_full.png')
fig,axes = scaplot_compare(data_imputation_final[mask_miss], variables, grid=True, s=8)
fig.tight_layout()
fig.suptitle('Imputed locations', y=1.03)

## Multiple realizations

In [None]:
fresh_simulate = True

In [None]:
n_real = 100
out_dir_reals = os.path.join(outdir, 'ImputedRealizations')
if fresh_simulate:
    try:
        shutil.rmtree(out_dir_reals)
    except:
        pass
    gs.mkdir(out_dir_reals)

In [None]:
final_variables = [data_ns.dhid, data_ns.x, data_ns.y, data_ns.z] + variables
if fresh_simulate:
    simcache = rmsp.SimCache(len(data_ns), file_prefix=outdir+'ImputedRealizations/', variables=final_variables)
    simcache.clear()
    simcache_final = rmsp.SimCache(len(data_ns), file_prefix=outdir+'ImputedRealizationsFinal/', variables=final_variables)
    simcache_final.clear()
else:
    simcache = rmsp.from_pickle(outdir+'simcache_ns.pkl')
    simcache_final = rmsp.from_pickle(outdir+'simcache_final.pkl')

In [None]:
# Keep the null index datafarame
data_imputation = data_ns.copy()
data_imputation['Missing'] = 0

mask = data_imputation[missing_variables].isna().any(axis=1)
data_imputation.loc[mask, 'Missing'] = 1

data_imputation.head()

In [None]:
if fresh_simulate:
    for ireal in trange(n_real):

        data_imputation_real = data_imputation.copy()

        for variable in missing_variables:

            vario_model = vario_models[variable]
#             search = rmsp.Search.from_vario_ranges(vario_model)
            search = rmsp.Search([0.0] * 3, [100.0] * 3, min_comps=1, max_comps=100)
            krig = rmsp.KrigeEstimator().set_params(
                search,
                vario_model,
                "sk",
                sk_mean=np.nanmean(data_imputation_real[variable]),
            )

            # Shuffle data
            data_imputation_shuffled = data_imputation_real.sample(frac=1).reset_index(drop=False)

            for idx, row in data_imputation_shuffled.iterrows():
                
                index= row['index']

                if pd.isna(row[variable]):

                    # mask missing location to implement kriging
                    mask = data_imputation_real["Hash"] == row["Hash"]
                    result = krig.estimate(
                        data_imputation_real.loc[mask],
                        data_imputation_real,
                        variable,
                        output=["estimate", "estimate_var"],
                    )

                    spatial_mean, spatial_variance = (
                        result["estimate"].values[0],
                        result["estimate_var"].values[0],
                    )

                    a1, a2, a3, a4 = get_conditional_moments(
                        row[response_var_dict[variable]].values.reshape(
                            -1, len(response_var_dict[variable])
                        ),
                        *models_mpl[variable],
                        data_ns,
                        variable,
                    )

                    if a2 == 0:
                        print(a1, a2, a3, a4)
                        # Get the conditioning data for GMM
                        conditioning_data = [
                            None if math.isnan(val) else val
                            for val in row[gmm_util.variable_names]
                        ]
                        # Get the uni/bivariate conditional distribution based on the GMM components
                        (
                            cond_means,
                            cond_covariances,
                            cond_contributions,
                        ) = gmm_util.get_conditional_pdf(
                            conditioning_data=conditioning_data
                        )

                        # Grab the marginal univariate GMM componenets (first variable)
                        cond_means = np.array(cond_means)
                        cond_means = cond_means[:, 0].reshape(gmm_util.n_components, 1)
                        cond_covariances = np.array(cond_covariances)
                        cond_covariances = cond_covariances[:, 0, 0].reshape(
                            gmm_util.n_components, 1, 1
                        )
                        a1, a2, a3, a4 = gmm_util.get_moments(
                            cond_means, cond_covariances, cond_contributions
                        )

                    # Fit the lambda distribution given the first four moments
                    lambdas = get_lambdas_keras(a1, a2, a3, a4, mlp_model)

                    gld = GeneralizedLambdaDist(*lambdas)

                    x_vals, F_gld, F_spatial, F_global, F_updated = update_bayesian_pr(
                        gld,
                        spatial_mean=spatial_mean,
                        spatial_variance=spatial_variance,
                    )

                    sim_val = simulate(F_updated, x_vals)
                    if np.isnan(sim_val):
                        print(
                            variable,
                            index,
                            a1,
                            a2,
                            a3,
                            a4,
                            *lambdas,
                            spatial_mean,
                            spatial_variance,
                        )

                    data_imputation_real.loc[index, variable] = sim_val

        simcache.set_real(ireal, data_imputation_real[final_variables])

        # back transformation
        data_imputation_real_final = data_imputation_real.copy()
        for var in variables:
            data_imputation_real_final[var] = ns_transformers[var].inverse_transform(
                data_imputation_real[var]
            )

        simcache_final.set_real(ireal, data_imputation_real_final[final_variables])

        del(data_imputation_real_final); del(data_imputation_real)

### Error Evaluation

In [None]:
data_all = rmsp.from_pickle(data_dir+'AllData.pkl')
data_all.head()

In [None]:
psims_final = {}
for var in missing_variables:
    psims_final[var] = rmsp.postsim(simcache_final, var=var)
    psims_final[var][f'{var} True'] = data_all[var]

In [None]:
fig, axes = create_axes(len(missing_variables), len(missing_variables), (9, 4))
for (var, psim), ax in zip(psims_final.items(), axes):
    mask = psim[f"{var} stdev"] > 0.0
    cv = rmsp.CrossVal(psim[mask], var + " e-type", f"{var} True")
    cv.scatplot(
        psim[mask],
        ax=ax,
        s=4,
        c="#838383",
        plot_worst=False,
        grid=True,
        stats=["count", "meanx", "meany", "rmse", "corr", "sor"],
    )
    set_axis_label_font(ax)
fig.tight_layout()
save_figure_paper("cv_dp_ld_etype.png")

In [None]:
real = simcache_final.get_real(0)
for var in missing_variables:
    real[f"{var} True"] = data_all[var]
fig, axes = create_axes(len(missing_variables), len(missing_variables), (9, 4))
for (var, psim), ax in zip(psims_final.items(), axes):
    mask = psim[f"{var} stdev"] > 0.0
    cv = rmsp.CrossVal(real[mask], var, f"{var} True")
    cv.scatplot(
        real[mask],
        ax=ax,
        s=4,
        c="#838383",
        plot_worst=False,
        grid=True,
        stats=["count", "meanx", "meany", "rmse", "corr", "sor"],
    )
    set_axis_label_font(ax)
fig.tight_layout()
save_figure_paper("cv_dp_ld_real.png")

### Realization stats

In [None]:
real_list = {var: [] for var in missing_variables}
mask_dict = {var: psims_final[var][f'{var} stdev']>0.0 for var in missing_variables}

for ireal, real in simcache_final.iter_realizations():
    for var in missing_variables:
        real_list[var].append(real[mask_dict[var]][var])

In [None]:
fig, axes = create_axes(len(missing_variables), len(missing_variables), (14,6))
for var, ax in zip(missing_variables, axes):
    real_data = pd.concat(real_list[var], axis=1)

    indices = real_data.index

    unistats = []
    names = []
    true_values = []
    count = 1
    for i, idx in enumerate(indices[0:10]):
        unistats.append(rmsp.UniStats(real_data.loc[idx].values))
        names.append(f'Loc {i+1}')
        true_values.append(data_all.loc[idx, var])

    unicompare = rmsp.UniCompare(unistats, names = names, setname='Missing Locations')
    unicompare.boxplot(ax=ax)
    ax.set_ylabel(var)
    set_axis_label_font(ax)

    width = (np.diff(ax.get_xticks())).mean() / 2
    for val, tick in zip(true_values, ax.get_xticks()):
        ax.hlines(val, tick-width, tick+width, color='r', ls='--')
    ax.grid()

### Histogram Reproduction

#### Normal Scores

In [None]:
ref_unistats = {}
sim_unistats = {}
for var in variables:
    ref_unistats[var] = rmsp.UniStats(
        data_ns[var]
    )
    sim_unistats[var] = [
        rmsp.UniStats(simcache.get_real(ireal, var)[var])
        for ireal in range(n_real)
    ]
fig, axes = create_axes(3,len(variables), (14,4))

for var, ax in zip(variables, axes):
    ref_unistats[var].cdfplot_checkreals(sim_unistats[var], ax=ax)
    set_axis_label_font(ax)

#### Original Units

In [None]:
ref_unistats = {}
sim_unistats = {}
for var in variables:
    ref_unistats[var] = rmsp.UniStats(
        data[var], data[f'{var}_wt']
    )
    sim_unistats[var] = [
        rmsp.UniStats(simcache_final.get_real(ireal, var)[var])
        for ireal in range(n_real)
    ]
fig, axes = create_axes(3,len(variables), (14,4))
for var, ax in zip(variables, axes):
    ref_unistats[var].cdfplot_checkreals(
        sim_unistats[var], ax=ax, log=False, grid=True,
    )
    set_axis_label_font(ax)

In [None]:
fig, axes = plt.subplots(1, len(missing_variables), figsize=(12, 4))

for var, ax in zip(missing_variables, axes):
    ref_unistats[var].cdfplot_checkreals(
        sim_unistats[var], ax=ax, log=False, grid=True, stats=["mean", "count"]
    )
    set_axis_label_font(ax)
save_figure_paper("hist_repro.png")

### Variogram reproduction

In [None]:
vario_exp_orig = {}
for var in missing_variables:
    vario_exp_orig[var] = rmsp.ExpVario("backns").calculate(
        data, var, searches=exp_vario_search, nstransformer = ns_transformers[var]
    )

In [None]:
vario_exp_reals = []
for ireal, real in simcache_final.iter_realizations():
    vario_exp_real = {}
    real = rmsp.PointData(real, x=data.x, y=data.y, z=data.z)
    for var in missing_variables:
        vario_exp_real[var] = rmsp.ExpVario("backns").calculate(
            real, var, searches=exp_vario_search, nstransformer = ns_transformers[var]
        )
        
    vario_exp_reals.append(vario_exp_real)

In [None]:
vario_exp_reals_ns = []
for ireal, real in simcache.iter_realizations():
    vario_exp_real_ns = {}
    real = rmsp.PointData(real, x=data.x, y=data.y, z=data.z)
    for var in missing_variables:
        vario_exp_real_ns[var] = rmsp.ExpVario("traditional").calculate(
            real, var, searches=exp_vario_search
        )
        
    vario_exp_reals_ns.append(vario_exp_real_ns)

In [None]:
fig, axes = plt.subplots(2, 2, figsize=(10, 6))
        
for vario_calc in vario_exp_reals:
    for i, var in enumerate(missing_variables):
        data_var = vario_calc[var].data
        for index in data_var['Search Index'].unique():
            mask = data_var['Search Index'] == index
            vario_calc[var].plot_draw(
                search_index=index,
                ax=axes[int(index), i],
                c='gray',
                ls='-',
                lw=2
            )
            
for i, var in enumerate(missing_variables):
    data_var = vario_exp_orig[var].data
    for index in data_var['Search Index'].unique():
        mask = data_var['Search Index'] == index
        azim = data_var[mask].Azimuth.mean()
        incl = data_var[mask].Inclination.mean()
        vario_exp_orig[var].plot(
            search_index=index,
            ax=axes[int(index), i],
            ylim=(0, 1.5),
            title=(f"{var} (azim: {azim:.1f}, incl: {incl:.1f})"),
            c=f'C{i}',
            ls='--',
            ms=5,
            tickangs = (45,0),
            lw=0.5,
            grid=True
        )
        
for ax in axes.flatten():
    ax.set_xlabel(ax.get_xlabel(), fontsize=12)
    ax.set_title(ax.get_title(), fontsize=12)
    

plt.tight_layout()
save_figure_paper('vario_repro.png')

In [None]:
fig, axes = plt.subplots(2, 2, figsize=(10, 6))
        
for vario_calc in vario_exp_reals_ns:
    for i, var in enumerate(missing_variables):
        data_var = vario_calc[var].data
        for index in data_var['Search Index'].unique():
            mask = data_var['Search Index'] == index
            vario_calc[var].plot_draw(
                search_index=index,
                ax=axes[int(index), i],
                c='gray',
                ls='-',
                lw=2
            )
            
for i, var in enumerate(missing_variables):
    data_var = vario_exp[var].data
    for index in data_var['Search Index'].unique():
        mask = data_var['Search Index'] == index
        azim = data_var[mask].Azimuth.mean()
        incl = data_var[mask].Inclination.mean()
        vario_exp[var].plot(
            search_index=index,
            ax=axes[int(index), i],
            ylim=(0, 1.5),
            title=(f"{var} (azim: {azim:.1f}, incl: {incl:.1f})"),
            c=f'C{i}',
            ls='--',
            ms=5,
            tickangs = (45,0),
            lw=0.5,
            grid=True
        )
#         vario_models[var].plot_draw(
#             ax=axes[int(index), i], azm=azim, incl=0.0, lw=2, c="k", ls="--"
#         )
        
for ax in axes.flatten():
    ax.set_xlabel(ax.get_xlabel(), fontsize=12)
    ax.set_title(ax.get_title(), fontsize=12)
    

plt.tight_layout()
save_figure_paper('vario_repro_ns.png')

# Correlation Matrix Reproduction

In [None]:
def get_bivstat(ireal):
    sim = simcache_final.get_real(
        ireal, columns=variables
    )
    return rmsp.BivStats(
        sim, variables
    )


parallel_args = [ireal for ireal in range(n_real)]
all_biv_stats = rmsp.parallel_runner(
    get_bivstat,
    parallel_args,
    num_threads=rmsp.GlobalParams["core.num_threads"],
    progressbar=True,
)

In [None]:
ref_bivstat = rmsp.BivStats(
    data,
    variables
)

In [None]:
_ = ref_bivstat.matrixplot_checkreals(all_biv_stats, figsize=(10, 10), stat="rankcorr")
save_figure_paper('corr_matrix_original.png')

# Export

In [None]:
pickle_data(simcache, 'simcache_ns.pkl')
pickle_data(simcache_final, 'simcache_final.pkl')

---
# <span style='color:#1B127A;'> RMSP GMM </span>

## <span style='color:#5177F9;'> Imputation </span>

In [None]:
imputer = rmsp.GMMImputer().fit(data_ns, variables)

In [None]:
search = rmsp.Search([0.0] * 3, [100.0] * 3, min_comps=1, max_comps=100)
num_search = 20

variomod_dict = {}
for var in vario_models:
    variomod_dict[var] = vario_models[var]
for var in response_variables:
    # a dummy model since the first variable has no missing samples
    vario_dict = dict(
        num_struct=1,
        nugget=0.05,
        shapes=["spherical"],
        var_contribs=[0.95],
        angles=[0.0] * 3,
        ranges=[100] * 3,
    )
    variomod_dict[var] = rmsp.VarioModel(vario_dict)

rmsp_nscache = imputer.impute(
    variomod_dict,
    gmm_model,
    search,
    num_search,
    reals=n_real,
    cache=outdir + "rmsp_cache_ns",
)

In [None]:
rmsp_finalcache = rmsp.SimCache(len(data_ns), file_prefix=outdir + "rmsp_cache_final",)
for ireal, real in rmsp_nscache.iter_realizations():
    real_final = data_ns[[]].copy()
    for var in variables:
        real_final[var] = ns_transformers[var].inverse_transform(
            real[var]
        )

    rmsp_finalcache.set_real(ireal, real_final[variables])

## <span style='color:#5177F9;'> Checks </span>

In [None]:
psims_final_rms = {}
for var in missing_variables:
    psims_final_rms[var] = rmsp.postsim(rmsp_finalcache, var=var)
    psims_final_rms[var][f'{var} True'] = data_all[var]

### <span style='color:#51AFF9;'> Cross validations </span>

In [None]:
fig, axes = create_axes(len(missing_variables), len(missing_variables), (9, 4))
for (var, psim), ax in zip(psims_final_rms.items(), axes):
    mask = psim[f"{var} stdev"] > 0.0
    cv = rmsp.CrossVal(psim[mask], var + " e-type", f"{var} True")
    cv.scatplot(
        psim[mask],
        ax=ax,
        s=4,
        c="#838383",
        plot_worst=False,
        grid=True,
        stats=["count", "meanx", "meany", "rmse", "corr", "sor"],
    )
    set_axis_label_font(ax)
fig.tight_layout()
save_figure_paper("cv_gmm_etype.png")

In [None]:
real_rms = rmsp_finalcache.get_real(0)
for var in missing_variables:
    real_rms[f"{var} True"] = data_all[var]
fig, axes = create_axes(len(missing_variables), len(missing_variables), (9, 4))
for (var, psim), ax in zip(psims_final_rms.items(), axes):
    mask = psim[f"{var} stdev"] > 0.0
    cv = rmsp.CrossVal(real_rms[mask], var, f"{var} True")
    cv.scatplot(
        real_rms[mask],
        ax=ax,
        s=4,
        c="#838383",
        plot_worst=False,
        grid=True,
        stats=["count", "meanx", "meany", "rmse", "corr", "sor"],
    )
    set_axis_label_font(ax)
fig.tight_layout()
save_figure_paper("cv_gmm_real.png")

In [None]:
fig,axes = scaplot_compare(rmsp_finalcache.get_real(0), variables, grid=True, s=8)
fig.tight_layout()
fig.suptitle('Imputed locations')

### <span style='color:#51AFF9;'> NS Histogram Checks </span>

In [None]:
ref_unistats = {}
sim_unistats_rms = {}
for var in variables:
    ref_unistats[var] = rmsp.UniStats(
        data_ns[var]
    )
    sim_unistats_rms[var] = [
        rmsp.UniStats(rmsp_nscache.get_real(ireal, var)[var])
        for ireal in range(n_real)
    ]
    
fig, axes = create_axes(3,len(variables), (14,4))

for var, ax in zip(variables, axes):
    ref_unistats[var].cdfplot_checkreals(sim_unistats_rms[var], ax=ax)
    set_axis_label_font(ax)

### <span style='color:#51AFF9;'> Original histogram checks </span>

In [None]:
ref_unistats = {}
sim_unistats_rms = {}
for var in variables:
    ref_unistats[var] = rmsp.UniStats(
        data[var], data[f'{var}_wt']
    )
    sim_unistats_rms[var] = [
        rmsp.UniStats(rmsp_finalcache.get_real(ireal, var)[var])
        for ireal in range(n_real)
    ]
fig, axes = create_axes(3,len(variables), (14,4))
for var, ax in zip(variables, axes):
    ref_unistats[var].cdfplot_checkreals(
        sim_unistats_rms[var], ax=ax, log=True, grid=True,
    )
    set_axis_label_font(ax)
    
fig, axes = plt.subplots(1, len(missing_variables), figsize=(12, 4))

for var, ax in zip(missing_variables, axes):
    ref_unistats[var].cdfplot_checkreals(
        sim_unistats[var], ax=ax, log=True, grid=True, stats=["mean", "count"]
    )
    set_axis_label_font(ax)
save_figure_paper("hist_repro_rmsp.png")

### <span style='color:#51AFF9;'> Variogram Reproduction </span>

In [None]:
vario_exp_reals = []
for ireal, real in rmsp_finalcache.iter_realizations(points=data):
    vario_exp_real = {}
    real = rmsp.PointData(real, x=data.x, y=data.y, z=data.z)
    for var in missing_variables:
        vario_exp_real[var] = rmsp.ExpVario("backns").calculate(
            real, var, searches=exp_vario_search, nstransformer = ns_transformers[var]
        )
        
    vario_exp_reals.append(vario_exp_real)

In [None]:
fig, axes = plt.subplots(2, 2, figsize=(10, 6))
        
for vario_calc in vario_exp_reals:
    for i, var in enumerate(missing_variables):
        data_var = vario_calc[var].data
        for index in data_var['Search Index'].unique():
            mask = data_var['Search Index'] == index
            vario_calc[var].plot_draw(
                search_index=index,
                ax=axes[int(index), i],
                c='gray',
                ls='-',
                lw=2
            )
            
for i, var in enumerate(missing_variables):
    data_var = vario_exp_orig[var].data
    for index in data_var['Search Index'].unique():
        mask = data_var['Search Index'] == index
        azim = data_var[mask].Azimuth.mean()
        incl = data_var[mask].Inclination.mean()
        vario_exp_orig[var].plot(
            search_index=index,
            ax=axes[int(index), i],
            ylim=(0, 1.5),
            title=(f"{var} (azim: {azim:.1f}, incl: {incl:.1f})"),
            c=f'C{i}',
            ls='--',
            ms=5,
            tickangs = (45,0),
            lw=0.5,
            grid=True
        )
        
for ax in axes.flatten():
    ax.set_xlabel(ax.get_xlabel(), fontsize=12)
    ax.set_title(ax.get_title(), fontsize=12)
    

plt.tight_layout()
save_figure_paper('vario_repro_rmsp.png')

# Clean up

In [None]:
# gs.rmdir(outdir) #command to delete generated data file
# gs.rmfile('temp') 