In [1]:
import numpy as np
import tensorflow as tf
import tensorflow_probability as tfp

from importlib import reload

In [2]:
import particle_decoding as pd

In [3]:
reload(pd)

<module 'particle_decoding' from '/Users/jim2/bin/libVAE/particle_decoding.py'>

## Transformations

In [4]:
test_coords = np.array([[[1.0, 1.0, 1.0],
                        [-1.0, 1.0, 0.0],
                        [0.5, 0.5, 1.0]],
                        [[1.0, 1.0, 1.0],
                        [-1.0, 1.0, 0.0],
                        [0.5, 0.5, 1.0]]])

In [5]:
pd.identity_transform(test_coords)

<tf.Tensor: shape=(2, 3, 3), dtype=float64, numpy=
array([[[ 1. ,  1. ,  1. ],
        [-1. ,  1. ,  0. ],
        [ 0.5,  0.5,  1. ]],

       [[ 1. ,  1. ,  1. ],
        [-1. ,  1. ,  0. ],
        [ 0.5,  0.5,  1. ]]])>

In [6]:
pd.spherical_transform(test_coords)

<tf.Tensor: shape=(2, 3, 3), dtype=float64, numpy=
array([[[1.73205081, 0.78539816, 0.95531662],
        [1.41421356, 2.35619449, 1.57079633],
        [1.22474487, 0.78539816, 0.61547971]],

       [[1.73205081, 0.78539816, 0.95531662],
        [1.41421356, 2.35619449, 1.57079633],
        [1.22474487, 0.78539816, 0.61547971]]])>

In [7]:
pd.spherical_transform(test_coords, dist_sq=tf.reduce_sum(test_coords * test_coords, axis=-1))

<tf.Tensor: shape=(2, 3, 3), dtype=float64, numpy=
array([[[1.73205081, 0.78539816, 0.95531662],
        [1.41421356, 2.35619449, 1.57079633],
        [1.22474487, 0.78539816, 0.61547971]],

       [[1.73205081, 0.78539816, 0.95531662],
        [1.41421356, 2.35619449, 1.57079633],
        [1.22474487, 0.78539816, 0.61547971]]])>

In [8]:
pd.spherical_transform(pd.spherical_transform(test_coords), reverse=True)

<tf.Tensor: shape=(2, 3, 3), dtype=float64, numpy=
array([[[ 1.00000000e+00,  1.00000000e+00,  1.00000000e+00],
        [-1.00000000e+00,  1.00000000e+00,  8.65956056e-17],
        [ 5.00000000e-01,  5.00000000e-01,  1.00000000e+00]],

       [[ 1.00000000e+00,  1.00000000e+00,  1.00000000e+00],
        [-1.00000000e+00,  1.00000000e+00,  8.65956056e-17],
        [ 5.00000000e-01,  5.00000000e-01,  1.00000000e+00]]])>

## Masking by distance

In [9]:
test_coords = np.array([[[1.0, 1.0, 1.0],
                        [-1.0, 1.0, 0.0],
                        [0.5, 0.5, 1.0]],
                        [[1.0, 1.0, 1.0],
                        [-1.0, 1.0, 0.0],
                        [0.5, 0.5, 1.0]]])

ref_coords = np.array([[[0.0, 0.0, 0.0]],
                       [[0.0, 0.0, 1.0]]])

In [10]:
pd.distance_mask(ref_coords, test_coords, k_neighbors=2)

(<tf.Tensor: shape=(2, 2, 3), dtype=float64, numpy=
 array([[[ 0.5,  0.5,  1. ],
         [-1. ,  1. ,  0. ]],
 
        [[ 0.5,  0.5,  0. ],
         [ 1. ,  1. ,  0. ]]])>,
 <tf.Tensor: shape=(2, 2), dtype=int32, numpy=
 array([[2, 1],
        [2, 0]], dtype=int32)>,
 <tf.Tensor: shape=(2, 2), dtype=float64, numpy=
 array([[1.5, 2. ],
        [0.5, 2. ]])>)

In [11]:
local_coords, near_inds, dist_sq = pd.distance_mask(ref_coords, test_coords, k_neighbors=2)
pd.spherical_transform(local_coords, dist_sq=dist_sq)

