In [48]:
import pyemma as pm
import deeptime as dt
import numpy as np
from pathlib import Path
import pandas as pd
import seaborn as sns
from typing import *
from pathlib import Path
from deeptime.numeric import is_sorted, spd_inv_sqrt, schatten_norm
import scipy
from msmtools.estimation import transition_matrix


In [49]:
dtrajs_24 = [np.load(x) for x in Path('hp_24/bs_0').glob('dtraj*.npy')]
dtrajs_235 = [np.load(x) for x in Path('hp_235/bs_0').glob('dtraj*.npy')]
lag = 41

In [50]:
est_24 = dt.markov.msm.MaximumLikelihoodMSM(reversible=True, lagtime=lag)
est_24.fit(dtrajs_24)
mod_24 = est_24.fetch_model()
mod_24.score(dim=2), mod_24.timescales(1)[0], (mod_24.eigenvalues()[:2]**2).sum()

(1.9648893624588197, 23939.695689025433, 1.9965805862753765)

In [51]:
est_235 = dt.markov.msm.MaximumLikelihoodMSM(reversible=True, lagtime=lag)
est_235.fit(dtrajs_235)
mod_235 = est_235.fetch_model()
mod_235.score(dim=2), mod_235.timescales()[0], (mod_235.eigenvalues()[:2]**2).sum()

(1.9803537206523718, 2584.872309328571, 1.968774859683399)

In [56]:
def vamp(dtrajs: List[np.ndarray], lag: int, dim: int = 2, epsilon: float = 1e-15) -> float: 
    
    # Get the covariance koopman model
    est = dt.markov.msm.MaximumLikelihoodMSM(reversible=True, lagtime=lag)
    est.fit(dtrajs)
    mod = est.fetch_model()
    
    cmat = mod.count_model.count_matrix
    tmat = mod.transition_matrix
    
    # Empirical covariances
    cov0t = cmat
    cov00 = np.diag(cov0t.sum(axis=1))
    covtt = np.diag(cov0t.sum(axis=0))
    
    # Symmetrized tmat
    
    # reweight operator to empirical distribution
    C0t_re = cov00 @ tmat
    # symmetrized operator and SVD
    tmat_sym = np.linalg.multi_dot([spd_inv_sqrt(cov00, epsilon=epsilon, method='schur'), C0t_re,
                                 spd_inv_sqrt(covtt, epsilon=epsilon, method='schur')])
    
        
    print('norm(K.T - K) = ', schatten_norm(tmat_sym.T - tmat_sym, 2) ** 2)
    
    # SVD
    U, singular_values, Vt = scipy.linalg.svd(tmat_sym, compute_uv=True)
    U = spd_inv_sqrt(cov00, epsilon=epsilon) @ U
    Vt = Vt @ spd_inv_sqrt(covtt, epsilon=epsilon)
    V = Vt.T
    
    
    # Sort by singular values
    sort_ix = np.argsort(singular_values)[::-1][:dim]  # indices to sort in descending order
    U = U[:, sort_ix]
    V = V[:, sort_ix]
    singular_values = singular_values[sort_ix]
    
    
    
    
    A = np.atleast_2d(spd_inv_sqrt(U.T.dot(cov00).dot(U), epsilon=epsilon))
    B = np.atleast_2d(U.T.dot(cov0t).dot(V))
    C = np.atleast_2d(spd_inv_sqrt(V.T.dot(covtt).dot(V), epsilon=epsilon))
    ABC = np.linalg.multi_dot([A, B, C])
    vamp1 = schatten_norm(ABC, 1) ** 1
    vamp2 = schatten_norm(ABC, 2) ** 2
    
    with np.printoptions(precision=10):
        print('A: ')
        print(np.round(A, 10))
        print('B: ')
        print(np.round(B, 10))
        print('C: ')
        print(np.round(C, 10))
        print('Lambdas     =', mod.eigenvalues(2))
        print('Singulars   =', singular_values[:2])
        print('Lambdas^2   =', mod.eigenvalues(2)**2)
        print('Singulars^2 =', singular_values[:2]**2)
        
        print('Sum lambdas =   ', np.round(mod.eigenvalues(2).sum(), 4))
        print('VAMP1 =         ', np.round(vamp1, 4))
        print('Sum lambdas^2 = ', np.round((mod.eigenvalues(2)**2).sum(), 4))
        print('VAMP2 =         ', np.round(vamp2, 4))              
        print(f't_2 =            {np.round(mod.timescales(1)[0], 0)}')
    print('-'*80)

vamp(dtrajs_24, lag=41)
vamp(dtrajs_235, lag=41)

norm(K.T - K) =  0.6186800006777994
A: 
[[1. 0.]
 [0. 1.]]
B: 
[[1.           0.          ]
 [0.           0.9822878206]]
C: 
[[1. 0.]
 [0. 1.]]
Lambdas     = [1.           0.9982888291]
Singulars   = [1.           0.9822849512]
Lambdas^2   = [1.           0.9965805863]
Singulars^2 = [1.           0.9648837253]
Sum lambdas =    1.9983
VAMP1 =          1.9823
Sum lambdas^2 =  1.9966
VAMP2 =          1.9649
t_2 =            23940.0
--------------------------------------------------------------------------------
norm(K.T - K) =  0.850867732038595
A: 
[[1. 0.]
 [0. 1.]]
B: 
[[ 1.000000000e+00 -1.800000000e-09]
 [ 0.000000000e+00  9.901281335e-01]]
C: 
[[ 1. -0.]
 [-0.  1.]]
Lambdas     = [1.           0.9842636129]
Singulars   = [1.          0.989646646]
Lambdas^2   = [1.           0.9687748597]
Singulars^2 = [1.          0.979400484]
Sum lambdas =    1.9843
VAMP1 =          1.9901
Sum lambdas^2 =  1.9688
VAMP2 =          1.9804
t_2 =            2585.0
-------------------------------------