# Import all the libraries and packages 

In [None]:
import os
import numpy as np
import joblib
import seaborn as sns
import matplotlib.pyplot as plt

import torch
import torch.nn.functional as F

import pyro
import pyro.distributions as dist
from pyro.infer import MCMC, NUTS, Predictive, EmpiricalMarginal
from pyro.infer.autoguide import init_to_mean, init_to_median, init_to_value
from pyro.infer.inspect import get_dependencies
from pyro.infer import SVI, TraceEnum_ELBO, config_enumerate, infer_discrete

import gempy as gp
import gempy_engine
import gempy_viewer as gpv
from gempy_engine.core.backend_tensor import BackendTensor
import arviz as az
from gempy_probability.plot_posterior import default_red, default_blue, PlotPosterior

from sklearn.mixture import GaussianMixture
from sklearn.mixture import BayesianGaussianMixture
from sklearn.cluster import KMeans
from scipy.stats import multivariate_normal, norm

# Get the path of all the data

In [None]:

filename_a = './Fw__Hyperspectral_datasets_from_the_KSL_cores/CuSp131.pkl'
filename_b = './Fw__Hyperspectral_datasets_from_the_KSL_cores/CuSp133.pkl'
filename_c = './Fw__Hyperspectral_datasets_from_the_KSL_cores/CuSp136.pkl'

# Load and get the data

In [None]:
with open(filename_a, 'rb') as myfile:
    a= joblib.load(myfile)

# Description about Hyperspectral data
## Data is obtained by scannig the core data from different boreholes using hyperspectral sensors. There were around 450 channels for each pixels initially. It was preprocessed and seperated based on 10 different types of rocks. In each of the file we have "X", "Y", "Z" coordinates points corresponding to sensors and corresponding to each rock type we have a transformed RGB correspondence information. 

In [None]:
# The variability in "X" and "Y" are much smaller as compared to "Z" direction in case of birehole information. 
# Therefore, we are trying to build our model considering the "Z" direction mostly. 
# get the z-cordinates of borehole
zz = a['XYZ'][:,2]
print(zz.shape)
# sort the z-cordinates
ixx = np.argsort( zz )
# mask if values is less than some specified value
mask = zz[ixx] < -700
ah = a['BR_Anhydrite'][:,0] # correlates to "anhydrite index" derived from hyperspectral 
position_cord , hsi_data = zz[ixx][mask], ah[ixx][mask]/255 # To normalize the hyperspectral spectra, divide it with 255. 
plt.plot( position_cord , hsi_data )
print(position_cord.shape)

# define breakpoints 
# In general it is very difficult to define the breakpoints in the plot. Prepocessing of hyperspectral data is itself a very difficult task
# becasue of high correaltion, high dimensional and noisy data. 
brk1 = -845 
brk2 = -825 

plt.axvline( brk1, color='r' )
plt.axvline( brk2, color='g')

In [None]:
hsi_data.reshape(1,-1)

# Since it is difficult to classify a hyperspectral data in general. We can apply different classical clustering methods to have some starting guess

In [None]:

X = np.concatenate(((position_cord.reshape((-1,1))/1000.0), hsi_data.reshape((-1,1))), axis=1)


In [None]:
#gm2 = BayesianGaussianMixture(n_components=3,covariance_type="full", random_state=0).fit(hsi_data.reshape(-1,1))
gm2 = BayesianGaussianMixture(n_components=3,covariance_type="full", random_state=0).fit(X)
gm2.means_ , gm2.covariances_

In [None]:
loc_mean = torch.tensor(gm2.means_)
loc_cov  = torch.tensor(gm2.covariances_)

In [None]:
correct_order = [0,2,1]
loc_mean, loc_cov = loc_mean[correct_order], loc_cov[correct_order]

In [None]:
from scipy.stats import multivariate_normal

# Create a grid of points
x, y = np.meshgrid(np.linspace(-0.9, -0.7, 100), np.linspace(-2, 2, 100))
pos = np.dstack((x, y))  # Combine x and y grids into a (100, 100, 2) array
for i in range(3):
    # Create a multivariate normal distribution
    rv = multivariate_normal(loc_mean[i], loc_cov[i])

    # Calculate PDF values for each point in the grid
    pdf_values = rv.pdf(pos)

    # Plot the Gaussian distribution using contour plot
    plt.figure(figsize=(8, 6))
    plt.contourf(x, y, pdf_values, cmap='viridis')
    plt.colorbar(label='Probability Density')
    plt.xlabel('spatial')
    plt.ylabel('spectral')
    plt.title('2D Gaussian Distribution')
    plt.scatter(loc_mean[i][0], loc_mean[i][1], color='red', label='Mean')
    plt.legend()
    plt.grid(True)
