In [2]:
import math
import torch
import numpy as np

dtype = torch.double
device_id = "cuda"

In [3]:
def bspline_test_7d(x):
    """
    8-dimensional test function consisting of tensor-products of B-splines.
    """
    if x.shape[1] != 7:
        raise ValueError("input must be M x 7 matrix")

    valout = bspline_o2(x[:, [0, 2, 3]]) + bspline_o4(  # [0, 2, 7]
        x[:, [1, 4, 5, 6]]
    )  # +  # [1, 4, 5, 9]

    return valout


def bspline_o2(x):
    """Quadratic B-spline"""
    x = x - torch.floor(x)

    val = torch.ones(x.shape[0], device=device_id, dtype=dtype)

    for t in range(x.shape[1]):
        ind = torch.where((0 <= x[:, t]) & (x[:, t] < 1 / 2))[0]
        if len(ind) > 0:
            val[ind] = val[ind] * 4.0 * x[ind, t]

        ind = torch.where((1 / 2 <= x[:, t]) & (x[:, t] < 1))[0]
        if len(ind) > 0:
            val[ind] = val[ind] * 4.0 * (1 - x[ind, t])

        val = math.sqrt(3 / 4) * val

    return val


def bspline_o4(x):
    """Quartic B-spline"""
    x = x - torch.floor(x)

    val = torch.ones(x.shape[0], device=device_id, dtype=dtype)

    for t in range(x.shape[1]):
        ind = torch.where((0 <= x[:, t]) & (x[:, t] < 1 / 4))[0]
        if len(ind) > 0:
            val[ind] = val[ind] * 128 / 3 * x[ind, t] ** 3

        ind = torch.where((1 / 4 <= x[:, t]) & (x[:, t] < 2 / 4))[0]
        if len(ind) > 0:
            val[ind] = val[ind] * (
                8 / 3 - 32 * x[ind, t] + 128 * x[ind, t] ** 2 - 128 * x[ind, t] ** 3
            )

        ind = torch.where((2 / 4 <= x[:, t]) & (x[:, t] < 3 / 4))[0]
        if len(ind) > 0:
            val[ind] = val[ind] * (
                -88 / 3 - 256 * x[ind, t] ** 2 + 160 * x[ind, t] + 128 * x[ind, t] ** 3
            )

        ind = torch.where((3 / 4 <= x[:, t]) & (x[:, t] < 1))[0]
        if len(ind) > 0:
            val[ind] = val[ind] * (
                128 / 3
                - 128 * x[ind, t]
                + 128 * x[ind, t] ** 2
                - (128 / 3) * x[ind, t] ** 3
            )

        val = math.sqrt(315 / 604) * val

    return val


def bspline_test_7d_fouriercoeff(freq_out):

    fhat = torch.zeros(freq_out.shape[0], device=device_id, dtype=freq_out.dtype)

    ind = torch.where(torch.sum(torch.abs(freq_out[:, [1, 4, 5, 6]]), axis=1) <= 1e-8)[
        0
    ]
    if len(ind) > 0:
        fhat[ind] = fhat[ind] + bspline_o2_hat(freq_out[ind][:, [0, 2, 3]])

    ind = torch.where(torch.sum(torch.abs(freq_out[:, [0, 2, 3]]), axis=1) <= 1e-8)[0]
    if len(ind) > 0:
        fhat[ind] = fhat[ind] + bspline_o4_hat(freq_out[ind][:, [1, 4, 5, 6]])

    norm_fct_square = (
        2
        + 2
        * bspline_o2_hat(torch.zeros((1, 3), device=device_id, dtype=freq_out.dtype))[0]
        * bspline_o4_hat(torch.zeros((1, 4), device=device_id, dtype=freq_out.dtype))[0]
    )

    return fhat, norm_fct_square


def bspline_o2_hat(k):
    """Fourier coefficients of quadratic B-spline"""
    val = torch.ones(k.shape[0], device=device_id, dtype=k.dtype)

    for t in range(k.shape[1]):
        ind = torch.where(k[:, t] != 0)[0]
        if len(ind) > 0:
            val[ind] = (
                val[ind]
                * bspline_sinc(torch.pi / 2 * k[ind, t]) ** 2
                * (-1) ** k[ind, t]
            )

        val = math.sqrt(3 / 4) * val

    return val


