In [1]:
import numpy as np
import tensorflow as tf
import tensorflow_probability as tfp

from importlib import reload

In [2]:
import particle_decoding as pd

In [3]:
reload(pd)

<module 'particle_decoding' from '/Users/jim2/bin/libVAE/particle_decoding.py'>

## Transformations

In [4]:
test_coords = np.array([[[1.0, 1.0, 1.0],
                        [-1.0, 1.0, 0.0],
                        [0.5, 0.5, 1.0]],
                        [[1.0, 1.0, 1.0],
                        [-1.0, 1.0, 0.0],
                        [0.5, 0.5, 1.0]]])

In [5]:
pd.identity_transform(test_coords)

<tf.Tensor: shape=(2, 3, 3), dtype=float64, numpy=
array([[[ 1. ,  1. ,  1. ],
        [-1. ,  1. ,  0. ],
        [ 0.5,  0.5,  1. ]],

       [[ 1. ,  1. ,  1. ],
        [-1. ,  1. ,  0. ],
        [ 0.5,  0.5,  1. ]]])>

In [6]:
pd.spherical_transform(test_coords)

<tf.Tensor: shape=(2, 3, 3), dtype=float64, numpy=
array([[[1.73205081, 0.78539816, 0.95531662],
        [1.41421356, 2.35619449, 1.57079633],
        [1.22474487, 0.78539816, 0.61547971]],

       [[1.73205081, 0.78539816, 0.95531662],
        [1.41421356, 2.35619449, 1.57079633],
        [1.22474487, 0.78539816, 0.61547971]]])>

In [7]:
pd.spherical_transform(test_coords, dist_sq=tf.reduce_sum(test_coords * test_coords, axis=-1))

<tf.Tensor: shape=(2, 3, 3), dtype=float64, numpy=
array([[[1.73205081, 0.78539816, 0.95531662],
        [1.41421356, 2.35619449, 1.57079633],
        [1.22474487, 0.78539816, 0.61547971]],

       [[1.73205081, 0.78539816, 0.95531662],
        [1.41421356, 2.35619449, 1.57079633],
        [1.22474487, 0.78539816, 0.61547971]]])>

In [8]:
pd.spherical_transform(pd.spherical_transform(test_coords), reverse=True)

<tf.Tensor: shape=(2, 3, 3), dtype=float64, numpy=
array([[[ 1.00000000e+00,  1.00000000e+00,  1.00000000e+00],
        [-1.00000000e+00,  1.00000000e+00,  8.65956056e-17],
        [ 5.00000000e-01,  5.00000000e-01,  1.00000000e+00]],

       [[ 1.00000000e+00,  1.00000000e+00,  1.00000000e+00],
        [-1.00000000e+00,  1.00000000e+00,  8.65956056e-17],
        [ 5.00000000e-01,  5.00000000e-01,  1.00000000e+00]]])>

## Masking by distance

In [9]:
test_coords = np.array([[[1.0, 1.0, 1.0],
                        [-1.0, 1.0, 0.0],
                        [0.5, 0.5, 1.0]],
                        [[1.0, 1.0, 1.0],
                        [-1.0, 1.0, 0.0],
                        [0.5, 0.5, 1.0]]])

ref_coords = np.array([[[0.0, 0.0, 0.0]],
                       [[0.0, 0.0, 1.0]]])

In [10]:
distance_mask = pd.DistanceMask()

In [11]:
distance_mask(ref_coords, test_coords, k_neighbors=2)

(<tf.Tensor: shape=(2, 2, 3), dtype=float64, numpy=
 array([[[ 0.5,  0.5,  1. ],
         [-1. ,  1. ,  0. ]],
 
        [[ 0.5,  0.5,  0. ],
         [ 1. ,  1. ,  0. ]]])>,
 <tf.Tensor: shape=(2, 2), dtype=int32, numpy=
 array([[2, 1],
        [2, 0]], dtype=int32)>,
 <tf.Tensor: shape=(2, 2), dtype=float64, numpy=
 array([[1.5, 2. ],
        [0.5, 2. ]])>)

In [12]:
local_coords, near_inds, dist_sq = distance_mask(ref_coords, test_coords, k_neighbors=2)
pd.spherical_transform(local_coords, dist_sq=dist_sq)

<tf.Tensor: shape=(2, 2, 3), dtype=float64, numpy=
array([[[1.22474487, 0.78539816, 0.61547971],
        [1.41421356, 2.35619449, 1.57079633]],

       [[0.70710678, 0.78539816, 1.57079633],
        [1.41421356, 0.78539816, 1.57079633]]])>

In [13]:
#Need to also test if ref included in input coordinates
ref_coords = np.array([[[-1.0, 1.0, 0.0]],
                       [[1.0, 1.0, 1.0]]])

distance_mask(ref_coords, test_coords, k_neighbors=1, ref_included=True)

(<tf.Tensor: shape=(2, 1, 3), dtype=float64, numpy=
 array([[[ 1.5, -0.5,  1. ]],
 
        [[-0.5, -0.5,  0. ]]])>,
 <tf.Tensor: shape=(2, 1), dtype=int32, numpy=
 array([[2],
        [2]], dtype=int32)>,
 <tf.Tensor: shape=(2, 1), dtype=float64, numpy=
 array([[3.5],
        [0.5]])>)

In [14]:
distance_mask(ref_coords, test_coords, k_neighbors=1, ref_included=False)

(<tf.Tensor: shape=(2, 1, 3), dtype=float64, numpy=
 array([[[0., 0., 0.]],
 
        [[0., 0., 0.]]])>,
 <tf.Tensor: shape=(2, 1), dtype=int32, numpy=
 array([[1],
        [0]], dtype=int32)>,
 <tf.Tensor: shape=(2, 1), dtype=float64, numpy=
 array([[-0.],
        [-0.]])>)