plt.show()

In [None]:
from scipy.stats import multivariate_normal, norm
plt.figure(figsize=(8, 6))
# Create a grid of points
x, y = np.meshgrid(np.linspace(-0.9, -0.7, 100), np.linspace(-2, 2, 100))
pos = np.dstack((x, y))  # Combine x and y grids into a (100, 100, 2) array
for i in range(3):
    # Create a multivariate normal distribution
    rv = multivariate_normal(loc_mean[i], loc_cov[i])

    # Calculate PDF values for each point in the grid
    pdf_values = rv.pdf(pos)

    # Plot the Gaussian distribution using contour plot
    
    plt.contour(x, y, pdf_values, cmap='viridis')
    plt.scatter(loc_mean[i][0], loc_mean[i][1], label='Mean'+str(i+1))
#plt.colorbar(label='Probability Density')
plt.xlabel('spatial')
plt.ylabel('spectral')
plt.title('2D Gaussian Distribution')

plt.legend()
plt.grid(True)
plt.show()

In [None]:
y_gmm_label2 = gm2.predict(X)
print(y_gmm_label2)
y_gmm_label_arranged2 = np.zeros_like(y_gmm_label2)
y_gmm_label_arranged2[y_gmm_label2 == 1] = 3
y_gmm_label_arranged2[y_gmm_label2 == 0] = 2
y_gmm_label_arranged2[y_gmm_label2 == 2] = 1
y_gmm_label_arranged2 = torch.Tensor(y_gmm_label_arranged2)
y_gmm_label_arranged2 

In [None]:
import matplotlib.pyplot as plt
# Define colors for each label
colors = ['r', 'g', 'b']
labels = y_gmm_label_arranged2

# Plot the dataset with different colors for each label
plt.figure(figsize=(8, 6))
for label_val, color in zip([1,2,3], colors):
    plt.scatter(position_cord[labels == label_val], hsi_data[labels == label_val], c=color, label=f'Label {label_val}')

plt.xlabel('z')
plt.ylabel('hsi_data')
plt.title('2D Dataset with Label Information')
plt.legend()
plt.grid(True)
plt.show()

In [None]:
# A = np.zeros_like(a['BR_Anhydrite'][:, 0])
# B = np.zeros_like(a['BR_Anhydrite'][:, 0])
# C = np.zeros_like(a['BR_Anhydrite'][:, 0])

A = np.zeros_like(hsi_data)
B = np.zeros_like(hsi_data)
C = np.zeros_like(hsi_data)

# Get indices where the mask is True
# indices_A = np.where(mask & (zz[ixx] < brk1))
# indices_B = np.where(mask & (zz[ixx] > brk1) & (zz[ixx] < brk2))
# indices_C = np.where(mask & (zz[ixx] > brk2))
shift =0
# Get indices where the mask is True
indices_A = np.where(mask & (zz[ixx] < (brk1+shift)))
indices_B = np.where(mask & (zz[ixx] > (brk1+shift)) & (zz[ixx] < (brk2+shift)))
indices_C = np.where(mask & (zz[ixx] > (brk2+shift)))

y_obs_label = torch.ones(234)
y_obs_label[indices_A] =3
y_obs_label[indices_B] = 2
y_obs_label

In [None]:
matched_label2 = y_obs_label == y_gmm_label_arranged2
(matched_label2.sum()/len(y_obs_label)) *100

In [None]:
#y_obs_label = y_gmm_label_arranged
y_obs_label = hsi_data.reshape((-1,1))
y_obs_label2 = X

In [None]:
geo_model_test = gp.create_geomodel(
    project_name='Gempy_abc_Test',
    extent=[0, 1000, -500, 500, -900, -700],
    resolution=[100,100,100],
    refinement=3,
    structural_frame= gp.data.StructuralFrame.initialize_default_structure()
    )
p2d = gpv.plot_2d(geo_model_test)

In [None]:
geo_model_test.grid.active_grids_bool

In [None]:
gp.add_surface_points(
    geo_model=geo_model_test,
    x=[100.0, 900.0],
    y=[0.0, 0.0],
    z=[brk1, brk1],
    elements_names=['surface1', 'surface1']
)
gpv.plot_2d(geo_model_test, cell_number=11)

In [None]:
gp.add_orientations(
    geo_model=geo_model_test,
    x=[500],
    y=[0.0],
    z=[brk1],
    elements_names=['surface1'],
    pole_vector=[[0, 0, 1]]
)

