# VAMPNets

In [None]:
import numpy as np
from tqdm.notebook import tqdm

import sktime
import sktime.decomposition.vampnet as vnet

import torch
import torch.nn as nn

from sklearn.model_selection import train_test_split

In [None]:
assert torch.cuda.is_available()
device = torch.device("cuda:0")
torch.backends.cudnn.benchmark = True
torch.set_num_threads(12)

## Ellipsoids dataset

In [None]:
data_source = sktime.data.ellipsoids()
data_source.msm.update_transition_matrix([[.95, .05], [.09, .91]])
data = data_source.observations(50000, n_dim=2).astype(np.float32)

dataset = sktime.data.TimeLaggedDataset.from_trajectory(lagtime=1, data=data)

The dataset in two dimensions: Jump process between two metastable states where each of the states is observed in form of an ellipsoid.

In [None]:
import matplotlib.pyplot as plt
import matplotlib as mpl
from scipy.stats import multivariate_normal

x = np.linspace(-8,8,500)
y = np.linspace(-6,10,500)
X, Y = np.meshgrid(x,y)
pos = np.empty(X.shape + (2,))
pos[:, :, 0] = X
pos[:, :, 1] = Y
rv1 = multivariate_normal(data_source.state_0_mean, data_source.covariance_matrix)
rv2 = multivariate_normal(data_source.state_1_mean, data_source.covariance_matrix)

fig = plt.figure()
ax = fig.gca()

ax.contourf(X, Y, (rv1.pdf(pos) + rv2.pdf(pos)).reshape(len(x), len(y)))
ax.autoscale(False)
ax.scatter(*(data_source.observations(100).T), color='cyan', marker='x',label='samples')
plt.legend()
plt.show()

Split data into train and validation set, move the validation set into a torch tensor and onto the appropriate device.

In [None]:
n_val = int(len(dataset)*.3)
train_data, val_data = torch.utils.data.random_split(dataset, [len(dataset) - n_val, n_val])

The network lobe. Optionally one can use two lobes, one for the instantaneous and one for the time-shifted data.

In [None]:
lobe = vnet.MLPLobe(units=[data.shape[1], 20, 15, 10, 6, 1], nonlinearity=lambda: nn.PReLU(),
                    output_nonlinearity=nn.Sigmoid)

The optimizer to train the brain as well as some hyperparameters.

In [None]:
vampnet = vnet.VAMPNet(lagtime=1, lobe=lobe, device=device, learning_rate=1e-3, score_mode='clamp',
                      score_method='VAMPE')

In [None]:
train_scores = []
val_scores = []

def train_callback(step, score):
    train_scores.append((step, score.cpu().numpy()))
def val_callback(step, score):
    np_score = score.cpu().numpy()
    print(f"Validation step {step}: {np_score:.5f}")
    val_scores.append((step, np_score))

In [None]:
vampnet_model = vampnet.fit(train_data, batch_size=32, n_epochs=20, 
                            validation_data=val_data, 
                            train_score_callback=train_callback,
                            validation_score_callback=val_callback
                           ).fetch_model()

In [None]:
vampnet_model = vampnet.fetch_model()
featurization = vampnet_model.transform(data)

In [None]:
featurization.shape

In [None]:
plt.plot(featurization[:1000]);

In [None]:
f, ax = plt.subplots(1, 1)
cm = ax.scatter(*data.T, c=featurization, cmap='coolwarm')
f.colorbar(cm, ax=ax);

In [None]:
koopman_model = sktime.decomposition.VAMP(lagtime=1).fit(featurization).fetch_model()

In [None]:
plt.plot(koopman_model.transform(featurization)[:1000, 0])

In [None]:
koopman_model_direct = sktime.decomposition.VAMP(lagtime=1).fit(data).fetch_model()

print(koopman_model.score())
print(koopman_model_direct.score())

In [None]:
plt.semilogy(*np.array(train_scores).T)
plt.semilogy(*np.array(val_scores).T)

In [None]:
projection = koopman_model.transform(featurization)
dtraj = sktime.clustering.KmeansClustering(2).fit(projection).transform(projection)
msm = sktime.markov.msm.MaximumLikelihoodMSM().fit(dtraj, lagtime=1).fetch_model()

In [None]:
print("estimated transition matrix", msm.transition_matrix)
print("reference transition matrix", data_source.msm.transition_matrix)

In [None]:
def print_states_pie_chart():
    coors = []
    n_states = np.max(dtraj)+1

    for i in range(n_states):
        coors.append(np.sum(dtraj==i))
    total = len(dtraj)
    
    fig1, ax1 = plt.subplots()
    ax1.pie(np.array(coors), autopct='%1.2f%%', startangle=90)
    ax1.axis('equal')
    print('States population: '+str(np.array(coors)/total*100)+'%')
    plt.show()

print_states_pie_chart()

In [None]:
linear_model = sktime.decomposition.VAMP(lagtime=1, dim=1).fit(data).fetch_model()

