In [1]:
import numpy as np
import tensorflow as tf
import tensorflow_probability as tfp

from importlib import reload

In [2]:
import particle_decoding as pd

In [3]:
reload(pd)

<module 'particle_decoding' from '/Users/jim2/bin/libVAE/particle_decoding.py'>

## Transformations

In [4]:
test_coords = np.array([[[1.0, 1.0, 1.0],
                        [-1.0, 1.0, 0.0],
                        [0.5, 0.5, 1.0]],
                        [[1.0, 1.0, 1.0],
                        [-1.0, 1.0, 0.0],
                        [0.5, 0.5, 1.0]]])

In [5]:
pd.identity_transform(test_coords)

<tf.Tensor: shape=(2, 3, 3), dtype=float64, numpy=
array([[[ 1. ,  1. ,  1. ],
        [-1. ,  1. ,  0. ],
        [ 0.5,  0.5,  1. ]],

       [[ 1. ,  1. ,  1. ],
        [-1. ,  1. ,  0. ],
        [ 0.5,  0.5,  1. ]]])>

In [6]:
pd.spherical_transform(test_coords)

<tf.Tensor: shape=(2, 3, 3), dtype=float64, numpy=
array([[[1.73205081, 0.78539816, 0.95531662],
        [1.41421356, 2.35619449, 1.57079633],
        [1.22474487, 0.78539816, 0.61547971]],

       [[1.73205081, 0.78539816, 0.95531662],
        [1.41421356, 2.35619449, 1.57079633],
        [1.22474487, 0.78539816, 0.61547971]]])>

In [7]:
pd.spherical_transform(test_coords, dist_sq=tf.reduce_sum(test_coords * test_coords, axis=-1))

<tf.Tensor: shape=(2, 3, 3), dtype=float64, numpy=
array([[[1.73205081, 0.78539816, 0.95531662],
        [1.41421356, 2.35619449, 1.57079633],
        [1.22474487, 0.78539816, 0.61547971]],

       [[1.73205081, 0.78539816, 0.95531662],
        [1.41421356, 2.35619449, 1.57079633],
        [1.22474487, 0.78539816, 0.61547971]]])>

In [8]:
pd.spherical_transform(pd.spherical_transform(test_coords), reverse=True)

<tf.Tensor: shape=(2, 3, 3), dtype=float64, numpy=
array([[[ 1.00000000e+00,  1.00000000e+00,  1.00000000e+00],
        [-1.00000000e+00,  1.00000000e+00,  8.65956056e-17],
        [ 5.00000000e-01,  5.00000000e-01,  1.00000000e+00]],

       [[ 1.00000000e+00,  1.00000000e+00,  1.00000000e+00],
        [-1.00000000e+00,  1.00000000e+00,  8.65956056e-17],
        [ 5.00000000e-01,  5.00000000e-01,  1.00000000e+00]]])>

## Masking by distance

In [9]:
test_coords = np.array([[[1.0, 1.0, 1.0],
                        [-1.0, 1.0, 0.0],
                        [0.5, 0.5, 1.0]],
                        [[1.0, 1.0, 1.0],
                        [-1.0, 1.0, 0.0],
                        [0.5, 0.5, 1.0]]])

ref_coords = np.array([[[0.0, 0.0, 0.0]],
                       [[0.0, 0.0, 1.0]]])

In [10]:
distance_mask = pd.DistanceMask()

In [11]:
distance_mask(ref_coords, test_coords, k_neighbors=2)

(<tf.Tensor: shape=(2, 2, 3), dtype=float64, numpy=
 array([[[ 0.5,  0.5,  1. ],
         [-1. ,  1. ,  0. ]],
 
        [[ 0.5,  0.5,  0. ],
         [ 1. ,  1. ,  0. ]]])>,
 <tf.Tensor: shape=(2, 2), dtype=int32, numpy=
 array([[2, 1],
        [2, 0]], dtype=int32)>,
 <tf.Tensor: shape=(2, 2), dtype=float64, numpy=
 array([[1.5, 2. ],
        [0.5, 2. ]])>)

In [12]:
local_coords, near_inds, dist_sq = distance_mask(ref_coords, test_coords, k_neighbors=2)
pd.spherical_transform(local_coords, dist_sq=dist_sq)

<tf.Tensor: shape=(2, 2, 3), dtype=float64, numpy=
array([[[1.22474487, 0.78539816, 0.61547971],
        [1.41421356, 2.35619449, 1.57079633]],

       [[0.70710678, 0.78539816, 1.57079633],
        [1.41421356, 0.78539816, 1.57079633]]])>

In [13]:
#Need to also test if ref included in input coordinates
ref_coords = np.array([[[-1.0, 1.0, 0.0]],
                       [[1.0, 1.0, 1.0]]])

distance_mask(ref_coords, test_coords, k_neighbors=1, ref_included=True)

(<tf.Tensor: shape=(2, 1, 3), dtype=float64, numpy=
 array([[[ 1.5, -0.5,  1. ]],
 
        [[-0.5, -0.5,  0. ]]])>,
 <tf.Tensor: shape=(2, 1), dtype=int32, numpy=
 array([[2],
        [2]], dtype=int32)>,
 <tf.Tensor: shape=(2, 1), dtype=float64, numpy=
 array([[3.5],
        [0.5]])>)

In [14]:
distance_mask(ref_coords, test_coords, k_neighbors=1, ref_included=False)

(<tf.Tensor: shape=(2, 1, 3), dtype=float64, numpy=
 array([[[0., 0., 0.]],
 
        [[0., 0., 0.]]])>,
 <tf.Tensor: shape=(2, 1), dtype=int32, numpy=
 array([[1],
        [0]], dtype=int32)>,
 <tf.Tensor: shape=(2, 1), dtype=float64, numpy=
 array([[-0.],
        [-0.]])>)

In [15]:
#Now try distances with a periodic simulation box
test_coords = np.array([[[1.0, 1.0, 1.0],
                        [-1.0, 1.0, 0.0],
                        [0.5, 0.5, 1.0]],
                        [[1.0, 1.0, 1.0],
                        [-1.0, 1.0, 0.0],
                        [0.5, 0.5, 1.0]]])

ref_coords = np.array([[[0.0, 0.0, 0.0]],
                       [[0.0, 0.0, 1.0]]])

distance_mask = pd.DistanceMask(box_lengths=np.array([3.0, 3.0, 3.0]))

In [16]:
distance_mask(ref_coords, test_coords, k_neighbors=2)

(<tf.Tensor: shape=(2, 2, 3), dtype=float64, numpy=
 array([[[ 0.5,  0.5,  1. ],
         [-1. ,  1. ,  0. ]],
 
        [[ 0.5,  0.5,  0. ],
         [ 1. ,  1. ,  0. ]]])>,
 <tf.Tensor: shape=(2, 2), dtype=int32, numpy=
 array([[2, 1],
        [2, 0]], dtype=int32)>,
 <tf.Tensor: shape=(2, 2), dtype=float64, numpy=
 array([[1.5, 2. ],
        [0.5, 2. ]])>)

In [17]:
local_coords, near_inds, dist_sq = distance_mask(ref_coords, test_coords, k_neighbors=2)
pd.spherical_transform(local_coords, dist_sq=dist_sq)

<tf.Tensor: shape=(2, 2, 3), dtype=float64, numpy=
array([[[1.22474487, 0.78539816, 0.61547971],
        [1.41421356, 2.35619449, 1.57079633]],

       [[0.70710678, 0.78539816, 1.57079633],
        [1.41421356, 0.78539816, 1.57079633]]])>

In [18]:
#Need to also test if ref included in input coordinates
ref_coords = np.array([[[-1.0, 1.0, 0.0]],
                       [[1.0, 1.0, 1.0]]])

distance_mask(ref_coords, test_coords, k_neighbors=1, ref_included=True)

(<tf.Tensor: shape=(2, 1, 3), dtype=float64, numpy=
 array([[[-1. ,  0. ,  1. ]],
 
        [[-0.5, -0.5,  0. ]]])>,
 <tf.Tensor: shape=(2, 1), dtype=int32, numpy=
 array([[0],
        [2]], dtype=int32)>,
 <tf.Tensor: shape=(2, 1), dtype=float64, numpy=
 array([[2. ],
        [0.5]])>)

In [19]:
distance_mask(ref_coords, test_coords, k_neighbors=1, ref_included=False)

