In [1]:
#General Purpose
import sys 
import numpy as np
import json
from tqdm import tqdm
#JAX
import jax.numpy as jnp
import jax 
#import jax_md
#DynamicalSystems
sys.path.append("../../")
from kooplearn.estimators import ReducedRank
from kooplearn.kernels import RBF

In [2]:
#Init variables
key = jax.random.PRNGKey(0)
key, split = jax.random.split(key)

with open("config.json", "r") as f:
    config = json.load(f)

@jax.jit
def pande_potential(x):
    """
    See Example 1 of "Modeling Molecular Kinetics with tICA and the Kernel Trick" 10.1021/ct5007357
    """
    return jnp.sum(4*(x**8+ 0.8*jnp.exp(-80*(x**2)) +  0.2*jnp.exp(-80*((x - 0.5)**2)) + 0.5*jnp.exp(-40*((x + 0.5)**2))))

In [3]:
displacement_fn, shift_fn = jax_md.space.free()
kernel = RBF(length_scale = config["ML"]["RBF_length_scale"])

#Load Some Configs
T_min = config["physics"]["temperature"]["min"]
T_max = config["physics"]["temperature"]["max"]
T_num = config["physics"]["temperature"]["num"]
steps = int(config["meta"]["simulation_steps"])
write_every = int(config["meta"]["write_every"])

def log_position(vars):
    val, positions, curr_idx = vars
    return positions.at[curr_idx // write_every].set(val)

def do_nothing(vars):
    _, positions, _ = vars
    return positions

evals_data = np.zeros((T_num, config["meta"]["num_simulations"], config["ML"]["rank"]), dtype=np.complex128)

In [4]:
for T_idx, temperature in tqdm(enumerate(np.linspace(T_min, T_max, T_num)), total=T_num):
    R = jax.random.uniform(split, (config["meta"]["num_simulations"],1), minval=-1, maxval=1)
    positions = jnp.zeros((steps // write_every,) + R.shape) #Init array to store positions
    init_fn, apply_fn = jax_md.simulate.brownian(pande_potential, shift_fn, config["physics"]["time_step"], temperature)
    apply_fin = jax.jit(apply_fn)

    def simulation_step(curr_idx, state_and_positions):
        state, positions = state_and_positions
        pred = (curr_idx % write_every == 0)
        vars = (state.position, positions, curr_idx)
        positions = jax.lax.cond(pred, log_position, do_nothing, vars)
        state = apply_fn(state)
        return state, positions
        
    state = init_fn(key, R)
    state, positions = jax.lax.fori_loop(0, steps, simulation_step, (state, positions))
    traj = jnp.squeeze(positions)

    _min = config["ML"]["eigenfunction_sample"]["min"]
    _max = config["ML"]["eigenfunction_sample"]["max"]
    _num = config["ML"]["eigenfunction_sample"]["num"]

    x = np.linspace(_min, _max, _num)[:, None]

    for sim_idx in range(traj.shape[1]):
        sample = traj[:,sim_idx]
        X = np.asarray(sample[:-1, None], dtype=np.float64)
        Y = np.asarray(sample[1:, None], dtype=np.float64)
        Koopman = ReducedRank(kernel, rank=config["ML"]["rank"], tikhonov_reg=config["ML"]["tikhonov_reg"])
        Koopman.fit(X, Y, _save_svals=True)
        evals, lefuns, refuns = Koopman.eig()
        evals_data[T_idx, sim_idx] = evals
        #Write to file evals_data, lefuns_data, refuns_data
        with open("data/evals_data.npy", "wb") as f:
            np.save(f, evals_data)

  0%|          | 0/1 [00:00<?, ?it/s]2022-12-16 12:51:47.775963: E external/org_tensorflow/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc:2153] Execution of replica 0 failed: INTERNAL: CustomCall failed: jaxlib/gpu/prng_kernels.cc:33: operation gpuGetLastError() failed: out of memory
  0%|          | 0/1 [00:01<?, ?it/s]


XlaRuntimeError: INTERNAL: CustomCall failed: jaxlib/gpu/prng_kernels.cc:33: operation gpuGetLastError() failed: out of memory