In [1]:
import numpy as np
import tensorflow as tf
import tensorflow_probability as tfp

from importlib import reload

In [42]:
import particle_decoding as pd

In [59]:
reload(pd)

<module 'particle_decoding' from '/Users/jim2/bin/libVAE/particle_decoding.py'>

## Transformations

In [6]:
test_coords = np.array([[[1.0, 1.0, 1.0],
                        [-1.0, 1.0, 0.0],
                        [0.5, 0.5, 1.0]],
                        [[1.0, 1.0, 1.0],
                        [-1.0, 1.0, 0.0],
                        [0.5, 0.5, 1.0]]])

In [7]:
pd.identity_transform(test_coords)

<tf.Tensor: shape=(2, 3, 3), dtype=float64, numpy=
array([[[ 1. ,  1. ,  1. ],
        [-1. ,  1. ,  0. ],
        [ 0.5,  0.5,  1. ]],

       [[ 1. ,  1. ,  1. ],
        [-1. ,  1. ,  0. ],
        [ 0.5,  0.5,  1. ]]])>

In [8]:
pd.spherical_transform(test_coords)

<tf.Tensor: shape=(2, 3, 3), dtype=float64, numpy=
array([[[1.73205081, 0.78539816, 0.95531662],
        [1.41421356, 2.35619449, 1.57079633],
        [1.22474487, 0.78539816, 0.61547971]],

       [[1.73205081, 0.78539816, 0.95531662],
        [1.41421356, 2.35619449, 1.57079633],
        [1.22474487, 0.78539816, 0.61547971]]])>

In [9]:
pd.spherical_transform(test_coords, dist_sq=tf.reduce_sum(test_coords * test_coords, axis=-1))

<tf.Tensor: shape=(2, 3, 3), dtype=float64, numpy=
array([[[1.73205081, 0.78539816, 0.95531662],
        [1.41421356, 2.35619449, 1.57079633],
        [1.22474487, 0.78539816, 0.61547971]],

       [[1.73205081, 0.78539816, 0.95531662],
        [1.41421356, 2.35619449, 1.57079633],
        [1.22474487, 0.78539816, 0.61547971]]])>

In [10]:
pd.spherical_transform(pd.spherical_transform(test_coords), reverse=True)

<tf.Tensor: shape=(2, 3, 3), dtype=float64, numpy=
array([[[ 1.00000000e+00,  1.00000000e+00,  1.00000000e+00],
        [-1.00000000e+00,  1.00000000e+00,  8.65956056e-17],
        [ 5.00000000e-01,  5.00000000e-01,  1.00000000e+00]],

       [[ 1.00000000e+00,  1.00000000e+00,  1.00000000e+00],
        [-1.00000000e+00,  1.00000000e+00,  8.65956056e-17],
        [ 5.00000000e-01,  5.00000000e-01,  1.00000000e+00]]])>

## Masking by distance

In [11]:
test_coords = np.array([[[1.0, 1.0, 1.0],
                        [-1.0, 1.0, 0.0],
                        [0.5, 0.5, 1.0]],
                        [[1.0, 1.0, 1.0],
                        [-1.0, 1.0, 0.0],
                        [0.5, 0.5, 1.0]]])

ref_coords = np.array([[[0.0, 0.0, 0.0]],
                       [[0.0, 0.0, 1.0]]])

In [12]:
pd.distance_mask(ref_coords, test_coords, k_neighbors=2)

(<tf.Tensor: shape=(2, 2, 3), dtype=float64, numpy=
 array([[[ 0.5,  0.5,  1. ],
         [-1. ,  1. ,  0. ]],
 
        [[ 0.5,  0.5,  0. ],
         [ 1. ,  1. ,  0. ]]])>,
 <tf.Tensor: shape=(2, 2), dtype=int32, numpy=
 array([[2, 1],
        [2, 0]], dtype=int32)>,
 <tf.Tensor: shape=(2, 2), dtype=float64, numpy=
 array([[1.5, 2. ],
        [0.5, 2. ]])>)

In [13]:
local_coords, near_inds, dist_sq = pd.distance_mask(ref_coords, test_coords, k_neighbors=2)
pd.spherical_transform(local_coords, dist_sq=dist_sq)

<tf.Tensor: shape=(2, 2, 3), dtype=float64, numpy=
array([[[1.22474487, 0.78539816, 0.61547971],
        [1.41421356, 2.35619449, 1.57079633]],

       [[0.70710678, 0.78539816, 1.57079633],
        [1.41421356, 0.78539816, 1.57079633]]])>