<tf.Tensor: shape=(2, 2, 3), dtype=float64, numpy=
array([[[1.22474487, 0.78539816, 0.61547971],
        [1.41421356, 2.35619449, 1.57079633]],

       [[0.70710678, 0.78539816, 1.57079633],
        [1.41421356, 0.78539816, 1.57079633]]])>

In [12]:
#Need to also test if ref included in input coordinates
ref_coords = np.array([[[-1.0, 1.0, 0.0]],
                       [[1.0, 1.0, 1.0]]])

pd.distance_mask(ref_coords, test_coords, k_neighbors=1, ref_included=True)

(<tf.Tensor: shape=(2, 1, 3), dtype=float64, numpy=
 array([[[ 1.5, -0.5,  1. ]],
 
        [[-0.5, -0.5,  0. ]]])>,
 <tf.Tensor: shape=(2, 1), dtype=int32, numpy=
 array([[2],
        [2]], dtype=int32)>,
 <tf.Tensor: shape=(2, 1), dtype=float64, numpy=
 array([[3.5],
        [0.5]])>)

In [13]:
pd.distance_mask(ref_coords, test_coords, k_neighbors=1, ref_included=False)

(<tf.Tensor: shape=(2, 1, 3), dtype=float64, numpy=
 array([[[0., 0., 0.]],
 
        [[0., 0., 0.]]])>,
 <tf.Tensor: shape=(2, 1), dtype=int32, numpy=
 array([[1],
        [0]], dtype=int32)>,
 <tf.Tensor: shape=(2, 1), dtype=float64, numpy=
 array([[-0.],
        [-0.]])>)

## Creating probability distributions

In [14]:
params = np.array([[[[0.0, 1.0], [10.0, 1.0], [-10.0, 1.0]],
                    [[0.0, 1.0], [0.0, 2.0], [0.0, 10.0]],
                    [[5.0, 1.0], [5.0, 2.0], [5.0, 10.0]]],
                   [[[0.0, 1.0], [-10.0, 1.0], [10.0, 1.0]],
                    [[0.0, 10.0], [0.0, 2.0], [0.0, 1.0]],
                    [[-5.0, 10.0], [-5.0, 2.0], [-5.0, 1.0]]]])

dist_list = [tfp.distributions.Normal, tfp.distributions.Normal, tfp.distributions.Normal]

In [15]:
joint_dist = pd.create_dist(params, dist_list)

In [16]:
sample = tf.stack(joint_dist.sample(), axis=-1)
print(sample)

tf.Tensor(
[[[ -0.63069479   7.32036578  -9.56933454]
  [ -2.87596283  -2.28326106  -7.53382161]
  [  3.6032184    5.60191437  17.14728616]]

 [[ -0.51090521 -10.45562484   9.51574286]
  [ -7.75650514  -4.69503275   0.4550419 ]
  [  1.74670781  -4.29780466  -3.38294273]]], shape=(2, 3, 3), dtype=float64)


In [17]:
joint_dist.log_prob(tf.unstack(sample, axis=-1))

<tf.Tensor: shape=(2, 3), dtype=float64, numpy=
array([[ -6.63865971, -10.82358144,  -7.51111771],
       [ -3.10837715,  -8.91231287,  -7.34921009]])>

In [18]:
np.sum(-0.5*((sample[-1, 1, :] - params[-1, 1, :, 0])**2)/(params[-1, 1, :, 1]**2) - np.log(params[-1, 1, :, 1]) - 0.5*np.log(2.0*np.pi))

-8.91231286702834

## Solvation neural networks

In [19]:
#Input can be solute or solvent coordinates, with extra_inputs being the other
#Start by testing without augmented input (only relevant for first network to predict layer of solvent for first solute)
solute_net = pd.SolvationNet((8, 3), out_event_dims=2, hidden_dim=50, n_hidden=3, augment_input=False)

In [20]:
solute_coords = tf.random.normal((1, 2, 3))

In [21]:
params, shifts = solute_net(solute_coords)
print(params, shifts)

