In [3]:

import jax
import os

#os.environ['CUDA_VISIBLE_DEVICES'] = '3'
print(jax.local_devices())



import jax
jax.config.update("jax_enable_x64", True)

import jax.numpy as jnp
from jax import random, vmap, jit, grad
import jax.scipy as scipy

from jax import random


import matplotlib.pyplot as plt

from utils_rough_pde import *
from utils_elliptic_coef import *

from scipy import integrate

import time

from jax.numpy import fft

from tqdm import tqdm

[cuda(id=0)]


In [4]:
def discrete_sine_transform(y):
    n = y.shape[0]
    y_extended = jnp.concatenate([y,  -y[::-1]])
    y_fft = fft.rfft(y_extended)
    return (-y_fft.imag/(jnp.sqrt(2)*(n)))

# We will compute the DST in 2 dimensions

In [23]:
lower = 0.0
upper = 1.0
domain = jnp.array([lower,upper])

# 2d grid 
n_points = 4
x = jnp.linspace(lower, upper, n_points, endpoint=False)
y = jnp.linspace(lower+1, upper+1, n_points, endpoint=False)
X, Y = jnp.meshgrid(x, y)

In [24]:
x

Array([0.  , 0.25, 0.5 , 0.75], dtype=float64)

In [25]:
X.shape, Y

((4, 4),
 Array([[1.  , 1.  , 1.  , 1.  ],
        [1.25, 1.25, 1.25, 1.25],
        [1.5 , 1.5 , 1.5 , 1.5 ],
        [1.75, 1.75, 1.75, 1.75]], dtype=float64))

In [26]:
# Defining the second order elliptic operators 
def matern_kernel_2d(x_1, x_2, y_1, y_2, length_scale):
    r = jnp.sqrt((x_1 - y_1) ** 2 + (x_2 - y_2) ** 2)
    return (1 + jnp.sqrt(5)*r/length_scale + (5 / 3) * (r ** 2) / (length_scale ** 2)) * jnp.exp(-jnp.sqrt(5)*r/length_scale)

In [38]:
length_scale = 1.0
matern_kernel_tensor = jit(vmap(vmap(vmap(vmap(matern_kernel_2d, in_axes=(None, None,None, 0, None)), in_axes=(None, None, 0, None, None)), in_axes = (None, 0, None, None, None)), in_axes=(0, None, None, None, None)))

In [40]:
K_tensor = matern_kernel_tensor(x, x, y, y, length_scale)

In [46]:
# Checking that the tensor is correct
for i in range(n_points):
    for j in range(n_points):
        for k in range(n_points):
            for l in range(n_points):
                assert K_tensor[i,j,k,l] == matern_kernel_2d(x[i], x[j], y[k], y[l], length_scale)

In [63]:
# Testing the sine transform
def discrete_sine_transform(y):
    n = y.shape[0]
    y_extended = jnp.concatenate([y,  -y[::-1]])
    y_fft = fft.rfft(y_extended)
    return (-y_fft.imag/(jnp.sqrt(2)*(n)))[:-1] # Removing the last spurious element 

In [64]:
def vectorize_function(f, ndim):
    for _ in range(ndim):
        f = jax.vmap(f)
    return f

f_test = lambda x: x**2

f_test_vectorized = vectorize_function(f_test, 4)

In [65]:
K_test = f_test_vectorized(K_tensor)
# Checking that the tensor is correct
for i in range(n_points):
    for j in range(n_points):
        for k in range(n_points):
            for l in range(n_points):
                assert K_test[i,j,k,l] == f_test(K_tensor[i,j,k,l])


In [120]:
vmap_dst = jit(vmap(discrete_sine_transform, in_axes=(0,)))
@jit
def dst_2d(A):
    # Receives a 2d array and returns the 2d discrete sine transform
    return vmap_dst(vmap_dst(A).T).T

In [121]:
i,j = 0,2
dst_2d(K_tensor[i,j]).shape

(4, 4)

In [162]:
def test_function(A):
    return A**2

In [163]:
# Create a 4D tensor of size NxNxNxN
N = 2
tensor_4d = jnp.arange(N**4).reshape(N, N, N, N)

In [165]:
test_function(tensor_4d[i,j])

Array([[144, 169],
       [196, 225]], dtype=int64)

In [142]:
i, j = 1,2

In [157]:
K_trans_1 = vmap(vmap(dst_2d, in_axes=0), in_axes=1)(K_tensor)
K_trans_1[i,j]

