In [1]:
import numpy as np
import tensorflow as tf
import tensorflow_probability as tfp

from importlib import reload

In [2]:
import particle_decoding as pd

In [76]:
reload(pd)

<module 'particle_decoding' from '/Users/jim2/bin/libVAE/particle_decoding.py'>

## Transformations

In [4]:
test_coords = np.array([[[1.0, 1.0, 1.0],
                        [-1.0, 1.0, 0.0],
                        [0.5, 0.5, 1.0]],
                        [[1.0, 1.0, 1.0],
                        [-1.0, 1.0, 0.0],
                        [0.5, 0.5, 1.0]]])

In [5]:
pd.identity_transform(test_coords)

<tf.Tensor: shape=(2, 3, 3), dtype=float64, numpy=
array([[[ 1. ,  1. ,  1. ],
        [-1. ,  1. ,  0. ],
        [ 0.5,  0.5,  1. ]],

       [[ 1. ,  1. ,  1. ],
        [-1. ,  1. ,  0. ],
        [ 0.5,  0.5,  1. ]]])>

In [6]:
pd.spherical_transform(test_coords)

<tf.Tensor: shape=(2, 3, 3), dtype=float64, numpy=
array([[[1.73205081, 0.78539816, 0.95531662],
        [1.41421356, 2.35619449, 1.57079633],
        [1.22474487, 0.78539816, 0.61547971]],

       [[1.73205081, 0.78539816, 0.95531662],
        [1.41421356, 2.35619449, 1.57079633],
        [1.22474487, 0.78539816, 0.61547971]]])>

In [7]:
pd.spherical_transform(test_coords, dist_sq=tf.reduce_sum(test_coords * test_coords, axis=-1))

<tf.Tensor: shape=(2, 3, 3), dtype=float64, numpy=
array([[[1.73205081, 0.78539816, 0.95531662],
        [1.41421356, 2.35619449, 1.57079633],
        [1.22474487, 0.78539816, 0.61547971]],

       [[1.73205081, 0.78539816, 0.95531662],
        [1.41421356, 2.35619449, 1.57079633],
        [1.22474487, 0.78539816, 0.61547971]]])>

In [8]:
pd.spherical_transform(pd.spherical_transform(test_coords), reverse=True)

<tf.Tensor: shape=(2, 3, 3), dtype=float64, numpy=
array([[[ 1.00000000e+00,  1.00000000e+00,  1.00000000e+00],
        [-1.00000000e+00,  1.00000000e+00,  8.65956056e-17],
        [ 5.00000000e-01,  5.00000000e-01,  1.00000000e+00]],

       [[ 1.00000000e+00,  1.00000000e+00,  1.00000000e+00],
        [-1.00000000e+00,  1.00000000e+00,  8.65956056e-17],
        [ 5.00000000e-01,  5.00000000e-01,  1.00000000e+00]]])>

## Masking by distance

In [9]:
test_coords = np.array([[[1.0, 1.0, 1.0],
                        [-1.0, 1.0, 0.0],
                        [0.5, 0.5, 1.0]],
                        [[1.0, 1.0, 1.0],
                        [-1.0, 1.0, 0.0],
                        [0.5, 0.5, 1.0]]])

ref_coords = np.array([[[0.0, 0.0, 0.0]],
                       [[0.0, 0.0, 1.0]]])

In [10]:
distance_mask = pd.DistanceMask()

In [11]:
distance_mask(ref_coords, test_coords, k_neighbors=2)

(<tf.Tensor: shape=(2, 2, 3), dtype=float64, numpy=
 array([[[ 0.5,  0.5,  1. ],
         [-1. ,  1. ,  0. ]],
 
        [[ 0.5,  0.5,  0. ],
         [ 1. ,  1. ,  0. ]]])>,
 <tf.Tensor: shape=(2, 2), dtype=int32, numpy=
 array([[2, 1],
        [2, 0]], dtype=int32)>,
 <tf.Tensor: shape=(2, 2), dtype=float64, numpy=
 array([[1.5, 2. ],
        [0.5, 2. ]])>)

In [12]:
local_coords, near_inds, dist_sq = distance_mask(ref_coords, test_coords, k_neighbors=2)
pd.spherical_transform(local_coords, dist_sq=dist_sq)

<tf.Tensor: shape=(2, 2, 3), dtype=float64, numpy=
array([[[1.22474487, 0.78539816, 0.61547971],
        [1.41421356, 2.35619449, 1.57079633]],

       [[0.70710678, 0.78539816, 1.57079633],
        [1.41421356, 0.78539816, 1.57079633]]])>

In [13]:
#Need to also test if ref included in input coordinates
ref_coords = np.array([[[-1.0, 1.0, 0.0]],
                       [[1.0, 1.0, 1.0]]])

distance_mask(ref_coords, test_coords, k_neighbors=1, ref_included=True)

(<tf.Tensor: shape=(2, 1, 3), dtype=float64, numpy=
 array([[[ 1.5, -0.5,  1. ]],
 
        [[-0.5, -0.5,  0. ]]])>,
 <tf.Tensor: shape=(2, 1), dtype=int32, numpy=
 array([[2],
        [2]], dtype=int32)>,
 <tf.Tensor: shape=(2, 1), dtype=float64, numpy=
 array([[3.5],
        [0.5]])>)

In [14]:
distance_mask(ref_coords, test_coords, k_neighbors=1, ref_included=False)

(<tf.Tensor: shape=(2, 1, 3), dtype=float64, numpy=
 array([[[0., 0., 0.]],
 
        [[0., 0., 0.]]])>,
 <tf.Tensor: shape=(2, 1), dtype=int32, numpy=
 array([[1],
        [0]], dtype=int32)>,
 <tf.Tensor: shape=(2, 1), dtype=float64, numpy=
 array([[-0.],
        [-0.]])>)