In [15]:
#Now try distances with a periodic simulation box
test_coords = np.array([[[1.0, 1.0, 1.0],
                        [-1.0, 1.0, 0.0],
                        [0.5, 0.5, 1.0]],
                        [[1.0, 1.0, 1.0],
                        [-1.0, 1.0, 0.0],
                        [0.5, 0.5, 1.0]]])

ref_coords = np.array([[[0.0, 0.0, 0.0]],
                       [[0.0, 0.0, 1.0]]])

distance_mask = pd.DistanceMask(box_lengths=np.array([3.0, 3.0, 3.0]))

In [16]:
distance_mask(ref_coords, test_coords, k_neighbors=2)

(<tf.Tensor: shape=(2, 2, 3), dtype=float64, numpy=
 array([[[ 0.5,  0.5,  1. ],
         [-1. ,  1. ,  0. ]],
 
        [[ 0.5,  0.5,  0. ],
         [ 1. ,  1. ,  0. ]]])>,
 <tf.Tensor: shape=(2, 2), dtype=int32, numpy=
 array([[2, 1],
        [2, 0]], dtype=int32)>,
 <tf.Tensor: shape=(2, 2), dtype=float64, numpy=
 array([[1.5, 2. ],
        [0.5, 2. ]])>)

In [17]:
local_coords, near_inds, dist_sq = distance_mask(ref_coords, test_coords, k_neighbors=2)
pd.spherical_transform(local_coords, dist_sq=dist_sq)

<tf.Tensor: shape=(2, 2, 3), dtype=float64, numpy=
array([[[1.22474487, 0.78539816, 0.61547971],
        [1.41421356, 2.35619449, 1.57079633]],

       [[0.70710678, 0.78539816, 1.57079633],
        [1.41421356, 0.78539816, 1.57079633]]])>

In [18]:
#Need to also test if ref included in input coordinates
ref_coords = np.array([[[-1.0, 1.0, 0.0]],
                       [[1.0, 1.0, 1.0]]])

distance_mask(ref_coords, test_coords, k_neighbors=1, ref_included=True)

(<tf.Tensor: shape=(2, 1, 3), dtype=float64, numpy=
 array([[[-1. ,  0. ,  1. ]],
 
        [[-0.5, -0.5,  0. ]]])>,
 <tf.Tensor: shape=(2, 1), dtype=int32, numpy=
 array([[0],
        [2]], dtype=int32)>,
 <tf.Tensor: shape=(2, 1), dtype=float64, numpy=
 array([[2. ],
        [0.5]])>)

In [19]:
distance_mask(ref_coords, test_coords, k_neighbors=1, ref_included=False)

(<tf.Tensor: shape=(2, 1, 3), dtype=float64, numpy=
 array([[[0., 0., 0.]],
 
        [[0., 0., 0.]]])>,
 <tf.Tensor: shape=(2, 1), dtype=int32, numpy=
 array([[1],
        [0]], dtype=int32)>,
 <tf.Tensor: shape=(2, 1), dtype=float64, numpy=
 array([[-0.],
        [-0.]])>)

## Creating probability distributions

In [20]:
params = np.array([[[[0.0, 1.0], [10.0, 1.0], [-10.0, 1.0]],
                    [[0.0, 1.0], [0.0, 2.0], [0.0, 10.0]],
                    [[5.0, 1.0], [5.0, 2.0], [5.0, 10.0]]],
                   [[[0.0, 1.0], [-10.0, 1.0], [10.0, 1.0]],
                    [[0.0, 10.0], [0.0, 2.0], [0.0, 1.0]],
                    [[-5.0, 10.0], [-5.0, 2.0], [-5.0, 1.0]]]])

dist_list = [tfp.distributions.Normal, tfp.distributions.Normal, tfp.distributions.Normal]

In [21]:
joint_dist = pd.create_dist(params, dist_list)

In [22]:
sample = tf.stack(joint_dist.sample(), axis=-1)
print(sample)

tf.Tensor(
[[[ -1.80460269  11.46820636 -10.47440171]
  [  1.50618439  -2.54008052  -1.12336689]
  [  5.19470539   6.37537853  -4.60627149]]

 [[ -0.63418889  -8.93917175   9.22914558]
  [  2.57862519  -0.72715157  -0.24517709]
  [-11.11769958  -4.32197956  -4.48107333]]], shape=(2, 3, 3), dtype=float64)


In [23]:
joint_dist.log_prob(tf.unstack(sample, axis=-1))

<tf.Tensor: shape=(2, 3), dtype=float64, numpy=
array([[-5.57545448, -7.69965449, -6.46936349],
       [-3.81769993, -5.88194399, -6.13178552]])>

In [24]:
np.sum(-0.5*((sample[-1, 1, :] - params[-1, 1, :, 0])**2)/(params[-1, 1, :, 1]**2) - np.log(params[-1, 1, :, 1]) - 0.5*np.log(2.0*np.pi))

-5.881943990680755

## Solvation neural networks

In [25]:
#Input can be solute or solvent coordinates, with extra_inputs being the other
#Start by testing without augmented input (only relevant for first network to predict layer of solvent for first solute)
solute_net = pd.SolvationNet((8, 3), out_event_dims=2, hidden_dim=50, n_hidden=3, augment_input=False)

In [26]:
solute_coords = tf.random.normal((1, 2, 3))

In [27]:
params, shifts = solute_net(solute_coords)
print(params, shifts)

