In [1]:
import sys
sys.path.insert(1, '/home/richard/nfmc_jax/')
import nfmc_jax
import arviz as az
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import pandas as pd
import scipy
from sklearn.neighbors import KernelDensity
import jax
import jax.numpy as jnp
from jax.scipy.special import logsumexp
from jax.scipy.stats import multivariate_normal as mvn
import chaospy
import corner
import torch
import warnings
import pickle

import getdist
from getdist import plots, MCSamples

seed=1234
np.random.seed(seed)
key = jax.random.PRNGKey(seed)

In [2]:
n = 10

key = jax.random.PRNGKey(0)
jax_prior_init = jax.random.uniform(key, shape=(10 * n, n), minval=-1, maxval=1)
print(jnp.shape(jax_prior_init))
print(jax_prior_init)

(100, 10)
[[-8.84091854e-01 -4.43298101e-01 -1.22544289e-01 -5.09226084e-01
   4.90239620e-01  2.34477520e-01  5.42016745e-01  6.91947699e-01
   9.19848204e-01  7.04192400e-01]
 [ 7.50247955e-01  7.75618553e-02 -1.93581581e-02 -8.29578400e-01
   1.59775496e-01 -3.33229303e-01 -1.17771864e-01  4.73334551e-01
  -6.47885323e-01 -3.71244907e-01]
 [ 1.63232803e-01  5.28733253e-01 -6.22561693e-01 -3.91986847e-01
   5.72552443e-01 -3.86230469e-01  7.46369123e-01 -4.86523390e-01
   3.35375786e-01  5.44512987e-01]
 [ 2.36198425e-01  2.10574150e-01 -9.22422409e-02 -1.90681696e-01
   4.45937634e-01 -7.79232264e-01 -7.89236546e-01  3.00645828e-04
   8.02695751e-01 -4.50159311e-01]
 [ 1.41799450e-02 -8.90784979e-01 -9.27755117e-01  3.31110477e-01
  -6.48730993e-01 -7.78982401e-01 -5.01650810e-01 -9.97568130e-01
   2.93876886e-01  5.40287733e-01]
 [-7.28118896e-01  6.98439837e-01  8.96300077e-01 -9.57493782e-01
   9.91392136e-01 -2.73618221e-01  8.73942852e-01  5.49952269e-01
  -1.89256191e-01  7.09

In [3]:
def run_nfo(log_like,log_prior,
    n, #number of sobol points
    low=-1.,
    high=1.,
    knots=None,
    knots_trainable=5,
    bw=1.,
    rel_bw=1,
    layers=5,
    ktrunc=jnp.inf,
    t_ess=0.7,
    beta_max=1.0,
    min_delta_beta=0.05,
    rel_beta=0.5,
    frac_rel_beta_AF = 1,
    alpha_w = (0,0),
    alpha_uw = (0,0),
    latent_sigma=2.4/(2**2),  #FIXME put this in the code as default
    use_latent_beta2=False,
    use_pq_beta_IW1=False, 
    top_verbose=True,
    verbose=False,
    trainable_qw=True,
    sgd_steps=10,
    gamma=0,
    optimize_directions=False,
    snf_loss=2,
    logp_cut=0,
    a=1,
    b=1,
    c=1,
    d=1,
    Ntemp=None,
    eps_z=0.1,
    cull_lowp_tol=0.01, 
    max_cull_frac=0.5,
    ess_tol=1.0,
    local_thresh=jnp.inf,
    nfmc_draws=100, 
    nf_iter=100,
    edge_bins=0,
    local_step_factor=0.25,
    nfmc_local_exploration=True,
    nfmc_AF_samples=200,
    nfmc_full_qw_samples=True,
    agressive_cull=False,
    cull_mode='importance_weights',
    cull_iw_tol=1e-3,
    nfmc_frac_validate=0.1,
    frac_validate=0.1):
    
    ntemp = int(12*n) if(Ntemp is None) else Ntemp
    
    n_prior = 10*n
    n0=n_prior
    N=n_prior
    aN=int(2*N*a)
    bN=n*b
    cN=n*c
    dN=n*d
    if(n<6):
        latent_sigma=1
    else:
        latent_sigma*=((n)**(-1/2))
    dist1d = chaospy.Iid(chaospy.Uniform(lower=low,upper=high),n)
    bounds=np.array([low*np.ones(n),high*np.ones(n)])
    init_prior=dist1d
    init_prior = np.atleast_2d(dist1d.sample(n_prior+1,rule='sobol')).T[1:] #drop first (0,0) sample because cheating
    #init_prior = jax_prior_init
    trace = nfmc_jax.infer_nfomc(log_like, log_prior, jnp.array(init_prior), 
                             log_prior_args=((jnp.array([low]), jnp.array([high]))), 
                             inference_mode="sampling", 
                              vmap=True, parallel_backend=None, 
                               n0=n0,
                               N=N,
                               t_ess=t_ess,
                               N_AF=aN,
                               expl_top_AF=bN,
                               expl_top_qw=cN,
                               expl_latent=dN,
                               bounds=bounds,
                               beta_max=beta_max,
                               N_temp=ntemp,
                               rel_bw=rel_bw,
                               rel_beta=rel_beta,
                               frac_rel_beta_AF = frac_rel_beta_AF,
                               latent_sigma=latent_sigma,
                               use_latent_beta2=use_latent_beta2,
                               use_pq_beta_IW1=use_pq_beta_IW1,
                               k_trunc=ktrunc,
                               #sinf parameters
                               frac_validate=frac_validate,
                               alpha_w=alpha_w,
                               alpha_uw=alpha_uw,
                               NBfirstlayer=True, 
                               verbose=verbose,interp_nbin=knots,iteration=layers,
                               bw_factor_min=bw,bw_factor_max=bw,bw_factor_num=1, #manually force bw factor
                               trainable_qw=trainable_qw,
                               sgd_steps=sgd_steps,
                               gamma=gamma,
                               knots_trainable=knots_trainable,
                               optimize_directions=optimize_directions,
                               logp_cut=logp_cut,
                               random_seed=seed,
                               eps_z=eps_z,
                               cull_lowp_tol=cull_lowp_tol, 
                               max_cull_frac=max_cull_frac,
                               ess_tol=ess_tol,
                               local_thresh=local_thresh,
                               nfmc_draws=nfmc_draws, 
                               nf_iter=nf_iter,
                               top_verbose=top_verbose,
                               edge_bins=edge_bins,
                               local_step_factor=local_step_factor,
                               nfmc_AF_samples=nfmc_AF_samples,
                               nfmc_full_qw_samples=nfmc_full_qw_samples,
                               nfmc_local_exploration=nfmc_local_exploration,
                               agressive_cull=agressive_cull,
                               cull_mode=cull_mode,
                               cull_iw_tol=cull_iw_tol,
                               nfmc_frac_validate=nfmc_frac_validate,
                               min_delta_beta=min_delta_beta,
                               snf_loss=snf_loss,
                             )
    
    return trace


def log_flat_prior(x,low,high):
    n=x.shape[-1]
    return -n*jnp.log((high-low))


def plot_Zs(trace,low,high,n,beta_max):
    Zuws = np.array([trace['logZ'][0]['q{0}_pq_uw'.format(i)] for i in range(1,len(trace['betas'][0]))])
    Zws = np.array([trace['logZ'][0]['q{0}_pq_w'.format(i)] for i in range(1,len(trace['betas'][0]))])
    Zts = np.log(np.array([trace['logZ'][0]['q{0}_pq_w_trainable'.format(i)] for i in range(1,len(trace['betas'][0]))]))
    plt.plot(np.array(trace['betas'][0][1:])*beta_max,Zuws,ls=' ',marker='s',label='uw')
    plt.plot(np.array(trace['betas'][0][1:])*beta_max,Zts,ls=' ',marker='d',label='w_t')
    plt.plot(np.array(trace['betas'][0][1:])*beta_max,Zws,ls=' ',marker='.',label='w')
    if(low is not None and high is not None):plt.axhline(np.log((high-low)**-n),ls='--',c='k',zorder=-1,label='prior norm' )
    plt.ylabel(r'$\log Z$')
    plt.legend(prop={"size":10})
    plt.xlabel(r'$\beta$')
    plt.show()

In [4]:

def get_icov(n,target=200,iseed=seed,eps=1,scale=50, just_do_it=False):
    condition=0
    eigmax=np.inf
    this_seed=iseed
    if just_do_it:
        wish = scipy.stats.wishart(df=n, scale=np.eye(n)* scale,seed=this_seed)
        iC = wish.rvs(size=1)
        C = np.linalg.inv(iC)
        eigs = np.linalg.eigvals(C)
        eigmax,eigmin = eigs.max(),eigs.min()
        condition = eigmax/eigmin
    elif not just_do_it:
        while(abs(condition-target)>eps):
            wish = scipy.stats.wishart(df=n, scale=np.eye(n)* scale,seed=this_seed)
            iC = wish.rvs(size=1)
            C = np.linalg.inv(iC)
            eigs = np.linalg.eigvals(C)
            eigmax,eigmin = eigs.max(),eigs.min()
            condition = eigmax/eigmin
            this_seed += 1
    print("PSD: ",np.all(eigs>0))
    print("Condition number: ", condition)
    print("Op norm: ", eigmax)
    return iC,C

iCov,Cov=get_icov(n,eps=1)
def log_like_cg(x,mu_diag=0,icov=iCov):
    n=x.shape[-1]
    mu = mu_diag*jnp.ones(n)
    #return -0.5 * ( n*jnp.log( (2 * jnp.pi)) + jnp.log( 1/jnp.linalg.det(icov) )) - 0.5 * jnp.dot((x - mu),jnp.dot(icov,(x - mu)))
    return - 0.5 * jnp.dot((x - mu),jnp.dot(icov,(x - mu)))

PSD:  True
Condition number:  199.08320751596906
Op norm:  0.1587924033016887


In [5]:
from scipy.stats import multivariate_normal as n_mvn

def t2a(tens): return tens.numpy().astype(np.float64)
def a2t(arr): return torch.from_numpy(arr.astype(np.float32))

def plot_corr_gd(trace,Cov,Ngd=1000,beta_idx=None, out_name=None):
    qmodels=trace['q_models']
    logp=trace['logp']
    samples=trace['q_samples']
    #qnums = np.unique([int(s.split('q')[1].split('_')[0]) for s in list(trace['q_models'][0].keys())])
    names = ["x%s"%i for i in range(n)]
    labels =  ["x_%s"%i for i in range(n)]

    truth = n_mvn.rvs(mean=np.zeros(n),cov=Cov,size=Ngd)
    truth_gd = MCSamples(samples=truth,names = names, labels = labels, label='truth')

    #if(beta_idx is None): beta_idx=qnums.max()
    s_w,s_uw = t2a(qmodels[0]['q{0}_w'.format(beta_idx)].sample(Ngd,device='cpu')[0]),t2a(qmodels[0]['q{0}_uw'.format(beta_idx)].sample(Ngd,device='cpu')[0])
    samples_w = MCSamples(samples=s_w,names = names, labels = labels, label='q_w')
    samples_uw = MCSamples(samples=s_uw,names = names, labels = labels, label='q_uw')

    # Triangle plot
    plt.figure()
    g = plots.get_subplot_plotter()
    g.triangle_plot([samples_uw,truth_gd,samples_w], filled=True)
    plt.show()
    if out_name is not None:
        g.export(out_name)

In [6]:
def plot_triangles(samples, weights, Cov, Ngd=1000, beta_idx=None,
                   out_name=None):

    names = ["x%s"%i for i in range(n)]
    labels =  ["x_%s"%i for i in range(n)]
    
    truth = n_mvn.rvs(mean=np.zeros(n),cov=Cov,size=Ngd)
    truth_gd = MCSamples(samples=truth,names = names, labels = labels, label='truth')

    samples_w = MCSamples(samples=samples, weights=weights,
                          names=names, labels=labels, 
                          label='posterior samples')

    # Triangle plot
    plt.figure()
    g = plots.get_subplot_plotter()
    g.triangle_plot([truth_gd, samples_w], filled=True)
    plt.show()
    if out_name is not None:
        g.export(out_name)

In [7]:
snf_loss = 4
nf_cull = 0

In [8]:
warnings.filterwarnings('ignore')
cg10 = run_nfo(n=n,log_like=log_like_cg,log_prior=log_flat_prior,
               beta_max=1,c=0,t_ess=0.9,frac_rel_beta_AF=.75,b=2,d=2, Ntemp=500,
               min_delta_beta=1e-2, nfmc_draws=1000, nfmc_local_exploration=True,
               agressive_cull=False, local_thresh=jnp.inf, cull_mode=None,
               snf_loss=snf_loss, frac_validate=0.1, nfmc_frac_validate=0.1)

with open(f'cg10_cullNONE_AF_snfloss{snf_loss}_nfcull{nf_cull}_dict.pickle', 'wb') as cg_file:
    pickle.dump(cg10, cg_file, pickle.HIGHEST_PROTOCOL)

qw = cg10['qw_posterior'][0]
qw_iw = cg10['qw_posterior_weights'][0]
qw = qw[qw_iw != 0.0]
qw_iw = np.asarray(qw_iw[qw_iw != 0.0])

resample_idx = np.random.choice(np.arange(len(qw)), size=10000, p=qw_iw/qw_iw.sum())
posterior = qw[resample_idx, ...]

plot_corr_gd(cg10, Cov, beta_idx=100, out_name=f'./qw_quw_nfo_10d_cg_cullNONE_AF_snfloss{snf_loss}_nfcull{nf_cull}.png') 

plot_triangles(qw, qw_iw, Cov, out_name=f'./qw_weighted_10d_cg_cullNONE_AF_snfloss{snf_loss}_nfcull{nf_cull}.png')
plot_triangles(qw, np.ones(len(qw)), Cov, out_name=f'./qw_unweighted_10d_cg_cullNONE_AF_snfloss{snf_loss}_nfcull{nf_cull}.png')

quw = cg10['quw_posterior'][0]
quw_iw = cg10['quw_posterior_weights'][0]
quw = quw[quw_iw != 0.0]
quw_iw = np.asarray(quw_iw[quw_iw != 0.0])
resample_idx = np.random.choice(np.arange(len(quw)), size=10000, p=quw_iw/quw_iw.sum())

posterior = quw[resample_idx, ...]

plot_triangles(quw, quw_iw, Cov, out_name=f'./quw_weighted_10d_cg_cullNONE_AF_snfloss{snf_loss}_nfcull{nf_cull}.png')

Inference mode is sampling. Maximum beta is set to 1.
After first quw fit, beta=0.00
min_delta_beta = 0.01
Updated beta = 9.5367431640625e-07


TypeError: 'method' object is not iterable