In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import torch
from functools import partial
import logging

sns.set(font_scale=2., style='whitegrid', palette=sns.color_palette('tab10'))

In [None]:
def get_coeffs(kernel_fn, order=1):
    """ Given a raw stationary and isotropic kernel function (R -> R),
        computes the optimal discrete filter coefficients to a fixed order
        by matching the coverage in spatial domain (from the highest and lowest samples)
        with the coverage in the frequency domain (from the nyquist frequency)."""
    N = 10**4
    x = np.linspace(-30,30,N)
    fn_values = kernel_fn(torch.from_numpy(x).float()).cpu().data.numpy()
    w = 2*np.pi*np.fft.fftfreq(N,60/N)
    fft_values = np.absolute(np.fft.fft(fn_values)/(2*np.pi*np.sqrt(N)))
    
    obj_fn = partial(coverage_diff,order=order,x=x,w=w,fn_values=fn_values,fft_values=fft_values)
    s = binary_search(0,(.1,9),obj_fn,1e-4) # Search for zeros of objective function (up to 1e-4 precision)
    vals = kernel_fn(s*torch.arange(-order,order+1).float())
    return vals/vals[order], s*torch.arange(-order,order+1).float(), x, fn_values, w, fft_values, s

def coverage_diff(spacing,order,x,w,fn_values,fft_values):
    """ Given sample spacing and filter order, compute the difference in coverage over
        spatial and frequency domains. """
    k = 2*order+1
    a = spacing*k/2
    nyquist_w = np.pi/spacing
    spatial_coverage = fn_values[(-a<=x)&(x<=a)].sum()/fn_values.sum() #(dx's cancel)
    spectral_coverage = fft_values[(-nyquist_w<=w)&(w<=nyquist_w)].sum()/fft_values.sum() #(dw's cancel)
    # print(f"cov: x {spatial_coverage:.2f} w {spectral_coverage:.2f}")
    return spatial_coverage-spectral_coverage

def binary_search(target,bounds,fn,eps=1e-2):
    """ Perform binary search to find the input corresponding to the target output
        of a given monotonic function fn up to the eps precision. Requires initial bounds
        (lower,upper) on the values of x."""
    lb,ub = bounds
    i = 0
    while ub-lb>eps:
        guess = (ub+lb)/2
        y = fn(guess)
        if y<target:
            lb = guess
        elif y>=target:
            ub = guess
        i+=1
        if i>500: assert False
    return (ub+lb)/2

## Kernels

In [None]:
def matern(d2, nu=.5):
    d = d2.abs().sqrt()#(d2.abs()+1e-3).sqrt()
    exp_component = torch.exp(-np.sqrt(nu * 2) * d)
    if nu == 0.5:
        constant_component = 1
    elif nu == 1.5:
        constant_component = (np.sqrt(3) * d).add(1)
    elif nu == 2.5:
        constant_component = (np.sqrt(5) * d).add(1).add(5.0 / 3.0 * d ** 2)
    else:
        raise NotImplementedError
    return constant_component * exp_component

def rbf(d2):
    return torch.exp(-d2)

## Coefficient Search

In [None]:
order = 1
coeffs, coeffs_m, x, fn_values, w, fft_values, s = get_coeffs(lambda d: rbf(d**2), order)
asort_w = np.argsort(w)
w = w[asort_w]
fft_values = fft_values[asort_w]

k = 2 * order + 1
a = s * k / 2
ny_w = np.pi / s

fig, axes = plt.subplots(figsize=(10, 10), ncols=1, nrows=2)

axes[0].set_xlabel(r'$\tau$')
axes[0].set_ylabel(r'$k(\tau)$')
axes[0].set_xticks([x[(-a <= x) & (x <= a)][0], x[(-a <= x) & (x <= a)][-1]])
axes[0].set_xticklabels([r'$-\frac{sm}{2}$', r'$+\frac{sm}{2}$'])
axes[0].set_yticks([0.0])

axes[1].set_xlabel(r'$\omega$')
axes[1].set_ylabel(r'$\mathcal{F}[k](\omega)$')
axes[1].set_xticks([w[(-ny_w <= w) & (w <= ny_w)][0], w[(-ny_w <= w) & (w <= ny_w)][-1]])
axes[1].set_xticklabels([r'$-\frac{\pi}{s}$', r'$+\frac{\pi}{s}$'])
axes[1].set_yticks([0.0])

sns.lineplot(x=x[(- 2. * a <= x) & (x <= 2. * a)],
             y=fn_values[(- 2. * a <= x) & (x <= 2. * a)],
             ax=axes[0], color='gray')
axes[0].fill_between(x=x, y1=fn_values, where=((-a <= x) & (x <= a)), color='gray', alpha=0.1)
axes[0].plot(
    [x[(-a <= x) & (x <= a)][0], x[(-a <= x) & (x <= a)][0]],
    # [0.0, fn_values[(-a <= x) & (x <= a)][0]],
    [0.0, np.max(fn_values) / 2],
    '--',
    c='black',
    linewidth=2.0
)
axes[0].plot(
    [x[(-a <= x) & (x <= a)][-1], x[(-a <= x) & (x <= a)][-1]],
    # [0.0, fn_values[(-a <= x) & (x <= a)][-1]],
    [0.0, np.max(fn_values) / 2],
    '--',
    c='black',
    linewidth=2.0
)

axes[0].arrow(
    x[(-a <= x) & (x <= a)][0] - 0.2, np.max(fn_values) / 4, x[(-a <= x) & (x <= a)][0] - .001, 0.0,
    head_width=0.05, head_length=0.1, color='black'
)
axes[0].arrow(
    x[(-a <= x) & (x <= a)][-1] + 0.2, np.max(fn_values) / 4, x[(-a <= x) & (x <= a)][-1] + .001, 0.0,
    head_width=0.05, head_length=0.1, color='black'
)

axes[0].stem(coeffs_m.numpy(), coeffs.numpy(), use_line_collection=True, basefmt=None)


sns.lineplot(x=w[(- 2. * ny_w <= w) & (w <= 2. * ny_w)],
             y=fft_values[(- 2. * ny_w <= w) & (w <= 2. * ny_w)],
             ax=axes[1], color='gray')
axes[1].fill_between(x=w, y1=fft_values, where=((-ny_w <= w) & (w <= ny_w)), color='gray', alpha=0.1)
axes[1].plot(
    [w[(-ny_w <= w) & (w <= ny_w)][0], w[(-ny_w <= w) & (w <= ny_w)][0]],
    [0.0, np.max(fft_values) / 2],
    '--',
    c='black',
    linewidth=2.0
)
axes[1].plot(
    [w[(-ny_w <= w) & (w <= ny_w)][-1], w[(-ny_w <= w) & (w <= ny_w)][-1]],
    [0.0, np.max(fft_values) / 2],
    '--',
    c='black',
    linewidth=2.0
)

axes[1].arrow(
    w[(-ny_w <= w) & (w <= ny_w)][0] - 2.5, np.max(fft_values) / 4, 2., 0.0,
    head_width=0.025, head_length=0.25, color='black',
)
axes[1].arrow(
    w[(-ny_w <= w) & (w <= ny_w)][-1] + 2.5, np.max(fft_values) / 4, -2., 0.0,
    head_width=0.025, head_length=0.25, color='black',
)

fig.tight_layout()
fig.savefig('coverage.pdf', bbox_inches='tight')