tf.Tensor(
[[[[-0.26300415 -0.04491773]
   [ 0.01559608 -0.04290709]
   [ 0.09154715 -0.07192329]]

  [[-0.00531513  0.02146923]
   [-0.04789436  0.20241836]
   [-0.01482152  0.11503346]]

  [[ 0.09345279 -0.06281184]
   [ 0.18041205 -0.01978392]
   [ 0.02940929 -0.07436828]]

  [[-0.06230192 -0.12048668]
   [ 0.15268534  0.15264048]
   [ 0.11495566 -0.08427886]]

  [[ 0.0935947   0.06728826]
   [-0.08921711  0.00637102]
   [ 0.08959395  0.115782  ]]

  [[-0.21189848  0.00038444]
   [ 0.06282714  0.15107004]
   [ 0.12764709 -0.32938448]]

  [[-0.19525407  0.13234133]
   [ 0.05902193 -0.08482008]
   [-0.03620581 -0.0980647 ]]

  [[ 0.10565148 -0.06509417]
   [-0.1028333   0.03888365]
   [-0.21221206  0.01687928]]]], shape=(1, 8, 3, 2), dtype=float32) tf.Tensor(
[[[[ 0.49415493  0.17960471]
   [ 0.2806256   0.61243856]
   [-0.08388991 -0.12048071]]

  [[-0.05935885  0.01016361]
   [ 0.68902546 -0.774578  ]
   [ 0.8546278   0.41355968]]

  [[ 0.1572334   0.24603935]
   [-1.0171     -0.092

In [28]:
dist = pd.create_dist(params + shifts, [tfp.distributions.Normal]*3,
                      param_transforms=[tf.identity, lambda x: tf.math.exp(0.5*x)])
sample = tf.stack(dist.sample(), axis=-1)
print(sample)

tf.Tensor(
[[[-1.3431163   1.2904217   1.4851311 ]
  [ 0.5081883  -0.15877563 -0.04802763]
  [ 0.11033186  0.7790389   0.16883728]
  [-1.6656454   0.9140979   1.0123812 ]
  [ 1.3491056   0.50307786  0.3068582 ]
  [ 1.3837591  -1.1227186   0.5600258 ]
  [ 0.25876427 -0.5565603   0.68984276]
  [ 0.19406873 -0.57859665 -0.599913  ]]], shape=(1, 8, 3), dtype=float32)


In [29]:
#First set of 3 (first particle) should not change shifts, even with different input due to autoregressive
#All others should change, though
new_shifts = solute_net(solute_coords, sampled_input=sample)
print(new_shifts)

tf.Tensor(
[[[[ 4.94154930e-01  1.79604709e-01]
   [ 2.80625612e-01  6.12438560e-01]
   [-8.38899091e-02 -1.20480709e-01]]

  [[-6.08667508e-02  2.02550292e-02]
   [ 6.87840223e-01 -7.79687464e-01]
   [ 8.45931768e-01  4.23155040e-01]]

  [[ 1.58431262e-01  2.68051893e-01]
   [-1.01950407e+00 -7.72175267e-02]
   [-4.36457157e-01  5.72175145e-01]]

  [[-8.31106901e-01 -7.97481686e-02]
   [ 5.89055955e-01  2.02732414e-01]
   [ 7.79898226e-01  8.86875272e-01]]

  [[-4.99849409e-01  1.76101923e-02]
   [-2.16020614e-01 -5.33889771e-01]
   [-4.37884390e-01 -3.51978481e-01]]

  [[ 4.15069938e-01  4.69506025e-01]
   [ 5.71261495e-02  7.07432628e-05]
   [ 2.90810406e-01 -1.12195916e-01]]

  [[-1.24819875e-01  1.50045097e-01]
   [-4.24999595e-01 -7.93636441e-01]
   [-1.59736082e-01 -2.40971312e-01]]

  [[-3.94103199e-01 -1.41889465e+00]
   [ 1.75889656e-01 -4.85093653e-01]
   [-7.87071407e-01  1.76750004e-01]]]], shape=(1, 8, 3, 2), dtype=float32)


In [30]:
#And now with solvent inputs (remember, could flip if wanted)
solute_net_solv = pd.SolvationNet((8, 3), out_event_dims=2, hidden_dim=50, n_hidden=3, augment_input=True)

In [31]:
solvent_coords = tf.random.normal((1, 6, 3))

In [32]:
params, shifts = solute_net_solv(solute_coords, extra_coords=solvent_coords)
print(params, shifts)

tf.Tensor(
[[[[ 0.01299383 -0.16666514]
   [-0.22981699  0.16155532]
   [-0.3253996   0.02511324]]

  [[ 0.1638951  -0.23893051]
   [ 0.26900056  0.11380103]
   [ 0.08642846 -0.07228004]]

  [[ 0.22516319 -0.30298397]
   [ 0.19191541  0.4650655 ]
   [ 0.25404316 -0.09493596]]

  [[ 0.05824015  0.11704068]
   [-0.05323146 -0.04289784]
   [-0.19894265 -0.00670012]]

  [[ 0.06362901  0.010835  ]
   [-0.03701606  0.32602507]
   [ 0.2718621  -0.08033212]]

  [[-0.09240989  0.10615189]
   [-0.20541821  0.16596374]
   [-0.26124695  0.3199701 ]]

  [[ 0.2704656  -0.2546274 ]
   [-0.00376546 -0.07360049]
   [ 0.08219222 -0.10672975]]

  [[-0.08374881 -0.31508148]
   [ 0.00753965 -0.07012221]
   [-0.26083344 -0.3678459 ]]]], shape=(1, 8, 3, 2), dtype=float32) tf.Tensor(
[[[[-4.0948918e-01 -1.2868618e+00]
   [-1.3320768e+00  9.8482639e-02]
   [ 1.9989736e-02 -2.7102521e-01]]

  [[ 5.3926492e-01  1.5616672e+00]
   [-2.2994334e-01  1.6195875e+00]
   [ 3.8140300e-01 -5.6129503e-01]]

  [[ 1.2284687e

In [33]:
dist = pd.create_dist(params + shifts, [tfp.distributions.Normal]*3,
                      param_transforms=[tf.identity, lambda x: tf.math.exp(0.5*x)])
sample = tf.stack(dist.sample(), axis=-1)
print(sample)

tf.Tensor(
[[[-0.35613245  1.264949   -0.50865734]
  [ 1.9285396  -0.02806605  0.9506392 ]
  [-0.6086298  -0.10444391  0.9364103 ]
  [-0.02912618  2.0028615   1.4573545 ]
  [-1.1689296   1.7372679   0.8193018 ]
  [-1.3703885  -1.6392976  -2.9117887 ]
  [ 2.0111156  -0.94677645  0.5545196 ]
  [-0.2767012   0.9017415  -0.932635  ]]], shape=(1, 8, 3), dtype=float32)


In [34]:
new_shifts = solute_net_solv(solute_coords, extra_coords=solvent_coords, sampled_input=sample)
print(new_shifts)

tf.Tensor(
[[[[-0.40948918 -1.2868618 ]
   [-1.3320768   0.09848264]
   [ 0.01998974 -0.2710252 ]]

  [[ 0.54069465  1.5628897 ]
   [-0.22852719  1.617799  ]
   [ 0.3844308  -0.5609374 ]]

  [[ 0.10627145 -0.45032242]
   [ 1.2721782   0.45240206]
   [ 0.74380195  0.31356728]]

  [[-0.29748812 -0.361576  ]
   [ 1.0889672  -0.70162094]
   [ 1.151593   -0.0942229 ]]

  [[ 0.48315272  0.8535464 ]
   [ 0.5731375  -0.02188936]
   [-0.36169398  0.139644  ]]

  [[-0.20386034  0.09887603]
   [-1.3147835  -0.56878173]
   [-0.26400772 -0.92518926]]

  [[ 1.7790726  -1.0535601 ]
   [-0.38271642  0.7601023 ]
   [ 0.8445685  -1.6633186 ]]

  [[-0.0816721  -0.8495386 ]
   [ 0.5954318   0.3435803 ]
   [ 0.11437541 -1.511545  ]]]], shape=(1, 8, 3, 2), dtype=float32)


In [35]:
#Now that it's been called (so built) with extra_coords provided, must provide always with this layer
new_shifts = solute_net_solv(solute_coords, sampled_input=sample)
print(new_shifts)

ValueError: Input 0 of layer dense_18 is incompatible with the layer: expected axis -1 of input shape to have value 24 but received input with shape (1, 6)

In [36]:
#Now pretend we're working with a central solvent particle and adding one more
#(conditioned on other solutes and solvent already around)
solvent_net = pd.SolvationNet((1, 3), out_event_dims=2, hidden_dim=50, n_hidden=3, augment_input=True)

In [37]:
params, shifts = solvent_net(solvent_coords, extra_coords=solute_coords)
print(params, shifts)

tf.Tensor(
[[[[ 0.88141865 -0.3517618 ]
   [-0.56746614 -0.27149877]
   [-0.4007211  -0.30483955]]]], shape=(1, 1, 3, 2), dtype=float32) tf.Tensor(
[[[[-0.7444008   0.20916139]
   [-0.77712953 -0.86723846]
   [-1.5608668  -3.0933537 ]]]], shape=(1, 1, 3, 2), dtype=float32)


## Full particle decoder

In [51]:
decoder = pd.ParticleDecoder(3, 2,
                             box_lengths=np.array([3.0, 3.0, 3.0]),
                             k_solute_neighbors=3,
                             k_solvent_neighbors=3)

In [52]:
solute_coords = tf.random.normal((2, 5, 3))

In [53]:
sample_out, params_out, ref_out = decoder(solute_coords)
print(sample_out.shape, params_out.shape, ref_out.shape)

(2, 60, 3) (2, 60, 3, 2) (2, 60, 3)


In [54]:
sample_out - ref_out

<tf.Tensor: shape=(2, 60, 3), dtype=float32, numpy=
array([[[ 8.8234580e-01, -3.8947886e-01,  1.4511508e+00],
        [ 2.9156053e-01,  3.0945319e-01,  7.1341336e-01],
        [-1.6083878e+00,  7.4755245e-01,  1.4846696e+00],
        [-3.0983088e+00, -1.3124341e+00, -7.8231931e-02],
        [ 7.2886348e-03, -2.1157482e+00, -1.3389307e+00],
        [-4.4196048e-01, -2.2615576e-01, -1.1380413e+00],
        [ 1.0045214e+00, -6.4560354e-01,  2.1094420e+00],
        [-1.0049415e+00,  1.4284598e+00, -1.9368618e+00],
        [ 1.6356500e+00, -1.7617359e+00, -2.8623223e-01],
        [ 7.2864485e-01, -2.0122251e-01,  9.7565842e-01],
        [-7.6967323e-01,  2.5229597e-01, -1.1702856e+00],
        [-1.2146485e-01,  2.8170800e-01,  1.4449694e+00],
        [-1.9597758e+00,  2.1649418e+00, -6.2073970e-01],
        [ 1.0124104e+00, -1.2154644e+00, -2.1656168e+00],
        [-2.0964320e+00,  2.9280782e+00,  2.9226470e-01],
        [-2.9261820e+00, -1.5752052e+00, -1.0938907e-01],
        [ 1.3262047e

In [55]:
params_out

<tf.Tensor: shape=(2, 60, 3, 2), dtype=float32, numpy=
array([[[[ 2.24267662e-01,  4.08133417e-02],
         [ 5.49771786e-01, -1.02569066e-01],
         [ 4.34912667e-02,  1.32377994e+00]],

        [[ 3.48392367e-01,  3.09946418e-01],
         [-7.34403580e-02, -1.07901558e-01],
         [ 3.66896302e-01,  7.32461810e-01]],

        [[-6.35847986e-01, -1.07149613e+00],
         [-1.20127559e-01, -2.47392476e-01],
         [-1.13766372e-01,  6.83178306e-01]],

        [[-1.60088032e-01,  9.28622365e-01],
         [-2.95259446e-01, -2.59135187e-01],
         [-1.40495336e+00, -3.54035497e-01]],

        [[-2.97596663e-01, -9.04355049e-02],
         [-1.75277090e+00, -1.71590877e+00],
         [-1.57865167e-01, -3.64035368e-02]],

        [[-3.56974870e-01,  2.64566362e-01],
         [ 1.13862216e+00,  4.03557777e-01],
         [ 5.91090918e-02,  7.41369903e-01]],

        [[ 4.72817928e-01, -5.37541270e-01],
         [ 3.60018909e-02, -4.69236851e-01],
         [ 1.43122047e-01,  4.742

In [56]:
#For convenience, added log_prob function to decoder
decoder.get_log_probs(sample_out, params_out, ref_out)

<tf.Tensor: shape=(2, 60), dtype=float32, numpy=
array([[-4.3481083, -3.3357706, -4.9469986, -6.544323 , -2.9760199,
        -4.427817 , -4.3070583, -2.595    , -4.852415 , -3.381824 ,
        -3.5272987, -3.0901704, -5.885207 , -5.3657303, -8.399137 ,
        -5.0011015, -4.1334524, -3.7733693, -2.7852638, -4.192324 ,
        -3.6345158, -1.8082442, -4.2222342, -5.2243795, -3.5352893,
        -3.9909248, -4.4528904, -3.942157 , -2.0438998, -2.794443 ,
        -3.6165953, -2.9359753, -4.6390123, -1.7163305, -4.7450447,
        -3.1110125, -6.3029275, -3.494632 , -5.5376425, -2.4133673,
        -3.102149 , -6.5007324, -8.002174 , -4.847277 , -3.234931 ,
        -4.4653387, -6.2435427, -5.2144966, -5.929131 , -3.0619698,
        -7.758571 , -2.4337087, -5.2422132, -2.5941892, -3.2893748,
        -5.5963535, -3.7651231, -7.9952602, -3.341042 , -3.064176 ],
       [-3.6814313, -4.9668026, -3.7147474, -3.903227 , -4.0675516,
        -4.4615602, -5.035186 , -8.3018675, -3.7997017, -3.8644094

In [57]:
#Above tested generation, now test if have training data
training_data = tf.random.normal((2, 60, 3))

In [58]:
sample_out, params_out, ref_out = decoder(solute_coords, train_data=training_data)
print(sample_out.shape, params_out.shape, ref_out.shape)

(2, 60, 3) (2, 60, 3, 2) (2, 60, 3)


In [59]:
training_data

<tf.Tensor: shape=(2, 60, 3), dtype=float32, numpy=
array([[[-1.20045257e+00, -1.42873979e+00, -2.17751598e+00],
        [-3.87140542e-01,  1.99903774e+00, -7.15324581e-01],
        [ 1.31742072e+00, -7.14077353e-01, -1.66773129e+00],
        [ 6.62990332e-01, -2.10018110e+00, -1.10725713e+00],
        [ 3.27913791e-01,  4.01114561e-02, -2.00809002e+00],
        [-1.28795668e-01, -6.55543089e-01,  5.53375594e-02],
        [ 9.59794819e-01,  1.30048037e+00,  8.91452491e-01],
        [ 1.61246210e-01,  9.80594754e-01, -3.38876158e-01],
        [-6.89834535e-01, -3.57737154e-01,  1.50017822e+00],
        [ 3.04798901e-01, -5.50457418e-01, -7.34910369e-01],
        [ 1.60559452e+00, -4.80351269e-01,  6.32781744e-01],
        [-8.74686956e-01, -9.14747417e-02, -2.54110903e-01],
        [ 1.45220923e+00, -1.43750179e+00,  8.79997015e-01],
        [ 4.03347373e-01, -2.51566708e-01, -7.56733000e-01],
        [ 6.52765810e-01, -1.38391867e-01, -2.47616386e+00],
        [ 1.81059107e-01,  1.2982

In [60]:
sample_out

<tf.Tensor: shape=(2, 60, 3), dtype=float32, numpy=
array([[[-1.56113052e+00,  2.93025136e-01,  1.66974664e+00],
        [-8.70299637e-01,  7.28987008e-02,  1.69392967e+00],
        [-1.78796339e+00,  2.14231014e-03,  1.32722211e+00],
        [ 8.63396168e-01,  5.31996548e-01, -7.80187130e-01],
        [ 8.15135002e-01, -6.03680611e-02, -2.54538774e+00],
        [-1.70487070e+00,  6.80459023e-01, -1.01139820e+00],
        [ 6.62990332e-01,  8.99818897e-01, -1.10725713e+00],
        [-3.87140542e-01,  1.99903774e+00, -3.71532464e+00],
        [-3.82321239e-01,  2.28464961e-01, -1.34148395e+00],
        [-5.39047480e-01,  9.11137164e-02, -1.00185990e+00],
        [-6.43610835e-01, -1.61148727e-01, -8.23018402e-02],
        [-2.50694752e+00, -1.07600451e-01, -4.13690627e-01],
        [-3.87429714e-01, -7.84877360e-01, -1.33983290e+00],
        [ 3.27913761e-01,  4.01114225e-02, -2.00809002e+00],
        [ 2.37035847e+00, -7.16777921e-01, -2.81206632e+00],
        [-2.04020524e+00,  1.3004

In [61]:
sample_out - ref_out

<tf.Tensor: shape=(2, 60, 3), dtype=float32, numpy=
array([[[-1.61520362e-01,  4.19605315e-01,  3.41519475e-01],
        [ 5.29310524e-01,  1.99478865e-01,  3.65702510e-01],
        [-3.88353229e-01,  1.28722474e-01, -1.00505352e-03],
        [ 4.48692441e-02,  1.59915745e-01, -4.16733116e-01],
        [-3.39192152e-03, -4.32448864e-01, -2.18193364e+00],
        [-2.52339768e+00,  3.08378220e-01, -6.47944212e-01],
        [ 6.79922938e-01, -1.75326705e-01, -1.19406700e-01],
        [-3.70207906e-01,  9.23892140e-01, -2.72747421e+00],
        [-3.65388602e-01, -8.46680641e-01, -3.53633523e-01],
        [ 1.14137995e+00, -2.72142142e-01, -1.07032955e+00],
        [ 1.03681660e+00, -5.24404585e-01, -1.50771499e-01],
        [-8.26520085e-01, -4.70856309e-01, -4.82160270e-01],
        [-1.31297708e+00, -6.00332618e-02,  1.07900393e+00],
        [-5.97633600e-01,  7.64955521e-01,  4.10746813e-01],
        [ 1.44481111e+00,  8.06617737e-03, -3.93229485e-01],
        [-4.79074717e-01,  1.0074

In [62]:
params_out

<tf.Tensor: shape=(2, 60, 3, 2), dtype=float32, numpy=
array([[[[ 2.24267662e-01,  4.08133417e-02],
         [ 5.49771786e-01, -1.02569066e-01],
         [ 4.34912667e-02,  1.32377994e+00]],

        [[ 3.29815447e-01,  1.60148665e-01],
         [ 4.27395701e-02, -2.17725798e-01],
         [ 2.79500872e-01,  8.11927319e-01]],

        [[-5.85581005e-01, -1.05535936e+00],
         [-8.32149982e-02, -2.04021037e-01],
         [-1.00026250e-01,  7.14432061e-01]],

        [[-2.90359080e-01,  2.16093734e-02],
         [ 1.57296956e-01, -5.74448228e-01],
         [-4.44163680e-02,  1.47815943e-01]],

        [[ 1.54832602e-01,  1.26071382e+00],
         [-3.20507467e-01, -1.77074254e-01],
         [-2.02278066e+00,  2.82913685e-01]],

        [[-2.28163576e+00,  6.54796422e-01],
         [ 6.67632937e-01, -9.69477892e-01],
         [-4.49395657e-01,  2.86307275e-01]],

        [[ 3.10298681e-01, -4.37809765e-01],
         [-5.44017553e-03, -4.43345577e-01],
         [-8.54672492e-03,  8.490

In [63]:
decoder.get_log_probs(sample_out, params_out, ref_out)

<tf.Tensor: shape=(2, 60), dtype=float32, numpy=
array([[-3.4804733, -3.1678662, -2.5701609, -2.669084 , -3.4606636,
        -2.9427586, -2.8716912, -3.4371178, -2.7656944, -2.5175672,
        -3.2088466, -2.70536  , -3.0402045, -2.9625616, -2.632444 ,
        -3.558968 , -3.789453 , -1.4754338, -2.6180592, -2.6638813,
        -3.1842701, -3.8260245, -2.906293 , -4.0445805, -4.580841 ,
        -4.078925 , -3.465204 , -2.8616803, -3.445475 , -2.2325158,
        -3.035531 , -2.9548836, -3.5819585, -3.5294435, -4.377815 ,
        -3.3026743, -2.1143641, -3.8727398, -2.9661775, -4.150196 ,
        -2.5793185, -3.1980333, -2.5645459, -5.153719 , -2.1670895,
        -3.2190661, -3.1405032, -2.547495 , -2.7377806, -4.153712 ,
        -3.3097072, -3.575777 , -3.3787253, -4.6647754, -2.9824586,
        -3.7591698, -3.6043847, -2.8973002, -3.9107597, -4.571171 ],
       [-2.6821904, -3.2742898, -3.5597808, -2.9755921, -3.1752095,
        -2.055773 , -4.3659554, -3.685823 , -2.8662906, -2.889393 

In [64]:
#Try again with more complicated transform function
#Which will also require a more complicated set of sampling distributions
#Don't want to use Gaussian for r, theta, and phi
#Want to use Gamma for r and von Mises for angles
decoder = pd.ParticleDecoder(3, 2,
                             box_lengths=np.array([3.0, 3.0, 3.0]),
                             k_solute_neighbors=3,
                             k_solvent_neighbors=3,
                             tfp_dist_list=[tfp.distributions.Gamma,
                                            tfp.distributions.vonMises,
                                            tfp.distributions.vonMises],
                             param_transforms=[tf.identity, tf.math.exp]
                             coord_transform=pd.spherical_transform)

In [65]:
solute_coords = tf.random.normal((2, 5, 3))

In [66]:
sample_out, params_out, ref_out = decoder(solute_coords)
print(sample_out.shape, params_out.shape, ref_out.shape)

(2, 60, 3) (2, 60, 3, 2) (2, 60, 3)


In [67]:
sample_out - ref_out

<tf.Tensor: shape=(2, 60, 3), dtype=float32, numpy=
array([[[ 8.88745785e-02,  1.27376556e-01,  6.07588291e-01],
        [-1.49032891e+00,  1.72509146e+00, -7.12687492e+00],
        [-3.26570272e-02,  1.17639756e+00,  2.20073152e+00],
        [-1.29945397e-01, -2.30975986e-01, -1.10048425e+00],
        [-7.01887131e-01, -1.23853004e+00, -9.27140951e-01],
        [-4.04428631e-01, -8.79619122e-02,  4.08250284e+00],
        [-9.11954641e-02,  5.07668257e-02,  9.73518848e-01],
        [ 2.26313412e-01, -7.19931364e-01, -1.05501521e+00],
        [ 2.71557999e+00, -1.29662275e+00, -4.09489870e-02],
        [ 1.44245708e+00, -9.36303973e-01,  8.96052718e-01],
        [-2.23563194e-01, -2.60478401e+00, -1.23701763e+00],
        [ 1.69520259e-01,  1.73900843e-01, -1.74681529e-01],
        [-2.18650401e-02, -4.11948681e-01,  1.73623621e-01],
        [-5.72397709e-02,  1.74308300e-01,  1.29768085e+00],
        [-4.95748520e-02, -1.00033283e-01,  9.33373213e-01],
        [ 8.64518046e-01,  1.1426

In [68]:
params_out

<tf.Tensor: shape=(2, 60, 3, 2), dtype=float32, numpy=
array([[[[-2.62904358e+00, -3.37280631e-02],
         [-1.76279449e+00, -1.36267483e+00],
         [-2.69201016e+00,  1.89421904e+00]],

        [[ 3.63301826e+00,  3.36498833e+00],
         [ 2.00444531e+00, -1.69517553e+00],
         [ 3.17204189e+00, -3.89058471e-01]],

        [[-2.67739105e+00, -4.15094280e+00],
         [-1.12657976e+00, -4.79282498e-01],
         [ 2.61853743e+00, -4.50963914e-01]],

        [[ 4.92209166e-01, -8.55443180e-02],
         [ 1.09600258e+00, -1.43001175e+00],
         [ 5.07735133e-01, -1.75398886e-01]],

        [[-1.75333786e+00, -2.34906793e+00],
         [ 7.05622911e-01,  9.58235681e-01],
         [ 3.20190609e-01,  1.89796007e+00]],

        [[-9.36390042e-01,  1.31708765e+00],
         [-7.24987388e-02, -1.00074470e+00],
         [-3.05643749e+00, -2.79399920e+00]],

        [[ 7.25040197e-01,  2.78438032e-02],
         [ 2.10535812e+00,  9.33347881e-01],
         [ 7.67737627e-03, -7.808

In [69]:
#For convenience, added log_prob function to decoder
decoder.get_log_probs(sample_out, params_out, ref_out)

<tf.Tensor: shape=(2, 60), dtype=float32, numpy=
array([[-2.36381378e+01, -3.95041776e+00, -8.59319275e+02,
        -2.66780987e+01, -6.71800461e+01, -9.74896851e+01,
        -2.94395852e+00, -1.12442007e+01, -4.37461823e+02,
        -1.29563370e+01, -1.31200867e+01, -1.85651501e+03,
        -3.58724475e+00, -2.26759505e+00, -5.51518593e+01,
        -5.23937130e+00, -6.12355652e+02, -4.19459198e+02,
        -6.65286636e+00, -5.56846619e+00, -7.28212070e+00,
        -5.27129841e+00, -8.86811066e+00, -3.33912802e+00,
        -6.65000153e+00, -1.61310902e+01, -3.50419712e+00,
        -2.20263718e+02, -3.66534591e+00, -5.89728737e+00,
        -2.82694840e+00, -2.32612133e+00, -7.51310825e+00,
        -1.61103916e+00, -6.80069809e+01, -1.79748268e+01,
        -3.58256340e+00, -1.17125454e+01, -1.01706862e+00,
        -5.05016876e+02, -8.58895838e-01, -6.14946365e+00,
        -7.13764071e-01, -1.04148079e+02, -2.37725282e+00,
        -1.81879210e+00, -6.09531832e+00, -1.39250488e+01,
       

In [70]:
#Above tested generation, now test if have training data
training_data = tf.random.normal((2, 60, 3))

In [71]:
sample_out, params_out, ref_out = decoder(solute_coords, train_data=training_data)
print(sample_out.shape, params_out.shape, ref_out.shape)

(2, 60, 3) (2, 60, 3, 2) (2, 60, 3)


In [72]:
training_data

<tf.Tensor: shape=(2, 60, 3), dtype=float32, numpy=
array([[[-1.92230254e-01, -1.05604017e+00, -4.47135299e-01],
        [ 2.88723230e-01, -7.35843420e-01,  7.74484217e-01],
        [ 3.18407342e-02,  8.28364015e-01,  1.57843813e-01],
        [-1.20446217e+00,  9.72177088e-01, -5.68230927e-01],
        [ 3.32911193e-01, -4.96045202e-01, -7.48315334e-01],
        [-9.64890599e-01,  1.43367922e+00,  7.44469047e-01],
        [-3.80300432e-01,  5.83521664e-01,  1.90933049e+00],
        [-1.39143872e+00,  3.96839589e-01,  3.44763249e-01],
        [ 1.02160347e+00, -1.96630085e+00, -1.27616668e+00],
        [ 2.91529089e-01,  1.37729511e-01,  1.11598156e-01],
        [-1.04255569e+00, -5.86505890e-01, -2.13351083e+00],
        [ 7.37435281e-01, -1.21270275e+00,  1.43747699e+00],
        [ 7.54187047e-01, -3.74138355e-01, -3.36862981e-01],
        [-1.21364474e+00,  1.10764754e+00,  9.12102401e-01],
        [-2.20591962e-01,  3.79345655e-01, -3.76165599e-01],
        [ 1.17201161e+00, -1.2679

In [73]:
sample_out

<tf.Tensor: shape=(2, 60, 3), dtype=float32, numpy=
array([[[ 1.0318983e+00, -1.7096361e+00,  3.0025077e+00],
        [ 1.0762373e+00, -3.7684798e-02, -3.0958407e+00],
        [ 7.5201052e-01,  8.8985741e-01,  3.0192370e+00],
        [ 9.8214352e-01, -1.5612761e+00,  4.1514137e-01],
        [-1.0249257e-02,  3.7133038e-01, -9.6790868e-01],
        [ 6.6114682e-01, -1.5926405e+00,  4.9117267e-01],
        [-1.7572179e+00, -6.1550361e-01, -1.2138773e+00],
        [-2.7649099e-01, -1.1683900e+00, -9.1807795e-01],
        [ 7.3743534e-01, -1.2127028e+00, -1.5625229e+00],
        [ 1.7863553e+00, -1.8923525e+00,  9.1210240e-01],
        [ 3.0346823e-01, -2.7927217e+00,  8.3525908e-01],
        [ 1.6085613e+00, -2.6031604e+00,  3.4476328e-01],
        [-3.8030022e-01,  3.5835216e+00, -1.0906698e+00],
        [ 6.1908090e-01,  1.5729684e+00, -3.6035872e-01],
        [ 7.5418705e-01,  2.6258616e+00, -3.3686298e-01],
        [ 1.0216035e+00, -1.9663008e+00,  4.7238336e+00],
        [ 9.0573198e

In [74]:
sample_out - ref_out

<tf.Tensor: shape=(2, 60, 3), dtype=float32, numpy=
array([[[-1.53887630e-01, -1.21288180e+00,  2.41039920e+00],
        [-1.09548569e-01,  4.59069490e-01, -3.68794918e+00],
        [-4.33775365e-01,  1.38661170e+00,  2.42712855e+00],
        [ 3.31173599e-01, -3.24076653e-01, -9.95674729e-03],
        [-6.61219180e-01,  1.60852981e+00, -1.39300680e+00],
        [ 1.01768970e-02, -3.55441093e-01,  6.60745502e-02],
        [-1.15291965e+00, -6.02803826e-02, -1.58741474e-01],
        [ 3.27807248e-01, -6.13166809e-01,  1.37057900e-01],
        [ 1.34173357e+00, -6.57479525e-01, -5.07387042e-01],
        [ 6.67620778e-01, -7.91652203e-02,  7.09039629e-01],
        [-8.15266252e-01, -9.79534507e-01,  6.32196307e-01],
        [ 4.89826798e-01, -7.89973140e-01,  1.41700491e-01],
        [-6.71610653e-01,  1.44594789e+00, -2.00775921e-01],
        [ 3.27770472e-01, -5.64605355e-01,  5.29535115e-01],
        [ 4.62876618e-01,  4.88287926e-01,  5.53030849e-01],
        [-1.02947950e-02, -2.5666

In [75]:
params_out

<tf.Tensor: shape=(2, 60, 3, 2), dtype=float32, numpy=
array([[[[-2.62904358e+00, -3.37280631e-02],
         [-1.76279449e+00, -1.36267483e+00],
         [-2.69201016e+00,  1.89421904e+00]],

        [[ 3.62991309e+00,  3.29263735e+00],
         [ 2.00001788e+00, -1.77931774e+00],
         [ 3.16046953e+00, -3.66473138e-01]],

        [[-2.67379522e+00, -4.15707731e+00],
         [-1.13439798e+00, -4.66901779e-01],
         [ 2.61911345e+00, -4.11668003e-01]],

        [[-2.88453043e-01, -1.27172148e+00],
         [-8.47360909e-01,  3.76592278e-01],
         [-1.08057117e+00, -8.54909182e-01]],

        [[-2.02530932e+00, -5.89817882e-01],
         [ 2.14358926e+00,  7.29258180e-01],
         [-9.32745278e-01,  1.16867423e+00]],

        [[-3.32889795e-01, -6.91879034e-01],
         [-3.74328643e-01, -1.37222850e+00],
         [-1.35618114e+00, -3.11526346e+00]],

        [[ 9.52293754e-01,  1.00344419e+00],
         [ 3.49933863e+00,  9.38518882e-01],
         [ 1.41361976e+00, -4.878

In [76]:
decoder.get_log_probs(sample_out, params_out, ref_out)

<tf.Tensor: shape=(2, 60), dtype=float32, numpy=
array([[-1.8467478e+01, -3.4584227e+00, -9.7786780e+02, -1.1290482e+01,
        -2.1305691e+01, -8.8147942e+01, -1.2053821e+01, -6.9964828e+00,
        -1.0861506e+03, -8.0973358e+00, -5.9120889e+00, -1.7629150e+03,
        -1.5955196e+01, -7.6601799e+01, -1.1567530e+01, -5.7511249e+00,
        -1.7826802e+00, -2.7119277e+00, -9.0998459e+01, -7.1989174e+00,
        -5.1443949e+00, -6.4026909e+00, -2.8578169e+01, -6.0077019e+00,
        -1.0387278e+02, -7.2903900e+00, -1.2111935e+01, -3.3354986e+00,
        -3.1821754e+00, -8.5605011e+00, -1.4035792e+00, -1.6600070e+02,
        -5.6058502e+00, -2.3582531e+01, -8.6379738e+01, -1.3336767e+00,
        -3.5794561e+00, -3.0602295e+00, -7.2572285e-01, -7.2044439e+00,
        -1.9737881e+02, -1.3886117e+01, -6.5538855e+00, -2.4137804e+01,
        -5.0655914e+01, -2.1127648e+00, -2.4225393e+01, -1.4649710e+02,
        -4.2851425e+01, -4.0455375e+00, -9.0061226e+00, -3.3401024e+00,
        -1.3321