In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
%env XLA_PYTHON_CLIENT_MEM_FRACTION=.8

env: XLA_PYTHON_CLIENT_MEM_FRACTION=.8


In [3]:
from jax import numpy as jnp
import numpy as np
import matplotlib.pyplot as plt

import jax
from jax import Array
jax.devices()

[CpuDevice(id=0)]

In [4]:
EPSILON   = 6.189
SIGMA     = 2.3925
a         = 1.80
lam       = 23.15
gamma     = 1.20
A         = 7.049556277
B         = 0.6022245584
p         = 4.
q         = 0.
#theta_0   = np.radians(109.47)
NUM_PARTICLES = 180
SPATIAL_DIMENSIONS = 3
KB = 0.00831446261815324 # in (unit.kilojoule_per_mole/unit.kelvin)
PRESS_PRIOR = 1.
TEMP_PRIOR  = 273.

CUTOFF= a * SIGMA

In [5]:
import sys
sys.path.insert(0, "/home/ninarell/OneDrive/WF_GAN_FOR_GLASSES/B_GEN/flow_diagrams")
import flow_diagrams
from flow_diagrams.energy.stillinger_weber import fd_stillinger_weber_neighbor_list
from flow_diagrams.energy.lennard_jones import fd_lennard_jones_neighbor_list
from flow_diagrams.utils.data import NumpyLoader, split_data


In [6]:
from jax_md import space, partition

In [7]:
def wrap_to_unit_cube(pos, lower, upper):
    width = upper - lower
    return jnp.mod(pos - lower, width) + lower

def wrap_to_box(pos, box):
    return pos % box

def remove_disp_of_first_atom(displacements):
    # assert displacements.shape == (NUM_PARTICLES, SPATIAL_DIMENSIONS)

    disp_at_1 = displacements[0,:]

    return displacements - disp_at_1


def transform_abs_coords_to_rel_coords(absolute_coordinates: Array, side_length: Array):
    """Transforms relative coordinates inside the unit cube to absolute coordinates given a 3d box_vector."""
    assert absolute_coordinates.shape[-1] == SPATIAL_DIMENSIONS
    assert side_length.shape == (3,)
    return absolute_coordinates / side_length



In [8]:
LOWER = 0.
UPPER = 1.
CUT_TYPE = 'switch'


PRIOR_PRESSURE = 1.
TEMP_PRIOR  = 273.
REDUCED_TEMP_PRIOR = TEMP_PRIOR * KB / EPSILON
REDUCED_PRESS_PRIOR = PRIOR_PRESSURE * SIGMA**3 / EPSILON

filename_prior = f"prod.liquid_273K_1atm_mW.npz"
#filename_prior = f"N180-T1.2000-P8.556-RCUT2.20_1000.npz"
#data_prior = jnp.load(filename_prior)
data_prior = jnp.load(filename_prior)
positions_prior_abs = data_prior['pos']
box_prior = data_prior['box']
vols_prior = jnp.prod(box_prior,axis=-1)
BOX_EDGES = np.mean(box_prior,axis=0)

# fix first atom in origin and wrap to box
positions_prior = jax.vmap(wrap_to_box)(jax.vmap(remove_disp_of_first_atom)(positions_prior_abs),box_prior)
MEAN_CONFIG = np.mean(positions_prior,axis=0)

# scale to [0,1]
positions_prior= jax.vmap(transform_abs_coords_to_rel_coords)(positions_prior,box_prior)
positions_prior = wrap_to_unit_cube(positions_prior,LOWER,UPPER)    

scale_prior = box_prior[:,0] / BOX_EDGES[0]
energies_prior = data_prior['ene']

assert np.logical_and(1. >= UPPER, positions_prior >= LOWER).all()
assert np.allclose(positions_prior[:,0,:],0,atol=1e-7)

n_configurations_prior = positions_prior.shape[0] 

print('# Prior samples', n_configurations_prior)

# Prior samples 1000


In [9]:
type(data_prior)

numpy.lib.npyio.NpzFile

