Using Bayesian inference for learning synthesis-structure-property relationship via coregionalized piecewise function determination.
Examples for paper:
Code and examples are presented first for N-D algorithm with 1 structure input and multiple functional property inputs where functional properties are measured over the same materials (not required to be the same as the structure).

Table of Content:
* Libraries to Install
* Import Libraries
* N-dimensional Case:
     * .py file for ND functions.
     * 2D Edge Case Challenges
         * Set up challenge data
         * Multicore .py files & 1 core scripts
         * Visualize results
     * (Bi,Sm)(Sc,Fe)O3 Challenge
         * Set up challenge data
         * Multicore .py files & 1 core scripts
         * Visualize results
     * Compute performance measures available in paper Table 1.

### Libraries to Install: ** please see requirements.txt and SAGEn_241025a.yml **

### Import Libraries

In [None]:
import matplotlib.pyplot as plt

import numpyro
import numpy as np
from numpy.random import default_rng

import torch

import dill
from torch.distributions import constraints

from torch.nn import Parameter
from torch.nn import Softmax
from torch.nn.functional import one_hot

import pyro
from pyro.infer import MCMC, NUTS, HMC, Predictive, SVI, Trace_ELBO
import pyro.contrib.gp as gp
import pyro.distributions as dist
from pyro.infer.autoguide import initialization as init

from scipy.spatial.distance import pdist, squareform
from scipy.spatial import Voronoi
from scipy.stats import multivariate_normal, entropy
import scipy.io as sio
from scipy.special import softmax as softnp
from scipy.stats.mstats import mquantiles
from scipy.interpolate import griddata
from scipy.stats import gamma, gennorm

import jax
import jax.numpy as jnp
from jax import config
config.update("jax_enable_x64", True)
from jax.lax import dynamic_slice
from jax.nn import one_hot as jax_one_hot

from numpyro.infer import MCMC as nMCMC
from numpyro.infer import NUTS as nNUTS
from numpyro.infer import Predictive as nPredictive
import numpyro.distributions as ndist

import tensorflow_probability as tfp

from sklearn.metrics import precision_recall_fscore_support as prfs
from sklearn.metrics.cluster import fowlkes_mallows_score as fmi
from sklearn.metrics import fowlkes_mallows_score as fms

torch.set_default_dtype(torch.float64)
from tqdm import trange

from applied_active_learning_191228a import *
from cameo_240821a import *

## N-dimensional Cases

### ND Functions: Create .py file

In [None]:
%%writefile sage_2D_functions_230804a.py

import numpyro
import numpy as np
from numpy.random import default_rng

import torch

import jax
import jax.numpy as jnp
from jax import config
config.update("jax_enable_x64", True)
from jax.lax import dynamic_slice
from jax.nn import one_hot as jax_one_hot
import gpjax as gpx

from numpyro.infer import MCMC as nMCMC
from numpyro.infer import NUTS as nNUTS
from numpyro.infer import Predictive as nPredictive
import numpyro.distributions as ndist

torch.set_default_dtype(torch.float64)
from tqdm import tqdm
numpyro.set_host_device_count(100)

# ND ------------  
def model_SAGE_ND_230628a(xs, ys, xf, yf, num_regions, gpc_var_bounds = jnp.asarray([0.1,10.]), gpc_ls_bounds = jnp.asarray([.5,10.]), \
                gpr_var_bounds = jnp.asarray([0.1, 5.]), gpr_ls_bounds = jnp.asarray([.01,5.]), gpr_bias_bounds = jnp.asarray([-2.,2.]), \
                           gpr_noise_bounds = jnp.asarray([0.0001,.1]), differential_entropy = False):
    # assumes all function property measurements measured at same locations.
    jitter = 1e-6
    Ns = ys.shape[0]
    Nf = yf.shape[0]
    Mf = yf.shape[1]
    Nsf = xs.shape[0] + xf.shape[0]
    x_ = jnp.vstack((xs,xf))

    # Priors: Segmentation.
    gpc_var = numpyro.sample('gpc_var', ndist.Uniform(gpc_var_bounds[0], gpc_var_bounds[1])) # variance
    gpc_lengthscale = numpyro.sample('gpc_lengthscale', ndist.Uniform(gpc_ls_bounds[0], gpc_ls_bounds[1])) # ls
    gpc_bias = numpyro.sample('gpc_bias', ndist.Normal(0, 1)) # bias
        
    # Priors: GPR
    gpr_var_bound_min = gpr_var_bounds[0]*jnp.ones((num_regions,Mf))
    gpr_var_bound_max = gpr_var_bounds[1]*jnp.ones((num_regions,Mf))
    gpr_lengthscale_bound_min = gpr_ls_bounds[0]*jnp.ones((num_regions,Mf))
    gpr_lengthscale_bound_max = gpr_ls_bounds[1]*jnp.ones((num_regions,Mf))
    gpr_bias_bound_min = gpr_bias_bounds[0]*jnp.ones((num_regions,Mf))
    gpr_bias_bound_max = gpr_bias_bounds[1]*jnp.ones((num_regions,Mf))
    
    gpr_noise = numpyro.sample("gpr_noise", ndist.Uniform(gpr_noise_bounds[0], gpr_noise_bounds[1]))
    gpr_var = numpyro.sample("gpr_var", ndist.Uniform(gpr_var_bound_min, gpr_var_bound_max))
    gpr_lengthscale_x = numpyro.sample("gpr_lengthscale_x", ndist.Uniform(gpr_lengthscale_bound_min, gpr_lengthscale_bound_max))
    gpr_lengthscale_y = numpyro.sample("gpr_lengthscale_y", ndist.Uniform(gpr_lengthscale_bound_min, gpr_lengthscale_bound_max))
    gpr_bias = numpyro.sample("gpr_bias", ndist.Uniform(gpr_bias_bound_min, gpr_bias_bound_max))
    
    # Get latent functions, one for each region (i.e., segment).
    Fc = jnp.zeros((Ns+Nf,num_regions))
    for i in range(num_regions):
        with numpyro.plate('gpc_latent_response' + str(i), Nsf):
            gpc_latent = numpyro.sample('gpc_latent_' + str(i), ndist.Normal(0, 1))
        
        f = compute_f_matern52_jax(gpc_var, gpc_lengthscale, gpc_bias, gpc_latent, x_)
        Fc = Fc.at[:,i].set(f) # x = x.at[idx].set(y)
    
    probs = logits_to_probs_jax(Fc)
    probs_fp = probs[Ns:,:]

    # temp = jnp.sum(jnp.isnan(probs.flatten()))
    # jax.debug.print("NaN: {t}",t=temp)
        # print('NaN:')
        # print('gpc_latent:', gpc_latent)
        # print('f:',f)
        # print('Fc:',Fc)
    
    # gpr for each region.
    Fr = jnp.zeros((Nf,num_regions,Mf))
    for j in range(Mf):
        for i in range(num_regions):
            with numpyro.plate('gpr_latent_response' + str(i), Nf):
                gpr_latent = numpyro.sample('gpr_latent_'+str(i)+'_Mf_'+str(j), ndist.Normal(0, 1))
    
            gpr_lengthscale_array = jnp.array([gpr_lengthscale_x[i,j], gpr_lengthscale_y[i,j]])
            f = compute_f_jax(gpr_var[i,j], gpr_lengthscale_array, gpr_bias[i,j], gpr_latent, xf)
            Fr = Fr.at[:,i,j].set(f)
    
    f_piecewise = jnp.zeros((Nf, Mf))
    for j in range(Mf):
        for i in range(num_regions):
            f_piecewise = f_piecewise.at[:,j].set( f_piecewise[:,j] + probs_fp[:,i] * Fr[:,i,j] )

    llk = ndist.Categorical(probs=probs[:Ns,:]).log_prob(ys.flatten()).sum()  
    
    for j in range(Mf):
        llk = llk + ndist.Normal(f_piecewise[:,j], jnp.sqrt( gpr_noise ) ).log_prob(yf[:,j]).sum()     

    numpyro.deterministic("llk", llk )
    numpyro.factor("obs", llk ) # likelihood of segmentation

def model_SAGE_ND_FP_230628a(xf, yf, num_regions, gpc_var_bounds = jnp.asarray([0.1,10.]), gpc_ls_bounds = jnp.asarray([.5,10.]), \
                gpr_var_bounds = jnp.asarray([0.1, 5.]), gpr_ls_bounds = jnp.asarray([.01,5.]), gpr_bias_bounds = jnp.asarray([-2.,2.]), \
                           gpr_noise_bounds = jnp.asarray([0.0001,.1]), differential_entropy = False):
    # assumes all function property measurements measured at same locations.
    jitter = 1e-6
    Nf = yf.shape[0]
    Mf = yf.shape[1]

    # Priors: Segmentation.
    gpc_var = numpyro.sample('gpc_var', ndist.Uniform(gpc_var_bounds[0], gpc_var_bounds[1])) # variance
    gpc_lengthscale = numpyro.sample('gpc_lengthscale', ndist.Uniform(gpc_ls_bounds[0], gpc_ls_bounds[1])) # ls
    gpc_bias = numpyro.sample('gpc_bias', ndist.Normal(0, 1)) # bias
        
    # Priors: GPR
    gpr_var_bound_min = gpr_var_bounds[0]*jnp.ones((num_regions,Mf))
    gpr_var_bound_max = gpr_var_bounds[1]*jnp.ones((num_regions,Mf))
    gpr_lengthscale_bound_min = gpr_ls_bounds[0]*jnp.ones((num_regions,Mf))
    gpr_lengthscale_bound_max = gpr_ls_bounds[1]*jnp.ones((num_regions,Mf))
    gpr_bias_bound_min = gpr_bias_bounds[0]*jnp.ones((num_regions,Mf))
    gpr_bias_bound_max = gpr_bias_bounds[1]*jnp.ones((num_regions,Mf))
    
    gpr_noise = numpyro.sample("gpr_noise", ndist.Uniform(gpr_noise_bounds[0], gpr_noise_bounds[1]))
    gpr_var = numpyro.sample("gpr_var", ndist.Uniform(gpr_var_bound_min, gpr_var_bound_max))
    gpr_lengthscale_x = numpyro.sample("gpr_lengthscale_x", ndist.Uniform(gpr_lengthscale_bound_min, gpr_lengthscale_bound_max))
    gpr_lengthscale_y = numpyro.sample("gpr_lengthscale_y", ndist.Uniform(gpr_lengthscale_bound_min, gpr_lengthscale_bound_max))
    gpr_bias = numpyro.sample("gpr_bias", ndist.Uniform(gpr_bias_bound_min, gpr_bias_bound_max))
    
    # Get latent functions, one for each region (i.e., segment).
    Fc = jnp.zeros((Nf,num_regions))
    for i in range(num_regions):
        with numpyro.plate('gpc_latent_response' + str(i), Nf):
            gpc_latent = numpyro.sample('gpc_latent_' + str(i), ndist.Normal(0, 1))
        
        f = compute_f_matern52_jax(gpc_var, gpc_lengthscale, gpc_bias, gpc_latent, xf)
        Fc = Fc.at[:,i].set(f) # x = x.at[idx].set(y)
    
    probs_fp = logits_to_probs_jax(Fc)
    
    # gpr for each region.
    Fr = jnp.zeros((Nf,num_regions,Mf))
    for j in range(Mf):
        for i in range(num_regions):
            with numpyro.plate('gpr_latent_response' + str(i), Nf):
                gpr_latent = numpyro.sample('gpr_latent_'+str(i)+'_Mf_'+str(j), ndist.Normal(0, 1))
    
            gpr_lengthscale_array = jnp.array([gpr_lengthscale_x[i,j], gpr_lengthscale_y[i,j]])
            f = compute_f_jax(gpr_var[i,j], gpr_lengthscale_array, gpr_bias[i,j], gpr_latent, xf)
            Fr = Fr.at[:,i,j].set(f)
    
    f_piecewise = jnp.zeros((Nf, Mf))
    for j in range(Mf):
        for i in range(num_regions):
            f_piecewise = f_piecewise.at[:,j].set( f_piecewise[:,j] + probs_fp[:,i] * Fr[:,i,j] )

    llk = 0.
    
    for j in range(Mf):
        llk = llk + ndist.Normal(f_piecewise[:,j], jnp.sqrt( gpr_noise ) ).log_prob(yf[:,j]).sum()     

    numpyro.deterministic("llk", llk )
    numpyro.factor("obs", llk ) # likelihood of segmentation

def model_SAGE_ND_PM_230628a(xs, ys, xf, num_regions, gpc_var_bounds = jnp.asarray([0.1,10.]), gpc_ls_bounds = jnp.asarray([.5,10.])):
    # assumes all function property measurements measured at same locations.
    jitter = 1e-6
    Ns = xs.shape[0]
    Nf = xf.shape[0]
    Nsf = Ns + Nf

    # Priors: Segmentation.
    gpc_var = numpyro.sample('gpc_var', ndist.Uniform(gpc_var_bounds[0], gpc_var_bounds[1])) # variance
    gpc_lengthscale = numpyro.sample('gpc_lengthscale', ndist.Uniform(gpc_ls_bounds[0], gpc_ls_bounds[1])) # ls
    gpc_bias = numpyro.sample('gpc_bias', ndist.Normal(0, 1)) # bias
            
    # Get latent functions, one for each region (i.e., segment).
    Fc = jnp.zeros((Ns,num_regions))
    for i in range(num_regions):
        with numpyro.plate('gpc_latent_response' + str(i), Nsf):
            gpc_latent = numpyro.sample('gpc_latent_' + str(i), ndist.Normal(0, 1))
        
        f = compute_f_matern52_jax(gpc_var, gpc_lengthscale, gpc_bias, gpc_latent[:Ns], xs)
        Fc = Fc.at[:,i].set(f) # x = x.at[idx].set(y)
    
    probs = logits_to_probs_jax(Fc)
    
    llk = ndist.Categorical(probs=probs[:Ns,:]).log_prob(ys.flatten()).sum()  
    
    numpyro.deterministic("llk", llk )
    numpyro.factor("obs", llk ) # likelihood of segmentation
      
def predict_SAGE_ND_230628a(Xnew, xs, ys, xf, yf, num_regions, eps=1E-6, gpc_var_bounds=jnp.asarray([0.1,10.]), gpc_ls_bounds=jnp.asarray([.5,10.]), \
        gpr_var_bounds=jnp.asarray([0.1, 5.]), gpr_ls_bounds=jnp.asarray([.01,5.]), gpr_bias_bounds=jnp.asarray([-2.,2.]), \
        gpr_noise_bounds = jnp.asarray([0.0001,.1])):
    
    # assumes all function property measurements measured at same locations.
    key_in = jax.random.PRNGKey(0)
    _, subkey = jax.random.split(key_in)
    jitter = 1e-6
    Ns = ys.shape[0]
    Nf = yf.shape[0]
    Mf = yf.shape[1]
    Nsf = xs.shape[0] + xf.shape[0]
    x_ = jnp.vstack((xs,xf))
    
    # Priors: Segmentation.
    gpc_var = numpyro.sample('gpc_var', ndist.Uniform(gpc_var_bounds[0], gpc_var_bounds[1])) # variance
    gpc_lengthscale = numpyro.sample('gpc_lengthscale', ndist.Uniform(gpc_ls_bounds[0], gpc_ls_bounds[1])) # ls
    gpc_bias = numpyro.sample('gpc_bias', ndist.Normal(0, 1)) # bias
        
    # Priors: GPR
    gpr_var_bound_min = gpr_var_bounds[0]*jnp.ones((num_regions,Mf))
    gpr_var_bound_max = gpr_var_bounds[1]*jnp.ones((num_regions,Mf))
    gpr_lengthscale_bound_min = gpr_ls_bounds[0]*jnp.ones((num_regions,Mf))
    gpr_lengthscale_bound_max = gpr_ls_bounds[1]*jnp.ones((num_regions,Mf))
    gpr_bias_bound_min = gpr_bias_bounds[0]*jnp.ones((num_regions,Mf))
    gpr_bias_bound_max = gpr_bias_bounds[1]*jnp.ones((num_regions,Mf))
    
    gpr_noise = numpyro.sample("gpr_noise", ndist.Uniform(gpr_noise_bounds[0], gpr_noise_bounds[1]))
    gpr_var = numpyro.sample("gpr_var", ndist.Uniform(gpr_var_bound_min, gpr_var_bound_max))
    gpr_lengthscale_x = numpyro.sample("gpr_lengthscale_x", ndist.Uniform(gpr_lengthscale_bound_min, gpr_lengthscale_bound_max))
    gpr_lengthscale_y = numpyro.sample("gpr_lengthscale_y", ndist.Uniform(gpr_lengthscale_bound_min, gpr_lengthscale_bound_max))
    gpr_bias = numpyro.sample("gpr_bias", ndist.Uniform(gpr_bias_bound_min, gpr_bias_bound_max))
    
    # ------- added --------------
    Nnew = Xnew.shape[0]
    gpc_train_latent = jnp.zeros((x_.shape[0],num_regions))
    gpc_new_latent = jnp.zeros((Nnew,num_regions))
    gpc_new_probs = jnp.zeros((Nnew,num_regions))
    # get region labels
    
    gpc_latent = [0]*num_regions
    for i in range(num_regions):
        gpc_latent[i] = numpyro.sample('gpc_latent_' + str(i), ndist.Normal(0, 1))
            
    for j in range(num_regions):
        f = compute_f_matern52_jax(gpc_var,
                  gpc_lengthscale,
                  gpc_bias,
                  gpc_latent[j], x_)
        
        gpc_train_latent = gpc_train_latent.at[:,j].set(f)
        gpc_noise = 1E-6
        mean, cov, _ = gpr_forward_matern52_jax(gpc_var, gpc_lengthscale ,x_,f, Xnew, gpc_noise, include_noise=False)
        fhat = ndist.MultivariateNormal(mean, cov + jnp.eye(Nnew) * eps).sample(subkey)
        gpc_new_latent = gpc_new_latent.at[:,j].set(fhat)
    gpc_new_probs =  logits_to_probs_jax(gpc_new_latent)

    temp_f = jnp.sum(jnp.isnan(f.flatten()))
    jax.debug.print("Pred, NaN f: {t}", t = temp_f)
    temp_prob = jnp.sum(jnp.isnan(gpc_new_probs.flatten()))
    jax.debug.print("Pred, NaN prob: {t}", t = temp_prob)
    # -----------------------------  
    
    gpr_latent = [ [0]*Mf for i in range(num_regions)]
    for j in range(Mf):
        for i in range(num_regions):
            gpr_latent[i][j] = numpyro.sample('gpr_latent_'+str(i)+'_Mf_'+str(j), ndist.Normal(0, 1))
    
    # ---added -------------------------------------
    Fr_new = jnp.zeros((Nnew,num_regions,Mf))
    Vr_new = jnp.zeros((Nnew,num_regions,Mf))
    for k in range(Mf):
        for j in range(num_regions):
            eta = gpr_latent[j][k]
            gpr_lengthscale_array = jnp.array([gpr_lengthscale_x[j,k], gpr_lengthscale_y[j,k]])
            f = compute_f_jax(gpr_var[j,k],
                                gpr_lengthscale_array,
                                gpr_bias[j,k], eta, xf)
            mean, _, var = gpr_forward_jax(gpr_var[j,k],
                                         gpr_lengthscale_array,
                                         xf,f, Xnew, gpr_noise, include_noise=False)
            Fr_new = Fr_new.at[:,j,k].set(mean)
            Vr_new = Vr_new.at[:,j,k].set(var)    
    
    f_piecewise = jnp.zeros((Nnew, Mf, 1))
    v_piecewise = jnp.zeros((Nnew, Mf, 1))
    f_sample = jnp.zeros((Nnew, Mf, 1))
    for k in range(Mf):
        for j in range(num_regions):
            f_piecewise = f_piecewise.at[:,k,0].set( f_piecewise[:,k,0] + gpc_new_probs[:,j] * Fr_new[:,j,k] )
            v_piecewise = v_piecewise.at[:,k,0].set( v_piecewise[:,k,0] + gpc_new_probs[:,j] * Vr_new[:,j,k] )
        f_sample = f_sample.at[:,k,0].set( ndist.Normal(f_piecewise[:,k,0], jnp.sqrt( gpr_noise ) ).sample(subkey) ) 
        
    gpc_new_probs_ = numpyro.sample('gpc_new_probs', ndist.Delta(gpc_new_probs))
    f_piecewise_ = numpyro.sample('f_piecewise', ndist.Delta(f_piecewise))
    f_sample_ = numpyro.sample('f_sample', ndist.Delta(f_sample))
    Fr_new_ = numpyro.sample('Fr_new', ndist.Delta(Fr_new))
    v_piecewise_ = numpyro.sample('v_piecewise', ndist.Delta(v_piecewise))
    
    return gpc_new_probs_, f_piecewise_, f_sample_, Fr_new_, v_piecewise_

def predict_SAGE_ND_240712a(Xnew, xs, ys, xf, yf, num_regions, eps=1E-6, gpc_var_bounds=jnp.asarray([0.1,10.]), gpc_ls_bounds=jnp.asarray([.5,10.]), \
        gpr_var_bounds=jnp.asarray([0.1, 5.]), gpr_ls_bounds=jnp.asarray([.01,5.]), gpr_bias_bounds=jnp.asarray([-2.,2.]), \
        gpr_noise_bounds = jnp.asarray([0.0001,.1]), idx_Xnew_exclude_xs=None, idx_Xnew_match_xs=None, idx_xs_match_Xnew=None, idx_xf_exclude_xs=None):
    
    # assumes all function property measurements measured at same locations.
    # assumes Xnew does not include points in xs. This should be handled by functions before and after this one.
    
    key_in = jax.random.PRNGKey(0)
    _, subkey = jax.random.split(key_in)
    jitter = 1e-6
    Ns = ys.shape[0]
    Nf = yf.shape[0]
    Mf = yf.shape[1]
    Nsf = xs.shape[0] + xf.shape[0]
    Xnew_no_xs = Xnew[idx_Xnew_exclude_xs,:]
    N_Xnew_no_xs = Xnew_no_xs.shape[0] # number of prediction points excluding xs
    N_Xnew = Xnew.shape[0] # number of all prediction points.
    x_ = jnp.vstack((xs,xf),dtype=jnp.float64)
    idx_x_exclude_overlap_with_xs = jnp.concatenate( (jnp.arange(Ns), jnp.array(idx_xf_exclude_xs) + Ns) )
    
    # Priors: Segmentation.
    gpc_var = numpyro.sample('gpc_var', ndist.Uniform(gpc_var_bounds[0], gpc_var_bounds[1])) # variance
    gpc_lengthscale = numpyro.sample('gpc_lengthscale', ndist.Uniform(gpc_ls_bounds[0], gpc_ls_bounds[1])) # ls
    gpc_bias = numpyro.sample('gpc_bias', ndist.Normal(0, 1)) # bias
        
    # Priors: GPR
    gpr_var_bound_min = gpr_var_bounds[0]*jnp.ones((num_regions,Mf))
    gpr_var_bound_max = gpr_var_bounds[1]*jnp.ones((num_regions,Mf))
    gpr_lengthscale_bound_min = gpr_ls_bounds[0]*jnp.ones((num_regions,Mf))
    gpr_lengthscale_bound_max = gpr_ls_bounds[1]*jnp.ones((num_regions,Mf))
    gpr_bias_bound_min = gpr_bias_bounds[0]*jnp.ones((num_regions,Mf))
    gpr_bias_bound_max = gpr_bias_bounds[1]*jnp.ones((num_regions,Mf))
    
    gpr_noise = numpyro.sample("gpr_noise", ndist.Uniform(gpr_noise_bounds[0], gpr_noise_bounds[1]))
    gpr_var = numpyro.sample("gpr_var", ndist.Uniform(gpr_var_bound_min, gpr_var_bound_max))
    gpr_lengthscale_x = numpyro.sample("gpr_lengthscale_x", ndist.Uniform(gpr_lengthscale_bound_min, gpr_lengthscale_bound_max))
    gpr_lengthscale_y = numpyro.sample("gpr_lengthscale_y", ndist.Uniform(gpr_lengthscale_bound_min, gpr_lengthscale_bound_max))
    gpr_bias = numpyro.sample("gpr_bias", ndist.Uniform(gpr_bias_bound_min, gpr_bias_bound_max))
    
    # ------- added --------------
    gpc_train_latent = jnp.zeros((idx_x_exclude_overlap_with_xs.shape[0],num_regions),dtype=jnp.float64) # Num of training points
    gpc_new_latent = jnp.zeros((N_Xnew_no_xs,num_regions),dtype=jnp.float64) # Num of predict points excluding xs
    gpc_new_probs = jnp.zeros((N_Xnew,num_regions),dtype=jnp.float64) # Num of all predict points
    # get region labels
    
    gpc_latent = [0]*num_regions
    for i in range(num_regions):
        gpc_latent[i] = numpyro.sample('gpc_latent_' + str(i), ndist.Normal(0, 1))
            
    for j in range(num_regions):
        f = compute_f_matern52_jax(gpc_var,
                  gpc_lengthscale,
                  gpc_bias,
                  gpc_latent[j][idx_x_exclude_overlap_with_xs], x_[idx_x_exclude_overlap_with_xs,:])
        
        gpc_train_latent = gpc_train_latent.at[:,j].set(f)
        # jax.debug.print("Pred, NaN train latent: {t}", t = jnp.sum(jnp.isnan(gpc_train_latent.flatten())))
        gpc_noise = 1E-6
        mean, cov, _ = gpr_forward_matern52_jax(gpc_var, gpc_lengthscale ,x_[idx_x_exclude_overlap_with_xs,:],
                                                f, Xnew_no_xs, gpc_noise, include_noise=False)
        # jax.debug.print("Pred, NaN train latent: {mean}, {cov}", mean=mean, cov=cov )
        fhat = ndist.MultivariateNormal(mean, cov + jnp.eye(N_Xnew_no_xs) * eps).sample(subkey)
        # jax.debug.print("Pred, NaN fhat: {t}", t = jnp.sum(jnp.isnan(fhat.flatten())))
        gpc_new_latent = gpc_new_latent.at[:,j].set(fhat)

    # idx_exclude_xs=idx_exclude_xs, idx_Xnew_match_xs=idx_Xnew_match_xs, idx_xs_match_Xnew=idx_xs_match_Xnew
    gpc_new_probs = gpc_new_probs.at[idx_Xnew_exclude_xs,:].set( logits_to_probs_jax(gpc_new_latent) )
    gpc_new_probs = gpc_new_probs.at[idx_Xnew_match_xs,:].set( jax_one_hot(ys[idx_xs_match_Xnew], num_regions) )

    # temp_prob = jnp.sum(jnp.isnan(gpc_new_probs.flatten()))
    # jax.debug.print("Pred, NaN prob: {t}", t = temp_prob)
    # -----------------------------  
    
    gpr_latent = [ [0]*Mf for i in range(num_regions)]
    for j in range(Mf):
        for i in range(num_regions):
            gpr_latent[i][j] = numpyro.sample('gpr_latent_'+str(i)+'_Mf_'+str(j), ndist.Normal(0, 1))
    
    # ---added -------------------------------------
    Fr_new = jnp.zeros((N_Xnew,num_regions,Mf),dtype=jnp.float64)
    Vr_new = jnp.zeros((N_Xnew,num_regions,Mf),dtype=jnp.float64)
    for k in range(Mf):
        for j in range(num_regions):
            eta = gpr_latent[j][k]
            gpr_lengthscale_array = jnp.array([gpr_lengthscale_x[j,k], gpr_lengthscale_y[j,k]])
            f = compute_f_jax(gpr_var[j,k],
                                gpr_lengthscale_array,
                                gpr_bias[j,k], eta, xf)
            mean, _, var = gpr_forward_jax(gpr_var[j,k],
                                         gpr_lengthscale_array,
                                         xf,f, Xnew, gpr_noise, include_noise=False)
            Fr_new = Fr_new.at[:,j,k].set(mean)
            Vr_new = Vr_new.at[:,j,k].set(var)    
    
    f_piecewise = jnp.zeros((N_Xnew, Mf, 1),dtype=jnp.float64)
    v_piecewise = jnp.zeros((N_Xnew, Mf, 1),dtype=jnp.float64)
    f_sample = jnp.zeros((N_Xnew, Mf, 1),dtype=jnp.float64)
    for k in range(Mf):
        for j in range(num_regions):
            f_piecewise = f_piecewise.at[:,k,0].set( f_piecewise[:,k,0] + gpc_new_probs[:,j] * Fr_new[:,j,k] )
            v_piecewise = v_piecewise.at[:,k,0].set( v_piecewise[:,k,0] + gpc_new_probs[:,j] * Vr_new[:,j,k] )
        f_sample = f_sample.at[:,k,0].set( ndist.Normal(f_piecewise[:,k,0], jnp.sqrt( gpr_noise ) ).sample(subkey) ) 
        
    gpc_new_probs_ = numpyro.sample('gpc_new_probs', ndist.Delta(gpc_new_probs))
    f_piecewise_ = numpyro.sample('f_piecewise', ndist.Delta(f_piecewise))
    f_sample_ = numpyro.sample('f_sample', ndist.Delta(f_sample))
    Fr_new_ = numpyro.sample('Fr_new', ndist.Delta(Fr_new))
    v_piecewise_ = numpyro.sample('v_piecewise', ndist.Delta(v_piecewise))
    
    return gpc_new_probs_, f_piecewise_, f_sample_, Fr_new_, v_piecewise_

def predict_SAGE_ND_FP_230628a(Xnew, xf, yf, num_regions, eps=1E-6, gpc_var_bounds=jnp.asarray([0.1,10.]), gpc_ls_bounds=jnp.asarray([.5,10.]), \
        gpr_var_bounds=jnp.asarray([0.1, 5.]), gpr_ls_bounds=jnp.asarray([.01,5.]), gpr_bias_bounds=jnp.asarray([-2.,2.]), \
        gpr_noise_bounds = jnp.asarray([0.0001,.1])):
    
    # assumes all function property measurements measured at same locations.
    key_in = jax.random.PRNGKey(0)
    _, subkey = jax.random.split(key_in)
    jitter = 1e-6
    Nf = yf.shape[0]
    Mf = yf.shape[1]

    # Priors: Segmentation.
    gpc_var = numpyro.sample('gpc_var', ndist.Uniform(gpc_var_bounds[0], gpc_var_bounds[1])) # variance
    gpc_lengthscale = numpyro.sample('gpc_lengthscale', ndist.Uniform(gpc_ls_bounds[0], gpc_ls_bounds[1])) # ls
    gpc_bias = numpyro.sample('gpc_bias', ndist.Normal(0, 1)) # bias
        
    # Priors: GPR
    gpr_var_bound_min = gpr_var_bounds[0]*jnp.ones((num_regions,Mf))
    gpr_var_bound_max = gpr_var_bounds[1]*jnp.ones((num_regions,Mf))
    gpr_lengthscale_bound_min = gpr_ls_bounds[0]*jnp.ones((num_regions,Mf))
    gpr_lengthscale_bound_max = gpr_ls_bounds[1]*jnp.ones((num_regions,Mf))
    gpr_bias_bound_min = gpr_bias_bounds[0]*jnp.ones((num_regions,Mf))
    gpr_bias_bound_max = gpr_bias_bounds[1]*jnp.ones((num_regions,Mf))
    
    gpr_noise = numpyro.sample("gpr_noise", ndist.Uniform(gpr_noise_bounds[0], gpr_noise_bounds[1]))
    gpr_var = numpyro.sample("gpr_var", ndist.Uniform(gpr_var_bound_min, gpr_var_bound_max))
    gpr_lengthscale_x = numpyro.sample("gpr_lengthscale_x", ndist.Uniform(gpr_lengthscale_bound_min, gpr_lengthscale_bound_max))
    gpr_lengthscale_y = numpyro.sample("gpr_lengthscale_y", ndist.Uniform(gpr_lengthscale_bound_min, gpr_lengthscale_bound_max))
    gpr_bias = numpyro.sample("gpr_bias", ndist.Uniform(gpr_bias_bound_min, gpr_bias_bound_max))
    
    # ------- added --------------
    Nnew = Xnew.shape[0]
    gpc_train_latent = jnp.zeros((xf.shape[0],num_regions))
    gpc_new_latent = jnp.zeros((Nnew,num_regions))
    gpc_new_probs = jnp.zeros((Nnew,num_regions))
    # get region labels
    
    gpc_latent = [0]*num_regions
    for i in range(num_regions):
        gpc_latent[i] = numpyro.sample('gpc_latent_' + str(i), ndist.Normal(0, 1))
            
    for j in range(num_regions):
        f = compute_f_matern52_jax(gpc_var,
                  gpc_lengthscale,
                  gpc_bias,
                  gpc_latent[j], xf)
        
        gpc_train_latent = gpc_train_latent.at[:,j].set(f)
        gpc_noise = 1E-6
        mean, cov, _ = gpr_forward_matern52_jax(gpc_var, gpc_lengthscale ,xf,f, Xnew, gpc_noise, include_noise=False)
        fhat = ndist.MultivariateNormal(mean, cov + jnp.eye(Nnew) * eps).sample(subkey)
        gpc_new_latent = gpc_new_latent.at[:,j].set(fhat)
    gpc_new_probs = logits_to_probs_jax(gpc_new_latent)
    # -----------------------------  
    
    gpr_latent = [ [0]*Mf for i in range(num_regions)]
    for j in range(Mf):
        for i in range(num_regions):
            gpr_latent[i][j] = numpyro.sample('gpr_latent_'+str(i)+'_Mf_'+str(j), ndist.Normal(0, 1))
    
    # ---added -------------------------------------
    Fr_new = jnp.zeros((Nnew,num_regions,Mf))
    Vr_new = jnp.zeros((Nnew,num_regions,Mf))
    for k in range(Mf):
        for j in range(num_regions):
            eta = gpr_latent[j][k]
            gpr_lengthscale_array = jnp.array([gpr_lengthscale_x[j,k], gpr_lengthscale_y[j,k]])
            f = compute_f_jax(gpr_var[j,k],
                                gpr_lengthscale_array,
                                gpr_bias[j,k], eta, xf)
            mean, _, var = gpr_forward_jax(gpr_var[j,k],
                                         gpr_lengthscale_array,
                                         xf,f, Xnew, gpr_noise, include_noise=False)
            Fr_new = Fr_new.at[:,j,k].set(mean)
            Vr_new = Vr_new.at[:,j,k].set(var)    
    
    f_piecewise = jnp.zeros((Nnew, Mf, 1))
    v_piecewise = jnp.zeros((Nnew, Mf, 1))
    f_sample = jnp.zeros((Nnew, Mf, 1))
    for k in range(Mf):
        for j in range(num_regions):
            f_piecewise = f_piecewise.at[:,k,0].set( f_piecewise[:,k,0] + gpc_new_probs[:,j] * Fr_new[:,j,k] )
            v_piecewise = v_piecewise.at[:,k,0].set( v_piecewise[:,k,0] + gpc_new_probs[:,j] * Vr_new[:,j,k] )
        f_sample = f_sample.at[:,k,0].set( ndist.Normal(f_piecewise[:,k,0], jnp.sqrt( gpr_noise ) ).sample(subkey) ) 
        
    gpc_new_probs_ = numpyro.sample('gpc_new_probs', ndist.Delta(gpc_new_probs))
    f_piecewise_ = numpyro.sample('f_piecewise', ndist.Delta(f_piecewise))
    f_sample_ = numpyro.sample('f_sample', ndist.Delta(f_sample))
    Fr_new_ = numpyro.sample('Fr_new', ndist.Delta(Fr_new))
    v_piecewise_ = numpyro.sample('v_piecewise', ndist.Delta(v_piecewise))
    
    return gpc_new_probs_, f_piecewise_, f_sample_, Fr_new_, v_piecewise_
  
def predict_SAGE_ND_PM_230628a(Xnew, xs, ys, num_regions, eps=1E-6, gpc_var_bounds=jnp.asarray([0.1,10.]), gpc_ls_bounds=jnp.asarray([.5,10.])):
    
    # assumes all function property measurements measured at same locations.
    key_in = jax.random.PRNGKey(0)
    _, subkey = jax.random.split(key_in)
    jitter = 1e-6
    Ns = ys.shape[0]

    # Priors: Segmentation.
    gpc_var = numpyro.sample('gpc_var', ndist.Uniform(gpc_var_bounds[0], gpc_var_bounds[1])) # variance
    gpc_lengthscale = numpyro.sample('gpc_lengthscale', ndist.Uniform(gpc_ls_bounds[0], gpc_ls_bounds[1])) # ls
    gpc_bias = numpyro.sample('gpc_bias', ndist.Normal(0, 1)) # bias
           
    # ------- added --------------
    Nnew = Xnew.shape[0]
    gpc_train_latent = jnp.zeros((xs.shape[0],num_regions))
    gpc_new_latent = jnp.zeros((Nnew,num_regions))
    gpc_new_probs = jnp.zeros((Nnew,num_regions))
    # get region labels
    
    gpc_latent = [0]*num_regions
    for i in range(num_regions):
        temp = numpyro.sample('gpc_latent_' + str(i), ndist.Normal(0, 1))
        gpc_latent[i] = temp[:Ns]
            
    for j in range(num_regions):
        f = compute_f_matern52_jax(gpc_var,
                  gpc_lengthscale,
                  gpc_bias,
                  gpc_latent[j], xs)
        
        gpc_train_latent = gpc_train_latent.at[:,j].set(f)
        gpc_noise = 1E-6
        mean, cov, _ = gpr_forward_matern52_jax(gpc_var, gpc_lengthscale ,xs,f, Xnew, gpc_noise, include_noise=False)
        fhat = ndist.MultivariateNormal(mean, cov + jnp.eye(Nnew) * eps).sample(subkey)
        gpc_new_latent = gpc_new_latent.at[:,j].set(fhat)
    gpc_new_probs = logits_to_probs_jax(gpc_new_latent)
    # -----------------------------  
            
    gpc_new_probs_ = numpyro.sample('gpc_new_probs', ndist.Delta(gpc_new_probs))
    gpc_new_latent_ = numpyro.sample('gpc_new_latent', ndist.Delta(gpc_new_latent))
    
    return gpc_new_probs_, gpc_new_latent_
  
# Coreg -----------
def model_SAGE_Coreg_ND_230628a(xs_, ys_, xf_, yf_, num_regions, gpc_var_bounds = jnp.asarray([0.1,10.]), gpc_ls_bounds = jnp.asarray([.5,10.]), \
                gpr_var_bounds = jnp.asarray([0.1, 5.]), gpr_ls_bounds = jnp.asarray([.01,5.]), gpr_bias_bounds = jnp.asarray([-2.,2.]), \
                gpr_noise_bounds = jnp.asarray([0.0001,.1])):
    
    # assume all inputs are lists
    # assumes all function property measurements measured at same locations.
    jitter = 1e-6

    Ns = np.array([xs_[i].shape[0] for i in range(len(xs_))], dtype=np.int64)
    Nf = np.array([xf_[i].shape[0] for i in range(len(xf_))], dtype=np.int64)

    Ns_indices = np.concatenate( (np.zeros((1), dtype = np.int64), Ns.cumsum()) )
    Nf_indices = np.concatenate( (np.zeros((1), dtype = np.int64), Nf.cumsum()) )
    
    Mf = len(xf_) # number of functional property data sets
    Ms = len(xs_) # number of structure data sets.
    
    xs = jnp.vstack(xs_)
    xf = jnp.vstack(xf_)
    x_ = jnp.vstack([xs,xf])
    
    Nsf = x_.shape[0] # number of all data points across all sets.

    # Priors: Segmentation.
    gpc_var = numpyro.sample('gpc_var', ndist.Uniform(gpc_var_bounds[0], gpc_var_bounds[1])) # variance
    gpc_lengthscale = numpyro.sample('gpc_lengthscale', ndist.Uniform(gpc_ls_bounds[0], gpc_ls_bounds[1])) # ls
    gpc_bias = numpyro.sample('gpc_bias', ndist.Normal(0, 1)) # bias
        
    # Priors: GPR
    gpr_var_bound_min = gpr_var_bounds[0]*jnp.ones((num_regions,Mf))
    gpr_var_bound_max = gpr_var_bounds[1]*jnp.ones((num_regions,Mf))
    gpr_lengthscale_bound_min = gpr_ls_bounds[0]*jnp.ones((num_regions,Mf))
    gpr_lengthscale_bound_max = gpr_ls_bounds[1]*jnp.ones((num_regions,Mf))
    gpr_bias_bound_min = gpr_bias_bounds[0]*jnp.ones((num_regions,Mf))
    gpr_bias_bound_max = gpr_bias_bounds[1]*jnp.ones((num_regions,Mf))
    
    gpr_noise = numpyro.sample("gpr_noise", ndist.Uniform(gpr_noise_bounds[0], gpr_noise_bounds[1]))
    gpr_var = numpyro.sample("gpr_var", ndist.Uniform(gpr_var_bound_min, gpr_var_bound_max))
    gpr_lengthscale_x = numpyro.sample("gpr_lengthscale_x", ndist.Uniform(gpr_lengthscale_bound_min, gpr_lengthscale_bound_max))
    gpr_lengthscale_y = numpyro.sample("gpr_lengthscale_y", ndist.Uniform(gpr_lengthscale_bound_min, gpr_lengthscale_bound_max))
    gpr_bias = numpyro.sample("gpr_bias", ndist.Uniform(gpr_bias_bound_min, gpr_bias_bound_max))


    # Get latent functions, one for each region (i.e., segment).
    Fc = jnp.zeros((Nsf,num_regions))
    for i in range(num_regions):
        with numpyro.plate('gpc_latent_response' + str(i), Nsf):
            gpc_latent = numpyro.sample('gpc_latent_' + str(i), ndist.Normal(0, 1))
        # print('gpc_latent', gpc_latent.shape, 'x_', x_.shape)
        f = compute_f_matern52_jax(gpc_var, gpc_lengthscale, gpc_bias, gpc_latent, x_)
        Fc = Fc.at[:,i].set(f) # x = x.at[idx].set(y)
    
    probs = logits_to_probs_jax(Fc)

    # predicted the region label for each functional property data point.
    Ns_sum = Ns.sum()
    probs_fp_ = [] # probs[Ns_sum:,:].double()
    probs_st_ = []
    
    for i in range(Ms):
        probs_st_.append( dynamic_slice(probs, (Ns_indices[i],0), (Ns[i],2) ) )
        
    # !!!!!!!! CHECK THIS !!!!!!!!!!!!
    for i in range(Mf):
        probs_fp_.append( dynamic_slice(probs, (Ns_sum + Nf_indices[i],0), (Nf[i],2) ) )
        
    # gpr for each region.
    Fr_ = []
    for j in range(Mf):
        fr = jnp.zeros((Nf[j],num_regions))
        for i in range(num_regions):
            with numpyro.plate('gpr_latent_response' + str(i), Nf[j]):
                gpr_latent = numpyro.sample('gpr_latent_'+str(i)+'_Mf_'+str(j), ndist.Normal(0, 1))
    
            gpr_lengthscale_array = jnp.array([gpr_lengthscale_x[i,j], gpr_lengthscale_y[i,j]])
            f = compute_f_jax(gpr_var[i,j], gpr_lengthscale_array, gpr_bias[i,j], gpr_latent, xf_[j])
            fr = fr.at[:,i].set(f)
        Fr_.append(fr)

    f_piecewise_ = []
    for j in range(Mf):
        fpw = jnp.zeros((Nf[j]))
        for i in range(num_regions):
            fpw = fpw.at[:].set( fpw + probs_fp_[j][:,i] * Fr_[j][:,i] )
        f_piecewise_.append(fpw)
            
    llk = ndist.Categorical(probs=probs_st_[0]).log_prob(ys_[0].flatten()).sum()
    for i in range(1,Ms):
        llk += ndist.Categorical(probs=probs_st_[i]).log_prob(ys_[i].flatten()).sum()

    for j in range(Mf):
        llk = llk + ndist.Normal(f_piecewise_[j], jnp.sqrt( gpr_noise ) ).log_prob(yf_[j]).sum()     

    numpyro.deterministic("llk", llk)
    numpyro.factor("obs", llk )
    
def model_SAGE_Coreg_ND_PM_230628a(xs_, ys_, xf_, num_regions, gpc_var_bounds = jnp.asarray([0.1,10.]), gpc_ls_bounds = jnp.asarray([.5,10.])):
    
    # assume all inputs are lists
    # assumes all function property measurements measured at same locations.
    jitter = 1e-6

    Ns = np.array([xs_[i].shape[0] for i in range(len(xs_))], dtype=np.int64)
    Nf = np.array([xf_[i].shape[0] for i in range(len(xf_))], dtype=np.int64)

    Ns_indices = np.concatenate( (np.zeros((1), dtype = np.int64), Ns.cumsum()) )
    Nf_indices = np.concatenate( (np.zeros((1), dtype = np.int64), Nf.cumsum()) )
    
    Mf = len(xf_) # number of functional property data sets
    Ms = len(xs_) # number of structure data sets.
    
    xs = jnp.vstack(xs_)
    xf = jnp.vstack(xf_)
    x_ = jnp.vstack([xs,xf])
    
    Nsf = x_.shape[0] # number of all data points across all sets.
    
    # Priors: Segmentation.
    gpc_var = numpyro.sample('gpc_var', ndist.Uniform(gpc_var_bounds[0], gpc_var_bounds[1])) # variance
    gpc_lengthscale = numpyro.sample('gpc_lengthscale', ndist.Uniform(gpc_ls_bounds[0], gpc_ls_bounds[1])) # ls
    gpc_bias = numpyro.sample('gpc_bias', ndist.Normal(0, 1)) # bias
        
    # Get latent functions, one for each region (i.e., segment).
    Fc = jnp.zeros((x_.shape[0],num_regions))
    for i in range(num_regions):
        with numpyro.plate('gpc_latent_response' + str(i), Nsf):
            gpc_latent = numpyro.sample('gpc_latent_' + str(i), ndist.Normal(0, 1))
        
        f = compute_f_matern52_jax(gpc_var, gpc_lengthscale, gpc_bias, gpc_latent, x_)
        Fc = Fc.at[:,i].set(f) # x = x.at[idx].set(y)
    
    probs = logits_to_probs_jax(Fc)

    # predicted the region label for each functional property data point.
    Ns_sum = Ns.sum()
    probs_st_ = []
    
    for i in range(Ms):
        probs_st_.append( dynamic_slice(probs, (Ns_indices[i],0), (Ns[i],2) ) )
                    
    llk = ndist.Categorical(probs=probs_st_[0]).log_prob(ys_[0].flatten()).sum()
    for i in range(1,Ms):
        llk += ndist.Categorical(probs=probs_st_[i]).log_prob(ys_[i].flatten()).sum()

    numpyro.deterministic("llk", llk)
    numpyro.factor("obs", llk )

def predict_SAGE_Coreg_ND_230628a(Xnew, xs_, ys_, xf_, yf_, num_regions, eps=1E-6, gpc_var_bounds = jnp.asarray([0.1,10.]), gpc_ls_bounds = jnp.asarray([.5,10.]), \
                gpr_var_bounds = jnp.asarray([0.1, 5.]), gpr_ls_bounds = jnp.asarray([.01,5.]), gpr_bias_bounds = jnp.asarray([-2.,2.]), \
                gpr_noise_bounds = jnp.asarray([0.0001,.1])):
    
    # assume all inputs are lists
    # assumes all function property measurements measured at same locations.
    
    key_in = jax.random.PRNGKey(0)
    _, subkey = jax.random.split(key_in)
    
    jitter = 1e-6

    Ns = np.array([xs_[i].shape[0] for i in range(len(xs_))], dtype=np.int64)
    Nf = np.array([xf_[i].shape[0] for i in range(len(xf_))], dtype=np.int64)

    Ns_indices = np.concatenate( (np.zeros((1), dtype = np.int64), Ns.cumsum()) )
    Nf_indices = np.concatenate( (np.zeros((1), dtype = np.int64), Nf.cumsum()) )
    
    Mf = len(xf_) # number of functional property data sets
    Ms = len(xs_) # number of structure data sets.
    Nnew = Xnew.shape[0]
    
    xs = jnp.vstack(xs_)
    xf = jnp.vstack(xf_)
    x_ = jnp.vstack([xs,xf])
    
    Nsf = x_.shape[0] # number of all data points across all sets.

    # Priors: Segmentation.
    gpc_var = numpyro.sample('gpc_var', ndist.Uniform(gpc_var_bounds[0], gpc_var_bounds[1])) # variance
    gpc_lengthscale = numpyro.sample('gpc_lengthscale', ndist.Uniform(gpc_ls_bounds[0], gpc_ls_bounds[1])) # ls
    gpc_bias = numpyro.sample('gpc_bias', ndist.Normal(0, 1)) # bias
        
    # Priors: GPR
    gpr_var_bound_min = gpr_var_bounds[0]*jnp.ones((num_regions,Mf))
    gpr_var_bound_max = gpr_var_bounds[1]*jnp.ones((num_regions,Mf))
    gpr_lengthscale_bound_min = gpr_ls_bounds[0]*jnp.ones((num_regions,Mf))
    gpr_lengthscale_bound_max = gpr_ls_bounds[1]*jnp.ones((num_regions,Mf))
    gpr_bias_bound_min = gpr_bias_bounds[0]*jnp.ones((num_regions,Mf))
    gpr_bias_bound_max = gpr_bias_bounds[1]*jnp.ones((num_regions,Mf))
    
    gpr_noise = numpyro.sample("gpr_noise", ndist.Uniform(gpr_noise_bounds[0], gpr_noise_bounds[1]))
    gpr_var = numpyro.sample("gpr_var", ndist.Uniform(gpr_var_bound_min, gpr_var_bound_max))
    gpr_lengthscale_x = numpyro.sample("gpr_lengthscale_x", ndist.Uniform(gpr_lengthscale_bound_min, gpr_lengthscale_bound_max))
    gpr_lengthscale_y = numpyro.sample("gpr_lengthscale_y", ndist.Uniform(gpr_lengthscale_bound_min, gpr_lengthscale_bound_max))
    gpr_bias = numpyro.sample("gpr_bias", ndist.Uniform(gpr_bias_bound_min, gpr_bias_bound_max))

    # --- added ----------------------------
    gpc_train_latent = jnp.zeros((x_.shape[0],num_regions))
    gpc_new_latent = jnp.zeros((Nnew,num_regions))
    gpc_new_probs = jnp.zeros((Nnew,num_regions))
    
    gpc_latent = [0]*num_regions
    for i in range(num_regions):
        gpc_latent[i] = numpyro.sample('gpc_latent_' + str(i), ndist.Normal(0, 1))
        
    # get region labels
    for j in range(num_regions):
        f = compute_f_matern52_jax(gpc_var,
                  gpc_lengthscale,
                  gpc_bias,
                  gpc_latent[j], x_)
        gpc_train_latent.at[:,j].set(f)
        gpc_noise = 1E-5
        mean, cov, _ = gpr_forward_matern52_jax(gpc_var, gpc_lengthscale ,x_,f, Xnew, gpc_noise, include_noise=False)
        
        fhat = ndist.MultivariateNormal(mean, cov + jnp.eye(Nnew) * eps).sample(subkey)
        gpc_new_latent = gpc_new_latent.at[:,j].set(fhat)
        
    gpc_new_probs = logits_to_probs_jax(gpc_new_latent)

    # get gpr
    Fr_new = jnp.zeros((Nnew,num_regions,Mf))
    Vr_new = jnp.zeros((Nnew,num_regions,Mf))
    
    gpr_latent = [ [0]*Mf for i in range(num_regions)]
    for j in range(Mf):
        for i in range(num_regions):
            gpr_latent[i][j] = numpyro.sample('gpr_latent_'+str(i)+'_Mf_'+str(j), ndist.Normal(0, 1))    
    
    for k in range(Mf):
        for j in range(num_regions):
            gpr_lengthscale_array = jnp.array([gpr_lengthscale_x[j,k], gpr_lengthscale_y[j,k]])
            f = compute_f_jax(gpr_var[j,k],
                                gpr_lengthscale_array,
                                gpr_bias[j,k], gpr_latent[j][k], xf_[k])
            mean, _, var = gpr_forward_jax(gpr_var[j,k],
                                         gpr_lengthscale_array,
                                         xf_[k],f, Xnew, gpr_noise, include_noise=False)
            Fr_new = Fr_new.at[:,j,k].set(mean)
            Vr_new = Vr_new.at[:,j,k].set(var)
        
    f_piecewise = jnp.zeros((Nnew, Mf, 1)) # last dimension added for stacking purposes in plotting func.
    v_piecewise = jnp.zeros((Nnew, Mf, 1))
    f_sample = jnp.zeros((Nnew, Mf, 1))
    for k in range(Mf):
        for j in range(num_regions):
            f_piecewise = f_piecewise.at[:,k,0].set( f_piecewise[:,k,0] + gpc_new_probs[:,j] * Fr_new[:,j,k] )
            v_piecewise = v_piecewise.at[:,k,0].set( v_piecewise[:,k,0] + gpc_new_probs[:,j] * Vr_new[:,j,k] )
        f_sample = f_sample.at[:,k,0].set( ndist.Normal(f_piecewise[:,k,0], jnp.sqrt( gpr_noise ) ).sample(subkey) )         

        
    gpc_new_probs_ = numpyro.sample('gpc_new_probs', ndist.Delta(gpc_new_probs))
    f_piecewise_ = numpyro.sample('f_piecewise', ndist.Delta(f_piecewise))
    f_sample_ = numpyro.sample('f_sample', ndist.Delta(f_sample))
    Fr_new_ = numpyro.sample('Fr_new', ndist.Delta(Fr_new))
    v_piecewise_ = numpyro.sample('v_piecewise', ndist.Delta(v_piecewise))
    
    return gpc_new_probs_, f_piecewise_, f_sample_, Fr_new_, v_piecewise_

def predict_SAGE_Coreg_ND_PM_230628a(Xnew, xs_, ys_, xf_, num_regions, eps=1E-6, gpc_var_bounds = jnp.asarray([0.1,10.]), gpc_ls_bounds = jnp.asarray([.5,10.])):
    
    # assume all inputs are lists
    # assumes all function property measurements measured at same locations.
    
    key_in = jax.random.PRNGKey(0)
    _, subkey = jax.random.split(key_in)
    
    jitter = eps
    
    xs = jnp.vstack(xs_)
    xf = jnp.vstack(xf_)
    x_ = jnp.vstack([xs,xf])

    Ns = np.array([xs_[i].shape[0] for i in range(len(xs_))], dtype=np.int64)

    Ns_indices = np.concatenate( (np.zeros((1), dtype = np.int64), Ns.cumsum()) )

    Ms = len(xs_) # number of structure data sets.
    Nnew = Xnew.shape[0]
    
    xs = jnp.vstack(xs_)
    
    # Priors: Segmentation.
    gpc_var = numpyro.sample('gpc_var', ndist.Uniform(gpc_var_bounds[0], gpc_var_bounds[1])) # variance
    gpc_lengthscale = numpyro.sample('gpc_lengthscale', ndist.Uniform(gpc_ls_bounds[0], gpc_ls_bounds[1])) # ls
    gpc_bias = numpyro.sample('gpc_bias', ndist.Normal(0, 1)) # bias
        
    # --- added ----------------------------
    gpc_train_latent = jnp.zeros((x_.shape[0],num_regions))
    gpc_new_latent = jnp.zeros((Nnew,num_regions))
    gpc_new_probs = jnp.zeros((Nnew,num_regions))
    
    gpc_latent = [0]*num_regions
    for i in range(num_regions):
        gpc_latent[i] = numpyro.sample('gpc_latent_' + str(i), ndist.Normal(0, 1))
        
    # get region labels
    for j in range(num_regions):
        f = compute_f_matern52_jax(gpc_var,
                  gpc_lengthscale,
                  gpc_bias,
                  gpc_latent[j], x_)
        gpc_train_latent.at[:,j].set(f)
        gpc_noise = 1E-5
        mean, cov, _ = gpr_forward_matern52_jax(gpc_var, gpc_lengthscale ,x_,f, Xnew, gpc_noise, include_noise=False)
        
        fhat = ndist.MultivariateNormal(mean, cov + jnp.eye(Nnew) * eps).sample(subkey)
        gpc_new_latent = gpc_new_latent.at[:,j].set(fhat)
        
    gpc_new_probs = logits_to_probs_jax(gpc_new_latent)
     
    gpc_new_probs_ = numpyro.sample('gpc_new_probs', ndist.Delta(gpc_new_probs))
    
    return gpc_new_probs_      
       
# -------------------------------------------------   
def logits_to_probs_jax(logits):
    # assumes obs x num_of_categories
    logits = logits - jax.nn.logsumexp(logits, axis=-1, keepdims=True)
    probs = jax.nn.softmax(logits, axis=-1)
    return probs

# Joint analysis with coregionalized functional properties.
def remap_array(v):
    vnew = torch.zeros(v.shape)
    uv = torch.unique(v)
    for i in range(uv.shape[0]):
        vnew[v == uv[i]] = i
    return vnew

def flip_keys_and_indices(samples, step = 1):
    s = []
    K = list(samples.keys())
    Nf = samples['gpr_noise'].shape[0]
    
    for n in tqdm(np.arange(0,Nf,step)):
        temp = {}
        for k in K:
            temp[k]=samples[k][n]
        temp['seed'] = n
        s.append(temp)
    return s

def gpr_forward_jax(variance,lengthscales,xtrain,ytrain,xnew,noise_var,include_noise = True):
    # n is new, t is train
    K_nt = RBF_jax(variance, lengthscales, xnew, xtrain)
    K_tt = RBF_jax(variance, lengthscales, xtrain, xtrain)
    K_nn = RBF_jax(variance, lengthscales, xnew, xnew)
    I_noise = jnp.eye(K_tt.shape[0])*(noise_var + 1E-6)
    L = jnp.linalg.inv(K_tt + I_noise)
    mean = jnp.matmul(K_nt,jnp.matmul(L,ytrain.flatten()[:,None]))
    cov = K_nn - jnp.matmul(K_nt, jnp.matmul(L,K_nt.T) )
    if include_noise:
        cov = cov + jnp.eye(cov.shape[0])*noise_var
    var = jnp.diagonal(cov)
    return mean.flatten(), cov, var.flatten()

def gpr_forward_matern52_jax(variance,lengthscale,xtrain,ytrain,xnew,noise_var,include_noise = True):
    # n is new, t is train
    K_nt = Matern52_2D_jax(variance, lengthscale, xnew, xtrain)
    K_tt = Matern52_2D_jax(variance, lengthscale, xtrain, xtrain)
    K_nn = Matern52_2D_jax(variance, lengthscale, xnew, xnew)
    I_noise = jnp.eye(K_tt.shape[0])*(noise_var + 1E-6)
    L = jnp.linalg.inv(K_tt + I_noise)
    mean = jnp.matmul(K_nt,jnp.matmul(L,ytrain.flatten()[:,None]))
    cov = K_nn - jnp.matmul(K_nt, jnp.matmul(L,K_nt.T) )
    if include_noise:
        cov = cov + jnp.eye(cov.shape[0])*noise_var
    var = jnp.diagonal(cov)
    return mean.flatten(), cov, var.flatten()

def RBF_jax(variance, lengthscales, X, Z = None):
        if Z is None:
            Z = X.copy()
    #     if jnp.isscalar(lengthscales):
    #         lengthscales = lengthscales*jnp.ones((2))
        scaled_X = X / lengthscales
        scaled_Z = Z / lengthscales
        X2 = (scaled_X**2).sum(1, keepdims=True)
        Z2 = (scaled_Z**2).sum(1, keepdims=True)
        XZ = jnp.matmul(scaled_X, scaled_Z.T)
        r2 = X2 - XZ + Z2.T
        return variance * jnp.exp(-0.5 * r2)

def Matern52_2D_jax(variance, lengthscale, X, Z = None):
    if Z is None:
        Z = X.copy()

    kernel0 = gpx.kernels.Matern52(lengthscale=lengthscale, variance=variance)
    kernel1 = gpx.kernels.Matern52(lengthscale=lengthscale, variance=variance)
    prod_kernel = gpx.kernels.ProductKernel(kernels=[kernel0, kernel1])
    
    return prod_kernel.cross_covariance(X, Z)
    
def euclidean_jax(X1, X2 = None):
    if X2 is None:
        X2 = X1.copy()
    c = X1[:,None]-X2[None,:]
    return jnp.sqrt(jnp.sum(c**2, axis = 2))
    
def compute_f_jax(variance, lengthscales, bias, eta, X):
    N = X.shape[0]
    K = RBF_jax(variance, lengthscales, X) + jnp.eye(N) * 1e-6
    L = jnp.linalg.cholesky(K)
    return jnp.matmul(L, eta) + bias

def compute_f_matern52_jax(variance, lengthscale, bias, eta, X):
    N = X.shape[0]
    K = Matern52_2D_jax(variance, lengthscale, X) + jnp.eye(N) * 1e-6
    L = jnp.linalg.cholesky(K)
    return jnp.matmul(L, eta) + bias

def gen_data_2D_example(x,y):
    L = torch.zeros((x.shape[0]))
    r = torch.sqrt(x**2 + y**2)
    for i in range(x.shape[0]):
        if r[i] < 1.:
            L[i] = 1
#         elif y[i] > x[i] + 2:
#             L[i] = 2
            
    # f02 = torch.exp(-.5*((x+1.5)**2+(y-1.5)**2)/ 1)
    f01 = 1.2-.5*torch.exp(-.5*(x**2+y**2)/ 2.)
    f00 = torch.exp(-.5*((x-1.5)**2+(y-1.5)**2)/ .2)
    f0 = torch.zeros(x.shape)
    f0[L == 0] = f00[L == 0]
    f0[L == 1] = f01[L == 1]
    # f0[L == 2] = f02[L == 2]
    
    # f12 = .3*torch.exp(-.5*((x+2.)**2+(y-1.)**2)/ 1)
    f11 = 1.5*torch.exp(-.5*(x**2+y**2)/ 1)
    f10 = torch.exp(-.5*((x+1.5)**2+(y+1.5)**2)/ .2)
    f1 = torch.zeros(x.shape)
    f1[L == 0] = f10[L == 0]
    f1[L == 1] = f11[L == 1]
    # f1[L == 2] = f12[L == 2]
    f = torch.hstack((f0[:,None], f1[:,None]))
    return L, f

def compare_inputs_jax(Xnew, x):
    m_Xn_x = jnp.zeros(Xnew.shape[0], dtype=jnp.integer)
    idx_Xnew_match_x = []
    idx_x_match_Xnew = []
    for i in range(x.shape[0]):
        temp = diff_mat_row_jax(Xnew,x[i,:][None,:])
        m_Xn_x += temp
        idx = jnp.nonzero(temp, size=1, fill_value=-1)[0][0]
        if idx > -1:
            idx_Xnew_match_x.append(idx)
            idx_x_match_Xnew.append(i)
    idx_Xnew_match_x = jnp.asarray(idx_Xnew_match_x, dtype=jnp.integer)
    idx_x_match_Xnew = jnp.asarray(idx_x_match_Xnew, dtype=jnp.integer)
    return m_Xn_x, idx_Xnew_match_x, idx_x_match_Xnew
        
def diff_mat_row_jax(M,r):
    d = jnp.sum( (M - jnp.tile(r,(M.shape[0],1)))**2, axis = 1)
    return d < 1e-6

### 2D Challenges

#### Set up 2D Challenge data.
- Challenge 1: Structure data is more informative of phase boundaries.
- Challenge 2: Functional property is more informative of phase boundaries.
- Challenge 3: Demonstrate N-Dimensional Coregionalization

In [None]:
# Challenge 1 ------------------------------
N = 41
xp_,yp_ = torch.meshgrid(torch.linspace(-2,2,N),torch.linspace(-2,2,N),indexing='xy')
xp = torch.round(xp_.flatten(),decimals=2)
yp = torch.round(yp_.flatten(),decimals=2)
Xp = torch.hstack((xp[:,None],yp[:,None])).double()
Lp, _ = gen_data_2D_example(xp,yp)

N = 21
x_,y_ = torch.meshgrid(torch.linspace(-2,2,N),torch.linspace(-2,2,N),indexing='xy')
x = torch.round(x_.flatten(), decimals=2)
y = torch.round(y_.flatten(), decimals=2)
X = torch.hstack((x[:,None],y[:,None])).double()
L, _ = gen_data_2D_example(x,y)
Xnew = X

r = torch.sqrt(Xp[:,0]**2 + Xp[:,1]**2)
k1 = gp.kernels.RBF(2, variance=torch.tensor([1.]), lengthscale=torch.tensor([1.,1.])).forward(Xp).detach().numpy()
Z1 = np.random.default_rng(0).multivariate_normal(5*torch.ones(Xp.shape[0]), k1, 5)

k2 = gp.kernels.RBF(2, variance=torch.tensor([1.]), lengthscale=torch.tensor([1.,1.])).forward(Xp).detach().numpy()
Z2 = np.random.default_rng(0).multivariate_normal(torch.zeros(Xp.shape[0]), k2, 5)

Zj = Z2.copy()
for i in range(5):
    Zj[i,r<1] = Z1[i,r<1]
    
Zj = torch.tensor( Zj )
f = torch.cat([Zj[0,:][:,None],Zj[1,:][:,None]],axis=1)

with open(r"2D_2a_and_2b_fp_231030a.dill", "wb") as output_file:
    dill.dump(f, output_file)


num_data_points_st = 20
num_data_points_fp = 60

seed = 0
top = torch.tensor([912,914, 994,996, 1078,1080, 1162,1164, 1244,1246, 1248,1166, 1250,1168, 1252,1170, 1254,1172, 1256,1174, 1176,1092,
                    1160, 1076, 1094, 1096, 1012,1014, 930,932, 848,850])
bottom = torch.tensor([504, 520, 584, 604, 620, 766,768, 684,686, 600,602, 516,518, 436, 434, 514,432, 512,430, 510,428, 508,426, 424, 506, 586,588, 666,668, 748,750, 830,832])
extra = torch.tensor( default_rng(seed).choice(Xp.shape[0],num_data_points_st,replace=False) )

def map_indices(Xp, idx):  
    for i in range(idx.shape[0]):
        if (10*Xp[idx[i],0] % 2):
            idx[i] += 1
        if (10*Xp[idx[i],1] % 2):
            idx[i] = idx[i] + 41
    return idx

extra = map_indices(Xp, extra)
kp_st = torch.unique(torch.cat([extra, top, bottom, torch.tensor([902, 1680, 10])]))

kp_fp = torch.tensor( default_rng(seed+1).permutation(Xp.shape[0])[:num_data_points_fp] )
kp_fp = map_indices(Xp, kp_fp)
temp = torch.sqrt(Xp[kp_fp,0]**2 + Xp[kp_fp,1]**2)
kp_fp = kp_fp[torch.logical_or( temp > 1.5, temp < 1.)]

kp_st_2d1 = kp_st.numpy().copy()
kp_fp_2d1 = kp_fp.numpy().copy()

xs = Xp[kp_st,:].double()
ys = Lp[kp_st].double()
xf = Xp[kp_fp,:].double()
yf = f[kp_fp,0][:,None].double()
# yf += torch.normal(torch.zeros(yf.shape),.01)

xs_2a = xs.clone()
ys_2a = ys.clone()
xf_2a = xf.clone()
yf_2a = yf.clone()

plt.figure(figsize = (12,12))
plt.scatter(x,y,c=L,s=10)
for i in range(Xp.shape[0]):
    plt.text(Xp[i,0],Xp[i,1],str(i),fontsize=5)

plt.figure(figsize = (6,2.5),dpi=300)
plt.subplot(1,2,1)
plt.scatter(xp,yp,c=Lp,s=10)
plt.plot(xs_2a[:,0],xs_2a[:,1],'s',markerfacecolor='none',markeredgecolor='r',markersize=4.5)
plt.subplot(1,2,2)
plt.scatter(xp,yp,c=f[:,0],s=10)
plt.plot(xf_2a[:,0],xf_2a[:,1],'s',markerfacecolor='none',markeredgecolor='r',markersize=4.5)
# plt.savefig('2a_ground_truth.png',transparent=True)

# Challenge 2 ------------------------
seed = 0
num_data_points_st = 60
num_data_points_fp = 40

kp_st = torch.tensor( default_rng(seed).permutation(Xp.shape[0])[:num_data_points_st] )
lines = torch.cat([torch.arange(820,860,2), torch.arange(20,1660,82), torch.arange(0,1680,84), torch.arange(40,1640,80)]).long()

temp = torch.tensor( default_rng(seed+1).choice(X.shape[0],num_data_points_fp,replace=False) )
kp_fp = torch.unique(torch.cat([lines, top, bottom]))#,temp ,in_center])) #, torch.tensor([398,150,130,229,269,166,167,211,213,356,357,353,277,399,339,396,209])]))

temp = torch.sqrt(Xp[kp_st,0]**2 + Xp[kp_st,1]**2)
kp_st = kp_st[torch.logical_or( temp > 1.4, temp < .5)]
kp_st = torch.unique(torch.cat([kp_st, torch.tensor([76, 18, 189, 192,229,185,129])]))

kp_st_2d2 = kp_st.numpy().copy()
kp_fp_2d2 = kp_fp.numpy().copy()

xs_2b = Xp[kp_st,:].double().clone()
ys_2b = Lp[kp_st].double().clone()
xf_2b = Xp[kp_fp,:].double().clone()
yf_2b = f[kp_fp,0][:,None].double().clone()

plt.figure(figsize = (6,2.5),dpi=300)
plt.subplot(1,2,1)
plt.scatter(xp,yp,c=Lp,s=10)
plt.plot(xs_2b[:,0],xs_2b[:,1],'s',markerfacecolor='none',markeredgecolor='r',markersize=4.5)
plt.subplot(1,2,2)
plt.scatter(xp,yp,c=f[:,0],s=10)
plt.plot(xf_2b[:,0],xf_2b[:,1],'s',markerfacecolor='none',markeredgecolor='r',markersize=4.5)
# plt.savefig('2b_ground_truth.png',transparent=True)

plt.show()
# Challenge 3 -----------------------------
seed = 1
num_data_points_st0 = 30
num_data_points_st1 = 30
num_data_points_fp0 = 40
num_data_points_fp1 = 40
temp = torch.tensor( default_rng(seed+0).permutation(Xp.shape[0])[:num_data_points_st0] )
kp_st0 = torch.cat([top,temp,torch.tensor([1680])])
temp = torch.tensor( default_rng(seed+1).permutation(Xp.shape[0])[:num_data_points_st1] )
kp_st1 = torch.cat([bottom,temp,torch.tensor([178])])
kp_fp0 = torch.tensor( default_rng(seed+4).permutation(Xp.shape[0])[:num_data_points_fp0] )
kp_fp1 = torch.tensor( default_rng(seed+3).permutation(Xp.shape[0])[:num_data_points_fp1] )

kp_st0 = map_indices(Xp, kp_st0)
kp_st1 = map_indices(Xp, kp_st1)
kp_fp0 = map_indices(Xp, kp_fp0)
kp_fp1 = map_indices(Xp, kp_fp1)

# These should be lists.
Xs_ = [Xp[kp_st0,:].double(), Xp[kp_st1,:].double()]
Xf_ = [Xp[kp_fp0,:].double(), Xp[kp_fp1,:].double()]
ys_ = [Lp[kp_st0].double(), Lp[kp_st1].double()]
yf_ = [f[kp_fp0,0].double(), f[kp_fp1,1].double()]

plt.figure(figsize = (6,2.5),dpi=300)
plt.subplot(1,2,1)
plt.scatter(xp,yp,c=Lp,s=10)
plt.plot(Xs_[0][:,0],Xs_[0][:,1],'s',markerfacecolor='none',markeredgecolor='r',markersize=4.5)
plt.subplot(1,2,2)
plt.scatter(xp,yp,c=Lp,s=10)
plt.plot(Xs_[1][:,0],Xs_[1][:,1],'s',markerfacecolor='none',markeredgecolor='r',markersize=4.5)
# plt.savefig('2c_st_ground_truth.png',transparent=True)

plt.figure(figsize = (6,2.5),dpi=300)
plt.subplot(1,2,1)
plt.scatter(xp,yp,c=f[:,0],s=10)
plt.plot(Xf_[0][:,0],Xf_[0][:,1],'s',markerfacecolor='none',markeredgecolor='r',markersize=4.5)
plt.subplot(1,2,2)
plt.scatter(xp,yp,c=f[:,1],s=10)
plt.plot(Xf_[1][:,0],Xf_[1][:,1],'s',markerfacecolor='none',markeredgecolor='r',markersize=4.5)
# plt.savefig('2c_fp_ground_truth.png',transparent=True)

plt.show()

S = form_graph(Xp)
plt.figure(figsize = (10,10))
plot_graph(S, Xp)

with open(r"2D_2a_and_2b_points_240718a.dill", "wb") as output_file:
    dill.dump([Xp.numpy(), kp_st_2d1,kp_fp_2d1, kp_st_2d2, kp_fp_2d2, xs_2a.numpy(), ys_2a.numpy(), xf_2a.numpy(), yf_2a.numpy(), xs_2b.numpy(), ys_2b.numpy(), xf_2b.numpy(), yf_2b.numpy()], output_file)

#### Present multicore and 1 core scripts for each challenge and algorithm

##### Challenge 1: SAGE-ND, multicore

In [None]:
%%writefile sage_2D_2an_matern52_with_1init_230804a.py
# Unified for init 2a
from sage_2D_functions_230804a import predict_SAGE_ND_PM_230628a, predict_SAGE_ND_230628a, model_SAGE_ND_230628a
from sage_2D_functions_230804a import gen_data_2D_example, model_SAGE_ND_FP_230628a, model_SAGE_ND_PM_230628a

import matplotlib.pyplot as plt

import numpyro
import numpy as np
from numpy.random import default_rng

import torch
from torch.distributions import constraints
import scipy.io as sio

import jax
import jax.numpy as jnp
from jax import config
config.update("jax_enable_x64", True)
from jax.lax import dynamic_slice
from jax.nn import one_hot as jax_one_hot

from numpyro.infer import MCMC as nMCMC
from numpyro.infer import NUTS as nNUTS
from numpyro.infer import SVI as nSVI
from numpyro.infer import Predictive as nPredictive
from numpyro.infer import Trace_ELBO as nTrace_ELBO
import numpyro.distributions as ndist

torch.set_default_dtype(torch.float64)
from tqdm import trange
import dill
from tqdm import tqdm, trange
from torch.multiprocessing import Pool, Manager, Process
import functools
import math
import time
from numpyro import handlers
from numpyro.infer.initialization import init_to_value

numpyro.set_host_device_count(200)

if __name__ == '__main__':
    num_proc = 100

    
    N = 41
    xp_,yp_ = torch.meshgrid(torch.linspace(-2,2,N),torch.linspace(-2,2,N),indexing='xy')
    xp = torch.round(xp_.flatten(),decimals=2)
    yp = torch.round(yp_.flatten(),decimals=2)
    Xp = torch.hstack((xp[:,None],yp[:,None])).double()
    Lp, _ = gen_data_2D_example(xp,yp)

    r = torch.sqrt(Xp[:,0]**2 + Xp[:,1]**2)

    with open(r"2D_2a_and_2b_fp_231030a.dill", "rb") as input_file:
        f = dill.load(input_file)


    num_data_points_st = 20
    num_data_points_fp = 60

    seed = 0
    top = torch.tensor([912,914, 994,996, 1078,1080, 1162,1164, 1244,1246, 1248,1166, 1250,1168, 1252,1170, 1254,1172, 1256,1174, 1176,1092,
                        1160, 1076, 1094, 1096, 1012,1014, 930,932, 848,850])
    bottom = torch.tensor([504, 520, 584, 604, 620, 766,768, 684,686, 600,602, 516,518, 436, 434, 514,432, 512,430, 510,428, 508,426, 424, 506, 586,588, 666,668, 748,750, 830,832])
    extra = torch.tensor( default_rng(seed).choice(Xp.shape[0],num_data_points_st,replace=False) )

    def map_indices(Xp, idx):  
        for i in range(idx.shape[0]):
            if (10*Xp[idx[i],0] % 2):
                idx[i] += 1
            if (10*Xp[idx[i],1] % 2):
                idx[i] = idx[i] + 41
        return idx

    extra = map_indices(Xp, extra)
    kp_st = torch.unique(torch.cat([extra, top, bottom, torch.tensor([902, 1680, 10])]))

    kp_fp = torch.tensor( default_rng(seed+1).permutation(Xp.shape[0])[:num_data_points_fp] )
    kp_fp = map_indices(Xp, kp_fp)
    temp = torch.sqrt(Xp[kp_fp,0]**2 + Xp[kp_fp,1]**2)
    kp_fp = kp_fp[torch.logical_or( temp > 1.5, temp < 1.)]

    xs = Xp[kp_st,:].double()
    ys = Lp[kp_st].double()
    xf = Xp[kp_fp,:].double()
    yf = f[kp_fp,0][:,None].double()
    # yf += torch.normal(torch.zeros(yf.shape),.01)
    starting_data = [Xp, Lp, f, xs, ys, xf, yf]

    xs_2a = xs.clone()
    ys_2a = ys.clone()
    xf_2a = xf.clone()
    yf_2a = yf.clone()

    xs = jnp.asarray( xs_2a.detach().numpy(), dtype=jnp.float64).copy()
    ys = jnp.asarray( ys_2a.detach().numpy(), dtype=jnp.integer).copy()
    xf = jnp.asarray( xf_2a.detach().numpy(), dtype=jnp.float64).copy()
    yf = jnp.asarray( yf_2a.detach().numpy(), dtype=jnp.float64).copy()

    Ns = xs.shape[0] + xf.shape[0]

    Nn = 40
    xn_,yn_ = torch.meshgrid(torch.linspace(-2,2,Nn),torch.linspace(-2,2,Nn),indexing='xy')
    xn = torch.round(xn_.flatten(),decimals=2)
    yn = torch.round(yn_.flatten(),decimals=2)
    X40 = torch.hstack((xn[:,None],yn[:,None])).double()
    Xnew_ = jnp.asarray( X40.detach().numpy(), dtype=jnp.float64).copy()


    key = jax.random.PRNGKey(0)
    num_regions = 2

    def predict_structure(post_samples, model, *args, **kwargs):
        key = jax.random.PRNGKey(0)
        model = handlers.seed(handlers.condition(model, post_samples), key)
        model_trace = handlers.trace(model).get_trace(*args, **kwargs)
        return model_trace["gpc_new_probs"]["value"], model_trace["gpc_new_latent"]["value"]

    def predict_sage(post_samples, model, *args, **kwargs):
        key = jax.random.PRNGKey(0)
        model = handlers.seed(handlers.condition(model, post_samples), key)
        model_trace = handlers.trace(model).get_trace(*args, **kwargs)
        return model_trace["Fr_new"]["value"], model_trace["f_piecewise"]["value"], model_trace["f_sample"]["value"], model_trace["gpc_new_probs"]["value"], model_trace["v_piecewise"]["value"]

    predict_fn_structure = lambda samples: predict_structure(
            samples, predict_SAGE_ND_PM_230628a, Xnew=Xnew_, xs=xs, ys=ys, num_regions=num_regions
        )
    predict_fn_sage_1core = lambda samples: predict_sage(
            samples, predict_SAGE_ND_230628a, Xnew_, xs, ys, xf, yf, num_regions=num_regions
        )
    def subsample(samples, step):
        tamples = {}
        for k in samples.keys():
            tamples[k] = samples[k][::step]  
        return tamples  

    def split_samples(samples, num_proc, length):
        sample_list = []
        splits = np.array(length/num_proc).astype(int)
        s = {}
        for i in trange(splits):
            for k in samples.keys():
                s[k] = samples[k][(i*num_proc):((i+1)*num_proc)]  
            sample_list.append(s)
        return sample_list

    def get_samples_split(samples, num_proc, length, i):
        s = {}
        for k in samples.keys():
            s[k] = samples[k][(i*num_proc):((i+1)*num_proc)]  
        return s    
    
    
    # ------------------------
    
    data = [xs, ys, xf, num_regions, jnp.asarray([5.,10.], dtype=jnp.float64), 
             jnp.asarray([1.,2.], dtype=jnp.float64)]
        
    key = jax.random.PRNGKey(0)
    autoguide_mle = numpyro.infer.autoguide.AutoLowRankMultivariateNormal(model_SAGE_ND_PM_230628a)
    optimizer = numpyro.optim.Adam(step_size=0.05)

    svi = nSVI(model_SAGE_ND_230628a, autoguide_mle, optimizer, loss=nTrace_ELBO())
    svi_result = svi.run(key, 100000, *data)

    params = svi_result.params
    mle_2a_st = autoguide_mle.median(params)
    preds_st = predict_fn_structure(mle_2a_st)
    
    gpc_new_probs_, gpc_new_latent_ = predict_fn_structure(mle_2a_st)
    
    gpc_latent_ = jnp.vstack((mle_2a_st['gpc_latent_0'],mle_2a_st['gpc_latent_1']))

    preds_fp = None
    
    # !!!!!!!!!!!!!!!!!!!!!!!!

    # gpc_new_probs_st = predict_fn_st(mle_2a_st)
    init_params = {'gpc_latent_0': mle_2a_st['gpc_latent_0'], 'gpc_latent_1': mle_2a_st['gpc_latent_1'],
                   'gpc_var': mle_2a_st['gpc_var'],'gpc_lengthscale': mle_2a_st['gpc_lengthscale'],
                   'gpc_bias': mle_2a_st['gpc_bias']}
    init_strategy=init_to_value(values=init_params)
    
    tic = time.perf_counter()
    nuts = nMCMC(nNUTS(model_SAGE_ND_230628a, target_accept_prob=0.8, max_tree_depth=5, init_strategy=init_strategy),
                num_samples=2000, num_warmup=100, num_chains=100)
    nuts.run(key, xs, ys, xf, yf, num_regions, gpc_var_bounds = jnp.asarray([5.,10.], dtype=jnp.float64),
             gpc_ls_bounds = jnp.asarray([.1,2.], dtype=jnp.float64), gpr_var_bounds = jnp.asarray([.1, 2.], dtype=jnp.float64),
             gpr_ls_bounds = jnp.asarray([.1,2.], dtype=jnp.float64), gpr_noise_bounds = jnp.asarray([0.001,.1], dtype=jnp.float64))

    nuts_posterior_samples = nuts.get_samples()
    
    import dill
    with open(r"2D_2an_100chains_4E3samples_2init_231011a.dill", "wb") as output_file:
        dill.dump(nuts_posterior_samples, output_file)

    print('start', nuts_posterior_samples['gpr_noise'].shape[0])
    samples = subsample(nuts_posterior_samples, step = 10)
    print('after subsampling', samples['gpr_noise'].shape[0]) 

    num_length = samples['gpr_noise'].shape[0]
    
    print('splitting')
    sl = split_samples(samples, num_proc, num_length)
    print('done splitting')
    
    splits = np.array(num_length / num_proc).astype(int)
   
    predict_fn_sage = jax.pmap(
        lambda samples: predict_sage(
            samples, predict_SAGE_ND_230628a, Xnew=Xnew_, xs=xs, ys=ys, xf=xf, yf=yf, num_regions=num_regions
        ), axis_name = 0
    )

    print('starting pred analysis, for #', num_length)
    labels = ['Fr_new', 'f_piecewise', 'f_sample', 'gpc_new_probs', 'v_piecewise']

    for i in trange(splits):
        if i == 0:
            preds = predict_fn_sage(sl[i])
            preds_stacked = {labels[0]:preds[0].squeeze(), labels[1]:preds[1].squeeze(), labels[2]:preds[2].squeeze(), labels[3]:preds[3].squeeze(), labels[4]:preds[4].squeeze()}
        else:
            preds = predict_fn_sage(sl[i])
            for j in range(len(labels)):
                preds_stacked[labels[j]] = np.vstack((preds_stacked[labels[j]],preds[j].squeeze()))
    toc = time.perf_counter()
    print(f"Run in {toc - tic:0.4f} seconds")    
                
    print('done pred analysis')
    
    output = {'preds': preds_stacked, 'preds_st':preds_st, 'preds_fp':preds_fp, 'starting_data':starting_data}
    with open(r"2D_2an_matern52_N41_10ksamples_2init_231011a.dill", "wb") as output_file:
        dill.dump(output, output_file)

##### Challenge 1: SAGE-ND, 1 core

In [None]:
from sage_2D_functions_230804a import predict_SAGE_ND_PM_230628a, predict_SAGE_ND_230628a, model_SAGE_ND_230628a
from sage_2D_functions_230804a import gen_data_2D_example, model_SAGE_ND_FP_230628a, model_SAGE_ND_PM_230628a

import matplotlib.pyplot as plt

import numpyro
import numpy as np
from numpy.random import default_rng

import torch
from torch.distributions import constraints
import scipy.io as sio

import jax
import jax.numpy as jnp
from jax import config
config.update("jax_enable_x64", True)
from jax.lax import dynamic_slice
from jax.nn import one_hot as jax_one_hot

from numpyro.infer import MCMC as nMCMC
from numpyro.infer import NUTS as nNUTS
from numpyro.infer import SVI as nSVI
from numpyro.infer import Predictive as nPredictive
from numpyro.infer import Trace_ELBO as nTrace_ELBO
import numpyro.distributions as ndist

torch.set_default_dtype(torch.float64)
from tqdm import trange
import dill
from tqdm import tqdm, trange
from torch.multiprocessing import Pool, Manager, Process
import functools
import math
import time
from numpyro import handlers
from numpyro.infer.initialization import init_to_value

num_proc = 1

N = 41
xp_,yp_ = torch.meshgrid(torch.linspace(-2,2,N),torch.linspace(-2,2,N),indexing='xy')
xp = torch.round(xp_.flatten(),decimals=2)
yp = torch.round(yp_.flatten(),decimals=2)
Xp = torch.hstack((xp[:,None],yp[:,None])).double()
Lp, _ = gen_data_2D_example(xp,yp)

r = torch.sqrt(Xp[:,0]**2 + Xp[:,1]**2)

with open(r"2D_2a_and_2b_fp_231030a.dill", "rb") as input_file:
    f = dill.load(input_file)

num_data_points_st = 20
num_data_points_fp = 60

seed = 0
top = torch.tensor([912,914, 994,996, 1078,1080, 1162,1164, 1244,1246, 1248,1166, 1250,1168, 1252,1170, 1254,1172, 1256,1174, 1176,1092,
                    1160, 1076, 1094, 1096, 1012,1014, 930,932, 848,850])
bottom = torch.tensor([504, 520, 584, 604, 620, 766,768, 684,686, 600,602, 516,518, 436, 434, 514,432, 512,430, 510,428, 508,426, 424, 506, 586,588, 666,668, 748,750, 830,832])
extra = torch.tensor( default_rng(seed).choice(Xp.shape[0],num_data_points_st,replace=False) )

def map_indices(Xp, idx):  
    for i in range(idx.shape[0]):
        if (10*Xp[idx[i],0] % 2):
            idx[i] += 1
        if (10*Xp[idx[i],1] % 2):
            idx[i] = idx[i] + 41
    return idx

extra = map_indices(Xp, extra)
kp_st = torch.unique(torch.cat([extra, top, bottom, torch.tensor([902, 1680, 10])]))

kp_fp = torch.tensor( default_rng(seed+1).permutation(Xp.shape[0])[:num_data_points_fp] )
kp_fp = map_indices(Xp, kp_fp)
temp = torch.sqrt(Xp[kp_fp,0]**2 + Xp[kp_fp,1]**2)
kp_fp = kp_fp[torch.logical_or( temp > 1.5, temp < 1.)]

xs = Xp[kp_st,:].double()
ys = Lp[kp_st].double()
xf = Xp[kp_fp,:].double()
yf = f[kp_fp,0][:,None].double()
# yf += torch.normal(torch.zeros(yf.shape),.01)
starting_data = [Xp, Lp, f, xs, ys, xf, yf]

xs_2a = xs.clone()
ys_2a = ys.clone()
xf_2a = xf.clone()
yf_2a = yf.clone()

xs = jnp.asarray( xs_2a.detach().numpy(), dtype=jnp.float64).copy()
ys = jnp.asarray( ys_2a.detach().numpy(), dtype=jnp.integer).copy()
xf = jnp.asarray( xf_2a.detach().numpy(), dtype=jnp.float64).copy()
yf = jnp.asarray( yf_2a.detach().numpy(), dtype=jnp.float64).copy()

Ns = xs.shape[0] + xf.shape[0]

Nn = 40
xn_,yn_ = torch.meshgrid(torch.linspace(-2,2,Nn),torch.linspace(-2,2,Nn),indexing='xy')
xn = torch.round(xn_.flatten(),decimals=2)
yn = torch.round(yn_.flatten(),decimals=2)
X40 = torch.hstack((xn[:,None],yn[:,None])).double()
Xnew_ = jnp.asarray( X40.detach().numpy(), dtype=jnp.float64).copy()

key = jax.random.PRNGKey(0)
num_regions = 2

def predict_structure(post_samples, model, *args, **kwargs):
    key = jax.random.PRNGKey(0)
    model = handlers.seed(handlers.condition(model, post_samples), key)
    model_trace = handlers.trace(model).get_trace(*args, **kwargs)
    return model_trace["gpc_new_probs"]["value"], model_trace["gpc_new_latent"]["value"]

def predict_sage(post_samples, model, *args, **kwargs):
    key = jax.random.PRNGKey(0)
    model = handlers.seed(handlers.condition(model, post_samples), key)
    model_trace = handlers.trace(model).get_trace(*args, **kwargs)
    return model_trace["Fr_new"]["value"], model_trace["f_piecewise"]["value"], model_trace["f_sample"]["value"], model_trace["gpc_new_probs"]["value"], model_trace["v_piecewise"]["value"]

predict_fn_structure = lambda samples: predict_structure(
        samples, predict_SAGE_ND_PM_230628a, Xnew=Xnew_, xs=xs, ys=ys, num_regions=num_regions
    )
predict_fn_sage_1core = lambda samples: predict_sage(
        samples, predict_SAGE_ND_230628a, Xnew_, xs, ys, xf, yf, num_regions=num_regions
    )
def subsample(samples, step):
    tamples = {}
    for k in samples.keys():
        tamples[k] = samples[k][::step]  
    return tamples  

def split_samples(samples, num_proc, length):
    sample_list = []
    splits = np.array(length/num_proc).astype(int)
    s = {}
    for i in trange(splits):
        for k in samples.keys():
            s[k] = samples[k][(i*num_proc):((i+1)*num_proc)]  
        sample_list.append(s)
    return sample_list

def get_samples_split(samples, num_proc, length, i):
    s = {}
    for k in samples.keys():
        s[k] = samples[k][(i*num_proc):((i+1)*num_proc)]  
    return s    


# ------------------------

data = [xs, ys, xf, num_regions, jnp.asarray([5.,10.], dtype=jnp.float64), 
         jnp.asarray([1.,2.], dtype=jnp.float64)]

key = jax.random.PRNGKey(0)
autoguide_mle = numpyro.infer.autoguide.AutoLowRankMultivariateNormal(model_SAGE_ND_PM_230628a)
optimizer = numpyro.optim.Adam(step_size=0.05)

svi = nSVI(model_SAGE_ND_PM_230628a, autoguide_mle, optimizer, loss=nTrace_ELBO())
svi_result = svi.run(key, 100000, *data)

params = svi_result.params
mle_2a_st = autoguide_mle.median(params)
preds_st = predict_fn_structure(mle_2a_st)

gpc_new_probs_, gpc_new_latent_ = predict_fn_structure(mle_2a_st)

gpc_latent_ = jnp.vstack((mle_2a_st['gpc_latent_0'],mle_2a_st['gpc_latent_1']))

preds_fp = None

# !!!!!!!!!!!!!!!!!!!!!!!!

init_params = {'gpc_latent_0': mle_2a_st['gpc_latent_0'], 'gpc_latent_1': mle_2a_st['gpc_latent_1'],
               'gpc_var': mle_2a_st['gpc_var'],'gpc_lengthscale': mle_2a_st['gpc_lengthscale'],
               'gpc_bias': mle_2a_st['gpc_bias']}
init_strategy=init_to_value(values=init_params)

tic = time.perf_counter()
nuts = nMCMC(nNUTS(model_SAGE_ND_230628a, target_accept_prob=0.8, max_tree_depth=5, init_strategy=init_strategy),
            num_samples=2000, num_warmup=100, num_chains=1)
nuts.run(key, xs, ys, xf, yf, num_regions, gpc_var_bounds = jnp.asarray([5.,10.], dtype=jnp.float64),
         gpc_ls_bounds = jnp.asarray([.1,2.], dtype=jnp.float64), gpr_var_bounds = jnp.asarray([.1, 2.], dtype=jnp.float64),
         gpr_ls_bounds = jnp.asarray([.1,2.], dtype=jnp.float64), gpr_noise_bounds = jnp.asarray([0.001,.1], dtype=jnp.float64))

nuts_posterior_samples = nuts.get_samples()

import dill
# with open(r"2D_2an_100chains_4E3samples_2init_231011a.dill", "wb") as output_file:
#     dill.dump(nuts_posterior_samples, output_file)

print('start', nuts_posterior_samples['gpr_noise'].shape[0])
samples = subsample(nuts_posterior_samples, step = 10)
print('after subsampling', samples['gpr_noise'].shape[0]) 

num_length = samples['gpr_noise'].shape[0]

print('splitting')
sl = split_samples(samples, num_proc, num_length)
print('done splitting')

splits = np.array(num_length / num_proc).astype(int)

predict_fn_sage = jax.pmap(
    lambda samples: predict_sage(
        samples, predict_SAGE_ND_230628a, Xnew=Xnew_, xs=xs, ys=ys, xf=xf, yf=yf, num_regions=num_regions
    ), axis_name = 0
)

print('starting pred analysis, for #', num_length)
labels = ['Fr_new', 'f_piecewise', 'f_sample', 'gpc_new_probs', 'v_piecewise']

for i in trange(splits):
    if i == 0:
        preds = predict_fn_sage(sl[i])
        preds_stacked = {labels[0]:preds[0].squeeze(), labels[1]:preds[1].squeeze(), labels[2]:preds[2].squeeze(), labels[3]:preds[3].squeeze(), labels[4]:preds[4].squeeze()}
    else:
        preds = predict_fn_sage(sl[i])
        for j in range(len(labels)):
            preds_stacked[labels[j]] = np.vstack((preds_stacked[labels[j]],preds[j].squeeze()))
toc = time.perf_counter()
print(f"Run in {toc - tic:0.4f} seconds")    

print('done pred analysis')

output = {'preds': preds_stacked, 'preds_st':preds_st, 'preds_fp':preds_fp, 'starting_data':starting_data}
with open(r"2D_2an_matern52_N41_1core_231011a.dill", "wb") as output_file:
    dill.dump(output, output_file)

##### Challenge 1: SAGE-ND-FP, multicore

In [None]:
%%writefile sage_2D_2an_fp_matern52_230804a.py

from sage_2D_functions_230804a import predict_SAGE_ND_PM_230628a, predict_SAGE_ND_230628a, model_SAGE_ND_230628a
from sage_2D_functions_230804a import gen_data_2D_example, model_SAGE_ND_FP_230628a, model_SAGE_ND_PM_230628a

import matplotlib.pyplot as plt
import numpyro
import numpy as np
from numpy.random import default_rng

import torch
from torch.distributions import constraints
import scipy.io as sio

import jax
import jax.numpy as jnp
from jax import config
config.update("jax_enable_x64", True)
from jax.lax import dynamic_slice
from jax.nn import one_hot as jax_one_hot

from numpyro.infer import MCMC as nMCMC
from numpyro.infer import NUTS as nNUTS
from numpyro.infer import SVI as nSVI
from numpyro.infer import Predictive as nPredictive
from numpyro.infer import Trace_ELBO as nTrace_ELBO
import numpyro.distributions as ndist

torch.set_default_dtype(torch.float64)
from tqdm import trange
import dill
from tqdm import tqdm, trange
from torch.multiprocessing import Pool, Manager, Process
import functools
import math
import time
from numpyro import handlers
from numpyro.infer.initialization import init_to_value

numpyro.set_host_device_count(200)


if __name__ == '__main__':
    num_proc = 100

    
    N = 41
    xp_,yp_ = torch.meshgrid(torch.linspace(-2,2,N),torch.linspace(-2,2,N),indexing='xy')
    xp = torch.round(xp_.flatten(),decimals=2)
    yp = torch.round(yp_.flatten(),decimals=2)
    Xp = torch.hstack((xp[:,None],yp[:,None])).double()
    Lp, _ = gen_data_2D_example(xp,yp)

    r = torch.sqrt(Xp[:,0]**2 + Xp[:,1]**2)

    with open(r"2D_2a_and_2b_fp_231030a.dill", "rb") as input_file:
        f = dill.load(input_file)


    num_data_points_st = 20
    num_data_points_fp = 60

    seed = 0
    top = torch.tensor([912,914, 994,996, 1078,1080, 1162,1164, 1244,1246, 1248,1166, 1250,1168, 1252,1170, 1254,1172, 1256,1174, 1176,1092,
                        1160, 1076, 1094, 1096, 1012,1014, 930,932, 848,850])
    bottom = torch.tensor([504, 520, 584, 604, 620, 766,768, 684,686, 600,602, 516,518, 436, 434, 514,432, 512,430, 510,428, 508,426, 424, 506, 586,588, 666,668, 748,750, 830,832])
    extra = torch.tensor( default_rng(seed).choice(Xp.shape[0],num_data_points_st,replace=False) )

    def map_indices(Xp, idx):  
        for i in range(idx.shape[0]):
            if (10*Xp[idx[i],0] % 2):
                idx[i] += 1
            if (10*Xp[idx[i],1] % 2):
                idx[i] = idx[i] + 41
        return idx

    extra = map_indices(Xp, extra)
    kp_st = torch.unique(torch.cat([extra, top, bottom, torch.tensor([902, 1680, 10])]))

    kp_fp = torch.tensor( default_rng(seed+1).permutation(Xp.shape[0])[:num_data_points_fp] )
    kp_fp = map_indices(Xp, kp_fp)
    temp = torch.sqrt(Xp[kp_fp,0]**2 + Xp[kp_fp,1]**2)
    kp_fp = kp_fp[torch.logical_or( temp > 1.5, temp < 1.)]

    xs = Xp[kp_st,:].double()
    ys = Lp[kp_st].double()
    xf = Xp[kp_fp,:].double()
    yf = f[kp_fp,0][:,None].double()
    # yf += torch.normal(torch.zeros(yf.shape),.01)
    starting_data = [Xp, Lp, f, xs, ys, xf, yf]

    xs_2a = xs.clone()
    ys_2a = ys.clone()
    xf_2a = xf.clone()
    yf_2a = yf.clone()

    xs = jnp.asarray( xs_2a.detach().numpy(), dtype=jnp.float64).copy()
    ys = jnp.asarray( ys_2a.detach().numpy(), dtype=jnp.integer).copy()
    xf = jnp.asarray( xf_2a.detach().numpy(), dtype=jnp.float64).copy()
    yf = jnp.asarray( yf_2a.detach().numpy(), dtype=jnp.float64).copy()

    Ns = xs.shape[0] + xf.shape[0]

    Nn = 40
    xn_,yn_ = torch.meshgrid(torch.linspace(-2,2,Nn),torch.linspace(-2,2,Nn),indexing='xy')
    xn = torch.round(xn_.flatten(),decimals=2)
    yn = torch.round(yn_.flatten(),decimals=2)
    X40 = torch.hstack((xn[:,None],yn[:,None])).double()
    Xnew_ = jnp.asarray( X40.detach().numpy(), dtype=jnp.float64).copy()


    key = jax.random.PRNGKey(0)
    num_regions = 2


    def predict_fp(post_samples, model, *args, **kwargs):
        key = jax.random.PRNGKey(0)
        model = handlers.seed(handlers.condition(model, post_samples), key)
        model_trace = handlers.trace(model).get_trace(*args, **kwargs)
        return model_trace["Fr_new"]["value"], model_trace["f_piecewise"]["value"], model_trace["f_sample"]["value"], model_trace["gpc_new_probs"]["value"], model_trace["v_piecewise"]["value"]

    def subsample(samples, step):
        tamples = {}
        for k in samples.keys():
            tamples[k] = samples[k][::step]  
        return tamples  

    def split_samples(samples, num_proc, length):
        sample_list = []
        splits = np.array(length/num_proc).astype(int)
        s = {}
        for i in trange(splits):
            for k in samples.keys():
                s[k] = samples[k][(i*num_proc):((i+1)*num_proc)]  
            sample_list.append(s)
        return sample_list

    def get_samples_split(samples, num_proc, length, i):
        s = {}
        for k in samples.keys():
            s[k] = samples[k][(i*num_proc):((i+1)*num_proc)]  
        return s    


    # !!!!!!!!!!!!!!!!!!!!!!!!
    
    tic = time.perf_counter()
    nuts = nMCMC(nNUTS(model_SAGE_ND_FP_230628a, target_accept_prob=0.8, max_tree_depth=5),
                num_samples=2000, num_warmup=100, num_chains=100)
    nuts.run(key, xf, yf, num_regions, gpc_var_bounds = jnp.asarray([5.,10.], dtype=jnp.float64),
             gpc_ls_bounds = jnp.asarray([.1,2.], dtype=jnp.float64), gpr_var_bounds = jnp.asarray([.1, 2.], dtype=jnp.float64),
             gpr_ls_bounds = jnp.asarray([.1,2.], dtype=jnp.float64), gpr_noise_bounds = jnp.asarray([0.001,.1], dtype=jnp.float64))

    nuts_posterior_samples = nuts.get_samples()
    
    import dill
    # with open(r"2D_temp.dill", "wb") as output_file:
    #     dill.dump(nuts_posterior_samples, output_file)

    print('start', nuts_posterior_samples['gpr_noise'].shape[0])
    samples = subsample(nuts_posterior_samples, step = 10)
    print('after subsampling', samples['gpr_noise'].shape[0]) 

    num_length = samples['gpr_noise'].shape[0]
    
    print('splitting')
    sl = split_samples(samples, num_proc, num_length)
    print('done splitting')
    
    splits = np.array(num_length / num_proc).astype(int)
   
    predict_fn_sage = jax.pmap(
        lambda samples: predict_fp(
            samples, predict_SAGE_ND_FP_230628a, Xnew=Xnew_, xf=xf, yf=yf, num_regions=num_regions
        ), axis_name = 0
    )

    print('starting pred analysis, for #', num_length)
    labels = ['Fr_new', 'f_piecewise', 'f_sample', 'gpc_new_probs', 'v_piecewise']

    for i in trange(splits):
        if i == 0:
            preds = predict_fn_sage(sl[i])
            preds_stacked = {labels[0]:preds[0].squeeze(), labels[1]:preds[1].squeeze(), labels[2]:preds[2].squeeze(), labels[3]:preds[3].squeeze(), labels[4]:preds[4].squeeze()}
        else:
            preds = predict_fn_sage(sl[i])
            for j in range(len(labels)):
                preds_stacked[labels[j]] = np.vstack((preds_stacked[labels[j]],preds[j].squeeze()))
    toc = time.perf_counter()
    print(f"Run in {toc - tic:0.4f} seconds")    
                
    print('done pred analysis')
    
    output = {'preds': preds_stacked, 'starting_data':starting_data}
    with open(r"2D_2an_fp_matern52_231011a.dill", "wb") as output_file:
        dill.dump(output, output_file)

##### Challenge 1: SAGE-ND-FP, 1 core

In [None]:
%%writefile sage_2D_2an_fp_matern52_230804a.py

from sage_2D_functions_230804a import predict_SAGE_ND_PM_230628a, predict_SAGE_ND_230628a, model_SAGE_ND_230628a
from sage_2D_functions_230804a import gen_data_2D_example, model_SAGE_ND_FP_230628a, model_SAGE_ND_PM_230628a

import matplotlib.pyplot as plt
import numpyro
import numpy as np
from numpy.random import default_rng

import torch
from torch.distributions import constraints
import scipy.io as sio

import jax
import jax.numpy as jnp
from jax import config
config.update("jax_enable_x64", True)
from jax.lax import dynamic_slice
from jax.nn import one_hot as jax_one_hot

from numpyro.infer import MCMC as nMCMC
from numpyro.infer import NUTS as nNUTS
from numpyro.infer import SVI as nSVI
from numpyro.infer import Predictive as nPredictive
from numpyro.infer import Trace_ELBO as nTrace_ELBO
import numpyro.distributions as ndist

torch.set_default_dtype(torch.float64)
from tqdm import trange
import dill
from tqdm import tqdm, trange
from torch.multiprocessing import Pool, Manager, Process
import functools
import math
import time
from numpyro import handlers
from numpyro.infer.initialization import init_to_value

numpyro.set_host_device_count(200)

num_proc = 1

N = 41
xp_,yp_ = torch.meshgrid(torch.linspace(-2,2,N),torch.linspace(-2,2,N),indexing='xy')
xp = torch.round(xp_.flatten(),decimals=2)
yp = torch.round(yp_.flatten(),decimals=2)
Xp = torch.hstack((xp[:,None],yp[:,None])).double()
Lp, _ = gen_data_2D_example(xp,yp)

r = torch.sqrt(Xp[:,0]**2 + Xp[:,1]**2)

with open(r"2D_2a_and_2b_fp_231030a.dill", "rb") as input_file:
    f = dill.load(input_file)


num_data_points_st = 20
num_data_points_fp = 60

seed = 0
top = torch.tensor([912,914, 994,996, 1078,1080, 1162,1164, 1244,1246, 1248,1166, 1250,1168, 1252,1170, 1254,1172, 1256,1174, 1176,1092,
                    1160, 1076, 1094, 1096, 1012,1014, 930,932, 848,850])
bottom = torch.tensor([504, 520, 584, 604, 620, 766,768, 684,686, 600,602, 516,518, 436, 434, 514,432, 512,430, 510,428, 508,426, 424, 506, 586,588, 666,668, 748,750, 830,832])
extra = torch.tensor( default_rng(seed).choice(Xp.shape[0],num_data_points_st,replace=False) )

def map_indices(Xp, idx):  
    for i in range(idx.shape[0]):
        if (10*Xp[idx[i],0] % 2):
            idx[i] += 1
        if (10*Xp[idx[i],1] % 2):
            idx[i] = idx[i] + 41
    return idx

extra = map_indices(Xp, extra)
kp_st = torch.unique(torch.cat([extra, top, bottom, torch.tensor([902, 1680, 10])]))

kp_fp = torch.tensor( default_rng(seed+1).permutation(Xp.shape[0])[:num_data_points_fp] )
kp_fp = map_indices(Xp, kp_fp)
temp = torch.sqrt(Xp[kp_fp,0]**2 + Xp[kp_fp,1]**2)
kp_fp = kp_fp[torch.logical_or( temp > 1.5, temp < 1.)]

xs = Xp[kp_st,:].double()
ys = Lp[kp_st].double()
xf = Xp[kp_fp,:].double()
yf = f[kp_fp,0][:,None].double()
# yf += torch.normal(torch.zeros(yf.shape),.01)
starting_data = [Xp, Lp, f, xs, ys, xf, yf]

xs_2a = xs.clone()
ys_2a = ys.clone()
xf_2a = xf.clone()
yf_2a = yf.clone()

xs = jnp.asarray( xs_2a.detach().numpy(), dtype=jnp.float64).copy()
ys = jnp.asarray( ys_2a.detach().numpy(), dtype=jnp.integer).copy()
xf = jnp.asarray( xf_2a.detach().numpy(), dtype=jnp.float64).copy()
yf = jnp.asarray( yf_2a.detach().numpy(), dtype=jnp.float64).copy()

Ns = xs.shape[0] + xf.shape[0]

Nn = 40
xn_,yn_ = torch.meshgrid(torch.linspace(-2,2,Nn),torch.linspace(-2,2,Nn),indexing='xy')
xn = torch.round(xn_.flatten(),decimals=2)
yn = torch.round(yn_.flatten(),decimals=2)
X40 = torch.hstack((xn[:,None],yn[:,None])).double()
Xnew_ = jnp.asarray( X40.detach().numpy(), dtype=jnp.float64).copy()


key = jax.random.PRNGKey(0)
num_regions = 2


def predict_fp(post_samples, model, *args, **kwargs):
    key = jax.random.PRNGKey(0)
    model = handlers.seed(handlers.condition(model, post_samples), key)
    model_trace = handlers.trace(model).get_trace(*args, **kwargs)
    return model_trace["Fr_new"]["value"], model_trace["f_piecewise"]["value"], model_trace["f_sample"]["value"], model_trace["gpc_new_probs"]["value"], model_trace["v_piecewise"]["value"]

def subsample(samples, step):
    tamples = {}
    for k in samples.keys():
        tamples[k] = samples[k][::step]  
    return tamples  

def split_samples(samples, num_proc, length):
    sample_list = []
    splits = np.array(length/num_proc).astype(int)
    s = {}
    for i in trange(splits):
        for k in samples.keys():
            s[k] = samples[k][(i*num_proc):((i+1)*num_proc)]  
        sample_list.append(s)
    return sample_list

def get_samples_split(samples, num_proc, length, i):
    s = {}
    for k in samples.keys():
        s[k] = samples[k][(i*num_proc):((i+1)*num_proc)]  
    return s    


# !!!!!!!!!!!!!!!!!!!!!!!!

tic = time.perf_counter()
nuts = nMCMC(nNUTS(model_SAGE_ND_FP_230628a, target_accept_prob=0.8, max_tree_depth=5),
            num_samples=2000, num_warmup=100, num_chains=100)
nuts.run(key, xf, yf, num_regions, gpc_var_bounds = jnp.asarray([5.,10.], dtype=jnp.float64),
         gpc_ls_bounds = jnp.asarray([.1,2.], dtype=jnp.float64), gpr_var_bounds = jnp.asarray([.1, 2.], dtype=jnp.float64),
         gpr_ls_bounds = jnp.asarray([.1,2.], dtype=jnp.float64), gpr_noise_bounds = jnp.asarray([0.001,.1], dtype=jnp.float64))

nuts_posterior_samples = nuts.get_samples()

import dill
# with open(r"2D_temp.dill", "wb") as output_file:
#     dill.dump(nuts_posterior_samples, output_file)

print('start', nuts_posterior_samples['gpr_noise'].shape[0])
samples = subsample(nuts_posterior_samples, step = 10)
print('after subsampling', samples['gpr_noise'].shape[0]) 

num_length = samples['gpr_noise'].shape[0]

print('splitting')
sl = split_samples(samples, num_proc, num_length)
print('done splitting')

splits = np.array(num_length / num_proc).astype(int)

predict_fn_sage = lambda samples: predict_fp(
        samples, predict_SAGE_ND_FP_230628a, Xnew=Xnew_, xf=xf, yf=yf, num_regions=num_regions
    )

print('starting pred analysis, for #', num_length)
labels = ['Fr_new', 'f_piecewise', 'f_sample', 'gpc_new_probs', 'v_piecewise']

for i in trange(splits):
    if i == 0:
        preds = predict_fn_sage(sl[i])
        preds_stacked = {labels[0]:preds[0].squeeze(), labels[1]:preds[1].squeeze(), labels[2]:preds[2].squeeze(), labels[3]:preds[3].squeeze(), labels[4]:preds[4].squeeze()}
    else:
        preds = predict_fn_sage(sl[i])
        for j in range(len(labels)):
            preds_stacked[labels[j]] = np.vstack((preds_stacked[labels[j]],preds[j].squeeze()))
toc = time.perf_counter()
print(f"Run in {toc - tic:0.4f} seconds")    

print('done pred analysis')

output = {'preds': preds_stacked, 'starting_data':starting_data}
with open(r"2D_2an_fp_1core_231011a.dill", "wb") as output_file:
    dill.dump(output, output_file)

##### Challenge 1: SAGE-ND-PM, multicore

In [None]:
%%writefile sage_2D_2an_structure_matern_with_1init_230804a.py

from sage_2D_functions_230804a import predict_SAGE_ND_PM_230628a, predict_SAGE_ND_230628a, model_SAGE_ND_230628a
from sage_2D_functions_230804a import gen_data_2D_example, model_SAGE_ND_FP_230628a, model_SAGE_ND_PM_230628a

import matplotlib.pyplot as plt
import numpyro
import numpy as np
from numpy.random import default_rng

import torch
from torch.distributions import constraints
import scipy.io as sio

import jax
import jax.numpy as jnp
from jax import config
config.update("jax_enable_x64", True)
from jax.lax import dynamic_slice
from jax.nn import one_hot as jax_one_hot

from numpyro.infer import MCMC as nMCMC
from numpyro.infer import NUTS as nNUTS
from numpyro.infer import SVI as nSVI
from numpyro.infer import Predictive as nPredictive
from numpyro.infer import Trace_ELBO as nTrace_ELBO
import numpyro.distributions as ndist

torch.set_default_dtype(torch.float64)
from tqdm import trange
import dill
from tqdm import tqdm, trange
from torch.multiprocessing import Pool, Manager, Process
import functools
import math
import time
from numpyro import handlers
from numpyro.infer.initialization import init_to_value

numpyro.set_host_device_count(200)



if __name__ == '__main__':
    num_proc = 100
    
    N = 41
    xp_,yp_ = torch.meshgrid(torch.linspace(-2,2,N),torch.linspace(-2,2,N),indexing='xy')
    xp = torch.round(xp_.flatten(),decimals=2)
    yp = torch.round(yp_.flatten(),decimals=2)
    Xp = torch.hstack((xp[:,None],yp[:,None])).double()
    Lp, _ = gen_data_2D_example(xp,yp)

    r = torch.sqrt(Xp[:,0]**2 + Xp[:,1]**2)

    with open(r"2D_2a_and_2b_fp_231030a.dill", "rb") as input_file:
        f = dill.load(input_file)


    num_data_points_st = 20
    num_data_points_fp = 60

    seed = 0
    top = torch.tensor([912,914, 994,996, 1078,1080, 1162,1164, 1244,1246, 1248,1166, 1250,1168, 1252,1170, 1254,1172, 1256,1174, 1176,1092,
                        1160, 1076, 1094, 1096, 1012,1014, 930,932, 848,850])
    bottom = torch.tensor([504, 520, 584, 604, 620, 766,768, 684,686, 600,602, 516,518, 436, 434, 514,432, 512,430, 510,428, 508,426, 424, 506, 586,588, 666,668, 748,750, 830,832])
    extra = torch.tensor( default_rng(seed).choice(Xp.shape[0],num_data_points_st,replace=False) )

    def map_indices(Xp, idx):  
        for i in range(idx.shape[0]):
            if (10*Xp[idx[i],0] % 2):
                idx[i] += 1
            if (10*Xp[idx[i],1] % 2):
                idx[i] = idx[i] + 41
        return idx

    extra = map_indices(Xp, extra)
    kp_st = torch.unique(torch.cat([extra, top, bottom, torch.tensor([902, 1680, 10])]))

    kp_fp = torch.tensor( default_rng(seed+1).permutation(Xp.shape[0])[:num_data_points_fp] )
    kp_fp = map_indices(Xp, kp_fp)
    temp = torch.sqrt(Xp[kp_fp,0]**2 + Xp[kp_fp,1]**2)
    kp_fp = kp_fp[torch.logical_or( temp > 1.5, temp < 1.)]

    xs = Xp[kp_st,:].double()
    ys = Lp[kp_st].double()
    xf = Xp[kp_fp,:].double()
    yf = f[kp_fp,0][:,None].double()
    # yf += torch.normal(torch.zeros(yf.shape),.01)
    starting_data = [Xp, Lp, f, xs, ys, xf, yf]

    xs_2a = xs.clone()
    ys_2a = ys.clone()
    xf_2a = xf.clone()
    yf_2a = yf.clone()

    xs = jnp.asarray( xs_2a.detach().numpy(), dtype=jnp.float64).copy()
    ys = jnp.asarray( ys_2a.detach().numpy(), dtype=jnp.integer).copy()
    xf = jnp.asarray( xf_2a.detach().numpy(), dtype=jnp.float64).copy()
    yf = jnp.asarray( yf_2a.detach().numpy(), dtype=jnp.float64).copy()

    Ns = xs.shape[0] + xf.shape[0]

    Nn = 41
    xn_,yn_ = torch.meshgrid(torch.linspace(-2,2,Nn),torch.linspace(-2,2,Nn),indexing='xy')
    xn = torch.round(xn_.flatten(),decimals=2)
    yn = torch.round(yn_.flatten(),decimals=2)
    X40 = torch.hstack((xn[:,None],yn[:,None])).double()
    Xnew_ = jnp.asarray( X40.detach().numpy(), dtype=jnp.float64).copy()


    key = jax.random.PRNGKey(0)
    num_regions = 2

    def predict_structure(post_samples, model, *args, **kwargs):
        key = jax.random.PRNGKey(0)
        model = handlers.seed(handlers.condition(model, post_samples), key)
        model_trace = handlers.trace(model).get_trace(*args, **kwargs)
        return model_trace["gpc_new_probs"]["value"], model_trace["gpc_new_latent"]["value"]

    def predict_sage(post_samples, model, *args, **kwargs):
        key = jax.random.PRNGKey(0)
        model = handlers.seed(handlers.condition(model, post_samples), key)
        model_trace = handlers.trace(model).get_trace(*args, **kwargs)
        return model_trace["Fr_new"]["value"], model_trace["f_piecewise"]["value"], model_trace["f_sample"]["value"], model_trace["gpc_new_probs"]["value"], model_trace["v_piecewise"]["value"]

    predict_fn_structure = lambda samples: predict_structure(
            samples, predict_model_joint_structure_ND_matern52_numpyro_230628a, Xnew=Xnew_, xs=xs, ys=ys, num_regions=num_regions
        )
    predict_fn_sage_1core = lambda samples: predict_sage(
            samples, predict_model_joint_ND_matern52_numpyro_230628a, Xnew_, xs, ys, xf, yf, num_regions=num_regions
        )
    def subsample(samples, step):
        tamples = {}
        for k in samples.keys():
            tamples[k] = samples[k][::step]  
        return tamples  

    def split_samples(samples, num_proc, length):
        sample_list = []
        splits = np.array(length/num_proc).astype(int)
        s = {}
        for i in trange(splits):
            for k in samples.keys():
                s[k] = samples[k][(i*num_proc):((i+1)*num_proc)]  
            sample_list.append(s)
        return sample_list

    def get_samples_split(samples, num_proc, length, i):
        s = {}
        for k in samples.keys():
            s[k] = samples[k][(i*num_proc):((i+1)*num_proc)]  
        return s    
    
    
    
    # -------------------------------------------------

    
    
    data = [xs, ys, xf, num_regions, jnp.asarray([5.,10.], dtype=jnp.float64), 
             jnp.asarray([1.,2.], dtype=jnp.float64)]
        
    key = jax.random.PRNGKey(0)
    autoguide_mle = numpyro.infer.autoguide.AutoLowRankMultivariateNormal(model_joint_structure_ND_matern52_numpyro_230628a)
    optimizer = numpyro.optim.Adam(step_size=0.05)

    svi = nSVI(model_joint_structure_ND_matern52_numpyro_230628a, autoguide_mle, optimizer, loss=nTrace_ELBO())
    svi_result = svi.run(key, 100000, *data)

    params = svi_result.params
    mle_2a_st = autoguide_mle.median(params)
    preds_st = predict_fn_structure(mle_2a_st)
    
    gpc_new_probs_, gpc_new_latent_ = predict_fn_structure(mle_2a_st)
    
    gpc_latent_ = jnp.vstack((mle_2a_st['gpc_latent_0'],mle_2a_st['gpc_latent_1']))

    preds_fp = None
    
    # !!!!!!!!!!!!!!!!!!!!!!!!

    # gpc_new_probs_st = predict_fn_st(mle_2a_st)
    init_params = {'gpc_latent_0': mle_2a_st['gpc_latent_0'], 'gpc_latent_1': mle_2a_st['gpc_latent_1'],
                   'gpc_var': mle_2a_st['gpc_var'],'gpc_lengthscale': mle_2a_st['gpc_lengthscale'],
                   'gpc_bias': mle_2a_st['gpc_bias']}
    init_strategy=init_to_value(values=init_params)
    
    tic = time.perf_counter()
    nuts = nMCMC(nNUTS(model_joint_structure_ND_matern52_numpyro_230628a, target_accept_prob=0.8, max_tree_depth=5, init_strategy=init_strategy),
                num_samples=2000, num_warmup=100, num_chains=100)
    nuts.run(key, xs, ys, xf, num_regions, gpc_var_bounds = jnp.asarray([5.,10.], dtype=jnp.float64),
             gpc_ls_bounds = jnp.asarray([.1,2.], dtype=jnp.float64))

    nuts_posterior_samples = nuts.get_samples()
    
    import dill
    with open(r"2D_2an_structure.dill", "wb") as output_file:
        dill.dump(nuts_posterior_samples, output_file)

    print('start', nuts_posterior_samples['gpc_bias'].shape[0])
    samples = subsample(nuts_posterior_samples, step = 10)
    print('after subsampling', samples['gpc_bias'].shape[0]) 

    num_length = samples['gpc_bias'].shape[0]
    
    print('splitting')
    sl = split_samples(samples, num_proc, num_length)
    print('done splitting')
    
    splits = np.array(num_length / num_proc).astype(int)
   

    predict_fn_st_multicore = jax.pmap(
        lambda samples: predict_structure(
            samples, predict_model_joint_structure_ND_matern52_numpyro_230628a, Xnew=Xnew_, xs=xs, ys=ys, num_regions=num_regions
        ), axis_name = 0
    )

    print('starting pred analysis, for #', num_length)
    labels = ['gpc_new_probs']

    for i in trange(splits):
        if i == 0:
            preds = predict_fn_st_multicore(sl[i])
            preds_stacked = {labels[0]:preds[0].squeeze()}
        else:
            preds = predict_fn_st_multicore(sl[i])
            for j in range(len(labels)):
                preds_stacked[labels[j]] = np.vstack((preds_stacked[labels[j]],preds[j].squeeze()))
    toc = time.perf_counter()
    print(f"Run in {toc - tic:0.4f} seconds")
                
    print('done pred analysis')
    
    output = {'preds': preds_stacked, 'preds_st':preds_st, 'preds_fp':preds_fp, 'starting_data':starting_data}
    with open(r"2D_2an_structure_matern52_N41_231011a.dill", "wb") as output_file:
        dill.dump(output, output_file)

##### Challenge 1: SAGE-ND-PM, 1 core

In [None]:
%%writefile sage_2D_2an_structure_matern_with_1init_230804a.py

from sage_2D_functions_230804a import predict_SAGE_ND_PM_230628a, predict_SAGE_ND_230628a, model_SAGE_ND_230628a
from sage_2D_functions_230804a import gen_data_2D_example, model_SAGE_ND_FP_230628a, model_SAGE_ND_PM_230628a

import matplotlib.pyplot as plt
import numpyro
import numpy as np
from numpy.random import default_rng

import torch
from torch.distributions import constraints
import scipy.io as sio

import jax
import jax.numpy as jnp
from jax import config
config.update("jax_enable_x64", True)
from jax.lax import dynamic_slice
from jax.nn import one_hot as jax_one_hot

from numpyro.infer import MCMC as nMCMC
from numpyro.infer import NUTS as nNUTS
from numpyro.infer import SVI as nSVI
from numpyro.infer import Predictive as nPredictive
from numpyro.infer import Trace_ELBO as nTrace_ELBO
import numpyro.distributions as ndist

torch.set_default_dtype(torch.float64)
from tqdm import trange
import dill
from tqdm import tqdm, trange
from torch.multiprocessing import Pool, Manager, Process
import functools
import math
import time
from numpyro import handlers
from numpyro.infer.initialization import init_to_value

numpyro.set_host_device_count(200)


num_proc = 100

N = 41
xp_,yp_ = torch.meshgrid(torch.linspace(-2,2,N),torch.linspace(-2,2,N),indexing='xy')
xp = torch.round(xp_.flatten(),decimals=2)
yp = torch.round(yp_.flatten(),decimals=2)
Xp = torch.hstack((xp[:,None],yp[:,None])).double()
Lp, _ = gen_data_2D_example(xp,yp)

r = torch.sqrt(Xp[:,0]**2 + Xp[:,1]**2)

with open(r"2D_2a_and_2b_fp_231030a.dill", "rb") as input_file:
    f = dill.load(input_file)


num_data_points_st = 20
num_data_points_fp = 60

seed = 0
top = torch.tensor([912,914, 994,996, 1078,1080, 1162,1164, 1244,1246, 1248,1166, 1250,1168, 1252,1170, 1254,1172, 1256,1174, 1176,1092,
                    1160, 1076, 1094, 1096, 1012,1014, 930,932, 848,850])
bottom = torch.tensor([504, 520, 584, 604, 620, 766,768, 684,686, 600,602, 516,518, 436, 434, 514,432, 512,430, 510,428, 508,426, 424, 506, 586,588, 666,668, 748,750, 830,832])
extra = torch.tensor( default_rng(seed).choice(Xp.shape[0],num_data_points_st,replace=False) )

def map_indices(Xp, idx):  
    for i in range(idx.shape[0]):
        if (10*Xp[idx[i],0] % 2):
            idx[i] += 1
        if (10*Xp[idx[i],1] % 2):
            idx[i] = idx[i] + 41
    return idx

extra = map_indices(Xp, extra)
kp_st = torch.unique(torch.cat([extra, top, bottom, torch.tensor([902, 1680, 10])]))

kp_fp = torch.tensor( default_rng(seed+1).permutation(Xp.shape[0])[:num_data_points_fp] )
kp_fp = map_indices(Xp, kp_fp)
temp = torch.sqrt(Xp[kp_fp,0]**2 + Xp[kp_fp,1]**2)
kp_fp = kp_fp[torch.logical_or( temp > 1.5, temp < 1.)]

xs = Xp[kp_st,:].double()
ys = Lp[kp_st].double()
xf = Xp[kp_fp,:].double()
yf = f[kp_fp,0][:,None].double()
# yf += torch.normal(torch.zeros(yf.shape),.01)
starting_data = [Xp, Lp, f, xs, ys, xf, yf]

xs_2a = xs.clone()
ys_2a = ys.clone()
xf_2a = xf.clone()
yf_2a = yf.clone()

xs = jnp.asarray( xs_2a.detach().numpy(), dtype=jnp.float64).copy()
ys = jnp.asarray( ys_2a.detach().numpy(), dtype=jnp.integer).copy()
xf = jnp.asarray( xf_2a.detach().numpy(), dtype=jnp.float64).copy()
yf = jnp.asarray( yf_2a.detach().numpy(), dtype=jnp.float64).copy()

Ns = xs.shape[0] + xf.shape[0]

Nn = 41
xn_,yn_ = torch.meshgrid(torch.linspace(-2,2,Nn),torch.linspace(-2,2,Nn),indexing='xy')
xn = torch.round(xn_.flatten(),decimals=2)
yn = torch.round(yn_.flatten(),decimals=2)
X40 = torch.hstack((xn[:,None],yn[:,None])).double()
Xnew_ = jnp.asarray( X40.detach().numpy(), dtype=jnp.float64).copy()


key = jax.random.PRNGKey(0)
num_regions = 2

def predict_structure(post_samples, model, *args, **kwargs):
    key = jax.random.PRNGKey(0)
    model = handlers.seed(handlers.condition(model, post_samples), key)
    model_trace = handlers.trace(model).get_trace(*args, **kwargs)
    return model_trace["gpc_new_probs"]["value"], model_trace["gpc_new_latent"]["value"]

def predict_sage(post_samples, model, *args, **kwargs):
    key = jax.random.PRNGKey(0)
    model = handlers.seed(handlers.condition(model, post_samples), key)
    model_trace = handlers.trace(model).get_trace(*args, **kwargs)
    return model_trace["Fr_new"]["value"], model_trace["f_piecewise"]["value"], model_trace["f_sample"]["value"], model_trace["gpc_new_probs"]["value"], model_trace["v_piecewise"]["value"]

predict_fn_structure = lambda samples: predict_structure(
        samples, predict_SAGE_ND_PM_230628a, Xnew=Xnew_, xs=xs, ys=ys, num_regions=num_regions
    )
predict_fn_sage_1core = lambda samples: predict_sage(
        samples, predict_SAGE_ND_230628a, Xnew_, xs, ys, xf, yf, num_regions=num_regions
    )
def subsample(samples, step):
    tamples = {}
    for k in samples.keys():
        tamples[k] = samples[k][::step]  
    return tamples  

def split_samples(samples, num_proc, length):
    sample_list = []
    splits = np.array(length/num_proc).astype(int)
    s = {}
    for i in trange(splits):
        for k in samples.keys():
            s[k] = samples[k][(i*num_proc):((i+1)*num_proc)]  
        sample_list.append(s)
    return sample_list

def get_samples_split(samples, num_proc, length, i):
    s = {}
    for k in samples.keys():
        s[k] = samples[k][(i*num_proc):((i+1)*num_proc)]  
    return s    



# -------------------------------------------------



data = [xs, ys, xf, num_regions, jnp.asarray([5.,10.], dtype=jnp.float64), 
         jnp.asarray([1.,2.], dtype=jnp.float64)]

key = jax.random.PRNGKey(0)
autoguide_mle = numpyro.infer.autoguide.AutoLowRankMultivariateNormal(model_SAGE_ND_PM_230628a)
optimizer = numpyro.optim.Adam(step_size=0.05)

svi = nSVI(model_SAGE_ND_PM_230628a, autoguide_mle, optimizer, loss=nTrace_ELBO())
svi_result = svi.run(key, 100000, *data)

params = svi_result.params
mle_2a_st = autoguide_mle.median(params)
preds_st = predict_fn_structure(mle_2a_st)

gpc_new_probs_, gpc_new_latent_ = predict_fn_structure(mle_2a_st)

gpc_latent_ = jnp.vstack((mle_2a_st['gpc_latent_0'],mle_2a_st['gpc_latent_1']))

preds_fp = None

# !!!!!!!!!!!!!!!!!!!!!!!!

# gpc_new_probs_st = predict_fn_st(mle_2a_st)
init_params = {'gpc_latent_0': mle_2a_st['gpc_latent_0'], 'gpc_latent_1': mle_2a_st['gpc_latent_1'],
               'gpc_var': mle_2a_st['gpc_var'],'gpc_lengthscale': mle_2a_st['gpc_lengthscale'],
               'gpc_bias': mle_2a_st['gpc_bias']}
init_strategy=init_to_value(values=init_params)

tic = time.perf_counter()
nuts = nMCMC(nNUTS(model_SAGE_ND_PM_230628a, target_accept_prob=0.8, max_tree_depth=5, init_strategy=init_strategy),
            num_samples=2000, num_warmup=100, num_chains=100)
nuts.run(key, xs, ys, xf, num_regions, gpc_var_bounds = jnp.asarray([5.,10.], dtype=jnp.float64),
         gpc_ls_bounds = jnp.asarray([.1,2.], dtype=jnp.float64))

nuts_posterior_samples = nuts.get_samples()

import dill
with open(r"2D_2an_structure.dill", "wb") as output_file:
    dill.dump(nuts_posterior_samples, output_file)

print('start', nuts_posterior_samples['gpc_bias'].shape[0])
samples = subsample(nuts_posterior_samples, step = 10)
print('after subsampling', samples['gpc_bias'].shape[0]) 

num_length = samples['gpc_bias'].shape[0]

print('splitting')
sl = split_samples(samples, num_proc, num_length)
print('done splitting')

splits = np.array(num_length / num_proc).astype(int)


predict_fn_st_multicore = jax.pmap(
    lambda samples: predict_structure(
        samples, predict_SAGE_ND_PM_230628a, Xnew=Xnew_, xs=xs, ys=ys, num_regions=num_regions
    ), axis_name = 0
)

print('starting pred analysis, for #', num_length)
labels = ['gpc_new_probs']

for i in trange(splits):
    if i == 0:
        preds = predict_fn_st_multicore(sl[i])
        preds_stacked = {labels[0]:preds[0].squeeze()}
    else:
        preds = predict_fn_st_multicore(sl[i])
        for j in range(len(labels)):
            preds_stacked[labels[j]] = np.vstack((preds_stacked[labels[j]],preds[j].squeeze()))
toc = time.perf_counter()
print(f"Run in {toc - tic:0.4f} seconds")

print('done pred analysis')

output = {'preds': preds_stacked, 'preds_st':preds_st, 'preds_fp':preds_fp, 'starting_data':starting_data}
with open(r"2D_2an_structure_1core_231011a.dill", "wb") as output_file:
    dill.dump(output, output_file)

##### Challenge 2: SAGE-ND, multicore

In [None]:
%%writefile sage_2D_2bn_matern52_with_init_230804a.py
# Unified for init 2b
from sage_2D_functions_230804a import *

import matplotlib.pyplot as plt
import numpyro
import numpy as np
from numpy.random import default_rng

import torch
from torch.distributions import constraints
import scipy.io as sio

import jax
import jax.numpy as jnp
from jax import config
config.update("jax_enable_x64", True)
from jax.lax import dynamic_slice
from jax.nn import one_hot as jax_one_hot

from numpyro.infer import MCMC as nMCMC
from numpyro.infer import NUTS as nNUTS
from numpyro.infer import SVI as nSVI
from numpyro.infer import Predictive as nPredictive
from numpyro.infer import Trace_ELBO as nTrace_ELBO
import numpyro.distributions as ndist

torch.set_default_dtype(torch.float64)
from tqdm import trange
import dill
from tqdm import tqdm, trange
from torch.multiprocessing import Pool, Manager, Process
import functools
import math
import time
from numpyro import handlers
from numpyro.infer.initialization import init_to_value

numpyro.set_host_device_count(200)

N = 41
xp_,yp_ = torch.meshgrid(torch.linspace(-2,2,N),torch.linspace(-2,2,N),indexing='xy')
xp = torch.round(xp_.flatten(),decimals=2)
yp = torch.round(yp_.flatten(),decimals=2)
Xp = torch.hstack((xp[:,None],yp[:,None])).double()
Lp, _ = gen_data_2D_example(xp,yp)

r = torch.sqrt(Xp[:,0]**2 + Xp[:,1]**2)

with open(r"2D_2a_and_2b_fp_231030a.dill", "rb") as input_file:
    f = dill.load(input_file)


def map_indices(Xp, idx):  
    for i in range(idx.shape[0]):
        if (10*Xp[idx[i],0] % 2):
            idx[i] += 1
        if (10*Xp[idx[i],1] % 2):
            idx[i] = idx[i] + 41
    return idx

seed = 0
num_data_points_st = 60
num_data_points_fp = 40

top = torch.tensor([912,914, 994,996, 1078,1080, 1162,1164, 1244,1246, 1248,1166, 1250,1168, 1252,1170, 1254,1172, 1256,1174, 1176,1092,
                    1160, 1076, 1094, 1096, 1012,1014, 930,932, 848,850])
bottom = torch.tensor([504, 520, 584, 604, 620, 766,768, 684,686, 600,602, 516,518, 436, 434, 514,432, 512,430, 510,428, 508,426, 424, 506, 586,588, 666,668, 748,750, 830,832])

kp_st = torch.tensor( default_rng(seed).permutation(Xp.shape[0])[:num_data_points_st] )
lines = torch.cat([torch.arange(820,860,2), torch.arange(20,1660,82), torch.arange(0,1680,84), torch.arange(40,1640,80)]).long()

kp_fp = torch.unique(torch.cat([lines, top, bottom]))#,temp ,in_center])) #, torch.tensor([398,150,130,229,269,166,167,211,213,356,357,353,277,399,339,396,209])]))

temp = torch.sqrt(Xp[kp_st,0]**2 + Xp[kp_st,1]**2)
kp_st = kp_st[torch.logical_or( temp > 1.4, temp < .5)]
kp_st = torch.unique(torch.cat([kp_st, torch.tensor([76, 18, 189, 192,229,185,129])]))

xs_2b = Xp[kp_st,:].double().clone()
ys_2b = Lp[kp_st].double().clone()
xf_2b = Xp[kp_fp,:].double().clone()
yf_2b = f[kp_fp,0][:,None].double().clone()

xs = jnp.asarray( xs_2b.detach().numpy(), dtype=jnp.float64)
ys = jnp.asarray( ys_2b.detach().numpy(), dtype=jnp.integer)
xf = jnp.asarray( xf_2b.detach().numpy(), dtype=jnp.float64)
yf = jnp.asarray( yf_2b.detach().numpy(), dtype=jnp.float64)

starting_data = [Xp, Lp, f, xs, ys, xf, yf]

Nn = 41
xn_,yn_ = torch.meshgrid(torch.linspace(-2,2,Nn),torch.linspace(-2,2,Nn),indexing='xy')
xn = xn_.flatten()
yn = yn_.flatten()
X40 = torch.hstack((xn[:,None],yn[:,None])).double()

Xnew_ = jnp.asarray( X40.detach().numpy(), dtype=jnp.float64)

Ns = xs.shape[0] + xf.shape[0]

key = jax.random.PRNGKey(0)
num_regions = 2

def predict_st(post_samples, model, *args, **kwargs):
    key = jax.random.PRNGKey(0)
    model = handlers.seed(handlers.condition(model, post_samples), key)
    model_trace = handlers.trace(model).get_trace(*args, **kwargs)
    return model_trace["gpc_new_probs"]["value"]

def predict_sage(post_samples, model, *args, **kwargs):
    key = jax.random.PRNGKey(0)
    model = handlers.seed(handlers.condition(model, post_samples), key)
    model_trace = handlers.trace(model).get_trace(*args, **kwargs)
    return model_trace["Fr_new"]["value"], model_trace["f_piecewise"]["value"], model_trace["f_sample"]["value"], model_trace["gpc_new_probs"]["value"], model_trace["v_piecewise"]["value"]

predict_fn_st = lambda samples: predict_st(
        samples, predict_SAGE_ND_PM_230628a, Xnew=Xnew_, xs=xs, ys=ys, num_regions=num_regions
    )

predict_fn_sage = jax.pmap(
    lambda samples: predict_sage(
        samples, predict_SAGE_ND_230628a, Xnew=Xnew_, xs=xs, ys=ys, xf=xf, yf=yf, num_regions=num_regions
    ), axis_name = 0
)

def subsample(samples, step):
    tamples = {}
    for k in samples.keys():
        tamples[k] = samples[k][::step]  
    return tamples  

def split_samples(samples, num_proc, length):
    sample_list = []
    splits = np.array(length/num_proc).astype(int)
    s = {}
    for i in trange(splits):
        for k in samples.keys():
            s[k] = samples[k][(i*num_proc):((i+1)*num_proc)]  
        sample_list.append(s)
    return sample_list

def get_samples_split(samples, num_proc, length, i):
    s = {}
    for k in samples.keys():
        s[k] = samples[k][(i*num_proc):((i+1)*num_proc)]  
    return s

if __name__ == '__main__':
    tic = time.perf_counter()
    num_proc = 100
    data = [xs, ys, xf, num_regions, jnp.asarray([5.,10.], dtype=jnp.float64), 
             jnp.asarray([1.,2.], dtype=jnp.float64)]
        
    key = jax.random.PRNGKey(1)
    autoguide_mle = numpyro.infer.autoguide.AutoLowRankMultivariateNormal(model_SAGE_ND_PM_230628a)
    optimizer = numpyro.optim.Adam(step_size=0.01)

    svi = nSVI(model_SAGE_ND_PM_230628a, autoguide_mle, optimizer, loss=nTrace_ELBO())
    svi_result = svi.run(key, 100000, *data)

    params = svi_result.params
    mle_2a_st = autoguide_mle.median(params)
    preds_st = predict_fn_st(mle_2a_st)    
    
    
    init_params = {'gpc_latent_0': mle_2a_st['gpc_latent_0'], 'gpc_latent_1': mle_2a_st['gpc_latent_1'],
                  'gpc_var': mle_2a_st['gpc_var'],'gpc_lengthscale': mle_2a_st['gpc_lengthscale'],
                   'gpc_bias': mle_2a_st['gpc_bias']}
    init_strategy=init_to_value(values=init_params)

    nuts = nMCMC(nNUTS(model_SAGE_ND_230628a, target_accept_prob=0.8, max_tree_depth=5, init_strategy=init_strategy),
                num_samples=2000, num_warmup=100, num_chains=100)
    nuts.run(key, xs, ys, xf, yf, num_regions, gpc_var_bounds = jnp.asarray([5.,10.], dtype=jnp.float64),
             gpc_ls_bounds = jnp.asarray([.1,2.], dtype=jnp.float64), gpr_var_bounds = jnp.asarray([.1, 2.], dtype=jnp.float64),
             gpr_ls_bounds = jnp.asarray([.1,2.], dtype=jnp.float64), gpr_noise_bounds = jnp.asarray([0.001,.01], dtype=jnp.float64))

    nuts_posterior_samples = nuts.get_samples()
    
    import dill
    # with open(r"2D_2b_100chains_4E3samples_matern52_with_init_230906a.dill", "wb") as output_file:
    #     dill.dump(nuts_posterior_samples, output_file)
    
    print('start', nuts_posterior_samples['gpr_noise'].shape[0])
    samples = subsample(nuts_posterior_samples, step = 100)
    print('after subsampling', samples['gpr_noise'].shape[0]) 

    num_length = samples['gpr_noise'].shape[0]
    
    print('splitting')
    sl = split_samples(samples, num_proc, num_length)
    print('done splitting')
    
    splits = np.array(num_length / num_proc).astype(int)

    print('starting pred analysis, for #', num_length)
    labels = ['Fr_new', 'f_piecewise', 'f_sample', 'gpc_new_probs', 'v_piecewise']

    for i in trange(splits):
        if i == 0:
            preds = predict_fn_sage(sl[i])
            preds_stacked = {labels[0]:preds[0].squeeze(), labels[1]:preds[1].squeeze(), labels[2]:preds[2].squeeze(), labels[3]:preds[3].squeeze(), labels[4]:preds[4].squeeze()}
        else:
            preds = predict_fn_sage(sl[i])
            for j in range(len(labels)):
                preds_stacked[labels[j]] = np.vstack((preds_stacked[labels[j]],preds[j].squeeze()))
    toc = time.perf_counter()
    print(f"Run in {toc - tic:0.4f} seconds")    
                
    print('done pred analysis')
    
    output = {'preds': preds_stacked, 'preds_st':preds_st, 'starting_data':starting_data}    
    
    with open(r"2D_2bn_matern52_N41_pred_init_230906a.dill", "wb") as output_file:
        dill.dump(output, output_file)

##### Challenge 2: SAGE-ND, 1 core

In [None]:
%%writefile sage_2D_2bn_matern52_with_init_230804a.py
from sage_2D_functions_230804a import *

import matplotlib.pyplot as plt
import numpyro
import numpy as np
from numpy.random import default_rng

import torch
from torch.distributions import constraints
import scipy.io as sio

import jax
import jax.numpy as jnp
from jax import config
config.update("jax_enable_x64", True)
from jax.lax import dynamic_slice
from jax.nn import one_hot as jax_one_hot

from numpyro.infer import MCMC as nMCMC
from numpyro.infer import NUTS as nNUTS
from numpyro.infer import SVI as nSVI
from numpyro.infer import Predictive as nPredictive
from numpyro.infer import Trace_ELBO as nTrace_ELBO
import numpyro.distributions as ndist

torch.set_default_dtype(torch.float64)
from tqdm import trange
import dill
from tqdm import tqdm, trange
from torch.multiprocessing import Pool, Manager, Process
import functools
import math
import time
from numpyro import handlers
from numpyro.infer.initialization import init_to_value

numpyro.set_host_device_count(200)

N = 41
xp_,yp_ = torch.meshgrid(torch.linspace(-2,2,N),torch.linspace(-2,2,N),indexing='xy')
xp = torch.round(xp_.flatten(),decimals=2)
yp = torch.round(yp_.flatten(),decimals=2)
Xp = torch.hstack((xp[:,None],yp[:,None])).double()
Lp, _ = gen_data_2D_example(xp,yp)

r = torch.sqrt(Xp[:,0]**2 + Xp[:,1]**2)


with open(r"2D_2a_and_2b_fp_231030a.dill", "rb") as input_file:
    f = dill.load(input_file)


def map_indices(Xp, idx):  
    for i in range(idx.shape[0]):
        if (10*Xp[idx[i],0] % 2):
            idx[i] += 1
        if (10*Xp[idx[i],1] % 2):
            idx[i] = idx[i] + 41
    return idx

seed = 0
num_data_points_st = 60
num_data_points_fp = 40

top = torch.tensor([912,914, 994,996, 1078,1080, 1162,1164, 1244,1246, 1248,1166, 1250,1168, 1252,1170, 1254,1172, 1256,1174, 1176,1092,
                    1160, 1076, 1094, 1096, 1012,1014, 930,932, 848,850])
bottom = torch.tensor([504, 520, 584, 604, 620, 766,768, 684,686, 600,602, 516,518, 436, 434, 514,432, 512,430, 510,428, 508,426, 424, 506, 586,588, 666,668, 748,750, 830,832])

kp_st = torch.tensor( default_rng(seed).permutation(Xp.shape[0])[:num_data_points_st] )
lines = torch.cat([torch.arange(820,860,2), torch.arange(20,1660,82), torch.arange(0,1680,84), torch.arange(40,1640,80)]).long()

kp_fp = torch.unique(torch.cat([lines, top, bottom]))#,temp ,in_center])) #, torch.tensor([398,150,130,229,269,166,167,211,213,356,357,353,277,399,339,396,209])]))

temp = torch.sqrt(Xp[kp_st,0]**2 + Xp[kp_st,1]**2)
kp_st = kp_st[torch.logical_or( temp > 1.4, temp < .5)]
kp_st = torch.unique(torch.cat([kp_st, torch.tensor([76, 18, 189, 192,229,185,129])]))

xs_2b = Xp[kp_st,:].double().clone()
ys_2b = Lp[kp_st].double().clone()
xf_2b = Xp[kp_fp,:].double().clone()
yf_2b = f[kp_fp,0][:,None].double().clone()

xs = jnp.asarray( xs_2b.detach().numpy(), dtype=jnp.float64)
ys = jnp.asarray( ys_2b.detach().numpy(), dtype=jnp.integer)
xf = jnp.asarray( xf_2b.detach().numpy(), dtype=jnp.float64)
yf = jnp.asarray( yf_2b.detach().numpy(), dtype=jnp.float64)

starting_data = [Xp, Lp, f, xs, ys, xf, yf]

Nn = 41
xn_,yn_ = torch.meshgrid(torch.linspace(-2,2,Nn),torch.linspace(-2,2,Nn),indexing='xy')
xn = xn_.flatten()
yn = yn_.flatten()
X40 = torch.hstack((xn[:,None],yn[:,None])).double()

Xnew_ = jnp.asarray( X40.detach().numpy(), dtype=jnp.float64)

Ns = xs.shape[0] + xf.shape[0]

key = jax.random.PRNGKey(0)
num_regions = 2

def predict_st(post_samples, model, *args, **kwargs):
    key = jax.random.PRNGKey(0)
    model = handlers.seed(handlers.condition(model, post_samples), key)
    model_trace = handlers.trace(model).get_trace(*args, **kwargs)
    return model_trace["gpc_new_probs"]["value"]

def predict_sage(post_samples, model, *args, **kwargs):
    key = jax.random.PRNGKey(0)
    model = handlers.seed(handlers.condition(model, post_samples), key)
    model_trace = handlers.trace(model).get_trace(*args, **kwargs)
    return model_trace["Fr_new"]["value"], model_trace["f_piecewise"]["value"], model_trace["f_sample"]["value"], model_trace["gpc_new_probs"]["value"], model_trace["v_piecewise"]["value"]

predict_fn_st = lambda samples: predict_st(
        samples, predict_SAGE_ND_PM_230628a, Xnew=Xnew_, xs=xs, ys=ys, num_regions=num_regions
    )

predict_fn_sage = jax.pmap(
    lambda samples: predict_sage(
        samples, predict_SAGE_ND_230628a, Xnew=Xnew_, xs=xs, ys=ys, xf=xf, yf=yf, num_regions=num_regions
    ), axis_name = 0
)

def subsample(samples, step):
    tamples = {}
    for k in samples.keys():
        tamples[k] = samples[k][::step]  
    return tamples  

def split_samples(samples, num_proc, length):
    sample_list = []
    splits = np.array(length/num_proc).astype(int)
    s = {}
    for i in trange(splits):
        for k in samples.keys():
            s[k] = samples[k][(i*num_proc):((i+1)*num_proc)]  
        sample_list.append(s)
    return sample_list

def get_samples_split(samples, num_proc, length, i):
    s = {}
    for k in samples.keys():
        s[k] = samples[k][(i*num_proc):((i+1)*num_proc)]  
    return s

tic = time.perf_counter()
num_proc = 1
data = [xs, ys, xf, num_regions, jnp.asarray([5.,10.], dtype=jnp.float64), 
         jnp.asarray([1.,2.], dtype=jnp.float64)]

key = jax.random.PRNGKey(1)
autoguide_mle = numpyro.infer.autoguide.AutoLowRankMultivariateNormal(model_SAGE_ND_PM_230628a)
optimizer = numpyro.optim.Adam(step_size=0.01)

svi = nSVI(model_SAGE_ND_PM_230628a, autoguide_mle, optimizer, loss=nTrace_ELBO())
svi_result = svi.run(key, 100000, *data)

params = svi_result.params
mle_2a_st = autoguide_mle.median(params)
preds_st = predict_fn_st(mle_2a_st)    


init_params = {'gpc_latent_0': mle_2a_st['gpc_latent_0'], 'gpc_latent_1': mle_2a_st['gpc_latent_1'],
              'gpc_var': mle_2a_st['gpc_var'],'gpc_lengthscale': mle_2a_st['gpc_lengthscale'],
               'gpc_bias': mle_2a_st['gpc_bias']}
init_strategy=init_to_value(values=init_params)

nuts = nMCMC(nNUTS(model_SAGE_ND_230628a, target_accept_prob=0.8, max_tree_depth=5, init_strategy=init_strategy),
            num_samples=2000, num_warmup=100, num_chains=100)
nuts.run(key, xs, ys, xf, yf, num_regions, gpc_var_bounds = jnp.asarray([5.,10.], dtype=jnp.float64),
         gpc_ls_bounds = jnp.asarray([.1,2.], dtype=jnp.float64), gpr_var_bounds = jnp.asarray([.1, 2.], dtype=jnp.float64),
         gpr_ls_bounds = jnp.asarray([.1,2.], dtype=jnp.float64), gpr_noise_bounds = jnp.asarray([0.001,.01], dtype=jnp.float64))

nuts_posterior_samples = nuts.get_samples()

import dill

print('start', nuts_posterior_samples['gpr_noise'].shape[0])
samples = subsample(nuts_posterior_samples, step = 100)
print('after subsampling', samples['gpr_noise'].shape[0]) 

num_length = samples['gpr_noise'].shape[0]

print('splitting')
sl = split_samples(samples, num_proc, num_length)
print('done splitting')

splits = np.array(num_length / num_proc).astype(int)

print('starting pred analysis, for #', num_length)
labels = ['Fr_new', 'f_piecewise', 'f_sample', 'gpc_new_probs', 'v_piecewise']

for i in trange(splits):
    if i == 0:
        preds = predict_fn_sage(sl[i])
        preds_stacked = {labels[0]:preds[0].squeeze(), labels[1]:preds[1].squeeze(), labels[2]:preds[2].squeeze(), labels[3]:preds[3].squeeze(), labels[4]:preds[4].squeeze()}
    else:
        preds = predict_fn_sage(sl[i])
        for j in range(len(labels)):
            preds_stacked[labels[j]] = np.vstack((preds_stacked[labels[j]],preds[j].squeeze()))
toc = time.perf_counter()
print(f"Run in {toc - tic:0.4f} seconds")    

print('done pred analysis')

output = {'preds': preds_stacked, 'preds_st':preds_st, 'starting_data':starting_data}    

with open(r"2D_2bn_matern52_1core_230906a.dill", "wb") as output_file:
    dill.dump(output, output_file)

##### Challenge 2: SAGE-ND-FP, multicore

In [None]:
%%writefile sage_2D_2bn_FP_matern52_230804a.py

from sage_2D_functions_230804a import *

import matplotlib.pyplot as plt
import numpyro
import numpy as np
from numpy.random import default_rng

import torch
from torch.distributions import constraints
import scipy.io as sio

import jax
import jax.numpy as jnp
from jax import config
config.update("jax_enable_x64", True)
from jax.lax import dynamic_slice
from jax.nn import one_hot as jax_one_hot

from numpyro.infer import MCMC as nMCMC
from numpyro.infer import NUTS as nNUTS
from numpyro.infer import SVI as nSVI
from numpyro.infer import Predictive as nPredictive
from numpyro.infer import Trace_ELBO as nTrace_ELBO
import numpyro.distributions as ndist

torch.set_default_dtype(torch.float64)
from tqdm import trange
import dill
from tqdm import tqdm, trange
from torch.multiprocessing import Pool, Manager, Process
import functools
import math
import time
from numpyro import handlers
from numpyro.infer.initialization import init_to_value

numpyro.set_host_device_count(200)

if __name__ == '__main__':
    num_proc = 100
    N = 41
    xp_,yp_ = torch.meshgrid(torch.linspace(-2,2,N),torch.linspace(-2,2,N),indexing='xy')
    xp = torch.round(xp_.flatten(),decimals=2)
    yp = torch.round(yp_.flatten(),decimals=2)
    Xp = torch.hstack((xp[:,None],yp[:,None])).double()
    Lp, _ = gen_data_2D_example(xp,yp)

    r = torch.sqrt(Xp[:,0]**2 + Xp[:,1]**2)


    with open(r"2D_2a_and_2b_fp_231030a.dill", "rb") as input_file:
        f = dill.load(input_file)


    def map_indices(Xp, idx):  
        for i in range(idx.shape[0]):
            if (10*Xp[idx[i],0] % 2):
                idx[i] += 1
            if (10*Xp[idx[i],1] % 2):
                idx[i] = idx[i] + 41
        return idx

    seed = 0
    num_data_points_st = 60
    num_data_points_fp = 40

    top = torch.tensor([912,914, 994,996, 1078,1080, 1162,1164, 1244,1246, 1248,1166, 1250,1168, 1252,1170, 1254,1172, 1256,1174, 1176,1092,
                        1160, 1076, 1094, 1096, 1012,1014, 930,932, 848,850])
    bottom = torch.tensor([504, 520, 584, 604, 620, 766,768, 684,686, 600,602, 516,518, 436, 434, 514,432, 512,430, 510,428, 508,426, 424, 506, 586,588, 666,668, 748,750, 830,832])

    kp_st = torch.tensor( default_rng(seed).permutation(Xp.shape[0])[:num_data_points_st] )
    lines = torch.cat([torch.arange(820,860,2), torch.arange(20,1660,82), torch.arange(0,1680,84), torch.arange(40,1640,80)]).long()

    kp_fp = torch.unique(torch.cat([lines, top, bottom]))#,temp ,in_center])) #, torch.tensor([398,150,130,229,269,166,167,211,213,356,357,353,277,399,339,396,209])]))

    temp = torch.sqrt(Xp[kp_st,0]**2 + Xp[kp_st,1]**2)
    kp_st = kp_st[torch.logical_or( temp > 1.4, temp < .5)]
    kp_st = torch.unique(torch.cat([kp_st, torch.tensor([76, 18, 189, 192,229,185,129])]))

    xs_2b = Xp[kp_st,:].double().clone()
    ys_2b = Lp[kp_st].double().clone()
    xf_2b = Xp[kp_fp,:].double().clone()
    yf_2b = f[kp_fp,0][:,None].double().clone()

    xs = jnp.asarray( xs_2b.detach().numpy(), dtype=jnp.float64)
    ys = jnp.asarray( ys_2b.detach().numpy(), dtype=jnp.integer)
    xf = jnp.asarray( xf_2b.detach().numpy(), dtype=jnp.float64)
    yf = jnp.asarray( yf_2b.detach().numpy(), dtype=jnp.float64)

    starting_data = [Xp, Lp, f, xs, ys, xf, yf]

    Nn = 40
    xn_,yn_ = torch.meshgrid(torch.linspace(-2,2,Nn),torch.linspace(-2,2,Nn),indexing='xy')
    xn = xn_.flatten()
    yn = yn_.flatten()
    X40 = torch.hstack((xn[:,None],yn[:,None])).double()

    Xnew_ = jnp.asarray( X40.detach().numpy(), dtype=jnp.float64)

    Ns = xs.shape[0] + xf.shape[0]

    key = jax.random.PRNGKey(0)
    num_regions = 2

    def predict_fp(post_samples, model, *args, **kwargs):
        key = jax.random.PRNGKey(0)
        model = handlers.seed(handlers.condition(model, post_samples), key)
        model_trace = handlers.trace(model).get_trace(*args, **kwargs)
        return model_trace["Fr_new"]["value"], model_trace["f_piecewise"]["value"], model_trace["f_sample"]["value"], model_trace["gpc_new_probs"]["value"], model_trace["v_piecewise"]["value"]

    def subsample(samples, step):
        tamples = {}
        for k in samples.keys():
            tamples[k] = samples[k][::step]  
        return tamples  

    def split_samples(samples, num_proc, length):
        sample_list = []
        splits = np.array(length/num_proc).astype(int)
        s = {}
        for i in trange(splits):
            for k in samples.keys():
                s[k] = samples[k][(i*num_proc):((i+1)*num_proc)]  
            sample_list.append(s)
        return sample_list

    def get_samples_split(samples, num_proc, length, i):
        s = {}
        for k in samples.keys():
            s[k] = samples[k][(i*num_proc):((i+1)*num_proc)]  
        return s    


    # !!!!!!!!!!!!!!!!!!!!!!!!
    
    tic = time.perf_counter()
    nuts = nMCMC(nNUTS(model_SAGE_ND_FP_230628a, target_accept_prob=0.8, max_tree_depth=5),
                num_samples=2000, num_warmup=100, num_chains=100)
    nuts.run(key, xf, yf, num_regions, gpc_var_bounds = jnp.asarray([5.,10.], dtype=jnp.float64),
             gpc_ls_bounds = jnp.asarray([.1,2.], dtype=jnp.float64), gpr_var_bounds = jnp.asarray([.1, 2.], dtype=jnp.float64),
             gpr_ls_bounds = jnp.asarray([.1,2.], dtype=jnp.float64), gpr_noise_bounds = jnp.asarray([0.001,.1], dtype=jnp.float64))

    nuts_posterior_samples = nuts.get_samples()
    
    import dill
    # with open(r"2D_temp.dill", "wb") as output_file:
    #     dill.dump(nuts_posterior_samples, output_file)

    print('start', nuts_posterior_samples['gpr_noise'].shape[0])
    samples = subsample(nuts_posterior_samples, step = 10)
    print('after subsampling', samples['gpr_noise'].shape[0]) 

    num_length = samples['gpr_noise'].shape[0]
    
    print('splitting')
    sl = split_samples(samples, num_proc, num_length)
    print('done splitting')
    
    splits = np.array(num_length / num_proc).astype(int)
   
    predict_fn_sage = jax.pmap(
        lambda samples: predict_fp(
            samples, predict_SAGE_ND_FP_230628a, Xnew=Xnew_, xf=xf, yf=yf, num_regions=num_regions
        ), axis_name = 0
    )

    print('starting pred analysis, for #', num_length)
    labels = ['Fr_new', 'f_piecewise', 'f_sample', 'gpc_new_probs', 'v_piecewise']

    for i in trange(splits):
        if i == 0:
            preds = predict_fn_sage(sl[i])
            preds_stacked = {labels[0]:preds[0].squeeze(), labels[1]:preds[1].squeeze(), labels[2]:preds[2].squeeze(), labels[3]:preds[3].squeeze(), labels[4]:preds[4].squeeze()}
        else:
            preds = predict_fn_sage(sl[i])
            for j in range(len(labels)):
                preds_stacked[labels[j]] = np.vstack((preds_stacked[labels[j]],preds[j].squeeze()))
    toc = time.perf_counter()
    print(f"Run in {toc - tic:0.4f} seconds")    
                
    print('done pred analysis')
    
    output = {'preds': preds_stacked, 'starting_data':starting_data}
    with open(r"2D_2bn_fp_matern52_231011a.dill", "wb") as output_file:
        dill.dump(output, output_file)

##### Challenge 2: SAGE-ND-FP, 1 core

In [None]:
%%writefile sage_2D_2bn_FP_matern52_230804a.py

from sage_2D_functions_230804a import *

import matplotlib.pyplot as plt
import numpyro
import numpy as np
from numpy.random import default_rng

import torch
from torch.distributions import constraints
import scipy.io as sio

import jax
import jax.numpy as jnp
from jax import config
config.update("jax_enable_x64", True)
from jax.lax import dynamic_slice
from jax.nn import one_hot as jax_one_hot

from numpyro.infer import MCMC as nMCMC
from numpyro.infer import NUTS as nNUTS
from numpyro.infer import SVI as nSVI
from numpyro.infer import Predictive as nPredictive
from numpyro.infer import Trace_ELBO as nTrace_ELBO
import numpyro.distributions as ndist

torch.set_default_dtype(torch.float64)
from tqdm import trange
import dill
from tqdm import tqdm, trange
from torch.multiprocessing import Pool, Manager, Process
import functools
import math
import time
from numpyro import handlers
from numpyro.infer.initialization import init_to_value

numpyro.set_host_device_count(200)

num_proc = 1
N = 41
xp_,yp_ = torch.meshgrid(torch.linspace(-2,2,N),torch.linspace(-2,2,N),indexing='xy')
xp = torch.round(xp_.flatten(),decimals=2)
yp = torch.round(yp_.flatten(),decimals=2)
Xp = torch.hstack((xp[:,None],yp[:,None])).double()
Lp, _ = gen_data_2D_example(xp,yp)

r = torch.sqrt(Xp[:,0]**2 + Xp[:,1]**2)


with open(r"2D_2a_and_2b_fp_231030a.dill", "rb") as input_file:
    f = dill.load(input_file)


def map_indices(Xp, idx):  
    for i in range(idx.shape[0]):
        if (10*Xp[idx[i],0] % 2):
            idx[i] += 1
        if (10*Xp[idx[i],1] % 2):
            idx[i] = idx[i] + 41
    return idx

seed = 0
num_data_points_st = 60
num_data_points_fp = 40

top = torch.tensor([912,914, 994,996, 1078,1080, 1162,1164, 1244,1246, 1248,1166, 1250,1168, 1252,1170, 1254,1172, 1256,1174, 1176,1092,
                    1160, 1076, 1094, 1096, 1012,1014, 930,932, 848,850])
bottom = torch.tensor([504, 520, 584, 604, 620, 766,768, 684,686, 600,602, 516,518, 436, 434, 514,432, 512,430, 510,428, 508,426, 424, 506, 586,588, 666,668, 748,750, 830,832])

kp_st = torch.tensor( default_rng(seed).permutation(Xp.shape[0])[:num_data_points_st] )
lines = torch.cat([torch.arange(820,860,2), torch.arange(20,1660,82), torch.arange(0,1680,84), torch.arange(40,1640,80)]).long()

kp_fp = torch.unique(torch.cat([lines, top, bottom]))#,temp ,in_center])) #, torch.tensor([398,150,130,229,269,166,167,211,213,356,357,353,277,399,339,396,209])]))

temp = torch.sqrt(Xp[kp_st,0]**2 + Xp[kp_st,1]**2)
kp_st = kp_st[torch.logical_or( temp > 1.4, temp < .5)]
kp_st = torch.unique(torch.cat([kp_st, torch.tensor([76, 18, 189, 192,229,185,129])]))

xs_2b = Xp[kp_st,:].double().clone()
ys_2b = Lp[kp_st].double().clone()
xf_2b = Xp[kp_fp,:].double().clone()
yf_2b = f[kp_fp,0][:,None].double().clone()

xs = jnp.asarray( xs_2b.detach().numpy(), dtype=jnp.float64)
ys = jnp.asarray( ys_2b.detach().numpy(), dtype=jnp.integer)
xf = jnp.asarray( xf_2b.detach().numpy(), dtype=jnp.float64)
yf = jnp.asarray( yf_2b.detach().numpy(), dtype=jnp.float64)

starting_data = [Xp, Lp, f, xs, ys, xf, yf]

Nn = 40
xn_,yn_ = torch.meshgrid(torch.linspace(-2,2,Nn),torch.linspace(-2,2,Nn),indexing='xy')
xn = xn_.flatten()
yn = yn_.flatten()
X40 = torch.hstack((xn[:,None],yn[:,None])).double()

Xnew_ = jnp.asarray( X40.detach().numpy(), dtype=jnp.float64)

Ns = xs.shape[0] + xf.shape[0]

key = jax.random.PRNGKey(0)
num_regions = 2

def predict_fp(post_samples, model, *args, **kwargs):
    key = jax.random.PRNGKey(0)
    model = handlers.seed(handlers.condition(model, post_samples), key)
    model_trace = handlers.trace(model).get_trace(*args, **kwargs)
    return model_trace["Fr_new"]["value"], model_trace["f_piecewise"]["value"], model_trace["f_sample"]["value"], model_trace["gpc_new_probs"]["value"], model_trace["v_piecewise"]["value"]

def subsample(samples, step):
    tamples = {}
    for k in samples.keys():
        tamples[k] = samples[k][::step]  
    return tamples  

def split_samples(samples, num_proc, length):
    sample_list = []
    splits = np.array(length/num_proc).astype(int)
    s = {}
    for i in trange(splits):
        for k in samples.keys():
            s[k] = samples[k][(i*num_proc):((i+1)*num_proc)]  
        sample_list.append(s)
    return sample_list

def get_samples_split(samples, num_proc, length, i):
    s = {}
    for k in samples.keys():
        s[k] = samples[k][(i*num_proc):((i+1)*num_proc)]  
    return s    


# !!!!!!!!!!!!!!!!!!!!!!!!

tic = time.perf_counter()
nuts = nMCMC(nNUTS(model_SAGE_ND_FP_230628a, target_accept_prob=0.8, max_tree_depth=5),
            num_samples=2000, num_warmup=100, num_chains=100)
nuts.run(key, xf, yf, num_regions, gpc_var_bounds = jnp.asarray([5.,10.], dtype=jnp.float64),
         gpc_ls_bounds = jnp.asarray([.1,2.], dtype=jnp.float64), gpr_var_bounds = jnp.asarray([.1, 2.], dtype=jnp.float64),
         gpr_ls_bounds = jnp.asarray([.1,2.], dtype=jnp.float64), gpr_noise_bounds = jnp.asarray([0.001,.1], dtype=jnp.float64))

nuts_posterior_samples = nuts.get_samples()

import dill
# with open(r"2D_temp.dill", "wb") as output_file:
#     dill.dump(nuts_posterior_samples, output_file)

print('start', nuts_posterior_samples['gpr_noise'].shape[0])
samples = subsample(nuts_posterior_samples, step = 10)
print('after subsampling', samples['gpr_noise'].shape[0]) 

num_length = samples['gpr_noise'].shape[0]

print('splitting')
sl = split_samples(samples, num_proc, num_length)
print('done splitting')

splits = np.array(num_length / num_proc).astype(int)

predict_fn_sage = jax.pmap(
    lambda samples: predict_fp(
        samples, predict_SAGE_ND_FP_230628a, Xnew=Xnew_, xf=xf, yf=yf, num_regions=num_regions
    ), axis_name = 0
)

print('starting pred analysis, for #', num_length)
labels = ['Fr_new', 'f_piecewise', 'f_sample', 'gpc_new_probs', 'v_piecewise']

for i in trange(splits):
    if i == 0:
        preds = predict_fn_sage(sl[i])
        preds_stacked = {labels[0]:preds[0].squeeze(), labels[1]:preds[1].squeeze(), labels[2]:preds[2].squeeze(), labels[3]:preds[3].squeeze(), labels[4]:preds[4].squeeze()}
    else:
        preds = predict_fn_sage(sl[i])
        for j in range(len(labels)):
            preds_stacked[labels[j]] = np.vstack((preds_stacked[labels[j]],preds[j].squeeze()))
toc = time.perf_counter()
print(f"Run in {toc - tic:0.4f} seconds")    

print('done pred analysis')

output = {'preds': preds_stacked, 'starting_data':starting_data}
with open(r"2D_2bn_fp_1core_231011a.dill", "wb") as output_file:
    dill.dump(output, output_file)

##### Challenge 2: SAGE-ND-PM, multicore

In [None]:
%%writefile sage_2D_2bn_structure_matern_with_1init_230804a.py
# Unified for init 2a
from sage_2D_functions_230804a import predict_SAGE_ND_PM_230628a, predict_SAGE_ND_230628a, model_SAGE_ND_230628a
from sage_2D_functions_230804a import gen_data_2D_example, model_SAGE_ND_FP_230628a, model_SAGE_ND_PM_230628a

import matplotlib.pyplot as plt
import numpyro
import numpy as np
from numpy.random import default_rng

import torch
from torch.distributions import constraints
import scipy.io as sio

import jax
import jax.numpy as jnp
from jax import config
config.update("jax_enable_x64", True)
from jax.lax import dynamic_slice
from jax.nn import one_hot as jax_one_hot

from numpyro.infer import MCMC as nMCMC
from numpyro.infer import NUTS as nNUTS
from numpyro.infer import SVI as nSVI
from numpyro.infer import Predictive as nPredictive
from numpyro.infer import Trace_ELBO as nTrace_ELBO
import numpyro.distributions as ndist

torch.set_default_dtype(torch.float64)
from tqdm import trange
import dill
from tqdm import tqdm, trange
from torch.multiprocessing import Pool, Manager, Process
import functools
import math
import time
from numpyro import handlers
from numpyro.infer.initialization import init_to_value

numpyro.set_host_device_count(200)


if __name__ == '__main__':
    num_proc = 100
    
    N = 41
    xp_,yp_ = torch.meshgrid(torch.linspace(-2,2,N),torch.linspace(-2,2,N),indexing='xy')
    xp = torch.round(xp_.flatten(),decimals=2)
    yp = torch.round(yp_.flatten(),decimals=2)
    Xp = torch.hstack((xp[:,None],yp[:,None])).double()
    Lp, _ = gen_data_2D_example(xp,yp)

    r = torch.sqrt(Xp[:,0]**2 + Xp[:,1]**2)

    with open(r"2D_2a_and_2b_fp_231030a.dill", "rb") as input_file:
        f = dill.load(input_file)


    seed = 0
    num_data_points_st = 60
    num_data_points_fp = 40

    top = torch.tensor([912,914, 994,996, 1078,1080, 1162,1164, 1244,1246, 1248,1166, 1250,1168, 1252,1170, 1254,1172, 1256,1174, 1176,1092,
                        1160, 1076, 1094, 1096, 1012,1014, 930,932, 848,850])
    bottom = torch.tensor([504, 520, 584, 604, 620, 766,768, 684,686, 600,602, 516,518, 436, 434, 514,432, 512,430, 510,428, 508,426, 424, 506, 586,588, 666,668, 748,750, 830,832])

    kp_st = torch.tensor( default_rng(seed).permutation(Xp.shape[0])[:num_data_points_st] )
    lines = torch.cat([torch.arange(820,860,2), torch.arange(20,1660,82), torch.arange(0,1680,84), torch.arange(40,1640,80)]).long()

    kp_fp = torch.unique(torch.cat([lines, top, bottom]))#,temp ,in_center])) #, torch.tensor([398,150,130,229,269,166,167,211,213,356,357,353,277,399,339,396,209])]))

    temp = torch.sqrt(Xp[kp_st,0]**2 + Xp[kp_st,1]**2)
    kp_st = kp_st[torch.logical_or( temp > 1.4, temp < .5)]
    kp_st = torch.unique(torch.cat([kp_st, torch.tensor([76, 18, 189, 192,229,185,129])]))

    xs_2b = Xp[kp_st,:].double().clone()
    ys_2b = Lp[kp_st].double().clone()
    xf_2b = Xp[kp_fp,:].double().clone()
    yf_2b = f[kp_fp,0][:,None].double().clone()

    xs = jnp.asarray( xs_2b.detach().numpy(), dtype=jnp.float64)
    ys = jnp.asarray( ys_2b.detach().numpy(), dtype=jnp.integer)
    xf = jnp.asarray( xf_2b.detach().numpy(), dtype=jnp.float64)
    yf = jnp.asarray( yf_2b.detach().numpy(), dtype=jnp.float64)

    starting_data = [Xp, Lp, f, xs, ys, xf, yf]
    
    Ns = xs.shape[0] + xf.shape[0]

    Nn = 41
    xn_,yn_ = torch.meshgrid(torch.linspace(-2,2,Nn),torch.linspace(-2,2,Nn),indexing='xy')
    xn = torch.round(xn_.flatten(),decimals=2)
    yn = torch.round(yn_.flatten(),decimals=2)
    X40 = torch.hstack((xn[:,None],yn[:,None])).double()
    Xnew_ = jnp.asarray( X40.detach().numpy(), dtype=jnp.float64).copy()


    key = jax.random.PRNGKey(0)
    num_regions = 2

    def predict_structure(post_samples, model, *args, **kwargs):
        key = jax.random.PRNGKey(0)
        model = handlers.seed(handlers.condition(model, post_samples), key)
        model_trace = handlers.trace(model).get_trace(*args, **kwargs)
        return model_trace["gpc_new_probs"]["value"], model_trace["gpc_new_latent"]["value"]

    def predict_sage(post_samples, model, *args, **kwargs):
        key = jax.random.PRNGKey(0)
        model = handlers.seed(handlers.condition(model, post_samples), key)
        model_trace = handlers.trace(model).get_trace(*args, **kwargs)
        return model_trace["Fr_new"]["value"], model_trace["f_piecewise"]["value"], model_trace["f_sample"]["value"], model_trace["gpc_new_probs"]["value"], model_trace["v_piecewise"]["value"]

    predict_fn_structure = lambda samples: predict_structure(
            samples, predict_SAGE_ND_PM_230628a, Xnew=Xnew_, xs=xs, ys=ys, num_regions=num_regions
        )
    predict_fn_sage_1core = lambda samples: predict_sage(
            samples, predict_SAGE_ND_230628a, Xnew_, xs, ys, xf, yf, num_regions=num_regions
        )
    def subsample(samples, step):
        tamples = {}
        for k in samples.keys():
            tamples[k] = samples[k][::step]  
        return tamples  

    def split_samples(samples, num_proc, length):
        sample_list = []
        splits = np.array(length/num_proc).astype(int)
        s = {}
        for i in trange(splits):
            for k in samples.keys():
                s[k] = samples[k][(i*num_proc):((i+1)*num_proc)]  
            sample_list.append(s)
        return sample_list

    def get_samples_split(samples, num_proc, length, i):
        s = {}
        for k in samples.keys():
            s[k] = samples[k][(i*num_proc):((i+1)*num_proc)]  
        return s    
    
    
    #------------------------------------
    
    
    data = [xs, ys, xf, num_regions, jnp.asarray([5.,10.], dtype=jnp.float64), 
             jnp.asarray([1.,2.], dtype=jnp.float64)]
        
    key = jax.random.PRNGKey(0)
    autoguide_mle = numpyro.infer.autoguide.AutoLowRankMultivariateNormal(model_SAGE_ND_PM_230628a)
    optimizer = numpyro.optim.Adam(step_size=0.05)

    svi = nSVI(model_SAGE_ND_PM_230628a, autoguide_mle, optimizer, loss=nTrace_ELBO())
    svi_result = svi.run(key, 100000, *data)

    params = svi_result.params
    mle_2a_st = autoguide_mle.median(params)
    preds_st = predict_fn_structure(mle_2a_st)
    
    gpc_new_probs_, gpc_new_latent_ = predict_fn_structure(mle_2a_st)
    
    gpc_latent_ = jnp.vstack((mle_2a_st['gpc_latent_0'],mle_2a_st['gpc_latent_1']))

    preds_fp = None
    
    # !!!!!!!!!!!!!!!!!!!!!!!!

    # gpc_new_probs_st = predict_fn_st(mle_2a_st)
    init_params = {'gpc_latent_0': mle_2a_st['gpc_latent_0'], 'gpc_latent_1': mle_2a_st['gpc_latent_1'],
                   'gpc_var': mle_2a_st['gpc_var'],'gpc_lengthscale': mle_2a_st['gpc_lengthscale'],
                   'gpc_bias': mle_2a_st['gpc_bias']}
    init_strategy=init_to_value(values=init_params)
    
    tic = time.perf_counter()
    nuts = nMCMC(nNUTS(model_SAGE_ND_PM_230628a, target_accept_prob=0.8, max_tree_depth=5, init_strategy=init_strategy),
                num_samples=2000, num_warmup=100, num_chains=100)
    nuts.run(key, xs, ys, xf, num_regions, gpc_var_bounds = jnp.asarray([5.,10.], dtype=jnp.float64),
             gpc_ls_bounds = jnp.asarray([.1,2.], dtype=jnp.float64))

    nuts_posterior_samples = nuts.get_samples()
    
    import dill
    # with open(r"2D_2bn_structure.dill", "wb") as output_file:
    #     dill.dump(nuts_posterior_samples, output_file)

    print('start', nuts_posterior_samples['gpc_bias'].shape[0])
    samples = subsample(nuts_posterior_samples, step = 10)
    print('after subsampling', samples['gpc_bias'].shape[0]) 

    num_length = samples['gpc_bias'].shape[0]
    
    print('splitting')
    sl = split_samples(samples, num_proc, num_length)
    print('done splitting')
    
    splits = np.array(num_length / num_proc).astype(int)
   

    predict_fn_st_multicore = jax.pmap(
        lambda samples: predict_structure(
            samples, predict_SAGE_ND_PM_230628a, Xnew=Xnew_, xs=xs, ys=ys, num_regions=num_regions
        ), axis_name = 0
    )

    print('starting pred analysis, for #', num_length)
    labels = ['gpc_new_probs']

    for i in trange(splits):
        if i == 0:
            preds = predict_fn_st_multicore(sl[i])
            preds_stacked = {labels[0]:preds[0].squeeze()}
        else:
            preds = predict_fn_st_multicore(sl[i])
            for j in range(len(labels)):
                preds_stacked[labels[j]] = np.vstack((preds_stacked[labels[j]],preds[j].squeeze()))
    toc = time.perf_counter()
    print(f"Run in {toc - tic:0.4f} seconds")
                
    print('done pred analysis')
    
    output = {'preds': preds_stacked, 'preds_st':preds_st, 'preds_fp':preds_fp, 'starting_data':starting_data}
    with open(r"2D_2bn_structure_matern52_N41_231011a.dill", "wb") as output_file:
        dill.dump(output, output_file)

##### Challenge 2: SAGE-ND-PM, 1 core

In [None]:
%%writefile sage_2D_2bn_structure_matern_with_1init_230804a.py
# Unified for init 2a
from sage_2D_functions_230804a import predict_SAGE_ND_PM_230628a, predict_SAGE_ND_230628a, model_SAGE_ND_230628a
from sage_2D_functions_230804a import gen_data_2D_example, model_SAGE_ND_FP_230628a, model_SAGE_ND_PM_230628a

import matplotlib.pyplot as plt
import numpyro
import numpy as np
from numpy.random import default_rng

import torch
from torch.distributions import constraints
import scipy.io as sio

import jax
import jax.numpy as jnp
from jax import config
config.update("jax_enable_x64", True)
from jax.lax import dynamic_slice
from jax.nn import one_hot as jax_one_hot

from numpyro.infer import MCMC as nMCMC
from numpyro.infer import NUTS as nNUTS
from numpyro.infer import SVI as nSVI
from numpyro.infer import Predictive as nPredictive
from numpyro.infer import Trace_ELBO as nTrace_ELBO
import numpyro.distributions as ndist

torch.set_default_dtype(torch.float64)
from tqdm import trange
import dill
from tqdm import tqdm, trange
from torch.multiprocessing import Pool, Manager, Process
import functools
import math
import time
from numpyro import handlers
from numpyro.infer.initialization import init_to_value

numpyro.set_host_device_count(200)

num_proc = 1

N = 41
xp_,yp_ = torch.meshgrid(torch.linspace(-2,2,N),torch.linspace(-2,2,N),indexing='xy')
xp = torch.round(xp_.flatten(),decimals=2)
yp = torch.round(yp_.flatten(),decimals=2)
Xp = torch.hstack((xp[:,None],yp[:,None])).double()
Lp, _ = gen_data_2D_example(xp,yp)

r = torch.sqrt(Xp[:,0]**2 + Xp[:,1]**2)

with open(r"2D_2a_and_2b_fp_231030a.dill", "rb") as input_file:
    f = dill.load(input_file)


seed = 0
num_data_points_st = 60
num_data_points_fp = 40

top = torch.tensor([912,914, 994,996, 1078,1080, 1162,1164, 1244,1246, 1248,1166, 1250,1168, 1252,1170, 1254,1172, 1256,1174, 1176,1092,
                    1160, 1076, 1094, 1096, 1012,1014, 930,932, 848,850])
bottom = torch.tensor([504, 520, 584, 604, 620, 766,768, 684,686, 600,602, 516,518, 436, 434, 514,432, 512,430, 510,428, 508,426, 424, 506, 586,588, 666,668, 748,750, 830,832])

kp_st = torch.tensor( default_rng(seed).permutation(Xp.shape[0])[:num_data_points_st] )
lines = torch.cat([torch.arange(820,860,2), torch.arange(20,1660,82), torch.arange(0,1680,84), torch.arange(40,1640,80)]).long()

kp_fp = torch.unique(torch.cat([lines, top, bottom]))#,temp ,in_center])) #, torch.tensor([398,150,130,229,269,166,167,211,213,356,357,353,277,399,339,396,209])]))

temp = torch.sqrt(Xp[kp_st,0]**2 + Xp[kp_st,1]**2)
kp_st = kp_st[torch.logical_or( temp > 1.4, temp < .5)]
kp_st = torch.unique(torch.cat([kp_st, torch.tensor([76, 18, 189, 192,229,185,129])]))

xs_2b = Xp[kp_st,:].double().clone()
ys_2b = Lp[kp_st].double().clone()
xf_2b = Xp[kp_fp,:].double().clone()
yf_2b = f[kp_fp,0][:,None].double().clone()

xs = jnp.asarray( xs_2b.detach().numpy(), dtype=jnp.float64)
ys = jnp.asarray( ys_2b.detach().numpy(), dtype=jnp.integer)
xf = jnp.asarray( xf_2b.detach().numpy(), dtype=jnp.float64)
yf = jnp.asarray( yf_2b.detach().numpy(), dtype=jnp.float64)

starting_data = [Xp, Lp, f, xs, ys, xf, yf]

Ns = xs.shape[0] + xf.shape[0]

Nn = 41
xn_,yn_ = torch.meshgrid(torch.linspace(-2,2,Nn),torch.linspace(-2,2,Nn),indexing='xy')
xn = torch.round(xn_.flatten(),decimals=2)
yn = torch.round(yn_.flatten(),decimals=2)
X40 = torch.hstack((xn[:,None],yn[:,None])).double()
Xnew_ = jnp.asarray( X40.detach().numpy(), dtype=jnp.float64).copy()


key = jax.random.PRNGKey(0)
num_regions = 2

def predict_structure(post_samples, model, *args, **kwargs):
    key = jax.random.PRNGKey(0)
    model = handlers.seed(handlers.condition(model, post_samples), key)
    model_trace = handlers.trace(model).get_trace(*args, **kwargs)
    return model_trace["gpc_new_probs"]["value"], model_trace["gpc_new_latent"]["value"]

def predict_sage(post_samples, model, *args, **kwargs):
    key = jax.random.PRNGKey(0)
    model = handlers.seed(handlers.condition(model, post_samples), key)
    model_trace = handlers.trace(model).get_trace(*args, **kwargs)
    return model_trace["Fr_new"]["value"], model_trace["f_piecewise"]["value"], model_trace["f_sample"]["value"], model_trace["gpc_new_probs"]["value"], model_trace["v_piecewise"]["value"]

predict_fn_structure = lambda samples: predict_structure(
        samples, predict_SAGE_ND_PM_230628a, Xnew=Xnew_, xs=xs, ys=ys, num_regions=num_regions
    )
predict_fn_sage_1core = lambda samples: predict_sage(
        samples, predict_SAGE_ND_230628a, Xnew_, xs, ys, xf, yf, num_regions=num_regions
    )
def subsample(samples, step):
    tamples = {}
    for k in samples.keys():
        tamples[k] = samples[k][::step]  
    return tamples  

def split_samples(samples, num_proc, length):
    sample_list = []
    splits = np.array(length/num_proc).astype(int)
    s = {}
    for i in trange(splits):
        for k in samples.keys():
            s[k] = samples[k][(i*num_proc):((i+1)*num_proc)]  
        sample_list.append(s)
    return sample_list

def get_samples_split(samples, num_proc, length, i):
    s = {}
    for k in samples.keys():
        s[k] = samples[k][(i*num_proc):((i+1)*num_proc)]  
    return s    


#------------------------------------


data = [xs, ys, xf, num_regions, jnp.asarray([5.,10.], dtype=jnp.float64), 
         jnp.asarray([1.,2.], dtype=jnp.float64)]

key = jax.random.PRNGKey(0)
autoguide_mle = numpyro.infer.autoguide.AutoLowRankMultivariateNormal(model_SAGE_ND_PM_230628a)
optimizer = numpyro.optim.Adam(step_size=0.05)

svi = nSVI(model_SAGE_ND_PM_230628a, autoguide_mle, optimizer, loss=nTrace_ELBO())
svi_result = svi.run(key, 100000, *data)

params = svi_result.params
mle_2a_st = autoguide_mle.median(params)
preds_st = predict_fn_structure(mle_2a_st)

gpc_new_probs_, gpc_new_latent_ = predict_fn_structure(mle_2a_st)

gpc_latent_ = jnp.vstack((mle_2a_st['gpc_latent_0'],mle_2a_st['gpc_latent_1']))

preds_fp = None

# !!!!!!!!!!!!!!!!!!!!!!!!

# gpc_new_probs_st = predict_fn_st(mle_2a_st)
init_params = {'gpc_latent_0': mle_2a_st['gpc_latent_0'], 'gpc_latent_1': mle_2a_st['gpc_latent_1'],
               'gpc_var': mle_2a_st['gpc_var'],'gpc_lengthscale': mle_2a_st['gpc_lengthscale'],
               'gpc_bias': mle_2a_st['gpc_bias']}
init_strategy=init_to_value(values=init_params)

tic = time.perf_counter()
nuts = nMCMC(nNUTS(model_SAGE_ND_PM_230628a, target_accept_prob=0.8, max_tree_depth=5, init_strategy=init_strategy),
            num_samples=2000, num_warmup=100, num_chains=100)
nuts.run(key, xs, ys, xf, num_regions, gpc_var_bounds = jnp.asarray([5.,10.], dtype=jnp.float64),
         gpc_ls_bounds = jnp.asarray([.1,2.], dtype=jnp.float64))

nuts_posterior_samples = nuts.get_samples()

import dill
# with open(r"2D_2bn_structure.dill", "wb") as output_file:
#     dill.dump(nuts_posterior_samples, output_file)

print('start', nuts_posterior_samples['gpc_bias'].shape[0])
samples = subsample(nuts_posterior_samples, step = 10)
print('after subsampling', samples['gpc_bias'].shape[0]) 

num_length = samples['gpc_bias'].shape[0]

print('splitting')
sl = split_samples(samples, num_proc, num_length)
print('done splitting')

splits = np.array(num_length / num_proc).astype(int)


predict_fn_st_multicore = jax.pmap(
    lambda samples: predict_structure(
        samples, predict_SAGE_ND_PM_230628a, Xnew=Xnew_, xs=xs, ys=ys, num_regions=num_regions
    ), axis_name = 0
)

print('starting pred analysis, for #', num_length)
labels = ['gpc_new_probs']

for i in trange(splits):
    if i == 0:
        preds = predict_fn_st_multicore(sl[i])
        preds_stacked = {labels[0]:preds[0].squeeze()}
    else:
        preds = predict_fn_st_multicore(sl[i])
        for j in range(len(labels)):
            preds_stacked[labels[j]] = np.vstack((preds_stacked[labels[j]],preds[j].squeeze()))
toc = time.perf_counter()
print(f"Run in {toc - tic:0.4f} seconds")

print('done pred analysis')

output = {'preds': preds_stacked, 'preds_st':preds_st, 'preds_fp':preds_fp, 'starting_data':starting_data}
with open(r"2D_2bn_structure_1core_231011a.dill", "wb") as output_file:
    dill.dump(output, output_file)

##### Challenge 3: SAGE-ND Multi Inputs, multicore

In [None]:
%%writefile sage_2D_2cn_matern52_with_init_230804a.py
# Unified for init 2a
from sage_2D_functions_230804a import *

import matplotlib.pyplot as plt
import numpyro
import numpy as np
from numpy.random import default_rng

import torch
from torch.distributions import constraints
import scipy.io as sio

import jax
import jax.numpy as jnp
from jax import config
config.update("jax_enable_x64", True)
from jax.lax import dynamic_slice
from jax.nn import one_hot as jax_one_hot

from numpyro.infer import MCMC as nMCMC
from numpyro.infer import NUTS as nNUTS
from numpyro.infer import SVI as nSVI
from numpyro.infer import Predictive as nPredictive
from numpyro.infer import Trace_ELBO as nTrace_ELBO
import numpyro.distributions as ndist

torch.set_default_dtype(torch.float64)
from tqdm import trange
import dill
from tqdm import tqdm, trange
from torch.multiprocessing import Pool, Manager, Process
import functools
import math
import time
from numpyro import handlers
from numpyro.infer.initialization import init_to_value

numpyro.set_host_device_count(200)

N = 41
xp_,yp_ = torch.meshgrid(torch.linspace(-2,2,N),torch.linspace(-2,2,N),indexing='xy')
xp = torch.round(xp_.flatten(),decimals=2)
yp = torch.round(yp_.flatten(),decimals=2)
Xp = torch.hstack((xp[:,None],yp[:,None])).double()
Lp, _ = gen_data_2D_example(xp,yp)

r = torch.sqrt(Xp[:,0]**2 + Xp[:,1]**2)

with open(r"2D_2a_and_2b_fp_231030a.dill", "rb") as input_file:
    f = dill.load(input_file)

seed = 0
top = torch.tensor([912,914, 994,996, 1078,1080, 1162,1164, 1244,1246, 1248,1166, 1250,1168, 1252,1170, 1254,1172, 1256,1174, 1176,1092,
                    1160, 1076, 1094, 1096, 1012,1014, 930,932, 848,850])
bottom = torch.tensor([504, 520, 584, 604, 620, 766,768, 684,686, 600,602, 516,518, 436, 434, 514,432, 512,430, 510,428, 508,426, 424, 506, 586,588, 666,668, 748,750, 830,832])

def map_indices(Xp, idx):  
    for i in range(idx.shape[0]):
        if (10*Xp[idx[i],0] % 2):
            idx[i] += 1
        if (10*Xp[idx[i],1] % 2):
            idx[i] = idx[i] + 41
    return idx


if __name__ == '__main__':
    
    seed = 1
    num_data_points_st0 = 30
    num_data_points_st1 = 30
    num_data_points_fp0 = 40
    num_data_points_fp1 = 40
    temp = torch.tensor( default_rng(seed+0).permutation(Xp.shape[0])[:num_data_points_st0] )
    kp_st0 = torch.cat([top,temp,torch.tensor([1680])])
    temp = torch.tensor( default_rng(seed+1).permutation(Xp.shape[0])[:num_data_points_st1] )
    kp_st1 = torch.cat([bottom,temp,torch.tensor([178])])
    kp_fp0 = torch.tensor( default_rng(seed+4).permutation(Xp.shape[0])[:num_data_points_fp0] )
    kp_fp1 = torch.tensor( default_rng(seed+3).permutation(Xp.shape[0])[:num_data_points_fp1] )

    kp_st0 = map_indices(Xp, kp_st0)
    kp_st1 = map_indices(Xp, kp_st1)
    kp_fp0 = map_indices(Xp, kp_fp0)
    kp_fp1 = map_indices(Xp, kp_fp1)

    Xs_ = [jnp.asarray(Xp[kp_st0,:].detach().numpy(), dtype=jnp.float64),
           jnp.asarray(Xp[kp_st1,:].detach().numpy(), dtype=jnp.float64)]
    Xf_ = [jnp.asarray(Xp[kp_fp0,:].detach().numpy(), dtype=jnp.float64),
           jnp.asarray(Xp[kp_fp1,:].detach().numpy(), dtype=jnp.float64)]
    ys_ = [jnp.asarray(Lp[kp_st0].detach().numpy(), dtype=jnp.int16),
           jnp.asarray(Lp[kp_st1].detach().numpy(), dtype=jnp.int16)]
    yf_ = [jnp.asarray(f[kp_fp0,0].detach().numpy(), dtype=jnp.float64),
           jnp.asarray(f[kp_fp1,1].detach().numpy(), dtype=jnp.float64)]

    starting_data = [Xp, Lp, f, Xs_, ys_, Xf_, yf_]

    N = 41
    x_,y_ = torch.meshgrid(torch.linspace(-2,2,N),torch.linspace(-2,2,N),indexing='xy')
    x = x_.flatten()
    y = y_.flatten()
    X40 = torch.hstack((x[:,None],y[:,None])).double()
    Xnew_ = jnp.asarray( X40.detach().numpy(), dtype=jnp.float64)

    key = jax.random.PRNGKey(0)
    num_regions = 2

    def predict_st(post_samples, model, *args, **kwargs):
        key = jax.random.PRNGKey(0)
        model = handlers.seed(handlers.condition(model, post_samples), key)
        model_trace = handlers.trace(model).get_trace(*args, **kwargs)
        return model_trace["gpc_new_probs"]["value"]

    def predict_sage(post_samples, model, *args, **kwargs):
        key = jax.random.PRNGKey(0)
        model = handlers.seed(handlers.condition(model, post_samples), key)
        model_trace = handlers.trace(model).get_trace(*args, **kwargs)
        return model_trace["Fr_new"]["value"], model_trace["f_piecewise"]["value"], model_trace["f_sample"]["value"], model_trace["gpc_new_probs"]["value"], model_trace["v_piecewise"]["value"]

    predict_fn_st = lambda samples: predict_st(
            samples, predict_SAGE_ND_PM_Coreg_230628a, Xnew=Xnew_, xs_=Xs_, ys_=ys_, xf_=Xf_, num_regions=num_regions
        )

    predict_fn_sage = jax.pmap(
        lambda samples: predict_sage(
            samples, predict_SAGE_ND_Coreg_230628a, Xnew=Xnew_, xs_=Xs_, ys_=ys_, xf_=Xf_, yf_=yf_, num_regions=num_regions
        ), axis_name = 0
    )

    def subsample(samples, step):
        tamples = {}
        for k in samples.keys():
            tamples[k] = samples[k][::step]  
        return tamples  

    def split_samples(samples, num_proc, length):
        sample_list = []
        splits = np.array(length/num_proc).astype(int)
        s = {}
        for i in trange(splits):
            for k in samples.keys():
                s[k] = samples[k][(i*num_proc):((i+1)*num_proc)]  
            sample_list.append(s)
        return sample_list

    def get_samples_split(samples, num_proc, length, i):
        s = {}
        for k in samples.keys():
            s[k] = samples[k][(i*num_proc):((i+1)*num_proc)]  
        return s    

    
    tic = time.perf_counter()
    num_proc = 100
    
    data = [Xs_, ys_, Xf_, num_regions, jnp.asarray([5.,10.], dtype=jnp.float64), 
             jnp.asarray([1.,2.], dtype=jnp.float64)]
        
    key = jax.random.PRNGKey(0)
    autoguide_mle = numpyro.infer.autoguide.AutoLowRankMultivariateNormal(model_SAGE_ND_PM_Coreg_230628a)
    optimizer = numpyro.optim.Adam(step_size=0.01)

    svi = nSVI(model_SAGE_ND_PM_Coreg_230628a, autoguide_mle, optimizer, loss=nTrace_ELBO())
    svi_result = svi.run(key, 10000, *data)

    params = svi_result.params
    mle_st = autoguide_mle.median(params)
    preds_st = predict_fn_st(mle_st)  

    # gpc_new_probs_st = predict_fn_st(mle_2a_st)
    init_params = {'gpc_latent_0': mle_st['gpc_latent_0'], 'gpc_latent_1': mle_st['gpc_latent_1'],
                  'gpc_var': mle_st['gpc_var'],'gpc_lengthscale': mle_st['gpc_lengthscale'],
                   'gpc_bias': mle_st['gpc_bias']}
    init_strategy=init_to_value(values=init_params)

    nuts = nMCMC(nNUTS(model_SAGE_ND_Coreg_230628a, target_accept_prob=0.8, max_tree_depth=5, init_strategy=init_strategy),
            num_samples=2000, num_warmup=100, num_chains = 100)
    nuts.run(key, Xs_, ys_, Xf_, yf_, num_regions, gpc_var_bounds = jnp.asarray([5.,10.], dtype=jnp.float64),
         gpc_ls_bounds = jnp.asarray([.1,2.], dtype=jnp.float64), gpr_var_bounds = jnp.asarray([.1, 2.], dtype=jnp.float64),
         gpr_ls_bounds = jnp.asarray([.1,2.], dtype=jnp.float64), gpr_noise_bounds = jnp.asarray([0.001,.01], dtype=jnp.float64))

    nuts_posterior_samples = nuts.get_samples()

    import dill
#     with open(r"2D_2c_100chains_200samples_matern52_with_init_samples.dill", "wb") as output_file:
#         dill.dump(nuts_posterior_samples, output_file)    
    
    print('start', nuts_posterior_samples['gpr_noise'].shape[0])
    samples = subsample(nuts_posterior_samples, step = 10)
    print('after subsampling', samples['gpr_noise'].shape[0]) 

    num_length = samples['gpr_noise'].shape[0]
    
    print('splitting')
    sl = split_samples(samples, num_proc, num_length)
    print('done splitting')
    
    splits = np.array(num_length / num_proc).astype(int)

    print('starting pred analysis, for #', num_length)
    labels = ['Fr_new', 'f_piecewise', 'f_sample', 'gpc_new_probs', 'v_piecewise']

    for i in trange(splits):
        if i == 0:
            preds = predict_fn_sage(sl[i])
            preds_stacked = {labels[0]:preds[0].squeeze(), labels[1]:preds[1].squeeze(), labels[2]:preds[2].squeeze(), labels[3]:preds[3].squeeze(), labels[4]:preds[4].squeeze()}
        else:
            preds = predict_fn_sage(sl[i])
            for j in range(len(labels)):
                preds_stacked[labels[j]] = np.vstack((preds_stacked[labels[j]],preds[j].squeeze()))
    toc = time.perf_counter()
    print(f"Run in {toc - tic:0.4f} seconds")    
                
    print('done pred analysis')
    
    output = {'preds': preds_stacked, 'preds_st':preds_st, 'starting_data':starting_data} 
    
    with open(r"2D_2cn_matern52_2ksamples_N41_with_init_230906a.dill", "wb") as output_file:
        dill.dump(output, output_file)

##### Challenge 3: SAGE-ND Multi Inputs, 1 core

In [None]:
%%writefile sage_2D_2cn_matern52_with_init_230804a.py
# Unified for init 2a
from sage_2D_functions_230804a import *

import matplotlib.pyplot as plt
import numpyro
import numpy as np
from numpy.random import default_rng

import torch
from torch.distributions import constraints
import scipy.io as sio

import jax
import jax.numpy as jnp
from jax import config
config.update("jax_enable_x64", True)
from jax.lax import dynamic_slice
from jax.nn import one_hot as jax_one_hot

from numpyro.infer import MCMC as nMCMC
from numpyro.infer import NUTS as nNUTS
from numpyro.infer import SVI as nSVI
from numpyro.infer import Predictive as nPredictive
from numpyro.infer import Trace_ELBO as nTrace_ELBO
import numpyro.distributions as ndist

torch.set_default_dtype(torch.float64)
from tqdm import trange
import dill
from tqdm import tqdm, trange
from torch.multiprocessing import Pool, Manager, Process
import functools
import math
import time
from numpyro import handlers
from numpyro.infer.initialization import init_to_value

numpyro.set_host_device_count(200)

N = 41
xp_,yp_ = torch.meshgrid(torch.linspace(-2,2,N),torch.linspace(-2,2,N),indexing='xy')
xp = torch.round(xp_.flatten(),decimals=2)
yp = torch.round(yp_.flatten(),decimals=2)
Xp = torch.hstack((xp[:,None],yp[:,None])).double()
Lp, _ = gen_data_2D_example(xp,yp)

r = torch.sqrt(Xp[:,0]**2 + Xp[:,1]**2)

with open(r"2D_2a_and_2b_fp_231030a.dill", "rb") as input_file:
    f = dill.load(input_file)

seed = 0
top = torch.tensor([912,914, 994,996, 1078,1080, 1162,1164, 1244,1246, 1248,1166, 1250,1168, 1252,1170, 1254,1172, 1256,1174, 1176,1092,
                    1160, 1076, 1094, 1096, 1012,1014, 930,932, 848,850])
bottom = torch.tensor([504, 520, 584, 604, 620, 766,768, 684,686, 600,602, 516,518, 436, 434, 514,432, 512,430, 510,428, 508,426, 424, 506, 586,588, 666,668, 748,750, 830,832])

def map_indices(Xp, idx):  
    for i in range(idx.shape[0]):
        if (10*Xp[idx[i],0] % 2):
            idx[i] += 1
        if (10*Xp[idx[i],1] % 2):
            idx[i] = idx[i] + 41
    return idx

seed = 1
num_data_points_st0 = 30
num_data_points_st1 = 30
num_data_points_fp0 = 40
num_data_points_fp1 = 40
temp = torch.tensor( default_rng(seed+0).permutation(Xp.shape[0])[:num_data_points_st0] )
kp_st0 = torch.cat([top,temp,torch.tensor([1680])])
temp = torch.tensor( default_rng(seed+1).permutation(Xp.shape[0])[:num_data_points_st1] )
kp_st1 = torch.cat([bottom,temp,torch.tensor([178])])
kp_fp0 = torch.tensor( default_rng(seed+4).permutation(Xp.shape[0])[:num_data_points_fp0] )
kp_fp1 = torch.tensor( default_rng(seed+3).permutation(Xp.shape[0])[:num_data_points_fp1] )

kp_st0 = map_indices(Xp, kp_st0)
kp_st1 = map_indices(Xp, kp_st1)
kp_fp0 = map_indices(Xp, kp_fp0)
kp_fp1 = map_indices(Xp, kp_fp1)

Xs_ = [jnp.asarray(Xp[kp_st0,:].detach().numpy(), dtype=jnp.float64),
       jnp.asarray(Xp[kp_st1,:].detach().numpy(), dtype=jnp.float64)]
Xf_ = [jnp.asarray(Xp[kp_fp0,:].detach().numpy(), dtype=jnp.float64),
       jnp.asarray(Xp[kp_fp1,:].detach().numpy(), dtype=jnp.float64)]
ys_ = [jnp.asarray(Lp[kp_st0].detach().numpy(), dtype=jnp.int16),
       jnp.asarray(Lp[kp_st1].detach().numpy(), dtype=jnp.int16)]
yf_ = [jnp.asarray(f[kp_fp0,0].detach().numpy(), dtype=jnp.float64),
       jnp.asarray(f[kp_fp1,1].detach().numpy(), dtype=jnp.float64)]

starting_data = [Xp, Lp, f, Xs_, ys_, Xf_, yf_]

N = 41
x_,y_ = torch.meshgrid(torch.linspace(-2,2,N),torch.linspace(-2,2,N),indexing='xy')
x = x_.flatten()
y = y_.flatten()
X40 = torch.hstack((x[:,None],y[:,None])).double()
Xnew_ = jnp.asarray( X40.detach().numpy(), dtype=jnp.float64)

key = jax.random.PRNGKey(0)
num_regions = 2

def predict_st(post_samples, model, *args, **kwargs):
    key = jax.random.PRNGKey(0)
    model = handlers.seed(handlers.condition(model, post_samples), key)
    model_trace = handlers.trace(model).get_trace(*args, **kwargs)
    return model_trace["gpc_new_probs"]["value"]

def predict_sage(post_samples, model, *args, **kwargs):
    key = jax.random.PRNGKey(0)
    model = handlers.seed(handlers.condition(model, post_samples), key)
    model_trace = handlers.trace(model).get_trace(*args, **kwargs)
    return model_trace["Fr_new"]["value"], model_trace["f_piecewise"]["value"], model_trace["f_sample"]["value"], model_trace["gpc_new_probs"]["value"], model_trace["v_piecewise"]["value"]

predict_fn_st = lambda samples: predict_st(
        samples, predict_SAGE_ND_PM_Coreg_230628a, Xnew=Xnew_, xs_=Xs_, ys_=ys_, xf_=Xf_, num_regions=num_regions
    )

predict_fn_sage = jax.pmap(
    lambda samples: predict_sage(
        samples, predict_SAGE_ND_Coreg_230628a, Xnew=Xnew_, xs_=Xs_, ys_=ys_, xf_=Xf_, yf_=yf_, num_regions=num_regions
    ), axis_name = 0
)

def subsample(samples, step):
    tamples = {}
    for k in samples.keys():
        tamples[k] = samples[k][::step]  
    return tamples  

def split_samples(samples, num_proc, length):
    sample_list = []
    splits = np.array(length/num_proc).astype(int)
    s = {}
    for i in trange(splits):
        for k in samples.keys():
            s[k] = samples[k][(i*num_proc):((i+1)*num_proc)]  
        sample_list.append(s)
    return sample_list

def get_samples_split(samples, num_proc, length, i):
    s = {}
    for k in samples.keys():
        s[k] = samples[k][(i*num_proc):((i+1)*num_proc)]  
    return s    


tic = time.perf_counter()
num_proc = 1

data = [Xs_, ys_, Xf_, num_regions, jnp.asarray([5.,10.], dtype=jnp.float64), 
         jnp.asarray([1.,2.], dtype=jnp.float64)]

key = jax.random.PRNGKey(0)
autoguide_mle = numpyro.infer.autoguide.AutoLowRankMultivariateNormal(model_SAGE_ND_PM_Coreg_230628a)
optimizer = numpyro.optim.Adam(step_size=0.01)

svi = nSVI(model_SAGE_ND_PM_Coreg_230628a, autoguide_mle, optimizer, loss=nTrace_ELBO())
svi_result = svi.run(key, 10000, *data)

params = svi_result.params
mle_st = autoguide_mle.median(params)
preds_st = predict_fn_st(mle_st)  

# gpc_new_probs_st = predict_fn_st(mle_2a_st)
init_params = {'gpc_latent_0': mle_st['gpc_latent_0'], 'gpc_latent_1': mle_st['gpc_latent_1'],
              'gpc_var': mle_st['gpc_var'],'gpc_lengthscale': mle_st['gpc_lengthscale'],
               'gpc_bias': mle_st['gpc_bias']}
init_strategy=init_to_value(values=init_params)

nuts = nMCMC(nNUTS(model_SAGE_ND_Coreg_230628a, target_accept_prob=0.8, max_tree_depth=5, init_strategy=init_strategy),
        num_samples=2000, num_warmup=100, num_chains = 100)
nuts.run(key, Xs_, ys_, Xf_, yf_, num_regions, gpc_var_bounds = jnp.asarray([5.,10.], dtype=jnp.float64),
     gpc_ls_bounds = jnp.asarray([.1,2.], dtype=jnp.float64), gpr_var_bounds = jnp.asarray([.1, 2.], dtype=jnp.float64),
     gpr_ls_bounds = jnp.asarray([.1,2.], dtype=jnp.float64), gpr_noise_bounds = jnp.asarray([0.001,.01], dtype=jnp.float64))

nuts_posterior_samples = nuts.get_samples()

import dill
#     with open(r"2D_2c_100chains_200samples_matern52_with_init_samples.dill", "wb") as output_file:
#         dill.dump(nuts_posterior_samples, output_file)    

print('start', nuts_posterior_samples['gpr_noise'].shape[0])
samples = subsample(nuts_posterior_samples, step = 10)
print('after subsampling', samples['gpr_noise'].shape[0]) 

num_length = samples['gpr_noise'].shape[0]

print('splitting')
sl = split_samples(samples, num_proc, num_length)
print('done splitting')

splits = np.array(num_length / num_proc).astype(int)

print('starting pred analysis, for #', num_length)
labels = ['Fr_new', 'f_piecewise', 'f_sample', 'gpc_new_probs', 'v_piecewise']

for i in trange(splits):
    if i == 0:
        preds = predict_fn_sage(sl[i])
        preds_stacked = {labels[0]:preds[0].squeeze(), labels[1]:preds[1].squeeze(), labels[2]:preds[2].squeeze(), labels[3]:preds[3].squeeze(), labels[4]:preds[4].squeeze()}
    else:
        preds = predict_fn_sage(sl[i])
        for j in range(len(labels)):
            preds_stacked[labels[j]] = np.vstack((preds_stacked[labels[j]],preds[j].squeeze()))
toc = time.perf_counter()
print(f"Run in {toc - tic:0.4f} seconds")    

print('done pred analysis')

output = {'preds': preds_stacked, 'preds_st':preds_st, 'starting_data':starting_data} 

with open(r"2D_2cn_1core_230906a.dill", "wb") as output_file:
    dill.dump(output, output_file)

#### Visualize Results

###### Challenge 1: SAGE-ND

In [None]:
# !!!!!!!!!! 1 INIT !!!!!!!!!
# Using 2Init for 2a MCMC Matern52 - N=40

import dill
with open(r"2D_2an_matern52_N41_10ksamples_2init_231011a.dill", "rb") as input_file:
    output = dill.load(input_file)

preds_sage = output['preds']
preds_st = output['preds_st']
preds_fp = output['preds_fp']
starting_data = output['starting_data']

Xp, Lp, f, xsi, ysi, xfi, yfi = starting_data

labels = ['Fr_new', 'f_piecewise', 'f_sample', 'gpc_new_probs', 'v_piecewise']

N = 40
x_,y_ = torch.meshgrid(torch.linspace(-2,2,N),torch.linspace(-2,2,N),indexing='xy')
x = x_.flatten()
y = y_.flatten()

print(preds_sage['gpc_new_probs'].shape)

# for i in range(len(labels)):
#     print(preds[labels[i]].shape)

plt.figure(figsize = (6,2.5),dpi=300)
plt.subplot(1,2,1)
plt.scatter(xp,yp,c=Lp,s=10)
plt.plot(xsi[:,0],xsi[:,1],'s',markerfacecolor='none',markeredgecolor='r',markersize=4.5)
plt.subplot(1,2,2)
plt.scatter(xp,yp,c=f[:,0],s=10)
plt.plot(xfi[:,0],xfi[:,1],'s',markerfacecolor='none',markeredgecolor='r',markersize=4.5)
plt.savefig('2a_ground_truth.png',transparent=True)


plt.figure(figsize = (6,2.5),dpi=300)
plt.subplot(1,2,1)
plt.scatter(x, y,c=np.argmax(preds_st[0],axis=1),s=10)
plt.scatter(xsi[:,0],xsi[:,1],c=ysi,s=10,edgecolors='r',marker='s')
plt.subplot(1,2,2)
plt.scatter(x, y,c=entropy(preds_st[0],axis=1),s=10)
plt.scatter(xsi[:,0],xsi[:,1],c=ysi,s=10,edgecolors='r',marker='s')

# plt.figure(figsize = (6,2.5),dpi=300)
# plt.subplot(1,2,1)
# plt.scatter(x, y,c=preds_fp[1])
# plt.scatter(xfi[:,0],xfi[:,1],c=yfi,s=10,edgecolors='r',marker='s')
# plt.subplot(1,2,2)
# plt.scatter(x, y,c=preds_fp[4])
# plt.scatter(xfi[:,0],xfi[:,1],c=yfi,s=10,edgecolors='r',marker='s')


plt.figure(figsize = (6,2.5),dpi=300)
plt.subplot(1,2,1)
gpc_mean = np.nanmean(preds_sage['gpc_new_probs'], axis=0)
print(gpc_mean.shape)
gpc_est = np.argmax(gpc_mean,axis=1)
gpc_ent = entropy(gpc_mean,axis=1)
# print(gpc_mean.shape, gpc_est.shape)
plt.scatter(x, y, c=gpc_est, s=10)
plt.scatter(xsi[:,0],xsi[:,1],c=ysi,s=20,edgecolors='r',marker='s')
# plt.title('GPC mean')
plt.subplot(1,2,2)
plt.plot(xsi[:,0],xsi[:,1],'s',markerfacecolor="none",markeredgecolor='r',markersize=4)
plt.scatter(x, y, c=gpc_ent, s=10)
# plt.title('GPC entropy')
plt.savefig('2a_GPC_N40.png',transparent=True)
plt.figure(figsize = (6,2.5),dpi=300)
plt.subplot(1,2,1)
plt.scatter(x, y, c=np.nanmean(preds_sage['f_piecewise'], axis=0), s=10)
plt.scatter(xfi[:,0],xfi[:,1],s=20,c=yfi,edgecolor='r',marker='s')
# plt.title('GPR mean')
plt.subplot(1,2,2)
plt.scatter(x, y, c=np.nanmean(preds_sage['v_piecewise'], axis=0), s=10)
plt.plot(xfi[:,0],xfi[:,1],'s',markerfacecolor="none",markeredgecolor='r',markersize=4)
# plt.title('GPR var');
plt.savefig('2a_GPR_N40.png',transparent=True)

###### Challenge 1: SAGE-ND-PM

In [None]:
import dill
with open(r"2D_2an_structure_matern52_N40_231011a.dill", "rb") as input_file:
    output = dill.load(input_file)

preds_sage = output['preds']
preds_st = output['preds_st']
starting_data = output['starting_data']

Xp, Lp, f, xsi, ysi, xfi, yfi = starting_data

labels = ['Fr_new', 'f_piecewise', 'f_sample', 'gpc_new_probs', 'v_piecewise']
xp = Xp[:,0]
yp = Xp[:,1]

N = 40
x_,y_ = torch.meshgrid(torch.linspace(-2,2,N),torch.linspace(-2,2,N),indexing='xy')
x = x_.flatten()
y = y_.flatten()

print(preds_sage['gpc_new_probs'].shape)

plt.figure(figsize = (6,2.5),dpi=300)
plt.subplot(1,2,1)
plt.scatter(xp,yp,c=Lp,s=10)
plt.plot(xsi[:,0],xsi[:,1],'s',markerfacecolor='none',markeredgecolor='r',markersize=4.5)
plt.subplot(1,2,2)
plt.scatter(xp,yp,c=f[:,0],s=10)
plt.plot(xfi[:,0],xfi[:,1],'s',markerfacecolor='none',markeredgecolor='r',markersize=4.5)
# plt.savefig('2a_ground_truth.png',transparent=True)


plt.figure(figsize = (6,2.5),dpi=300)
plt.subplot(1,2,1)
plt.scatter(x, y,c=np.argmax(preds_st[0],axis=1),s=10)
plt.scatter(xsi[:,0],xsi[:,1],c=ysi,s=10,edgecolors='r',marker='s')
plt.subplot(1,2,2)
plt.scatter(x, y,c=entropy(preds_st[0],axis=1),s=10)
plt.scatter(xsi[:,0],xsi[:,1],c=ysi,s=10,edgecolors='r',marker='s')

plt.figure(figsize = (6,2.5),dpi=300)
plt.subplot(1,2,1)
gpc_mean = np.nanmean(preds_sage['gpc_new_probs'], axis=0)
print(gpc_mean.shape)
gpc_est = np.argmax(gpc_mean,axis=1)
gpc_ent = entropy(gpc_mean,axis=1)
# print(gpc_mean.shape, gpc_est.shape)
plt.scatter(x, y, c=gpc_est, s=10)
plt.scatter(xsi[:,0],xsi[:,1],c=ysi,s=20,edgecolors='r',marker='s')
# plt.title('GPC mean')
plt.subplot(1,2,2)
plt.plot(xsi[:,0],xsi[:,1],'s',markerfacecolor="none",markeredgecolor='r',markersize=4)
plt.scatter(x, y, c=gpc_ent, s=10)
# plt.title('GPC entropy')
plt.savefig('2a_GPC_N40_structure.png',transparent=True)

###### Challenge 1: SAGE-ND-FP

In [None]:
# !!!!!!!!!! 1 INIT !!!!!!!!!
# Using 2Init for 2a MCMC Matern52 - N=40

import dill
with open(r"2D_2an_fp_matern52_231011a.dill", "rb") as input_file:
    output = dill.load(input_file)

preds_sage = output['preds']
starting_data = output['starting_data']

Xp, Lp, f, xsi, ysi, xfi, yfi = starting_data

labels = ['Fr_new', 'f_piecewise', 'f_sample', 'gpc_new_probs', 'v_piecewise']

N = 40
x_,y_ = torch.meshgrid(torch.linspace(-2,2,N),torch.linspace(-2,2,N),indexing='xy')
x = x_.flatten()
y = y_.flatten()

print(preds_sage['gpc_new_probs'].shape)

# for i in range(len(labels)):
#     print(preds[labels[i]].shape)

plt.figure(figsize = (6,2.5),dpi=300)
plt.subplot(1,2,1)
plt.scatter(xp,yp,c=Lp,s=10)
plt.plot(xsi[:,0],xsi[:,1],'s',markerfacecolor='none',markeredgecolor='r',markersize=4.5)
plt.subplot(1,2,2)
plt.scatter(xp,yp,c=f[:,0],s=10)
plt.plot(xfi[:,0],xfi[:,1],'s',markerfacecolor='none',markeredgecolor='r',markersize=4.5)
# plt.savefig('2a_ground_truth.png',transparent=True)


plt.figure(figsize = (6,2.5),dpi=300)
plt.subplot(1,2,1)
gpc_mean = np.nanmean(preds_sage['gpc_new_probs'], axis=0)
print(gpc_mean.shape)
gpc_est = -np.argmax(gpc_mean,axis=1)
gpc_ent = entropy(gpc_mean,axis=1)
# print(gpc_mean.shape, gpc_est.shape)
plt.scatter(x, y, c=gpc_est, s=10)
plt.scatter(xsi[:,0],xsi[:,1],c=ysi,s=20,edgecolors='r',marker='s')
# plt.title('GPC mean')
plt.subplot(1,2,2)
plt.plot(xsi[:,0],xsi[:,1],'s',markerfacecolor="none",markeredgecolor='r',markersize=4)
plt.scatter(x, y, c=gpc_ent, s=10)
# plt.title('GPC entropy')
# plt.savefig('2a_GPC_N40.png',transparent=True)
plt.figure(figsize = (6,2.5),dpi=300)
plt.subplot(1,2,1)
plt.scatter(x, y, c=np.nanmean(preds_sage['f_piecewise'], axis=0), s=10)
plt.scatter(xfi[:,0],xfi[:,1],s=20,c=yfi,edgecolor='r',marker='s')
# plt.title('GPR mean')
plt.subplot(1,2,2)
plt.scatter(x, y, c=np.nanmean(preds_sage['v_piecewise'], axis=0), s=10)
plt.plot(xfi[:,0],xfi[:,1],'s',markerfacecolor="none",markeredgecolor='r',markersize=4)
# plt.title('GPR var');
# plt.savefig('2a_GPR_N40.png',transparent=True)

###### Challenge 2: SAGE-ND

In [None]:
# Using Init for 2b MATERN52

import dill
with open(r"2D_2bn_matern52_N40_pred_init_230906a.dill", "rb") as input_file:
    output = dill.load(input_file)

preds_sage = output['preds']
preds_st = output['preds_st']
starting_data = output['starting_data']

Xp, Lp, f, xsi, ysi, xfi, yfi = starting_data

labels = ['Fr_new', 'f_piecewise', 'f_sample', 'gpc_new_probs', 'v_piecewise']

N = 40
x_,y_ = torch.meshgrid(torch.linspace(-2,2,N),torch.linspace(-2,2,N),indexing='xy')
x = x_.flatten()
y = y_.flatten()

print(preds_sage['gpc_new_probs'].shape)

# for i in range(len(labels)):
#     print(preds[labels[i]].shape)

plt.figure(figsize = (6,2.5),dpi=300)
plt.subplot(1,2,1)
plt.scatter(x, y,c=np.argmax(preds_st,axis=1),s=10)
plt.scatter(xsi[:,0],xsi[:,1],c=ysi,s=10,edgecolors='r',marker='s')
plt.subplot(1,2,2)
plt.scatter(x, y,c=entropy(preds_st,axis=1),s=10)
plt.scatter(xsi[:,0],xsi[:,1],c=ysi,s=10,edgecolors='r',marker='s')


plt.figure(figsize = (6,2.5),dpi=300)
plt.subplot(1,2,1)
gpc_mean = np.nanmean(preds_sage['gpc_new_probs'], axis=0)
print(gpc_mean.shape)
gpc_est = np.argmax(gpc_mean,axis=1)
gpc_ent = entropy(gpc_mean,axis=1)
# print(gpc_mean.shape, gpc_est.shape)
plt.scatter(x, y, c=gpc_est, s=10)
plt.scatter(xsi[:,0],xsi[:,1],c=ysi,s=20,edgecolors='r',marker='s')
# plt.title('GPC mean')
plt.subplot(1,2,2)
plt.plot(xsi[:,0],xsi[:,1],'s',markerfacecolor="none",markeredgecolor='r',markersize=4)
plt.scatter(x, y, c=gpc_ent, s=10)
# plt.title('GPC entropy')
plt.savefig('2b_GPC_N40.png',transparent=True)
plt.figure(figsize = (6,2.5),dpi=300)
plt.subplot(1,2,1)
plt.scatter(x, y, c=np.nanmean(preds_sage['f_piecewise'], axis=0), s=10)
plt.scatter(xfi[:,0],xfi[:,1],s=20,c=yfi,edgecolor='r',marker='s')
# plt.title('GPR mean')
plt.subplot(1,2,2)
plt.scatter(x, y, c=np.nanmean(preds_sage['v_piecewise'], axis=0), s=10)
plt.plot(xfi[:,0],xfi[:,1],'s',markerfacecolor="none",markeredgecolor='r',markersize=4)
# plt.title('GPR var');
plt.savefig('2b_GPR_N40.png',transparent=True)  

###### Challenge 2: SAGE-ND-PM

In [None]:
import dill
import sklearn

with open(r"2D_2bn_structure_matern52_N40_231011a.dill", "rb") as input_file:
    output = dill.load(input_file)

preds_sage = output['preds']
preds_st = output['preds_st']
starting_data = output['starting_data']

Xp, Lp, f, xsi, ysi, xfi, yfi = starting_data

labels = ['Fr_new', 'f_piecewise', 'f_sample', 'gpc_new_probs', 'v_piecewise']
xp = Xp[:,0]
yp = Xp[:,1]

N = 40
x_,y_ = torch.meshgrid(torch.linspace(-2,2,N),torch.linspace(-2,2,N),indexing='xy')
x = x_.flatten()
y = y_.flatten()

print(preds_sage['gpc_new_probs'].shape)

plt.figure(figsize = (6,2.5),dpi=300)
plt.subplot(1,2,1)
plt.scatter(xp,yp,c=Lp,s=10)
plt.plot(xsi[:,0],xsi[:,1],'s',markerfacecolor='none',markeredgecolor='r',markersize=4.5)
plt.subplot(1,2,2)
plt.scatter(xp,yp,c=f[:,0],s=10)
plt.plot(xfi[:,0],xfi[:,1],'s',markerfacecolor='none',markeredgecolor='r',markersize=4.5)
# plt.savefig('2a_ground_truth.png',transparent=True)


plt.figure(figsize = (6,2.5),dpi=300)
plt.subplot(1,2,1)
plt.scatter(x, y,c=np.argmax(preds_st[0],axis=1),s=10)
plt.scatter(xsi[:,0],xsi[:,1],c=ysi,s=10,edgecolors='r',marker='s')
plt.subplot(1,2,2)
plt.scatter(x, y,c=entropy(preds_st[0],axis=1),s=10)
plt.scatter(xsi[:,0],xsi[:,1],c=ysi,s=10,edgecolors='r',marker='s')

plt.figure(figsize = (6,2.5),dpi=300)
plt.subplot(1,2,1)
gpc_mean = np.nanmean(preds_sage['gpc_new_probs'], axis=0)
print(gpc_mean.shape)
gpc_est = np.argmax(gpc_mean,axis=1)
gpc_ent = entropy(gpc_mean,axis=1)
# print(gpc_mean.shape, gpc_est.shape)
plt.scatter(x, y, c=gpc_est, s=10)
plt.scatter(xsi[:,0],xsi[:,1],c=ysi,s=20,edgecolors='r',marker='s')
# plt.title('GPC mean')
plt.subplot(1,2,2)
plt.plot(xsi[:,0],xsi[:,1],'s',markerfacecolor="none",markeredgecolor='r',markersize=4)
plt.scatter(x, y, c=gpc_ent, s=10)
# plt.title('GPC entropy')
plt.savefig('2b_GPC_N40_structure.png',transparent=True)


###### Challenge 2: SAGE-ND-FP

In [None]:
# !!!!!!!!!! 1 INIT !!!!!!!!!
# Using 2Init for 2a MCMC Matern52 - N=40

import dill
with open(r"2D_2bn_fp_matern52_231011a.dill", "rb") as input_file:
    output = dill.load(input_file)

preds_sage = output['preds']
starting_data = output['starting_data']

Xp, Lp, f, xsi, ysi, xfi, yfi = starting_data

labels = ['Fr_new', 'f_piecewise', 'f_sample', 'gpc_new_probs', 'v_piecewise']

N = 40
x_,y_ = torch.meshgrid(torch.linspace(-2,2,N),torch.linspace(-2,2,N),indexing='xy')
x = x_.flatten()
y = y_.flatten()

print(preds_sage['gpc_new_probs'].shape)

# for i in range(len(labels)):
#     print(preds[labels[i]].shape)

plt.figure(figsize = (6,2.5),dpi=300)
plt.subplot(1,2,1)
plt.scatter(xp,yp,c=Lp,s=10)
plt.plot(xsi[:,0],xsi[:,1],'s',markerfacecolor='none',markeredgecolor='r',markersize=4.5)
plt.subplot(1,2,2)
plt.scatter(xp,yp,c=f[:,0],s=10)
plt.plot(xfi[:,0],xfi[:,1],'s',markerfacecolor='none',markeredgecolor='r',markersize=4.5)
# plt.savefig('2a_ground_truth.png',transparent=True)


plt.figure(figsize = (6,2.5),dpi=300)
plt.subplot(1,2,1)
gpc_mean = np.nanmean(preds_sage['gpc_new_probs'], axis=0)
print(gpc_mean.shape)
gpc_est = -np.argmax(gpc_mean,axis=1)
gpc_ent = entropy(gpc_mean,axis=1)
# print(gpc_mean.shape, gpc_est.shape)
plt.scatter(x, y, c=gpc_est, s=10)
plt.scatter(xsi[:,0],xsi[:,1],c=ysi,s=20,edgecolors='r',marker='s')
# plt.title('GPC mean')
plt.subplot(1,2,2)
plt.plot(xsi[:,0],xsi[:,1],'s',markerfacecolor="none",markeredgecolor='r',markersize=4)
plt.scatter(x, y, c=gpc_ent, s=10)
# plt.title('GPC entropy')
# plt.savefig('2a_GPC_N40.png',transparent=True)
plt.figure(figsize = (6,2.5),dpi=300)
plt.subplot(1,2,1)
plt.scatter(x, y, c=np.nanmean(preds_sage['f_piecewise'], axis=0), s=10)
plt.scatter(xfi[:,0],xfi[:,1],s=20,c=yfi,edgecolor='r',marker='s')
# plt.title('GPR mean')
plt.subplot(1,2,2)
plt.scatter(x, y, c=np.nanmean(preds_sage['v_piecewise'], axis=0), s=10)
plt.plot(xfi[:,0],xfi[:,1],'s',markerfacecolor="none",markeredgecolor='r',markersize=4)
# plt.title('GPR var');
# plt.savefig('2a_GPR_N40.png',transparent=True)

###### Challenge 3

In [None]:
import dill
with open(r"2D_2cn_matern52_2ksamples_N40_with_init_230906a.dill", "rb") as input_file:
    output = dill.load(input_file)

preds_sage = output['preds']
preds_st = output['preds_st']
starting_data = output['starting_data']

Xp, Lp, f, Xsi_, ysi_, Xfi_, yfi_ = starting_data

labels = ['Fr_new', 'f_piecewise', 'f_sample', 'gpc_new_probs', 'v_piecewise']

N = 40
x_,y_ = torch.meshgrid(torch.linspace(-2,2,N),torch.linspace(-2,2,N),indexing='xy')
x = x_.flatten()
y = y_.flatten()


plt.figure(figsize = (6,2.5),dpi=300)
plt.subplot(1,2,1)
gpc_est = np.argmax(preds_st,axis=1)
gpc_ent = entropy(preds_st,axis=1)
plt.scatter(x, y, c=gpc_est, s=10)
plt.scatter(Xsi_[0][:,0],Xsi_[0][:,1],c=ysi_[0],s=10,edgecolors='r',marker='s')
plt.scatter(Xsi_[1][:,0],Xsi_[1][:,1],c=ysi_[1],s=10,edgecolors='m',marker='s')
plt.subplot(1,2,2)
plt.scatter(x, y, c=gpc_ent, s=10)
plt.plot(Xsi_[0][:,0],Xsi_[0][:,1],'s',markerfacecolor="none",markeredgecolor='r',markersize=4.5)
plt.plot(Xsi_[1][:,0],Xsi_[1][:,1],'s',markerfacecolor="none",markeredgecolor='m',markersize=4.5)


plt.figure(figsize = (6,2.5),dpi=300)
plt.subplot(1,2,1)
gpc_mean = np.nanmean(preds_sage['gpc_new_probs'], axis=0)
gpc_est = np.argmax(gpc_mean,axis=1)
gpc_ent = entropy(gpc_mean,axis=1)

plt.scatter(x, y, c=gpc_est, s=10)
plt.scatter(Xsi_[0][:,0],Xsi_[0][:,1],c=ysi_[0],s=10,edgecolors='r',marker='s')
plt.scatter(Xsi_[1][:,0],Xsi_[1][:,1],c=ysi_[1],s=10,edgecolors='m',marker='s')
plt.subplot(1,2,2)
plt.scatter(x, y, c=gpc_ent, s=10)
plt.plot(Xsi_[0][:,0],Xsi_[0][:,1],'s',markerfacecolor="none",markeredgecolor='r',markersize=4.5)
plt.plot(Xsi_[1][:,0],Xsi_[1][:,1],'s',markerfacecolor="none",markeredgecolor='m',markersize=4.5)
plt.savefig('2c_GPC_N40.png',transparent=True)
plt.figure(figsize = (6,2.5),dpi=300)
plt.subplot(1,2,1)
plt.scatter(x, y, c=np.nanmean(preds_sage['f_piecewise'], axis=0)[:,0], s=10)
plt.scatter(Xfi_[0][:,0],Xfi_[0][:,1],s=10,c=yfi_[0],edgecolor='r',marker='s')
plt.subplot(1,2,2)
plt.scatter(x, y, c=np.nanmean(preds_sage['v_piecewise'], axis=0)[:,0], s=10)
plt.plot(Xfi_[0][:,0],Xfi_[0][:,1],'s',markerfacecolor="none",markeredgecolor='r',markersize=4.5)
plt.savefig('2c_GPR1_N40.png',transparent=True)
plt.figure(figsize = (6,2.5),dpi=300)
plt.subplot(1,2,1)
plt.scatter(x, y, c=np.nanmean(preds_sage['f_piecewise'], axis=0)[:,1], s=10)
plt.scatter(Xfi_[1][:,0],Xfi_[1][:,1],s=10,c=yfi_[1],edgecolor='r',marker='s')
plt.subplot(1,2,2)
plt.scatter(x, y, c=np.nanmean(preds_sage['v_piecewise'], axis=0)[:,1], s=10)
plt.plot(Xfi_[1][:,0],Xfi_[1][:,1],'s',markerfacecolor="none",markeredgecolor='r',markersize=4.5)
plt.savefig('2c_GPR2_N40.png',transparent=True)

#### (Bi,Sm)(Sc,Fe)O3 (aka BSF) Challenge

###### Data and Data Visualizations

In [None]:
import dill
import scipy.io as sio
import numpy as np
from matplotlib import pyplot as plt

BSF = sio.loadmat('Raman_with_matched_Ecoercivity_180416a.mat')

ecoer = BSF['Ecoer_sub']
R = BSF['X']
xy = BSF['xy']

kp = ecoer > 0
kp[[57,16]] = False
kp = kp.flatten()

plt.figure(figsize = (8,2.5),dpi = 300)
plt.subplot(1,2,1)
plt.scatter(xy[kp,0],xy[kp,1],s=10,c=ecoer[kp])
plt.colorbar()

kp = kp.flatten()
x2 = xy[kp,:].astype('double')
f2 = ecoer[kp]
# 22.5, 7
# 30.5, 15
# m = (15-7)/(30.5-22.5)
# y = (15-7)/(30.5-22.5)*22.5-15.5
s2 = np.ones((x2.shape[0]))
s2[np.logical_and(x2[:,1]==7, x2[:,0]>22.5)] = 2
s2[np.logical_and(x2[:,1]==9, x2[:,0]>25.5)] = 2
s2[np.logical_and(x2[:,1]==11, x2[:,0]>27)] = 2
s2[np.logical_and(x2[:,1]==15, x2[:,0]>30.5)] = 2

idx = x2[:,1] > 1.2*x2[:,0] - 8
s2[idx] = 0

plt.subplot(1,2,2)
plt.scatter(x2[:,0],x2[:,1],s=10,c=s2)

print(f2.shape)
print(f2.max())

temp = np.abs(np.diff(s2)) > 0.
L = np.zeros(s2.shape)
L[1:] = temp
drop = np.asarray([19,54,103,93])
L[drop]=0
for i in range(L.shape[0]):
    if L[i]:
        L[i-1] = 1
L = L > 0

N = 20
seed = 0
# idx_fp = np.asarray([1,4,78,107,136,128,142,152,125,7,40,71,91,112,12,17,36])
idx_fp = default_rng(seed).choice(x2_.shape[0],N,replace=False)
# idx_fp = np.concatenate((idx_fp,np.asarray([0,18])))
idx_st = np.nonzero(L)[0]

plt.figure()
for i in range(x2.shape[0]):
    plt.text(x2[i,0],x2[i,1],str(i))

plt.plot(x2[L,0],x2[L,1],'ro')
plt.plot(x2[:,0],x2[:,1],'k.')
plt.plot(x2[idx_fp,0],x2[idx_fp,1],'r.')

data = [xy[idx_st,:], s2[idx_st], xy[idx_fp,:], f[idx_fp,0][:,None], xy]

with open(r"BSF_st_and_fp_samples_231030a.dill", "wb") as output_file:
    dill.dump([idx_st, idx_fp],output_file)

##### Hermes implementation

In [None]:
import matplotlib.pyplot as plt

import numpyro
import numpy as np
from numpy.random import default_rng
import sklearn

import torch
from torch.distributions import constraints
import scipy.io as sio

import jax
import jax.numpy as jnp
from jax import config
config.update("jax_enable_x64", True)
from jax.lax import dynamic_slice
from jax.nn import one_hot as jax_one_hot

from numpyro.infer import MCMC as nMCMC
from numpyro.infer import NUTS as nNUTS
from numpyro.infer import SVI as nSVI
from numpyro.infer import Predictive as nPredictive
from numpyro.infer import Trace_ELBO as nTrace_ELBO
import numpyro.distributions as ndist

torch.set_default_dtype(torch.float64)
from tqdm import trange
import dill
from tqdm import tqdm, trange
from torch.multiprocessing import Pool, Manager, Process
import functools
import math
import time
from numpyro import handlers
from numpyro.infer.initialization import init_to_value

import sys
sys.path.insert(0, r'C:/Users/gkusne/Documents/GitHub/')
import hermes
from hermes.joint import SAGE_ND

numpyro.set_host_device_count(1)

num_proc = 1
    
# ------ Load data -----------
BSF = sio.loadmat('Raman_with_matched_Ecoercivity_180416a.mat')

ecoer = BSF['Ecoer_sub']
R = BSF['X']
xy = BSF['xy']

kp = ecoer > 0
kp[[57,16]] = False
kp = kp.flatten()

kp = kp.flatten()
Xp = xy[kp,:].astype('double')
f = ecoer[kp]/500.

Lp = np.ones((Xp.shape[0]))
Lp[np.logical_and(Xp[:,1]==7, Xp[:,0]>22.5)] = 2
Lp[np.logical_and(Xp[:,1]==9, Xp[:,0]>25.5)] = 2
Lp[np.logical_and(Xp[:,1]==11, Xp[:,0]>27)] = 2
Lp[np.logical_and(Xp[:,1]==15, Xp[:,0]>30.5)] = 2

idx = Xp[:,1] > 1.2*Xp[:,0] - 8
Lp[idx] = 0
Xp = (Xp-20)/10.

# N = 50
# kp_fp = default_rng(0).permutation(Xp.shape[0])[:N]
# kp_st = default_rng(1).permutation(Xp.shape[0])[:N]

with open(r"BSF_st_and_fp_samples_231030a.dill", "rb") as input_file:
    kp_st, kp_fp = dill.load(input_file)

xs = Xp[kp_st,:]
ys = Lp[kp_st]
xf = Xp[kp_fp,:]
yf = f[kp_fp,0][:,None]
starting_data = [Xp, Lp, f, xs, ys, xf, yf]
sage_nd = SAGE_ND(
    num_phase_regions=4,
    num_samples=1000,
    num_warmup=50,
    num_chains = 1,
    target_accept_prob = 0.8,
    max_tree_depth = 5,
    jitter = 1E-6,
    phase_map_SVI_num_steps = 100000,
    Adam_step_size = 0.05,
    posterior_sampling = 1,
    locations_structure = np.asarray(xs),
    locations_functional_property = np.asarray(xf),
    target_structure_labels = np.asarray(ys),
    target_functional_properties = np.asarray(yf),
    locations_prediction = np.asarray(Xp),
    gpc_variance_bounds = np.asarray([5.,10.]),
    gpc_lengthscale_bounds = np.asarray([.1,2.]),
    gpr_variance_bounds = np.asarray([.1, 2.]),
    gpr_lengthscale_bounds = np.asarray([.1,5.]),
    gpr_noise_bounds= np.asarray([0.001,.1]),
    gpr_bias_bounds = np.asarray([-2., 2.]),
    )

sage_nd.run()
predictions_bsf = sage_nd.predictions
sage_pm_est_joint = predictions_bsf['phase_region_labels_mean_estimate']
print(sklearn.metrics.r2_score(f[:,0],predictions_bsf['functional_property_mean'].flatten()))
print(sklearn.metrics.f1_score(Lp, sage_pm_est_joint, average='micro'))

In [None]:
with open(r"2D_BSF_1core_231031a.dill", "wb") as output_file:
    dill.dump(predictions_bsf, output_file)

###### BSF: SAGE-ND, multicore

In [None]:
%%writefile sage_2D_BSF_231031a.py

from sage_2D_functions_230804a import predict_SAGE_ND_PM_230628a, predict_SAGE_ND_230628a, model_SAGE_ND_230628a
from sage_2D_functions_230804a import gen_data_2D_example, model_SAGE_ND_FP_230628a, model_SAGE_ND_PM_230628a

import matplotlib.pyplot as plt

import numpyro
import numpy as np
from numpy.random import default_rng

import torch
from torch.distributions import constraints
import scipy.io as sio

import jax
import jax.numpy as jnp
from jax import config
config.update("jax_enable_x64", True)
from jax.lax import dynamic_slice
from jax.nn import one_hot as jax_one_hot

from numpyro.infer import MCMC as nMCMC
from numpyro.infer import NUTS as nNUTS
from numpyro.infer import SVI as nSVI
from numpyro.infer import Predictive as nPredictive
from numpyro.infer import Trace_ELBO as nTrace_ELBO
import numpyro.distributions as ndist

torch.set_default_dtype(torch.float64)
from tqdm import trange
import dill
from tqdm import tqdm, trange
from torch.multiprocessing import Pool, Manager, Process
import functools
import math
import time
from numpyro import handlers
from numpyro.infer.initialization import init_to_value

numpyro.set_host_device_count(200)

if __name__ == '__main__':
    num_proc = 100
    
    # ------ Load data -----------
    BSF = sio.loadmat('Raman_with_matched_Ecoercivity_180416a.mat')

    ecoer = BSF['Ecoer_sub']
    R = BSF['X']
    xy = BSF['xy']

    kp = ecoer > 0
    kp[[57,16]] = False
    kp = kp.flatten()

    kp = kp.flatten()
    Xp = xy[kp,:].astype('double')
    f = ecoer[kp]/500.

    Lp = np.ones((Xp.shape[0]))
    Lp[np.logical_and(Xp[:,1]==7, Xp[:,0]>22.5)] = 2
    Lp[np.logical_and(Xp[:,1]==9, Xp[:,0]>25.5)] = 2
    Lp[np.logical_and(Xp[:,1]==11, Xp[:,0]>27)] = 2
    Lp[np.logical_and(Xp[:,1]==15, Xp[:,0]>30.5)] = 2

    idx = Xp[:,1] > 1.2*Xp[:,0] - 8
    Lp[idx] = 0
    Xp = (Xp-20)/10.

    # N = 50
    # kp_fp = default_rng(0).permutation(Xp.shape[0])[:N]
    # kp_st = default_rng(1).permutation(Xp.shape[0])[:N]
    
    with open(r"BSF_st_and_fp_samples_231030a.dill", "rb") as input_file:
        kp_st, kp_fp = dill.load(input_file)

    xs = Xp[kp_st,:]
    ys = Lp[kp_st]
    xf = Xp[kp_fp,:]
    yf = f[kp_fp,0][:,None]
    starting_data = [Xp, Lp, f, xs, ys, xf, yf]

    xs = jnp.asarray( xs, dtype=jnp.float64).copy()
    ys = jnp.asarray( ys, dtype=jnp.integer).copy()
    xf = jnp.asarray( xf, dtype=jnp.float64).copy()
    yf = jnp.asarray( yf, dtype=jnp.float64).copy()

    Ns = xs.shape[0] + xf.shape[0]

    Xnew_ = jnp.asarray( Xp, dtype=jnp.float64).copy()


    key = jax.random.PRNGKey(0)
    num_regions = 3

    def predict_structure(post_samples, model, *args, **kwargs):
        key = jax.random.PRNGKey(0)
        model = handlers.seed(handlers.condition(model, post_samples), key)
        model_trace = handlers.trace(model).get_trace(*args, **kwargs)
        return model_trace["gpc_new_probs"]["value"], model_trace["gpc_new_latent"]["value"]

    def predict_sage(post_samples, model, *args, **kwargs):
        key = jax.random.PRNGKey(0)
        model = handlers.seed(handlers.condition(model, post_samples), key)
        model_trace = handlers.trace(model).get_trace(*args, **kwargs)
        return model_trace["Fr_new"]["value"], model_trace["f_piecewise"]["value"], model_trace["f_sample"]["value"], model_trace["gpc_new_probs"]["value"], model_trace["v_piecewise"]["value"]

    predict_fn_structure = lambda samples: predict_structure(
            samples, predict_SAGE_ND_PM_230628a, Xnew=Xnew_, xs=xs, ys=ys, num_regions=num_regions
        )
    predict_fn_sage_1core = lambda samples: predict_sage(
            samples, predict_SAGE_ND_230628a, Xnew_, xs, ys, xf, yf, num_regions=num_regions
        )
    def subsample(samples, step):
        tamples = {}
        for k in samples.keys():
            tamples[k] = samples[k][::step]  
        return tamples  

    def split_samples(samples, num_proc, length):
        sample_list = []
        splits = np.array(length/num_proc).astype(int)
        s = {}
        for i in trange(splits):
            for k in samples.keys():
                s[k] = samples[k][(i*num_proc):((i+1)*num_proc)]  
            sample_list.append(s)
        return sample_list

    def get_samples_split(samples, num_proc, length, i):
        s = {}
        for k in samples.keys():
            s[k] = samples[k][(i*num_proc):((i+1)*num_proc)]  
        return s    
    
    data = [xs, ys, xf, num_regions, jnp.asarray([5.,10.], dtype=jnp.float64), 
             jnp.asarray([1.,2.], dtype=jnp.float64)]
        
    key = jax.random.PRNGKey(0)
    autoguide_mle = numpyro.infer.autoguide.AutoLowRankMultivariateNormal(model_SAGE_ND_PM_230628a)
    optimizer = numpyro.optim.Adam(step_size=0.05)

    svi = nSVI(model_SAGE_ND_PM_230628a, autoguide_mle, optimizer, loss=nTrace_ELBO())
    svi_result = svi.run(key, 100000, *data)

    params = svi_result.params
    mle_2a_st = autoguide_mle.median(params)
    preds_st = predict_fn_structure(mle_2a_st)
    
    gpc_new_probs_, gpc_new_latent_ = predict_fn_structure(mle_2a_st)
    
    gpc_latent_ = jnp.vstack((mle_2a_st['gpc_latent_0'],mle_2a_st['gpc_latent_1'],mle_2a_st['gpc_latent_2']))

    preds_fp = None
    
    # !!!!!!!!!!!!!!!!!!!!!!!!

    # gpc_new_probs_st = predict_fn_st(mle_2a_st)
    init_params = {'gpc_latent_0': mle_2a_st['gpc_latent_0'], 'gpc_latent_1': mle_2a_st['gpc_latent_1'], 'gpc_latent_2': mle_2a_st['gpc_latent_2'],
                   'gpc_var': mle_2a_st['gpc_var'],'gpc_lengthscale': mle_2a_st['gpc_lengthscale'],
                   'gpc_bias': mle_2a_st['gpc_bias']}
    init_strategy=init_to_value(values=init_params)
    
    tic = time.perf_counter()
    nuts = nMCMC(nNUTS(model_SAGE_ND_230628a, target_accept_prob=0.8, max_tree_depth=5, init_strategy=init_strategy),
                num_samples=2000, num_warmup=100, num_chains=100)
    nuts.run(key, xs, ys, xf, yf, num_regions, gpc_var_bounds = jnp.asarray([5.,10.], dtype=jnp.float64),
             gpc_ls_bounds = jnp.asarray([.1,2.], dtype=jnp.float64), gpr_var_bounds = jnp.asarray([.1, 2.], dtype=jnp.float64),
             gpr_ls_bounds = jnp.asarray([.1,2.], dtype=jnp.float64), gpr_noise_bounds = jnp.asarray([0.001,.01], dtype=jnp.float64))

    nuts_posterior_samples = nuts.get_samples()
    
    import dill
    with open(r"2D_BSF_samples.dill", "wb") as output_file:
        dill.dump(nuts_posterior_samples, output_file)

    print('start', nuts_posterior_samples['gpr_noise'].shape[0])
    samples = subsample(nuts_posterior_samples, step = 10)
    print('after subsampling', samples['gpr_noise'].shape[0]) 

    num_length = samples['gpr_noise'].shape[0]
    
    print('splitting')
    sl = split_samples(samples, num_proc, num_length)
    print('done splitting')
    
    splits = np.array(num_length / num_proc).astype(int)
   
    predict_fn_sage = jax.pmap(
        lambda samples: predict_sage(
            samples, predict_SAGE_ND_230628a, Xnew=Xnew_, xs=xs, ys=ys, xf=xf, yf=yf, num_regions=num_regions
        ), axis_name = 0
    )

    print('starting pred analysis, for #', num_length)
    labels = ['Fr_new', 'f_piecewise', 'f_sample', 'gpc_new_probs', 'v_piecewise']

    for i in trange(splits):
        if i == 0:
            preds = predict_fn_sage(sl[i])
            preds_stacked = {labels[0]:preds[0].squeeze(), labels[1]:preds[1].squeeze(), labels[2]:preds[2].squeeze(), labels[3]:preds[3].squeeze(), labels[4]:preds[4].squeeze()}
        else:
            preds = predict_fn_sage(sl[i])
            for j in range(len(labels)):
                preds_stacked[labels[j]] = np.vstack((preds_stacked[labels[j]],preds[j].squeeze()))
    toc = time.perf_counter()
    print(f"Run in {toc - tic:0.4f} seconds")    
                
    print('done pred analysis')
    
    output = {'preds': preds_stacked, 'preds_st':preds_st, 'preds_fp':preds_fp, 'starting_data':starting_data}
    with open(r"2D_BSF_new_pred_231031a.dill", "wb") as output_file:
        dill.dump(output, output_file)

###### BSF: SAGE-ND, 1 core

In [None]:
# %%writefile sage_2D_BSF_231031a.py

from sage_2D_functions_230804a import compare_inputs_jax, predict_SAGE_ND_PM_230628a, predict_SAGE_ND_240712a, predict_SAGE_ND_230628a, model_SAGE_ND_230628a
from sage_2D_functions_230804a import gen_data_2D_example, model_SAGE_ND_FP_230628a, model_SAGE_ND_PM_230628a

import matplotlib.pyplot as plt

import numpyro
import numpy as np
from numpy.random import default_rng

import torch
from torch.distributions import constraints
import scipy.io as sio

import jax
import jax.numpy as jnp
from jax import config
config.update("jax_enable_x64", True)
from jax.lax import dynamic_slice
from jax.nn import one_hot as jax_one_hot

from numpyro.infer import MCMC as nMCMC
from numpyro.infer import NUTS as nNUTS
from numpyro.infer import SVI as nSVI
from numpyro.infer import Predictive as nPredictive
from numpyro.infer import Trace_ELBO as nTrace_ELBO
import numpyro.distributions as ndist

torch.set_default_dtype(torch.float64)
from tqdm import trange
import dill
from tqdm import tqdm, trange
from torch.multiprocessing import Pool, Manager, Process
import functools
import math
import time
from numpyro import handlers
from numpyro.infer.initialization import init_to_value

numpyro.set_host_device_count(1)

num_proc = 1
    
# ------ Load data -----------
BSF = sio.loadmat('Raman_with_matched_Ecoercivity_180416a.mat')

ecoer = BSF['Ecoer_sub']
R = BSF['X']
xy = BSF['xy']

kp = ecoer > 0
kp[[57,16]] = False
kp = kp.flatten()

kp = kp.flatten()
Xp = xy[kp,:].astype('double')
f = ecoer[kp]/500.

Lp = np.ones((Xp.shape[0]))
Lp[np.logical_and(Xp[:,1]==7, Xp[:,0]>22.5)] = 2
Lp[np.logical_and(Xp[:,1]==9, Xp[:,0]>25.5)] = 2
Lp[np.logical_and(Xp[:,1]==11, Xp[:,0]>27)] = 2
Lp[np.logical_and(Xp[:,1]==15, Xp[:,0]>30.5)] = 2

idx = Xp[:,1] > 1.2*Xp[:,0] - 8
Lp[idx] = 0
Xp = (Xp-20)/10.

# N = 50
# kp_fp = default_rng(0).permutation(Xp.shape[0])[:N]
# kp_st = default_rng(1).permutation(Xp.shape[0])[:N]

with open(r"BSF_st_and_fp_samples_231030a.dill", "rb") as input_file:
    kp_st, kp_fp = dill.load(input_file)

xs = Xp[kp_st,:]
ys = Lp[kp_st]
xf = Xp[kp_fp,:]
yf = f[kp_fp,0][:,None]
starting_data = [Xp, Lp, f, xs, ys, xf, yf]

# unmeasured = np.setdiff1d(np.arange(Xp.shape[0]),np.concatenate((kp_st,kp_fp)))
# exclude_xs = np.setdiff1d(np.arange(Xp.shape[0]),kp_st)

xs = jnp.asarray( xs, dtype=jnp.float64).copy()
ys = jnp.asarray( ys, dtype=jnp.integer).copy()
xf = jnp.asarray( xf, dtype=jnp.float64).copy()
yf = jnp.asarray( yf, dtype=jnp.float64).copy()
Xp = jnp.asarray( Xp, dtype=jnp.float64).copy()
Xnew_ = jnp.asarray( Xp, dtype=jnp.float64).copy()

Ns = xs.shape[0] + xf.shape[0]

def identify_X_overlap_with_x(Xnew, xs):
    m_Xno_xs, idx_Xnew_match_xs, idx_xs_match_Xnew = compare_inputs_jax(Xnew, xs)
    idx_Xnew_exclude_xs = np.setdiff1d(np.arange(Xnew.shape[0]), idx_Xnew_match_xs )
    return idx_Xnew_exclude_xs, idx_Xnew_match_xs, idx_xs_match_Xnew

idx_Xnew_exclude_xs, idx_Xnew_match_xs, idx_xs_match_Xnew = identify_X_overlap_with_x(Xnew_, xs)
idx_xf_exclude_xs, idx_xf_match_xs, idx_xs_match_xf = identify_X_overlap_with_x(xf, xs)

key = jax.random.PRNGKey(0)
num_regions = 3

def predict_structure(post_samples, model, *args, **kwargs):
    key = jax.random.PRNGKey(0)
    model = handlers.seed(handlers.condition(model, post_samples), key)
    model_trace = handlers.trace(model).get_trace(*args, **kwargs)
    return model_trace["gpc_new_probs"]["value"], model_trace["gpc_new_latent"]["value"]

def predict_sage(post_samples, model, *args, **kwargs):
    key = jax.random.PRNGKey(0)
    model = handlers.seed(handlers.condition(model, post_samples), key)
    model_trace = handlers.trace(model).get_trace(*args, **kwargs)
    return model_trace["Fr_new"]["value"], model_trace["f_piecewise"]["value"], model_trace["f_sample"]["value"], model_trace["gpc_new_probs"]["value"], model_trace["v_piecewise"]["value"]

predict_fn_structure = lambda samples: predict_structure(
        samples, predict_SAGE_ND_PM_230628a, Xnew=Xnew_, xs=xs, ys=ys, num_regions=num_regions
    )
predict_fn_sage_1core = lambda samples: predict_sage(
        samples, predict_SAGE_ND_240712a, Xnew_, xs, ys, xf, yf, num_regions=num_regions, idx_Xnew_exclude_xs=idx_Xnew_exclude_xs, idx_Xnew_match_xs=idx_Xnew_match_xs, idx_xs_match_Xnew=idx_xs_match_Xnew, idx_xf_exclude_xs=idx_xf_exclude_xs
    )
def subsample(samples, step):
    tamples = {}
    for k in samples.keys():
        tamples[k] = samples[k][::step]  
    return tamples  

def split_samples(samples, num_proc, length):
    sample_list = []
    splits = np.array(length/num_proc).astype(int)
    s = {}
    for i in trange(splits):
        for k in samples.keys():
            s[k] = samples[k][(i*num_proc):((i+1)*num_proc)]  
        sample_list.append(s)
    return sample_list

def get_samples_split(samples, num_proc, length, i):
    s = {}
    for k in samples.keys():
        s[k] = samples[k][(i*num_proc):((i+1)*num_proc)]  
    return s    

data = [xs, ys, xf, num_regions, jnp.asarray([5.,10.], dtype=jnp.float64), 
         jnp.asarray([0.1,2.], dtype=jnp.float64)]

key = jax.random.PRNGKey(0)
autoguide_mle = numpyro.infer.autoguide.AutoLowRankMultivariateNormal(model_SAGE_ND_PM_230628a)
optimizer = numpyro.optim.Adam(step_size=0.05)

svi = nSVI(model_SAGE_ND_PM_230628a, autoguide_mle, optimizer, loss=nTrace_ELBO())
svi_result = svi.run(key, 100000, *data)

params = svi_result.params
mle_2a_st = autoguide_mle.median(params)
preds_st = predict_fn_structure(mle_2a_st)

gpc_new_probs_, gpc_new_latent_ = predict_fn_structure(mle_2a_st)

gpc_latent_ = mle_2a_st['gpc_latent_0']
for i in range(1,num_regions):
    gpc_latent_ = jnp.vstack((gpc_latent_,mle_2a_st['gpc_latent_' + str(i)]))

preds_fp = None

# !!!!!!!!!!!!!!!!!!!!!!!!

# gpc_new_probs_st = predict_fn_st(mle_2a_st)
init_params = {'gpc_var': mle_2a_st['gpc_var'],'gpc_lengthscale': mle_2a_st['gpc_lengthscale'], 'gpc_bias': mle_2a_st['gpc_bias']}
for i in range(num_regions):
    init_params['gpc_latent_'+str(i)] = mle_2a_st['gpc_latent_'+str(i)]
init_strategy=init_to_value(values=init_params)

tic = time.perf_counter()
nuts = nMCMC(nNUTS(model_SAGE_ND_230628a, target_accept_prob=0.8, max_tree_depth=5, init_strategy=init_strategy),
            num_samples=1000, num_warmup=100, num_chains=1)
nuts.run(key, xs, ys, xf, yf, num_regions, gpc_var_bounds = jnp.asarray([5.,10.], dtype=jnp.float64),
         gpc_ls_bounds = jnp.asarray([.1,2.], dtype=jnp.float64), gpr_var_bounds = jnp.asarray([.1, 2.], dtype=jnp.float64),
         gpr_ls_bounds = jnp.asarray([.1,5.], dtype=jnp.float64), gpr_noise_bounds = jnp.asarray([0.001,.01], dtype=jnp.float64))

nuts_posterior_samples = nuts.get_samples()

import dill
# with open(r"2D_BSF_samples.dill", "wb") as output_file:
#     dill.dump(nuts_posterior_samples, output_file)

print('start', nuts_posterior_samples['gpr_noise'].shape[0])
samples = subsample(nuts_posterior_samples, step = 1)
print('after subsampling', samples['gpr_noise'].shape[0]) 

num_length = samples['gpr_noise'].shape[0]

print('splitting')
sl = split_samples(samples, num_proc, num_length)
print('done splitting')

splits = np.array(num_length / num_proc).astype(int)

predict_fn_sage = jax.pmap(
    lambda samples: predict_sage(
        samples, predict_SAGE_ND_240712a, Xnew=Xnew_, xs=xs, ys=ys, xf=xf, yf=yf, num_regions=num_regions, idx_Xnew_exclude_xs=idx_Xnew_exclude_xs, idx_Xnew_match_xs=idx_Xnew_match_xs, idx_xs_match_Xnew=idx_xs_match_Xnew, idx_xf_exclude_xs=idx_xf_exclude_xs
    ), axis_name = 0
)

print('starting pred analysis, for #', num_length)
labels = ['Fr_new', 'f_piecewise', 'f_sample', 'gpc_new_probs', 'v_piecewise']

for i in trange(splits):
    if i == 0:
        preds = predict_fn_sage(sl[i])
        preds_stacked = {labels[0]:preds[0].squeeze(), labels[1]:preds[1].squeeze(), labels[2]:preds[2].squeeze(), labels[3]:preds[3].squeeze(), labels[4]:preds[4].squeeze()}
    else:
        preds = predict_fn_sage(sl[i])
        for j in range(len(labels)):
            preds_stacked[labels[j]] = np.vstack((preds_stacked[labels[j]],preds[j].squeeze()))
toc = time.perf_counter()
print(f"Run in {toc - tic:0.4f} seconds")    

print('done pred analysis')

preds_stacked_ = preds_stacked.copy()
preds_stacked_[labels[0]] = preds_stacked[labels[0]].reshape(-1,Xnew_.shape[0],preds_stacked[labels[0]].shape[-1])
preds_stacked_[labels[3]] = preds_stacked[labels[3]].reshape(-1,Xnew_.shape[0],preds_stacked[labels[0]].shape[-1])

output = {'preds': preds_stacked_, 'preds_st':preds_st, 'preds_fp':preds_fp, 'starting_data':starting_data}

with open(r"2D_BSF_1core_231031a.dill", "wb") as output_file:
    dill.dump(output, output_file)
    
preds_sage = output['preds']
print(sklearn.metrics.r2_score(f[:,0],np.nanmean(preds_sage['f_piecewise'], axis=0).flatten()))

In [None]:
import matplotlib.pyplot as plt

import numpyro
import numpy as np
from numpy.random import default_rng
import sklearn

import torch
from torch.distributions import constraints
import scipy.io as sio

import jax
import jax.numpy as jnp
from jax import config
config.update("jax_enable_x64", True)
from jax.lax import dynamic_slice
from jax.nn import one_hot as jax_one_hot

from numpyro.infer import MCMC as nMCMC
from numpyro.infer import NUTS as nNUTS
from numpyro.infer import SVI as nSVI
from numpyro.infer import Predictive as nPredictive
from numpyro.infer import Trace_ELBO as nTrace_ELBO
import numpyro.distributions as ndist

torch.set_default_dtype(torch.float64)
from tqdm import trange
import dill
from tqdm import tqdm, trange
from torch.multiprocessing import Pool, Manager, Process
import functools
import math
import time
from numpyro import handlers
from numpyro.infer.initialization import init_to_value

import sys
sys.path.insert(0, r'C:/Users/gkusne/Documents/GitHub/')
import hermes
from hermes.joint import SAGE_ND

numpyro.set_host_device_count(1)

num_proc = 1
    
# ------ Load data -----------
BSF = sio.loadmat('Raman_with_matched_Ecoercivity_180416a.mat')

ecoer = BSF['Ecoer_sub']
R = BSF['X']
xy = BSF['xy']

kp = ecoer > 0
kp[[57,16]] = False
kp = kp.flatten()

kp = kp.flatten()
Xp = xy[kp,:].astype('double')
f = ecoer[kp]/500.

Lp = np.ones((Xp.shape[0]))
Lp[np.logical_and(Xp[:,1]==7, Xp[:,0]>22.5)] = 2
Lp[np.logical_and(Xp[:,1]==9, Xp[:,0]>25.5)] = 2
Lp[np.logical_and(Xp[:,1]==11, Xp[:,0]>27)] = 2
Lp[np.logical_and(Xp[:,1]==15, Xp[:,0]>30.5)] = 2

idx = Xp[:,1] > 1.2*Xp[:,0] - 8
Lp[idx] = 0
Xp = (Xp-20)/10.

# N = 50
# kp_fp = default_rng(0).permutation(Xp.shape[0])[:N]
# kp_st = default_rng(1).permutation(Xp.shape[0])[:N]

with open(r"BSF_st_and_fp_samples_231030a.dill", "rb") as input_file:
    kp_st, kp_fp = dill.load(input_file)

xs = Xp[kp_st,:]
ys = Lp[kp_st]
xf = Xp[kp_fp,:]
yf = f[kp_fp,0][:,None]
starting_data = [Xp, Lp, f, xs, ys, xf, yf]
sage_nd = SAGE_ND(
    num_phase_regions=4,
    num_samples=1000,
    num_warmup=100,
    num_chains = 1,
    target_accept_prob = 0.8,
    max_tree_depth = 5,
    jitter = 1E-6,
    phase_map_SVI_num_steps = 100000,
    Adam_step_size = 0.05,
    posterior_sampling = 1,
    locations_structure = np.asarray(xs),
    locations_functional_property = np.asarray(xf),
    target_structure_labels = np.asarray(ys),
    target_functional_properties = np.asarray(yf),
    locations_prediction = np.asarray(Xp),
    gpc_variance_bounds = np.asarray([5.,10.]),
    gpc_lengthscale_bounds = np.asarray([.1,2.]),
    gpr_variance_bounds = np.asarray([.1, 2.]),
    gpr_lengthscale_bounds = np.asarray([.1,5.]),
    gpr_noise_bounds= np.asarray([0.001,.1]),
    gpr_bias_bounds = np.asarray([-2., 2.]),
    )

sage_nd.run()
predictions_bsf = sage_nd.predictions
print(sklearn.metrics.r2_score(f[:,0],predictions_bsf['functional_property_mean'].flatten()))

In [None]:
with open(r"2D_BSF_1core_231031a.dill", "wb") as output_file:
    dill.dump(predictions_bsf, output_file)

###### BSF: SAGE-ND-FP, multicore

In [None]:
%%writefile sage_BSF_fp_231031a.py

from sage_2D_functions_230804a import *

import matplotlib.pyplot as plt

import numpyro
import numpy as np
from numpy.random import default_rng

import torch
from torch.distributions import constraints
import scipy.io as sio

import jax
import jax.numpy as jnp
from jax import config
config.update("jax_enable_x64", True)
from jax.lax import dynamic_slice
from jax.nn import one_hot as jax_one_hot

from numpyro.infer import MCMC as nMCMC
from numpyro.infer import NUTS as nNUTS
from numpyro.infer import SVI as nSVI
from numpyro.infer import Predictive as nPredictive
from numpyro.infer import Trace_ELBO as nTrace_ELBO
import numpyro.distributions as ndist

torch.set_default_dtype(torch.float64)
from tqdm import trange
import dill
from tqdm import tqdm, trange
from torch.multiprocessing import Pool, Manager, Process
import functools
import math
import time
from numpyro import handlers
from numpyro.infer.initialization import init_to_value

numpyro.set_host_device_count(200)

if __name__ == '__main__':
    num_proc = 100
    
    # ------ Load data -----------
    BSF = sio.loadmat('Raman_with_matched_Ecoercivity_180416a.mat')

    ecoer = BSF['Ecoer_sub']
    R = BSF['X']
    xy = BSF['xy']

    kp = ecoer > 0
    kp[[57,16]] = False
    kp = kp.flatten()

    kp = kp.flatten()
    Xp = xy[kp,:].astype('double')
    f = ecoer[kp]/500.

    Lp = np.ones((Xp.shape[0]))
    Lp[np.logical_and(Xp[:,1]==7, Xp[:,0]>22.5)] = 2
    Lp[np.logical_and(Xp[:,1]==9, Xp[:,0]>25.5)] = 2
    Lp[np.logical_and(Xp[:,1]==11, Xp[:,0]>27)] = 2
    Lp[np.logical_and(Xp[:,1]==15, Xp[:,0]>30.5)] = 2

    idx = Xp[:,1] > 1.2*Xp[:,0] - 8
    Lp[idx] = 0
    Xp = (Xp-20)/10.
    
    with open(r"BSF_st_and_fp_samples_231030a.dill", "rb") as input_file:
        kp_st, kp_fp = dill.load(input_file)

    xs = Xp[kp_st,:]
    ys = Lp[kp_st]
    xf = Xp[kp_fp,:]
    yf = f[kp_fp,0][:,None]
    starting_data = [Xp, Lp, f, xs, ys, xf, yf]

    xs = jnp.asarray( xs, dtype=jnp.float64).copy()
    ys = jnp.asarray( ys, dtype=jnp.integer).copy()
    xf = jnp.asarray( xf, dtype=jnp.float64).copy()
    yf = jnp.asarray( yf, dtype=jnp.float64).copy()

    Ns = xs.shape[0] + xf.shape[0]

    Xnew_ = jnp.asarray( Xp, dtype=jnp.float64).copy()


    key = jax.random.PRNGKey(0)
    num_regions = 3

    def predict_fp(post_samples, model, *args, **kwargs):
        key = jax.random.PRNGKey(0)
        model = handlers.seed(handlers.condition(model, post_samples), key)
        model_trace = handlers.trace(model).get_trace(*args, **kwargs)
        return model_trace["Fr_new"]["value"], model_trace["f_piecewise"]["value"], model_trace["f_sample"]["value"], model_trace["gpc_new_probs"]["value"], model_trace["v_piecewise"]["value"]

    def subsample(samples, step):
        tamples = {}
        for k in samples.keys():
            tamples[k] = samples[k][::step]  
        return tamples  

    def split_samples(samples, num_proc, length):
        sample_list = []
        splits = np.array(length/num_proc).astype(int)
        s = {}
        for i in trange(splits):
            for k in samples.keys():
                s[k] = samples[k][(i*num_proc):((i+1)*num_proc)]  
            sample_list.append(s)
        return sample_list

    def get_samples_split(samples, num_proc, length, i):
        s = {}
        for k in samples.keys():
            s[k] = samples[k][(i*num_proc):((i+1)*num_proc)]  
        return s    


    # !!!!!!!!!!!!!!!!!!!!!!!!
    
    tic = time.perf_counter()
    nuts = nMCMC(nNUTS(model_SAGE_ND_FP_230628a, target_accept_prob=0.8, max_tree_depth=5),
                num_samples=2000, num_warmup=100, num_chains=100)
    nuts.run(key, xf, yf, num_regions, gpc_var_bounds = jnp.asarray([5.,10.], dtype=jnp.float64),
             gpc_ls_bounds = jnp.asarray([.1,2.], dtype=jnp.float64), gpr_var_bounds = jnp.asarray([.1, 2.], dtype=jnp.float64),
             gpr_ls_bounds = jnp.asarray([.1,2.], dtype=jnp.float64), gpr_noise_bounds = jnp.asarray([0.001,.1], dtype=jnp.float64))

    nuts_posterior_samples = nuts.get_samples()
    
    import dill

    print('start', nuts_posterior_samples['gpr_noise'].shape[0])
    samples = subsample(nuts_posterior_samples, step = 10)
    print('after subsampling', samples['gpr_noise'].shape[0]) 

    num_length = samples['gpr_noise'].shape[0]
    
    print('splitting')
    sl = split_samples(samples, num_proc, num_length)
    print('done splitting')
    
    splits = np.array(num_length / num_proc).astype(int)
   
    predict_fn_sage = jax.pmap(
        lambda samples: predict_fp(
            samples, predict_SAGE_ND_FP_230628a, Xnew=Xnew_, xf=xf, yf=yf, num_regions=num_regions
        ), axis_name = 0
    )

    print('starting pred analysis, for #', num_length)
    labels = ['Fr_new', 'f_piecewise', 'f_sample', 'gpc_new_probs', 'v_piecewise']

    for i in trange(splits):
        if i == 0:
            preds = predict_fn_sage(sl[i])
            preds_stacked = {labels[0]:preds[0].squeeze(), labels[1]:preds[1].squeeze(), labels[2]:preds[2].squeeze(), labels[3]:preds[3].squeeze(), labels[4]:preds[4].squeeze()}
        else:
            preds = predict_fn_sage(sl[i])
            for j in range(len(labels)):
                preds_stacked[labels[j]] = np.vstack((preds_stacked[labels[j]],preds[j].squeeze()))
    toc = time.perf_counter()
    print(f"Run in {toc - tic:0.4f} seconds")    
                
    print('done pred analysis')
    
    output = {'preds': preds_stacked, 'starting_data':starting_data}
    with open(r"BSF_fp_231011a.dill", "wb") as output_file:
        dill.dump(output, output_file)

###### BSF: SAGE-ND-FP, 1 core

In [None]:
# %%writefile sage_BSF_fp_231031a.py

from sage_2D_functions_230804a import *

import matplotlib.pyplot as plt

import numpyro
import numpy as np
from numpy.random import default_rng

import torch
from torch.distributions import constraints
import scipy.io as sio

import jax
import jax.numpy as jnp
from jax import config
config.update("jax_enable_x64", True)
from jax.lax import dynamic_slice
from jax.nn import one_hot as jax_one_hot

from numpyro.infer import MCMC as nMCMC
from numpyro.infer import NUTS as nNUTS
from numpyro.infer import SVI as nSVI
from numpyro.infer import Predictive as nPredictive
from numpyro.infer import Trace_ELBO as nTrace_ELBO
import numpyro.distributions as ndist

torch.set_default_dtype(torch.float64)
from tqdm import trange
import dill
from tqdm import tqdm, trange
from torch.multiprocessing import Pool, Manager, Process
import functools
import math
import time
from numpyro import handlers
from numpyro.infer.initialization import init_to_value

numpyro.set_host_device_count(200)

num_proc = 1

# ------ Load data -----------
BSF = sio.loadmat('Raman_with_matched_Ecoercivity_180416a.mat')

ecoer = BSF['Ecoer_sub']
R = BSF['X']
xy = BSF['xy']

kp = ecoer > 0
kp[[57,16]] = False
kp = kp.flatten()

kp = kp.flatten()
Xp = xy[kp,:].astype('double')
f = ecoer[kp]/500.

Lp = np.ones((Xp.shape[0]))
Lp[np.logical_and(Xp[:,1]==7, Xp[:,0]>22.5)] = 2
Lp[np.logical_and(Xp[:,1]==9, Xp[:,0]>25.5)] = 2
Lp[np.logical_and(Xp[:,1]==11, Xp[:,0]>27)] = 2
Lp[np.logical_and(Xp[:,1]==15, Xp[:,0]>30.5)] = 2

idx = Xp[:,1] > 1.2*Xp[:,0] - 8
Lp[idx] = 0
Xp = (Xp-20)/10.

with open(r"BSF_st_and_fp_samples_231030a.dill", "rb") as input_file:
    kp_st, kp_fp = dill.load(input_file)

xs = Xp[kp_st,:]
ys = Lp[kp_st]
xf = Xp[kp_fp,:]
yf = f[kp_fp,0][:,None]
starting_data = [Xp, Lp, f, xs, ys, xf, yf]

xs = jnp.asarray( xs, dtype=jnp.float64).copy()
ys = jnp.asarray( ys, dtype=jnp.integer).copy()
xf = jnp.asarray( xf, dtype=jnp.float64).copy()
yf = jnp.asarray( yf, dtype=jnp.float64).copy()

Ns = xs.shape[0] + xf.shape[0]

Xnew_ = jnp.asarray( Xp, dtype=jnp.float64).copy()


key = jax.random.PRNGKey(0)
num_regions = 3

def predict_fp(post_samples, model, *args, **kwargs):
    key = jax.random.PRNGKey(0)
    model = handlers.seed(handlers.condition(model, post_samples), key)
    model_trace = handlers.trace(model).get_trace(*args, **kwargs)
    return model_trace["Fr_new"]["value"], model_trace["f_piecewise"]["value"], model_trace["f_sample"]["value"], model_trace["gpc_new_probs"]["value"], model_trace["v_piecewise"]["value"]

def subsample(samples, step):
    tamples = {}
    for k in samples.keys():
        tamples[k] = samples[k][::step]  
    return tamples  

def split_samples(samples, num_proc, length):
    sample_list = []
    splits = np.array(length/num_proc).astype(int)
    s = {}
    for i in trange(splits):
        for k in samples.keys():
            s[k] = samples[k][(i*num_proc):((i+1)*num_proc)]  
        sample_list.append(s)
    return sample_list

def get_samples_split(samples, num_proc, length, i):
    s = {}
    for k in samples.keys():
        s[k] = samples[k][(i*num_proc):((i+1)*num_proc)]  
    return s    


# !!!!!!!!!!!!!!!!!!!!!!!!

tic = time.perf_counter()
nuts = nMCMC(nNUTS(model_SAGE_ND_FP_230628a, target_accept_prob=0.8, max_tree_depth=5),
            num_samples=1000, num_warmup=100, num_chains=1)
nuts.run(key, xf, yf, num_regions, gpc_var_bounds = jnp.asarray([5.,10.], dtype=jnp.float64),
         gpc_ls_bounds = jnp.asarray([.1,2.], dtype=jnp.float64), gpr_var_bounds = jnp.asarray([.1, 2.], dtype=jnp.float64),
         gpr_ls_bounds = jnp.asarray([.1,5.], dtype=jnp.float64), gpr_noise_bounds = jnp.asarray([0.001,.1], dtype=jnp.float64))

nuts_posterior_samples = nuts.get_samples()

import dill

print('start', nuts_posterior_samples['gpr_noise'].shape[0])
samples = subsample(nuts_posterior_samples, step = 1)
print('after subsampling', samples['gpr_noise'].shape[0]) 

num_length = samples['gpr_noise'].shape[0]

print('splitting')
sl = split_samples(samples, num_proc, num_length)
print('done splitting')

splits = np.array(num_length / num_proc).astype(int)

predict_fn_sage = jax.pmap(
    lambda samples: predict_fp(
        samples, predict_SAGE_ND_FP_230628a, Xnew=Xnew_, xf=xf, yf=yf, num_regions=num_regions
    ), axis_name = 0
)

print('starting pred analysis, for #', num_length)
labels = ['Fr_new', 'f_piecewise', 'f_sample', 'gpc_new_probs', 'v_piecewise']

for i in trange(splits):
    if i == 0:
        preds = predict_fn_sage(sl[i])
        preds_stacked = {labels[0]:preds[0].squeeze(), labels[1]:preds[1].squeeze(), labels[2]:preds[2].squeeze(), labels[3]:preds[3].squeeze(), labels[4]:preds[4].squeeze()}
    else:
        preds = predict_fn_sage(sl[i])
        for j in range(len(labels)):
            preds_stacked[labels[j]] = np.vstack((preds_stacked[labels[j]],preds[j].squeeze()))
toc = time.perf_counter()
print(f"Run in {toc - tic:0.4f} seconds")    

print('done pred analysis')

output = {'preds': preds_stacked, 'starting_data':starting_data}
with open(r"BSF_fp_231011a.dill", "wb") as output_file:
    dill.dump(output, output_file)

# with open(r"BSF_fp_matern52_231011a.dill", "rb") as input_file:
#     output = dill.load(input_file)
preds_sage = output['preds']

print(sklearn.metrics.r2_score(f[:,0],np.nanmean(preds_sage['f_piecewise'], axis=0).flatten()))

###### BSF: SAGE-ND-PM, multicore

In [None]:
%%writefile sage_BSF_structure_231031a.py

from sage_2D_functions_230804a import *

import matplotlib.pyplot as plt

import numpyro
import numpy as np
from numpy.random import default_rng

import torch
from torch.distributions import constraints
import scipy.io as sio

import jax
import jax.numpy as jnp
from jax import config
config.update("jax_enable_x64", True)
from jax.lax import dynamic_slice
from jax.nn import one_hot as jax_one_hot

from numpyro.infer import MCMC as nMCMC
from numpyro.infer import NUTS as nNUTS
from numpyro.infer import SVI as nSVI
from numpyro.infer import Predictive as nPredictive
from numpyro.infer import Trace_ELBO as nTrace_ELBO
import numpyro.distributions as ndist

torch.set_default_dtype(torch.float64)
from tqdm import trange
import dill
from tqdm import tqdm, trange
from torch.multiprocessing import Pool, Manager, Process
import functools
import math
import time
from numpyro import handlers
from numpyro.infer.initialization import init_to_value

numpyro.set_host_device_count(200)

if __name__ == '__main__':
    num_proc = 100
    
    # ------ Load data -----------
    BSF = sio.loadmat('Raman_with_matched_Ecoercivity_180416a.mat')

    ecoer = BSF['Ecoer_sub']
    R = BSF['X']
    xy = BSF['xy']

    kp = ecoer > 0
    kp[[57,16]] = False
    kp = kp.flatten()

    kp = kp.flatten()
    Xp = xy[kp,:].astype('double')
    f = ecoer[kp]/500.

    Lp = np.ones((Xp.shape[0]))
    Lp[np.logical_and(Xp[:,1]==7, Xp[:,0]>22.5)] = 2
    Lp[np.logical_and(Xp[:,1]==9, Xp[:,0]>25.5)] = 2
    Lp[np.logical_and(Xp[:,1]==11, Xp[:,0]>27)] = 2
    Lp[np.logical_and(Xp[:,1]==15, Xp[:,0]>30.5)] = 2

    idx = Xp[:,1] > 1.2*Xp[:,0] - 8
    Lp[idx] = 0
    Xp = (Xp-20)/10.

    # N = 50
    # kp_fp = default_rng(0).permutation(Xp.shape[0])[:N]
    # kp_st = default_rng(1).permutation(Xp.shape[0])[:N]
    with open(r"BSF_st_and_fp_samples_231030a.dill", "rb") as input_file:
        kp_st, kp_fp = dill.load(input_file)

    xs = Xp[kp_st,:]
    ys = Lp[kp_st]
    xf = Xp[kp_fp,:]
    yf = f[kp_fp,0][:,None]
    starting_data = [Xp, Lp, f, xs, ys, xf, yf]

    xs = jnp.asarray( xs, dtype=jnp.float64).copy()
    ys = jnp.asarray( ys, dtype=jnp.integer).copy()
    xf = jnp.asarray( xf, dtype=jnp.float64).copy()
    yf = jnp.asarray( yf, dtype=jnp.float64).copy()

    Ns = xs.shape[0] + xf.shape[0]

    Xnew_ = jnp.asarray( Xp, dtype=jnp.float64).copy()


    key = jax.random.PRNGKey(0)
    num_regions = 3
    def predict_structure(post_samples, model, *args, **kwargs):
        key = jax.random.PRNGKey(0)
        model = handlers.seed(handlers.condition(model, post_samples), key)
        model_trace = handlers.trace(model).get_trace(*args, **kwargs)
        return model_trace["gpc_new_probs"]["value"], model_trace["gpc_new_latent"]["value"]

    def predict_sage(post_samples, model, *args, **kwargs):
        key = jax.random.PRNGKey(0)
        model = handlers.seed(handlers.condition(model, post_samples), key)
        model_trace = handlers.trace(model).get_trace(*args, **kwargs)
        return model_trace["Fr_new"]["value"], model_trace["f_piecewise"]["value"], model_trace["f_sample"]["value"], model_trace["gpc_new_probs"]["value"], model_trace["v_piecewise"]["value"]

    predict_fn_structure = lambda samples: predict_structure(
            samples, predict_SAGE_ND_PM_230628a, Xnew=Xnew_, xs=xs, ys=ys, num_regions=num_regions
        )
    predict_fn_sage_1core = lambda samples: predict_sage(
            samples, predict_SAGE_ND_230628a, Xnew_, xs, ys, xf, yf, num_regions=num_regions
        )
    def subsample(samples, step):
        tamples = {}
        for k in samples.keys():
            tamples[k] = samples[k][::step]  
        return tamples  

    def split_samples(samples, num_proc, length):
        sample_list = []
        splits = np.array(length/num_proc).astype(int)
        s = {}
        for i in trange(splits):
            for k in samples.keys():
                s[k] = samples[k][(i*num_proc):((i+1)*num_proc)]  
            sample_list.append(s)
        return sample_list

    def get_samples_split(samples, num_proc, length, i):
        s = {}
        for k in samples.keys():
            s[k] = samples[k][(i*num_proc):((i+1)*num_proc)]  
        return s    
    
    data = [xs, ys, xf, num_regions, jnp.asarray([5.,10.], dtype=jnp.float64), 
             jnp.asarray([1.,2.], dtype=jnp.float64)]
        
    key = jax.random.PRNGKey(0)
    autoguide_mle = numpyro.infer.autoguide.AutoLowRankMultivariateNormal(model_SAGE_ND_PM_230628a)
    optimizer = numpyro.optim.Adam(step_size=0.05)

    svi = nSVI(model_SAGE_ND_PM_230628a, autoguide_mle, optimizer, loss=nTrace_ELBO())
    svi_result = svi.run(key, 100000, *data)

    params = svi_result.params
    mle_2a_st = autoguide_mle.median(params)
    preds_st = predict_fn_structure(mle_2a_st)
    
    gpc_new_probs_, gpc_new_latent_ = predict_fn_structure(mle_2a_st)
    
    gpc_latent_ = jnp.vstack((mle_2a_st['gpc_latent_0'],mle_2a_st['gpc_latent_1']))

    preds_fp = None
    
    # !!!!!!!!!!!!!!!!!!!!!!!!

    # gpc_new_probs_st = predict_fn_st(mle_2a_st)
    init_params = {'gpc_latent_0': mle_2a_st['gpc_latent_0'], 'gpc_latent_1': mle_2a_st['gpc_latent_1'],
                   'gpc_var': mle_2a_st['gpc_var'],'gpc_lengthscale': mle_2a_st['gpc_lengthscale'],
                   'gpc_bias': mle_2a_st['gpc_bias']}
    init_strategy=init_to_value(values=init_params)
    
    tic = time.perf_counter()
    nuts = nMCMC(nNUTS(model_SAGE_ND_PM_230628a, target_accept_prob=0.8, max_tree_depth=5, init_strategy=init_strategy),
                num_samples=2000, num_warmup=100, num_chains=100)
    nuts.run(key, xs, ys, xf, num_regions, gpc_var_bounds = jnp.asarray([5.,10.], dtype=jnp.float64),
             gpc_ls_bounds = jnp.asarray([.1,2.], dtype=jnp.float64))

    nuts_posterior_samples = nuts.get_samples()
    
    import dill
    with open(r"2D_2bn_structure.dill", "wb") as output_file:
        dill.dump(nuts_posterior_samples, output_file)

    print('start', nuts_posterior_samples['gpc_bias'].shape[0])
    samples = subsample(nuts_posterior_samples, step = 10)
    print('after subsampling', samples['gpc_bias'].shape[0]) 

    num_length = samples['gpc_bias'].shape[0]
    
    print('splitting')
    sl = split_samples(samples, num_proc, num_length)
    print('done splitting')
    
    splits = np.array(num_length / num_proc).astype(int)
   

    predict_fn_st_multicore = jax.pmap(
        lambda samples: predict_structure(
            samples, predict_SAGE_ND_PM_230628a, Xnew=Xnew_, xs=xs, ys=ys, num_regions=num_regions
        ), axis_name = 0
    )

    print('starting pred analysis, for #', num_length)
    labels = ['gpc_new_probs']

    for i in trange(splits):
        if i == 0:
            preds = predict_fn_st_multicore(sl[i])
            preds_stacked = {labels[0]:preds[0].squeeze()}
        else:
            preds = predict_fn_st_multicore(sl[i])
            for j in range(len(labels)):
                preds_stacked[labels[j]] = np.vstack((preds_stacked[labels[j]],preds[j].squeeze()))
    toc = time.perf_counter()
    print(f"Run in {toc - tic:0.4f} seconds")
                
    print('done pred analysis')
    
    output = {'preds': preds_stacked, 'preds_st':preds_st, 'preds_fp':preds_fp, 'starting_data':starting_data}
    with open(r"BSF_structure_matern52_231011a.dill", "wb") as output_file:
        dill.dump(output, output_file)

###### BSF: SAGE-ND-PM, 1 core

In [None]:
# %%writefile sage_BSF_structure_231031a.py

from sage_2D_functions_230804a import *

import matplotlib.pyplot as plt

import numpyro
import numpy as np
from numpy.random import default_rng

import torch
from torch.distributions import constraints
import scipy.io as sio

import jax
import jax.numpy as jnp
from jax import config
config.update("jax_enable_x64", True)
from jax.lax import dynamic_slice
from jax.nn import one_hot as jax_one_hot

from numpyro.infer import MCMC as nMCMC
from numpyro.infer import NUTS as nNUTS
from numpyro.infer import SVI as nSVI
from numpyro.infer import Predictive as nPredictive
from numpyro.infer import Trace_ELBO as nTrace_ELBO
import numpyro.distributions as ndist

torch.set_default_dtype(torch.float64)
from tqdm import trange
import dill
from tqdm import tqdm, trange
from torch.multiprocessing import Pool, Manager, Process
import functools
import math
import time
from numpyro import handlers
from numpyro.infer.initialization import init_to_value

numpyro.set_host_device_count(200)

num_proc = 1

# ------ Load data -----------
BSF = sio.loadmat('Raman_with_matched_Ecoercivity_180416a.mat')

ecoer = BSF['Ecoer_sub']
R = BSF['X']
xy = BSF['xy']

kp = ecoer > 0
kp[[57,16]] = False
kp = kp.flatten()

kp = kp.flatten()
Xp = xy[kp,:].astype('double')
f = ecoer[kp]/500.

Lp = np.ones((Xp.shape[0]))
Lp[np.logical_and(Xp[:,1]==7, Xp[:,0]>22.5)] = 2
Lp[np.logical_and(Xp[:,1]==9, Xp[:,0]>25.5)] = 2
Lp[np.logical_and(Xp[:,1]==11, Xp[:,0]>27)] = 2
Lp[np.logical_and(Xp[:,1]==15, Xp[:,0]>30.5)] = 2

idx = Xp[:,1] > 1.2*Xp[:,0] - 8
Lp[idx] = 0
Xp = (Xp-20)/10.

# N = 50
# kp_fp = default_rng(0).permutation(Xp.shape[0])[:N]
# kp_st = default_rng(1).permutation(Xp.shape[0])[:N]
with open(r"BSF_st_and_fp_samples_231030a.dill", "rb") as input_file:
    kp_st, kp_fp = dill.load(input_file)

xs = Xp[kp_st,:]
ys = Lp[kp_st]
xf = Xp[kp_fp,:]
yf = f[kp_fp,0][:,None]
starting_data = [Xp, Lp, f, xs, ys, xf, yf]

xs = jnp.asarray( xs, dtype=jnp.float64).copy()
ys = jnp.asarray( ys, dtype=jnp.integer).copy()
xf = jnp.asarray( xf, dtype=jnp.float64).copy()
yf = jnp.asarray( yf, dtype=jnp.float64).copy()

Ns = xs.shape[0] + xf.shape[0]

Xnew_ = jnp.asarray( Xp, dtype=jnp.float64).copy()


key = jax.random.PRNGKey(0)
num_regions = 3
def predict_structure(post_samples, model, *args, **kwargs):
    key = jax.random.PRNGKey(0)
    model = handlers.seed(handlers.condition(model, post_samples), key)
    model_trace = handlers.trace(model).get_trace(*args, **kwargs)
    return model_trace["gpc_new_probs"]["value"], model_trace["gpc_new_latent"]["value"]

def predict_sage(post_samples, model, *args, **kwargs):
    key = jax.random.PRNGKey(0)
    model = handlers.seed(handlers.condition(model, post_samples), key)
    model_trace = handlers.trace(model).get_trace(*args, **kwargs)
    return model_trace["Fr_new"]["value"], model_trace["f_piecewise"]["value"], model_trace["f_sample"]["value"], model_trace["gpc_new_probs"]["value"], model_trace["v_piecewise"]["value"]

predict_fn_structure = lambda samples: predict_structure(
        samples, predict_SAGE_ND_PM_230628a, Xnew=Xnew_, xs=xs, ys=ys, num_regions=num_regions
    )
predict_fn_sage_1core = lambda samples: predict_sage(
        samples, predict_SAGE_ND_230628a, Xnew_, xs, ys, xf, yf, num_regions=num_regions
    )
def subsample(samples, step):
    tamples = {}
    for k in samples.keys():
        tamples[k] = samples[k][::step]  
    return tamples  

def split_samples(samples, num_proc, length):
    sample_list = []
    splits = np.array(length/num_proc).astype(int)
    s = {}
    for i in trange(splits):
        for k in samples.keys():
            s[k] = samples[k][(i*num_proc):((i+1)*num_proc)]  
        sample_list.append(s)
    return sample_list

def get_samples_split(samples, num_proc, length, i):
    s = {}
    for k in samples.keys():
        s[k] = samples[k][(i*num_proc):((i+1)*num_proc)]  
    return s    

data = [xs, ys, xf, num_regions, jnp.asarray([5.,10.], dtype=jnp.float64), 
         jnp.asarray([1.,2.], dtype=jnp.float64)]

key = jax.random.PRNGKey(0)
autoguide_mle = numpyro.infer.autoguide.AutoLowRankMultivariateNormal(model_SAGE_ND_PM_230628a)
optimizer = numpyro.optim.Adam(step_size=0.05)

svi = nSVI(model_SAGE_ND_PM_230628a, autoguide_mle, optimizer, loss=nTrace_ELBO())
svi_result = svi.run(key, 100000, *data)

params = svi_result.params
mle_2a_st = autoguide_mle.median(params)
preds_st = predict_fn_structure(mle_2a_st)

gpc_new_probs_, gpc_new_latent_ = predict_fn_structure(mle_2a_st)

gpc_latent_ = jnp.vstack((mle_2a_st['gpc_latent_0'],mle_2a_st['gpc_latent_1']))

preds_fp = None

# !!!!!!!!!!!!!!!!!!!!!!!!

# gpc_new_probs_st = predict_fn_st(mle_2a_st)
init_params = {'gpc_latent_0': mle_2a_st['gpc_latent_0'], 'gpc_latent_1': mle_2a_st['gpc_latent_1'],
               'gpc_var': mle_2a_st['gpc_var'],'gpc_lengthscale': mle_2a_st['gpc_lengthscale'],
               'gpc_bias': mle_2a_st['gpc_bias']}
init_strategy=init_to_value(values=init_params)

tic = time.perf_counter()
nuts = nMCMC(nNUTS(model_SAGE_ND_PM_230628a, target_accept_prob=0.8, max_tree_depth=5, init_strategy=init_strategy),
            num_samples=1000, num_warmup=100, num_chains=1)
nuts.run(key, xs, ys, xf, num_regions, gpc_var_bounds = jnp.asarray([5.,10.], dtype=jnp.float64),
         gpc_ls_bounds = jnp.asarray([.1,2.], dtype=jnp.float64))

nuts_posterior_samples = nuts.get_samples()

import dill
# with open(r"2D_2bn_structure.dill", "wb") as output_file:
#     dill.dump(nuts_posterior_samples, output_file)

print('start', nuts_posterior_samples['gpc_bias'].shape[0])
samples = subsample(nuts_posterior_samples, step = 1)
print('after subsampling', samples['gpc_bias'].shape[0]) 

num_length = samples['gpc_bias'].shape[0]

print('splitting')
sl = split_samples(samples, num_proc, num_length)
print('done splitting')

splits = np.array(num_length / num_proc).astype(int)


predict_fn_st_multicore = jax.pmap(
    lambda samples: predict_structure(
        samples, predict_SAGE_ND_PM_230628a, Xnew=Xnew_, xs=xs, ys=ys, num_regions=num_regions
    ), axis_name = 0
)

print('starting pred analysis, for #', num_length)
labels = ['gpc_new_probs']

for i in trange(splits):
    if i == 0:
        preds = predict_fn_st_multicore(sl[i])
        preds_stacked = {labels[0]:preds[0].squeeze()}
    else:
        preds = predict_fn_st_multicore(sl[i])
        for j in range(len(labels)):
            preds_stacked[labels[j]] = np.vstack((preds_stacked[labels[j]],preds[j].squeeze()))
toc = time.perf_counter()
print(f"Run in {toc - tic:0.4f} seconds")

print('done pred analysis')

output = {'preds': preds_stacked, 'preds_st':preds_st, 'preds_fp':preds_fp, 'starting_data':starting_data}
with open(r"BSF_structure_matern52_231011a.dill", "wb") as output_file:
    dill.dump(output, output_file)

In [None]:
with open(r"BSF_structure_matern52_231011a.dill", "rb") as input_file:
    output = dill.load(input_file)

t = output['preds']['gpc_new_probs']
print(t.shape)

###### Visualize Results

In [None]:
import dill
from scipy.stats import multivariate_normal, entropy
# with open(r"2D_BSF_new_pred_231031a.dill", "rb") as input_file:
#     output = dill.load(input_file)
# output = predictions_bsf
preds_sage = output['preds']
preds_st = output['preds_st']
preds_fp = output['preds_fp']
starting_data = output['starting_data']

phase_region_labels_mean = np.nanmean(preds_sage['gpc_new_probs'], axis=0)
phase_region_labels_std = np.nanstd(preds_sage['gpc_new_probs'], axis=0)
phase_region_labels_mean_estimate = np.argmax(phase_region_labels_mean,axis=1)
phase_region_labels_mean_entropy = entropy(phase_region_labels_mean,axis=1)
functional_properties_mean = np.nanmean(preds_sage['f_piecewise'], axis=0)
functional_properties_std = np.sqrt( np.nanmean(preds_sage['v_piecewise'], axis=0) )

Xp, Lp, f, xsi, ysi, xfi, yfi = starting_data
Xp = Xp*10.+20.
xsi = xsi*10.+20.
xfi = xfi*10.+20.
print(Xp.shape, Lp.shape, f.shape, xsi.shape, ysi.shape, xfi.shape, yfi.shape)

labels = ['Fr_new', 'f_piecewise', 'f_sample', 'gpc_new_probs', 'v_piecewise']

xp = Xp[:,0]
yp = Xp[:,1]
x = xp.copy()
y = yp.copy()


plt.figure(figsize = (6,2.5),dpi=300)
plt.subplot(1,2,1)
plt.scatter(xp,yp,c=Lp,s=10)
plt.plot(xsi[:,0],xsi[:,1],'s',markerfacecolor='none',markeredgecolor='r',markersize=4.5)
plt.subplot(1,2,2)
plt.scatter(xp,yp,c=f[:,0],s=10)
plt.plot(xfi[:,0],xfi[:,1],'s',markerfacecolor='none',markeredgecolor='r',markersize=4.5)
# plt.savefig('BSF_ground_truth.png',transparent=True)
plt.show()

plt.figure(figsize = (6,2.5),dpi=300)
plt.subplot(1,2,1)
plt.scatter(x, y,c=np.argmax(preds_st[0],axis=1),s=10)
plt.scatter(xsi[:,0],xsi[:,1],c=ysi,s=10,edgecolors='r',marker='s')
plt.subplot(1,2,2)
plt.scatter(x, y,c=entropy(preds_st[0],axis=1),s=10)
plt.scatter(xsi[:,0],xsi[:,1],c=ysi,s=10,edgecolors='r',marker='s')
plt.title('VI approx');
plt.show()

gpc_mean = np.nanmean(preds_sage['gpc_new_probs'], axis=0)
gpc_est = np.argmax(gpc_mean,axis=1)
gpc_ent = entropy(gpc_mean,axis=1)

plt.figure()
plt.plot(gpc_mean);
plt.show()

plt.figure(figsize = (6,2.5),dpi=300)
plt.subplot(1,2,1)
plt.scatter(x, y, c=gpc_est, s=10)
plt.scatter(xsi[:,0],xsi[:,1],c=ysi,s=20,edgecolors='r',marker='s')
plt.subplot(1,2,2)
plt.plot(xsi[:,0],xsi[:,1],'s',markerfacecolor="none",markeredgecolor='r',markersize=4)
plt.scatter(x, y, c=gpc_ent, s=10)
# plt.savefig('BSF_GPC.png',transparent=True)
plt.show()

plt.figure(figsize = (6,2.5),dpi=300)
plt.subplot(1,2,1)
plt.scatter(x, y, c=np.nanmean(preds_sage['f_piecewise'], axis=0), s=10)
plt.scatter(xfi[:,0],xfi[:,1],s=20,c=yfi,edgecolor='r',marker='s')
plt.subplot(1,2,2)
plt.scatter(x, y, c=np.nanmean(preds_sage['v_piecewise'], axis=0), s=10)
plt.plot(xfi[:,0],xfi[:,1],'s',markerfacecolor="none",markeredgecolor='r',markersize=4)
# plt.savefig('BSF_GPR.png',transparent=True)
plt.show()

In [None]:
import dill
from scipy.stats import multivariate_normal, entropy

BSF = sio.loadmat('Raman_with_matched_Ecoercivity_180416a.mat')

ecoer = BSF['Ecoer_sub']
R = BSF['X']
xy = BSF['xy']

kp = ecoer > 0
kp[[57,16]] = False
kp = kp.flatten()

kp = kp.flatten()
Xp = xy[kp,:].astype('double')
f = ecoer[kp]/500.

Lp = np.ones((Xp.shape[0]))
Lp[np.logical_and(Xp[:,1]==7, Xp[:,0]>22.5)] = 2
Lp[np.logical_and(Xp[:,1]==9, Xp[:,0]>25.5)] = 2
Lp[np.logical_and(Xp[:,1]==11, Xp[:,0]>27)] = 2
Lp[np.logical_and(Xp[:,1]==15, Xp[:,0]>30.5)] = 2

idx = Xp[:,1] > 1.2*Xp[:,0] - 8
Lp[idx] = 0
Xp = (Xp-20)/10.

with open(r"BSF_st_and_fp_samples_231030a.dill", "rb") as input_file:
    kp_st, kp_fp = dill.load(input_file)

xs = Xp[kp_st,:]
ys = Lp[kp_st]
xf = Xp[kp_fp,:]
yf = f[kp_fp,0][:,None]

Xpi = Xp*10.+20.
xsi = xs*10.+20.
xfi = xf*10.+20.

xpi = Xpi[:,0]
ypi = Xpi[:,1]

plt.figure(figsize = (6,2.5),dpi=300)
plt.subplot(1,2,1)
plt.scatter(xpi,ypi,c=Lp,s=10)
plt.plot(xsi[:,0],xsi[:,1],'s',markerfacecolor='none',markeredgecolor='r',markersize=4.5)
plt.subplot(1,2,2)
plt.scatter(xpi,ypi,c=f[:,0],s=10)
plt.plot(xfi[:,0],xfi[:,1],'s',markerfacecolor='none',markeredgecolor='r',markersize=4.5)
# plt.savefig('BSF_ground_truth.png',transparent=True)
plt.show()

plt.figure(figsize = (6,2.5),dpi=300)
plt.subplot(1,2,1)
plt.scatter(xpi, ypi,c=np.argmax(output['phase_region_labels_SVI'],axis=1),s=10)
plt.scatter(xsi[:,0],xsi[:,1],c=ys,s=10,edgecolors='r',marker='s')
plt.subplot(1,2,2)
plt.scatter(xpi, ypi,c=entropy(output['phase_region_labels_SVI'],axis=1),s=10)
plt.scatter(xsi[:,0],xsi[:,1],c=ys,s=10,edgecolors='r',marker='s')
plt.title('VI approx');
plt.show()

plt.figure(figsize = (6,2.5),dpi=300)
plt.subplot(1,2,1)
plt.scatter(xpi, ypi, c=output['phase_region_labels_mean_estimate'], s=10)
plt.scatter(xsi[:,0],xsi[:,1],c=ys,s=20,edgecolors='r',marker='s')
plt.subplot(1,2,2)
plt.plot(xsi[:,0],xsi[:,1],'s',markerfacecolor="none",markeredgecolor='r',markersize=4)
plt.scatter(xpi, ypi, c=output['phase_region_labels_mean_entropy'], s=10)
# plt.savefig('BSF_GPC.png',transparent=True)
plt.show()

plt.figure(figsize = (6,2.5),dpi=300)
plt.subplot(1,2,1)
plt.scatter(xpi, ypi, c=output['functional_property_mean'], s=10)
plt.scatter(xfi[:,0],xfi[:,1],s=20,c=yf,edgecolor='r',marker='s')
plt.subplot(1,2,2)
plt.scatter(xpi, ypi, c=output['functional_property_std'], s=10)
plt.plot(xfi[:,0],xfi[:,1],'s',markerfacecolor="none",markeredgecolor='r',markersize=4)
# plt.savefig('BSF_GPR.png',transparent=True)
plt.show()

#### FeGaPd

In [None]:
import matplotlib.pyplot as plt
import numpy as np
from numpy.random import default_rng
import scipy.io as sio

FGP = sio.loadmat(r'G:\My Drive\Data\FeGaPd\FeGaPd_full_data_200817a.mat')
C = FGP['C']
f = FGP['Mag_modified']/10.
X = FGP['X']
Xp = FGP['XY']
xp = Xp[:,0]
yp = Xp[:,1]
L = FGP['labels_col'][0][1].astype(int)
L = L - 1

print(xy.min(), xy.max())
plt.figure()

edge = np.asarray([266,267,238,213,189,189,158,159,156,144,153,147, \
                   268,235,216,183,165,89,52,53,40,16,166,119,88,48,15,236,237, \
                  269,235,268,234,180,181,182,168,178,274,131,130,177,275])

kp_st = np.concatenate((edge,[61,200,256,92,93,185,186,215,214]))

N = 40
kp_fp = [ 0,8,13,  14, 19, 20,  23,  27,  32,  35,36,42,  45,  60,  71,  72, 80,  91,  99, 105, 108, 124, 126, 132,
 137,138,139,142, 145, 152, 155, 157, 162, 163, 167, 171, 219, 221, 224, 232, 239, 241, 244, 254, 265, 273 ]
print( np.intersect1d(kp_st, kp_fp))
xs = Xp[kp_st,:]
xf = Xp[kp_fp,:]

plt.figure(figsize = (10,10))
plt.scatter(Xp[:,0], Xp[:,1],c=L)
for i in range(Xp.shape[0]):
    plt.text(Xp[i,0], Xp[i,1],str(i))
plt.plot(Xp[kp_st,0], Xp[kp_st,1],'r.')
plt.plot(Xp[kp_fp,0], Xp[kp_fp,1],'rx')

plt.figure(figsize = (6,2.5),dpi=300)
plt.subplot(1,2,1)
plt.scatter(xp,yp,c=L,s=10)
plt.plot(xs[:,0],xs[:,1],'s',markerfacecolor='none',markeredgecolor='r',markersize=4.5)
plt.subplot(1,2,2)
plt.scatter(xp,yp,c=f[:,0],s=10)
plt.plot(xf[:,0],xf[:,1],'s',markerfacecolor='none',markeredgecolor='r',markersize=4.5)
# plt.savefig('BSF_ground_truth.png',transparent=True)
plt.show()

##### SAGE Joint

In [None]:
import matplotlib.pyplot as plt

import numpyro
import numpy as np
from numpy.random import default_rng

import torch
from torch.distributions import constraints
import scipy.io as sio

import jax
import jax.numpy as jnp
from jax import config
config.update("jax_enable_x64", True)
from jax.lax import dynamic_slice
from jax.nn import one_hot as jax_one_hot

from numpyro.infer import MCMC as nMCMC
from numpyro.infer import NUTS as nNUTS
from numpyro.infer import SVI as nSVI
from numpyro.infer import Predictive as nPredictive
from numpyro.infer import Trace_ELBO as nTrace_ELBO
import numpyro.distributions as ndist

torch.set_default_dtype(torch.float64)
from tqdm import trange
import dill
from tqdm import tqdm, trange
from torch.multiprocessing import Pool, Manager, Process
import functools
import math
import time
from numpyro import handlers
from numpyro.infer.initialization import init_to_value

import sys
sys.path.insert(0, r'C:/Users/gkusne/Documents/GitHub/')
import hermes
from hermes.joint import SAGE_ND

numpyro.set_host_device_count(1)

num_proc = 1
    
# ------ Load data -----------
import scipy.io as sio
import numpy as np
from numpy.random import default_rng

FGP = sio.loadmat(r'G:\My Drive\Data\FeGaPd\FeGaPd_full_data_200817a.mat')
C = FGP['C']
Mag = FGP['Mag_modified']/10.
X = FGP['X']
Xp = FGP['XY']
L = FGP['labels_col'][0][1].astype(int)
L = L - 1

edge = np.asarray([266,267,238,213,189,189,158,159,156,144,153,147, \
                   268,235,216,183,165,89,52,53,40,16,166,119,88,48,15,236,237, \
                  269,235,268,234,180,181,182,168,178,274,131,130,177,275])

kp_st = np.concatenate((edge,[61,200,256,92,93,185,186,215,214]))

N = 40
kp_fp = [ 0,8,13,  14, 19, 20,  23,  27,  32,  35,36,42,  45,  60,  71,  72, 80,  91,  99, 105, 108, 124, 126, 132,
 137,138,139,142, 145, 152, 155, 157, 162, 163, 167, 171, 219, 221, 224, 232, 239, 241, 244, 254, 265, 273 ]

xs = Xp[kp_st,:]
ys = L[kp_st].flatten()
xf = Xp[kp_fp,:]
yf = Mag[kp_fp]
f = Mag.copy()
Lp = L.copy()

sage_nd = SAGE_ND(
    num_phase_regions=5,
    num_samples=1000,
    num_warmup=100,
    num_chains = 1,
    target_accept_prob = 0.8,
    max_tree_depth = 5,
    jitter = 1E-6,
    phase_map_SVI_num_steps = 100000,
    Adam_step_size = 0.05,
    posterior_sampling = 1,
    locations_structure = np.asarray(xs),
    locations_functional_property = np.asarray(xf),
    target_structure_labels = np.asarray(ys),
    target_functional_properties = np.asarray(yf),
    locations_prediction = np.asarray(Xp),
    gpc_variance_bounds = np.asarray([.1,10.]),
    gpc_lengthscale_bounds = np.asarray([.1,2.]),
    gpr_variance_bounds = np.asarray([.1, 2.]),
    gpr_lengthscale_bounds = np.asarray([.1,2.]),
    gpr_noise_bounds= np.asarray([0.001,.1]),
    gpr_bias_bounds = np.asarray([-2., 2.]),
    )

sage_nd.run()
predictions_fgp = sage_nd.predictions
print(sklearn.metrics.r2_score(f[:,0],predictions_fgp['functional_property_mean'].flatten()))

sage_pm_est_joint = predictions_fgp['phase_region_labels_mean_estimate']
print(sklearn.metrics.f1_score(Lp, sage_pm_est_joint, average='micro'))

##### Visualize Results

In [None]:
import dill
from scipy.stats import multivariate_normal, entropy
import ternary

FGP = sio.loadmat(r'G:\My Drive\Data\FeGaPd\FeGaPd_full_data_200817a.mat')
C = FGP['C']
Mag = FGP['Mag_modified']/10.
X = FGP['X']
Xp = FGP['XY']
L = FGP['labels_col'][0][1].astype(int)
L = L - 1

edge = np.asarray([266,267,238,213,189,189,158,159,156,144,153,147, \
                   268,235,216,183,165,89,52,53,40,16,166,119,88,48,15,236,237, \
                  269,235,268,234,180,181,182,168,178,274,131,130,177,275])

kp_st = np.concatenate((edge,[61,200,256,92,93,185,186,215,214]))

N = 40
kp_fp = [ 0,8,13,  14, 19, 20,  23,  27,  32,  35,36,42,  45,  60,  71,  72, 80,  91,  99, 105, 108, 124, 126, 132,
 137,138,139,142, 145, 152, 155, 157, 162, 163, 167, 171, 219, 221, 224, 232, 239, 241, 244, 254, 265, 273 ]

xsi = Xp[kp_st,:]
ysi = L[kp_st].flatten()
xfi = Xp[kp_fp,:]
yfi = Mag[kp_fp]
f = Mag.copy()
Lp = L.copy()

output = predictions_fgp

Xpi = Xp.copy()

xpi = Xpi[:,0]
ypi = Xpi[:,1]

plt.figure(figsize = (6,2.5),dpi=300)
ax1 = plt.subplot(1,2,1)
fig1, tax1 = ternary.figure(ax=ax1, scale=.6)
tax1.boundary(linewidth=2)
tax1.gridlines(color="blue", multiple=.1)
plt.scatter(xpi,ypi,c=Lp,s=10)
plt.plot(xsi[:,0],xsi[:,1],'s',markerfacecolor='none',markeredgecolor='r',markersize=4.5)
tax1.clear_matplotlib_ticks()
tax1.get_axes().axis('off')
# plt.colorbar()
ax2 = plt.subplot(1,2,2)
fig2, tax2 = ternary.figure(ax = ax2, scale=.6)
tax2.boundary(linewidth=2)
tax2.gridlines(color="blue", multiple=.1)
plt.scatter(xpi,ypi,c=f[:,0]*10.,s=10)
plt.plot(xfi[:,0],xfi[:,1],'s',markerfacecolor='none',markeredgecolor='r',markersize=4.5)
tax2.clear_matplotlib_ticks()
tax2.get_axes().axis('off')
plt.colorbar()
plt.savefig('FGP_ground_truth.png',transparent=True)
plt.show()

plt.figure(figsize = (6,2.5),dpi=300)
ax1 = plt.subplot(1,2,1)
fig1, tax1 = ternary.figure(ax=ax1, scale=.6)
tax1.boundary(linewidth=2)
tax1.gridlines(color="blue", multiple=.1)
plt.scatter(xpi, ypi,c=np.argmax(output['phase_region_labels_SVI'],axis=1),s=10)
plt.scatter(xsi[:,0],xsi[:,1],c=ys,s=10,edgecolors='r',marker='s')
tax1.clear_matplotlib_ticks()
tax1.get_axes().axis('off')
ax2 = plt.subplot(1,2,2)
fig2, tax2 = ternary.figure(ax=ax2, scale=.6)
tax2.boundary(linewidth=2)
tax2.gridlines(color="blue", multiple=.1)
plt.scatter(xpi, ypi,c=entropy(output['phase_region_labels_SVI'],axis=1),s=10)
plt.scatter(xsi[:,0],xsi[:,1],c=ys,s=10,edgecolors='r',marker='s')
plt.title('VI approx');
tax2.clear_matplotlib_ticks()
tax2.get_axes().axis('off')
plt.show()

plt.figure(figsize = (6,2.5),dpi=300)
ax1 = plt.subplot(1,2,1)
fig1, tax1 = ternary.figure(ax=ax1, scale=.6)
tax1.boundary(linewidth=2)
tax1.gridlines(color="blue", multiple=.1)
plt.scatter(xpi, ypi, c=output['phase_region_labels_mean_estimate'], s=10)
plt.scatter(xsi[:,0],xsi[:,1],c=ys,s=20,edgecolors='r',marker='s')
tax1.clear_matplotlib_ticks()
tax1.get_axes().axis('off')
ax2 = plt.subplot(1,2,2)
fig2, tax2 = ternary.figure(ax=ax2, scale=.6)
tax2.boundary(linewidth=2)
tax2.gridlines(color="blue", multiple=.1)
plt.plot(xsi[:,0],xsi[:,1],'s',markerfacecolor="none",markeredgecolor='r',markersize=4)
plt.scatter(xpi, ypi, c=output['phase_region_labels_mean_entropy'], s=10)
tax2.clear_matplotlib_ticks()
tax2.get_axes().axis('off')
plt.savefig('FGP_GPC.png',transparent=True)
plt.show()

plt.figure(figsize = (6,2.5),dpi=300)
ax1 = plt.subplot(1,2,1)
fig1, tax1 = ternary.figure(ax=ax1, scale=.6)
tax1.boundary(linewidth=2)
tax1.gridlines(color="blue", multiple=.1)
plt.scatter(xpi, ypi, c=output['functional_property_mean'], s=10)
plt.scatter(xfi[:,0],xfi[:,1],s=20,c=yf,edgecolor='r',marker='s')
tax1.clear_matplotlib_ticks()
tax1.get_axes().axis('off')
ax2 = plt.subplot(1,2,2)
fig2, tax2 = ternary.figure(ax=ax2, scale=.6)
tax2.boundary(linewidth=2)
tax2.gridlines(color="blue", multiple=.1)
plt.scatter(xpi, ypi, c=output['functional_property_std'], s=10)
plt.plot(xfi[:,0],xfi[:,1],'s',markerfacecolor="none",markeredgecolor='r',markersize=4)
tax2.clear_matplotlib_ticks()
tax2.get_axes().axis('off')
plt.savefig('FGP_GPR.png',transparent=True)
plt.show()


##### SAGE-FP

In [None]:
# %%writefile sage_BSF_fp_231031a.py

from sage_2D_functions_230804a import *

import matplotlib.pyplot as plt

import numpyro
import numpy as np
from numpy.random import default_rng

import torch
from torch.distributions import constraints
import scipy.io as sio

import jax
import jax.numpy as jnp
from jax import config
config.update("jax_enable_x64", True)
from jax.lax import dynamic_slice
from jax.nn import one_hot as jax_one_hot

from numpyro.infer import MCMC as nMCMC
from numpyro.infer import NUTS as nNUTS
from numpyro.infer import SVI as nSVI
from numpyro.infer import Predictive as nPredictive
from numpyro.infer import Trace_ELBO as nTrace_ELBO
import numpyro.distributions as ndist

torch.set_default_dtype(torch.float64)
from tqdm import trange
import dill
from tqdm import tqdm, trange
from torch.multiprocessing import Pool, Manager, Process
import functools
import math
import time
from numpyro import handlers
from numpyro.infer.initialization import init_to_value

numpyro.set_host_device_count(200)

num_proc = 1

# ------ Load data -----------
FGP = sio.loadmat(r'G:\My Drive\Data\FeGaPd\FeGaPd_full_data_200817a.mat')
C = FGP['C']
Mag = FGP['Mag_modified']/10.
X = FGP['X']
Xp = FGP['XY']
L = FGP['labels_col'][0][1].astype(int)
L = L - 1

edge = np.asarray([266,267,238,213,189,189,158,159,156,144,153,147, \
                   268,235,216,183,165,89,52,53,40,16,166,119,88,48,15,236,237, \
                  269,235,268,234,180,181,182,168,178,274,131,130,177,275])

kp_st = np.concatenate((edge,[61,200,256,92,93,185,186,215,214]))

N = 40
kp_fp = [ 0,8,13,  14, 19, 20,  23,  27,  32,  35,36,42,  45,  60,  71,  72, 80,  91,  99, 105, 108, 124, 126, 132,
 137,138,139,142, 145, 152, 155, 157, 162, 163, 167, 171, 219, 221, 224, 232, 239, 241, 244, 254, 265, 273 ]

xs = Xp[kp_st,:]
ys = L[kp_st].flatten()
xf = Xp[kp_fp,:]
yf = Mag[kp_fp]
f = Mag.copy()
Lp = L.copy()

xs = jnp.asarray( xs, dtype=jnp.float64).copy()
ys = jnp.asarray( ys, dtype=jnp.integer).copy()
xf = jnp.asarray( xf, dtype=jnp.float64).copy()
yf = jnp.asarray( yf, dtype=jnp.float64).copy()

Ns = xs.shape[0] + xf.shape[0]

Xnew_ = jnp.asarray( Xp, dtype=jnp.float64).copy()


key = jax.random.PRNGKey(0)
num_regions = 5

def predict_fp(post_samples, model, *args, **kwargs):
    key = jax.random.PRNGKey(0)
    model = handlers.seed(handlers.condition(model, post_samples), key)
    model_trace = handlers.trace(model).get_trace(*args, **kwargs)
    return model_trace["Fr_new"]["value"], model_trace["f_piecewise"]["value"], model_trace["f_sample"]["value"], model_trace["gpc_new_probs"]["value"], model_trace["v_piecewise"]["value"]

def subsample(samples, step):
    tamples = {}
    for k in samples.keys():
        tamples[k] = samples[k][::step]  
    return tamples  

def split_samples(samples, num_proc, length):
    sample_list = []
    splits = np.array(length/num_proc).astype(int)
    s = {}
    for i in trange(splits):
        for k in samples.keys():
            s[k] = samples[k][(i*num_proc):((i+1)*num_proc)]  
        sample_list.append(s)
    return sample_list

def get_samples_split(samples, num_proc, length, i):
    s = {}
    for k in samples.keys():
        s[k] = samples[k][(i*num_proc):((i+1)*num_proc)]  
    return s    


# !!!!!!!!!!!!!!!!!!!!!!!!

tic = time.perf_counter()
nuts = nMCMC(nNUTS(model_SAGE_ND_FP_230628a, target_accept_prob=0.8, max_tree_depth=5),
            num_samples=1000, num_warmup=100, num_chains=1)
nuts.run(key, xf, yf, num_regions, gpc_var_bounds = jnp.asarray([5.,10.], dtype=jnp.float64),
         gpc_ls_bounds = jnp.asarray([.1,2.], dtype=jnp.float64), gpr_var_bounds = jnp.asarray([.1, 2.], dtype=jnp.float64),
         gpr_ls_bounds = jnp.asarray([.1,5.], dtype=jnp.float64), gpr_noise_bounds = jnp.asarray([0.001,.1], dtype=jnp.float64))

nuts_posterior_samples = nuts.get_samples()

import dill

print('start', nuts_posterior_samples['gpr_noise'].shape[0])
samples = subsample(nuts_posterior_samples, step = 1)
print('after subsampling', samples['gpr_noise'].shape[0]) 

num_length = samples['gpr_noise'].shape[0]

print('splitting')
sl = split_samples(samples, num_proc, num_length)
print('done splitting')

splits = np.array(num_length / num_proc).astype(int)

predict_fn_sage = jax.pmap(
    lambda samples: predict_fp(
        samples, predict_SAGE_ND_FP_230628a, Xnew=Xnew_, xf=xf, yf=yf, num_regions=num_regions
    ), axis_name = 0
)

print('starting pred analysis, for #', num_length)
labels = ['Fr_new', 'f_piecewise', 'f_sample', 'gpc_new_probs', 'v_piecewise']

for i in trange(splits):
    if i == 0:
        preds = predict_fn_sage(sl[i])
        preds_stacked = {labels[0]:preds[0].squeeze(), labels[1]:preds[1].squeeze(), labels[2]:preds[2].squeeze(), labels[3]:preds[3].squeeze(), labels[4]:preds[4].squeeze()}
    else:
        preds = predict_fn_sage(sl[i])
        for j in range(len(labels)):
            preds_stacked[labels[j]] = np.vstack((preds_stacked[labels[j]],preds[j].squeeze()))
toc = time.perf_counter()
print(f"Run in {toc - tic:0.4f} seconds")    

print('done pred analysis')

output = {'preds': preds_stacked, 'starting_data':starting_data}
with open(r"FGP_fp_231011a.dill", "wb") as output_file:
    dill.dump(output, output_file)

# with open(r"BSF_fp_matern52_231011a.dill", "rb") as input_file:
#     output = dill.load(input_file)
preds_sage = output['preds']

print(sklearn.metrics.r2_score(f[:,0],np.nanmean(preds_sage['f_piecewise'], axis=0).flatten()))

##### SAGE-PM

In [None]:
# %%writefile sage_BSF_structure_231031a.py

from sage_2D_functions_230804a import *

import matplotlib.pyplot as plt

import numpyro
import numpy as np
from numpy.random import default_rng

import torch
from torch.distributions import constraints
import scipy.io as sio

import jax
import jax.numpy as jnp
from jax import config
config.update("jax_enable_x64", True)
from jax.lax import dynamic_slice
from jax.nn import one_hot as jax_one_hot

from numpyro.infer import MCMC as nMCMC
from numpyro.infer import NUTS as nNUTS
from numpyro.infer import SVI as nSVI
from numpyro.infer import Predictive as nPredictive
from numpyro.infer import Trace_ELBO as nTrace_ELBO
import numpyro.distributions as ndist

torch.set_default_dtype(torch.float64)
from tqdm import trange
import dill
from tqdm import tqdm, trange
from torch.multiprocessing import Pool, Manager, Process
import functools
import math
import time
from numpyro import handlers
from numpyro.infer.initialization import init_to_value

numpyro.set_host_device_count(200)

num_proc = 1

# ------ Load data -----------
FGP = sio.loadmat(r'G:\My Drive\Data\FeGaPd\FeGaPd_full_data_200817a.mat')
C = FGP['C']
Mag = FGP['Mag_modified']/10.
X = FGP['X']
Xp = FGP['XY']
L = FGP['labels_col'][0][1].astype(int)
L = L - 1

edge = np.asarray([266,267,238,213,189,189,158,159,156,144,153,147, \
                   268,235,216,183,165,89,52,53,40,16,166,119,88,48,15,236,237, \
                  269,235,268,234,180,181,182,168,178,274,131,130,177,275])

kp_st = np.concatenate((edge,[61,200,256,92,93,185,186,215,214]))

N = 40
kp_fp = [ 0,8,13,  14, 19, 20,  23,  27,  32,  35,36,42,  45,  60,  71,  72, 80,  91,  99, 105, 108, 124, 126, 132,
 137,138,139,142, 145, 152, 155, 157, 162, 163, 167, 171, 219, 221, 224, 232, 239, 241, 244, 254, 265, 273 ]


xs = Xp[kp_st,:]
ys = L[kp_st].flatten()
xf = Xp[kp_fp,:]
yf = Mag[kp_fp]
f = Mag.copy()
Lp = L.copy()

xs = jnp.asarray( xs, dtype=jnp.float64).copy()
ys = jnp.asarray( ys, dtype=jnp.integer).copy()
xf = jnp.asarray( xf, dtype=jnp.float64).copy()
yf = jnp.asarray( yf, dtype=jnp.float64).copy()

Ns = xs.shape[0] + xf.shape[0]

Xnew_ = jnp.asarray( Xp, dtype=jnp.float64).copy()


key = jax.random.PRNGKey(0)
num_regions = 5
def predict_structure(post_samples, model, *args, **kwargs):
    key = jax.random.PRNGKey(0)
    model = handlers.seed(handlers.condition(model, post_samples), key)
    model_trace = handlers.trace(model).get_trace(*args, **kwargs)
    return model_trace["gpc_new_probs"]["value"], model_trace["gpc_new_latent"]["value"]

def predict_sage(post_samples, model, *args, **kwargs):
    key = jax.random.PRNGKey(0)
    model = handlers.seed(handlers.condition(model, post_samples), key)
    model_trace = handlers.trace(model).get_trace(*args, **kwargs)
    return model_trace["Fr_new"]["value"], model_trace["f_piecewise"]["value"], model_trace["f_sample"]["value"], model_trace["gpc_new_probs"]["value"], model_trace["v_piecewise"]["value"]

predict_fn_structure = lambda samples: predict_structure(
        samples, predict_SAGE_ND_PM_230628a, Xnew=Xnew_, xs=xs, ys=ys, num_regions=num_regions
    )
predict_fn_sage_1core = lambda samples: predict_sage(
        samples, predict_SAGE_ND_230628a, Xnew_, xs, ys, xf, yf, num_regions=num_regions
    )
def subsample(samples, step):
    tamples = {}
    for k in samples.keys():
        tamples[k] = samples[k][::step]  
    return tamples  

def split_samples(samples, num_proc, length):
    sample_list = []
    splits = np.array(length/num_proc).astype(int)
    s = {}
    for i in trange(splits):
        for k in samples.keys():
            s[k] = samples[k][(i*num_proc):((i+1)*num_proc)]  
        sample_list.append(s)
    return sample_list

def get_samples_split(samples, num_proc, length, i):
    s = {}
    for k in samples.keys():
        s[k] = samples[k][(i*num_proc):((i+1)*num_proc)]  
    return s    

data = [xs, ys, xf, num_regions, jnp.asarray([5.,10.], dtype=jnp.float64), 
         jnp.asarray([1.,2.], dtype=jnp.float64)]

key = jax.random.PRNGKey(0)
autoguide_mle = numpyro.infer.autoguide.AutoLowRankMultivariateNormal(model_SAGE_ND_PM_230628a)
optimizer = numpyro.optim.Adam(step_size=0.05)

svi = nSVI(model_SAGE_ND_PM_230628a, autoguide_mle, optimizer, loss=nTrace_ELBO())
svi_result = svi.run(key, 100000, *data)

params = svi_result.params
mle_2a_st = autoguide_mle.median(params)
preds_st = predict_fn_structure(mle_2a_st)

gpc_new_probs_, gpc_new_latent_ = predict_fn_structure(mle_2a_st)

gpc_latent_ = jnp.vstack((mle_2a_st['gpc_latent_0'],mle_2a_st['gpc_latent_1']))

preds_fp = None

# !!!!!!!!!!!!!!!!!!!!!!!!

# gpc_new_probs_st = predict_fn_st(mle_2a_st)
init_params = {'gpc_latent_0': mle_2a_st['gpc_latent_0'], 'gpc_latent_1': mle_2a_st['gpc_latent_1'],
               'gpc_var': mle_2a_st['gpc_var'],'gpc_lengthscale': mle_2a_st['gpc_lengthscale'],
               'gpc_bias': mle_2a_st['gpc_bias']}
init_strategy=init_to_value(values=init_params)

tic = time.perf_counter()
nuts = nMCMC(nNUTS(model_SAGE_ND_PM_230628a, target_accept_prob=0.8, max_tree_depth=5, init_strategy=init_strategy),
            num_samples=1000, num_warmup=100, num_chains=1)
nuts.run(key, xs, ys, xf, num_regions, gpc_var_bounds = jnp.asarray([5.,10.], dtype=jnp.float64),
         gpc_ls_bounds = jnp.asarray([.1,2.], dtype=jnp.float64))

nuts_posterior_samples = nuts.get_samples()

import dill
# with open(r"2D_2bn_structure.dill", "wb") as output_file:
#     dill.dump(nuts_posterior_samples, output_file)

print('start', nuts_posterior_samples['gpc_bias'].shape[0])
samples = subsample(nuts_posterior_samples, step = 1)
print('after subsampling', samples['gpc_bias'].shape[0]) 

num_length = samples['gpc_bias'].shape[0]

print('splitting')
sl = split_samples(samples, num_proc, num_length)
print('done splitting')

splits = np.array(num_length / num_proc).astype(int)


predict_fn_st_multicore = jax.pmap(
    lambda samples: predict_structure(
        samples, predict_SAGE_ND_PM_230628a, Xnew=Xnew_, xs=xs, ys=ys, num_regions=num_regions
    ), axis_name = 0
)

print('starting pred analysis, for #', num_length)
labels = ['gpc_new_probs']

for i in trange(splits):
    if i == 0:
        preds = predict_fn_st_multicore(sl[i])
        preds_stacked = {labels[0]:preds[0].squeeze()}
    else:
        preds = predict_fn_st_multicore(sl[i])
        for j in range(len(labels)):
            preds_stacked[labels[j]] = np.vstack((preds_stacked[labels[j]],preds[j].squeeze()))
toc = time.perf_counter()
print(f"Run in {toc - tic:0.4f} seconds")

print('done pred analysis')

output = {'preds': preds_stacked, 'preds_st':preds_st, 'preds_fp':preds_fp, 'starting_data':starting_data}
with open(r"FGP_structure_matern52_231011a.dill", "wb") as output_file:
    dill.dump(output, output_file)

### Performance calculations

##### Examples

In [None]:
from matplotlib import pyplot as plt
import dill
import GPy
import sklearn
from sklearn import metrics
import scipy
import tensorflow as tf
import tensorflow_probability as tfp
import gpflow
f64 = gpflow.utilities.to_default_float
from gpflow.ci_utils import ci_niter
import numpy as np
import torch
torch.set_default_dtype(torch.float64)

from scipy.spatial import Voronoi
from scipy.spatial.distance import pdist, squareform
import applied_active_learning_191228a as al

N = 40
x_,y_ = torch.meshgrid(torch.linspace(-2,2,N),torch.linspace(-2,2,N),indexing='xy')
x = x_.flatten()
y = y_.flatten()
X = np.hstack((x[:,None].detach().numpy(),y[:,None].detach().numpy()))

with open(r"2D_2a_and_2b_points_240718a.dill", "rb") as input_file:
    [Xp, kp_st_2d1,kp_fp_2d1, kp_st_2d2, kp_fp_2d2, xs_2a, ys_2a, xf_2a, yf_2a, xs_2b, ys_2b, xf_2b, yf_2b] = dill.load(input_file)
# 2a -----------------------------------------------------
with open(r"2D_2a_and_2b_fv_231030a.dill", "rb") as input_file:
    Lv, fv = dill.load(input_file)

# joint
with open(r"2D_2an_matern52_N41_10ksamples_2init_231011a.dill", "rb") as input_file:
    output = dill.load(input_file)

preds_sage = output['preds']
Xp, Lp, f, xsi, ysi, xfi, yfi = output['starting_data']

print(type(xfi), type(yfi), type(X))

sage_pm_mean_joint = np.nanmean(preds_sage['gpc_new_probs'], axis=0)
sage_pm_est_joint = np.argmax(sage_pm_mean_joint,axis=1)

sage_fp_est_joint = np.nanmean(preds_sage['f_piecewise'], axis=0)

# just structure
with open(r"2D_2an_structure_matern52_N40_231011a.dill", "rb") as input_file:
    output = dill.load(input_file)

preds_sage = output['preds']

sage_pm_mean_st = np.nanmean(preds_sage['gpc_new_probs'], axis=0)
sage_pm_est_st = np.argmax(sage_pm_mean_st,axis=1)


# just FP - SAGE
with open(r"2D_2an_fp_matern52_231011a.dill", "rb") as input_file:
    output = dill.load(input_file)
preds_sage = output['preds']

sage_pm_mean_fp = np.nanmean(preds_sage['gpc_new_probs'], axis=0)
sage_pm_est_fp = 1-np.argmax(sage_pm_mean_fp,axis=1)

sage_fp_est_fp = np.nanmean(preds_sage['f_piecewise'], axis=0)

# just fp - GPR
k = gpflow.kernels.SquaredExponential(lengthscales = [1., 1.])# + gpflow.kernels.White(variance=0.001) # set up kernel
data = (tf.convert_to_tensor(xfi), tf.convert_to_tensor(yfi.flatten()[:,None]))
m = gpflow.models.GPR(data=data, kernel=k, mean_function=gpflow.mean_functions.Constant(yfi.mean())) # set up GPR model

m.likelihood.variance.assign(0.005)
p = m.likelihood.variance
m.likelihood.variance = gpflow.Parameter(p, transform=tfp.bijectors.Sigmoid(f64(0.001), f64(0.01)) )    

opt = gpflow.optimizers.Scipy() # set up hyperparameter optimization
opt_logs = opt.minimize(m.training_loss, m.trainable_variables, method = 'tnc', options=dict(maxiter=100))  # run optimization
gpr_est_fp, temp_var = m.predict_f(tf.convert_to_tensor(X)) # compute the mean and variance for the other samples in the phase region

# just PM
C = 2
data = (tf.convert_to_tensor(xsi), tf.convert_to_tensor(ysi)) # create data variable that contains both the xy-coordinates of the currently measured samples and their labels.
kernel = gpflow.kernels.Matern52() #+ gpflow.kernels.White(variance=0.01)   # sum kernel: Matern32 + White
# Robustmax Multiclass Likelihood
invlink = gpflow.likelihoods.RobustMax(C)  # Robustmax inverse link function
likelihood = gpflow.likelihoods.MultiClass(C, invlink=invlink)  # Multiclass likelihood
m = gpflow.models.VGP(data=data, kernel=kernel, likelihood=likelihood, num_latent_gps=C) # set up the GP model

opt = gpflow.optimizers.Scipy() # set up the hyperparameter optimization
opt_logs = opt.minimize(m.training_loss, m.trainable_variables, method = 'tnc', options=dict(maxiter=ci_niter(1000)) ) # run the optimization
y = m.predict_y(tf.convert_to_tensor(X)) # what is the Poisson process for the full XY coordinates
y_mean = y[0].numpy() # mean of y
y_var = y[1].numpy() # variance of y.
gpc_est_pm = np.argmax(y_mean,axis=1)

r2_2a_sage_joint = sklearn.metrics.r2_score(fv[:,0],sage_fp_est_joint)
r2_2a_gpr_fp = sklearn.metrics.r2_score(fv[:,0],gpr_est_fp)
r2_2a_sage_fp = sklearn.metrics.r2_score(fv[:,0],sage_fp_est_fp)

acc_2a_sage_joint = sklearn.metrics.accuracy_score(Lv, sage_pm_est_joint)
acc_2a_sage_st = sklearn.metrics.accuracy_score(Lv, sage_pm_est_st)
acc_2a_gpc = sklearn.metrics.accuracy_score(Lv, gpc_est_pm)
acc_2a_sage_fp = sklearn.metrics.accuracy_score(Lv, sage_pm_est_fp)

fmi_2a_sage_joint = sklearn.metrics.fowlkes_mallows_score(Lv, sage_pm_est_joint)
fmi_2a_sage_st = sklearn.metrics.fowlkes_mallows_score(Lv, sage_pm_est_st)
fmi_2a_gpc = sklearn.metrics.fowlkes_mallows_score(Lv, gpc_est_pm)
fmi_2a_sage_fp = sklearn.metrics.fowlkes_mallows_score(Lv, sage_pm_est_fp)

f1s_2a_sage_joint = sklearn.metrics.f1_score(Lv, sage_pm_est_joint, average='micro')
f1s_2a_sage_st = sklearn.metrics.f1_score(Lv, sage_pm_est_st, average='micro')
f1s_2a_gpc = sklearn.metrics.f1_score(Lv, gpc_est_pm, average='micro')
f1s_2a_sage_fp = sklearn.metrics.f1_score(Lv, sage_pm_est_fp, average='micro')

print(fv.shape, sage_fp_est_joint.shape, gpr_est_fp.shape)
plt.figure(figsize = (6.5,2))
plt.subplot(1,2,1)
plt.plot(fv[:,0],sage_fp_est_joint,'k.')
plt.title(r2_2a_sage_joint)

plt.subplot(1,2,2)
plt.plot(fv[:,0],gpr_est_fp,'k.')
plt.title('fp' + str(r2_2a_gpr_fp))

plt.show()


print('2a R2, SAGE:',r2_2a_sage_joint, ' SAGE-FP:', r2_2a_sage_fp, ' GPR:',r2_2a_gpr_fp)
print('2a Acc, SAGE:',acc_2a_sage_joint, 'SAGE-PM:',acc_2a_sage_st, 'SAGE-FP:', acc_2a_sage_fp, ' GPC:', acc_2a_gpc)
print('2a FMI, SAGE:',fmi_2a_sage_joint, 'SAGE-PM:',fmi_2a_sage_st, 'SAGE-FP:', fmi_2a_sage_fp, ' GPC:', fmi_2a_gpc)
print('2a F1s, SAGE:',f1s_2a_sage_joint, 'SAGE-PM:',f1s_2a_sage_st, 'SAGE-FP:', f1s_2a_sage_fp, ' GPC:', f1s_2a_gpc)

# # 2b -----------------------------------------------------
# joint
with open(r"2D_2bn_matern52_N40_pred_init_230906a.dill", "rb") as input_file:
    output = dill.load(input_file)

preds_sage = output['preds']
Xp, Lp, f, xsi, ysi, xfi, yfi = output['starting_data']

print(type(xfi), type(yfi), type(X))

sage_pm_mean_joint = np.nanmean(preds_sage['gpc_new_probs'], axis=0)
sage_pm_est_joint = np.argmax(sage_pm_mean_joint,axis=1)

sage_fp_est_joint = np.nanmean(preds_sage['f_piecewise'], axis=0)

# just structure
with open(r"2D_2bn_structure_matern52_N40_231011a.dill", "rb") as input_file:
    output = dill.load(input_file)

preds_sage = output['preds']

sage_pm_mean_st = np.nanmean(preds_sage['gpc_new_probs'], axis=0)
sage_pm_est_st = np.argmax(sage_pm_mean_st,axis=1)


# just FP - SAGE
with open(r"2D_2bn_fp_matern52_231011a.dill", "rb") as input_file:
    output = dill.load(input_file)
preds_sage = output['preds']

sage_pm_mean_fp = np.nanmean(preds_sage['gpc_new_probs'], axis=0)
sage_pm_est_fp = 1-np.argmax(sage_pm_mean_fp,axis=1)

sage_fp_est_fp = np.nanmean(preds_sage['f_piecewise'], axis=0)

# just fp - GPR
k = gpflow.kernels.SquaredExponential(lengthscales = [1., 1.])# + gpflow.kernels.White(variance=0.001) # set up kernel
data = (tf.convert_to_tensor(f64(xfi)), tf.convert_to_tensor(f64(yfi.flatten()[:,None])))
m = gpflow.models.GPR(data=data, kernel=k, mean_function=gpflow.mean_functions.Constant(yfi.mean())) # set up GPR model

m.likelihood.variance.assign(0.005)
p = m.likelihood.variance
m.likelihood.variance = gpflow.Parameter(p, transform=tfp.bijectors.Sigmoid(f64(0.001), f64(0.01)) )    

opt = gpflow.optimizers.Scipy() # set up hyperparameter optimization
opt_logs = opt.minimize(m.training_loss, m.trainable_variables, method = 'tnc', options=dict(maxiter=100))  # run optimization
gpr_est_fp, temp_var = m.predict_f(tf.convert_to_tensor(X)) # compute the mean and variance for the other samples in the phase region

# just PM
C = 2
data = (f64(tf.convert_to_tensor(xsi)), f64(tf.convert_to_tensor(ysi))) # create data variable that contains both the xy-coordinates of the currently measured samples and their labels.
kernel = gpflow.kernels.Matern52() #+ gpflow.kernels.White(variance=0.01)   # sum kernel: Matern32 + White
# Robustmax Multiclass Likelihood
invlink = gpflow.likelihoods.RobustMax(C)  # Robustmax inverse link function
likelihood = gpflow.likelihoods.MultiClass(C, invlink=invlink)  # Multiclass likelihood
m = gpflow.models.VGP(data=data, kernel=kernel, likelihood=likelihood, num_latent_gps=C) # set up the GP model

opt = gpflow.optimizers.Scipy() # set up the hyperparameter optimization
opt_logs = opt.minimize(m.training_loss, m.trainable_variables, method = 'tnc', options=dict(maxiter=ci_niter(1000)) ) # run the optimization
y = m.predict_y(tf.convert_to_tensor(f64(X))) # what is the Poisson process for the full XY coordinates
y_mean = y[0].numpy() # mean of y
y_var = y[1].numpy() # variance of y.
gpc_est_pm = np.argmax(y_mean,axis=1)

r2_2b_sage_joint = sklearn.metrics.r2_score(fv[:,0],sage_fp_est_joint)
r2_2b_gpr_fp = sklearn.metrics.r2_score(fv[:,0],gpr_est_fp)
r2_2b_sage_fp = sklearn.metrics.r2_score(fv[:,0],sage_fp_est_fp)

acc_2b_sage_joint = sklearn.metrics.accuracy_score(Lv, sage_pm_est_joint)
acc_2b_sage_st = sklearn.metrics.accuracy_score(Lv, sage_pm_est_st)
acc_2b_gpc = sklearn.metrics.accuracy_score(Lv, gpc_est_pm)
acc_2b_sage_fp = sklearn.metrics.accuracy_score(Lv, sage_pm_est_fp)

fmi_2b_sage_joint = sklearn.metrics.fowlkes_mallows_score(Lv, sage_pm_est_joint)
fmi_2b_sage_st = sklearn.metrics.fowlkes_mallows_score(Lv, sage_pm_est_st)
fmi_2b_gpc = sklearn.metrics.fowlkes_mallows_score(Lv, gpc_est_pm)
fmi_2b_sage_fp = sklearn.metrics.fowlkes_mallows_score(Lv, sage_pm_est_fp)

f1s_2b_sage_joint = sklearn.metrics.f1_score(Lv, sage_pm_est_joint, average='micro')
f1s_2b_sage_st = sklearn.metrics.f1_score(Lv, sage_pm_est_st, average='micro')
f1s_2b_gpc = sklearn.metrics.f1_score(Lv, gpc_est_pm, average='micro')
f1s_2b_sage_fp = sklearn.metrics.f1_score(Lv, sage_pm_est_fp, average='micro')

print(fv.shape, sage_fp_est_joint.shape, gpr_est_fp.shape)
plt.figure(figsize = (6.5,2))
plt.subplot(1,2,1)
plt.plot(fv[:,0],sage_fp_est_joint,'k.')
plt.title(r2_2b_sage_joint)

plt.subplot(1,2,2)
plt.plot(fv[:,0],gpr_est_fp,'k.')
plt.title('fp' + str(r2_2b_gpr_fp))
plt.show()


print('2b R2, SAGE:',r2_2b_sage_joint, ' SAGE-FP:', r2_2b_sage_fp, ' GPR:',r2_2b_gpr_fp)
print('2b Acc, SAGE:',acc_2b_sage_joint, 'SAGE-PM:',acc_2b_sage_st, 'SAGE-FP:', acc_2b_sage_fp, ' GPC:', acc_2b_gpc)
print('2b FMI, SAGE:',fmi_2b_sage_joint, 'SAGE-PM:',fmi_2b_sage_st, 'SAGE-FP:', fmi_2b_sage_fp, ' GPC:', fmi_2b_gpc)
print('2b F1s, SAGE:',f1s_2b_sage_joint, 'SAGE-PM:',f1s_2b_sage_st, 'SAGE-FP:', f1s_2b_sage_fp, ' GPC:', f1s_2b_gpc)

In [None]:
# Added: CAMEO

import torch
import dill
import GPy
import sklearn
from sklearn.metrics import f1_score, r2_score
import scipy
import tensorflow as tf
import gpflow
f64 = gpflow.utilities.to_default_float
from gpflow.ci_utils import ci_niter
import tensorflow_probability as tfp
from jax.nn import one_hot as jax_one_hot

from scipy.spatial import Voronoi
from scipy.spatial.distance import pdist, squareform
import applied_active_learning_191228a as al

from cameo_240821a import *

N = 40
x_,y_ = torch.meshgrid(torch.linspace(-2,2,N),torch.linspace(-2,2,N),indexing='xy')
x = x_.flatten()
y = y_.flatten()
X = np.hstack((x[:,None].detach().numpy(),y[:,None].detach().numpy()))

with open(r"2D_2a_and_2b_points_240718a.dill", "rb") as input_file:
    [Xp, kp_st_2d1,kp_fp_2d1, kp_st_2d2, kp_fp_2d2, xs_2a, ys_2a, xf_2a, yf_2a, xs_2b, ys_2b, xf_2b, yf_2b] = dill.load(input_file)
# 2a -----------------------------------------------------
with open(r"2D_2a_and_2b_fv_231030a.dill", "rb") as input_file:
    Lv, fv = dill.load(input_file)

# joint
with open(r"2D_2an_matern52_N41_10ksamples_2init_231011a.dill", "rb") as input_file:
    output = dill.load(input_file)

preds_sage = output['preds']
Xp, Lp, f, xsi, ysi, xfi, yfi = output['starting_data']


kp_st_X = nearest_index(Xp, X, kp_st_2d1)
kp_fp_X = nearest_index(Xp, X, kp_fp_2d1)
xf = X[kp_fp_X,:]
xs = X[kp_st_X,:]
ys = Lv[kp_st_X].numpy()
yf = fv[kp_fp_X,0]

Ux = np.asarray(jax_one_hot(ys,2))
S = form_graph(X)
plt.figure()
cl_full, _ = GRF_applied(kp_st_X, Ux, S)
cl_full= cl_full.flatten()
cl_fp = cl_full.flatten()[kp_fp_X]
cameo_gpr_2a = np.zeros(X.shape[0])

for i in range(2):
    k = gpflow.kernels.SquaredExponential(lengthscales = [1., 1.])# + gpflow.kernels.White(variance=0.001) # set up kernel
    data = (f64(tf.convert_to_tensor(xf[cl_fp==i,:])), f64(tf.convert_to_tensor(yf[cl_fp==i].flatten()[:,None])))
    m = gpflow.models.GPR(data=data, kernel=k, mean_function=gpflow.mean_functions.Constant(yf[cl_fp==i].flatten().mean())) # set up GPR model
    
    m.likelihood.variance.assign(0.005)
    p = m.likelihood.variance
    m.likelihood.variance = gpflow.Parameter(p, transform=tfp.bijectors.Sigmoid(f64(0.001), f64(0.01)) )    
    
    opt = gpflow.optimizers.Scipy() # set up hyperparameter optimization
    opt_logs = opt.minimize(m.training_loss, m.trainable_variables, method = 'tnc', options=dict(maxiter=100))  # run optimization
    temp, _ = m.predict_f(tf.convert_to_tensor(f64(X[cl_full==i,:]))) # compute the mean and variance for the other samples in the phase region
    cameo_gpr_2a[cl_full==i] = temp.numpy().flatten()

plt.figure()
plt.subplot(1,2,1)
plt.scatter(X[:,0],X[:,1],c=fv[:,0])
plt.plot(xf[:,0],xf[:,1],'ro')
plt.subplot(1,2,2)
plt.scatter(X[:,0],X[:,1],c=cameo_gpr_2a)
plt.figure()
plt.subplot(1,2,1)
plt.scatter(X[:,0],X[:,1],c=Lv)
plt.plot(xs[:,0],xs[:,1],'ro')
plt.subplot(1,2,2)
plt.scatter(X[:,0],X[:,1],c=cl_full)
r2_2a_cameo = r2_score(fv[:,0],cameo_gpr_2a)
acc_2a_cameo = f1_score(Lv, cl_full)
print( r2_2a_cameo, acc_2a_cameo)


# ---- 2b ---------------------------
kp_st_X = nearest_index(Xp, X, kp_st_2d2)
kp_fp_X = nearest_index(Xp, X, kp_fp_2d2)
xf = X[kp_fp_X,:]
xs = X[kp_st_X,:]
ys = Lv[kp_st_X].numpy()
yf = fv[kp_fp_X,0]

Ux = np.asarray(jax_one_hot(ys,2))
S = form_graph(X)
plt.figure()
cl_full, _ = GRF_applied(kp_st_X, Ux, S)
cl_full= cl_full.flatten()
cl_fp = cl_full.flatten()[kp_fp_X]
cameo_gpr_2b = np.zeros(X.shape[0])

for i in range(2):
    k = gpflow.kernels.SquaredExponential(lengthscales = [1., 1.])# + gpflow.kernels.White(variance=0.001) # set up kernel
    data = (tf.convert_to_tensor(f64(xf[cl_fp==i,:])), tf.convert_to_tensor(f64(yf[cl_fp==i].flatten()[:,None])))
    m = gpflow.models.GPR(data=data, kernel=k, mean_function=gpflow.mean_functions.Constant(f64(yf[cl_fp==i].flatten().mean()))) # set up GPR model
    
    m.likelihood.variance.assign(0.005)
    p = m.likelihood.variance
    m.likelihood.variance = gpflow.Parameter(p, transform=tfp.bijectors.Sigmoid(f64(0.001), f64(0.01)) )    
    
    opt = gpflow.optimizers.Scipy() # set up hyperparameter optimization
    opt_logs = opt.minimize(m.training_loss, m.trainable_variables, method = 'tnc', options=dict(maxiter=100))  # run optimization
    temp, _ = m.predict_f(tf.convert_to_tensor(f64(X[cl_full==i,:]))) # compute the mean and variance for the other samples in the phase region
    cameo_gpr_2b[cl_full==i] = temp.numpy().flatten()

plt.figure()
plt.subplot(1,2,1)
plt.scatter(X[:,0],X[:,1],c=fv[:,0])
plt.plot(xf[:,0],xf[:,1],'ro')
plt.subplot(1,2,2)
plt.scatter(X[:,0],X[:,1],c=cameo_gpr_2b)
plt.figure()
plt.subplot(1,2,1)
plt.scatter(X[:,0],X[:,1],c=Lv)
plt.plot(xs[:,0],xs[:,1],'ro')
plt.subplot(1,2,2)
plt.scatter(X[:,0],X[:,1],c=cl_full)

r2_2b_cameo = r2_score(fv[:,0],cameo_gpr_2b)
acc_2b_cameo = f1_score(Lv, cl_full)
print( r2_2b_cameo, acc_2b_cameo)

##### BSF

In [None]:
import dill
import GPy
import sklearn
import scipy
import gpflow
import tensorflow as tf
import scipy.io as sio
f64 = gpflow.utilities.to_default_float
from gpflow.ci_utils import ci_niter

# BSF --------------
# joint
with open(r"2D_BSF_1core_231031a.dill", "rb") as input_file:
    output = dill.load(input_file)

# preds_sage = output['preds']
# preds_st = output['preds_st']
# preds_fp = output['preds_fp']

# starting_data = output['starting_data']
Xp, Lv, fv, xsi, ysi, xfi, yfi = starting_data

# sage_pm_mean_joint = output[] # np.nanmean(preds_sage['gpc_new_probs'], axis=0)
sage_pm_est_joint = output['phase_region_labels_mean_estimate'] #np.argmax(sage_pm_mean_joint,axis=1)
sage_fp_est_joint = output['functional_property_mean'].flatten() # np.nanmean(preds_sage['f_piecewise'], axis=0)


# just fp - GPR
k = gpflow.kernels.SquaredExponential(lengthscales = [1., 1.])# + gpflow.kernels.White(variance=0.001) # set up kernel
data = (tf.convert_to_tensor(xfi), tf.convert_to_tensor(yfi.flatten()[:,None]))
m = gpflow.models.GPR(data=data, kernel=k, mean_function=gpflow.mean_functions.Constant(yfi.mean())) # set up GPR model

m.likelihood.variance.assign(0.005)
p = m.likelihood.variance
m.likelihood.variance = gpflow.Parameter(p, transform=tfp.bijectors.Sigmoid(f64(0.001), f64(0.01)) )    

opt = gpflow.optimizers.Scipy() # set up hyperparameter optimization
opt_logs = opt.minimize(m.training_loss, m.trainable_variables, method = 'tnc', options=dict(maxiter=100))  # run optimization
gpr_est_fp, temp_var = m.predict_f(tf.convert_to_tensor(Xp)) # compute the mean and variance for the other samples in the phase region

# just PM
C = 3
data = (tf.convert_to_tensor(xsi), tf.convert_to_tensor(ysi)) # create data variable that contains both the xy-coordinates of the currently measured samples and their labels.
kernel = gpflow.kernels.Matern52() #+ gpflow.kernels.White(variance=0.01)   # sum kernel: Matern32 + White
# Robustmax Multiclass Likelihood
invlink = gpflow.likelihoods.RobustMax(C)  # Robustmax inverse link function
likelihood = gpflow.likelihoods.MultiClass(C, invlink=invlink)  # Multiclass likelihood
m = gpflow.models.VGP(data=data, kernel=kernel, likelihood=likelihood, num_latent_gps=C) # set up the GP model

opt = gpflow.optimizers.Scipy() # set up the hyperparameter optimization
opt_logs = opt.minimize(m.training_loss, m.trainable_variables, method = 'tnc', options=dict(maxiter=ci_niter(1000)) ) # run the optimization
y = m.predict_y(tf.convert_to_tensor(Xp)) # what is the Poisson process for the full XY coordinates
y_mean = y[0].numpy() # mean of y
y_var = y[1].numpy() # variance of y.
gpc_est_pm = np.argmax(y_mean,axis=1)


# just FP - SAGE
with open(r"BSF_fp_231011a.dill", "rb") as input_file:
    output = dill.load(input_file)
preds_sage = output['preds']
print('1',preds_sage['gpc_new_probs'].shape)
preds_sage['gpc_new_probs'] = preds_sage['gpc_new_probs'].reshape((1000,156,-1))
sage_pm_mean_fp = np.nanmean(preds_sage['gpc_new_probs'], axis=0)

sage_pm_est_fp = 1-np.argmax(sage_pm_mean_fp,axis=1)

sage_fp_est_fp = np.nanmean(preds_sage['f_piecewise'], axis=0)


# just structure
with open(r"BSF_structure_matern52_231011a.dill", "rb") as input_file:
    output = dill.load(input_file)

preds_sage = output['preds']
preds_sage['gpc_new_probs'] = preds_sage['gpc_new_probs'].reshape((1000,156,-1))
sage_pm_mean_st = np.nanmean(preds_sage['gpc_new_probs'], axis=0)
sage_pm_est_st = np.argmax(sage_pm_mean_st,axis=1)

# CAMEO
with open(r"BSF_st_and_fp_samples_231030a.dill", "rb") as input_file:
    kp_st, kp_fp = dill.load(input_file)
    
Ux = np.asarray(jax_one_hot(ysi,3))
S = form_graph(Xp)
cl_full, _ = GRF_applied(kp_st, Ux, S)
cl_full = cl_full.flatten()
cl_fp = cl_full.flatten()[kp_fp]
cameo_gpr_2a = np.zeros(Xp.shape[0])

for i in range(3):
    k = gpflow.kernels.SquaredExponential(lengthscales = [1., 1.])# + gpflow.kernels.White(variance=0.001) # set up kernel
    data = (tf.convert_to_tensor(xfi[cl_fp==i,:]), tf.convert_to_tensor(yfi[cl_fp==i].flatten()[:,None]))
    m = gpflow.models.GPR(data=data, kernel=k, mean_function=gpflow.mean_functions.Constant(yfi.mean())) # set up GPR model
    
    m.likelihood.variance.assign(0.005)
    p = m.likelihood.variance
    m.likelihood.variance = gpflow.Parameter(p, transform=tfp.bijectors.Sigmoid(f64(0.001), f64(0.01)) )    
    
    opt = gpflow.optimizers.Scipy() # set up hyperparameter optimization
    opt_logs = opt.minimize(m.training_loss, m.trainable_variables, method = 'tnc', options=dict(maxiter=100))  # run optimization
    temp, _ = m.predict_f(tf.convert_to_tensor(Xp[cl_full==i,:])) # compute the mean and variance for the other samples in the phase region
    cameo_gpr_2a[cl_full==i] = temp.numpy().flatten()


r2_2a_sage_joint = r2_score(fv[:,0],sage_fp_est_joint)
r2_2a_gpr_fp = r2_score(fv[:,0],gpr_est_fp)
r2_2a_sage_fp = r2_score(fv[:,0],sage_fp_est_fp)
r2_2a_cameo = r2_score(fv[:,0],cameo_gpr_2a)

acc_2a_sage_joint = f1_score(Lv, sage_pm_est_joint, average='micro')
acc_2a_sage_st = f1_score(Lv, sage_pm_est_st, average='micro')
acc_2a_gpc = f1_score(Lv, gpc_est_pm, average='micro')
acc_2a_sage_fp = f1_score(Lv, sage_pm_est_fp, average='micro')
acc_2a_cameo = f1_score(Lv, cl_full, average='micro')

Xp = Xp*10.+20.
xsi = xsi*10.+20.
xfi = xfi*10.+20.

print(fv.shape, sage_fp_est_joint.shape, gpr_est_fp.shape)
plt.figure(figsize = (6.5,2))
plt.subplot(1,2,1)
plt.plot(fv[:,0],sage_fp_est_joint,'k.')
plt.title(r2_2a_sage_joint)

plt.subplot(1,2,2)
plt.plot(fv[:,0],gpr_est_fp,'k.')
plt.title('fp' + str(r2_2a_gpr_fp))

plt.show()

print('BSF R2, SAGE:',r2_2a_sage_joint, ' SAGE-FP:', r2_2a_sage_fp, ' GPR:',r2_2a_gpr_fp, 'CAMEO:',r2_2a_cameo)
print('BSF Acc, SAGE:',acc_2a_sage_joint, 'SAGE-PM:',acc_2a_sage_st, 'SAGE-FP:', acc_2a_sage_fp, ' GPC:', acc_2a_gpc, 'CAMEO:',acc_2a_cameo)


##### FeGaPd

In [None]:
import dill
import sklearn
import scipy
import gpflow
import tensorflow as tf
f64 = gpflow.utilities.to_default_float
from gpflow.ci_utils import ci_niter

FGP = sio.loadmat(r'G:\My Drive\Data\FeGaPd\FeGaPd_full_data_200817a.mat')
C = FGP['C']
Mag = FGP['Mag_modified']/10.
X = FGP['X']
Xp = FGP['XY']
L = FGP['labels_col'][0][1].astype(int)
L = L - 1

edge = np.asarray([266,267,238,213,189,189,158,159,156,144,153,147, \
                   268,235,216,183,165,89,52,53,40,16,166,119,88,48,15,236,237, \
                  269,235,268,234,180,181,182,168,178,274,131,130,177,275])

kp_st = np.concatenate((edge,[61,200,256,92,93,185,186,215,214]))

N = 40
kp_fp = [ 0,8,13,  14, 19, 20,  23,  27,  32,  35,36,42,  45,  60,  71,  72, 80,  91,  99, 105, 108, 124, 126, 132,
 137,138,139,142, 145, 152, 155, 157, 162, 163, 167, 171, 219, 221, 224, 232, 239, 241, 244, 254, 265, 273 ]


xsi = Xp[kp_st,:]
ysi = L[kp_st].flatten()
xfi = Xp[kp_fp,:]
yfi = Mag[kp_fp]
fv = Mag.copy()
Lv = L.copy()

with open(r"2D_FGP_1core_240718a.dill", "rb") as input_file:
    output = dill.load(input_file)

# sage_pm_mean_joint = output[] # np.nanmean(preds_sage['gpc_new_probs'], axis=0)
sage_pm_est_joint = output['phase_region_labels_mean_estimate'] #np.argmax(sage_pm_mean_joint,axis=1)
sage_fp_est_joint = output['functional_property_mean'].flatten() # np.nanmean(preds_sage['f_piecewise'], axis=0)


# just fp - GPR
k = gpflow.kernels.SquaredExponential(lengthscales = [1., 1.])# + gpflow.kernels.White(variance=0.001) # set up kernel
data = (tf.convert_to_tensor(xfi), tf.convert_to_tensor(yfi.flatten()[:,None]))
m = gpflow.models.GPR(data=data, kernel=k, mean_function=gpflow.mean_functions.Constant(yfi.mean())) # set up GPR model

m.likelihood.variance.assign(0.005)
p = m.likelihood.variance
m.likelihood.variance = gpflow.Parameter(p, transform=tfp.bijectors.Sigmoid(f64(0.001), f64(0.01)) )    

opt = gpflow.optimizers.Scipy() # set up hyperparameter optimization
opt_logs = opt.minimize(m.training_loss, m.trainable_variables, method = 'tnc', options=dict(maxiter=100))  # run optimization
gpr_est_fp, temp_var = m.predict_f(tf.convert_to_tensor(Xp)) # compute the mean and variance for the other samples in the phase region

# just PM
C = 5
data = (tf.convert_to_tensor(xsi), tf.convert_to_tensor(ysi)) # create data variable that contains both the xy-coordinates of the currently measured samples and their labels.
kernel = gpflow.kernels.Matern52() #+ gpflow.kernels.White(variance=0.01)   # sum kernel: Matern32 + White
# Robustmax Multiclass Likelihood
invlink = gpflow.likelihoods.RobustMax(C)  # Robustmax inverse link function
likelihood = gpflow.likelihoods.MultiClass(C, invlink=invlink)  # Multiclass likelihood
m = gpflow.models.VGP(data=data, kernel=kernel, likelihood=likelihood, num_latent_gps=C) # set up the GP model

opt = gpflow.optimizers.Scipy() # set up the hyperparameter optimization
opt_logs = opt.minimize(m.training_loss, m.trainable_variables, method = 'tnc', options=dict(maxiter=ci_niter(1000)) ) # run the optimization
y = m.predict_y(tf.convert_to_tensor(Xp)) # what is the Poisson process for the full XY coordinates
y_mean = y[0].numpy() # mean of y
y_var = y[1].numpy() # variance of y.
gpc_est_pm = np.argmax(y_mean,axis=1)


# just FP - SAGE
with open(r"FGP_fp_231011a.dill", "rb") as input_file:
    output = dill.load(input_file)
preds_sage = output['preds']
preds_sage['gpc_new_probs'] = preds_sage['gpc_new_probs'].reshape((1000,-1,5))
sage_pm_mean_fp = np.nanmean(preds_sage['gpc_new_probs'], axis=0)
print(preds_sage['gpc_new_probs'].shape)
sage_pm_est_fp = 1-np.argmax(sage_pm_mean_fp,axis=1)

sage_fp_est_fp = np.nanmean(preds_sage['f_piecewise'], axis=0)


# just structure
with open(r"FGP_structure_matern52_231011a.dill", "rb") as input_file:
    output = dill.load(input_file)

preds_sage = output['preds']
preds_sage['gpc_new_probs'] = preds_sage['gpc_new_probs'].reshape((1000,-1,5))
sage_pm_mean_st = np.nanmean(preds_sage['gpc_new_probs'], axis=0)
sage_pm_est_st = np.argmax(sage_pm_mean_st,axis=1)

# CAMEO
Ux = np.asarray(jax_one_hot(ysi,5))
S = form_graph(Xp)
cl_full, _ = GRF_applied(kp_st, Ux, S)
cl_full = cl_full.flatten()
cl_fp = cl_full.flatten()[kp_fp]
cameo_gpr_fgp = np.zeros(Xp.shape[0])

for i in range(5):
    k = gpflow.kernels.SquaredExponential(lengthscales = [1., 1.])# + gpflow.kernels.White(variance=0.001) # set up kernel
    data = (tf.convert_to_tensor(xfi[cl_fp==i,:]), tf.convert_to_tensor(yfi[cl_fp==i].flatten()[:,None]))
    m = gpflow.models.GPR(data=data, kernel=k, mean_function=gpflow.mean_functions.Constant(yfi.mean())) # set up GPR model
    
    m.likelihood.variance.assign(0.005)
    p = m.likelihood.variance
    m.likelihood.variance = gpflow.Parameter(p, transform=tfp.bijectors.Sigmoid(f64(0.001), f64(0.01)) )    
    
    opt = gpflow.optimizers.Scipy() # set up hyperparameter optimization
    opt_logs = opt.minimize(m.training_loss, m.trainable_variables, method = 'tnc', options=dict(maxiter=100))  # run optimization
    temp, _ = m.predict_f(tf.convert_to_tensor(Xp[cl_full==i,:])) # compute the mean and variance for the other samples in the phase region
    cameo_gpr_fgp[cl_full==i] = temp.numpy().flatten()


r2_2a_sage_joint = sklearn.metrics.r2_score(fv[:,0],sage_fp_est_joint)
r2_2a_gpr_fp = sklearn.metrics.r2_score(fv[:,0],gpr_est_fp)
r2_2a_sage_fp = sklearn.metrics.r2_score(fv[:,0],sage_fp_est_fp)
r2_2a_cameo = sklearn.metrics.r2_score(fv[:,0],cameo_gpr_fgp)

acc_2a_sage_joint = sklearn.metrics.f1_score(Lv, sage_pm_est_joint, average='micro')
acc_2a_sage_st = sklearn.metrics.f1_score(Lv, sage_pm_est_st, average='micro')
acc_2a_gpc = sklearn.metrics.f1_score(Lv, gpc_est_pm, average='micro')
acc_2a_sage_fp = sklearn.metrics.f1_score(Lv, sage_pm_est_fp, average='micro')
acc_2a_cameo = sklearn.metrics.f1_score(Lv, cl_full, average='micro')

print(fv.shape, sage_fp_est_joint.shape, gpr_est_fp.shape)
plt.figure(figsize = (6.5,2))
plt.subplot(1,2,1)
plt.plot(fv[:,0],sage_fp_est_joint,'k.')
plt.title(r2_2a_sage_joint)

plt.subplot(1,2,2)
plt.plot(fv[:,0],gpr_est_fp,'k.')
plt.title('fp' + str(r2_2a_gpr_fp))

plt.show()

plt.figure()
plt.subplot(1,2,1)
plt.scatter(Xp[:,0],Xp[:,1],c=cl_full)
plt.subplot(1,2,2)
plt.scatter(Xp[:,0],Xp[:,1],c=cameo_gpr_fgp)

print('FGP R2, SAGE:',r2_2a_sage_joint, ' SAGE-FP:', r2_2a_sage_fp, ' GPR:',r2_2a_gpr_fp, 'CAMEO', r2_2a_cameo)
print('FGP Acc, SAGE:',acc_2a_sage_joint, 'SAGE-PM:',acc_2a_sage_st, 'SAGE-FP:', acc_2a_sage_fp, ' GPC:', acc_2a_gpc, 'CAMEO', acc_2a_cameo)
