In [1]:
import numpy as np
import torch
import tensorflow as tf
import jax
import jax.numpy as jnp
from jax import lax
from flax.nn.linear import _conv_dimension_numbers

In [73]:
s = (16, 16) # strides
k = 16
x = np.random.randn(1, 3, 384, 384).astype(np.float32)

th_x = torch.from_numpy(x)
tf_x = np.float32(th_x.numpy()).transpose(0, 2, 3, 1)

th_conv = torch.nn.Conv2d(3, 1, k, s, 0, bias=False)
tf_conv = tf.keras.layers.Convolution2D(1, k, s, 'valid', use_bias=False)
tf_conv(tf_x)
tf_conv.set_weights([th_conv.weight.data.permute(2, 3, 1, 0).detach().numpy()])

th_y = th_conv(th_x).detach().numpy()
tf_y = tf_conv(tf_x).numpy()
tf_y = tf_y.transpose(0, 3, 1, 2)

print(th_y.shape, th_y.ravel()[:10])
print(tf_y.shape, tf_y.ravel()[:10])
print('Error Rate:', np.mean(np.abs(th_y - tf_y) > 1e-6))

# assert np.allclose(th_y, tf_y, rtol=1e-5, atol=1e-8)

(1, 1, 24, 24) [-0.2502736   0.34197763  0.81398004  0.8719954  -1.2945161   0.07460865
  0.09205636 -0.24566422 -0.673575    0.24521644]
(1, 1, 24, 24) [-0.2502736   0.3419779   0.81398     0.87199515 -1.2945158   0.07460855
  0.09205629 -0.2456645  -0.673575    0.24521616]
Error Rate: 0.006944444444444444


In [74]:
jax_x = jnp.asarray(th_x.numpy())
jax_k = jnp.asarray(th_conv.weight.data.detach().numpy())

jax_y = lax.conv_general_dilated(
                 jax_x,    # lhs = NCHW image tensor
                 jax_k, # rhs = OIHW conv kernel tensor
                 s,  # window strides
                 'VALID') # padding mode
jax_y = jax_y._value

print(th_y.shape, th_y.ravel()[:10])
print(jax_y.shape, jax_y.ravel()[:10])
print('Error Rate:', np.mean(np.abs(th_y - jax_y) > 1e-6))

(1, 1, 24, 24) [-0.2502736   0.34197763  0.81398004  0.8719954  -1.2945161   0.07460865
  0.09205636 -0.24566422 -0.673575    0.24521644]
(1, 1, 24, 24) [-0.2502736   0.3419779   0.81398     0.87199515 -1.2945158   0.07460855
  0.09205629 -0.2456645  -0.673575    0.24521616]
Error Rate: 0.006944444444444444


In [79]:
jax_y - th_y

array([[[[ 0.0000000e+00,  2.6822090e-07, -5.9604645e-08,
          -2.3841858e-07,  2.3841858e-07, -9.6857548e-08,
          -7.4505806e-08, -2.8312206e-07,  0.0000000e+00,
          -2.8312206e-07,  3.5762787e-07,  2.3841858e-07,
           4.7683716e-07, -2.9802322e-08,  1.7881393e-07,
           1.1920929e-07, -2.3841858e-07, -1.0943040e-08,
          -1.7881393e-07,  2.3841858e-07,  6.0535967e-09,
           3.1292439e-07,  5.9604645e-08,  0.0000000e+00],
         [ 2.3841858e-07,  2.9802322e-08,  7.4505806e-08,
          -1.8253922e-07, -5.9604645e-08, -2.9802322e-08,
          -3.2782555e-07,  1.7881393e-07, -1.4901161e-08,
           7.4505806e-08, -1.0430813e-07,  1.3411045e-07,
          -3.5762787e-07,  2.9802322e-08, -1.4901161e-07,
           2.9802322e-08, -3.5762787e-07,  1.2665987e-07,
           0.0000000e+00, -2.9802322e-08,  0.0000000e+00,
           8.9406967e-08, -4.8428774e-08, -7.4505806e-08],
         [ 0.0000000e+00,  3.5762787e-07, -2.7567148e-07,
           5

In [80]:
jax_x = jnp.asarray(th_x.numpy().transpose(0, 2, 3, 1))
jax_k = jnp.asarray(th_conv.weight.data.detach().numpy().transpose(2, 3, 1, 0))

dimension_numbers = _conv_dimension_numbers(jax_x.shape)
jax_y = lax.conv_general_dilated(
                 jax_x,    # lhs = NHWc image tensor
                 jax_k,    # rhs = HWIO conv kernel tensor
                 s,        # window strides
                 'VALID',  # padding mode
                 dimension_numbers=dimension_numbers,)  
jax_y = jax_y._value.transpose(0, 3, 1, 2)

print(th_y.shape, th_y.ravel()[:10])
print(jax_y.shape, jax_y.ravel()[:10])
print('Error Rate:', np.mean(np.abs(th_y - jax_y) > 1e-6))

(1, 1, 24, 24) [-0.2502736   0.34197763  0.81398004  0.8719954  -1.2945161   0.07460865
  0.09205636 -0.24566422 -0.673575    0.24521644]
(1, 1, 24, 24) [-0.2502736   0.3419779   0.81398     0.87199515 -1.2945158   0.07460855
  0.09205629 -0.2456645  -0.673575    0.24521616]
Error Rate: 0.006944444444444444
