In [5]:
import math
import torch
import numpy as np

dtype = torch.double
device_id = "cuda"

In [11]:
def generate_data(N, M, D, dtype=torch.float, device="cuda"):
    # sample D-dimensional array of random points in [0,1]^D
    samples = torch.rand(N, D, dtype=dtype, device=device, requires_grad=False)

    # define D-dimensional array of frequencies in [-M, M]^D based on hyperbolic cross density
    def hyp_cross(d, M):
        if d == 1:
            return np.arange(-M, M + 1, dtype=np.int32).reshape(-1, 1)

        results = []
        for k in range(-M, M + 1):
            sub_result = hyp_cross(d - 1, int(M / max(1, abs(k))))
            extended = np.empty((len(sub_result), d), dtype=np.int32)
            extended[:, 0] = k
            extended[:, 1:] = sub_result
            results.append(extended)

        return np.vstack(results)

    frequencies = (
        -2
        * math.pi
        * torch.from_numpy(hyp_cross(D, M)).to(dtype=torch.double, device=device)
    )
    M_f = frequencies.size(0)

    # function in H^3/2 with known Fourier coefficients
    def fun1(x, D):
        out = torch.ones_like(x[:, 0], requires_grad=False)
        for i in range(D):
            out = out * torch.clip(
                0.2 - (x[:, i] - 0.5) ** 2,
                min=0,
                max=None,
            )
        return (15 / (4 * np.sqrt(3)) * 5 ** (3 / 4)) ** D * out[:, None]

    def fun1_fourier_coeffs(f, D):
        out = torch.ones_like(f[:, 0], requires_grad=False)
        for i in range(D):
            tmp = (
                5 ** (5 / 4)
                * np.sqrt(3)
                * (-1) ** torch.round(f[:, i] / (2 * np.pi))
                * (
                    np.sqrt(5) * torch.sin(f[:, i] / np.sqrt(5))
                    - f[:, i] * torch.cos(f[:, i] / np.sqrt(5))
                )
                / (f[:, i] ** 3)
            )
            tmp[torch.isnan(tmp)] = 5 ** (1 / 4) / np.sqrt(3)
            out = out * tmp
        return out

    # compute truncation error
    coeffs_gt_d = torch.zeros((M_f, 2), dtype=torch.double)
    coeffs_gt_d[:, 0] = fun1_fourier_coeffs(frequencies.to(dtype=torch.double), D)
    trunc_error = torch.sqrt(1 - torch.sum(coeffs_gt_d[:, 0] ** 2))

    # cast to chosen dtype/device
    frequencies = frequencies.to(dtype=dtype)
    coeffs_gt = torch.zeros((M_f, 2), dtype=dtype, device=device_id)
    coeffs_gt[:, 0] = fun1_fourier_coeffs(frequencies, D)

    # since we approximate real functions, we can drop half of the Fourier coeffcients
    frequencies_half = frequencies[: math.ceil(frequencies.size(0) / 2), :]

    # create vector with normalized function values
    values = fun1(samples, D)

    # Optionally print info
    print("Number of Fourier frequencies in Hyperbolic cross:", M_f)
    print("Truncation error computed with double precision:", trunc_error.item())

    return frequencies, frequencies_half, samples, values, coeffs_gt, trunc_error

In [None]:
frequencies, frequencies_half, samples, values, coeffs_gt, trunc_error = generate_data(
    10_000, 1500, 5, device=device_id, dtype=dtype
)

k_array = [
    1000,
    2000,
    4000,
    8000,
    16000,
    32000,
    64000,
    128000,
    256000,
    512000,
    1024000,
    2048000,
]
sol_trunc = torch.zeros_like(coeffs_gt)
norm_out = torch.zeros(12)
i = 0
for k in k_array:
    vals, ind_trunc = torch.topk(torch.abs(coeffs_gt[:, 0]), k)
    sol_trunc[ind_trunc, 0] = coeffs_gt[ind_trunc, 0]
    norm_out[i] = torch.norm(coeffs_gt - sol_trunc)
    # norm_out[i] = torch.sum(torch.abs(coeffs_gt - sol_trunc))/ math.sqrt(k)
    i = i + 1
torch.set_printoptions(precision=12, sci_mode=False)
print(norm_out)

Number of Fourier frequencies in Hyperbolic cross: 31601129
Truncation error computed with double precision: 1.2184913765988108e-05
tensor([    0.062187109143,     0.035435069352,     0.020855875686,
            0.011940010823,     0.006612705998,     0.003610250074,
            0.001916919369,     0.001008824562,     0.000525453012,
            0.000268026110,     0.000134955146,     0.000066431465])


In [8]:
torch.set_printoptions(precision=12)
print(norm_out)

tensor([6.218710914254e-02, 3.543506935239e-02, 2.085587568581e-02,
        1.194001082331e-02, 6.612705998123e-03, 3.610250074416e-03,
        1.916919369251e-03, 1.008824561723e-03, 5.254530115053e-04,
        2.680261095520e-04, 1.349551457679e-04, 6.643146480201e-05])