In [None]:
plt.plot(projection[:200][:, 0], label='VAMPNet estimator');
plt.plot(linear_model.transform(data)[:200][:, 0], label='VAMP estimator', linestyle='dotted')
plt.legend();

In [None]:
print('Ground truth timescale:', data_source.msm.timescales()[0])
print('VAMPNet timescale:', koopman_model.timescales()[0])
print('VAMP timescale:', koopman_model_direct.timescales()[0])

In [None]:
print('VAMPNet score:', koopman_model.score())
print('VAMP score:', koopman_model_direct.score())

## Alanine dipeptide

In [None]:
import mdshare

In [None]:
dihedrals = np.load(mdshare.fetch('alanine-dipeptide-3x250ns-backbone-dihedrals.npz'))['arr_0']
coordinates = np.load(mdshare.fetch('alanine-dipeptide-3x250ns-heavy-atom-positions.npz'))['arr_0']

In [None]:
# Tau, how much is the timeshift of the two datasets
tau = 1

# Batch size for Stochastic Gradient descent
batch_size = 768

# Which trajectory points percentage is used as training
train_ratio = 0.9

# How many output states the network has
output_size = 6

# Iteration over the training set in the fitting process
nb_epoch = 60

In [None]:
n_train = int(np.floor(len(coordinates) * train_ratio))
n_validation = len(coordinates) - tau - n_train

In [None]:
dataset = sktime.data.TimeLaggedDataset.from_trajectory(1, coordinates.astype(np.float32))
traj_data_train, traj_data_validation = torch.utils.data.random_split(dataset, [n_train, n_validation])

In [None]:
lobe = nn.Sequential(
    nn.Linear(coordinates.shape[1], 128), nn.ELU(),
    nn.Linear(128, 128), nn.ELU(),
    nn.Linear(128, 128), nn.ELU(),
    nn.Linear(128, 128), nn.ELU(),
    nn.Linear(128, 6), nn.Softmax(dim=1)  # output a probability distribution over 6 states
)

In [None]:
vampnet = vnet.VAMPNet(lagtime=tau, lobe=lobe, device=device, learning_rate=5e-4, score_method='VAMPE', 
                       score_mode='clamp', dtype=np.float32, epsilon=1e-6)

In [None]:
train_scores = []
val_scores = []

def train_callback(step, score):
    train_scores.append((step, score.cpu().numpy()))
def val_callback(step, score):
    np_score = score.cpu().numpy()
    val_scores.append((step, np_score))

In [None]:
vampnet.fit(traj_data_train, batch_size=batch_size, n_epochs=60, 
            validation_data=traj_data_validation,
            train_score_callback=train_callback, validation_score_callback=val_callback);

In [None]:
plt.loglog(*np.array(train_scores).T, label='train')
plt.loglog(*np.array(val_scores).T, label='val')
plt.xlabel('training step')
plt.ylabel('score [a.u.]')
plt.legend();

In [None]:
vampnet_model = vampnet.fetch_model()

In [None]:
# Transform the input trajectory using the network
states_prob = vampnet_model.transform(coordinates)
# and transform into discrete states
dtraj = states_prob.argmax(axis=-1)

In [None]:
cb = plt.scatter(*dihedrals.T, c=dtraj, cmap=plt.cm.get_cmap('plasma', output_size), 
                 alpha=.5, s=5, vmin=-.5, vmax=output_size-.5)
cbar = plt.colorbar(cb, ticks=np.arange(output_size))
cbar.ax.set_yticklabels([f"State {i+1}" for i in range(output_size)]);
plt.axes = [[-np.pi, np.pi],[-np.pi, np.pi]]
plt.xlabel(r'$\varphi$ [rad]')
plt.ylabel(r'$\psi$ [rad]')
plt.title('State assignments on Ramachandran plot');

In [None]:
plt.pie([np.sum(dtraj==i) for i in range(output_size)], autopct='%1.2f%%', 
        labels=[f"State {i+1}" for i in range(output_size)])
plt.title('State populations');

In [None]:
fig, axes = plt.subplots(2, 3, figsize=(15, 10), sharex=True, sharey=True)
fig.suptitle('State probabilities for each of the output states')

for i, ax in enumerate(axes.flatten()):
    im = ax.hexbin(*dihedrals.T, C=states_prob[:,i], vmin=0, vmax=1, cmap=plt.cm.coolwarm)
    ax.set_title(f"State {i+1}")

    ax.set_xlim([-np.pi, np.pi]);
    ax.set_ylim([-np.pi, np.pi]);
    
    ax.set_xlabel(r'$\varphi$ [rad]')
    ax.set_ylabel(r'$\psi$ [rad]')

norm = mpl.colors.Normalize(vmin=0, vmax=1)
fig.colorbar(plt.cm.ScalarMappable(norm=norm, cmap=plt.cm.coolwarm), ax=axes, shrink=.8);