In [1]:
import numpy as np
import tensorflow as tf
import tensorflow_probability as tfp

from importlib import reload

In [2]:
import particle_decoding as pd

In [94]:
reload(pd)

<module 'particle_decoding' from '/Users/jim2/bin/libVAE/particle_decoding.py'>

## Transformations

In [4]:
test_coords = np.array([[[1.0, 1.0, 1.0],
                        [-1.0, 1.0, 0.0],
                        [0.5, 0.5, 1.0]],
                        [[1.0, 1.0, 1.0],
                        [-1.0, 1.0, 0.0],
                        [0.5, 0.5, 1.0]]])

In [5]:
pd.identity_transform(test_coords)

<tf.Tensor: shape=(2, 3, 3), dtype=float64, numpy=
array([[[ 1. ,  1. ,  1. ],
        [-1. ,  1. ,  0. ],
        [ 0.5,  0.5,  1. ]],

       [[ 1. ,  1. ,  1. ],
        [-1. ,  1. ,  0. ],
        [ 0.5,  0.5,  1. ]]])>

In [6]:
pd.spherical_transform(test_coords)

<tf.Tensor: shape=(2, 3, 3), dtype=float64, numpy=
array([[[1.73205081, 0.78539816, 0.95531662],
        [1.41421356, 2.35619449, 1.57079633],
        [1.22474487, 0.78539816, 0.61547971]],

       [[1.73205081, 0.78539816, 0.95531662],
        [1.41421356, 2.35619449, 1.57079633],
        [1.22474487, 0.78539816, 0.61547971]]])>

In [7]:
pd.spherical_transform(test_coords, dist_sq=tf.reduce_sum(test_coords * test_coords, axis=-1))

<tf.Tensor: shape=(2, 3, 3), dtype=float64, numpy=
array([[[1.73205081, 0.78539816, 0.95531662],
        [1.41421356, 2.35619449, 1.57079633],
        [1.22474487, 0.78539816, 0.61547971]],

       [[1.73205081, 0.78539816, 0.95531662],
        [1.41421356, 2.35619449, 1.57079633],
        [1.22474487, 0.78539816, 0.61547971]]])>

In [8]:
pd.spherical_transform(pd.spherical_transform(test_coords), reverse=True)

<tf.Tensor: shape=(2, 3, 3), dtype=float64, numpy=
array([[[ 1.00000000e+00,  1.00000000e+00,  1.00000000e+00],
        [-1.00000000e+00,  1.00000000e+00,  8.65956056e-17],
        [ 5.00000000e-01,  5.00000000e-01,  1.00000000e+00]],

       [[ 1.00000000e+00,  1.00000000e+00,  1.00000000e+00],
        [-1.00000000e+00,  1.00000000e+00,  8.65956056e-17],
        [ 5.00000000e-01,  5.00000000e-01,  1.00000000e+00]]])>

## Masking by distance

In [9]:
test_coords = np.array([[[1.0, 1.0, 1.0],
                        [-1.0, 1.0, 0.0],
                        [0.5, 0.5, 1.0]],
                        [[1.0, 1.0, 1.0],
                        [-1.0, 1.0, 0.0],
                        [0.5, 0.5, 1.0]]])

ref_coords = np.array([[[0.0, 0.0, 0.0]],
                       [[0.0, 0.0, 1.0]]])

In [10]:
distance_mask = pd.DistanceMask()

In [11]:
distance_mask(ref_coords, test_coords, k_neighbors=2)

(<tf.Tensor: shape=(2, 2, 3), dtype=float64, numpy=
 array([[[ 0.5,  0.5,  1. ],
         [-1. ,  1. ,  0. ]],
 
        [[ 0.5,  0.5,  0. ],
         [ 1. ,  1. ,  0. ]]])>,
 <tf.Tensor: shape=(2, 2), dtype=int32, numpy=
 array([[2, 1],
        [2, 0]], dtype=int32)>,
 <tf.Tensor: shape=(2, 2), dtype=float64, numpy=
 array([[1.5, 2. ],
        [0.5, 2. ]])>)

In [12]:
local_coords, near_inds, dist_sq = distance_mask(ref_coords, test_coords, k_neighbors=2)
pd.spherical_transform(local_coords, dist_sq=dist_sq)

<tf.Tensor: shape=(2, 2, 3), dtype=float64, numpy=
array([[[1.22474487, 0.78539816, 0.61547971],
        [1.41421356, 2.35619449, 1.57079633]],

       [[0.70710678, 0.78539816, 1.57079633],
        [1.41421356, 0.78539816, 1.57079633]]])>

In [13]:
#Need to also test if ref included in input coordinates
ref_coords = np.array([[[-1.0, 1.0, 0.0]],
                       [[1.0, 1.0, 1.0]]])

distance_mask(ref_coords, test_coords, k_neighbors=1, ref_included=True)

(<tf.Tensor: shape=(2, 1, 3), dtype=float64, numpy=
 array([[[ 1.5, -0.5,  1. ]],
 
        [[-0.5, -0.5,  0. ]]])>,
 <tf.Tensor: shape=(2, 1), dtype=int32, numpy=
 array([[2],
        [2]], dtype=int32)>,
 <tf.Tensor: shape=(2, 1), dtype=float64, numpy=
 array([[3.5],
        [0.5]])>)

In [14]:
distance_mask(ref_coords, test_coords, k_neighbors=1, ref_included=False)

(<tf.Tensor: shape=(2, 1, 3), dtype=float64, numpy=
 array([[[0., 0., 0.]],
 
        [[0., 0., 0.]]])>,
 <tf.Tensor: shape=(2, 1), dtype=int32, numpy=
 array([[1],
        [0]], dtype=int32)>,
 <tf.Tensor: shape=(2, 1), dtype=float64, numpy=
 array([[-0.],
        [-0.]])>)

In [15]:
#Now try distances with a periodic simulation box
test_coords = np.array([[[1.0, 1.0, 1.0],
                        [-1.0, 1.0, 0.0],
                        [0.5, 0.5, 1.0]],
                        [[1.0, 1.0, 1.0],
                        [-1.0, 1.0, 0.0],
                        [0.5, 0.5, 1.0]]])

ref_coords = np.array([[[0.0, 0.0, 0.0]],
                       [[0.0, 0.0, 1.0]]])

distance_mask = pd.DistanceMask(box_lengths=np.array([3.0, 3.0, 3.0]))

In [16]:
distance_mask(ref_coords, test_coords, k_neighbors=2)

(<tf.Tensor: shape=(2, 2, 3), dtype=float64, numpy=
 array([[[ 0.5,  0.5,  1. ],
         [-1. ,  1. ,  0. ]],
 
        [[ 0.5,  0.5,  0. ],
         [ 1. ,  1. ,  0. ]]])>,
 <tf.Tensor: shape=(2, 2), dtype=int32, numpy=
 array([[2, 1],
        [2, 0]], dtype=int32)>,
 <tf.Tensor: shape=(2, 2), dtype=float64, numpy=
 array([[1.5, 2. ],
        [0.5, 2. ]])>)