Array([[-0.        , -0.        , -0.        , -0.        ],
       [-0.        ,  0.42259911,  0.0607691 ,  0.06980065],
       [-0.        ,  0.0607691 ,  0.01222958,  0.01028895],
       [-0.        ,  0.06980065,  0.01028895,  0.01157917]],      dtype=float64)

In [158]:
for i in range(n_points):
    for j in range(n_points):
        print(jnp.mean(jnp.abs(K_trans_1[i,j]- dst_2d(K_tensor[i,j]))))
        #print(jnp.allclose(K_trans_1[i,j],dst_2d(K_tensor[i,j])), )

0.0
0.0009441505943644007
0.0024250106586377915
0.004435435412704201
0.0009441505943644007
0.0
0.00170036999000366
0.0042339261556458285
0.0024250106586377915
0.00170036999000366
0.0
0.002913701354881224
0.004435435412704201
0.0042339261556458285
0.002913701354881224
0.0


In [159]:
K_trans_1 = vmap(vmap(dst_2d, in_axes = (1,)), in_axes = (0,))(K_tensor)
K_trans_1[i,j]

Array([[-0.        , -0.        , -0.        , -0.        ],
       [-0.        ,  0.23252324,  0.04097938,  0.04011494],
       [-0.        , -0.04097938, -0.0007008 , -0.00590172],
       [-0.        ,  0.04011494,  0.00590172,  0.00676482]],      dtype=float64)

In [160]:
dst_2d(K_tensor[i,j])

Array([[-0.        , -0.        , -0.        , -0.        ],
       [-0.        ,  0.42259911,  0.0607691 ,  0.06980065],
       [-0.        ,  0.0607691 ,  0.01222958,  0.01028895],
       [-0.        ,  0.06980065,  0.01028895,  0.01157917]],      dtype=float64)

In [161]:
for i in range(n_points):
    for j in range(n_points):
        print(jnp.mean(jnp.abs(K_trans_1[i,j]- dst_2d(K_tensor[i,j]))))
        #print(jnp.allclose(K_trans_1[i,j],dst_2d(K_tensor[i,j])), )

0.016418072364128768
0.008168984022508527
0.009211634580691167
0.015271417379853745
0.0202497744450224
0.010506645412480833
0.01133266660321496
0.019321258777553463
0.023433988186899887
0.012817075585627153
0.01307937137625481
0.023099955117342013
0.02489168104543778
0.014555921705814776
0.01385987559621819
0.025581757110357678


In [67]:
tensor_sine_transform = vectorize_function(discrete_sine_transform, 4)

In [81]:
tensor_sine_transform = vmap(vmap(vmap(discrete_sine_transform)))

In [89]:
tensor_sine_transform = vmap(discrete_sine_transform)

In [90]:
tensor_sine_transform(K_tensor[i,j])

Array([[-0.        ,  0.30190417,  0.0504235 ,  0.05146831],
       [-0.        ,  0.23259504,  0.03608084,  0.03939931],
       [-0.        ,  0.17326523,  0.02496178,  0.02921281],
       [-0.        ,  0.12555461,  0.01682743,  0.02109987]],      dtype=float64)

In [82]:
K_tensor.shape

(4, 4, 4, 4)

In [83]:
K_transformed = tensor_sine_transform(K_tensor)

In [86]:
K_transformed

Array([[[[-0.        ,  0.1870605 ,  0.04298894,  0.03420684],
         [-0.        ,  0.14843572,  0.03231707,  0.02682937],
         [-0.        ,  0.11375217,  0.02339077,  0.02034504],
         [-0.        ,  0.08463736,  0.01641859,  0.01500015]],

        [[-0.        ,  0.24230867,  0.0487911 ,  0.04286009],
         [-0.        ,  0.18942345,  0.03579469,  0.03317524],
         [-0.        ,  0.14306054,  0.02532264,  0.02484872],
         [-0.        ,  0.10498727,  0.01740876,  0.01811432]],

        [[-0.        ,  0.30190417,  0.0504235 ,  0.05146831],
         [-0.        ,  0.23259504,  0.03608084,  0.03939931],
         [-0.        ,  0.17326523,  0.02496178,  0.02921281],
         [-0.        ,  0.12555461,  0.01682743,  0.02109987]],

        [[-0.        ,  0.35927939,  0.04532089,  0.05898382],
         [-0.        ,  0.273218  ,  0.0316827 ,  0.04480812],
         [-0.        ,  0.20112159,  0.0214868 ,  0.03298622],
         [-0.        ,  0.14419502,  0.01424179, 