In [1]:
import torch
import torch.nn.functional as F
import jax
import jax.numpy as jnp
from jax.scipy.ndimage import map_coordinates

# PyTorch example
input_tensor = torch.rand(1, 3, 4, 4)  # shape: (N, C, H, W)
grid = torch.rand(1, 4, 4, 2) * 2 - 1  # shape: (N, H_out, W_out, 2), grid coordinates in range [-1, 1]

output_pytorch = F.grid_sample(input_tensor, grid, mode='bilinear', padding_mode='zeros', align_corners=True)

# JAX equivalent
input_tensor_jax = jnp.array(input_tensor.numpy())
grid_jax = jnp.array(grid.numpy())

# Rescale grid coordinates from [-1, 1] to [0, H] and [0, W]
N, H_out, W_out, _ = grid_jax.shape
_, C, H, W = input_tensor_jax.shape

grid_jax = (grid_jax + 1) * jnp.array([H / 2, W / 2]) - 0.5

# Prepare indices for map_coordinates
indices = grid_jax.transpose(3, 0, 1, 2).reshape(2, -1)

# Perform interpolation using map_coordinates
output_jax = jnp.stack([
    map_coordinates(input_tensor_jax[0, c], indices, order=1, mode='constant', cval=0).reshape(H_out, W_out)
    for c in range(C)
], axis=0).reshape(1, C, H_out, W_out)

print("PyTorch output:\n", output_pytorch)
print("JAX output:\n", output_jax)

An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.


PyTorch output:
 tensor([[[[0.6928, 0.6761, 0.7940, 0.9091],
          [0.7974, 0.6731, 0.7090, 0.5539],
          [0.6608, 0.6316, 0.7225, 0.7137],
          [0.8180, 0.7477, 0.7154, 0.6863]],

         [[0.4126, 0.4962, 0.3670, 0.2200],
          [0.3902, 0.5720, 0.5016, 0.4495],
          [0.6357, 0.4255, 0.3580, 0.3559],
          [0.4487, 0.1934, 0.2901, 0.4173]],

         [[0.6130, 0.3389, 0.4732, 0.4540],
          [0.5872, 0.4066, 0.6574, 0.4376],
          [0.4930, 0.4427, 0.1858, 0.8283],
          [0.8252, 0.1546, 0.2341, 0.3013]]]])
JAX output:
 [[[[0.8020759  0.62820303 0.9284525  0.28567412]
   [0.78631866 0.5464505  0.64400995 0.49808082]
   [0.43834862 0.5277994  0.42378038 0.43630135]
   [0.59703594 0.55864716 0.7057093  0.657401  ]]

  [[0.424793   0.11666153 0.35884365 0.34929678]
   [0.38629887 0.17555654 0.71630436 0.3791138 ]
   [0.30152482 0.47732118 0.05740294 0.5406113 ]
   [0.47750637 0.0530642  0.08097727 0.10591275]]

  [[0.7333186  0.1623245  0.5286766  0.