def bspline_o4_hat(k):
    """Fourier coefficients of quartic B-spline"""
    val = torch.ones(k.shape[0], device=device_id, dtype=k.dtype)

    for t in range(k.shape[1]):
        ind = torch.where(k[:, t] != 0)[0]
        if len(ind) > 0:
            val[ind] = (
                val[ind]
                * bspline_sinc(torch.pi / 4 * k[ind, t]) ** 4
                * (-1) ** k[ind, t]
            )

        val = math.sqrt(315 / 604) * val

    return val


def bspline_sinc(x):
    """Sinc function: sin(x)/x"""
    return torch.sin(x) / x


def generate_data(N, M, D, dtype=torch.float, device="cuda"):
    # sample D-dimensional array of random points in [0,1]^D
    samples = torch.rand(N, D, dtype=dtype, device=device, requires_grad=False)

    # define D-dimensional array of frequencies in [-M, M]^D based on hyperbolic cross density
    def hyp_cross(d, M):
        if d == 1:
            return np.arange(-M, M + 1, dtype=np.int32).reshape(-1, 1)

        results = []
        for k in range(-M, M + 1):
            sub_result = hyp_cross(d - 1, int(M / max(1, abs(k))))
            extended = np.empty((len(sub_result), d), dtype=np.int32)
            extended[:, 0] = k
            extended[:, 1:] = sub_result
            results.append(extended)
    
        return np.vstack(results)

    frequencies = torch.from_numpy(hyp_cross(D, M)).to(dtype=torch.double, device=device)
    M_f = frequencies.size(0)

    # compute truncation error
    coeffs_gt = torch.zeros((M_f, 2), dtype=torch.double)
    coeffs_gt[:, 0], norm_sq = bspline_test_7d_fouriercoeff(frequencies)
    coeffs_gt = coeffs_gt / math.sqrt(norm_sq)
    trunc_error = torch.sqrt(1 - torch.sum(coeffs_gt[:, 0] ** 2))

    # cast to chosen dtype/device
    frequencies = -2 * math.pi * frequencies.to(dtype=dtype)
    coeffs_gt = coeffs_gt.to(dtype=dtype, device=device_id)

    # since we approximate real functions, we can drop half of the Fourier coeffcients
    frequencies_half = frequencies[: math.ceil(frequencies.size(0) / 2), :]

    # create vector with normalized function values
    values = bspline_test_7d(samples)[:, None] / math.sqrt(norm_sq)

    # Print info
    print("Number of Fourier frequencies in hyperbolic cross:", M_f)
    print("Truncation error computed with double precision:", trunc_error.item())

    return frequencies, frequencies_half, samples, values, coeffs_gt, trunc_error

In [6]:
frequencies, frequencies_half, samples, values, coeffs_gt, trunc_error = generate_data(
    10_000, 170, 7, device=device_id, dtype=dtype
)

k_array = [
    500,
    1000,
    2000,
    4000,
    8000,
    16000,
    32000,
    64000,
    128000,
    256000,
]
sol_trunc = torch.zeros_like(coeffs_gt)
norm_out = torch.zeros(10)
i = 0
for k in k_array:
    vals, ind_trunc = torch.topk(torch.abs(coeffs_gt[:, 0]), k)
    sol_trunc[ind_trunc, 0] = coeffs_gt[ind_trunc, 0]
    norm_out[i] = torch.norm(coeffs_gt - sol_trunc)
    i = i + 1
torch.set_printoptions(precision=12, sci_mode=False)
print(norm_out)

Number of Fourier frequencies in hyperbolic cross: 40888265
Truncation error computed with double precision: 0.00012723915212107113
tensor([    0.032380435616,     0.008383532986,     0.002924559172,
            0.000990046072,     0.000304284971,     0.000057072208,
            0.000003781371,     0.000000441580,     0.000000026648,
            0.000000000000])


In [5]:
torch.set_printoptions(precision=12)
print(norm_out)

tensor([3.238038718700e-02, 8.383346721530e-03, 2.924027619883e-03,
        9.884745813906e-04, 2.991751825903e-04, 4.937915946357e-05,
        3.548695531208e-06, 4.143513194776e-07, 1.492849577289e-08,
        0.000000000000e+00])