gpv.plot_2d(geo_model_test, cell_number=5)

In [None]:
geo_model_test.update_transform(gp.data.GlobalAnisotropy.NONE)

In [None]:
gp.compute_model(geo_model_test, engine_config=gp.data.GemPyEngineConfig())

In [None]:
geo_model_test.interpolation_options.kernel_options

In [None]:
gpv.plot_2d(geo_model_test, cell_number=[5])

In [None]:
geo_model_test.structural_frame

In [None]:
element2 = gp.data.StructuralElement(
    name='surface2',
    color=next(geo_model_test.structural_frame.color_generator),
    surface_points=gp.data.SurfacePointsTable.from_arrays(
        x=np.array([100.0, 900.0]),
        y=np.array([0.0, 0.0]),
        z=np.array([brk2, brk2]),
        names='surface2'
    ),
    orientations=gp.data.OrientationsTable.initialize_empty()
)

geo_model_test.structural_frame.structural_groups[0].append_element(element2)
# Compute and visualize the updated model:
gp.compute_model(geo_model_test)
gpv.plot_2d(geo_model_test, cell_number=5, legend='force')
    

In [None]:
#gpv.plot_3d(geo_model_test, image=True)

In [None]:
geo_model_test.structural_frame

In [None]:
geo_model_test.structural_frame.structural_groups[0].elements[0], geo_model_test.structural_frame.structural_groups[0].elements[1] = \
geo_model_test.structural_frame.structural_groups[0].elements[1], geo_model_test.structural_frame.structural_groups[0].elements[0]

In [None]:
geo_model_test.structural_frame

In [None]:
gpv.plot_2d(geo_model_test, cell_number=5, legend='force')

# Create a custome grid where the observed data information is available 

In [None]:
x_loc = 300
y_loc = 0
z_loc = position_cord
xyz_coord = np.array([[x_loc, y_loc, z] for z in z_loc])
gp.set_custom_grid(geo_model_test.grid, xyz_coord=xyz_coord)

In [None]:
sp_coords_copy_test = geo_model_test.interpolation_input.surface_points.sp_coords.copy()
geo_model_test.transform.apply_inverse(sp_coords_copy_test)

In [None]:
gp.add_surface_points(
    geo_model=geo_model_test,
    x=[x_loc, x_loc],
    y=[0.0, 0.0],
    z=[brk1, brk2],
    elements_names=['surface1', 'surface2']
)
gpv.plot_2d(geo_model_test, cell_number=1)

In [None]:
geo_model_test.surface_points

In [None]:
gp.compute_model(geo_model_test)
gpv.plot_2d(geo_model_test, cell_number=5, legend='force')

In [None]:
sp_coords_copy_test = geo_model_test.interpolation_input.surface_points.sp_coords.copy()
geo_model_test.transform.apply_inverse(sp_coords_copy_test)

In [None]:
# Change the backend to PyTorch for probabilistic modeling
BackendTensor.change_backend_gempy(engine_backend=gp.data.AvailableBackends.PYTORCH)
# Set random seed for PyTorch backend
torch.manual_seed(42)

In [None]:
geo_model_test.interpolation_options.uni_degree = 0
geo_model_test.interpolation_options.mesh_extraction = False
geo_model_test.interpolation_options.sigmoid_slope = 1100.

In [None]:
#geo_model_test.solutions.octrees_output[0].last_output_center

In [None]:
custom_grid_values = geo_model_test.solutions.octrees_output[0].last_output_center.custom_grid_values
custom_grid_values.shape

In [None]:
y_obs_label.shape, y_obs_label2.shape

In [None]:
geo_model_test.surface_points

In [None]:
geo_model_test.orientations

In [None]:
geo_model_test.transform.apply_inverse(sp_coords_copy_test)

In [None]:

# Set random seed for Pyro
pyro.set_rng_seed(42)

