In [1]:
import torch
from stochman import nnj
import numpy as np
from jax import numpy as jnp
from jax import vjp
from jax import lax # for convolutions

In [2]:
_batch_size = 1
_input_size = 2
_hidden_size = 48
_hidden_c_in, _hidden_c_out = 3, 3
_hidden_h, _hidden_w = 4, 4
assert _hidden_size == _hidden_c_in * _hidden_h * _hidden_w
_kernel_size, _padding = 3,1
assert _kernel_size == 2*_padding + 1
_output_size = 2


x = torch.randn(_batch_size, _input_size)

In [3]:
A_layers = [nnj.Linear(_input_size, _hidden_size),
            nnj.Tanh()
            ]
B_layers = [nnj.Reshape(_hidden_c_in, _hidden_h, _hidden_w),
            nnj.Conv2d(_hidden_c_in, _hidden_c_out, _kernel_size, stride=1, padding=_padding),
            nnj.Flatten(),
            ]
C_layers = [nnj.Tanh(),
            nnj.Linear(_hidden_size, _output_size),
            #nnj.L2Norm()
            ]
                        
stoch_model = nnj.Sequential(
                    *A_layers,
                    nnj.ResidualBlock(*B_layers, add_hooks = True),
                    *C_layers,
                    add_hooks=True
                )

stoch_model

