In [1]:
import numpy as np
import jax.numpy as jnp
import jax
import sympy.abc
from sympy import fps, ln
import matplotlib.pyplot as plt

import sys
sys.path.append("../../")
from kooplearn.estimators import ReducedRank
from kooplearn.kernels import Linear

#Dummy variables to test
key = jax.random.PRNGKey(0)

NUM_SAMPLES = 1_000
NUM_FEATURES = 50

dummy_traj = jax.random.uniform(key, (NUM_SAMPLES, NUM_FEATURES))

In [7]:
def ln_series_coefficients(truncation_order: int, ascending: bool = True) -> jnp.ndarray:
    #Minimum trucation order is 1, i.e. constant + linear expansion
    poly = fps(ln(sympy.abc.T), x0=1).polynomial(n=truncation_order + 1).as_poly()
    coeffs =  jnp.array(poly.all_coeffs(), dtype=float)
    if ascending:
        return jnp.flip(coeffs)
    else:
        return coeffs

def cross_covariance_series(traj: jnp.ndarray, truncation_order: int) -> jnp.ndarray:
    #Minimum trucation order is 1, i.e. constant + linear expansion
    def cross_covariance(traj: jnp.ndarray, order: int, truncation_order: int) -> jnp.ndarray:
        max_idx = traj.shape[0] - 1
        N = max_idx//truncation_order
        base_ids = jnp.arange(N, dtype=int)*truncation_order #Replace this with a for loop!!!
        return (traj[base_ids].T)@traj[base_ids + order]
    _cc = jax.vmap(cross_covariance, in_axes=(None, 0, None))
    orders = jnp.arange(truncation_order + 1, dtype=int)
    return _cc(traj, orders, truncation_order)

In [8]:
max_order = 4
coeffs = ln_series_coefficients(max_order)
T = cross_covariance_series(dummy_traj, max_order)

### Experiment on alanine dipeptide

In [73]:
#To download these files, run "python get_dataset.py"
files = [
    "alanine-dipeptide-3x250ns-backbone-dihedrals.npz",
    "alanine-dipeptide-3x250ns-heavy-atom-distances.npz",
]
#Load the data on memory. The .npz files are comprised of three independent simulations ['arr_0', 'arr_1', 'arr_2']. Either of them can be used to train the model.
dihedrals = np.load("../../examples/alanine_dipeptide/data/" + files[0])['arr_2'] #Dihedral angles \phi and \psi
distances = np.load("../../examples/alanine_dipeptide/data/" + files[1])['arr_2'] #Distance between heavy atoms