In [65]:
temp = 1.0

import numpy as np

def one_hot(x, n):
    z = np.zeros((x.shape[0], n))
    for i in range(x.shape[0]):
        z[i, x[i]] = 1
    return z

data = np.random.random((4, 8))
n_classes = 10
y = one_hot(np.random.randint(0, n_classes, (data.shape[0],)), n_classes)

In [66]:
data, y

(array([[0.11141843, 0.55237528, 0.85151549, 0.71874714, 0.66772003,
         0.52591008, 0.88956325, 0.82083697],
        [0.89945012, 0.59236676, 0.23607209, 0.60670265, 0.1502762 ,
         0.45755856, 0.43141032, 0.50547597],
        [0.01076259, 0.86810899, 0.67119655, 0.46519228, 0.50549821,
         0.11878284, 0.60170483, 0.32097283],
        [0.05121979, 0.5590654 , 0.78117268, 0.64897132, 0.98347347,
         0.27841603, 0.56809717, 0.07723502]]),
 array([[0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],
        [1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 1., 0., 0., 0., 0., 0.]]))

In [100]:
samples = 1000

In [101]:
import torch
from torch.distributions.relaxed_categorical import RelaxedOneHotCategorical

In [None]:
torch_linears = [
    torch.nn.Linear(data.shape[1], n_classes)
    for _ in range(samples)
]

In [124]:
def get_torch_grad(torch_linear):
    logits = torch_linear(torch.tensor(data, dtype=torch.float32))
    dist = RelaxedOneHotCategorical(temp, logits=logits)
    output = dist.rsample()
    loss = torch.mean((output - torch.tensor(y)) ** 2)
    loss.backward()
    return torch_linear.weight.grad.numpy().max()

grad_samples = [get_torch_grad(torch_linear) for torch_linear in torch_linears]
f'{np.mean(grad_samples):.5f} += {np.std(grad_samples):.5f}'

'0.01981 += 0.00830'

In [125]:
import tensorflow as tf
from tensorflow_probability import distributions as tfd

In [141]:
with tf.device('/CPU:0'):
    tf.random.set_seed(42)
    z = tfd.RelaxedOneHotCategorical(
        temp, logits=tf.convert_to_tensor([0., 1.])).sample()
z

<tf.Tensor: shape=(2,), dtype=float32, numpy=array([0.25640118, 0.7435988 ], dtype=float32)>

In [153]:
z = RelaxedOneHotCategorical(temp, logits=torch.tensor([0., 1.])).rsample()
z

tensor([0.1661, 0.8339])

In [126]:
def get_tf_grad(torch_linear):
    with tf.device('/CPU:0'):
        tf_linear = tf.keras.layers.Dense(n_classes)

        tf_linear.build((None, data.shape[1]))
        tf_linear.set_weights([
            torch_linear.weight.detach().numpy().T,
            torch_linear.bias.detach().numpy()
        ])

        tf_data = tf.convert_to_tensor(data)
        with tf.GradientTape() as tape:
            tf_logits = tf_linear(tf_data)
            dist = tfd.RelaxedOneHotCategorical(temp, logits=tf_logits)
            output = dist.sample()
            loss = tf.reduce_mean(tf.pow((y - output), 2))

        weight_gradient, _ = tape.gradient(loss, tf_linear.weights)
        return weight_gradient.numpy().max()


grad_samples = [get_tf_grad(torch_linear) for torch_linear in torch_linears]
f'{np.mean(grad_samples):.5f} += {np.std(grad_samples):.5f}'

'0.00683 += 0.00212'

In [127]:
torch_linear = torch_linears[0]
tf_linear = tf.keras.layers.Dense(n_classes)

tf_linear.build((None, data.shape[1]))
tf_linear.set_weights([
    torch_linear.weight.detach().numpy().T,
    torch_linear.bias.detach().numpy()
])

tf_linear(tf.convert_to_tensor(data)), torch_linear(torch.tensor(data, dtype=torch.float32))

(<tf.Tensor: shape=(4, 10), dtype=float32, numpy=
 array([[-0.6142186 ,  0.6297692 ,  0.49465802, -0.40699455,  0.3843618 ,
         -0.01220748, -0.06731844,  0.16490209,  0.7125151 , -0.35208547],
        [-0.41142106,  0.5612997 ,  0.12182345, -0.28776225,  0.64473546,
         -0.56076264,  0.39762443, -0.20173165,  0.27442896, -0.5084722 ],
        [-0.56447244,  0.46228093,  0.44391337, -0.18342935,  0.14141229,
          0.12823752, -0.17296311,  0.02412914,  0.56035423, -0.45750996],
        [-0.6395828 ,  0.34457055,  0.45277998, -0.28159922, -0.03297274,
          0.28652987, -0.20939389,  0.3001512 ,  0.6854204 , -0.19646247]],
       dtype=float32)>,
 tensor([[-0.6142,  0.6298,  0.4947, -0.4070,  0.3844, -0.0122, -0.0673,  0.1649,
           0.7125, -0.3521],
         [-0.4114,  0.5613,  0.1218, -0.2878,  0.6447, -0.5608,  0.3976, -0.2017,
           0.2744, -0.5085],
         [-0.5645,  0.4623,  0.4439, -0.1834,  0.1414,  0.1282, -0.1730,  0.0241,
           0.5604, -0.457