In [17]:
local_coords, near_inds, dist_sq = distance_mask(ref_coords, test_coords, k_neighbors=2)
pd.spherical_transform(local_coords, dist_sq=dist_sq)

<tf.Tensor: shape=(2, 2, 3), dtype=float64, numpy=
array([[[1.22474487, 0.78539816, 0.61547971],
        [1.41421356, 2.35619449, 1.57079633]],

       [[0.70710678, 0.78539816, 1.57079633],
        [1.41421356, 0.78539816, 1.57079633]]])>

In [18]:
#Need to also test if ref included in input coordinates
ref_coords = np.array([[[-1.0, 1.0, 0.0]],
                       [[1.0, 1.0, 1.0]]])

distance_mask(ref_coords, test_coords, k_neighbors=1, ref_included=True)

(<tf.Tensor: shape=(2, 1, 3), dtype=float64, numpy=
 array([[[-1. ,  0. ,  1. ]],
 
        [[-0.5, -0.5,  0. ]]])>,
 <tf.Tensor: shape=(2, 1), dtype=int32, numpy=
 array([[0],
        [2]], dtype=int32)>,
 <tf.Tensor: shape=(2, 1), dtype=float64, numpy=
 array([[2. ],
        [0.5]])>)

In [19]:
distance_mask(ref_coords, test_coords, k_neighbors=1, ref_included=False)

(<tf.Tensor: shape=(2, 1, 3), dtype=float64, numpy=
 array([[[0., 0., 0.]],
 
        [[0., 0., 0.]]])>,
 <tf.Tensor: shape=(2, 1), dtype=int32, numpy=
 array([[1],
        [0]], dtype=int32)>,
 <tf.Tensor: shape=(2, 1), dtype=float64, numpy=
 array([[-0.],
        [-0.]])>)

## Creating probability distributions

In [20]:
params = np.array([[[[0.0, 1.0], [10.0, 1.0], [-10.0, 1.0]],
                    [[0.0, 1.0], [0.0, 2.0], [0.0, 10.0]],
                    [[5.0, 1.0], [5.0, 2.0], [5.0, 10.0]]],
                   [[[0.0, 1.0], [-10.0, 1.0], [10.0, 1.0]],
                    [[0.0, 10.0], [0.0, 2.0], [0.0, 1.0]],
                    [[-5.0, 10.0], [-5.0, 2.0], [-5.0, 1.0]]]])

dist_list = [tfp.distributions.Normal]*3
param_trans = [lambda x, y: [x, y]]*3

In [21]:
joint_dist = pd.create_dist(params, dist_list, param_trans)

In [22]:
sample = tf.stack(joint_dist.sample(), axis=-1)
print(sample)

tf.Tensor(
[[[  0.49169714  11.19161727  -8.474033  ]
  [ -1.51291659   0.54238984  12.39250583]
  [  4.4238223    1.32622251  -3.22931312]]

 [[ -0.85670569  -9.72859304  10.14088885]
  [-13.93950195  -0.01701247  -0.22859501]
  [ -1.03448696  -3.55396454  -4.90337052]]], shape=(2, 3, 3), dtype=float64)


In [23]:
joint_dist.log_prob(tf.unstack(sample, axis=-1))

<tf.Tensor: shape=(2, 3), dtype=float64, numpy=
array([[-4.75196215, -7.70165052, -7.94422635],
       [-3.17054362, -6.75026046, -6.09722029]])>

In [24]:
np.sum(-0.5*((sample[-1, 1, :] - params[-1, 1, :, 0])**2)/(params[-1, 1, :, 1]**2) - np.log(params[-1, 1, :, 1]) - 0.5*np.log(2.0*np.pi))

-6.750260464092591

## Solvation neural networks

In [25]:
#Input can be solute or solvent coordinates, with extra_inputs being the other
#Start by testing without augmented input (only relevant for first network to predict layer of solvent for first solute)
solute_net = pd.SolvationNet((8, 3), out_event_dims=2, hidden_dim=50, n_hidden=3, augment_input=False)

In [26]:
solute_coords = tf.random.normal((1, 2, 3))

In [27]:
params, shifts = solute_net(solute_coords)
print(params, shifts)