tf.Tensor(
[[[[-0.05343146 -0.01744708]
   [-0.01111172  0.10000312]
   [ 0.15373015 -0.05767399]]

  [[ 0.0199598   0.20465696]
   [ 0.16297239  0.00276308]
   [ 0.07671326  0.04065618]]

  [[-0.07596047 -0.15887283]
   [-0.0653896  -0.08447456]
   [-0.08472626  0.24786235]]

  [[ 0.08911207 -0.06785001]
   [ 0.19760217  0.2200958 ]
   [ 0.16323312  0.03194129]]

  [[ 0.09130453 -0.05914907]
   [-0.09558204  0.10732365]
   [-0.20040767  0.18727888]]

  [[ 0.03559308 -0.08360106]
   [ 0.02835881 -0.03329094]
   [ 0.17418751  0.11739578]]

  [[-0.01265201  0.10476752]
   [-0.1101061  -0.03044134]
   [-0.15527336  0.2520959 ]]

  [[ 0.05101128 -0.06426936]
   [ 0.08353112  0.02236869]
   [-0.16487661  0.10023498]]]], shape=(1, 8, 3, 2), dtype=float32) tf.Tensor(
[[[[ 0.05465062  0.40191245]
   [ 0.4882277  -0.02144564]
   [-0.15812159 -0.5027362 ]]

  [[ 0.53503436 -0.36344436]
   [ 0.12764154 -0.32422307]
   [ 0.17383929 -0.60133934]]

  [[-0.84924406  0.31073946]
   [ 0.95487845 -0.218

In [22]:
dist = pd.create_dist(params + shifts, [tfp.distributions.Normal]*3,
                      param_transforms=[tf.identity, lambda x: tf.math.exp(0.5*x)])
sample = tf.stack(dist.sample(), axis=-1)
print(sample)

tf.Tensor(
[[[-0.558955    0.578966   -1.3150043 ]
  [ 0.577973    0.78255916  1.1857    ]
  [ 0.8864244   0.35352528 -0.37875664]
  [ 2.618908   -0.7341408   0.48986787]
  [ 0.85061383 -0.42614132 -1.8769361 ]
  [-0.4467808   2.1490645   1.5516778 ]
  [ 0.5655447  -1.5609803  -1.3636532 ]
  [ 0.61894566  0.81743     0.8687078 ]]], shape=(1, 8, 3), dtype=float32)


In [23]:
#First set of 3 (first particle) should not change shifts, even with different input due to autoregressive
#All others should change, though
new_shifts = solute_net(solute_coords, sampled_input=sample)
print(new_shifts)

tf.Tensor(
[[[[ 0.05465062  0.40191245]
   [ 0.4882277  -0.02144564]
   [-0.15812159 -0.5027362 ]]

  [[ 0.53911024 -0.36447403]
   [ 0.12695514 -0.32888848]
   [ 0.172545   -0.6029239 ]]

  [[-0.850589    0.30576348]
   [ 0.96357864 -0.19923142]
   [-0.16083513 -0.22859132]]

  [[ 0.04739184  0.01950971]
   [ 0.02184865  0.8729154 ]
   [ 0.30949178  0.23660335]]

  [[ 0.23142959  0.38250232]
   [-0.32069373 -0.44441321]
   [-0.0651332   0.14782232]]

  [[ 0.12461222  0.16260451]
   [ 0.14633012  0.1290909 ]
   [-0.27555257  0.05995542]]

  [[-0.15914941  0.23213677]
   [-0.01835339 -0.3679172 ]
   [ 0.33140105  0.4842227 ]]

  [[-0.01002502 -0.3434873 ]
   [-0.12823251 -0.5214914 ]
   [ 0.28583765  0.63640076]]]], shape=(1, 8, 3, 2), dtype=float32)


In [24]:
#And now with solvent inputs (remember, could flip if wanted)
solute_net_solv = pd.SolvationNet((8, 3), out_event_dims=2, hidden_dim=50, n_hidden=3, augment_input=True)

In [25]:
solvent_coords = tf.random.normal((1, 6, 3))

In [26]:
params, shifts = solute_net_solv(solute_coords, extra_coords=solvent_coords)
print(params, shifts)

tf.Tensor(
[[[[-0.08991517  0.32618904]
   [-0.06456754 -0.34168598]
   [ 0.15339841 -0.19027872]]

  [[-0.18956217  0.01222431]
   [ 0.13867798  0.43863368]
   [ 0.27719817  0.03524739]]

  [[-0.4396702  -0.01883183]
   [-0.15864518 -0.21449885]
   [ 0.02030725  0.18788826]]

  [[-0.02327124  0.16705602]
   [ 0.2512965  -0.5137874 ]
   [-0.13282327  0.10775443]]

  [[-0.02390646  0.46070448]
   [ 0.38383532 -0.26624614]
   [-0.09860896 -0.0698304 ]]

  [[ 0.07231984 -0.2696631 ]
   [-0.36812055 -0.08534073]
   [-0.14341965  0.12840396]]

  [[ 0.2580382   0.13138121]
   [ 0.04654905 -0.47386828]
   [ 0.17405878 -0.0088684 ]]

  [[-0.18014894 -0.19435543]
   [-0.20022985 -0.01195858]
   [ 0.18734047 -0.17433706]]]], shape=(1, 8, 3, 2), dtype=float32) tf.Tensor(
[[[[ 8.6616343e-01  6.9703162e-04]
   [-4.9962959e-01  1.9099520e+00]
   [-6.4729393e-01  1.7891903e+00]]

  [[ 4.2793494e-01 -6.2408465e-01]
   [-1.4102942e-01  6.7031658e-01]
   [ 5.3797752e-02  3.2218111e-01]]

  [[ 1.1774898e

In [27]:
dist = pd.create_dist(params + shifts, [tfp.distributions.Normal]*3,
                      param_transforms=[tf.identity, lambda x: tf.math.exp(0.5*x)])
sample = tf.stack(dist.sample(), axis=-1)
print(sample)

tf.Tensor(
[[[ 1.8568175   1.2206428  -2.5038223 ]
  [-0.1255145   2.1375322  -0.07818604]
  [ 0.35163334  0.3742599   0.49022302]
  [ 1.2161851  -0.7739396   0.29444635]
  [ 4.865453    0.7365566  -0.5921209 ]
  [-1.0386637   0.67759234 -0.2098292 ]
  [-0.7775736  -0.091086   -1.6804595 ]
  [-1.2873342  -0.3309005  -1.4387039 ]]], shape=(1, 8, 3), dtype=float32)


In [28]:
new_shifts = solute_net_solv(solute_coords, extra_coords=solvent_coords, sampled_input=sample)
print(new_shifts)

tf.Tensor(
[[[[ 8.6616343e-01  6.9703162e-04]
   [-4.9962959e-01  1.9099520e+00]
   [-6.4729393e-01  1.7891903e+00]]

  [[ 4.2172086e-01 -6.2437671e-01]
   [-1.2322435e-01  6.4647365e-01]
   [ 7.1760669e-02  3.2469419e-01]]

  [[ 1.1563253e+00 -7.8922087e-01]
   [ 4.5641485e-01  1.6661456e-01]
   [ 3.7734485e-01  9.2000648e-02]]

  [[-2.2684656e-01 -2.8150874e-01]
   [-1.2950764e+00  6.2683070e-01]
   [ 1.0175768e+00 -8.9644536e-02]]

  [[-1.4596611e-01  2.2316046e+00]
   [-1.0375872e-01 -1.4730238e+00]
   [ 1.0070620e+00  4.3811649e-01]]

  [[-1.3103859e+00  7.1715212e-01]
   [ 1.9439280e-02  5.3254253e-01]
   [ 2.8739139e-02 -1.2432921e-01]]

  [[-8.0750114e-01  7.1779042e-02]
   [-1.7204028e-01  9.7322905e-01]
   [-6.8046016e-01 -4.8173431e-01]]

  [[-8.9925218e-01  4.0616366e-01]
   [-5.1969230e-01 -2.4518207e-01]
   [-5.7692277e-01  9.2006624e-01]]]], shape=(1, 8, 3, 2), dtype=float32)


In [29]:
#Now that it's been called (so built) with extra_coords provided, must provide always with this layer
new_shifts = solute_net_solv(solute_coords, sampled_input=sample)
print(new_shifts)

ValueError: Input 0 of layer dense_18 is incompatible with the layer: expected axis -1 of input shape to have value 24 but received input with shape (1, 6)

In [30]:
#Now pretend we're working with a central solvent particle and adding one more
#(conditioned on other solutes and solvent already around)
solvent_net = pd.SolvationNet((1, 3), out_event_dims=2, hidden_dim=50, n_hidden=3, augment_input=True)

In [31]:
params, shifts = solvent_net(solvent_coords, extra_coords=solute_coords)
print(params, shifts)

tf.Tensor(
[[[[-0.44500983  0.5446396 ]
   [ 0.18715477 -0.21092294]
   [-0.18116441 -0.14093411]]]], shape=(1, 1, 3, 2), dtype=float32) tf.Tensor(
[[[[-0.5137334  -0.37528616]
   [-3.3742995  -2.2601662 ]
   [-0.17854308  1.4736085 ]]]], shape=(1, 1, 3, 2), dtype=float32)


## Full particle decoder

In [32]:
decoder = pd.ParticleDecoder(3, 2, k_solute_neighbors=3, k_solvent_neighbors=3)

In [33]:
solute_coords = tf.random.normal((2, 5, 3))

In [34]:
sample_out, params_out, ref_out = decoder(solute_coords)
print(sample_out.shape, params_out.shape, ref_out.shape)

(2, 60, 3) (2, 60, 3, 2) (2, 60, 3)


In [35]:
sample_out - ref_out

<tf.Tensor: shape=(2, 60, 3), dtype=float32, numpy=
array([[[ 2.59587079e-01, -1.78246009e+00,  1.59993303e+00],
        [-2.00532389e+00,  6.86201811e-01, -2.21410656e+00],
        [-3.14071178e+00, -1.01904881e+00, -2.22382021e+00],
        [ 5.69737911e+00, -3.24758053e+00, -5.43871117e+00],
        [-2.15248823e+00, -4.73428679e+00, -1.13199625e+01],
        [-4.48878002e+00, -4.24166679e+00, -2.33200908e-01],
        [ 3.49597359e+00, -2.27120829e+00,  1.25757766e+00],
        [-2.08799458e+00, -4.16380024e+00,  2.40583467e+00],
        [-7.33402967e+00, -2.39053702e+00,  3.35251451e-01],
        [-1.87421727e+00,  8.14388156e-01,  3.31929088e-01],
        [ 1.83158445e+00, -4.05200052e+00,  3.18570161e+00],
        [-2.74683237e-01,  7.18853712e-01,  2.78574562e+00],
        [ 6.56892836e-01, -1.34701753e+00, -2.22184324e+00],
        [-2.08578348e+00, -1.45382893e+00, -3.34285450e+00],
        [-4.32845688e+00, -3.85709000e+00,  2.18118250e-01],
        [ 1.56662261e+00, -2.7849

In [36]:
params_out

<tf.Tensor: shape=(2, 60, 3, 2), dtype=float32, numpy=
array([[[[-2.09930122e-01, -6.92374706e-01],
         [-1.55831778e+00,  2.16812992e+00],
         [ 1.12965584e+00, -3.02582324e-01]],

        [[-1.33067667e-01,  2.18980885e+00],
         [ 3.69626641e-01, -7.96890318e-01],
         [ 1.01741540e+00,  2.29061389e+00]],

        [[-1.70731068e-01,  8.40314627e-01],
         [-1.44030023e+00, -4.61926997e-01],
         [ 1.84366786e+00,  1.16136181e+00]],

        [[ 5.16262102e+00,  2.91189969e-01],
         [-3.45778775e+00, -9.45850611e-01],
         [ 7.68125117e-01,  2.53853130e+00]],

        [[-2.44260740e+00, -1.82026410e+00],
         [-4.40699959e+00, -2.13102770e+00],
         [-5.27596617e+00,  2.64644742e+00]],

        [[-3.66289282e+00,  3.12677574e+00],
         [-3.46074724e+00,  1.58756196e+00],
         [-2.15412569e+00,  4.11534929e+00]],

        [[ 4.52024746e+00,  6.84665024e-01],
         [-2.86772156e+00, -6.50373459e-01],
         [ 4.91524935e-01,  2.796

In [37]:
#For convenience, added log_prob function to decoder
decoder.get_log_probs(sample_out, params_out, ref_out)

<tf.Tensor: shape=(2, 60), dtype=float32, numpy=
array([[ -3.7162056 ,  -5.434373  ,  -8.160625  ,  -5.383893  ,
         -4.110385  ,  -7.2790594 ,  -4.795574  ,  -4.9891686 ,
         -6.841637  ,  -5.5094543 ,  -2.9828117 ,  -4.8690147 ,
         -2.8658853 ,  -3.9101076 ,  -8.185442  ,  -5.489682  ,
         -4.2537727 ,  -5.0451922 ,   0.47895908,  -1.3443118 ,
         -6.2717085 ,  -2.6589646 ,  -7.361408  , -10.722578  ,
         -4.972938  ,  -5.123552  ,  -1.8480511 ,  -0.07007366,
         -2.979459  ,  -4.9660606 ,  -4.1865935 ,  -6.3227124 ,
         -7.1134663 ,  -4.8312716 ,  -8.864256  ,  -7.8746586 ,
         -7.081497  ,  -5.9228315 , -10.651675  ,  -2.3836927 ,
         -5.567163  ,  -1.4994144 ,  -5.755042  ,  -3.5966048 ,
         -7.5979285 ,  -6.7693973 ,  -9.030688  ,  -7.5848885 ,
         -4.9578867 , -12.774134  , -10.246046  ,  -5.979232  ,
         -6.6009474 ,  16.34864   ,  -3.5216506 ,          nan,
         -4.90421   ,  -5.7627025 ,  -5.5139985 ,  -8.6

In [38]:
#Above tested generation, now test if have training data
training_data = tf.random.normal((2, 60, 3))

In [39]:
sample_out, params_out, ref_out = decoder(solute_coords, train_data=training_data)
print(sample_out.shape, params_out.shape, ref_out.shape)

(2, 60, 3) (2, 60, 3, 2) (2, 60, 3)


In [40]:
training_data

<tf.Tensor: shape=(2, 60, 3), dtype=float32, numpy=
array([[[-0.42785975, -0.7624844 ,  1.2198808 ],
        [-0.16642877, -2.3913453 ,  0.07976497],
        [-0.19348079, -0.5546516 ,  1.4156916 ],
        [ 0.19515382,  1.485254  ,  0.38537183],
        [-0.7376194 , -1.1319193 , -1.0823684 ],
        [-0.5561684 , -0.12261876,  0.08801163],
        [-0.5724013 ,  2.9457366 , -0.47197008],
        [ 0.24609105,  0.14984937, -0.6331455 ],
        [-1.9387635 ,  1.8798894 , -0.2645729 ],
        [-0.79543096, -0.23932967,  0.9587716 ],
        [ 0.7815791 ,  1.0798616 , -0.03349392],
        [-0.09923479, -0.7703586 , -0.793145  ],
        [ 1.2875804 ,  0.22833115,  0.21169512],
        [ 1.1668521 ,  0.12103882, -0.9839402 ],
        [-0.18979673,  0.32948118, -1.7640039 ],
        [-0.30397427, -0.8267363 , -0.66971993],
        [ 0.5779909 ,  1.3371624 ,  0.55148214],
        [ 0.2226666 ,  0.06426026,  0.7262613 ],
        [ 0.22647108,  0.9443663 ,  1.104705  ],
        [ 0.03900

In [41]:
sample_out

<tf.Tensor: shape=(2, 60, 3), dtype=float32, numpy=
array([[[-0.16642874, -2.3913453 ,  0.07976496],
        [-1.175374  , -0.91729075,  1.1604546 ],
        [-0.3084364 , -1.9193932 , -0.05191147],
        [ 0.6881696 , -0.9082798 ,  0.4718012 ],
        [-1.8997829 , -0.22091854, -0.9894229 ],
        [-1.9905825 ,  0.43481147, -0.27657443],
        [-0.3528995 , -0.7673694 ,  0.67086715],
        [-1.4533603 , -0.09526485, -0.31282353],
        [-1.8507493 , -1.1827408 , -0.40269694],
        [ 0.15134084,  0.42607367,  0.6000903 ],
        [-0.19348073, -0.5546516 ,  1.4156916 ],
        [ 1.6176678 ,  0.34661734, -0.06845785],
        [ 1.2875803 ,  0.22833121,  0.21169512],
        [-0.49906027, -0.7017304 , -0.91984546],
        [-1.099215  ,  0.6068254 ,  0.7844246 ],
        [-0.42785975, -0.7624844 ,  1.2198808 ],
        [-0.79543096, -0.23932964,  0.9587716 ],
        [-0.08785272, -0.11392468,  1.0688022 ],
        [ 0.60379505, -1.280178  , -0.86590904],
        [-1.93876

In [42]:
sample_out - ref_out

<tf.Tensor: shape=(2, 60, 3), dtype=float32, numpy=
array([[[ 5.6954354e-01, -6.4043558e-01,  3.1394115e-01],
        [-4.3940175e-01,  8.3361894e-01,  1.3946308e+00],
        [ 4.2753589e-01, -1.6848350e-01,  1.8226472e-01],
        [ 1.5500574e+00, -2.0019417e+00,  4.0143937e-01],
        [-1.0378951e+00, -1.3145804e+00, -1.0597848e+00],
        [-1.1286947e+00, -6.5885043e-01, -3.4693626e-01],
        [ 9.7289318e-01, -6.7610025e-01,  6.7657202e-01],
        [-1.2756765e-01, -3.9957315e-03, -3.0711865e-01],
        [-5.2495658e-01, -1.0914717e+00, -3.9699206e-01],
        [-5.0659722e-01,  1.1586285e+00,  5.1886719e-01],
        [-8.5141879e-01,  1.7790318e-01,  1.3344685e+00],
        [ 9.5972973e-01,  1.0791721e+00, -1.4968097e-01],
        [ 1.0086298e+00, -1.0047638e+00, -2.6345447e-02],
        [-7.7801079e-01, -1.9348254e+00, -1.1578860e+00],
        [-1.3781655e+00, -6.2626964e-01,  5.4638404e-01],
        [-2.6143101e-01,  1.6288610e+00,  1.1401159e+00],
        [ 3.7994307e

In [43]:
params_out

<tf.Tensor: shape=(2, 60, 3, 2), dtype=float32, numpy=
array([[[[-2.09930122e-01, -6.92374706e-01],
         [-1.55831778e+00,  2.16812992e+00],
         [ 1.12965584e+00, -3.02582324e-01]],

        [[-6.25643134e-02,  2.19550633e+00],
         [ 3.74187052e-01, -7.92216480e-01],
         [ 9.76022601e-01,  2.28536820e+00]],

        [[-1.32504940e-01,  8.95510912e-01],
         [-1.40616453e+00, -4.03023303e-01],
         [ 1.83116710e+00,  1.11648214e+00]],

        [[ 1.77213144e+00, -3.57632041e-02],
         [-3.75690627e+00, -1.05026290e-02],
         [ 2.79838115e-01, -5.86498082e-01]],

        [[-2.09946609e+00,  1.95201814e+00],
         [-2.25395155e+00, -1.72324443e+00],
         [-2.41547441e+00,  5.51177204e-01]],

        [[-1.29563391e+00,  3.18254209e+00],
         [-1.18598413e+00, -5.66563010e-03],
         [ 7.69680589e-02,  1.55898786e+00]],

        [[ 9.66130018e-01,  1.95487738e-02],
         [-1.57097673e+00,  5.80088019e-01],
         [ 4.86190110e-01,  1.184

In [44]:
decoder.get_log_probs(sample_out, params_out, ref_out)

<tf.Tensor: shape=(2, 60), dtype=float32, numpy=
array([[  -4.4489527,   -4.851021 ,   -5.216556 ,   -4.03549  ,
          -6.2283096,   -5.2839494,   -3.8785415,   -4.0942726,
          -4.0884147,   -2.1732135,  -10.851208 ,   -6.803324 ,
          -4.022485 ,   -3.0416205,   -5.1067777,   -5.159482 ,
          -4.452944 ,   -4.000231 ,   -4.41047  ,  -35.78415  ,
          -5.5126457,   -3.8589106,   -7.664651 ,   -4.863558 ,
          -3.1869595,   -4.934106 ,  -25.187737 ,   -9.815769 ,
          -5.8228216,   -3.2604923,  -33.913292 ,   -3.3540692,
         -11.815083 ,   -9.553243 ,   -1.621703 ,   -3.543353 ,
          -3.5333862,   -1.9590777,   -4.5257826,   -3.3852484,
          -3.6043162,   -4.1136065,   -3.1585479,   -2.178435 ,
          -4.2838697,   -4.652062 ,   -5.618073 ,  -11.733508 ,
         -12.120133 ,   -5.4675255,   -4.6535106,   -4.7930765,
          -5.668269 ,  -10.200572 ,   -9.043009 ,   -5.2856846,
          -7.781201 ,  -14.701249 ,   -8.860632 ,  -17.