In [1]:
import os
import sys
import copy
import numpy as np
np.random.seed(0)

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F 

import jax
import jax.numpy as jnp
from flax import linen as jnn

import einops

In [20]:
# Test grid sample 
x = np.arange(75.).reshape((1, 3, 5, 5))
x

array([[[[ 0.,  1.,  2.,  3.,  4.],
         [ 5.,  6.,  7.,  8.,  9.],
         [10., 11., 12., 13., 14.],
         [15., 16., 17., 18., 19.],
         [20., 21., 22., 23., 24.]],

        [[25., 26., 27., 28., 29.],
         [30., 31., 32., 33., 34.],
         [35., 36., 37., 38., 39.],
         [40., 41., 42., 43., 44.],
         [45., 46., 47., 48., 49.]],

        [[50., 51., 52., 53., 54.],
         [55., 56., 57., 58., 59.],
         [60., 61., 62., 63., 64.],
         [65., 66., 67., 68., 69.],
         [70., 71., 72., 73., 74.]]]])

In [4]:
# equal to x
torch_grid = torch.rand((1, 5, 5, 2), dtype=torch.float) * 2 - 1
numpy_grid = torch_grid.numpy()
print(f'max: {torch.max(torch_grid)} min: {torch.min(torch_grid)}')

max: 0.8556442260742188 min: -0.9678440093994141


In [17]:
def grid_sampler_compute_source_index(coord, size: int, align_corners: bool):
  coord = grid_sampler_unnormalize(coord, size, align_corners)
  print(coord, size)
  return jnp.clip(coord, 0, size - 1)

def grid_sampler_unnormalize(coord, size: int, align_corners: bool):
  if align_corners:
    return (coord + 1. / 2) * (size - 1)
  else:
    return ((coord + 1.) * size - 1) / 2

In [27]:
input_arr = jnp.asarray(x)
_, C, H, W = input_arr.shape
ix = grid_sampler_compute_source_index(numpy_grid[..., 0], W, False)
iy = grid_sampler_compute_source_index(numpy_grid[..., 1], H, False)
print(f'ix max: {jnp.max(ix)} min: {jnp.min(ix)}')
print(f'iy max: {jnp.max(iy)} min: {jnp.min(iy)}')

[[[ 2.855754   -0.27442712  0.515357    0.06548208  2.2507987 ]
  [-0.11631715  4.100039    1.1919976   2.9559836   4.1391106 ]
  [ 3.4129162   3.1455922  -0.03089565  1.6570919   1.5491953 ]
  [-0.34758735  0.91856575  1.525666    1.4007546   2.4929564 ]
  [ 3.1884384   1.2327335   3.9161587   3.9766836   0.34198922]]] 5
[[[ 1.8326066   3.049519    1.4485173   3.3466935   0.35571796]
  [ 1.9808667   0.28369826  1.5969338   0.9364698   2.5415833 ]
  [ 2.7321405  -0.41961002  2.009158    3.7586012   1.001533  ]
  [ 3.216091   -0.32091337  0.554492   -0.0871833   0.4820041 ]
  [ 3.7806883   3.8260922  -0.32138455  1.7887709   0.10089409]]] 5
ix max: 4.0 min: 0.0
iy max: 3.82609224319458 min: 0.0


In [74]:
ix_nw

Array([[[2., 0., 0., 0., 2.],
        [0., 4., 1., 2., 4.],
        [3., 3., 0., 1., 1.],
        [0., 0., 1., 1., 2.],
        [3., 1., 3., 3., 0.]]], dtype=float32)

In [26]:
# get NE, NW, SE, SW pixel values from (x, y)
ix_nw = jnp.floor(ix)
iy_nw = jnp.floor(iy)
ix_ne = ix_nw + 1
iy_ne = jnp.copy(iy_nw)
ix_sw = jnp.copy(ix_nw)
iy_sw = iy_nw +1
ix_se = ix_nw +1
iy_se = iy_nw +1

# get surfaces to each neighbor
nw = (ix_se - ix) * (iy_se - iy)
ne = (ix - ix_sw) * (iy_sw - iy)
sw = (ix_ne - ix) * (iy - iy_ne)
se = (ix - ix_nw) * (iy - iy_nw)

In [60]:
i_nw = ix_nw.reshape((1,25)) + iy_nw.reshape((1,25)) * 5
i_ne = ix_ne.reshape((1,25)) + iy_ne.reshape((1,25)) * 5
i_sw = ix_sw.reshape((1,25)) + iy_sw.reshape((1,25)) * 5
i_se = ix_se.reshape((1,25)) + iy_se.reshape((1,25)) * 5

In [58]:
print(ix_nw)
print(iy_nw)

[[[2. 0. 0. 0. 2.]
  [0. 4. 1. 2. 4.]
  [3. 3. 0. 1. 1.]
  [0. 0. 1. 1. 2.]
  [3. 1. 3. 3. 0.]]]
[[[1. 3. 1. 3. 0.]
  [1. 0. 1. 0. 2.]
  [2. 0. 2. 3. 1.]
  [3. 0. 0. 0. 0.]
  [3. 3. 0. 1. 0.]]]


In [61]:
input_arr = input_arr.reshape((1, 3, 25))
nw = jnp.take(input_arr, i_nw.astype(int), axis = -1).reshape((1, 3, 5, 5)) * nw
ne = jnp.take(input_arr, i_ne.astype(int), axis = -1).reshape((1, 3, 5, 5)) * ne
sw = jnp.take(input_arr, i_sw.astype(int), axis = -1).reshape((1, 3, 5, 5)) * sw
se = jnp.take(input_arr, i_se.astype(int), axis = -1).reshape((1, 3, 5, 5)) * se

In [73]:
nw[0,0,...]

Array([[1.1831467e+00, 2.1385820e+02, 6.6818051e+00, 1.3736848e+02,
        1.9307878e+00],
       [4.7833323e-01, 1.1460828e+01, 1.1724422e+01, 1.1185474e-02,
        8.9849670e+01],
       [2.6576254e+01, 7.6896701e+00, 9.9084206e+01, 2.1191071e+01,
        1.6204090e+01],
       [1.7637955e+02, 0.0000000e+00, 2.1131960e-01, 5.9924543e-01,
        1.0505860e+00],
       [5.7667130e+01, 3.4159000e+01, 7.5457191e-01, 3.1520629e-01,
        0.0000000e+00]], dtype=float32)

In [69]:
jax_output = nw + ne + sw + se
jax_output[0,0,...]

Array([[ 13.032912  , 214.84859   ,  13.103388  , 145.00954   ,
          4.9947824 ],
       [ 10.287     ,  14.014112  ,  18.94702   ,   7.643925  ,
        100.139755  ],
       [ 41.605545  ,   8.272039  ,  99.22157   ,  40.316727  ,
         20.06027   ],
       [180.70137   ,   0.91856575,   4.298126  ,   1.4007546 ,
          5.42827   ],
       [ 76.55528   ,  52.38726   ,   4.4192066 ,  13.196343  ,
          0.84645975]], dtype=float32)

In [70]:
torch_output = F.grid_sample(torch.tensor(x), torch_grid.to(float), align_corners = False)
torch_output[0, 0, ...]

tensor([[12.0188, 11.0632,  7.7579, 16.7990,  4.0294],
        [ 8.7523,  4.8764,  9.1767,  7.6383, 14.3837],
        [17.0736,  1.8257,  9.7354, 20.4501,  6.5569],
        [10.4911,  0.6238,  4.2981,  1.2786,  4.9030],
        [22.0919, 20.3632,  2.6576, 12.9205,  0.8465]], dtype=torch.float64)