tf.Tensor(
[[[[ 0.10781626 -0.11571729]
   [-0.16840316  0.03963847]
   [ 0.00373492  0.12946814]]

  [[ 0.04101098  0.10868616]
   [-0.11427411  0.00110314]
   [-0.030279    0.09210403]]

  [[ 0.12329493 -0.04062085]
   [ 0.0930428   0.06798562]
   [ 0.17172804  0.00833471]]

  [[ 0.08198394  0.16118045]
   [-0.02777326  0.12805092]
   [-0.08818378  0.01964958]]

  [[-0.06819345 -0.00276325]
   [-0.01321042  0.09585222]
   [-0.23026142  0.10316847]]

  [[-0.00448867 -0.01813732]
   [-0.03066178 -0.05994048]
   [-0.16927662  0.02164468]]

  [[-0.17675501 -0.10436471]
   [ 0.1461188   0.08162406]
   [-0.11232565 -0.03871739]]

  [[ 0.14816432  0.0419888 ]
   [-0.07042156 -0.01943495]
   [ 0.11227106  0.07579395]]]], shape=(1, 8, 3, 2), dtype=float32) tf.Tensor(
[[[[-0.64500034 -0.08439755]
   [-0.75721383  0.84083635]
   [ 0.47223446 -0.35288173]]

  [[ 0.10515638 -0.19499454]
   [-0.25758034  0.09153756]
   [-0.131713   -0.72505105]]

  [[-0.32415056 -0.58341235]
   [-0.6797054  -0.119

In [28]:
dist = pd.create_dist(params + shifts, [tfp.distributions.Normal]*3, [lambda x, y: [x, tf.math.exp(0.5*y)]]*3)
sample = tf.stack(dist.sample(), axis=-1)
print(sample)

tf.Tensor(
[[[-0.37553298 -2.6843796   1.3914394 ]
  [-1.4597013  -0.37443635 -1.1675941 ]
  [ 0.78088254  0.7249082  -1.1419271 ]
  [-0.43148774  1.5753844   1.0171566 ]
  [-1.4602644   0.88730437 -1.9841225 ]
  [-1.3848507   1.3533903   0.5335958 ]
  [-0.809906    0.29745895 -0.6161704 ]
  [-0.6153516  -1.8824135  -0.6535361 ]]], shape=(1, 8, 3), dtype=float32)


In [29]:
#First set of 3 (first particle) should not change shifts, even with different input due to autoregressive
#All others should change, though
new_shifts = solute_net(solute_coords, sampled_input=sample)
print(new_shifts)

tf.Tensor(
[[[[-0.64500034 -0.08439755]
   [-0.75721383  0.84083635]
   [ 0.47223446 -0.35288173]]

  [[ 0.09293883 -0.19159913]
   [-0.2719168   0.1069086 ]
   [-0.14381017 -0.7298844 ]]

  [[-0.28607392 -0.5723177 ]
   [-0.69163376 -0.11542855]
   [-0.35247365 -0.0740208 ]]

  [[-0.03576267  0.14998166]
   [-0.15114665  1.2878133 ]
   [ 0.00847869  0.712388  ]]

  [[-0.19274074  0.13678446]
   [-0.63451535  0.40722725]
   [ 0.83231175 -0.09545106]]

  [[-0.01921152 -0.6514305 ]
   [ 0.14742929  0.535235  ]
   [-0.38923907 -1.1309706 ]]

  [[ 0.50458133  0.32079434]
   [ 0.99562395  0.54062283]
   [ 0.0142978   0.35314912]]

  [[-0.1822684  -0.58434415]
   [ 0.20268929  0.31447667]
   [-0.15488133 -0.38507086]]]], shape=(1, 8, 3, 2), dtype=float32)


In [30]:
#And now with solvent inputs (remember, could flip if wanted)
solute_net_solv = pd.SolvationNet((8, 3), out_event_dims=2, hidden_dim=50, n_hidden=3, augment_input=True)

In [31]:
solvent_coords = tf.random.normal((1, 6, 3))

In [32]:
params, shifts = solute_net_solv(solute_coords, extra_coords=solvent_coords)
print(params, shifts)

tf.Tensor(
[[[[ 0.15956444  0.3244326 ]
   [-0.16784303 -0.00090821]
   [ 0.33943322  0.25059813]]

  [[-0.1017492   0.5648912 ]
   [ 0.17074701 -0.01971192]
   [-0.17403331  0.38767713]]

  [[-0.443408    0.02574154]
   [ 0.02508543 -0.47224748]
   [ 0.10160196 -0.00984256]]

  [[-0.11672117  0.46153378]
   [-0.07583715 -0.12084991]
   [ 0.4076216  -0.05110563]]

  [[ 0.18315119  0.07975881]
   [ 0.12006899 -0.02032234]
   [-0.19641884  0.00243369]]

  [[-0.19815093 -0.39651477]
   [ 0.02604538  0.40783793]
   [-0.33587584 -0.35284555]]

  [[ 0.14673568  0.33315682]
   [-0.22252207 -0.06051992]
   [ 0.1244607   0.28196537]]

  [[ 0.53577715 -0.10649478]
   [-0.07364794  0.14315231]
   [-0.01673634 -0.30133852]]]], shape=(1, 8, 3, 2), dtype=float32) tf.Tensor(
[[[[-2.0634918   1.9269998 ]
   [ 0.4808945   1.0519842 ]
   [-1.0508817   1.2779062 ]]

  [[-0.82666254 -1.2249026 ]
   [-1.665776   -0.3654336 ]
   [-0.6936797   0.12217906]]

  [[-0.7007578  -0.58440024]
   [-0.82439053 -0.722

In [33]:
dist = pd.create_dist(params + shifts, [tfp.distributions.Normal]*3, [lambda x, y: [x, tf.math.exp(0.5*y)]]*3)
sample = tf.stack(dist.sample(), axis=-1)
print(sample)

tf.Tensor(
[[[ 1.6731229   0.29586104 -1.6792445 ]
  [-1.0380908  -1.8602877   0.2053709 ]
  [-0.40047026  0.02286899 -0.3240497 ]
  [-0.4913438  -0.19388306  0.20336539]
  [ 0.7611809  -1.2745836   0.5533415 ]
  [-2.1002738  -0.3551841   0.91953754]
  [ 2.2348485  -1.2436614   2.291748  ]
  [ 0.38886893 -0.20912799 -2.723728  ]]], shape=(1, 8, 3), dtype=float32)


In [34]:
new_shifts = solute_net_solv(solute_coords, extra_coords=solvent_coords, sampled_input=sample)
print(new_shifts)

tf.Tensor(
[[[[-2.0634918   1.9269998 ]
   [ 0.4808945   1.0519842 ]
   [-1.0508817   1.2779062 ]]

  [[-0.82429004 -1.2206783 ]
   [-1.6602033  -0.37007993]
   [-0.68835366  0.12643513]]

  [[-0.7150829  -0.570533  ]
   [-0.83665407 -0.7149752 ]
   [-0.9511814   0.519834  ]]

  [[ 0.35814998  0.28046036]
   [-0.66918874 -0.4724368 ]
   [ 0.49067736  0.21258077]]

  [[-1.7941884   0.42579362]
   [-0.99053204 -1.2763073 ]
   [ 1.0499139  -1.4782698 ]]

  [[-0.820827    0.25613546]
   [-0.09130436 -0.69646716]
   [ 1.2254927  -0.11581251]]

  [[ 2.2475696  -0.56147385]
   [-1.2798913  -0.15451515]
   [ 0.5717433   0.57452995]]

  [[ 0.47830212  1.0059398 ]
   [ 0.05792773 -1.3524187 ]
   [-0.23197085  0.99139225]]]], shape=(1, 8, 3, 2), dtype=float32)


In [35]:
#Now that it's been called (so built) with extra_coords provided, must provide always with this layer
new_shifts = solute_net_solv(solute_coords, sampled_input=sample)
print(new_shifts)

ValueError: Input 0 of layer dense_18 is incompatible with the layer: expected axis -1 of input shape to have value 24 but received input with shape (1, 6)

In [37]:
#Now pretend we're working with a central solvent particle and adding one more
#(conditioned on other solutes and solvent already around)
solvent_net = pd.SolvationNet((1, 3), out_event_dims=2, hidden_dim=50, n_hidden=3, augment_input=True)

In [38]:
params, shifts = solvent_net(solvent_coords, extra_coords=solute_coords)
print(params, shifts)

tf.Tensor(
[[[[ 0.09617166 -0.1427853 ]
   [-0.21919696  0.07105578]
   [-0.2631318  -0.03186999]]]], shape=(1, 1, 3, 2), dtype=float32) tf.Tensor(
[[[[-1.4714721  -1.3842937 ]
   [ 0.05496937 -0.77525663]
   [-0.46102628 -0.3011788 ]]]], shape=(1, 1, 3, 2), dtype=float32)


## Solute decoding

In [59]:
solute_decoder = pd.SoluteDecoder(3,
                             box_lengths=np.array([3.0, 3.0, 3.0]),
                             k_solute_neighbors=3,
                             k_solvent_neighbors=3,
                             name='solutedecoder')

In [60]:
solute_coords = tf.random.uniform((2, 5, 3), minval=-1.5, maxval=1.5)

In [61]:
sample_out, params_out, ref_out, unused_data = solute_decoder(solute_coords)
print(sample_out.shape, params_out.shape, ref_out.shape)

(2, 15, 3) (2, 15, 3, 2) (2, 15, 3)


In [62]:
print(unused_data)

None


In [63]:
sample_out - ref_out

<tf.Tensor: shape=(2, 15, 3), dtype=float32, numpy=
array([[[ 0.4073695 , -0.02867442,  1.67454   ],
        [ 1.0605726 , -2.8944564 , -0.7953121 ],
        [ 1.0504704 , -0.5953356 , -0.9838525 ],
        [-0.23768282, -1.3471911 ,  0.23579001],
        [ 1.3999076 , -1.1094917 ,  0.78841496],
        [-0.29680216,  0.08044308, -2.9119508 ],
        [-0.8726536 , -0.07400417,  1.2249012 ],
        [ 1.0328872 , -0.7593357 ,  0.6607555 ],
        [ 0.02047682, -0.159136  ,  0.57557094],
        [ 0.56785655, -0.609092  ,  1.125905  ],
        [ 0.59897065, -0.88984996, -0.24103546],
        [-0.86891294,  1.5069112 , -2.0948174 ],
        [ 1.2573322 , -1.5498912 ,  0.9216907 ],
        [ 2.8279605 , -0.19105208, -2.9742413 ],
        [ 3.0149536 ,  1.2677433 , -1.4484994 ]],

       [[-1.9155545 ,  0.5966111 , -1.3746209 ],
        [-2.550157  ,  1.6547565 ,  0.18368924],
        [ 0.12734151, -1.5268837 ,  0.6142406 ],
        [-0.4265014 , -0.36384952, -0.5473944 ],
        [-2.202

In [64]:
params_out

<tf.Tensor: shape=(2, 15, 3, 2), dtype=float32, numpy=
array([[[[ 0.30313596,  0.03677286],
         [ 0.06988432, -0.50537074],
         [-0.06098804,  0.79898125]],

        [[ 1.0190378 , -0.9878891 ],
         [-1.8065484 , -0.14789519],
         [-1.4334364 ,  0.44219863]],

        [[-0.02417737,  1.4134057 ],
         [ 0.5869621 , -0.180574  ],
         [-0.6869556 , -1.3051687 ]],

        [[-0.6793554 , -0.94451296],
         [-0.67884594,  0.3752394 ],
         [-0.15374206, -0.6022556 ]],

        [[-0.81860137,  2.1191244 ],
         [ 0.05659524, -0.1398349 ],
         [-0.01752928, -0.53684294]],

        [[-0.00444075, -0.64395547],
         [ 0.40476272,  0.0675981 ],
         [-2.3486712 ,  1.046513  ]],

        [[ 0.14150064, -0.46851993],
         [ 1.0613854 , -0.43430832],
         [ 0.07343089, -0.3408835 ]],

        [[ 1.4619094 , -0.8366637 ],
         [ 0.34619766,  0.09568012],
         [ 0.9580365 , -0.10289458]],

        [[-1.955654  ,  0.59390104],
    

In [65]:
#For convenience, added log_prob function to decoder
solute_decoder.get_log_probs(sample_out, params_out, ref_out)

<tf.Tensor: shape=(2, 15), dtype=float32, numpy=
array([[-3.6126866, -3.2292728, -3.860935 , -2.71389  , -5.111157 ,
        -3.1781282, -4.8838935, -3.1516604, -4.7994237, -2.6393197,
        -1.8344693, -3.6293135, -4.7647104, -6.104755 , -3.5642986],
       [-3.823821 , -3.895533 , -4.0120955, -3.9807448, -5.357432 ,
        -4.891387 , -3.8125367, -9.616757 , -4.420715 , -3.7828453,
        -2.9694428, -3.617165 , -3.281393 , -8.257726 , -6.186609 ]],
      dtype=float32)>

In [66]:
#Above tested generation, now test if have training data
training_data = tf.random.normal((2, 60, 3))

In [67]:
sample_out, params_out, ref_out, unused_data = solute_decoder(solute_coords, train_data=training_data)
print(sample_out.shape, params_out.shape, ref_out.shape)

(2, 15, 3) (2, 15, 3, 2) (2, 15, 3)


In [68]:
training_data

<tf.Tensor: shape=(2, 60, 3), dtype=float32, numpy=
array([[[ 1.33831954e+00,  1.31551540e+00, -1.66626775e+00],
        [ 4.04382795e-01,  1.17365541e-02,  1.67298079e-01],
        [-7.93628633e-01,  1.68706465e+00,  1.68575048e+00],
        [-1.61827886e+00, -6.88336194e-01, -1.41281232e-01],
        [-1.12045026e+00,  6.89436853e-01,  1.38193369e+00],
        [ 1.44188955e-01, -2.05882502e+00, -1.89884090e+00],
        [ 1.25067502e-01,  7.96369672e-01, -3.15662265e-01],
        [ 2.20198080e-01, -1.46165574e+00, -1.44531858e+00],
        [ 1.69581771e-02, -6.43030882e-01,  6.67359710e-01],
        [-4.77539033e-01,  5.93882203e-01,  9.78645921e-01],
        [ 1.05426061e+00, -1.65352106e+00, -2.59309959e+00],
        [-4.60481226e-01, -8.06448519e-01,  2.88906980e+00],
        [-9.54926968e-01,  9.16537225e-01, -1.03455827e-01],
        [ 1.77636611e+00, -2.91218728e-01,  1.19345140e+00],
        [-4.10004050e-01, -1.06915092e+00,  2.14656973e+00],
        [-8.82443607e-01,  6.3430

In [69]:
unused_data

<tf.Tensor: shape=(2, 45, 3), dtype=float32, numpy=
array([[[ 1.3383195 ,  1.3155154 , -1.6662678 ],
        [-0.79362863,  1.6870646 ,  1.6857505 ],
        [-1.6182789 , -0.6883362 , -0.14128123],
        [-1.1204503 ,  0.68943685,  1.3819337 ],
        [ 0.14418896, -2.058825  , -1.8988409 ],
        [ 0.22019808, -1.4616557 , -1.4453186 ],
        [ 0.01695818, -0.6430309 ,  0.6673597 ],
        [-0.47753903,  0.5938822 ,  0.9786459 ],
        [ 1.0542606 , -1.6535211 , -2.5930996 ],
        [-0.95492697,  0.9165372 , -0.10345583],
        [-0.41000405, -1.0691509 ,  2.1465697 ],
        [-0.8824436 ,  0.6343032 ,  1.1929923 ],
        [ 2.1096609 , -0.09342341,  1.5141119 ],
        [-0.18350191,  1.3684999 ,  1.3274722 ],
        [-1.19622   ,  1.4813558 , -1.9298594 ],
        [-1.7785381 , -1.6580665 ,  0.35193923],
        [-1.1718476 , -0.12321528,  1.8777704 ],
        [ 0.19212192, -0.9396484 , -1.7070075 ],
        [-0.36260855, -0.8596625 , -0.02605031],
        [-0.25688

In [70]:
sample_out

<tf.Tensor: shape=(2, 15, 3), dtype=float32, numpy=
array([[[ 0.1250675 ,  0.7963697 , -0.31566226],
        [ 0.9545727 , -0.72348   , -1.4100987 ],
        [ 0.0819529 ,  1.3748324 , -0.69111365],
        [ 0.5215315 , -1.9538057 ,  0.3595367 ],
        [ 0.6879007 , -0.2160123 ,  0.23717904],
        [ 0.09096932, -0.14759494, -3.0031033 ],
        [-1.0137408 ,  0.1238116 , -0.65662986],
        [-0.10494626, -0.07097322,  1.4757384 ],
        [-0.46048123,  2.1935515 , -0.11093009],
        [ 1.5007433 ,  0.23304205,  1.1246293 ],
        [ 1.7763661 , -0.29121876,  1.1934514 ],
        [ 0.43404546,  0.82995653,  0.6248821 ],
        [ 1.7001228 , -1.3638356 ,  1.0394785 ],
        [ 0.4043828 ,  0.01173657,  0.16729808],
        [ 0.19506425, -1.190726  ,  1.5221657 ]],

       [[-0.33815548, -0.02746361, -1.0163385 ],
        [-1.1731118 ,  1.094159  ,  0.11668152],
        [ 1.0489861 , -0.42739287,  0.6758878 ],
        [-0.9040997 , -0.2970416 , -0.50034666],
        [-1.833

In [71]:
sample_out - ref_out

<tf.Tensor: shape=(2, 15, 3), dtype=float32, numpy=
array([[[ 0.22596219,  0.12777054, -0.217762  ],
        [ 1.0554674 , -1.3920791 , -1.3121984 ],
        [ 0.18284759,  0.70623326, -0.5932134 ],
        [-0.86302334, -1.9646311 , -0.74225944],
        [-0.69665414, -0.2268377 , -0.8646171 ],
        [-1.2935855 , -0.15842034, -4.1048994 ],
        [-0.07128799, -0.35455966,  0.09116536],
        [ 0.83750653, -0.5493445 ,  2.2235336 ],
        [ 0.48197156,  1.7151803 ,  0.63686514],
        [ 1.0918945 , -0.05453654, -0.19422364],
        [ 1.3675174 , -0.57879734, -0.1254015 ],
        [ 0.0251967 ,  0.54237795, -0.6939708 ],
        [ 1.7772932 , -0.565495  , -0.03576303],
        [ 0.48155317,  0.81007713, -0.9079435 ],
        [ 0.27223462, -0.39238548,  0.4469241 ]],

       [[-0.4779404 ,  0.05238932, -0.7969363 ],
        [-1.3128967 ,  1.174012  ,  0.3360837 ],
        [ 0.90920115, -0.34753993,  0.89529   ],
        [ 0.47143257,  0.90856737,  0.7796296 ],
        [-0.457

In [72]:
params_out

<tf.Tensor: shape=(2, 15, 3, 2), dtype=float32, numpy=
array([[[[ 3.03135961e-01,  3.67728584e-02],
         [ 6.98843151e-02, -5.05370736e-01],
         [-6.09880388e-02,  7.98981249e-01]],

        [[ 8.37980747e-01, -1.16864085e+00],
         [-1.82580018e+00,  3.54081094e-02],
         [-1.40886664e+00,  6.38997495e-01]],

        [[ 9.84043330e-02,  1.49009645e+00],
         [ 6.39671564e-01,  2.18987465e-04],
         [-6.23661697e-01, -1.16230547e+00]],

        [[-1.24479032e+00, -1.02247134e-01],
         [-1.42395782e+00, -6.33332729e-02],
         [-1.07681513e+00, -1.13738585e+00]],

        [[-6.79478407e-01,  2.54240060e+00],
         [ 1.30007446e-01,  5.22295237e-01],
         [-6.27405524e-01,  1.33564401e+00]],

        [[-1.02103698e+00, -1.86996472e+00],
         [-2.49838591e-01, -1.52087283e+00],
         [-4.08163357e+00,  1.03303957e+00]],

        [[ 1.56504571e-01, -4.15059686e-01],
         [-1.71596035e-01, -5.21098495e-01],
         [-7.08015561e-02,  3.387

In [73]:
solute_decoder.get_log_probs(sample_out, params_out, ref_out)

<tf.Tensor: shape=(2, 15), dtype=float32, numpy=
array([[-2.9331825, -2.6790464, -2.9253209, -2.5162997, -5.002162 ,
        -1.8381107, -2.385829 , -0.599303 , -2.9929757, -2.600055 ,
        -1.8399032, -3.490314 , -3.1221867, -3.4067698, -2.9497783],
       [-2.1979475, -3.3517115, -3.422019 , -2.326581 , -2.1468146,
        -2.3502321, -3.2804394, -1.5019624, -2.2673373, -2.3138618,
        -2.6060233, -3.0315967, -3.2160254, -2.524601 , -2.3842678]],
      dtype=float32)>

## Solvent decoding

In [74]:
solvent_decoder = pd.SolventDecoder(2,
                             box_lengths=np.array([3.0, 3.0, 3.0]),
                             k_solute_neighbors=3,
                             k_solvent_neighbors=3,
                             name='solventdecoder')

In [75]:
solvent_coords = tf.identity(sample_out)
training_data_unused = tf.identity(unused_data)

In [76]:
sample_out, params_out, ref_out, unused_data = solvent_decoder(solvent_coords, solute_coords=solute_coords)
print(sample_out.shape, params_out.shape, ref_out.shape)

(2, 45, 3) (2, 45, 3, 2) (2, 45, 3)


In [77]:
sample_out - ref_out

<tf.Tensor: shape=(2, 45, 3), dtype=float32, numpy=
array([[[-7.8179181e-01,  1.4009084e+00,  2.4779563e+00],
        [-2.3164260e+00, -1.9469554e+00, -4.0619338e-01],
        [-3.9325924e+00, -3.6074233e-01,  1.3488020e+00],
        [ 6.0001332e-01, -2.5263035e+00,  6.4149737e-02],
        [ 4.6995559e+00, -1.2110132e-01,  1.7472327e-01],
        [-2.1069651e+00, -1.2048495e+00, -9.1210556e-01],
        [ 5.5306250e-01, -6.8921041e-01, -1.8860016e+00],
        [-2.3560815e+00, -5.0281882e-03,  1.4475805e+00],
        [ 1.7907677e+00,  1.8480859e+00,  8.0429107e-02],
        [ 1.5702116e-01,  1.0471189e+00,  7.0578146e-01],
        [ 1.9070717e+00,  6.6678637e-01, -1.0638847e+00],
        [ 4.2527780e-01,  1.1846163e+00,  2.4434009e+00],
        [ 8.9083433e-02, -1.8858397e-01, -1.5298676e-01],
        [ 4.0063134e-01,  1.0477858e+00,  9.1995192e-01],
        [ 1.8798981e+00,  1.9587212e+00,  2.0127349e+00],
        [-1.0618007e+00,  1.4004937e+00,  2.0799282e-01],
        [ 3.0462565e

In [78]:
params_out

<tf.Tensor: shape=(2, 45, 3, 2), dtype=float32, numpy=
array([[[[-1.09353960e+00,  7.62675405e-01],
         [-3.79723191e-01,  1.66837960e-01],
         [ 1.05406845e+00,  1.57093674e-01]],

        [[-6.80275381e-01,  1.08901262e+00],
         [-1.17626369e+00, -1.63915396e-01],
         [-2.56464481e-02, -1.44412148e+00]],

        [[-8.29614282e-01,  2.24788547e+00],
         [-3.56653750e-01, -2.26991343e+00],
         [ 8.09637547e-01, -1.17435873e+00]],

        [[ 6.54187381e-01, -1.23911476e+00],
         [-1.51479161e+00,  2.06830412e-01],
         [ 1.41734397e+00,  7.25943685e-01]],

        [[ 1.31913197e+00,  1.66188407e+00],
         [ 2.61916697e-01, -2.00362253e+00],
         [-1.16011649e-02, -1.27958357e+00]],

        [[-8.20178688e-02, -1.74209088e-01],
         [-5.91521621e-01, -3.61364573e-01],
         [-3.65606904e-01,  1.42655516e+00]],

        [[-5.70749342e-01,  2.89352238e-02],
         [ 6.78084016e-01,  1.09664530e-01],
         [-2.52331161e+00,  2.154

In [79]:
#For convenience, added log_prob function to decoder
solvent_decoder.get_log_probs(sample_out, params_out, ref_out)

<tf.Tensor: shape=(2, 45), dtype=float32, numpy=
array([[ -5.530859 ,  -3.6045263,  -3.1375532,  -3.467719 ,  -3.6368647,
         -5.9484916,  -4.5486746,  -5.631031 ,  -3.497005 ,  -3.448432 ,
         -3.8655944,  -6.9420404,  -2.7396197,  -3.1468472,  -5.130164 ,
         -3.2415755,  -7.0202527,  -4.304549 ,  -3.7210095,  -4.8481364,
         -5.1135187,  -2.7243598,  -4.687882 ,  -4.721973 ,  -4.3842554,
         -5.118205 ,  -6.115066 ,  -2.9149685,  -5.1569786,  -5.4117594,
         -4.909191 ,  -3.0814123,  -3.9029593,  -7.2492757,  -4.247234 ,
         -2.6004903,  -8.963356 ,  -3.2693887,  -5.2834287,  -3.7505312,
         -4.216944 ,  -4.3085394,  -4.1005454,  -5.8446045,  -3.1086004],
       [ -4.7219496,  -3.2494555,  -6.2112565,  -3.3008366,  -2.642427 ,
         -5.2862334,  -4.8507338,  -3.296557 ,  -3.0933979,  -3.6499894,
         -6.638645 ,  -5.398485 ,  -7.1496305,  -2.7331157,  -6.7870107,
         -3.2630959,  -4.691192 ,  -3.1464515,  -2.649967 , -12.114716 ,
 

In [81]:
#Now do with training data (leftover from solute decoding)
sample_out, params_out, ref_out, unused_data = solvent_decoder(solvent_coords,
                                                  solute_coords=solute_coords,
                                                  train_data=training_data_unused)
print(sample_out.shape, params_out.shape, ref_out.shape)

(2, 45, 3) (2, 45, 3, 2) (2, 45, 3)


In [82]:
unused_data

<tf.Tensor: shape=(2, 0, 3), dtype=float32, numpy=array([], shape=(2, 0, 3), dtype=float32)>

In [83]:
sample_out

<tf.Tensor: shape=(2, 45, 3), dtype=float32, numpy=
array([[[-0.8824436 ,  0.6343032 ,  1.1929923 ],
        [ 0.22019808, -1.4616557 , -1.4453186 ],
        [-1.4554077 ,  0.03112915,  0.43372697],
        [ 1.4898094 , -3.7483077 ,  1.9330842 ],
        [ 1.8037798 , -1.5186442 ,  1.0701406 ],
        [ 0.08663563, -0.80437016, -4.1144543 ],
        [-0.89033914, -0.0934234 , -1.485888  ],
        [-1.1718476 , -0.12321529, -1.1222295 ],
        [-0.31996953,  0.9711963 ,  0.12809676],
        [ 1.3817213 , -0.6883362 ,  2.8587189 ],
        [ 3.0136044 , -0.1369034 ,  1.8009857 ],
        [ 0.7950315 ,  0.74178123,  1.819874  ],
        [ 4.301643  , -0.64166   ,  0.8928647 ],
        [ 1.1081833 ,  0.8492683 ,  0.3368077 ],
        [ 0.6921603 ,  0.01043427,  2.6315434 ],
        [-0.25688517, -0.04312833, -0.00699444],
        [ 1.2982382 , -0.46375597,  0.42167568],
        [ 1.4030902 ,  0.6605246 ,  0.01447755],
        [ 1.5736504 , -1.8194174 ,  1.1383296 ],
        [ 1.14749

In [84]:
sample_out - ref_out

<tf.Tensor: shape=(2, 45, 3), dtype=float32, numpy=
array([[[-1.0075111 , -0.16206646,  1.5086546 ],
        [-0.7343746 , -0.73817575, -0.03521991],
        [-1.5373607 , -1.3437033 ,  1.1248406 ],
        [ 0.9682779 , -1.794502  ,  1.5735476 ],
        [ 1.115879  , -1.3026319 ,  0.83296156],
        [-0.00433369, -0.65677524, -1.111351  ],
        [ 0.12340164, -0.217235  , -0.82925814],
        [-1.0669013 , -0.05224207, -2.5979679 ],
        [ 0.14051169, -1.2223552 ,  0.23902684],
        [-0.11902201, -0.92137825,  1.7340896 ],
        [ 1.2372383 ,  0.15431535,  0.6075343 ],
        [ 0.36098602, -0.0881753 ,  1.194992  ],
        [ 2.60152   ,  0.7221756 , -0.14661384],
        [ 0.70380044,  0.83753175,  0.16950962],
        [ 0.49709606,  1.2011603 ,  1.1093777 ],
        [-0.38195267, -0.839498  ,  0.30866784],
        [ 0.34366548,  0.25972402,  1.8317744 ],
        [ 1.3211373 , -0.7143078 ,  0.7055912 ],
        [ 1.0521188 ,  0.13438833,  0.7787929 ],
        [ 0.45959

In [85]:
params_out

<tf.Tensor: shape=(2, 45, 3, 2), dtype=float32, numpy=
array([[[[-1.09353960e+00,  7.62675405e-01],
         [-3.79723191e-01,  1.66837960e-01],
         [ 1.05406845e+00,  1.57093674e-01]],

        [[-6.80275381e-01,  1.08901262e+00],
         [-1.17626369e+00, -1.63915396e-01],
         [-2.56464481e-02, -1.44412148e+00]],

        [[-9.93697643e-01,  2.43652987e+00],
         [-1.03772020e+00, -1.71515524e+00],
         [ 1.02741814e+00, -1.60907650e+00]],

        [[ 6.54187381e-01, -1.23911476e+00],
         [-1.51479161e+00,  2.06830412e-01],
         [ 1.41734397e+00,  7.25943685e-01]],

        [[ 1.50815022e+00,  1.71732533e+00],
         [-6.93225920e-01, -1.26225519e+00],
         [ 8.27834904e-01, -1.74837422e+00]],

        [[ 1.57928541e-01, -2.35672459e-01],
         [-6.65214002e-01, -8.04778337e-02],
         [-4.06945616e-01,  1.46814895e+00]],

        [[ 4.11913157e-01, -5.03803134e-01],
         [ 3.42960566e-01,  4.75717261e-02],
         [-3.47494304e-01,  1.029

In [86]:
solvent_decoder.get_log_probs(sample_out, params_out, ref_out)

<tf.Tensor: shape=(2, 45), dtype=float32, numpy=
array([[-3.4101958, -2.6110425, -2.6097736, -2.8116608, -2.7801588,
        -3.4066653, -3.303232 , -3.1462336, -3.6866484, -2.0999413,
        -3.41468  , -2.5681367, -2.6948233, -2.5127072, -2.5453188,
        -2.8804874, -3.7060275, -3.6134243, -3.9799972, -4.521641 ,
        -3.8919733, -2.515728 , -3.470159 , -3.754457 , -3.30758  ,
        -3.8181767, -3.29625  , -2.7213554, -2.4627373, -3.2935743,
        -2.9441586, -3.0752473, -3.5557275, -3.4711332, -3.3708022,
        -3.7670174, -2.432826 , -2.8359795, -4.2835073, -3.8204436,
        -2.9582253, -3.6533852, -3.844265 , -2.822627 , -4.2954655],
       [-2.2312994, -2.7905626, -3.213862 , -2.2218797, -2.8241744,
        -4.904376 , -2.6343598, -2.939527 , -3.1651888, -2.2299862,
        -2.6402698, -1.8288848, -3.7621655, -2.752579 , -4.398548 ,
        -2.4148176, -2.9977016, -2.968325 , -3.0765905, -3.8539076,
        -3.251545 , -2.7912145, -3.8337522, -3.0343   , -2.5128975

## Full decoding with more complex distributions

In [95]:
#Try again with more complicated transform function
#Which will also require a more complicated set of sampling distributions
#Don't want to use Gaussian for r, theta, and phi
#Want to use Gamma for r and von Mises for angles
#For Gamma, specify log(mean) and log(var) and convert to distribution parameters
#Must do log(mean) since mean is strictly positive, which requires mean_transform
decoder = pd.SoluteSolventDecoder(3, 2,
                             box_lengths=np.array([3.0, 3.0, 3.0]),
                             k_solute_neighbors=3,
                             k_solvent_neighbors=3,
                             tfp_dist_list=[tfp.distributions.Gamma,
                                            tfp.distributions.VonMises,
                                            tfp.distributions.VonMises],
                             param_transforms=[lambda x, y: [tf.math.exp(2*x - y), tf.math.exp(x - y)],
                                               lambda x, y: [x, tf.math.exp(-y)],
                                               lambda x, y: [x, tf.math.exp(-y)]],
                             mean_transforms=[tf.math.exp, tf.identity, tf.identity],
                             coord_transform=pd.spherical_transform)

In [96]:
solute_coords = tf.random.uniform((2, 5, 3), minval=-1.5, maxval=1.5)

In [97]:
sample_out, params_out, ref_out, unused_data = decoder(solute_coords)
print(sample_out.shape, params_out.shape, ref_out.shape)

(2, 60, 3) (2, 60, 3, 2) (2, 60, 3)


In [98]:
sample_out - ref_out

<tf.Tensor: shape=(2, 60, 3), dtype=float32, numpy=
array([[[-6.57248497e-03, -3.10897827e-04, -1.20344758e-02],
        [ 2.40434051e+00, -1.63986361e+00,  7.77477026e-01],
        [ 6.72146261e-01, -2.53318548e-01, -1.76914251e+00],
        [ 8.97169113e-04,  3.71456146e-04,  1.78523064e-02],
        [ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00],
        [-6.56056404e-02, -1.76155269e-02,  5.93137145e-01],
        [ 7.80308723e+00,  1.22241366e+00, -4.24284792e+00],
        [ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00],
        [-1.08422017e+00, -4.25542593e-02, -1.30555224e+00],
        [-1.20478868e-02, -4.57699299e-02,  1.44327879e-02],
        [ 6.59548640e-01,  7.07664013e-01,  9.49708939e-01],
        [ 3.69158697e+00,  4.65112269e-01,  3.01725173e+00],
        [ 6.46457911e-01, -1.22598374e+00,  3.14612770e+00],
        [-2.38418579e-07, -5.06639481e-07,  4.76837158e-07],
        [-2.07675338e-01,  1.03485972e-01,  1.34845769e+00],
        [-9.64415741e+00, -1.0241

In [99]:
params_out

<tf.Tensor: shape=(2, 60, 3, 2), dtype=float32, numpy=
array([[[[-1.1336217e+00, -1.1167966e+00],
         [-1.1315372e+00,  1.4715846e+00],
         [ 1.2739428e+00, -1.0234891e+00]],

        [[ 1.2194245e+00,  5.4232514e-01],
         [-6.6474214e-02, -8.2336497e-01],
         [-2.9658990e+00, -6.0646695e-01]],

        [[ 1.0523709e+00, -6.6771275e-01],
         [-5.8930683e-01, -1.7612375e+00],
         [-6.4377820e-01,  3.0045841e+00]],

        [[ 3.6556000e-01,  1.6371391e+00],
         [-2.0010090e+00,  2.7556186e+00],
         [-8.1860685e-01, -4.8097074e-01]],

        [[-4.6133184e+00,  1.6156620e+00],
         [ 3.1463413e+00,  3.3611434e+00],
         [-3.8043982e-01,  3.4105823e+00]],

        [[-3.6033040e-01, -2.7756634e+00],
         [-2.8859148e+00, -3.1065086e-01],
         [-5.5977601e-01, -1.8299392e+00]],

        [[ 2.1311455e+00,  2.9007077e-01],
         [-1.1718750e+00,  1.1855065e+00],
         [ 4.4523919e-01,  2.0323086e+00]],

        [[-2.7248104e+00, -8

In [100]:
#For convenience, added log_prob function to decoder
decoder.get_log_probs(sample_out, params_out, ref_out)

<tf.Tensor: shape=(2, 60), dtype=float32, numpy=
array([[-2.75223756e+00, -5.35232687e+00, -3.43959904e+00,
        -2.16167450e+00,             inf, -1.75812650e+00,
        -4.86873865e+00,             inf, -1.94272625e+00,
        -4.38273621e+00, -8.34967709e+00, -4.12830782e+00,
        -3.12700653e+00,  6.70936489e+00, -2.72840166e+00,
        -3.33263683e+00, -3.10321569e+00, -4.15060902e+00,
         1.26252640e+02, -7.52055120e+00, -5.61295509e+00,
        -6.70908165e+00, -4.06388140e+00, -5.88162374e+00,
        -6.26864719e+00, -5.24774933e+00, -4.02879667e+00,
        -4.46981573e+00, -1.27465839e+01, -3.27493882e+00,
        -5.02163658e+01,             inf, -8.94665527e+00,
                    inf, -4.91919327e+00,  6.81336761e-01,
         8.02539587e-01,             inf, -1.89832573e+01,
                    inf, -9.59176178e+01, -1.89593077e+00,
        -2.59332991e+00, -7.17555761e+00,             inf,
                    inf, -4.92819977e+00,             inf,
       

In [101]:
#Above tested generation, now test if have training data
training_data = tf.random.normal((2, 60, 3))

In [103]:
sample_out, params_out, ref_out, unused_data = decoder(solute_coords, train_data=training_data)
print(sample_out.shape, params_out.shape, ref_out.shape)

(2, 60, 3) (2, 60, 3, 2) (2, 60, 3)


In [104]:
training_data

<tf.Tensor: shape=(2, 60, 3), dtype=float32, numpy=
array([[[ 0.7498355 ,  0.42316303, -0.23963082],
        [ 0.98107064, -0.245562  , -0.40743113],
        [-0.90513355,  0.1854421 ,  1.4763688 ],
        [ 0.257954  ,  0.3442438 ,  0.61162466],
        [-1.0138341 , -1.649005  ,  0.44575483],
        [-0.8896971 , -0.27316263, -0.67733246],
        [-0.70758194,  0.40318605,  1.0580208 ],
        [-0.5152939 ,  0.06764869, -0.28447795],
        [-0.6278028 , -0.67295986, -1.2732356 ],
        [-0.92933494,  1.7294288 ,  0.22123632],
        [ 1.0625159 ,  0.24621502, -0.50846475],
        [ 0.11945935, -0.52619207, -1.350426  ],
        [ 0.1372898 ,  0.41269913, -1.7860678 ],
        [ 0.4274886 ,  0.5674834 ,  0.20325209],
        [ 2.1492496 , -1.402494  ,  1.0385101 ],
        [-0.42640242,  1.3551216 , -0.67278975],
        [-1.515974  ,  1.3386755 , -0.46348563],
        [-0.56847155, -0.72927153, -0.14825559],
        [ 1.5308512 ,  1.2210469 , -0.37276384],
        [-0.49843

In [105]:
sample_out

<tf.Tensor: shape=(2, 60, 3), dtype=float32, numpy=
array([[[-6.01524532e-01,  1.55853736e+00,  1.00311446e+00],
        [-4.26400036e-01,  1.35512006e+00, -3.67278957e+00],
        [-2.58310628e+00,  2.23791218e+00,  2.75241446e+00],
        [-3.79057646e-01,  6.50641561e-01, -1.76781273e+00],
        [-1.76687491e+00, -5.92385530e-01, -8.72709453e-01],
        [-1.99885976e+00, -7.25425601e-01, -4.68929172e-01],
        [-1.43985653e+00, -2.37393627e+01,  8.03683853e+00],
        [-9.35154557e-01,  1.62057877e-02,  1.10444760e+00],
        [-4.98435616e-01, -2.64510155e-01,  2.08618593e+00],
        [ 7.74865270e-01,  1.13498425e+00,  1.03071916e+00],
        [-1.54897285e+00,  3.85943353e-02,  1.38257122e+00],
        [ 4.27847576e+00, -3.84893751e+00, -3.31641197e+00],
        [ 1.61585534e+00, -1.35300726e-01,  1.36763489e+00],
        [ 9.99089718e-01, -1.71880871e-01,  6.94878161e-01],
        [ 2.76203656e+00, -1.62849212e+00,  2.44391441e+00],
        [-4.09351081e-01, -7.7877

In [108]:
decoder.coord_transform(sample_out - ref_out)

<tf.Tensor: shape=(2, 60, 3), dtype=float32, numpy=
array([[[ 3.44586849e-01,  9.04809594e-01,  2.28413478e-01],
        [ 4.34833002e+00, -5.66554308e-01,  3.08068585e+00],
        [ 2.93829465e+00,  2.77571654e+00,  7.81903148e-01],
        [ 1.60410798e+00,  8.44344378e-01,  1.96813846e+00],
        [ 5.08364975e-01, -2.81495333e+00,  1.00071847e+00],
        [ 9.69102681e-01, -2.74039268e+00,  7.95728564e-01],
        [ 2.46718941e+01, -1.58723342e+00,  1.27755392e+00],
        [ 2.69954294e-01,  8.77012730e-01,  7.40780413e-01],
        [ 1.31166685e+00, -2.48963058e-01,  4.50241208e-01],
        [ 1.95578170e+00,  3.28542471e-01,  1.39523828e+00],
        [ 9.78589714e-01, -2.38291383e+00,  7.83261657e-01],
        [ 7.96530485e+00, -6.86252832e-01,  2.09771609e+00],
        [ 9.17827725e-01,  1.08885443e+00,  3.84069145e-01],
        [ 5.59326649e-01,  2.61130500e+00,  1.24653745e+00],
        [ 2.61366296e+00, -7.38470435e-01,  7.41625071e-01],
        [ 1.08212021e+02, -1.5683

In [109]:
params_out

<tf.Tensor: shape=(2, 60, 3, 2), dtype=float32, numpy=
array([[[[-1.13362169e+00, -1.11679661e+00],
         [-1.13153720e+00,  1.47158456e+00],
         [ 1.27394283e+00, -1.02348912e+00]],

        [[ 1.56116164e+00,  3.77810627e-01],
         [ 2.38537639e-02, -9.58272696e-01],
         [-3.19863558e+00, -3.56653929e-01]],

        [[ 1.16153479e+00, -5.91787696e-01],
         [-5.22830009e-01, -1.93499696e+00],
         [-6.22683525e-01,  3.11711001e+00]],

        [[ 2.33022153e-01,  2.56194878e+00],
         [-2.53031635e+00,  2.90314245e+00],
         [-1.67892241e+00, -7.35425651e-01]],

        [[-4.28254461e+00,  1.64707029e+00],
         [ 2.33105850e+00,  2.58344364e+00],
         [ 9.15082276e-01,  2.68955612e+00]],

        [[-2.19704449e-01, -3.23005819e+00],
         [-2.77788568e+00, -5.42020023e-01],
         [ 8.27930689e-01, -3.18728733e+00]],

        [[ 3.18068099e+00, -1.16994226e+00],
         [-1.57188785e+00,  4.76934314e-01],
         [ 1.27417874e+00,  1.019

In [110]:
decoder.get_log_probs(sample_out, params_out, ref_out)

<tf.Tensor: shape=(2, 60), dtype=float32, numpy=
array([[-4.4633522e+00, -2.8965063e+00, -1.6171749e+01, -9.3339386e+00,
        -1.3116118e+01,  9.6769929e-02, -3.7670903e+00, -9.3110018e+00,
        -4.3146877e+00, -6.8092022e+00, -4.8373327e+00, -2.2564783e+00,
        -5.5699291e+00, -7.0446463e+00, -3.0317314e+00, -5.3666191e+00,
        -7.6605253e+00, -5.0785108e+00, -1.7366251e+00, -6.5960388e+00,
        -5.1534891e+00, -3.7396002e+00, -1.1838749e+01, -2.5994997e+00,
        -6.5384579e+00, -5.3810949e+00, -5.6639729e+00, -3.3938851e+00,
        -4.7727003e+00, -8.9372063e+00, -1.2575382e+01, -6.0241590e+00,
        -6.5685663e+00, -2.3635096e+01, -7.8351960e+00, -3.9160428e+00,
        -7.8820122e+01, -2.2305595e+01, -9.9335070e+00, -1.2421956e+01,
        -6.9289703e+00, -4.6397433e+00, -1.6643188e+01, -1.6275740e+01,
        -3.1314344e+00, -1.7695084e+01, -4.3475063e+01, -7.8199148e+00,
        -4.4308529e+01, -1.6813417e+02, -2.3501032e+01, -1.6340253e+02,
        -1.2564