## Creating probability distributions

In [14]:
params = np.array([[[[0.0, 1.0], [10.0, 1.0], [-10.0, 1.0]],
                    [[0.0, 1.0], [0.0, 2.0], [0.0, 10.0]],
                    [[5.0, 1.0], [5.0, 2.0], [5.0, 10.0]]],
                   [[[0.0, 1.0], [-10.0, 1.0], [10.0, 1.0]],
                    [[0.0, 10.0], [0.0, 2.0], [0.0, 1.0]],
                    [[-5.0, 10.0], [-5.0, 2.0], [-5.0, 1.0]]]])

dist_list = [tfp.distributions.Normal, tfp.distributions.Normal, tfp.distributions.Normal]

In [15]:
joint_dist = pd.create_dist(params, dist_list)

In [16]:
sample = tf.stack(joint_dist.sample(), axis=-1)
print(sample)

tf.Tensor(
[[[ -0.81760675  11.33741687 -10.07465369]
  [  0.80637715   2.17460065  -7.56174451]
  [  5.65615621   2.50775362  -3.24293371]]

 [[ -1.20212597  -9.72186889  10.74520898]
  [  1.77992614   2.44943838   0.05657853]
  [  1.37372276  -1.81720925  -5.56520884]]], shape=(2, 3, 3), dtype=float64)


In [17]:
joint_dist.log_prob(tf.unstack(sample, axis=-1))

<tf.Tensor: shape=(2, 3), dtype=float64, numpy=
array([[-3.98818452, -6.95468083, -7.08395964],
       [-3.79571569, -6.51995767, -7.38166972]])>

In [18]:
np.sum(-0.5*((sample[-1, 1, :] - params[-1, 1, :, 0])**2)/(params[-1, 1, :, 1]**2) - np.log(params[-1, 1, :, 1]) - 0.5*np.log(2.0*np.pi))

-6.51995767293184

## Solvation neural networks

In [19]:
#Input can be solute or solvent coordinates, with extra_inputs being the other
#Start by testing without augmented input (only relevant for first network to predict layer of solvent for first solute)
solute_net = pd.SolvationNet((8, 3), out_event_dims=2, hidden_dim=50, n_hidden=3, augment_input=False)

In [20]:
solute_coords = tf.random.normal((1, 2, 3))

In [21]:
params, shifts = solute_net(solute_coords)
print(params, shifts)