Sequential(
  (0): Linear(in_features=2, out_features=48, bias=True)
  (1): Tanh()
  (2): ResidualBlock(
    (_F): Sequential(
      (0): Reshape()
      (1): Conv2d(3, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (2): Flatten()
    )
  )
  (3): Tanh()
  (4): Linear(in_features=48, out_features=2, bias=True)
)

In [4]:
### test single input ###
y = stoch_model(x)

print('inner\n',[fm.shape for fm in stoch_model.feature_maps])
print('outer\n',[fm.shape for fm in stoch_model._modules_list[2]._F.feature_maps])

inner
 [torch.Size([1, 2]), torch.Size([1, 48]), torch.Size([1, 48]), torch.Size([1, 48]), torch.Size([1, 48]), torch.Size([1, 2])]
outer
 [torch.Size([1, 48]), torch.Size([1, 3, 4, 4]), torch.Size([1, 3, 4, 4]), torch.Size([1, 48])]


In [5]:
matrix = torch.ones(_batch_size, _output_size)
GGN = stoch_model._jTmjp(x, None, matrix, wrt = 'input', from_diag = True, to_diag = True)

reforwarding 6
non reforwarding 4


In [6]:
matrix = torch.ones(_batch_size, _output_size)
GGN = stoch_model._jTmjp(x, None, matrix, wrt = 'input', from_diag = True, to_diag = True)
GGN = stoch_model._jTmjp(x, None, matrix, wrt = 'input', from_diag = True, to_diag = False)
GGN = stoch_model._jTmjp(x, None, matrix, wrt = 'weight', from_diag = True, to_diag = True)
GGN = stoch_model._jTmjp(x, None, matrix, wrt = 'weight', from_diag = True, to_diag = False)

matrix = torch.ones(_batch_size, _output_size, _output_size)
GGN = stoch_model._jTmjp(x, None, matrix, wrt = 'input', from_diag = False, to_diag = True)
GGN = stoch_model._jTmjp(x, None, matrix, wrt = 'input', from_diag = False, to_diag = False)
GGN = stoch_model._jTmjp(x, None, matrix, wrt = 'weight', from_diag = False, to_diag = True)
GGN = stoch_model._jTmjp(x, None, matrix, wrt = 'weight', from_diag = False, to_diag = False)

reforwarding 6
non reforwarding 4
reforwarding 6
non reforwarding 4
reforwarding 6
non reforwarding 4
non reforwarding 4
reforwarding 6
non reforwarding 4
non reforwarding 4
reforwarding 6
non reforwarding 4
reforwarding 6
non reforwarding 4
reforwarding 6
non reforwarding 4
non reforwarding 4
reforwarding 6
non reforwarding 4
non reforwarding 4


### Get the weights

In [7]:
first_linear_layer_index = 0
resnet_index, conv_layer_index_in_resnet = 2, 1
second_linear_layer_index = 4

print('Weights in stoch model:')

print('\tlinear 1 - W\n\t',stoch_model._modules_list[first_linear_layer_index].weight.detach().numpy().shape)
print('\tlinear 1 - b\n\t',stoch_model._modules_list[first_linear_layer_index].bias.detach().numpy().shape)

print('\tconv - K\n\t',stoch_model._modules_list[resnet_index]._F._modules_list[conv_layer_index_in_resnet].weight.detach().numpy().shape)
print('\tconv - b\n\t',stoch_model._modules_list[resnet_index]._F._modules_list[conv_layer_index_in_resnet].bias.detach().numpy().shape)

print('\tlinear 2 - W\n\t',stoch_model._modules_list[second_linear_layer_index].weight.detach().numpy().shape)
print('\tlinear 2 - b\n\t',stoch_model._modules_list[second_linear_layer_index].bias.detach().numpy().shape)


weights = []

for row in stoch_model._modules_list[first_linear_layer_index].weight.detach().numpy():
    weights = np.concatenate((weights, row))
weights = np.concatenate((weights, stoch_model._modules_list[first_linear_layer_index].bias.detach().numpy()))
first_linear_num_weights = len(weights)

for c_out in range(_hidden_c_out):
    for c_in in range(_hidden_c_in):
        kernel = stoch_model._modules_list[resnet_index]._F._modules_list[conv_layer_index_in_resnet].weight.detach().numpy()[c_out,c_in]
        for row in kernel:
            weights = np.concatenate((weights, row))
weights = np.concatenate((weights, stoch_model._modules_list[resnet_index]._F._modules_list[conv_layer_index_in_resnet].bias.detach().numpy()))
conv_num_weights = len(weights) - first_linear_num_weights

for row in stoch_model._modules_list[second_linear_layer_index].weight.detach().numpy():
    weights = np.concatenate((weights, row))
weights = np.concatenate((weights, stoch_model._modules_list[second_linear_layer_index].bias.detach().numpy()))
second_linear_num_weights = len(weights) - first_linear_num_weights - conv_num_weights

weights = jnp.array(weights)
print('n weights per layer', first_linear_num_weights, conv_num_weights, second_linear_num_weights)
print('n weights total', len(weights))

Weights in stoch model:
	linear 1 - W
	 (48, 2)
	linear 1 - b
	 (48,)
	conv - K
	 (3, 3, 3, 3)
	conv - b
	 (3,)




	linear 2 - W
	 (2, 48)
	linear 2 - b
	 (2,)
n weights per layer 144 84 98
n weights total 326


### Define an equivalent Neural network on Jax

In [8]:
def Tanh(x):
    return jnp.tanh(x)

def L2Norm(x):
    x = x.T / (jnp.linalg.norm(x, ord=2, axis=1) + 1e-6)
    return x.T

def LinearLayer(w, b, x, print_weights=False):
    w = w.reshape(len(b), -1)
    out = jnp.dot(x, w.T) + b
    if print_weights:
        print('\twi\n\t',w)
        print('\tbi\n\t',b)
    return out
    
def ConvLayer(k, b, x, print_weights=False):
    k = k.reshape(_hidden_c_out, _hidden_c_in, _kernel_size, _kernel_size)
    out = lax.conv(x,
                   k,
                   (1,1),
                   'SAME')
    bias = jnp.einsum("c,bchw->bchw", b, jnp.ones_like(out))
    out = out + bias
    if print_weights:
        print('\tki\n\t',k)
        print('\tbi\n\t',b)
    return out

def Flatten(x):
    return x.reshape(_batch_size, -1)

def Reshape(x):
    return x.reshape(_batch_size, _hidden_c_in, _hidden_h, _hidden_w)

def ResNet(x, *layers):
    tmp_x = x
    for layer in layers:
        tmp_x = layer(tmp_x)
    return tmp_x + x

def jax_model(weights, x, print_weights=False, return_feature_maps=False):
    # split the weights array
    linear1_weights = weights[ : first_linear_num_weights]
    w1 = linear1_weights[ : _input_size*_hidden_size]
    b1 = linear1_weights[_input_size*_hidden_size : ]
    LL1 = lambda x : LinearLayer(w1, b1, x, print_weights=print_weights)

    conv_weights = weights[first_linear_num_weights : -second_linear_num_weights]
    conv_k = conv_weights[ : -_hidden_c_out]
    conv_b = conv_weights[-_hidden_c_out : ]
    assert len(conv_k) == _hidden_c_in * _hidden_c_out * _kernel_size * _kernel_size and len(conv_b) == _hidden_c_out
    CL = lambda x : ConvLayer(conv_k, conv_b, x, print_weights=print_weights)

    linear2_weights = weights[-second_linear_num_weights : ]
    w2 = linear2_weights[ : _hidden_c_out * _hidden_h * _hidden_w *_output_size]
    b2 = linear2_weights[_hidden_c_out * _hidden_h * _hidden_w * _output_size : ]
    LL2 = lambda x : LinearLayer(w2, b2, x, print_weights=print_weights)

    RN = lambda x : ResNet(x, Reshape, CL, Flatten)

    fm = [x]
    for layer in [LL1, Tanh, RN, Tanh, LL2]: #, L2Norm]:
        x = layer(x)
        fm.append(x)
    if return_feature_maps:
        return x, fm
    else:
        return x


Check that outputs are the same for random inputs

In [9]:
random_inputs_amount = 100

for _ in range(random_inputs_amount):
    x = torch.randn(_batch_size, _input_size)
    jax_x = jnp.array(x.numpy())

    y = stoch_model(x)
    jax_y,fm = jax_model(weights, jax_x, return_feature_maps=True)

    #print([np.max(abs(fs.detach().numpy() - np.array(fj))) for fs,fj in zip(stoch_model.feature_maps, fm)])
    assert max([np.max(abs(fs.detach().numpy() - np.array(fj))) for fs,fj in zip(stoch_model.feature_maps, fm)]) < 1e-5

# Check correctness wrt weights

Check that Vector-Jacobian products are the same for random inputs and random vector

In [10]:
random_inputs_amount = 10
random_vectors_amount = 10

for _ in range(random_inputs_amount):
    x = torch.randn(_batch_size, _input_size)
    jax_x = jnp.array(x.numpy())

    jax_y_by_vjp, vjp_fun = vjp(lambda w: jax_model(w, jax_x), weights)

    for _ in range(random_vectors_amount):
        vector = torch.randn(_batch_size, _output_size)
        jax_vector = jnp.array(vector.numpy())

        #vector = torch.zeros((_batch_size, _output_size))
        #vector[0,0] = 1
        #jax_vector = jnp.zeros((_batch_size, _output_size))
        #jax_vector = jax_vector.at[0,0].set(1)

        jax_vj = vjp_fun(jax_vector)
        stoch_vj = stoch_model._vjp(x, None, vector, wrt='weight')

        difference = np.array(jax_vj) - stoch_vj.detach().numpy()
        assert np.max(abs(difference)) < 1e-4


### Fix one input x and compute the Jacobian

In [11]:
x = torch.randn(_batch_size, _input_size)
jax_x = jnp.array(x.numpy())
jax_y_by_vjp, vjp_fun = vjp(lambda w: jax_model(w, jax_x), weights)


# define the identity matrix, for each batch element
identity = []
for i in range(_output_size):
    e_i = jnp.zeros((_batch_size, _output_size))
    for b in range(_batch_size):
        e_i = e_i.at[b,i].set(1)
    identity.append(e_i)

J_by_jax = []
for e_i in identity:
    v_i = vjp_fun(e_i)
    J_by_jax.append(v_i[0])

J_by_jax = np.array(J_by_jax)
print('Jacobian shape', J_by_jax.shape)


Jacobian shape (2, 326)


### Define a random matrix and backpropagate it

Check block diagonal

In [12]:
matrix = torch.randn(_batch_size, _output_size, _output_size)
jax_matrix = matrix.numpy()

jmj_by_jax = np.einsum("ji,bjk,kq->biq", J_by_jax, jax_matrix, J_by_jax)
jmj_by_stoch = stoch_model._jTmjp(x, None, matrix, wrt='weight')

blocks_by_jax = [jmj_by_jax[0][:first_linear_num_weights, :first_linear_num_weights], 
                 jmj_by_jax[0][first_linear_num_weights: -second_linear_num_weights, first_linear_num_weights: -second_linear_num_weights], 
                 jmj_by_jax[0][-second_linear_num_weights:, -second_linear_num_weights:]
]   

for block in range(2,-1,-1):
      print(blocks_by_jax[block].shape , jmj_by_stoch[block][0].detach().numpy().shape)
      difference = blocks_by_jax[block] - jmj_by_stoch[block][0].detach().numpy()
      print(np.max(abs(difference)))
      print(np.max(abs(blocks_by_jax[block])),
            np.max(abs(jmj_by_stoch[block][0].detach().numpy())))
      assert np.max(abs(difference)) < 1e-5

reforwarding 6
non reforwarding 4
non reforwarding 4
(98, 98) (98, 98)
2.9802322e-07
1.9350502 1.9350502
(84, 84) (1, 84, 84)
5.9604645e-08
0.32389998 0.32389998
(144, 144) (144, 144)
1.1175871e-08
0.04994989 0.049949884


Check exact diagonal

In [13]:
matrix = torch.randn(_batch_size, _output_size, _output_size)
jax_matrix = matrix.numpy()

jmj_by_jax = np.einsum("ji,bjk,ki->bi", J_by_jax, jax_matrix, J_by_jax)
jmj_by_stoch = stoch_model._jTmjp(x, None, matrix, wrt='weight', to_diag=True)

print(jmj_by_jax[0].shape , jmj_by_stoch[0].detach().numpy().shape)
difference = jmj_by_jax[0] - jmj_by_stoch[0].detach().numpy()
print(difference[-31:])
assert np.max(abs(difference)) < 1e-5


reforwarding 6
non reforwarding 4
non reforwarding 4
(326,) (326,)
[ 1.1175871e-08  0.0000000e+00  3.7252903e-09 -1.8626451e-08
 -1.1641532e-10  0.0000000e+00  7.4505806e-09  0.0000000e+00
  7.4505806e-09 -2.2351742e-08  0.0000000e+00  0.0000000e+00
 -1.4901161e-08  0.0000000e+00  7.4505806e-09  2.9802322e-08
  5.5879354e-09  0.0000000e+00  1.1175871e-08 -2.2351742e-08
  1.3038516e-08  1.3969839e-09  0.0000000e+00  4.4703484e-08
 -2.9802322e-08  2.7939677e-09  7.4505806e-09  0.0000000e+00
 -1.4901161e-08  0.0000000e+00  0.0000000e+00]


# Check correctness wrt the input

Check that Vector-Jacobian products are the same for random inputs and random vector

In [14]:
random_inputs_amount = 10
random_vectors_amount = 10

for _ in range(random_inputs_amount):
    x = torch.randn(_batch_size, _input_size)
    jax_x = jnp.array(x.numpy())

    jax_y_by_vjp, vjp_fun = vjp(lambda data: jax_model(weights, data), jax_x)

    for _ in range(random_vectors_amount):
        vector = torch.randn(_batch_size, _output_size)
        jax_vector = jnp.array(vector.numpy())

        #vector = torch.zeros((_batch_size, _output_size))
        #vector[0,0] = 1
        #jax_vector = jnp.zeros((_batch_size, _output_size))
        #jax_vector = jax_vector.at[0,0].set(1)

        jax_vj = vjp_fun(jax_vector)
        stoch_vj = stoch_model._vjp(x, None, vector, wrt='input')

        difference = np.array(jax_vj) - stoch_vj.detach().numpy()
        assert np.max(abs(difference)) < 1e-5


### Fix one input x and compute the Jacobian

In [15]:
x = torch.randn(_batch_size, _input_size)
jax_x = jnp.array(x.numpy())
jax_y_by_vjp, vjp_fun = vjp(lambda data: jax_model(weights, data), jax_x)


# define the identity matrix, for each batch element
identity = []
for i in range(_output_size):
    e_i = jnp.zeros((_batch_size, _output_size))
    for b in range(_batch_size):
        e_i = e_i.at[b,i].set(1)
    identity.append(e_i)

J_by_jax = []
for e_i in identity:
    v_i = vjp_fun(e_i)
    J_by_jax.append(v_i[0][0])

J_by_jax = np.array(J_by_jax)
print('Jacobian shape', J_by_jax.shape)


Jacobian shape (2, 2)


### Define a random matrix and backpropagate it

Check full case

In [16]:
matrix = torch.randn(_batch_size, _output_size, _output_size)
jax_matrix = matrix.numpy()

jmj_by_jax = np.einsum("ji,bjk,kq->biq", J_by_jax, jax_matrix, J_by_jax)
jmj_by_stoch = stoch_model._jTmjp(x, None, matrix, wrt='input')

print(jmj_by_jax[0].shape , jmj_by_stoch[0].detach().numpy().shape)
difference = jmj_by_jax - jmj_by_stoch.detach().numpy()
assert np.max(abs(difference)) < 1e-5

reforwarding 6
non reforwarding 4
(2, 2) (2, 2)


Check exact diagonal

In [17]:
matrix = torch.randn(_batch_size, _output_size, _output_size)
jax_matrix = matrix.numpy()

jmj_by_jax = np.einsum("ji,bjk,ki->bi", J_by_jax, jax_matrix, J_by_jax)
jmj_by_stoch = stoch_model._jTmjp(x, None, matrix, wrt='input', to_diag=True)

print(jmj_by_jax[0].shape , jmj_by_stoch[0].detach().numpy().shape)
difference = jmj_by_jax[0] - jmj_by_stoch[0].detach().numpy()
assert np.max(abs(difference)) < 1e-5

reforwarding 6
non reforwarding 4
(2,) (2,)


# Fix a pair of inputs (x1, x2) and compute the Jacobians (wrt weight)

In [18]:
x1 = torch.randn(_batch_size, _input_size)
x2 = torch.randn(_batch_size, _input_size)
jax_x1 = jnp.array(x1.numpy())
jax_x2 = jnp.array(x2.numpy())
jax_y1_by_vjp, vjp_fun1 = vjp(lambda w: jax_model(w, jax_x1), weights)
jax_y2_by_vjp, vjp_fun2 = vjp(lambda w: jax_model(w, jax_x2), weights)


# define the identity matrix, for each batch element
identity = []
for i in range(_output_size):
    e_i = jnp.zeros((_batch_size, _output_size))
    for b in range(_batch_size):
        e_i = e_i.at[b,i].set(1)
    identity.append(e_i)

J1_by_jax = []
J2_by_jax = []
for e_i in identity:
    v1_i = vjp_fun1(e_i)
    v2_i = vjp_fun2(e_i)
    J1_by_jax.append(v1_i[0])
    J2_by_jax.append(v2_i[0])

J1_by_jax = np.array(J1_by_jax)
J2_by_jax = np.array(J2_by_jax)
assert J1_by_jax.shape == J2_by_jax.shape
print('Jacobian shape', J1_by_jax.shape)

Jacobian shape (2, 326)


### Define a random matrix and backpropagate it

Check block diagonal

In [19]:
matrixes = tuple(torch.randn(_batch_size, _output_size, _output_size) for _ in range(3))
jax_matrixes = tuple(matrix.numpy() for matrix in matrixes)
 
jmj_by_jax = np.einsum("ji,bjk,kq->biq", J1_by_jax, jax_matrixes[0], J1_by_jax) \
                - np.einsum("ji,bjk,kq->biq", J1_by_jax, jax_matrixes[1], J2_by_jax) \
                - np.einsum("ji,bjk,kq->bqi", J1_by_jax, jax_matrixes[1], J2_by_jax) \
                + np.einsum("ji,bjk,kq->biq", J2_by_jax, jax_matrixes[2], J2_by_jax)
jmj_by_stoch = stoch_model._jTmjp_batch2(x1, x2, None, None, matrixes, wrt='weight')
jmj_by_stoch = [matrixes[0] - matrixes[1] - matrixes[1].transpose(-2,-1) + matrixes[2] for matrixes in jmj_by_stoch]

blocks_by_jax = [jmj_by_jax[0][:first_linear_num_weights, :first_linear_num_weights], 
                 jmj_by_jax[0][first_linear_num_weights: -second_linear_num_weights, first_linear_num_weights: -second_linear_num_weights], 
                 jmj_by_jax[0][-second_linear_num_weights:, -second_linear_num_weights:]
]   

for block in range(3):
    print(blocks_by_jax[block].shape , jmj_by_stoch[block][0].detach().numpy().shape)
    difference = blocks_by_jax[block] - jmj_by_stoch[block][0].detach().numpy()
    print(np.max(abs(difference)))
    assert np.max(abs(difference)) < 1e-4

(144, 144) (144, 144)
5.9604645e-08
(84, 84) (84, 84)
1.3411045e-07
(98, 98) (98, 98)
1.1920929e-06


In [20]:
len(jmj_by_stoch[0])

1

Check exact diagonal

In [21]:
matrixes = tuple(torch.randn(_batch_size, _output_size, _output_size) for _ in range(3))
jax_matrixes = tuple(matrix.numpy() for matrix in matrixes)
 
jmj_by_jax = np.einsum("ji,bjk,ki->bi", J1_by_jax, jax_matrixes[0], J1_by_jax) \
                - 2 * np.einsum("ji,bjk,ki->bi", J1_by_jax, jax_matrixes[1], J2_by_jax) \
                + np.einsum("ji,bjk,ki->bi", J2_by_jax, jax_matrixes[2], J2_by_jax)
jmj_by_stoch = stoch_model._jTmjp_batch2(x1, x2, None, None, matrixes, wrt='weight', to_diag=True)
jmj_by_stoch = [matrixes[0] - 2 * matrixes[1] + matrixes[2] for matrixes in jmj_by_stoch]
jmj_by_stoch = torch.cat(jmj_by_stoch, dim=1)

print(jmj_by_jax[0].shape , jmj_by_stoch[0].detach().numpy().shape)
difference = jmj_by_jax[0] - jmj_by_stoch[0].detach().numpy()
print(np.max(abs(difference)))
assert np.max(abs(difference)) < 1e-4

(326,) (326,)
4.7683716e-07


# Fix a pair of inputs (x1, x2) and compute the Jacobians (wrt weight)

In [22]:
x1 = torch.randn(_batch_size, _input_size)
x2 = torch.randn(_batch_size, _input_size)
jax_x1 = jnp.array(x1.numpy())
jax_x2 = jnp.array(x2.numpy())
jax_y1_by_vjp, vjp_fun1 = vjp(lambda data: jax_model(weights, data), jax_x1)
jax_y2_by_vjp, vjp_fun2 = vjp(lambda data: jax_model(weights, data), jax_x2)


# define the identity matrix, for each batch element
identity = []
for i in range(_output_size):
    e_i = jnp.zeros((_batch_size, _output_size))
    for b in range(_batch_size):
        e_i = e_i.at[b,i].set(1)
    identity.append(e_i)

J1_by_jax = []
J2_by_jax = []
for e_i in identity:
    v1_i = vjp_fun1(e_i)
    v2_i = vjp_fun2(e_i)
    J1_by_jax.append(v1_i[0][0])
    J2_by_jax.append(v2_i[0][0])

J1_by_jax = np.array(J1_by_jax)
J2_by_jax = np.array(J2_by_jax)
assert J1_by_jax.shape == J2_by_jax.shape
print('Jacobian shape', J1_by_jax.shape)


Jacobian shape (2, 2)


### Define a random matrix and backpropagate it

Check full case

In [23]:
matrixes = tuple(torch.randn(_batch_size, _output_size, _output_size) for _ in range(3))
jax_matrixes = tuple(matrix.numpy() for matrix in matrixes)
 
jmj_by_jax = np.einsum("ji,bjk,kq->biq", J1_by_jax, jax_matrixes[0], J1_by_jax) \
                - np.einsum("ji,bjk,kq->biq", J1_by_jax, jax_matrixes[1], J2_by_jax) \
                - np.einsum("ji,bjk,kq->bqi", J1_by_jax, jax_matrixes[1], J2_by_jax) \
                + np.einsum("ji,bjk,kq->biq", J2_by_jax, jax_matrixes[2], J2_by_jax)
jmj_by_stoch = stoch_model._jTmjp_batch2(x1, x2, None, None, matrixes, wrt='input')
jmj_by_stoch = jmj_by_stoch[0] - jmj_by_stoch[1] - jmj_by_stoch[1].transpose(-2,-1) + jmj_by_stoch[2]


print(jmj_by_jax[0].shape , jmj_by_stoch[0].detach().numpy().shape)
difference = jmj_by_jax - jmj_by_stoch.detach().numpy()
assert np.max(abs(difference)) < 1e-5

(2, 2) (2, 2)


Check exact diagonal

In [24]:
matrixes = tuple(torch.randn(_batch_size, _output_size, _output_size) for _ in range(3))
jax_matrixes = tuple(matrix.numpy() for matrix in matrixes)
 
jmj_by_jax = np.einsum("ji,bjk,ki->bi", J1_by_jax, jax_matrixes[0], J1_by_jax) \
                - 2* np.einsum("ji,bjk,ki->bi", J1_by_jax, jax_matrixes[1], J2_by_jax) \
                + np.einsum("ji,bjk,ki->bi", J2_by_jax, jax_matrixes[2], J2_by_jax)
jmj_by_stoch = stoch_model._jTmjp_batch2(x1, x2, None, None, matrixes, wrt='input', to_diag=True)
jmj_by_stoch = jmj_by_stoch[0] - 2 * jmj_by_stoch[1] + jmj_by_stoch[2]


print(jmj_by_jax[0].shape , jmj_by_stoch[0].detach().numpy().shape)
difference = jmj_by_jax - jmj_by_stoch.detach().numpy()
assert np.max(abs(difference)) < 1e-5

(2,) (2,)
