# The metastable states of a molecule
In this example we will reproduce the results in {footcite:t}`Mardt2018`, training a {class}`kooplearn.models.feature_maps.VAMPNet` to learn the kinetics of the small molecule Alanine Dipeptide from simulation data. 

## Description of the dataset


## Data Loading

In [1]:
from pathlib import Path
import os
import numpy as np
data_path = Path.cwd().parent.parent / "examples/ala2/__data__"

In [2]:
data_path

PosixPath('/Users/pietronovelli/code_repos/kooplearn/examples/ala2/__data__')

In [3]:
def load_ala2_data(model_path: os.PathLike, descriptor:str):
    rel_path = model_path
    files = {
        "dihedrals": "alanine-dipeptide-3x250ns-backbone-dihedrals.npz",
        "distances": "alanine-dipeptide-3x250ns-heavy-atom-distances.npz",
        "positions": "alanine-dipeptide-3x250ns-heavy-atom-positions.npz",
    }
    if descriptor not in ['dihedrals', 'distances', 'positions']:
        raise ValueError(f"descriptor must be one of 'dihedrals', 'distances', 'positions'. Got {descriptor}")
    return np.concatenate([np.load(os.path.join(rel_path, files[descriptor]))[f"arr_{arr_idx}"] for arr_idx in range(3)])

In [4]:
distances = load_ala2_data(data_path, 'distances')
dihedrals = load_ala2_data(data_path, 'dihedrals')

distances_dim = distances.shape[-1]

In [5]:
# Make the data into a context window Dataset
from kooplearn.nn.data import traj_to_contexts_dataset
from torch.utils.data import DataLoader, random_split

dist_dataset = traj_to_contexts_dataset(distances)
train_dist, val_dist, test_dist = random_split(dist_dataset, [0.8, 0.1, 0.1])
train_loader = DataLoader(train_dist, batch_size=10000, shuffle=True)
val_loader = DataLoader(val_dist, batch_size=5000, shuffle=True)

The provided trajectory is of type <class 'numpy.ndarray'>. Converting to torch.Tensor.


In [6]:
import torch

class _old_MLP(torch.nn.Module):
    def __init__(
        self, feature_dim: int, activation=torch.nn.ELU
    ):
        super().__init__()
        self.activation = activation
        self.encoder = torch.nn.Sequential(
            torch.nn.Linear(feature_dim,64), torch.nn.LayerNorm(64), self.activation(),
            torch.nn.Linear(64,128), self.activation(),
            torch.nn.Linear(128,64), self.activation(),
            torch.nn.Linear(64,8)
        )

    def forward(self, x):
        return self.encoder(x)


class MLP(torch.nn.Module):
    def __init__(
        self, feature_dim: int, activation=torch.nn.ELU
    ):
        super().__init__()
        self.activation = activation
        self.encoder = torch.nn.Sequential(
            torch.nn.BatchNorm1d(feature_dim),
            torch.nn.Linear(feature_dim, 20), torch.nn.ELU(),
            torch.nn.Linear(20, 20), torch.nn.ELU(),
            torch.nn.Linear(20, 20), torch.nn.ELU(),
            torch.nn.Linear(20, 20), torch.nn.ELU(),
            torch.nn.Linear(20, 20), torch.nn.ELU(),
            torch.nn.Linear(20, 6), torch.nn.Softmax(dim=-1)
        )
    def forward(self, x):
        return self.encoder(x)

In [7]:
from kooplearn.models.feature_maps import VAMPNet, DPNet
from torch.optim import Adam
import lightning 


trainer_kwargs = {
    "accelerator": "cpu",
    "devices": 1,
    "max_epochs": 30,
    "enable_progress_bar": True,
    "enable_model_summary": True,
    "enable_checkpointing": False,
    "logger": False,
}


trainer = lightning.Trainer(**trainer_kwargs)
feature_map = DPNet(MLP, Adam, trainer, encoder_kwargs={'feature_dim': distances_dim}, optimizer_kwargs={"lr": 1e-3},center_covariances=False,
        seed=0)

GPU available: True (mps), used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
  rank_zero_warn(
Global seed set to 0


In [8]:
feature_map.fit(train_dataloaders=train_loader, val_dataloaders=val_loader)

  rank_zero_warn("You passed in a `val_dataloader` but have no `validation_step`. Skipping val loop.")

  | Name    | Type | Params
---------------------------------
0 | encoder | MLP  | 2.8 K 
---------------------------------
2.8 K     Trainable params
0         Non-trainable params
2.8 K     Total params
0.011     Total estimated model params size (MB)


Fitting DPNet. Lookback window length set to 1


  rank_zero_warn(


Training: 0it [00:00, ?it/s]

  rank_zero_warn(
  rank_zero_warn(
  rank_zero_warn(
  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")