@config_enumerate
def model_test(y_obs_label):
    """
    This Pyro model represents the probabilistic aspects of the geological model.
    It defines a prior distribution for the top layer's location and
    computes the thickness of the geological layer as an observed variable.
    """
    # Define prior for the top layer's location
    prior_mean_surface_1 = sp_coords_copy_test[2, 2]
    prior_mean_surface_2 = sp_coords_copy_test[5, 2]
    
    

    
    mu_surface_1 = pyro.sample('mu_1', dist.Normal(prior_mean_surface_1, torch.tensor(0.02, dtype=torch.float64)))
    mu_surface_2 = pyro.sample('mu_2', dist.Normal(prior_mean_surface_2, torch.tensor(0.02, dtype=torch.float64)))
    
    # Ensure that mu_surface_1 is greater than mu_surface_2
    pyro.sample('condition', dist.Delta(torch.tensor(1.0, dtype=torch.float64)), obs=(mu_surface_1 > mu_surface_2))
    # Update the model with the new top layer's location
    interpolation_input = geo_model_test.interpolation_input
    
    
    interpolation_input.surface_points.sp_coords = torch.index_put(
        interpolation_input.surface_points.sp_coords,
        (torch.tensor([2]), torch.tensor([2])),
        mu_surface_1
    )
    interpolation_input.surface_points.sp_coords = torch.index_put(
        interpolation_input.surface_points.sp_coords,
        (torch.tensor([5]), torch.tensor([2])),
        mu_surface_2
    )
    

    
    
    # # Compute the geological model
    geo_model_test.solutions = gempy_engine.compute_model(
        interpolation_input=interpolation_input,
        options=geo_model_test.interpolation_options,
        data_descriptor=geo_model_test.input_data_descriptor,
        geophysics_input=geo_model_test.geophysics_input,
    )
    
    # Compute and observe the thickness of the geological layer
    
    custom_grid_values = geo_model_test.solutions.octrees_output[0].last_output_center.custom_grid_values
    
    lambda_ = 4
    class_label = F.softmax(- lambda_ * (torch.tensor([1,2,3], dtype=torch.float64) - custom_grid_values.reshape(-1,1))**2, dim=1)
    
    sample =[]
    for i in range(loc_mean.shape[0]):
        sample_data = pyro.sample("sample_data"+str(i+1), dist.MultivariateNormal(loc=loc_mean[i],covariance_matrix=loc_cov[i]))
        sample.append(sample_data)
    sample_tesnor = torch.stack(sample, dim=0)
    
    
    with pyro.plate('N='+str(y_obs_label.shape[0]), y_obs_label.shape[0]):
        assignment = pyro.sample("assignment", dist.Categorical(class_label))
        
        obs = pyro.sample("obs", dist.MultivariateNormal(loc=sample_tesnor[assignment],covariance_matrix=loc_cov[assignment]), obs=y_obs_label)
        
    return obs
    
dependencies = get_dependencies(model_test, model_args=(torch.tensor(X),))
pyro.render_model(model_test, model_args=(torch.tensor(X),),render_distributions=True)


In [None]:
y_obs_label =torch.tensor(X)

In [None]:

model_test(y_obs_label)

In [None]:
prior = Predictive(model_test, num_samples=100)(y_obs_label)

# Key to avoid
avoid_key = 'condition'

# Create sub-dictionary without the avoid_key
prior = dict((key, value) for key, value in prior.items() if key != avoid_key)

data = az.from_pyro(prior=prior)
az.plot_trace(data.prior)
plt.show()

In [None]:
pyro.primitives.enable_validation(is_validate=True)
nuts_kernel = NUTS(model_test, step_size=0.0085, adapt_step_size=True, target_accept_prob=0.9, max_tree_depth=10, init_strategy=init_to_mean)
mcmc = MCMC(nuts_kernel, num_samples=200, warmup_steps=50, disable_validation=False)
mcmc.run(y_obs_label)

In [None]:
posterior_samples = mcmc.get_samples()
posterior_predictive = Predictive(model_test, posterior_samples)(y_obs_label)
data = az.from_pyro(posterior=mcmc, prior=prior, posterior_predictive=posterior_predictive)
az.plot_trace(data)
plt.show()

In [None]:
posterior_samples

In [None]:
loc_mean_posterior, loc_cov_posterior =[], []
for key, values in posterior_samples.items():
    print(key)
    if key == "sample_data1":
        mean = values.mean(dim=0)
        cov = np.cov(values.detach().numpy(), rowvar=False)
        print("mean\n",mean)
        print("cov\n", cov)
        loc_mean_posterior.append(mean.detach().numpy())
        loc_cov_posterior.append(cov)
    elif key == "sample_data2":
        mean = values.mean(dim=0)
        cov = np.cov(values.detach().numpy(), rowvar=False)
        print("mean\n",mean)
        print("cov\n", cov)
        loc_mean_posterior.append(mean.detach().numpy())
        loc_cov_posterior.append(cov)
    elif key == "sample_data3":
        mean = values.mean(dim=0)
        cov = np.cov(values.detach().numpy(), rowvar=False)
        print("mean\n",mean)
        print("cov\n", cov)
        loc_mean_posterior.append(mean.detach().numpy())
        loc_cov_posterior.append(cov)
    elif key == "sigma_data":
        print("mean\n",values.mean(dim=0), "\nstd\n", values.std(dim=0))
    else:
        print("mean\n",values.mean(), "\nstd\n", values.std())

