In [1]:
import torch
torch.__version__

'1.10.0+cpu'

In [2]:
import jax
import jax.numpy as jnp

In [3]:
jax.__version__

'0.2.25'

In [4]:
jnp

<module 'jax.numpy' from '/opt/conda/lib/python3.7/site-packages/jax/numpy/__init__.py'>

In [5]:
key = jax.random.PRNGKey(42)
key



DeviceArray([ 0, 42], dtype=uint32)

In [6]:
xi = jax.random.normal(key, shape=(2048, 1))

In [7]:
xi.mean()

DeviceArray(0.00068037, dtype=float32)

In [8]:
import numpy as np
import scipy

In [9]:
init_seed = 42
size = 256

key = jax.random.PRNGKey(init_seed)
xi = jax.random.normal(key, shape=(size, 1))

x = jnp.linspace(0, size - 1, num=size, dtype=int)
y = jnp.linspace(0, size - 1, num=size, dtype=int)
k1, k2 = jnp.meshgrid(x, y)

alpha = 2
tau = 3

coef = tau ** (alpha - 1) * (np.pi ** 2 * (k1 ** 2 + k2 ** 2) + tau ** 2) ** (-alpha / 2)
L = size * coef * xi

L = L.at[1, 1].set(0)

result = jax.scipy.fft.dct(L)
result.shape

(256, 256)

In [10]:
coef.shape

(256, 256)

In [11]:
k1[1]

DeviceArray([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,
              12,  13,  14,  15,  16,  17,  18,  19,  20,  21,  22,  23,
              24,  25,  26,  27,  28,  29,  30,  31,  32,  33,  34,  35,
              36,  37,  38,  39,  40,  41,  42,  43,  44,  45,  46,  47,
              48,  49,  50,  51,  52,  53,  54,  55,  56,  57,  58,  59,
              60,  61,  62,  63,  64,  65,  66,  67,  68,  69,  70,  71,
              72,  73,  74,  75,  76,  77,  78,  79,  80,  81,  82,  83,
              84,  85,  86,  87,  88,  89,  90,  91,  92,  93,  94,  95,
              96,  97,  98,  99, 100, 101, 102, 103, 104, 105, 106, 107,
             108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119,
             120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131,
             132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143,
             144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155,
             156, 157, 158, 159, 160, 161, 162, 163

In [12]:
def gaussian_random_field(alpha, tau, size):
    
    init_seed = 42
    key = jax.random.PRNGKey(init_seed)
    xi = jax.random.normal(key, shape=(size, 1))

    x_lin = jnp.linspace(0, size - 1, num=size, dtype=int)
    k1, k2 = jnp.meshgrid(x_lin, x_lin)

    coef = tau ** (alpha - 1) * (np.pi ** 2 * (k1 ** 2 + k2 ** 2) + tau ** 2) ** (-alpha / 2)
    L = size * coef * xi

    L = L.at[1, 1].set(0)

    result = jax.scipy.fft.dct(L)
    
    return result

gaussian_random_field(2, 3, 256)

DeviceArray([[-4.3577606e+02, -4.3276358e+02, -4.2868890e+02, ...,
              -5.0374043e-01, -3.3561754e-01, -1.6773775e-01],
             [ 1.0878515e+02,  1.0722731e+02,  1.0514499e+02, ...,
               1.9170768e+00,  1.2783842e+00,  6.3928860e-01],
             [-1.3798648e+02, -1.3542706e+02, -1.3202864e+02, ...,
               4.2358752e-02,  2.8240345e-02,  1.4117505e-02],
             ...,
             [ 5.8134389e-01,  8.4904045e-02, -7.0481044e-03, ...,
               2.8941432e-08, -1.8491789e-08,  7.4729947e-09],
             [ 6.6839337e-01,  9.7057305e-02, -8.1524616e-03, ...,
               3.0995896e-08, -1.9960886e-08,  4.4583777e-09],
             [ 6.9146261e-02,  9.9832267e-03, -8.4829697e-04, ...,
               3.0360867e-09, -1.6949059e-09,  1.2347762e-09]],            dtype=float32)

In [13]:
import numpy as np

In [44]:
def solve(coef, F):
    
    k = 20
    # k = len(coef)
    
    l1 = jnp.linspace(1 / (2 * k), (2 * k - 1) / (2 * k), num=k - 1)
    x1, y1 = jnp.meshgrid(l1, l1)
    # print(x1.shape)
    # print(x1)
    
    l2 = jnp.linspace(0, 1, num=k - 1)
    x2, y2 = jnp.meshgrid(l2, l2)
    
    points = jnp.vstack((
        jnp.reshape(x1, (1, -1)),
        jnp.reshape(y1, (1, -1))))
    xi = jnp.vstack((
        jnp.reshape(x2, (1, -1)),
        jnp.reshape(y2, (1, -1))))
    l = x1.shape[0] * x1.shape[1]
    values = np.reshape(np.random.rand(l), (1, -1))
    print(points.shape)
    print(values.shape)
    coef = scipy.interpolate.griddata(points=np.asarray(points).T,#(np.asarray(jnp.reshape(x1, ,
                               values=values.T,
                               xi=xi.T,
                               method='cubic')
    
    coef = jax.scipy.ndimage.map_coordinates(input, coordinates, order, mode='constant', cval=0.0)

In [45]:
solve(None, None)

(2, 361)
(1, 361)


NameError: name 'coordinates' is not defined

In [35]:
def func(x, y):
    return x*(1-x)*np.cos(4*np.pi*x) * np.sin(4*np.pi*y**2)**2

grid_x, grid_y = np.mgrid[0:1:100j, 0:1:200j]

rng = np.random.default_rng()
points = rng.random((1000, 2))
values = func(points[:,0], points[:,1])

from scipy.interpolate import griddata
grid_z0 = griddata(points, values, (grid_x, grid_y), method='nearest')

In [37]:
grid_x.shape

(100, 200)

In [38]:
points.shape

(1000, 2)

In [39]:
values.shape

(1000,)