In [15]:
#Now try distances with a periodic simulation box
test_coords = np.array([[[1.0, 1.0, 1.0],
                        [-1.0, 1.0, 0.0],
                        [0.5, 0.5, 1.0]],
                        [[1.0, 1.0, 1.0],
                        [-1.0, 1.0, 0.0],
                        [0.5, 0.5, 1.0]]])

ref_coords = np.array([[[0.0, 0.0, 0.0]],
                       [[0.0, 0.0, 1.0]]])

distance_mask = pd.DistanceMask(box_lengths=np.array([3.0, 3.0, 3.0]))

In [16]:
distance_mask(ref_coords, test_coords, k_neighbors=2)

(<tf.Tensor: shape=(2, 2, 3), dtype=float64, numpy=
 array([[[ 0.5,  0.5,  1. ],
         [-1. ,  1. ,  0. ]],
 
        [[ 0.5,  0.5,  0. ],
         [ 1. ,  1. ,  0. ]]])>,
 <tf.Tensor: shape=(2, 2), dtype=int32, numpy=
 array([[2, 1],
        [2, 0]], dtype=int32)>,
 <tf.Tensor: shape=(2, 2), dtype=float64, numpy=
 array([[1.5, 2. ],
        [0.5, 2. ]])>)

In [17]:
local_coords, near_inds, dist_sq = distance_mask(ref_coords, test_coords, k_neighbors=2)
pd.spherical_transform(local_coords, dist_sq=dist_sq)

<tf.Tensor: shape=(2, 2, 3), dtype=float64, numpy=
array([[[1.22474487, 0.78539816, 0.61547971],
        [1.41421356, 2.35619449, 1.57079633]],

       [[0.70710678, 0.78539816, 1.57079633],
        [1.41421356, 0.78539816, 1.57079633]]])>

In [18]:
#Need to also test if ref included in input coordinates
ref_coords = np.array([[[-1.0, 1.0, 0.0]],
                       [[1.0, 1.0, 1.0]]])

distance_mask(ref_coords, test_coords, k_neighbors=1, ref_included=True)

(<tf.Tensor: shape=(2, 1, 3), dtype=float64, numpy=
 array([[[-1. ,  0. ,  1. ]],
 
        [[-0.5, -0.5,  0. ]]])>,
 <tf.Tensor: shape=(2, 1), dtype=int32, numpy=
 array([[0],
        [2]], dtype=int32)>,
 <tf.Tensor: shape=(2, 1), dtype=float64, numpy=
 array([[2. ],
        [0.5]])>)

In [19]:
distance_mask(ref_coords, test_coords, k_neighbors=1, ref_included=False)

(<tf.Tensor: shape=(2, 1, 3), dtype=float64, numpy=
 array([[[0., 0., 0.]],
 
        [[0., 0., 0.]]])>,
 <tf.Tensor: shape=(2, 1), dtype=int32, numpy=
 array([[1],
        [0]], dtype=int32)>,
 <tf.Tensor: shape=(2, 1), dtype=float64, numpy=
 array([[-0.],
        [-0.]])>)

## Creating probability distributions

In [20]:
params = np.array([[[[0.0, 1.0], [10.0, 1.0], [-10.0, 1.0]],
                    [[0.0, 1.0], [0.0, 2.0], [0.0, 10.0]],
                    [[5.0, 1.0], [5.0, 2.0], [5.0, 10.0]]],
                   [[[0.0, 1.0], [-10.0, 1.0], [10.0, 1.0]],
                    [[0.0, 10.0], [0.0, 2.0], [0.0, 1.0]],
                    [[-5.0, 10.0], [-5.0, 2.0], [-5.0, 1.0]]]])

dist_list = [tfp.distributions.Normal]*3
param_trans = [lambda x, y: [x, y]]*3

In [21]:
joint_dist = pd.create_dist(params, dist_list, param_trans)

In [22]:
sample = tf.stack(joint_dist.sample(), axis=-1)
print(sample)

tf.Tensor(
[[[  0.21444572  10.98526594  -8.92859722]
  [  0.74477985  -3.60675276  -2.91630749]
  [  6.05203319   3.72033378   7.64744451]]

 [[ -1.23496414  -8.97304691   7.68607601]
  [  5.11984264   2.55245489   1.15245539]
  [-14.02047385  -3.52264444  -6.48308741]]], shape=(2, 3, 3), dtype=float64)


In [23]:
joint_dist.log_prob(tf.unstack(sample, axis=-1))

<tf.Tensor: shape=(2, 3), dtype=float64, numpy=
array([[-3.83913553, -7.69850382, -6.54567281],
       [-6.72382225, -7.36206677, -7.53198918]])>

In [24]:
np.sum(-0.5*((sample[-1, 1, :] - params[-1, 1, :, 0])**2)/(params[-1, 1, :, 1]**2) - np.log(params[-1, 1, :, 1]) - 0.5*np.log(2.0*np.pi))

-7.362066771472351

## Solvation neural networks

In [25]:
#Input can be solute or solvent coordinates, with extra_inputs being the other
#Start by testing without augmented input (only relevant for first network to predict layer of solvent for first solute)
solute_net = pd.SolvationNet((8, 3), out_event_dims=2, hidden_dim=50, n_hidden=3, augment_input=False)

In [26]:
solute_coords = tf.random.normal((1, 2, 3))

In [27]:
params, shifts = solute_net(solute_coords)
print(params, shifts)

