##### General Imports

In [1]:
#Custom functions
from training import TrainingState, loss, update, feature_map, optimiser
from data import TimeLaggedDataset, collate_fn

#JAX
import jax.numpy as jnp
import jax
import haiku as hk

#MISC
from torch.utils.data import DataLoader

#Standard library
import json

##### Experiment on __alanine dipeptide__

In [2]:
with open('config.json') as f:
    CONFIG = json.load(f)
    
#To download these files, run "python get_dataset.py"
files = [
    "alanine-dipeptide-3x250ns-backbone-dihedrals.npz",
    "alanine-dipeptide-3x250ns-heavy-atom-distances.npz",
]
#Load the data on memory. The .npz files are comprised of three independent simulations ['arr_0', 'arr_1', 'arr_2']. Either of them can be used to train the model.
dihedrals = jnp.load("../../examples/alanine_dipeptide/data/" + files[0])['arr_2'] #Dihedral angles \phi and \psi
distances = jnp.load("../../examples/alanine_dipeptide/data/" + files[1])['arr_2'] #Distance between heavy atoms

In [3]:
#Init the network
initial_params = feature_map.init(jax.random.PRNGKey(seed=42), distances[0])
initial_opt_state = optimiser.init(initial_params)
state = TrainingState(initial_params, initial_params, initial_opt_state)

In [4]:
#VIZ
print(hk.experimental.tabulate(feature_map)(distances[0]))

+-----------------------------+--------------------------------------------------------------+-----------------+----------+----------+---------------+---------------+
| Module                      | Config                                                       | Module params   | Input    | Output   |   Param count |   Param bytes |
| VAMPnet (MLP)               | MLP(output_sizes=[64, 128, 128, 64, 64, 32], name='VAMPnet') |                 | f32[45]  | f32[32]  |        42,272 |     169.09 KB |
+-----------------------------+--------------------------------------------------------------+-----------------+----------+----------+---------------+---------------+
| VAMPnet/~/linear_0 (Linear) | Linear(output_size=64, name='linear_0')                      | w: f32[45,64]   | f32[45]  | f32[64]  |         2,944 |      11.78 KB |
|  └ VAMPnet (MLP)            |                                                              | b: f32[64]      |          |          |               |               

In [5]:
dataset = TimeLaggedDataset(distances, max_lag = 50)
dataloader = DataLoader(dataset, collate_fn=collate_fn, batch_size = CONFIG['training']['batch_size'])

In [6]:
# Training & evaluation loop.
for _ in range(CONFIG['training']['num_epochs']):
    for data_batch in dataloader:
        state = update(state, data_batch)
    print(loss(state.params, data_batch,CONFIG['opt']['tikhonov_reg'], CONFIG['opt']['VAMP_order'], CONFIG['opt']['rank']))

NotImplementedError: Only the b=None case of eigh is implemented