In [None]:
import os
os.environ['CUDA_VISIBLE_DEVICES']='-1'

import jax
import jax.numpy as jnp
from jax import random
import matplotlib.pyplot as plt

import jax.random as random
rng = random.key(2)

from phyEBM.utils import  get_keys_and_rng

In [None]:
j = jnp.arange(1, stop=20+1, step=1)

nu   = 0.75


sigma = 1.

ell_vals = jnp.exp(jnp.linspace(jnp.log(0.1), jnp.log(0.5), 10))
dim  = 2

ax = plt.axes() 
for ell in ell_vals:
    
    gamma = sigma**2. * 2.**dim * jnp.pi**(dim/2.) * \
        jax.scipy.special.gamma(nu + dim/2.) / jax.scipy.special.gamma(nu)

    vals = jnp.sqrt( gamma * ell**dim * (ell**2. * 2*j**2.*jnp.pi**2. + 1.)**(-nu - dim/2) )
    
    # plt.plot(j, vals/vals[0], '-o', label=f'{ell:.2f}')
    plt.plot(j, vals, '-o', label=f'{ell:.2f}')

    ax.set_xticks(j) 

plt.title(rf'$\ell$; $\quad\nu={nu},\;\sigma={sigma}$')
plt.xlabel(r'$j=k$')
plt.legend()
plt.grid()
plt.savefig(f'plots/TwoDSpectrumDecay_GRF_WM_nu_{nu:.3f}.pdf')
plt.show()


In [None]:
nus   = jnp.exp(jnp.linspace(jnp.log(0.5), jnp.log(3.), 10))

# ell_vals = jnp.exp(jnp.linspace(jnp.log(0.05), jnp.log(5.), 10))
ell = 0.5


sigma = 2.
dim  = 2

ax = plt.axes() 
for nu in nus: 
    
    gamma = sigma**2. * 2.**dim * jnp.pi**(dim/2.) * \
        jax.scipy.special.gamma(nu + dim/2.) / jax.scipy.special.gamma(nu)

    vals = jnp.sqrt( gamma * ell**dim * (ell**2. * 2*j**2.*jnp.pi**2. + 1.)**(-nu - dim/2) )
    
    # plt.plot(j, vals/vals[0], '-o', label=f'{nu:.2f}')
    plt.plot(j, vals, '-o', label=f'{nu:.2f}')

    ax.set_xticks(j) 

plt.title(rf'$\nu$; $\quad\ell={ell}, \; \sigma={sigma}$')
plt.xlabel(r'$j=k$')
plt.legend()
plt.grid()
plt.savefig(f'plots/TwoDSpectrumDecay_GRF_WM_ell_{ell}.pdf')
plt.show()


In [None]:
nu   = 0.75
sigmas = jnp.exp(jnp.linspace(jnp.log(0.1), jnp.log(1.), 10))
ell = 0.5

ax = plt.axes() 
for sigma in sigmas:
    
    gamma = sigma**2. * 2.**dim * jnp.pi**(dim/2.) * \
        jax.scipy.special.gamma(nu + dim/2.) / jax.scipy.special.gamma(nu)

    vals = jnp.sqrt( gamma * ell**dim * (ell**2. * 2*j**2.*jnp.pi**2. + 1.)**(-nu - dim/2) )
    
    # plt.plot(j, vals/vals[0], '-o', label=f'{nu:.2f}')
    plt.plot(j, vals, '-o', label=f'{sigma:.2f}')

    ax.set_xticks(j) 

plt.title(rf'$\sigma$; $\quad  \nu={nu},\; \ell={ell}$')
plt.xlabel(r'$j=k$')
plt.legend()
plt.grid()
plt.savefig(f'plots/TwoDSpectrumDecay_GRF_WM_ell_{ell}.pdf')
plt.show()

In [None]:
from src.WM_prior_2D import WM_Prior_2D
n_basis = 50
prior = WM_Prior_2D(n_basis, n_basis)


init_params_prior = {'sigma_val': jnp.log( 1. ), 
                     'ell_val': jnp.log( jnp.exp(0.5) - 1. ),
                     'nu_val': jnp.log( 2. - 0.5) }
x = jnp.linspace(0., 1., 200)
X, Y = jnp.meshgrid(x,x, indexing='xy')
grid = jnp.concatenate((X[None, ...], Y[None, ...]), 0)
keys, rng = get_keys_and_rng(rng, num=10)
Zs, As = jax.vmap(prior.sample_smooth_z, in_axes=(0, None, None))(keys, init_params_prior, grid)

for i in range(Zs.shape[0]):
    plt.contourf(grid[0], grid[1], Zs[i], 50)
    plt.xlabel(r'$x_{(1)}$')
    plt.ylabel(r'$x_{(2)}$')
    plt.colorbar()
    plt.savefig(f'plots/WMPrior2D_Z_{i}.png')
    plt.show()
    plt.contourf(grid[0], grid[1], As[i], 50, cmap='twilight')
    plt.xlabel(r'$x_{(1)}$')
    plt.ylabel(r'$x_{(2)}$')
    plt.colorbar()
    plt.savefig(f'plots/WMPrior2D_A_{i}.png')
    plt.show()