(<tf.Tensor: shape=(2, 1, 3), dtype=float64, numpy=
 array([[[0., 0., 0.]],
 
        [[0., 0., 0.]]])>,
 <tf.Tensor: shape=(2, 1), dtype=int32, numpy=
 array([[1],
        [0]], dtype=int32)>,
 <tf.Tensor: shape=(2, 1), dtype=float64, numpy=
 array([[-0.],
        [-0.]])>)

## Creating probability distributions

In [20]:
params = np.array([[[[0.0, 1.0], [10.0, 1.0], [-10.0, 1.0]],
                    [[0.0, 1.0], [0.0, 2.0], [0.0, 10.0]],
                    [[5.0, 1.0], [5.0, 2.0], [5.0, 10.0]]],
                   [[[0.0, 1.0], [-10.0, 1.0], [10.0, 1.0]],
                    [[0.0, 10.0], [0.0, 2.0], [0.0, 1.0]],
                    [[-5.0, 10.0], [-5.0, 2.0], [-5.0, 1.0]]]])

dist_list = [tfp.distributions.Normal]*3
param_trans = [lambda x, y: [x, y]]*3

In [21]:
joint_dist = pd.create_dist(params, dist_list, param_trans)

In [22]:
sample = tf.stack(joint_dist.sample(), axis=-1)
print(sample)

tf.Tensor(
[[[  0.23292805  12.04493878  -7.83952683]
  [ -1.18649911   0.29522801   8.1184565 ]
  [  3.6817518    3.12434873  -7.36771411]]

 [[ -0.35549221 -11.37528835  10.143287  ]
  [  2.59537173  -0.31698058   0.34192527]
  [ -6.74109612  -5.2156615   -4.37855254]]], shape=(2, 3, 3), dtype=float64)


In [23]:
joint_dist.log_prob(tf.unstack(sample, axis=-1))

<tf.Tensor: shape=(2, 3), dtype=float64, numpy=
array([[-7.2086528 , -6.79687957, -7.82599726],
       [-3.77597756, -5.85724368, -5.96661716]])>

In [24]:
np.sum(-0.5*((sample[-1, 1, :] - params[-1, 1, :, 0])**2)/(params[-1, 1, :, 1]**2) - np.log(params[-1, 1, :, 1]) - 0.5*np.log(2.0*np.pi))

-5.857243675854347

## Solvation neural networks

In [25]:
#Input can be solute or solvent coordinates, with extra_inputs being the other
#Start by testing without augmented input (only relevant for first network to predict layer of solvent for first solute)
solute_net = pd.SolvationNet((8, 3), out_event_dims=2, hidden_dim=50, n_hidden=3, augment_input=False)

In [26]:
solute_coords = tf.random.normal((1, 2, 3))

In [27]:
params, shifts = solute_net(solute_coords)
print(params, shifts)

