In [2]:
import os
import sys
import copy
import numpy as np
np.random.seed(0)

In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F 

import jax
import jax.numpy as jnp
from flax import linen as jnn

import einops

In [4]:
# Test grid sample 
x = np.arange(25.).reshape((1, 1, 5, 5))
x

array([[[[ 0.,  1.,  2.,  3.,  4.],
         [ 5.,  6.,  7.,  8.,  9.],
         [10., 11., 12., 13., 14.],
         [15., 16., 17., 18., 19.],
         [20., 21., 22., 23., 24.]]]])

In [13]:
# equal to x
src_grid = torch.rand((1, 5, 5, 2), dtype=torch.float)
torch_grid = src_grid * 2 - 1
numpy_grid = src_grid.permute(0,3,1,2).reshape((2, -1)).numpy() * (torch_grid.shape[1] - 1)
print('torch_grid: ',torch_grid[...,0], torch_grid.shape)
print('numpy_grid: ',numpy_grid[0], numpy_grid.shape)

torch_output = F.grid_sample(torch.tensor(copy.deepcopy(x), dtype = torch.float32), torch_grid)
print('torch output: ',torch_output[0,0,...], torch_output.shape)

jax_output = jax.scipy.ndimage.map_coordinates(copy.deepcopy(x[0,0,...]), numpy_grid, order = 1)
print('jax output: ',jax_output, jax_output.shape)

torch_grid:  tensor([[[-0.5048, -0.9436, -0.7713,  0.1125,  0.0106],
         [-0.5177, -0.7973, -0.3634, -0.0250,  0.4466],
         [-0.3332,  0.0376,  0.2303, -0.5626,  0.3011],
         [-0.5748, -0.5264,  0.1371,  0.5615, -0.6268],
         [-0.5502,  0.2950, -0.0785,  0.9615,  0.3548]]]) torch.Size([1, 5, 5, 2])
numpy_grid:  [0.9904866  0.11270928 0.45735264 2.2249892  2.021273   0.96463203
 0.40530038 1.2731702  1.9500761  2.8932679  1.3335369  2.0752418
 2.4606476  0.8748789  2.6021693  0.85030246 0.9471445  2.2742271
 3.1229355  0.74645925 0.8996985  2.5900836  1.8429229  3.9230647
 2.709521  ] (2, 25)
torch output:  tensor([[ 9.6686, 10.2754, 16.9405,  3.4484, 12.3776],
        [ 0.4979, 19.6236, 17.0953,  1.1155,  4.5924],
        [20.6826, 21.4377, 18.9383,  9.0154, 17.8115],
        [20.5534,  8.3198,  8.6805, 19.4595,  0.2202],
        [20.2133,  6.1495, 15.3396,  7.5569,  8.3712]]) torch.Size([1, 1, 5, 5])
jax output:  [ 6.7813168  3.5288527  5.3857765 11.711691  14.0568

