This is a PyTorch module for Radial Basis Function (RBF) Interpolation, which is translated from SciPy's implemenation. This implementation benefits from GPU acceleration, making it significantly faster and more suitable for larger interpolation problems.
pip install torchrbf
The only dependencies are PyTorch and NumPy. If you want to run the tests and benchmarks, you also need SciPy installed.
If you are using TF32, you may experience numerical precision issues. TF32 is enabled by default in PyTorch versions 1.7 to 1.11 (see here). To disable it, you can use
torch.backends.cuda.matmul.allow_tf32 = False
torchrbf will issue a warning if TF32 is enabled.
Here is a simple example for interpolating 3D data in a 2D domain:
import torch
import matplotlib.pyplot as plt
from torchrbf import RBFInterpolator
y = torch.rand(100, 2) # Data coordinates
d = torch.rand(100, 3) # Data vectors at each point
interpolator = RBFInterpolator(y, d, smoothing=1.0, kernel='thin_plate_spline')
# Query coordinates (100x100 grid of points)
x = torch.linspace(0, 1, 100)
y = torch.linspace(0, 1, 100)
grid_points = torch.meshgrid(x, y, indexing='ij')
grid_points = torch.stack(grid_points, dim=-1).reshape(-1, 2)
# Query RBF on grid points
interp_vals = interpolator(grid_points)
# Plot the interpolated values in 2D
plt.scatter(grid_points[:, 0], grid_points[:, 1], c=interp_vals[:, 0])
plt.title('Interpolated values in 2D')
plt.show()
Since the module is implemented in PyTorch, it benefits from GPU acceleration. For larger interpolation problems, torchrbf is significantly faster than SciPy's implementation (+100x faster on a RTX 3090):