tf.Tensor(
[[[[-0.02379658  0.24545582]
   [ 0.12956512  0.27409476]
   [ 0.11925282  0.02949036]]

  [[-0.22070679  0.308375  ]
   [ 0.10471827  0.09005634]
   [ 0.06206036  0.31988555]]

  [[-0.04578574  0.01402279]
   [-0.156299    0.06304955]
   [-0.04154722  0.04393618]]

  [[ 0.35807756 -0.08217393]
   [-0.13659167 -0.01435896]
   [-0.00223473 -0.072791  ]]

  [[ 0.12733865 -0.0460002 ]
   [-0.18548144  0.05036053]
   [ 0.17022954  0.09171958]]

  [[-0.18497781  0.125972  ]
   [-0.18594837 -0.12062389]
   [-0.05190141 -0.28748202]]

  [[ 0.07677729  0.10898297]
   [-0.03361616 -0.15803109]
   [ 0.00987837 -0.2932418 ]]

  [[ 0.13682924  0.02703234]
   [ 0.09476364  0.07638013]
   [ 0.11665244  0.31628913]]]], shape=(1, 8, 3, 2), dtype=float32) tf.Tensor(
[[[[-0.3358397   0.25417396]
   [-0.16339779 -0.32258683]
   [ 0.06453422 -0.26169235]]

  [[ 0.03252108 -0.21226408]
   [-0.11045979  0.11668792]
   [ 0.23011714  0.3562886 ]]

  [[-0.44168344 -0.52185684]
   [-0.03945018 -0.138

In [28]:
dist = pd.create_dist(params + shifts, [tfp.distributions.Normal]*3, [lambda x, y: [x, tf.math.exp(0.5*y)]]*3)
sample = tf.stack(dist.sample(), axis=-1)
print(sample)

tf.Tensor(
[[[-1.4158545  -0.99324286  0.2953406 ]
  [ 0.03383498  0.53300834  2.5849373 ]
  [-0.18131882 -0.05752528  0.8517498 ]
  [-0.33809483  2.1218865   0.03103745]
  [-0.34894547  0.5718595   0.8472733 ]
  [-0.3696894   0.45907575  0.4821236 ]
  [ 1.3244584   0.5103633   0.70408094]
  [-0.29772598  0.98282194  0.45256186]]], shape=(1, 8, 3), dtype=float32)


In [29]:
#First set of 3 (first particle) should not change shifts, even with different input due to autoregressive
#All others should change, though
new_shifts = solute_net(solute_coords, sampled_input=sample)
print(new_shifts)

tf.Tensor(
[[[[-0.3358397   0.25417396]
   [-0.16339779 -0.32258683]
   [ 0.06453422 -0.26169235]]

  [[ 0.03375171 -0.20828189]
   [-0.09607562  0.12971741]
   [ 0.22670537  0.36193663]]

  [[-0.42428315 -0.5263384 ]
   [-0.04825987 -0.16737193]
   [ 0.2267499   0.28345013]]

  [[ 0.26253954 -0.30709198]
   [ 0.23626184  0.24931854]
   [-0.42569014  0.00420201]]

  [[-0.45593983  0.08344761]
   [ 0.2524996   0.22826236]
   [-0.25032923 -0.06403652]]

  [[ 0.14105138  0.60200083]
   [ 0.04630354  0.6925131 ]
   [ 0.13035977 -0.76964086]]

  [[ 0.27394184 -0.27118647]
   [ 0.8978176  -0.09691283]
   [ 0.15606922  0.02202904]]

  [[ 0.01618966 -0.50535595]
   [ 0.22068036 -0.8412704 ]
   [-0.23749122  0.17567205]]]], shape=(1, 8, 3, 2), dtype=float32)


In [30]:
#And now with solvent inputs (remember, could flip if wanted)
solute_net_solv = pd.SolvationNet((8, 3), out_event_dims=2, hidden_dim=50, n_hidden=3, augment_input=True)

In [31]:
solvent_coords = tf.random.normal((1, 6, 3))

In [32]:
params, shifts = solute_net_solv(solute_coords, extra_coords=solvent_coords)
print(params, shifts)

tf.Tensor(
[[[[ 0.7550958   0.1851994 ]
   [ 0.17669357  0.05058133]
   [ 0.18894506  0.1655793 ]]

  [[ 0.09594494 -0.36294055]
   [-0.14140862 -0.19461644]
   [ 0.33395585 -0.15586087]]

  [[-0.03108068  0.18690585]
   [-0.19554432  0.08475019]
   [-0.20662773  0.2642927 ]]

  [[-0.37536556  0.03253983]
   [ 0.20404223 -0.29153493]
   [-0.1704972  -0.3375806 ]]

  [[ 0.25607115 -0.02822217]
   [ 0.07041867  0.10393742]
   [ 0.06587229  0.09794677]]

  [[ 0.16255382 -0.27246606]
   [-0.3043923   0.21931075]
   [-0.3118653  -0.18846235]]

  [[ 0.06915081  0.11242154]
   [ 0.19008125  0.27000338]
   [-0.1255779  -0.07191899]]

  [[ 0.01585248 -0.11549129]
   [-0.07575282 -0.13416024]
   [ 0.16235493 -0.06734431]]]], shape=(1, 8, 3, 2), dtype=float32) tf.Tensor(
[[[[-1.3780967e+00 -5.6949031e-01]
   [ 9.4148654e-01 -3.7595785e-01]
   [ 9.8116690e-01 -5.2568877e-01]]

  [[-2.5131774e-01  9.2313236e-01]
   [ 2.0212075e-01  5.7437831e-01]
   [-5.0662857e-01 -5.0069755e-01]]

  [[ 1.0708573e

In [33]:
dist = pd.create_dist(params + shifts, [tfp.distributions.Normal]*3, [lambda x, y: [x, tf.math.exp(0.5*y)]]*3)
sample = tf.stack(dist.sample(), axis=-1)
print(sample)

tf.Tensor(
[[[-0.38588247  2.20116     0.08851027]
  [-0.98316795 -0.8278707  -1.3412374 ]
  [ 1.618958   -2.5488262   0.8774866 ]
  [-0.95591927  0.31746793  1.2323648 ]
  [-0.28575385 -0.9622369  -1.4316034 ]
  [-0.65631396 -0.05193239 -0.248436  ]
  [-0.8922453   0.53043336  1.4878199 ]
  [ 2.496495   -0.03772695  1.9672512 ]]], shape=(1, 8, 3), dtype=float32)


In [34]:
new_shifts = solute_net_solv(solute_coords, extra_coords=solvent_coords, sampled_input=sample)
print(new_shifts)

tf.Tensor(
[[[[-1.3780967  -0.5694903 ]
   [ 0.94148654 -0.37595785]
   [ 0.9811669  -0.52568877]]

  [[-0.25789416  0.9507365 ]
   [ 0.21071029  0.5677661 ]
   [-0.5036067  -0.5149995 ]]

  [[ 1.0648761  -0.3518599 ]
   [-1.2339132   1.2255844 ]
   [-0.0669743   0.19580929]]

  [[-0.6965189   0.50652725]
   [-0.03802601 -0.6176303 ]
   [ 1.9548005  -0.9394873 ]]

  [[-0.3045347  -2.4599104 ]
   [ 0.01998843  0.29156917]
   [ 0.04698271  0.16596243]]

  [[-0.9606475   0.10015298]
   [ 0.6002346  -1.2442107 ]
   [-0.34869766 -0.5321461 ]]

  [[-0.3953134  -0.06503797]
   [ 0.78322333  0.7560538 ]
   [ 0.35978368  0.41703868]]

  [[ 1.3709139   0.78528464]
   [ 0.04287516  2.0449677 ]
   [ 1.9870697   0.12183425]]]], shape=(1, 8, 3, 2), dtype=float32)


In [35]:
#Now that it's been called (so built) with extra_coords provided, must provide always with this layer
new_shifts = solute_net_solv(solute_coords, sampled_input=sample)
print(new_shifts)

ValueError: Input 0 of layer dense_18 is incompatible with the layer: expected axis -1 of input shape to have value 24 but received input with shape (1, 6)

In [36]:
#Now pretend we're working with a central solvent particle and adding one more
#(conditioned on other solutes and solvent already around)
solvent_net = pd.SolvationNet((1, 3), out_event_dims=2, hidden_dim=50, n_hidden=3, augment_input=True)

In [37]:
params, shifts = solvent_net(solvent_coords, extra_coords=solute_coords)
print(params, shifts)

tf.Tensor(
[[[[ 0.416312   -0.13859941]
   [ 0.29045895 -0.02378876]
   [-0.27068567  0.1401866 ]]]], shape=(1, 1, 3, 2), dtype=float32) tf.Tensor(
[[[[ 2.5410247   0.81326175]
   [-2.58177     1.0721965 ]
   [-0.01803738  0.7218741 ]]]], shape=(1, 1, 3, 2), dtype=float32)


## Decimation encoding

In [38]:
test_coords = np.random.normal(size=(2, 6, 3))
test_coords

array([[[ 0.35197941, -1.77454827,  0.14737934],
        [-0.19495841, -1.2578131 ,  0.88655394],
        [-0.8343571 , -0.67848354, -0.28848646],
        [ 0.32884625,  1.53408741,  0.08089258],
        [-0.22151719, -1.08639542,  0.55394217],
        [ 0.76592815, -1.39207851, -0.1559114 ]],

       [[-0.11144715,  0.27861419, -2.3066409 ],
        [-2.67397693,  0.57243877,  0.1235203 ],
        [-0.14179341, -0.28693329, -0.80414923],
        [-1.07563638,  0.05536577,  1.02700245],
        [ 1.45412011, -0.82404985,  0.4734435 ],
        [ 1.05640824,  0.83340425,  0.47019862]]])

In [39]:
cg_mask = np.array([True, False, True, False, True, False])
encoder = pd.DecimationEncoder(cg_mask)

In [40]:
encoder(test_coords)

(<tf.Tensor: shape=(2, 3, 3), dtype=float32, numpy=
 array([[[ 0.3519794 , -1.7745483 ,  0.14737934],
         [-0.8343571 , -0.67848355, -0.28848645],
         [-0.22151719, -1.0863954 ,  0.55394214]],
 
        [[-0.11144715,  0.2786142 , -2.3066409 ],
         [-0.1417934 , -0.2869333 , -0.8041492 ],
         [ 1.4541202 , -0.82404983,  0.4734435 ]]], dtype=float32)>,
 <tf.Tensor: shape=(2, 3, 3), dtype=float32, numpy=
 array([[[-0.19495842, -1.2578131 ,  0.88655394],
         [ 0.32884625,  1.5340874 ,  0.08089258],
         [ 0.76592815, -1.3920785 , -0.1559114 ]],
 
        [[-2.673977  ,  0.5724388 ,  0.12352029],
         [-1.0756364 ,  0.05536577,  1.0270025 ],
         [ 1.0564083 ,  0.83340424,  0.4701986 ]]], dtype=float32)>)

In [41]:
encoder = pd.DecimationEncoder(cg_mask, identical_randomize=True)
encoder(test_coords)

(<tf.Tensor: shape=(2, 3, 3), dtype=float32, numpy=
 array([[[-0.22151719, -1.0863954 ,  0.55394214],
         [-0.8343571 , -0.67848355, -0.28848645],
         [ 0.3519794 , -1.7745483 ,  0.14737934]],
 
        [[ 1.4541202 , -0.82404983,  0.4734435 ],
         [-0.11144715,  0.2786142 , -2.3066409 ],
         [-0.1417934 , -0.2869333 , -0.8041492 ]]], dtype=float32)>,
 <tf.Tensor: shape=(2, 3, 3), dtype=float32, numpy=
 array([[[-0.19495842, -1.2578131 ,  0.88655394],
         [ 0.32884625,  1.5340874 ,  0.08089258],
         [ 0.76592815, -1.3920785 , -0.1559114 ]],
 
        [[-2.673977  ,  0.5724388 ,  0.12352029],
         [-1.0756364 ,  0.05536577,  1.0270025 ],
         [ 1.0564083 ,  0.83340424,  0.4701986 ]]], dtype=float32)>)

## Solute decoding

In [42]:
solute_decoder = pd.SoluteDecoder(3,
                             box_lengths=np.array([3.0, 3.0, 3.0]),
                             k_solute_neighbors=3,
                             k_solvent_neighbors=3,
                             name='solutedecoder')

In [43]:
solute_coords = tf.random.uniform((2, 5, 3), minval=-1.5, maxval=1.5)

In [44]:
sample_out, params_out, ref_out, unused_data = solute_decoder(solute_coords)
print(sample_out.shape, params_out.shape, ref_out.shape)

(2, 15, 3) (2, 15, 3, 2) (2, 15, 3)


In [45]:
print(unused_data)

None


In [46]:
sample_out - ref_out

<tf.Tensor: shape=(2, 15, 3), dtype=float32, numpy=
array([[[ 0.8336817 ,  0.05688685,  1.3860383 ],
        [-0.97174877, -0.19506472,  2.8657384 ],
        [-1.3552151 ,  0.4493233 , -0.40065038],
        [-2.0063853 , -2.368535  ,  2.3537383 ],
        [-0.05119061, -0.16374892,  1.626596  ],
        [ 0.58612454, -1.2799978 ,  0.3651464 ],
        [ 2.9601038 ,  0.811373  ,  0.70613337],
        [ 0.17170656,  0.14409924, -0.22456977],
        [ 2.817066  , -1.2336887 ,  1.6163819 ],
        [-0.10634845,  1.6237172 , -0.15497398],
        [-0.14150959, -0.95371586, -2.0264423 ],
        [-1.0646422 , -0.94507337,  1.6430466 ],
        [-0.0923816 ,  0.8239846 ,  0.34679008],
        [-0.72767675, -0.73405445,  0.89950716],
        [-1.3501872 , -0.9117471 ,  0.70877534]],

       [[ 0.7251493 , -0.07881433, -1.2822386 ],
        [ 0.7912017 , -1.5737748 ,  1.441805  ],
        [-0.08033657,  0.66727054, -2.2722359 ],
        [ 1.8877407 , -0.8425393 , -0.41746005],
        [ 0.227

In [47]:
params_out

<tf.Tensor: shape=(2, 15, 3, 2), dtype=float32, numpy=
array([[[[ 0.47383893, -0.67874056],
         [ 0.03557665, -0.38291192],
         [-0.0337587 , -0.04436594]],

        [[-0.22422947, -1.2833833 ],
         [ 0.6337109 ,  0.7042277 ],
         [ 0.45558596,  0.44783542]],

        [[-0.6130725 , -0.0074392 ],
         [ 0.24627009, -0.24166892],
         [-0.42126945, -0.10587405]],

        [[ 1.3109969 ,  1.0063645 ],
         [ 1.8045746 ,  1.0621383 ],
         [ 1.0484841 ,  0.3242733 ]],

        [[-0.3869249 ,  0.5219101 ],
         [ 0.403221  ,  1.6266562 ],
         [ 0.88865775, -0.94309497]],

        [[ 0.8810882 , -1.9067734 ],
         [-1.027667  ,  0.6797792 ],
         [ 0.32283264, -0.28039154]],

        [[ 0.8133907 ,  0.19880359],
         [ 0.10626364,  0.2635614 ],
         [-0.39258283,  0.5138985 ]],

        [[ 0.8291457 ,  0.52107143],
         [ 0.24329543, -2.6577535 ],
         [ 0.27872142,  0.33230758]],

        [[ 2.0116036 ,  0.34650487],
    

In [48]:
#For convenience, added log_prob function to decoder
solute_decoder.get_log_probs(sample_out, params_out, ref_out)

<tf.Tensor: shape=(2, 15), dtype=float32, numpy=
array([[-3.3854094, -5.7252116, -2.8832557, -9.590844 , -4.1237736,
        -2.3132644, -5.6857514, -2.1439962, -3.3392236, -3.1796653,
        -3.5966702, -4.4956026, -3.7993617, -4.1653123, -4.2372975],
       [-3.680969 , -4.3334446, -5.3502874, -3.9719794, -4.5625944,
        -4.911551 , -2.9716413, -3.1334872, -3.3870206, -2.8591685,
        -4.1119614, -6.565435 , -5.1597805, -5.0933022, -4.9748764]],
      dtype=float32)>

In [49]:
#Above tested generation, now test if have training data
training_data = tf.random.normal((2, 60, 3))

In [50]:
sample_out, params_out, ref_out, unused_data = solute_decoder(solute_coords, train_data=training_data)
print(sample_out.shape, params_out.shape, ref_out.shape)

(2, 15, 3) (2, 15, 3, 2) (2, 15, 3)


In [51]:
training_data

<tf.Tensor: shape=(2, 60, 3), dtype=float32, numpy=
array([[[ 0.26393723, -0.71674263,  0.77520674],
        [-0.32947597,  0.70190775, -0.05947131],
        [ 0.21751098,  0.19541635, -0.12415879],
        [ 1.1110301 ,  1.1321521 ,  0.3721803 ],
        [ 1.4552447 , -0.1140788 ,  1.9399314 ],
        [ 0.75434804,  0.4982361 , -0.15094593],
        [-0.5321035 ,  0.08645869,  0.318812  ],
        [ 0.22621667, -0.21715117,  0.73490256],
        [-0.6868557 ,  0.72648364, -0.7172816 ],
        [ 1.6075745 ,  1.4685405 ,  1.3330345 ],
        [-2.5688188 , -1.8809812 , -0.90459156],
        [ 1.0510867 ,  1.2612305 , -0.5535334 ],
        [ 0.44043976,  0.03686449,  0.38674033],
        [-2.2207112 ,  0.13744473,  1.1882442 ],
        [-0.38844374,  1.0599322 ,  0.41078016],
        [-0.19054541, -0.11730913,  1.2515503 ],
        [-1.4024612 ,  0.2902992 , -0.34164745],
        [-0.04007145,  0.576949  , -0.18520083],
        [-0.4225353 , -0.57950985,  0.05380371],
        [-0.95550

In [52]:
unused_data

<tf.Tensor: shape=(2, 45, 3), dtype=float32, numpy=
array([[[ 0.26393723, -0.71674263,  0.77520674],
        [-0.32947597,  0.70190775, -0.05947131],
        [ 0.21751098,  0.19541635, -0.12415879],
        [ 0.75434804,  0.4982361 , -0.15094593],
        [-0.5321035 ,  0.08645869,  0.318812  ],
        [ 0.22621667, -0.21715117,  0.73490256],
        [-0.6868557 ,  0.72648364, -0.7172816 ],
        [ 1.6075745 ,  1.4685405 ,  1.3330345 ],
        [ 0.44043976,  0.03686449,  0.38674033],
        [-2.2207112 ,  0.13744473,  1.1882442 ],
        [-0.38844374,  1.0599322 ,  0.41078016],
        [-0.19054541, -0.11730913,  1.2515503 ],
        [-0.04007145,  0.576949  , -0.18520083],
        [-0.4225353 , -0.57950985,  0.05380371],
        [ 0.3273793 ,  0.44960517,  0.16199656],
        [ 0.1539331 ,  0.3426948 ,  0.02784622],
        [ 0.709939  , -0.78621745, -0.6908073 ],
        [-0.13660015,  0.2576262 ,  2.1143885 ],
        [ 0.18416835,  0.29799697, -3.6712215 ],
        [ 0.68248

In [53]:
sample_out

<tf.Tensor: shape=(2, 15, 3), dtype=float32, numpy=
array([[[ 1.4552447 , -0.11407879, -1.0600686 ],
        [ 0.43118125,  1.1190188 , -0.90459156],
        [-0.02805418,  0.23880617, -1.4654468 ],
        [ 1.597539  ,  0.2902992 , -0.34164745],
        [ 0.11329991,  0.99487376,  0.85636604],
        [ 2.1023407 ,  0.8081109 ,  1.1681136 ],
        [ 1.1110301 ,  1.1321521 ,  0.3721803 ],
        [ 1.4536576 ,  1.0964034 ,  0.62047756],
        [ 1.0474855 ,  0.09206176,  1.0476453 ],
        [ 1.4293184 ,  0.71614265,  1.0434474 ],
        [ 2.0444906 , -0.30675924,  1.427661  ],
        [ 0.31054798,  0.15235573,  2.650246  ],
        [ 0.1481477 , -0.76727337, -0.44821805],
        [ 1.0510867 , -1.7387695 , -0.5535334 ],
        [ 1.172103  , -0.33534405, -0.273342  ]],

       [[-0.4542118 , -0.61438686,  0.6686751 ],
        [ 0.28118896, -0.7959853 ,  0.3252164 ],
        [-1.5043046 ,  0.99327993,  0.5628963 ],
        [ 0.7242004 ,  0.9174552 , -0.4594909 ],
        [ 2.261

In [54]:
sample_out - ref_out

<tf.Tensor: shape=(2, 15, 3), dtype=float32, numpy=
array([[[ 0.64286435, -0.2108908 ,  0.05854523],
        [-0.38119906,  1.0222068 ,  0.21402228],
        [-0.8404345 ,  0.14199416, -0.346833  ],
        [ 1.4539378 , -0.21824262, -0.16877693],
        [-0.03030127,  0.48633194,  1.0292366 ],
        [ 1.9587395 ,  0.29956907,  1.3409841 ],
        [ 0.18115687, -0.00314903, -0.17200819],
        [ 0.5237844 , -0.03889775,  0.07628906],
        [ 0.11761224, -1.0432394 ,  0.50345683],
        [ 0.6540146 ,  0.25021112, -0.25577652],
        [ 1.2691867 , -0.7726908 ,  0.12843704],
        [-0.46475586, -0.3135758 ,  1.351022  ],
        [-1.0826143 , -0.19147712,  0.15334135],
        [-0.17967534, -1.1629733 ,  0.04802603],
        [-0.05865896,  0.2404522 ,  0.3282174 ]],

       [[-0.14149234, -0.60829574,  0.12080115],
        [ 0.5939084 , -0.78989416, -0.22265756],
        [-1.1915852 ,  0.99937105,  0.01502234],
        [ 0.4462059 ,  0.20281982, -1.106412  ],
        [ 1.983

In [55]:
params_out

<tf.Tensor: shape=(2, 15, 3, 2), dtype=float32, numpy=
array([[[[ 0.47383893, -0.67874056],
         [ 0.03557665, -0.38291192],
         [-0.0337587 , -0.04436594]],

        [[-0.2621433 , -1.1468856 ],
         [ 0.61521506,  0.6563058 ],
         [ 0.45697862,  0.40357527]],

        [[-0.6014812 , -0.18545824],
         [ 0.26560807, -0.15702325],
         [-0.27442434, -0.15525849]],

        [[ 1.6751978 ,  0.10624464],
         [-0.2935787 , -0.07050999],
         [ 0.31004894, -0.32376567]],

        [[ 0.19791818,  1.7542708 ],
         [ 0.3655042 , -1.2175819 ],
         [ 0.72117615, -1.78944   ]],

        [[ 2.0440965 , -1.43951   ],
         [-0.12078054, -0.48031908],
         [ 0.9397254 , -1.2417763 ]],

        [[ 0.6648004 ,  0.1203116 ],
         [-0.1478462 ,  0.31266373],
         [-0.569482  ,  0.15464711]],

        [[-0.05259411, -0.24778342],
         [-0.02423684, -2.7753966 ],
         [-0.12280667,  0.5682157 ]],

        [[ 0.6960901 ,  1.6495652 ],
    

In [56]:
solute_decoder.get_log_probs(sample_out, params_out, ref_out)

<tf.Tensor: shape=(2, 15), dtype=float32, numpy=
array([[-2.280964 , -2.7983038, -2.5543134, -2.7963214, -2.4436557,
        -1.612888 , -3.2296572, -1.7550983, -3.9355242, -3.5792947,
        -3.3002028, -3.1980584, -2.7390957, -3.6717882, -4.327573 ],
       [-2.6513963, -3.2145023, -3.272739 , -2.6364965, -2.6431005,
        -3.6270428, -2.9534717, -4.7176256, -1.3786598, -1.45429  ,
        -2.9807694, -4.0597796, -2.9642723, -2.3940837, -3.710791 ]],
      dtype=float32)>

## Solvent decoding

In [57]:
solvent_decoder = pd.SolventDecoder(2,
                             box_lengths=np.array([3.0, 3.0, 3.0]),
                             k_solute_neighbors=3,
                             k_solvent_neighbors=3,
                             name='solventdecoder')

In [58]:
solvent_coords = tf.identity(sample_out)
training_data_unused = tf.identity(unused_data)

In [59]:
sample_out, params_out, ref_out, unused_data = solvent_decoder(solvent_coords, solute_coords=solute_coords)
print(sample_out.shape, params_out.shape, ref_out.shape)

(2, 45, 3) (2, 45, 3, 2) (2, 45, 3)


In [60]:
sample_out - ref_out

<tf.Tensor: shape=(2, 45, 3), dtype=float32, numpy=
array([[[ 0.12343693, -0.4202201 , -0.34864056],
        [ 0.07669097,  0.426113  , -0.87373126],
        [-0.28532982,  0.9667836 ,  0.06998658],
        [ 1.9649742 ,  0.06711209,  0.5283185 ],
        [ 1.249219  , -0.6108271 , -0.8826252 ],
        [-1.2946079 ,  1.497864  ,  0.23090541],
        [-3.6773863 ,  0.29107153, -2.240098  ],
        [-0.39596987,  0.39447653,  0.52678525],
        [ 0.44296038,  1.1216705 ,  0.17155373],
        [ 0.6525836 , -1.4394284 ,  0.11215007],
        [ 2.3677952 ,  0.06342742,  1.6725798 ],
        [ 0.66939425, -1.0900638 ,  0.9279554 ],
        [ 1.029613  , -1.7301981 , -0.2502736 ],
        [-0.02132297, -0.17727828,  1.5651343 ],
        [-0.43776953, -0.51404643,  0.12131467],
        [-0.4731174 , -0.34618276, -1.3642681 ],
        [-0.47790334, -0.4624077 ,  0.52710664],
        [ 0.5938833 , -0.16325508,  4.1026382 ],
        [-0.33873653,  0.822685  , -0.61762524],
        [ 0.56883

In [61]:
params_out

<tf.Tensor: shape=(2, 45, 3, 2), dtype=float32, numpy=
array([[[[ 1.14011765e+00,  9.83528554e-01],
         [ 2.18420923e-01,  6.53838038e-01],
         [ 3.66929397e-02, -9.28745627e-01]],

        [[-9.01193857e-01, -1.01705924e-01],
         [ 4.41511348e-02, -1.08802450e+00],
         [-2.52516568e-02,  2.21947074e-01]],

        [[ 5.58661520e-02, -2.12526298e+00],
         [-7.17604280e-01, -5.74384451e-01],
         [-2.06811070e-01, -3.51178646e-02]],

        [[ 9.11037326e-01,  5.54280281e-01],
         [-1.24505013e-01, -7.11070657e-01],
         [ 1.53751388e-01, -4.18220520e-01]],

        [[-8.37231696e-01,  2.91448027e-01],
         [-6.99740499e-02, -3.28218758e-01],
         [-5.74121118e-01,  2.45969966e-02]],

        [[ 4.10268545e-01,  3.54031861e-01],
         [ 9.00114477e-01, -1.50641704e+00],
         [-2.50267088e-01,  6.18266016e-02]],

        [[-7.35000014e-01,  5.93128562e-01],
         [ 5.63981235e-01, -1.10470569e+00],
         [-5.35043001e-01,  4.518

In [62]:
#For convenience, added log_prob function to decoder
solvent_decoder.get_log_probs(sample_out, params_out, ref_out)

<tf.Tensor: shape=(2, 45), dtype=float32, numpy=
array([[-3.598393 , -3.3070903, -4.436056 , -2.9323316, -4.6265845,
        -4.146181 , -6.156593 , -3.2284634, -3.6166973, -3.4447887,
        -5.403331 , -4.9266815, -4.8836765, -2.7261915, -4.338925 ,
        -4.4963937, -2.6702323, -5.7534804, -3.7002516, -6.7871437,
        -2.6283526, -5.5147586, -3.8186643, -6.2947555, -5.2873406,
        -2.3010952, -3.346161 , -3.3937974, -3.0210865, -3.3522992,
        -3.76145  , -3.6697795, -4.2524896, -3.6685584, -2.7690055,
        -5.3313293, -3.9003813, -3.9239602, -6.8004456, -4.021591 ,
        -3.6852093, -3.7404125, -2.6353917, -3.3915076, -3.204545 ],
       [-3.6118438, -1.3173392, -3.3612902, -5.703704 , -6.0400887,
        -2.3099625, -3.3630376, -2.798106 , -3.1293657, -4.437932 ,
        -4.133655 , -4.977697 , -5.001131 , -2.4529667, -4.034671 ,
        -4.309471 , -2.0922344, -4.3303604, -5.7385616, -3.502195 ,
        -2.8904533, -2.8408077, -3.8040724, -3.377536 , -5.468847 

In [63]:
#Now do with training data (leftover from solute decoding)
sample_out, params_out, ref_out, unused_data = solvent_decoder(solvent_coords,
                                                  solute_coords=solute_coords,
                                                  train_data=training_data_unused)
print(sample_out.shape, params_out.shape, ref_out.shape)

(2, 45, 3) (2, 45, 3, 2) (2, 45, 3)


In [64]:
unused_data

<tf.Tensor: shape=(2, 0, 3), dtype=float32, numpy=array([], shape=(2, 0, 3), dtype=float32)>

In [65]:
sample_out

<tf.Tensor: shape=(2, 45, 3), dtype=float32, numpy=
array([[[ 2.8633997 ,  0.2576262 , -0.88561165],
        [ 0.05113849,  0.4466184 , -0.8329168 ],
        [-0.19054541, -0.11730912, -1.7484498 ],
        [ 2.4499388 , -0.27324826, -0.14391424],
        [-2.245652  ,  0.4982361 , -0.15094593],
        [ 2.4754367 ,  0.9312377 ,  0.5307437 ],
        [ 0.682483  ,  1.8729491 , -0.05705923],
        [ 0.5117935 ,  1.0830097 , -0.17820299],
        [ 1.6075745 ,  1.4685405 ,  1.3330345 ],
        [ 2.254748  ,  0.70257413,  2.04718   ],
        [ 2.3131444 ,  0.72648364,  2.2827184 ],
        [-0.5321035 ,  0.08645868,  3.3188121 ],
        [ 0.2281462 , -1.5717171 , -0.58806956],
        [ 1.1146951 , -2.9547322 , -0.13426077],
        [ 1.9772346 , -0.4640722 , -1.3439821 ],
        [ 0.917237  , -0.13660878, -2.050876  ],
        [ 0.18416834,  0.29799694, -0.6712215 ],
        [ 0.45407164,  1.402507  , -1.6716096 ],
        [ 1.9807253 , -0.49170157, -0.5292878 ],
        [ 0.04779

In [66]:
sample_out - ref_out

<tf.Tensor: shape=(2, 45, 3), dtype=float32, numpy=
array([[[ 1.4081551 ,  0.371705  ,  0.17445695],
        [-0.38004276, -0.67240036,  0.07167476],
        [-0.16249123, -0.35611528, -0.28300297],
        [ 0.8523998 , -0.5635475 ,  0.19773321],
        [-2.3589518 , -0.49663767, -1.0073119 ],
        [ 0.373096  ,  0.1231268 , -0.6373699 ],
        [-0.42854708,  0.74079704, -0.42923954],
        [-0.94186413, -0.01339364, -0.79868054],
        [ 0.560089  ,  1.3764788 ,  0.28538918],
        [ 0.8254297 , -0.01356852,  1.0037326 ],
        [ 0.26865387,  1.033243  ,  0.8550575 ],
        [-0.8426515 , -0.06589705,  0.6685662 ],
        [ 0.07999849, -0.8044438 , -0.13985151],
        [ 0.06360841, -1.2159626 ,  0.4192726 ],
        [ 0.80513155, -0.12872815, -1.0706401 ],
        [-0.5380077 , -0.02252999, -0.9908073 ],
        [-0.24701291, -0.82102185,  0.23337007],
        [ 0.48212582,  1.1637008 , -0.20616281],
        [ 0.38318634, -0.7820008 , -0.18764037],
        [-0.06550

In [67]:
params_out

<tf.Tensor: shape=(2, 45, 3, 2), dtype=float32, numpy=
array([[[[ 1.14011765e+00,  9.83528554e-01],
         [ 2.18420923e-01,  6.53838038e-01],
         [ 3.66929397e-02, -9.28745627e-01]],

        [[-6.96684122e-01,  3.23532224e-01],
         [-5.47349632e-01, -1.10962105e+00],
         [ 1.55373454e-01, -5.36144495e-01]],

        [[-4.49221969e-01, -8.42945874e-01],
         [-6.22938946e-02,  1.85685351e-01],
         [-3.11588883e-01,  7.05880105e-01]],

        [[ 9.11037326e-01,  5.54280281e-01],
         [-1.24505013e-01, -7.11070657e-01],
         [ 1.53751388e-01, -4.18220520e-01]],

        [[-2.15003276e+00,  1.02504003e+00],
         [-2.38875061e-01, -7.05369353e-01],
         [-8.95110250e-01,  5.42599857e-01]],

        [[-3.34204853e-01,  9.75678980e-01],
         [ 2.63719678e-01, -1.74253237e+00],
         [-5.67234635e-01, -1.66913748e-01]],

        [[-6.48261726e-01,  3.99543881e-01],
         [ 6.75787508e-01, -1.23136199e+00],
         [-3.49264264e-01,  3.388

In [68]:
solvent_decoder.get_log_probs(sample_out, params_out, ref_out)

<tf.Tensor: shape=(2, 45), dtype=float32, numpy=
array([[-3.1546907, -2.161677 , -2.9126778, -2.6680114, -3.266698 ,
        -2.4435747, -2.536028 , -3.080057 , -4.072207 , -2.5719955,
        -3.7549706, -2.3955295, -3.8903341, -1.7808449, -3.6026013,
        -2.3721638, -3.431187 , -3.3825603, -2.978187 , -3.4649224,
        -2.961567 , -2.789726 , -3.307129 , -3.1924644, -3.4683175,
        -2.6581912, -2.7457504, -3.2488227, -3.3771925, -2.9494865,
        -3.570778 , -3.171203 , -4.315799 , -3.996235 , -3.1504748,
        -3.8764422, -4.4339843, -3.0799222, -3.228052 , -3.8990607,
        -5.1890287, -3.3028483, -3.3736608, -3.1985273, -4.061404 ],
       [-2.456504 , -1.4426608, -2.4327419, -2.1717885, -2.5091634,
        -2.0167367, -2.5612185, -2.9819584, -2.38617  , -3.5076218,
        -3.6467805, -3.3283474, -3.0994244, -2.353569 , -2.9181237,
        -3.5591607, -3.001525 , -4.320714 , -3.0254698, -3.8415453,
        -2.9350722, -2.38406  , -3.2761126, -3.0180025, -3.5756106

## Full decoding with more complex distributions

In [69]:
#Try again with more complicated transform function
#Which will also require a more complicated set of sampling distributions
#Don't want to use Gaussian for r, theta, and phi
#Want to use Gamma for r and von Mises for angles
#For Gamma, specify log(mean) and log(var) and convert to distribution parameters
#Must do log(mean) since mean is strictly positive, which requires mean_transform
decoder = pd.SoluteSolventDecoder(3, 2,
                             box_lengths=np.array([3.0, 3.0, 3.0]),
                             k_solute_neighbors=3,
                             k_solvent_neighbors=3,
                             tfp_dist_list=[tfp.distributions.Gamma,
                                            tfp.distributions.VonMises,
                                            tfp.distributions.VonMises],
                             param_transforms=[lambda x, y: [tf.math.exp(2*x - y), tf.math.exp(x - y)],
                                               lambda x, y: [x, tf.math.exp(-y)],
                                               lambda x, y: [x, tf.math.exp(-y)]],
                             mean_transforms=[tf.math.exp, tf.identity, tf.identity],
                             coord_transform=pd.spherical_transform)

In [70]:
solute_coords = tf.random.uniform((2, 5, 3), minval=-1.5, maxval=1.5)

In [71]:
sample_out, params_out, ref_out, unused_data = decoder(solute_coords)
print(sample_out.shape, params_out.shape, ref_out.shape)

(2, 60, 3) (2, 60, 3, 2) (2, 60, 3)


In [72]:
sample_out - ref_out

<tf.Tensor: shape=(2, 60, 3), dtype=float32, numpy=
array([[[ 1.98708177e-02,  1.52838826e-02,  5.14532328e-02],
        [ 6.15119934e-05, -6.53147697e-04, -2.45690346e-04],
        [ 1.59545076e+00, -3.26237202e-01,  4.28077030e+00],
        [-1.19162664e+01,  1.56795990e+00, -6.72304153e+00],
        [ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00],
        [-9.31804955e-01,  8.90364170e-01,  2.65949607e-01],
        [ 2.35737157e+00,  6.04532528e+00, -1.52590621e+00],
        [ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00],
        [-5.28097153e-05, -1.43051147e-05,  4.20808792e-05],
        [-2.01803017e+00,  6.52577400e-01, -1.26085396e+01],
        [-7.65323639e-05, -7.14063644e-05,  4.95463610e-05],
        [ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00],
        [-3.05266261e+00, -2.08868861e-01,  5.27312338e-01],
        [-2.69567966e-03,  4.35059667e-02, -5.49995899e-03],
        [ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00],
        [ 0.00000000e+00,  0.0000

In [73]:
params_out

<tf.Tensor: shape=(2, 60, 3, 2), dtype=float32, numpy=
array([[[[-1.7425725e+00,  9.5725238e-01],
         [-2.0196581e+00,  2.7478486e-02],
         [-5.8626974e-01, -1.4174404e+00]],

        [[-1.8022878e+00,  2.3772842e-01],
         [-4.4621280e-01, -2.1322260e+00],
         [ 2.4439478e+00, -1.1399449e+00]],

        [[ 2.2930560e-01,  9.3413764e-01],
         [-2.7992318e+00,  2.4182706e+00],
         [-2.2869251e+00,  1.0889691e-01]],

        [[ 2.6567101e+00,  3.9586645e-01],
         [-5.6530535e-01,  2.7255580e-01],
         [-4.1942373e-01,  1.9774427e+00]],

        [[-1.2786601e+00,  3.6020100e-01],
         [ 3.7986729e+00, -1.2263867e+00],
         [ 5.5587912e-01, -2.2901921e+00]],

        [[ 7.9181957e-01, -8.5271347e-01],
         [ 3.2516031e+00,  1.2087592e+00],
         [-1.5857506e+00, -4.7505856e-01]],

        [[ 1.8036314e+00,  2.6070235e+00],
         [-1.7677748e+00, -1.6411619e+00],
         [ 5.9989467e-02,  1.0123758e+00]],

        [[-3.4271808e+00,  1

In [74]:
#For convenience, added log_prob function to decoder
decoder.get_log_probs(sample_out, params_out, ref_out)

<tf.Tensor: shape=(2, 60), dtype=float32, numpy=
array([[ -6.857469  ,  -1.5366277 ,  -8.379789  ,  -5.784135  ,
                 inf,  -6.9292893 , -14.660288  ,          inf,
          3.2428854 ,  -4.4781485 ,   6.4430594 ,          inf,
         -6.5302186 ,   0.85082847,          inf,          inf,
                 inf,          inf,          inf,  -4.7128563 ,
                 inf,          inf, -21.641037  ,  -3.4539216 ,
          4.488715  ,  -3.4520092 ,          inf,  -7.097623  ,
                 inf,          inf,          inf,  -0.9532613 ,
         -9.289783  ,          inf,  -2.8976188 ,          inf,
          0.13500744, -45.097733  ,  -4.0170536 ,  -3.8979442 ,
         -1.2809662 ,   1.6974928 ,  -2.182263  ,          inf,
         -4.710357  ,          inf,  -1.8490217 , -12.234593  ,
                 inf, -11.2926235 ,          inf, -12.904577  ,
          0.47254753,  -4.980419  ,  -1.7825274 ,  -1.1877342 ,
         -1.019699  ,  -9.087112  ,          inf,  -1.9

In [75]:
#Above tested generation, now test if have training data
training_data = tf.random.normal((2, 60, 3))

In [76]:
sample_out, params_out, ref_out, unused_data = decoder(solute_coords, train_data=training_data)
print(sample_out.shape, params_out.shape, ref_out.shape)

(2, 60, 3) (2, 60, 3, 2) (2, 60, 3)


In [77]:
training_data

<tf.Tensor: shape=(2, 60, 3), dtype=float32, numpy=
array([[[-6.89380407e-01,  7.60281801e-01, -6.06886744e-01],
        [-4.58496511e-01, -1.02989078e+00,  2.32431844e-01],
        [ 4.71399724e-01, -1.04713678e+00,  2.35186696e-01],
        [-4.92227860e-02, -2.01346421e+00,  6.74899280e-01],
        [-2.57214379e+00,  9.10236597e-01, -1.26497358e-01],
        [-9.05323029e-01, -2.60664463e-01, -1.21049571e+00],
        [ 9.90537524e-01, -2.04775095e+00,  2.54041731e-01],
        [-8.74205709e-01,  2.14358211e-01,  7.16050148e-01],
        [-1.33341122e+00,  1.29978601e-02, -5.24770498e-01],
        [ 8.40752780e-01, -5.61721921e-01,  7.93349326e-01],
        [-1.69597113e+00, -1.28753245e+00,  1.14234909e-01],
        [ 1.85985819e-01, -7.48355865e-01,  1.28128612e+00],
        [-1.75291240e-01, -4.21844035e-01,  1.87200093e+00],
        [-1.16111755e+00,  1.10491669e+00, -1.14852309e+00],
        [-3.67780924e-01,  2.24724784e-01,  1.31746510e-03],
        [-2.39143163e-01,  1.3159

In [78]:
sample_out

<tf.Tensor: shape=(2, 60, 3), dtype=float32, numpy=
array([[[ 1.85985804e-01, -7.48355865e-01, -1.71871376e+00],
        [ 3.23065519e-01, -3.84563178e-01, -2.02125072e+00],
        [ 1.45065701e+00, -4.56527710e-01, -2.44822788e+00],
        [-1.17947130e+01,  3.42462540e+01,  3.03777313e+01],
        [ 3.82792711e-01, -1.45542312e+00, -1.32682168e+00],
        [ 1.03104994e-01, -5.02335310e-01, -2.49341631e+00],
        [-6.45849609e+00,  1.97010827e+00,  1.22324314e+01],
        [ 1.11949718e+00,  1.25107849e+00,  1.51688957e+00],
        [ 1.30402875e+00,  1.71246767e+00,  1.14235044e-01],
        [-2.57214379e+00,  9.10236418e-01,  2.87350273e+00],
        [-1.23707461e+00,  1.20487750e+00, -1.95867702e-01],
        [-7.69096613e-02,  1.92076159e+00, -3.14861000e-01],
        [-1.01611176e+01, -7.89508343e+00, -1.14852333e+00],
        [-1.27070773e+00, -7.96456635e-01, -1.81755066e+00],
        [-9.05323029e-01, -2.60664463e-01, -1.21049571e+00],
        [ 2.16467857e+00, -1.2678

In [79]:
decoder.coord_transform(sample_out - ref_out)

<tf.Tensor: shape=(2, 60, 3), dtype=float32, numpy=
array([[[ 3.4117663e-01, -2.4963977e+00,  2.4639339e+00],
        [ 6.1598760e-01,  1.7136718e+00,  2.7456391e+00],
        [ 1.4878110e+00,  1.4811143e-01,  2.3036225e+00],
        [ 4.9142471e+01,  1.8996463e+00,  8.6637896e-01],
        [ 3.0024350e-01, -1.2603278e+00,  1.1610138e+00],
        [ 1.2695262e+00,  1.8466389e+00,  2.5404627e+00],
        [ 1.3534561e+01,  3.0284836e+00,  6.3160497e-01],
        [ 4.5603713e-01,  2.6740046e+00,  1.0973229e+00],
        [ 1.3689756e+00,  1.8399991e+00,  2.6315098e+00],
        [ 3.5048161e+00, -2.8260128e+00,  4.7597253e-01],
        [ 2.8340986e-01, -2.3252399e+00,  1.4082073e+00],
        [ 1.0980991e+00,  4.8633477e-01,  1.6374304e+00],
        [ 1.1538096e+01, -2.4808819e+00,  1.5499315e+00],
        [ 4.8079857e-01,  3.0522530e+00,  2.6699421e+00],
        [ 6.0177737e-01,  1.3106543e+00,  1.2692082e+00],
        [ 2.5271990e+00, -2.5673813e-01,  9.4328493e-01],
        [ 4.3354490e

In [80]:
params_out

<tf.Tensor: shape=(2, 60, 3, 2), dtype=float32, numpy=
array([[[[-1.7425725 ,  0.9572524 ],
         [-2.019658  ,  0.02747849],
         [-0.58626974, -1.4174404 ]],

        [[-1.6357486 ,  0.10746148],
         [-0.39627022, -2.1319287 ],
         [ 2.4187899 , -1.1885003 ]],

        [[ 0.17605668,  1.1044862 ],
         [-2.7556765 ,  2.4561806 ],
         [-2.2293813 ,  0.17078829]],

        [[ 3.8927724 ,  0.83543086],
         [-1.2421163 ,  0.46113902],
         [-0.86446327,  2.4395235 ]],

        [[-2.1042595 , -0.02943283],
         [ 3.260982  , -1.3755053 ],
         [ 0.7570095 , -3.3565001 ]],

        [[-0.5521876 , -1.8833822 ],
         [ 5.230797  ,  1.1775465 ],
         [-2.7491872 , -0.01476082]],

        [[ 2.5854356 ,  0.37312967],
         [-0.08033901, -0.98602384],
         [-0.6345459 ,  1.7833376 ]],

        [[-2.2759328 , -0.39029336],
         [ 1.5117419 , -0.8272902 ],
         [ 0.3529255 , -2.9683723 ]],

        [[ 0.23744494, -1.2642412 ],
    

In [81]:
decoder.get_log_probs(sample_out, params_out, ref_out)

<tf.Tensor: shape=(2, 60), dtype=float32, numpy=
array([[ -13.10688  ,  -16.233967 ,   -5.948166 ,   -5.7599726,
          -9.587238 ,   -5.3962607,   -8.79209  ,  -10.026943 ,
          -3.7243218,   -5.0374002, -371.54346  ,   -4.9213233,
          -5.08078  ,   -8.202512 ,   -7.469711 ,   -5.8130693,
         -12.130564 ,   -9.050789 ,   -9.449261 ,   -4.6189737,
          -6.082814 ,  -11.237307 ,   -6.732051 ,   -5.954458 ,
          -4.9476233,   -2.331057 ,  -19.707373 ,   -8.95641  ,
          -1.2625558,   -1.9058256,   -5.3416824,  -11.188383 ,
         -14.357221 ,   -1.9820458,   -7.312926 ,   -4.6129093,
          -9.591605 ,   -8.480754 ,   -7.0285444,  -11.048829 ,
          -7.042932 ,  -10.172918 ,   -7.3640385,  -11.905553 ,
          -9.725476 ,  -10.289483 ,   -4.5743427,   -8.695995 ,
          -3.4493465,  -19.588467 ,  -11.772694 ,   -8.468766 ,
          -8.956662 ,   -7.0034747,  -18.75544  ,  -11.979624 ,
         -15.685166 ,  -16.816761 ,  -13.684622 ,  -11.

In [82]:
reload(pd)

<module 'particle_decoding' from '/Users/jim2/bin/libVAE/particle_decoding.py'>

## VAE model

In [93]:
training_data = tf.random.uniform((2, 60, 3), minval=-1.0, maxval=1.0)

decoder_specs = {'box_lengths':np.array([2.0, 2.0, 2.0]),
                 'tfp_dist_list':[tfp.distributions.Gamma,
                                  tfp.distributions.VonMises,
                                  tfp.distributions.VonMises],
                 'param_transforms':[lambda x, y: [tf.math.exp(2*x - y), tf.math.exp(x - y)],
                                     lambda x, y: [x, tf.math.exp(-y)],
                                     lambda x, y: [x, tf.math.exp(-y)]],
                 'mean_transforms':[tf.math.exp, tf.identity, tf.identity],
                 'coord_transform':pd.spherical_transform}

vae = pd.PriorFlowSolventVAE(training_data.shape[1:], 2, decoder_kwargs=decoder_specs)

In [98]:
vae(training_data)

(<tf.Tensor: shape=(2, 45, 3), dtype=float32, numpy=
 array([[[ 0.84751654,  0.7793739 ,  0.35686135],
         [ 0.5970199 ,  0.5875535 ,  0.9711249 ],
         [ 0.4000671 ,  0.32312107,  0.7345326 ],
         [ 0.49215126, -0.06343348, -0.405419  ],
         [-0.7680141 ,  0.6214082 ,  0.42270583],
         [ 0.89276266,  0.5571599 ,  0.8560276 ],
         [ 0.63390917, -0.5649847 ,  1.2110204 ],
         [-0.30097246,  0.9352832 ,  0.41817904],
         [-0.36377406, -0.70149684,  0.13857532],
         [12.8176    , -1.6962559 , -7.5967197 ],
         [-0.82741165, -0.22912908, -0.96879196],
         [-0.11737904, -0.5979032 , -0.8412246 ],
         [-0.58859134, -0.4399581 , -0.9775288 ],
         [ 1.8470466 , -1.2472007 , -0.8944028 ],
         [-0.8955488 , -0.80580425, -0.488379  ],
         [ 0.84751654,  0.7793739 ,  0.35686135],
         [ 0.52058506,  0.61047107,  1.2041911 ],
         [ 1.0725114 ,  0.08720723,  1.3256674 ],
         [ 1.0508078 ,  0.25482735,  0.8148002 

In [99]:
vae.summary()

Model: "priorflow_vae"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
encoder (DecimationEncoder)  multiple                  0         
_________________________________________________________________
solvent_decoder (SolventDeco multiple                  28256     
_________________________________________________________________
rqs_realnvp_flow (NormFlowRQ multiple                  267870    
_________________________________________________________________
flatten_29 (Flatten)         multiple                  0         
Total params: 296,126
Trainable params: 296,126
Non-trainable params: 0
_________________________________________________________________


In [86]:
vae(training_data, training=True)

(<tf.Tensor: shape=(2, 45, 3), dtype=float32, numpy=
 array([[[-1.56200814e+00,  1.93680048e-01, -9.97836590e-02],
         [ 1.62766457e-01, -7.45468378e-01,  1.27308774e+00],
         [-2.34459747e+02,  1.61096512e+02, -1.89957520e+02],
         [-1.08021736e+00,  1.77963972e-01, -1.23140836e+00],
         [-3.64162445e-01,  1.91607475e-01,  1.38972998e+00],
         [-1.10594511e-01, -2.31358290e-01,  5.28046608e-01],
         [-1.07955933e-02,  1.68232203e-01, -7.62535095e-01],
         [-1.52230740e-01,  1.24668598e-01, -1.29968882e-01],
         [-1.53504610e-01, -3.53822231e-01,  1.68604517e+00],
         [ 2.48547101e+00, -2.45865750e+00, -1.89912939e+00],
         [-4.70965862e-01, -5.91603756e-01, -2.01921463e-01],
         [-4.41015482e-01,  2.15812373e+00, -1.67918241e+00],
         [-1.27439007e-01, -1.02739382e+00, -2.22780704e-02],
         [-6.79298162e-01,  9.40869093e-01, -6.52676105e-01],
         [-2.30598211e-01, -3.49322796e-01,  2.24064112e-01],
         [-1.3550

In [87]:
training_data

<tf.Tensor: shape=(2, 60, 3), dtype=float32, numpy=
array([[[-0.46382737, -0.58660626,  0.1433382 ],
        [ 0.88593173, -0.81683135,  0.11178803],
        [ 0.36215687,  0.30775118,  0.9536052 ],
        [-0.06143665,  0.6284561 , -0.22390509],
        [-0.66986585,  0.29380536,  0.8461194 ],
        [-0.24324512,  0.10883665,  0.85266685],
        [ 0.83233404, -0.49434948,  0.6724057 ],
        [-0.15223074,  0.1246686 , -0.12996888],
        [ 0.33330512, -0.59068036, -0.81361985],
        [-0.3335111 , -0.47977972,  0.8369527 ],
        [-0.72361994, -0.93345356,  0.26679516],
        [ 0.7187345 , -0.13469839, -0.897449  ],
        [-0.96484923,  0.4205966 , -0.9628999 ],
        [-0.23706198,  0.85077596,  0.556273  ],
        [ 0.43799186,  0.19368005, -0.09978366],
        [-0.45975018, -0.9034815 ,  0.04248238],
        [-0.4168098 , -0.801286  ,  0.82297254],
        [ 0.5430856 ,  0.8067372 ,  0.14647198],
        [-0.08412647, -0.4365871 ,  0.17151141],
        [ 0.99643

In [88]:
vae.losses

[<tf.Tensor: shape=(), dtype=float32, numpy=49.460617>,
 <tf.Tensor: shape=(), dtype=float32, numpy=-1275.0544>]

In [89]:
[(m.name, m.result()) for m in vae.metrics]

[('logp_z', <tf.Tensor: shape=(), dtype=float32, numpy=49.28447>),
 ('regularizer_loss', <tf.Tensor: shape=(), dtype=float32, numpy=49.28447>),
 ('recon_loss', <tf.Tensor: shape=(), dtype=float32, numpy=inf>)]

In [95]:
opt = tf.keras.optimizers.Adam(learning_rate=0.0001, beta_1=0.9, beta_2=0.999, epsilon=1e-08)
vae.compile(optimizer=opt)

In [96]:
#big_test_data = tf.data.Dataset.from_tensor_slices(tf.random.uniform((1000, 60, 3), minval=-1.0, maxval=1.0)).batch(50)
big_test_data = tf.random.uniform((1000, 60, 3), minval=-1.0, maxval=1.0)

In [97]:
vae.fit(x=big_test_data, epochs=10,
       batch_size=50, validation_split=0.1)

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


<tensorflow.python.keras.callbacks.History at 0x7fc59842ec90>