tf.Tensor(
[[[[-0.09980661 -0.1041244 ]
   [ 0.04961302 -0.05271351]
   [-0.09814769  0.20591472]]

  [[-0.10259533  0.02306227]
   [-0.05728697 -0.00780068]
   [ 0.09675334  0.07201551]]

  [[-0.00781018  0.05733149]
   [-0.19918846 -0.01654456]
   [ 0.00218722  0.01779334]]

  [[ 0.02409838  0.04546865]
   [ 0.0926391  -0.17045996]
   [-0.18060714 -0.07500496]]

  [[-0.12327825 -0.09151483]
   [-0.04204683 -0.16377294]
   [-0.15425947 -0.0141225 ]]

  [[ 0.04088102  0.07029386]
   [-0.01120629 -0.0299336 ]
   [ 0.10975686  0.07654113]]

  [[ 0.08257192  0.04917239]
   [ 0.06631015 -0.13409425]
   [-0.04061389  0.13579434]]

  [[ 0.05932912 -0.06376822]
   [-0.1518034   0.13452949]
   [ 0.03586555 -0.03287418]]]], shape=(1, 8, 3, 2), dtype=float32) tf.Tensor(
[[[[-0.0383045   0.12508091]
   [ 0.46848476  0.5283229 ]
   [ 0.37754977 -0.42560795]]

  [[ 0.0199008  -0.4003601 ]
   [ 0.14179192 -0.3156644 ]
   [-0.23247185  0.24679103]]

  [[ 0.38978565 -0.06835768]
   [ 0.4082023   0.311

In [22]:
dist = pd.create_dist(params + shifts, [tfp.distributions.Normal]*3,
                      param_transforms=[tf.identity, lambda x: tf.math.exp(0.5*x)])
sample = tf.stack(dist.sample(), axis=-1)
print(sample)

tf.Tensor(
[[[ 0.6135796  -0.33550513  0.37937278]
  [-0.37440145 -0.06462502  0.5282067 ]
  [ 0.8384571  -0.8439002   0.09957206]
  [ 0.9197638   0.15746492  0.6442377 ]
  [ 0.47552627 -0.57384294 -1.3950636 ]
  [-0.11309004 -1.0998665  -1.120734  ]
  [ 0.553651   -1.263895    0.88763237]
  [ 0.4490246  -1.9193971  -1.2477242 ]]], shape=(1, 8, 3), dtype=float32)


In [23]:
#First set of 3 (first particle) should not change shifts, even with different input due to autoregressive
#All others should change, though
new_shifts = solute_net(solute_coords, sampled_input=sample)
print(new_shifts)

tf.Tensor(
[[[[-0.0383045   0.12508091]
   [ 0.46848476  0.5283229 ]
   [ 0.37754977 -0.42560795]]

  [[ 0.01660449 -0.39933178]
   [ 0.14763793 -0.32144842]
   [-0.23303536  0.25136343]]

  [[ 0.39120752 -0.06646115]
   [ 0.4022907   0.31650415]
   [-0.27128214 -0.13910626]]

  [[ 0.27705407 -0.11834546]
   [-0.14851934 -0.3232368 ]
   [ 0.02540156  0.06414494]]

  [[ 0.10513692 -0.06454682]
   [-0.56495523  0.30560976]
   [ 0.11102645 -0.07631263]]

  [[ 0.7589178   0.17683901]
   [-0.2802874   0.20163116]
   [-0.06162059  0.02359262]]

  [[ 0.11184192  0.2580911 ]
   [-0.09140304  0.46329448]
   [-0.20937099  0.19694963]]

  [[-0.5300697  -0.00627059]
   [-0.8134009  -0.24270348]
   [-0.31866944 -0.30322558]]]], shape=(1, 8, 3, 2), dtype=float32)


In [24]:
#And now with solvent inputs (remember, could flip if wanted)
solute_net_solv = pd.SolvationNet((8, 3), out_event_dims=2, hidden_dim=50, n_hidden=3, augment_input=True)

In [25]:
solvent_coords = tf.random.normal((1, 6, 3))

In [26]:
params, shifts = solute_net_solv(solute_coords, extra_coords=solvent_coords)
print(params, shifts)

tf.Tensor(
[[[[-0.04393127  0.07799909]
   [-0.45480007  0.0830999 ]
   [ 0.09980106 -0.10134049]]

  [[-0.0212495  -0.07647125]
   [-0.27031705  0.22254787]
   [ 0.14784735 -0.10961125]]

  [[ 0.19027738  0.21500748]
   [-0.07181767 -0.21314739]
   [-0.12685727 -0.06784045]]

  [[-0.11054699  0.06168682]
   [ 0.10411571  0.20586964]
   [ 0.1029194   0.02431325]]

  [[-0.08037464  0.17289037]
   [-0.07414215 -0.21687378]
   [ 0.1044483   0.3811229 ]]

  [[ 0.21736358  0.01333422]
   [-0.18241525  0.23080651]
   [ 0.29402754 -0.15339898]]

  [[-0.27673924 -0.06864452]
   [-0.02252953 -0.06797273]
   [ 0.15385608  0.16442102]]

  [[-0.29977065 -0.14787084]
   [-0.10428659 -0.14010908]
   [-0.0601379  -0.04649765]]]], shape=(1, 8, 3, 2), dtype=float32) tf.Tensor(
[[[[ 0.1755567   0.619532  ]
   [-0.24137665  1.9391973 ]
   [ 0.05261597 -0.39134994]]

  [[ 0.9921926   0.48123786]
   [-0.15530773  0.08682685]
   [ 0.69199    -0.03299864]]

  [[-0.03132709 -0.7789026 ]
   [-0.9295109  -0.249

In [27]:
dist = pd.create_dist(params + shifts, [tfp.distributions.Normal]*3,
                      param_transforms=[tf.identity, lambda x: tf.math.exp(0.5*x)])
sample = tf.stack(dist.sample(), axis=-1)
print(sample)

tf.Tensor(
[[[ 1.25191     0.8854323  -0.5983267 ]
  [ 1.5736074   0.9306261   3.16043   ]
  [-0.05924763 -0.51728696 -1.4418948 ]
  [ 2.1890135   0.8281939   1.824611  ]
  [-0.2927664  -0.28309464 -0.7010681 ]
  [-0.7037216   2.2428987   0.15105975]
  [ 0.33747444  1.867202   -1.0906695 ]
  [-0.16583134 -1.3260478  -0.00592005]]], shape=(1, 8, 3), dtype=float32)


In [28]:
new_shifts = solute_net_solv(solute_coords, extra_coords=solvent_coords, sampled_input=sample)
print(new_shifts)

tf.Tensor(
[[[[ 0.1755567   0.619532  ]
   [-0.24137665  1.9391973 ]
   [ 0.05261597 -0.39134994]]

  [[ 0.9928813   0.4847834 ]
   [-0.15807655  0.08715776]
   [ 0.6944961  -0.02858201]]

  [[-0.04249711 -0.78527635]
   [-0.93515104 -0.24446882]
   [-0.7677029  -0.15466097]]

  [[ 0.4818019   0.18391809]
   [ 0.42915756 -1.2173065 ]
   [-0.53433484 -0.24436516]]

  [[-0.06524521 -0.78180045]
   [-0.3802561  -0.31051153]
   [-0.568477    0.75208867]]

  [[ 0.8100127   0.40884542]
   [ 0.4959828  -0.7091505 ]
   [ 0.4669512   0.29569528]]

  [[ 0.2391311   1.5754158 ]
   [ 0.40541324  1.8538301 ]
   [ 0.38033673  0.9968158 ]]

  [[ 0.06766802 -0.35980046]
   [ 0.7097578   0.5926517 ]
   [-0.86717904  0.2792175 ]]]], shape=(1, 8, 3, 2), dtype=float32)


In [29]:
#Now that it's been called (so built) with extra_coords provided, must provide always with this layer
new_shifts = solute_net_solv(solute_coords, sampled_input=sample)
print(new_shifts)

ValueError: Input 0 of layer dense_18 is incompatible with the layer: expected axis -1 of input shape to have value 24 but received input with shape (1, 6)

In [30]:
#Now pretend we're working with a central solvent particle and adding one more
#(conditioned on other solutes and solvent already around)
solvent_net = pd.SolvationNet((1, 3), out_event_dims=2, hidden_dim=50, n_hidden=3, augment_input=True)

In [31]:
params, shifts = solvent_net(solvent_coords, extra_coords=solute_coords)
print(params, shifts)

tf.Tensor(
[[[[-0.19094816 -0.4481198 ]
   [ 0.42995268  0.31935585]
   [-0.20115793 -0.49286693]]]], shape=(1, 1, 3, 2), dtype=float32) tf.Tensor(
[[[[-0.5909523   0.11009777]
   [ 0.2289467   2.0338945 ]
   [ 0.04213387 -0.708703  ]]]], shape=(1, 1, 3, 2), dtype=float32)


## Full particle decoder

In [66]:
decoder = pd.ParticleDecoder(3, 2, k_solute_neighbors=3, k_solvent_neighbors=3)

In [67]:
solute_coords = tf.random.normal((2, 5, 3))

In [68]:
sample_out, params_out, ref_out = decoder(solute_coords)
print(sample_out.shape, params_out.shape, ref_out.shape)

(2, 60, 3) (2, 60, 3, 2) (2, 60, 3)


In [69]:
#For convenience, added log_prob function to decoder
decoder.get_log_probs(sample_out, params_out, ref_out)

<tf.Tensor: shape=(2, 60), dtype=float32, numpy=
array([[ -3.2949982 ,  -3.1210432 ,  -2.2930758 ,  -3.7953968 ,
         -4.053272  ,  -3.326024  ,  -1.8224097 ,  -2.7933607 ,
         -3.4188285 ,  -3.6859188 ,  -3.7196665 ,  -4.9562745 ,
         -3.0743856 ,  -3.6298223 ,  -0.36489794, -11.974653  ,
         -1.3554857 ,  -4.7123537 ,  -3.5820088 ,  -3.716106  ,
         -2.4059625 ,  -4.0679    ,  -1.7192688 ,  -8.929336  ,
         -3.2049131 ,  -7.6591873 ,  -2.5990767 ,  -5.8504176 ,
         -4.976803  ,  -4.4901114 ,  -3.2681174 ,  -4.8130407 ,
         -3.9620962 ,  -4.741335  ,  -4.149606  ,  -3.2440553 ,
         -6.5622597 ,  -4.3301125 ,  -4.8495374 ,  -6.302489  ,
         -4.405369  ,  -3.095609  ,  -4.094417  ,  -3.2268963 ,
         -1.6817219 ,  -5.322399  ,  -3.1788025 ,  -3.6609015 ,
         -5.0258226 ,  -4.582065  ,  -2.811842  ,  -2.5754826 ,
         -2.6126313 ,   0.49051094,  -3.7317128 ,  -4.862255  ,
         -4.112133  ,  -2.9638703 ,  -2.5627127 ,  -4.0