In [10]:
# Using float32 for positions / velocities, but float64 for reductions.
dtype = np.float32

# Specify the format of the neighbor list. 
# Options are Dense, Sparse, or OrderedSparse. 
format = partition.Dense

In [11]:
from jax_md.energy import stillinger_weber_neighbor_list
displacement_frac, shift_frac = space.periodic_general(BOX_EDGES, fractional_coordinates=False)
neighbor_fn, energy_fn = stillinger_weber_neighbor_list(
    displacement=displacement_frac,
    box_size=BOX_EDGES,
    sigma=SIGMA,
    A = A,
    B = B,
    lam = lam,
    gamma = gamma,
    epsilon= EPSILON,
    cutoff = CUTOFF,
    dr_threshold= 0.5,
    fractional_coordinates=False,
    format = format
    ) 

"""
displacement_frac,box_size=BOX_EDGES,
    cutoff = REDUCED_CUTOFF,
    dr_threshold=1.,
    epsilon=EPSILON,
    sigma=SIGMA,
    lam=lam,
    format=format,
    fractional_coordinates=False)
"""

NEIGHBOR_LIST = neighbor_fn.allocate(MEAN_CONFIG)


In [12]:
def compute_sw_energy(pos_rel: jnp.ndarray, scale):
    box= scale * BOX_EDGES
    #print(box)
    #print(pos_rel*box)
    nbrs = NEIGHBOR_LIST.update(pos_rel * box)
    #nbrs = NEIGHBOR_LIST.update(data_prior['pos'][0])
    sw_energy = energy_fn(pos_rel * box, nbrs, box=box)
    #energy_fn(data_prior['pos'][0], nbrs, box=box)

    return sw_energy

In [13]:
train_fraction = 0.1
BATCH_SIZE = 128
num_samples = 10

# Store all displacements relative to first one (which stays at its equilibrium position)
dataset_prior_train, dataset_prior_test = split_data(train_fraction, positions_prior,
                        energies_prior,
                       scale_prior)
dataloader_train = NumpyLoader(dataset_prior_train,BATCH_SIZE,False)

ene_prior = dataset_prior_test.energies[:num_samples]
pos_latent = dataset_prior_test.pos[:num_samples]
scale_latent = dataset_prior_test.scale[:num_samples]
# config_latent = jax.vmap(transform_abs_pos_to_abs_config)(pos_latent)


In [14]:
energies_recomputed_prior = jax.vmap(compute_sw_energy)(pos_latent,scale_latent)

  return _reduction(a, "sum", np.sum, lax.add, 0, preproc=_cast_to_numeric,


In [15]:
 energies_recomputed_prior

Array([-1836.9418, -1791.5863, -1820.9943, -1838.1879, -1802.7601,
       -1842.5334, -1844.6782, -1854.7607, -1857.1711, -1831.8352],      dtype=float32)

In [16]:
ene_prior

array([-1836.94156096, -1827.3783167 , -1837.9955649 , -1838.18766338,
       -1843.17967025, -1842.53303051, -1844.67845624, -1854.76075957,
       -1857.17064481, -1831.83512308])

In [17]:
energies_recomputed_prior - ene_prior

Array([-2.4414062e-04,  3.5791992e+01,  1.7001343e+01, -2.4414062e-04,
        4.0419556e+01, -3.6621094e-04,  2.4414062e-04,  0.0000000e+00,
       -4.8828125e-04, -1.2207031e-04], dtype=float32)

In [19]:
from flow_diagrams.utils.conditioning import convert_from_reduced_p, convert_from_reduced_t

convert_from_reduced_p(EPSILON, SIGMA)

0.4519232066944438

In [20]:
convert_from_reduced_t(EPSILON, SIGMA)

2.586833855799373

In [21]:
REDUCED_TEMP_PRIOR

0.36675525848373475

In [22]:
REDUCED_PRESS_PRIOR

2.212765322043141

In [23]:
2.212765322043141 * EPSILON / SIGMA**3

0.9999999999999999