In [None]:
# Create a grid of points
x, y = np.meshgrid(np.linspace(-0.9, -0.7, 100), np.linspace(-0.2, 1.2, 100))
pos = np.dstack((x, y))  # Combine x and y grids into a (100, 100, 2) array
for i in range(3):
    # Create a multivariate normal distribution
    rv = multivariate_normal(loc_mean_posterior[i], loc_cov_posterior[i])

    # Calculate PDF values for each point in the grid
    pdf_values = rv.pdf(pos)

    # Plot the Gaussian distribution using contour plot
    plt.figure(figsize=(8, 6))
    plt.contourf(x, y, pdf_values, cmap='viridis')
    plt.colorbar(label='Probability Density')
    plt.xlabel('spatial')
    plt.ylabel('spectral')
    plt.title('2D Gaussian Distribution')
    plt.scatter(loc_mean_posterior[i][0], loc_mean_posterior[i][1], color='red', label='Mean')
    plt.legend()
    plt.grid(True)
plt.show()

In [None]:
plt.figure(figsize=(8, 6))
# Create a grid of points
x, y = np.meshgrid(np.linspace(-0.9, -0.7, 1000), np.linspace(-2, 2, 1000))
pos = np.dstack((x, y))  # Combine x and y grids into a (100, 100, 2) array
for i in range(3):
    # Calculate distances from the mean for all points in the grid
    #distances = np.sqrt((x - loc_mean_posterior[i][0])**2 + (y - loc_mean_posterior[i][1])**2)
    # Create a multivariate normal distribution
    rv = multivariate_normal(loc_mean_posterior[i], loc_cov_posterior[i])
    # Calculate PDF values for each point in the grid
    pdf_values = rv.pdf(pos)
    
   
    #pdf_values[distances >0.01]=0
    # Plot the Gaussian distribution using contour plot
    plt.contour(x, y, pdf_values, extend='min', cmap='viridis')
    plt.scatter(loc_mean_posterior[i][0], loc_mean_posterior[i][1], label='Mean'+str(i+1))
#plt.colorbar(label='Probability Density')
plt.xlabel('spatial')
plt.ylabel('spectral')
plt.title('2D Gaussian Distribution')

plt.legend()
plt.grid(True)
plt.show()

In [None]:
data.posterior

In [None]:
az.plot_density(
    data=[data.posterior, data.prior],
    shade=.9,
    var_names=['mu_1'],
    data_labels=["Posterior Predictive", "Prior Predictive"],
    colors=[default_red, default_blue],
)
plt.show()

In [None]:
az.plot_density(
    data=[data.posterior, data.prior],
    shade=.9,
    var_names=['mu_2'],
    data_labels=["Posterior Predictive", "Prior Predictive"],
    colors=[default_red, default_blue],
)
plt.show()

In [None]:
# Update the model with the new top layer's location
interpolation_input = geo_model_test.interpolation_input
interpolation_input.surface_points.sp_coords = torch.index_put(
    interpolation_input.surface_points.sp_coords,
    (torch.tensor([2]), torch.tensor([2])),
    posterior_samples["mu_1"].mean()
)
interpolation_input.surface_points.sp_coords = torch.index_put(
    interpolation_input.surface_points.sp_coords,
    (torch.tensor([5]), torch.tensor([2])),
    posterior_samples["mu_2"].mean()
)

#print("interpolation_input",interpolation_input.surface_points.sp_coords)

# # Compute the geological model
geo_model_test.solutions = gempy_engine.compute_model(
    interpolation_input=interpolation_input,
    options=geo_model_test.interpolation_options,
    data_descriptor=geo_model_test.input_data_descriptor,
    geophysics_input=geo_model_test.geophysics_input,
)

gpv.plot_2d(geo_model_test, cell_number=5,legend='force')

In [None]:
sp_coords_copy_test2 =interpolation_input.surface_points.sp_coords
sp_coords_copy_test2

In [None]:
sp_cord= geo_model_test.transform.apply_inverse(sp_coords_copy_test2.detach().numpy())
sp_cord

In [None]:
#gpv.plot_3d(geo_model_test)

In [None]:
custom_grid_values = geo_model_test.solutions.octrees_output[0].last_output_center.custom_grid_values
custom_grid_values