tf.Tensor(
[[[[-0.02991906  0.03474456]
   [ 0.04694335 -0.30082986]
   [-0.27825925  0.01359881]]

  [[-0.22389111 -0.03151908]
   [ 0.11536059 -0.06560486]
   [-0.1917338   0.22249524]]

  [[ 0.04552162 -0.2812307 ]
   [-0.1865316  -0.02563034]
   [ 0.14342093 -0.07622399]]

  [[-0.22526799  0.41627496]
   [ 0.01712676 -0.04835619]
   [ 0.11832903 -0.11827321]]

  [[-0.22165766  0.05228366]
   [ 0.01474504 -0.27279344]
   [-0.09255562 -0.21349686]]

  [[ 0.12998202  0.20185329]
   [ 0.00334737  0.18556513]
   [ 0.12778433  0.258542  ]]

  [[-0.08710504 -0.0789277 ]
   [-0.3215373   0.06362516]
   [ 0.02210784 -0.09951571]]

  [[-0.00971321 -0.11772958]
   [-0.02228445  0.15683885]
   [-0.14728114  0.05829949]]]], shape=(1, 8, 3, 2), dtype=float32) tf.Tensor(
[[[[-0.5702193   0.10242276]
   [ 0.06208383  0.0614965 ]
   [ 0.63484496  0.12287204]]

  [[-0.30224758  0.2080164 ]
   [ 0.2346755  -0.58575094]
   [-0.67766994  0.58275557]]

  [[ 0.17076838  0.7967104 ]
   [-0.24945955  0.403

In [28]:
dist = pd.create_dist(params + shifts, [tfp.distributions.Normal]*3, [lambda x, y: [x, tf.math.exp(0.5*y)]]*3)
sample = tf.stack(dist.sample(), axis=-1)
print(sample)

tf.Tensor(
[[[-2.0254989   1.0928754  -1.3855463 ]
  [-0.74807084  1.2247336  -2.184418  ]
  [-1.7917328  -2.1371703   0.30274838]
  [-2.2622733   0.03297795  0.20939031]
  [ 0.4065243   0.23652282 -0.36569336]
  [-0.6350927   0.41721833 -0.5318515 ]
  [-0.02749306  0.12260407 -0.8482939 ]
  [ 0.9013275  -1.3540046   1.5958984 ]]], shape=(1, 8, 3), dtype=float32)


In [29]:
#First set of 3 (first particle) should not change shifts, even with different input due to autoregressive
#All others should change, though
new_shifts = solute_net(solute_coords, sampled_input=sample)
print(new_shifts)

tf.Tensor(
[[[[-0.5702193   0.10242276]
   [ 0.06208383  0.0614965 ]
   [ 0.63484496  0.12287204]]

  [[-0.31552345  0.1815641 ]
   [ 0.18031707 -0.51817554]
   [-0.65162534  0.5954337 ]]

  [[ 0.12611629  0.7638406 ]
   [-0.14613794  0.43455833]
   [-0.3300029   0.03675803]]

  [[ 0.29884303 -0.1426386 ]
   [ 0.24490124 -0.2781062 ]
   [ 0.67148304 -0.2519494 ]]

  [[ 0.8468522  -0.45434064]
   [ 0.01736777 -0.71401054]
   [-0.4038055  -0.6100823 ]]

  [[-0.00461812 -0.14540403]
   [ 0.5071668   0.2602861 ]
   [-0.05079362 -0.27574855]]

  [[ 0.93765163  0.22175571]
   [ 0.91380775  0.5562156 ]
   [ 0.15205522  0.6775594 ]]

  [[ 1.0671461   0.38773963]
   [ 0.49657407 -0.12616023]
   [-0.0346818  -0.25413114]]]], shape=(1, 8, 3, 2), dtype=float32)


In [30]:
#And now with solvent inputs (remember, could flip if wanted)
solute_net_solv = pd.SolvationNet((8, 3), out_event_dims=2, hidden_dim=50, n_hidden=3, augment_input=True)

In [31]:
solvent_coords = tf.random.normal((1, 6, 3))

In [32]:
params, shifts = solute_net_solv(solute_coords, extra_coords=solvent_coords)
print(params, shifts)

tf.Tensor(
[[[[ 0.05719483 -0.04520281]
   [ 0.0688635  -0.13556404]
   [-0.09076778  0.61524194]]

  [[ 0.17296946 -0.25806147]
   [ 0.08461397  0.115769  ]
   [-0.4429516   0.1292496 ]]

  [[ 0.15389366  0.01381597]
   [-0.02399583  0.16826536]
   [ 0.08355334 -0.05810891]]

  [[ 0.00437843 -0.3261838 ]
   [-0.21915945 -0.20572868]
   [ 0.22116393  0.33660078]]

  [[ 0.073622    0.01056137]
   [ 0.12566684 -0.6766143 ]
   [-0.48783904  0.27167055]]

  [[ 0.08725085  0.00840436]
   [ 0.24962084 -0.30215055]
   [ 0.15827037 -0.22926673]]

  [[-0.1387734   0.01755169]
   [ 0.21355698  0.07237558]
   [-0.14446372 -0.02517876]]

  [[ 0.05749101 -0.18892531]
   [-0.09169422  0.27556375]
   [ 0.05544627  0.14360003]]]], shape=(1, 8, 3, 2), dtype=float32) tf.Tensor(
[[[[ 1.5454783  -0.504551  ]
   [-0.74806386  0.295422  ]
   [-0.58324474 -0.3088516 ]]

  [[ 0.41401672 -0.51051366]
   [-1.2858216  -1.3612124 ]
   [ 0.3708284   0.10739153]]

  [[ 0.260424    0.49330163]
   [ 0.06513315  0.311

In [33]:
dist = pd.create_dist(params + shifts, [tfp.distributions.Normal]*3, [lambda x, y: [x, tf.math.exp(0.5*y)]]*3)
sample = tf.stack(dist.sample(), axis=-1)
print(sample)

tf.Tensor(
[[[ 1.7191827  -1.0685421  -1.6340321 ]
  [ 0.92454123 -1.761539    1.9991454 ]
  [ 0.64554095 -1.8035946   0.19999373]
  [ 0.25530115 -0.25807044  1.0190977 ]
  [ 0.70228565  0.5526482  -0.8073046 ]
  [ 0.38177472  0.02222461 -1.7063066 ]
  [ 0.41057116  0.06930777 -1.9659926 ]
  [ 0.32516903 -1.4714203  -0.800588  ]]], shape=(1, 8, 3), dtype=float32)


In [34]:
new_shifts = solute_net_solv(solute_coords, extra_coords=solvent_coords, sampled_input=sample)
print(new_shifts)

tf.Tensor(
[[[[ 1.5454783  -0.504551  ]
   [-0.74806386  0.295422  ]
   [-0.58324474 -0.3088516 ]]

  [[ 0.4077664  -0.5114895 ]
   [-1.2859792  -1.3557202 ]
   [ 0.37200505  0.10508069]]

  [[ 0.26190686  0.48434004]
   [ 0.06507744  0.30370456]
   [-1.1102226  -0.03305111]]

  [[ 0.6901213   0.7827432 ]
   [ 0.33459574 -1.345809  ]
   [ 0.97541195  0.48326883]]

  [[-0.3575552   1.0713925 ]
   [-0.12701678 -0.16925704]
   [-1.1676636   0.675653  ]]

  [[ 0.8421416   0.7587639 ]
   [ 0.6292156   0.10328307]
   [-0.8352114  -0.17457661]]

  [[ 0.06267682 -1.0134772 ]
   [-0.16702676  0.8704173 ]
   [-1.549744    0.18348849]]

  [[ 0.7890519  -1.3944204 ]
   [-1.2938039   0.47328177]
   [ 0.5774903  -0.95883846]]]], shape=(1, 8, 3, 2), dtype=float32)


In [35]:
#Now that it's been called (so built) with extra_coords provided, must provide always with this layer
new_shifts = solute_net_solv(solute_coords, sampled_input=sample)
print(new_shifts)

ValueError: Input 0 of layer dense_18 is incompatible with the layer: expected axis -1 of input shape to have value 24 but received input with shape (1, 6)

In [36]:
#Now pretend we're working with a central solvent particle and adding one more
#(conditioned on other solutes and solvent already around)
solvent_net = pd.SolvationNet((1, 3), out_event_dims=2, hidden_dim=50, n_hidden=3, augment_input=True)

In [37]:
params, shifts = solvent_net(solvent_coords, extra_coords=solute_coords)
print(params, shifts)

tf.Tensor(
[[[[-0.04235222  0.3004281 ]
   [-0.01502383 -0.04668699]
   [ 0.09877495  0.12738968]]]], shape=(1, 1, 3, 2), dtype=float32) tf.Tensor(
[[[[ 0.666824    1.6624672 ]
   [-0.19942087  0.66355145]
   [-0.79080963  1.1689491 ]]]], shape=(1, 1, 3, 2), dtype=float32)


## Full particle decoder

In [77]:
decoder = pd.ParticleDecoder(3, 2,
                             box_lengths=np.array([3.0, 3.0, 3.0]),
                             k_solute_neighbors=3,
                             k_solvent_neighbors=3)

In [78]:
solute_coords = tf.random.uniform((2, 5, 3), minval=-1.5, maxval=1.5)

In [79]:
sample_out, params_out, ref_out = decoder(solute_coords)
print(sample_out.shape, params_out.shape, ref_out.shape)

(2, 60, 3) (2, 60, 3, 2) (2, 60, 3)


In [80]:
sample_out - ref_out

<tf.Tensor: shape=(2, 60, 3), dtype=float32, numpy=
array([[[-1.32714581e+00, -4.13748503e-01, -8.38724852e-01],
        [-2.41272259e+00, -1.73847950e+00, -8.32716227e-02],
        [-5.04931033e-01,  1.29382944e+00, -1.98968959e+00],
        [-1.15179467e+00, -4.51758265e-01,  4.85966504e-01],
        [ 1.80733681e-01,  5.03747106e-01,  3.26197004e+00],
        [-1.43243492e+00, -9.29448009e-01, -1.58431399e+00],
        [-1.21276033e+00, -4.36940938e-01, -7.39343047e-01],
        [ 5.39698601e-01, -2.69666791e-01, -1.92056739e+00],
        [-9.74236727e-01, -3.67212248e+00,  8.90242577e-01],
        [-7.74036407e-01, -2.95437932e-01, -1.40186048e+00],
        [ 7.67808974e-01, -1.40716791e-01,  7.35237479e-01],
        [ 2.61848181e-01,  4.77797329e-01,  5.03151536e-01],
        [-9.31636691e-01, -2.17747998e+00,  1.35686302e+00],
        [ 2.12097216e+00,  2.09636593e+00,  1.33617020e+00],
        [-5.15720844e-02,  2.16723144e-01, -1.33502781e-01],
        [ 5.75878501e-01,  2.6237

In [81]:
params_out

<tf.Tensor: shape=(2, 60, 3, 2), dtype=float32, numpy=
array([[[[-4.36999261e-01,  1.03518236e+00],
         [-4.57027823e-01,  2.67408550e-01],
         [-5.56211054e-01, -7.03046471e-02]],

        [[ 1.11274168e-01,  8.93528223e-01],
         [-8.26731920e-02,  4.81074244e-01],
         [-4.03879315e-01,  4.35976863e-01]],

        [[ 9.92402732e-01,  1.14354753e+00],
         [ 5.58950961e-01, -3.35525930e-01],
         [-1.07525969e+00,  4.63239431e-01]],

        [[-4.34408844e-01, -2.81047672e-01],
         [ 1.81566775e-01, -3.66040528e-01],
         [ 1.04778290e+00,  2.49289304e-01]],

        [[ 3.81991953e-01, -1.63461876e+00],
         [ 4.97435451e-01, -5.79360485e-01],
         [ 1.22245347e+00,  1.05375099e+00]],

        [[-1.73260951e+00, -4.64159727e-01],
         [ 1.00381374e-02, -1.43386543e+00],
         [-9.97239232e-01, -2.93641329e-01]],

        [[ 1.36079222e-01, -2.75479406e-01],
         [ 6.99664652e-02, -7.76437461e-01],
         [-1.52638674e+00,  8.506

In [82]:
#For convenience, added log_prob function to decoder
decoder.get_log_probs(sample_out, params_out, ref_out)

<tf.Tensor: shape=(2, 60), dtype=float32, numpy=
array([[-3.557197 , -5.8461313, -4.39045  , -3.3109353, -3.0056617,
        -3.8150644, -4.265965 , -3.5549612, -5.449364 , -4.235356 ,
        -3.607459 , -3.2160277, -4.6676526, -5.482623 , -2.802143 ,
        -2.8454685, -3.3078678, -3.9064684, -3.2483158, -4.426918 ,
        -3.2159972, -3.5605755, -4.6209393, -4.84853  , -2.9602525,
        -3.858611 , -4.216859 , -4.6669188, -5.4510126, -3.8832846,
        -2.0755105, -2.9305077, -3.7148075, -4.211273 , -3.8556933,
        -4.7625303, -2.5326533, -6.8340425, -2.5500207, -3.8132644,
        -3.438814 , -3.9232273, -4.111244 , -3.0427547, -3.874486 ,
        -5.827339 , -5.1328306, -6.2676773, -5.0050745, -3.424739 ,
        -3.0061276, -3.5443642, -4.7158427, -3.5845582, -3.0858517,
        -5.3708677, -4.3300223, -7.1995077, -2.698897 , -0.764    ],
       [-2.503958 , -3.86901  , -6.457015 , -7.1328096, -4.492614 ,
        -3.9499195, -7.0429635, -4.0260715, -4.321005 , -2.2473812

In [83]:
#Above tested generation, now test if have training data
training_data = tf.random.normal((2, 60, 3))

In [84]:
sample_out, params_out, ref_out = decoder(solute_coords, train_data=training_data)
print(sample_out.shape, params_out.shape, ref_out.shape)

(2, 60, 3) (2, 60, 3, 2) (2, 60, 3)


In [85]:
training_data

<tf.Tensor: shape=(2, 60, 3), dtype=float32, numpy=
array([[[-1.3828479 , -1.3479795 , -0.34994233],
        [-0.6745582 ,  1.0071992 , -0.23908341],
        [ 0.9883995 , -0.4445044 ,  0.6972556 ],
        [ 2.0301037 ,  1.2465769 , -1.3493426 ],
        [-1.0880244 , -1.6955938 , -0.28732374],
        [ 0.6905825 ,  0.09656796, -1.083081  ],
        [ 0.15604804, -0.08175212,  0.5491082 ],
        [ 1.6770103 ,  3.055316  ,  0.850607  ],
        [-0.26127937, -0.90105814,  0.36135206],
        [ 0.6267869 , -1.3457528 ,  0.7248802 ],
        [-1.3594846 ,  1.6323911 ,  1.787056  ],
        [-0.6866033 ,  0.5055509 ,  0.27402115],
        [ 1.1962193 ,  1.1337395 , -1.7196463 ],
        [ 0.84703195, -0.71926236, -1.3443736 ],
        [-0.5914132 ,  0.951369  , -0.3885472 ],
        [-0.2581287 ,  1.0542793 ,  0.18349864],
        [-0.74548185, -0.40614045,  0.6011698 ],
        [-0.69947404, -0.83691245,  0.56345594],
        [ 0.09931951,  0.8833676 ,  0.5097185 ],
        [-0.21803

In [86]:
sample_out

<tf.Tensor: shape=(2, 60, 3), dtype=float32, numpy=
array([[[ 4.85895932e-01,  2.37861097e-01, -1.85923815e+00],
        [ 8.57063115e-01,  1.02534306e+00, -1.78429711e+00],
        [ 1.75864029e+00,  7.38040030e-01, -1.98052692e+00],
        [-6.99474037e-01, -8.36912453e-01,  5.63455939e-01],
        [ 2.60786563e-01, -1.41070461e+00,  9.10110116e-01],
        [-1.08192599e+00,  1.88378870e-01,  8.98438692e-01],
        [ 2.92056382e-01,  2.44775921e-01, -6.29905283e-01],
        [ 6.90582514e-01,  9.65679586e-02, -1.08308101e+00],
        [ 2.71570385e-02, -2.42748737e-01, -1.64865541e+00],
        [ 1.89903140e-01,  5.21010280e-01, -4.30716991e-01],
        [ 3.98017168e-01, -1.54112124e+00,  1.27270246e+00],
        [-2.07388580e-01, -1.35633194e+00,  3.94645095e-01],
        [ 1.67701030e+00,  5.53160906e-02,  8.50606978e-01],
        [ 6.17530167e-01,  1.88103819e+00,  2.47260666e+00],
        [ 2.07624626e+00,  7.91807234e-01,  1.02593482e-01],
        [ 1.19621933e+00,  1.1337

In [87]:
sample_out - ref_out

<tf.Tensor: shape=(2, 60, 3), dtype=float32, numpy=
array([[[-0.08115214, -0.4802571 , -0.7585447 ],
        [ 0.29001504,  0.30722487, -0.68360364],
        [ 1.1915922 ,  0.01992184, -0.87983346],
        [-0.84187573, -0.38507968,  0.66303635],
        [ 0.11838487, -0.95887184,  1.0096905 ],
        [-1.2243277 ,  0.64021164,  0.9980191 ],
        [-0.19015652, -0.11363158, -0.19890255],
        [ 0.20836961, -0.26183954, -0.6520783 ],
        [-0.45505586, -0.60115623, -1.2176527 ],
        [ 0.8118795 ,  1.8502347 , -0.86851966],
        [ 1.0199935 , -0.21189678,  0.8348998 ],
        [ 0.4145878 , -0.02710748, -0.04315758],
        [ 0.29161787, -0.5947484 , -0.08241028],
        [-0.76786226,  1.2309737 ,  1.5395894 ],
        [ 0.69085383,  0.14174277, -0.8304238 ],
        [ 0.7103234 ,  0.8958784 ,  0.13959181],
        [ 1.3728173 ,  0.17899024, -0.5443753 ],
        [-1.1578709 ,  0.41595644, -0.09932661],
        [ 0.891119  ,  0.70541894,  0.50195396],
        [-1.23068

In [88]:
params_out

<tf.Tensor: shape=(2, 60, 3, 2), dtype=float32, numpy=
array([[[[-4.36999261e-01,  1.03518236e+00],
         [-4.57027823e-01,  2.67408550e-01],
         [-5.56211054e-01, -7.03046471e-02]],

        [[ 1.00579962e-01,  9.56189334e-01],
         [-9.92560089e-02,  5.51300943e-01],
         [-3.91707271e-01,  3.52663457e-01]],

        [[ 9.84847188e-01,  1.13165116e+00],
         [ 4.87213463e-01, -3.85258019e-01],
         [-1.13517690e+00,  4.44009393e-01]],

        [[-5.22794783e-01, -1.07320356e+00],
         [-4.33168769e-01, -2.83524394e-01],
         [ 9.58842635e-01, -5.33696949e-01]],

        [[ 9.89815593e-02, -1.45842999e-01],
         [-5.63703775e-01, -4.93251860e-01],
         [ 1.12202048e+00, -2.82653004e-01]],

        [[-9.14054751e-01,  2.37081498e-01],
         [ 6.93296015e-01, -1.75051308e+00],
         [ 8.47702980e-01,  6.52455211e-01]],

        [[-7.01401383e-02, -2.31047869e-01],
         [-1.35152087e-01, -3.04327279e-01],
         [-3.61152589e-02, -5.154

In [89]:
decoder.get_log_probs(sample_out, params_out, ref_out)

<tf.Tensor: shape=(2, 60), dtype=float32, numpy=
array([[-3.417612 , -3.7713318, -3.540316 , -2.0366306, -2.4323936,
        -2.3782976, -2.262974 , -2.7312515, -2.0419245, -2.4626434,
        -2.5799265, -3.668037 , -3.7322288, -2.8383238, -3.1867151,
        -1.5666976, -2.4462438, -2.1947606, -2.2360282, -3.2513745,
        -3.993558 , -2.7044768, -2.0219722, -3.5263915, -3.3662257,
        -3.8759081, -4.0746937, -2.2495909, -3.0002074, -3.411603 ,
        -3.0845313, -5.3162932, -3.5011563, -2.8490393, -2.8870347,
        -1.0947998, -3.58073  , -4.4498186, -2.4333282, -3.3267684,
        -1.8214128, -2.2098508, -2.828311 , -2.4143057, -4.0947595,
        -1.7808242, -1.3334777, -4.4160833, -3.0172281, -3.5853553,
        -3.4837756, -1.6976292, -2.456397 , -4.14011  , -2.6044974,
        -1.5045866, -2.8467824, -2.257277 , -3.1519673, -3.05616  ],
       [-2.172255 , -2.9427428, -4.023569 , -3.9549956, -2.8613398,
        -4.378444 , -3.3784091, -3.925931 , -3.371484 , -1.7041826

In [95]:
#Try again with more complicated transform function
#Which will also require a more complicated set of sampling distributions
#Don't want to use Gaussian for r, theta, and phi
#Want to use Gamma for r and von Mises for angles
#For Gamma, specify mean and log(var) and convert to distribution parameters
decoder = pd.ParticleDecoder(3, 2,
                             box_lengths=np.array([3.0, 3.0, 3.0]),
                             k_solute_neighbors=3,
                             k_solvent_neighbors=3,
                             tfp_dist_list=[tfp.distributions.Gamma,
                                            tfp.distributions.VonMises,
                                            tfp.distributions.VonMises],
                             param_transforms=[lambda x, y: [tf.math.exp(2*x - y), tf.math.exp(x - y)],
                                               lambda x, y: [x, tf.math.exp(-y)],
                                               lambda x, y: [x, tf.math.exp(-y)]],
                             mean_transforms=[tf.math.exp, tf.identity, tf.identity],
                             coord_transform=pd.spherical_transform)

In [96]:
solute_coords = tf.random.uniform((2, 5, 3), minval=-1.5, maxval=1.5)

In [97]:
sample_out, params_out, ref_out = decoder(solute_coords)
print(sample_out.shape, params_out.shape, ref_out.shape)

(2, 60, 3) (2, 60, 3, 2) (2, 60, 3)


In [98]:
sample_out - ref_out

<tf.Tensor: shape=(2, 60, 3), dtype=float32, numpy=
array([[[-2.2886550e+00,  8.3291101e-01,  1.3229189e+00],
        [ 1.0996449e-01,  3.0662420e-01,  5.5479131e+00],
        [-6.2674403e-02,  5.3536236e-02, -3.2776606e+00],
        [ 0.0000000e+00,  0.0000000e+00,  0.0000000e+00],
        [ 0.0000000e+00,  0.0000000e+00,  0.0000000e+00],
        [-3.3945459e-01, -2.6465905e-01,  4.7985458e-01],
        [ 3.9935112e-05,  4.8995018e-05,  8.6665154e-05],
        [ 0.0000000e+00,  0.0000000e+00,  0.0000000e+00],
        [-4.1817770e+00, -2.2565536e+00, -1.3592211e+01],
        [ 0.0000000e+00,  0.0000000e+00,  0.0000000e+00],
        [ 2.2529373e+01, -1.9872732e+00,  2.8133444e+01],
        [-3.2246548e-01,  6.8763179e-01,  5.0522119e-01],
        [ 0.0000000e+00,  0.0000000e+00,  0.0000000e+00],
        [ 0.0000000e+00,  0.0000000e+00,  0.0000000e+00],
        [ 0.0000000e+00,  0.0000000e+00,  0.0000000e+00],
        [ 0.0000000e+00,  0.0000000e+00,  0.0000000e+00],
        [-6.0466951e

In [99]:
params_out

<tf.Tensor: shape=(2, 60, 3, 2), dtype=float32, numpy=
array([[[[ 9.13494229e-01,  1.62008390e-01],
         [-1.84427118e+00, -8.33066463e-01],
         [-2.32583857e+00,  1.15421593e-01]],

        [[ 1.71736801e+00,  2.77505755e-01],
         [-1.70329034e+00,  1.65046230e-01],
         [-1.96743703e+00,  1.90957582e+00]],

        [[ 1.16364503e+00, -2.40286902e-01],
         [-2.67079532e-01,  8.20937753e-02],
         [ 1.25382617e-01,  4.71965432e-01]],

        [[-3.80608535e+00,  1.78472340e+00],
         [-1.73739457e+00, -3.57307285e-01],
         [-5.29228151e-01, -6.45075083e-01]],

        [[-1.06403112e+00,  2.00739813e+00],
         [ 1.04672205e+00,  6.40796661e-01],
         [-2.16297197e+00, -1.67821026e+00]],

        [[ 3.09235990e-01, -1.40385354e+00],
         [ 7.81065941e-01, -1.87758136e+00],
         [-7.00425029e-01, -2.33340955e+00]],

        [[-2.64498639e+00, -1.32006794e-01],
         [-1.44929361e+00, -9.97731805e-01],
         [-5.10933876e-01, -2.578

In [100]:
#For convenience, added log_prob function to decoder
decoder.get_log_probs(sample_out, params_out, ref_out)

<tf.Tensor: shape=(2, 60), dtype=float32, numpy=
array([[-7.0931873e+00, -5.8065753e+00, -6.2555466e+00,            inf,
                   inf, -2.2830511e+01, -8.4844074e+00,            inf,
        -4.6860909e+01,            inf, -3.7814283e+00, -8.2640488e+01,
                   inf,            inf,            inf,            inf,
        -3.0188019e+01, -5.9071951e+00, -7.1418004e+00, -7.0081930e+00,
                   inf,            inf, -6.9002733e+00, -7.3732738e+00,
        -6.6030316e+00,            inf, -1.0041814e+01, -3.7858489e+00,
        -5.0469437e+00, -4.8505192e+00, -4.1099982e+00, -3.3069052e+02,
        -8.3925735e+01, -1.5955637e+01, -1.5905023e+01,  3.8115501e-02,
        -2.7534888e+00, -2.2700996e+00,  2.2721343e+00, -9.2394698e-01,
        -2.1516180e+01,  5.1308262e-01,  1.5286086e+00, -3.5845718e+00,
         1.7353827e-01, -4.4625902e+00, -2.5098228e+00, -9.2824972e-01,
        -5.9076843e+00,            inf,  9.6036756e-01, -1.6907973e+00,
        -2.8753

In [101]:
#Above tested generation, now test if have training data
training_data = tf.random.normal((2, 60, 3))

In [102]:
sample_out, params_out, ref_out = decoder(solute_coords, train_data=training_data)
print(sample_out.shape, params_out.shape, ref_out.shape)

(2, 60, 3) (2, 60, 3, 2) (2, 60, 3)


In [103]:
training_data

<tf.Tensor: shape=(2, 60, 3), dtype=float32, numpy=
array([[[ 1.32316101e+00, -8.19735348e-01,  1.28946936e+00],
        [-3.24979037e-01,  1.32114412e-02,  3.97429258e-01],
        [-7.16652751e-01, -1.17039204e+00, -1.40354896e+00],
        [ 1.48491216e+00,  5.03120184e-01, -1.63797534e+00],
        [-4.97044884e-02,  1.12673056e+00, -1.10722220e+00],
        [ 6.19082451e-01, -1.17576933e+00, -3.08949053e-01],
        [-3.39081436e-01, -5.31021059e-02,  6.37633204e-01],
        [ 9.17870700e-01,  3.34726691e-01,  6.24505877e-01],
        [ 2.33562693e-01, -8.31101179e-01, -1.06330252e+00],
        [-2.27743424e-02, -1.37636962e-04,  1.48302269e+00],
        [ 2.70574379e+00, -6.01301193e-01, -8.27764928e-01],
        [-1.51834011e+00,  8.72638822e-01, -1.82305172e-01],
        [-1.04272878e+00,  3.27202153e+00, -1.02172129e-01],
        [ 5.04144251e-01, -6.83612525e-01,  2.72092372e-01],
        [-1.00378585e+00,  2.76088148e-01,  1.33781042e-02],
        [-1.92420137e+00, -1.5260

In [104]:
sample_out

<tf.Tensor: shape=(2, 60, 3), dtype=float32, numpy=
array([[[ 1.99621415e+00,  2.76088119e-01,  1.33781433e-02],
        [ 1.66101170e+00,  1.66071546e+00,  6.24000788e-01],
        [ 1.81503999e+00, -8.49125683e-01,  2.39349151e+00],
        [-7.69833565e-01,  1.13404036e-01, -3.47941375e+00],
        [ 1.56266236e+00,  2.60099936e+00,  3.46436977e-01],
        [ 1.11968708e+00,  9.60994184e-01, -6.74027205e-03],
        [-1.28667831e-01,  1.15623820e+00, -3.86880207e+00],
        [-7.16652751e-01, -1.17039204e+00, -1.40354896e+00],
        [-9.42596436e-01, -1.27725327e+00,  2.69177437e-01],
        [ 8.19219410e-01,  7.21526086e-01, -1.01131654e+00],
        [ 3.43683243e-01, -4.10167992e-01,  1.82670856e+00],
        [ 3.21583837e-01, -3.39238465e-01,  4.44980204e-01],
        [ 8.63276601e-01, -1.40458059e+00, -2.26876020e+00],
        [-1.21239889e+00, -4.75581139e-01,  6.79086328e-01],
        [-1.94620478e+00,  1.31444693e-01, -2.50042510e+00],
        [ 1.32316101e+00, -8.1973

In [105]:
sample_out - ref_out

<tf.Tensor: shape=(2, 60, 3), dtype=float32, numpy=
array([[[ 5.67108154e-01,  5.87262988e-01, -1.01025772e+00],
        [ 2.31905699e-01,  1.97189033e+00, -3.99635077e-01],
        [ 3.85933995e-01, -5.37950814e-01,  1.36985564e+00],
        [-1.90616107e+00, -1.30960870e+00, -3.11561465e+00],
        [ 4.26334858e-01,  1.17798662e+00,  7.10236073e-01],
        [-1.66404247e-02, -4.62018549e-01,  3.57058823e-01],
        [ 1.21444583e+00,  2.45365453e+00, -2.66798401e+00],
        [ 6.26460910e-01,  1.27024293e-01, -2.02730775e-01],
        [ 4.00517225e-01,  2.01630592e-02,  1.46999562e+00],
        [ 7.20139325e-01,  8.67764413e-01, -8.50621104e-01],
        [ 2.44603157e-01, -2.63929665e-01,  1.98740399e+00],
        [ 2.22503752e-01, -1.93000138e-01,  6.05675638e-01],
        [ 2.08673835e+00, -8.97365570e-01, -2.28261971e+00],
        [ 1.10628605e-02,  3.16338837e-02,  6.65226817e-01],
        [-7.22743034e-01,  6.38659716e-01, -2.51428461e+00],
        [-6.73053145e-01, -1.0958

In [106]:
params_out

<tf.Tensor: shape=(2, 60, 3, 2), dtype=float32, numpy=
array([[[[ 9.13494229e-01,  1.62008390e-01],
         [-1.84427118e+00, -8.33066463e-01],
         [-2.32583857e+00,  1.15421593e-01]],

        [[ 1.87759459e+00,  2.59001136e-01],
         [-1.73462164e+00,  2.07481042e-01],
         [-1.85318542e+00,  1.85181355e+00]],

        [[ 1.26763666e+00, -9.55591947e-02],
         [ 7.72508830e-02,  3.61488700e-01],
         [-3.51049006e-02,  5.62028289e-01]],

        [[-3.71247482e+00,  1.78827703e+00],
         [-2.31971574e+00, -3.87682676e-01],
         [-6.59653127e-01, -6.70299113e-01]],

        [[-1.67426038e+00,  8.02883565e-01],
         [ 1.10971260e+00,  5.13852060e-01],
         [-1.98599339e+00, -1.10102558e+00]],

        [[ 3.66988719e-01, -1.25605911e-01],
         [ 2.04468757e-01, -1.75998342e+00],
         [-1.30405247e-01, -2.50981307e+00]],

        [[-4.07195282e+00,  1.06365347e+00],
         [-2.10613799e+00, -8.48157406e-02],
         [ 8.90335143e-01, -3.741

In [107]:
decoder.get_log_probs(sample_out, params_out, ref_out)

<tf.Tensor: shape=(2, 60), dtype=float32, numpy=
array([[-8.1413822e+00, -2.0560381e+01, -6.8728480e+00, -1.6053120e+01,
        -1.2413630e+01, -1.3763941e+01, -1.5909858e+01, -9.5815611e+00,
        -6.5777617e+00, -9.4956150e+00, -4.5660004e+01, -8.2327087e+01,
        -1.1063395e+01, -1.2021087e+01, -2.6454905e+01, -9.2079029e+00,
        -1.4191542e+02, -4.0135698e+00, -2.2167222e+00, -5.4535468e+02,
        -4.4422559e+02, -4.5959830e+02, -6.4909941e+03, -6.9595919e+00,
        -1.5370132e+02, -6.9595399e+00, -5.4321966e+00, -5.0810943e+00,
        -1.7097773e+01, -9.5624435e+01, -3.6806103e+05, -6.8529625e+01,
        -4.9708063e+02, -1.4691078e+03, -4.0493314e+02, -1.7346703e+03,
        -2.8546425e+01, -6.5721222e+02, -3.5570786e+02, -9.2264969e+01,
        -1.6620876e+01, -2.2986349e+01, -5.8477703e+01, -1.1603130e+03,
        -1.6184973e+01, -7.5649304e+00, -5.7171733e+03, -9.4572418e+01,
        -3.3858552e+03, -1.6068682e+01, -6.4744541e+03, -3.6672412e+02,
        -1.7682