In [1]:
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
from pathlib import Path
import seaborn as sns
from pandas.api.types import CategoricalDtype
import matplotlib as mpl
import pymc as pm
import scipy as sp
import pickle
from sklearn.preprocessing import MinMaxScaler
import patsy as pt

In [20]:
def gamma(alpha, beta):
    def g(x):
        return pm.Gamma(x, alpha=alpha, beta=beta)
    return g

def hcauchy(beta):
    def g(x):
        return pm.HalfCauchy(x, beta=beta)
    return g


def fit_gp(y, X, l_prior, eta_prior, sigma_prior, kernel_type='M52', bayes_kws=dict(draws=1000, tune=1000, chains=2, cores=1), prop_Xu=None):
    """
    function to return a pymc3 model
    y : dependent variable
    X : independent variables
    prop_Xu : number of inducing varibles to use. If None, use full marginal likelihood. If not none, use FTIC. 
    bayes_kw : kws for pm.sample
    X, y are dataframes. We'll use the column names. 
    """
    kernel_type = kernel_type.lower()
    with pm.Model() as model:
        # Covert arrays
        X_a = X.values
        y_a = y.values.flatten()
        X_cols = list(X.columns)

        
        # Kernels
        # 3 way interaction
        eta = eta_prior('eta')
        cov = eta**2
        for i in range(X_a.shape[1]):
            var_lab = 'l_'+X_cols[i]
            if kernel_type=='rbf':
                cov = cov*pm.gp.cov.ExpQuad(X_a.shape[1], ls=l_prior(var_lab), active_dims=[i])
            if kernel_type=='exponential':
                cov = cov*pm.gp.cov.Exponential(X_a.shape[1], ls=l_prior(var_lab), active_dims=[i])
            if kernel_type=='m52':
                cov = cov*pm.gp.cov.Matern52(X_a.shape[1], ls=l_prior(var_lab), active_dims=[i])
            if kernel_type=='m32':
                cov = cov*pm.gp.cov.Matern32(X_a.shape[1], ls=l_prior(var_lab), active_dims=[i])

        # Covariance model
        cov_tot = cov 
        
        # Noise model
        sigma_n =sigma_prior('sigma_n')

        # Model
        if not (prop_Xu is None):
            # Inducing variables
            num_Xu = int(X_a.shape[0]*prop_Xu)
            Xu = pm.gp.util.kmeans_inducing_points(num_Xu, X_a)
            gp = pm.gp.MarginalSparse(cov_func=cov_tot, approx="FITC")
            y_ = gp.marginal_likelihood('y_', X=X_a, y=y_a, Xu=Xu, noise=sigma_n)
        else:
            gp = pm.gp.Marginal(cov_func=cov_tot)
            y_ = gp.marginal_likelihood('y_', X=X_a, y=y_a, noise=sigma_n)
            
        
        if not (bayes_kws is None):
            trace = pm.sample(**bayes_kws)
            result = trace
        else:
            mp = pm.find_MAP()
            result = mp
    
    return gp, result, model

In [27]:
protein = '1fme'
lag=41
proc=2


params = [['dihedrals', None], ['distances', 'linear'], ['distances', 'logistic']]
kernels = ['exponential', 'rbf', 'm32', 'm52']

# pre-processing params
data_cols = ['median', 'tica__dim', 'tica__lag', 'cluster__k', 'feature__value', 'distances__scheme', 'distances__transform', 
             'distances__steepness', 'distances__centre'
]
var_names_short = ['ts', 'dim', 'lag', 'states', 'feat', 'scheme', 'trans', 'steep', 'cent']
name_dict = dict(zip(data_cols, var_names_short))
scaling = dict(dim=[1, 20], lag=[1, 100],states=[10, 500], steep=[0, 50], cent=[0, 1.5])

# Bayesian kws
bayes_kws = dict(draws=5000, tune=1000, chains=4, cores=4, target_accept=0.90)

# Load data
summary_path = f'{protein}/summary.h5'
hp_path = '../experiments/hpsample.h5'
timescales = pd.read_hdf(summary_path, key='timescales')
# vamps = pd.read_hdf(summary_path, key='vamps')
timescales.reset_index(inplace=True)
# vamps.reset_index(inplace=True)
hps = pd.read_hdf(hp_path)
hps.reset_index(inplace=True)

# Create main data DF
data = timescales.query(f"process=={proc}").query(f'lag=={lag}')
data = data.merge(hps, on=['hp_ix'], how='left')
data = data.loc[:, data_cols+['hp_ix']]
data.rename(mapper=name_dict, axis=1, inplace=True)

out_dir = Path(protein).joinpath('sensitivity')
out_dir.mkdir(parents=True, exist_ok=True)

for kernel in kernels: 
    print(kernel)
    for feat, trans in params:
        print(feat, trans)
        
        out_path = out_dir.joinpath(f"{feat}_{trans}_{kernel}.pkl")
        
        # Create formula
        formula = "ts ~  dim + lag + states"
        if feat == 'distances': 
            formula += ' + scheme'
            if (trans == 'logistic'): 
                formula += " + steep + cent"
        print(formula)
        
        X = data.query(f"(feat == '{feat}')")
        if trans is not None: 
            X = X.query(f"trans == '{trans}'")

        # Create scaled and subsetted DF
        ydf, Xdf = pt.dmatrices(formula, data=X, return_type='dataframe', NA_action='raise')
        scaler = MinMaxScaler(feature_range=(0, 1))
        yX = np.concatenate([ydf.values, Xdf.values], axis=1)
        scaler.fit(yX)
        yX_s = scaler.transform(yX)
        data_s = pd.DataFrame(yX_s, columns = list(ydf.columns) + list(Xdf.columns))
        data_s.drop(labels=['Intercept'], axis=1, inplace=True) # Drop the intercept because we're in an exponential
        y = data_s.iloc[:, [0]]
        X = data_s.iloc[:, 1:]

        # Fit model
        l_prior = gamma(2, 0.5)
        eta_prior = hcauchy(2)
        sigma_prior = hcauchy(2)
        gp, trace, model = fit_gp(y=y, X=X,  # Data
                                        l_prior=l_prior, eta_prior=eta_prior, sigma_prior=sigma_prior,  # Priors
                                        kernel_type=kernel,  # Kernel
                                        prop_Xu=None,  # proportion of data points which are inducing variables.
                                        bayes_kws=bayes_kws)  # Bayes kws

        results = {'trace': trace, 'data': data_s, 'formula': formula, 'scaler': scaler, 'lag': lag, 'proc': proc}
        pickle.dump(obj=results, file=out_path.open('wb'))
        
        with sns.plotting_context('paper'): 
            pm.plot_trace(trace)
            plt.tight_layout()
            plt.savefig(out_path.with_suffix('.pdf'), bbox_inches='tight')
            plt.close()


exponential
dihedrals None
ts ~  dim + lag + states


Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [eta, l_dim, l_lag, l_states, sigma_n]


Sampling 4 chains for 1_000 tune and 5_000 draw iterations (4_000 + 20_000 draws total) took 157 seconds.


distances linear
ts ~  dim + lag + states + scheme


Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [eta, l_scheme[T.closest-heavy], l_dim, l_lag, l_states, sigma_n]


Sampling 4 chains for 1_000 tune and 5_000 draw iterations (4_000 + 20_000 draws total) took 577 seconds.


distances logistic
ts ~  dim + lag + states + scheme + steep + cent


Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [eta, l_scheme[T.closest-heavy], l_dim, l_lag, l_states, l_steep, l_cent, sigma_n]


Sampling 4 chains for 1_000 tune and 5_000 draw iterations (4_000 + 20_000 draws total) took 911 seconds.
There was 1 divergence after tuning. Increase `target_accept` or reparameterize.


rbf
dihedrals None
ts ~  dim + lag + states


Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [eta, l_dim, l_lag, l_states, sigma_n]


Sampling 4 chains for 1_000 tune and 5_000 draw iterations (4_000 + 20_000 draws total) took 201 seconds.
There were 8 divergences after tuning. Increase `target_accept` or reparameterize.
There were 2 divergences after tuning. Increase `target_accept` or reparameterize.
There were 16 divergences after tuning. Increase `target_accept` or reparameterize.
There were 7 divergences after tuning. Increase `target_accept` or reparameterize.


distances linear
ts ~  dim + lag + states + scheme


Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [eta, l_scheme[T.closest-heavy], l_dim, l_lag, l_states, sigma_n]


Sampling 4 chains for 1_000 tune and 5_000 draw iterations (4_000 + 20_000 draws total) took 191 seconds.
There was 1 divergence after tuning. Increase `target_accept` or reparameterize.


distances logistic
ts ~  dim + lag + states + scheme + steep + cent


Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [eta, l_scheme[T.closest-heavy], l_dim, l_lag, l_states, l_steep, l_cent, sigma_n]


Sampling 4 chains for 1_000 tune and 5_000 draw iterations (4_000 + 20_000 draws total) took 1040 seconds.
There were 50 divergences after tuning. Increase `target_accept` or reparameterize.
There were 21 divergences after tuning. Increase `target_accept` or reparameterize.
There were 21 divergences after tuning. Increase `target_accept` or reparameterize.


m32
dihedrals None
ts ~  dim + lag + states


Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [eta, l_dim, l_lag, l_states, sigma_n]


Sampling 4 chains for 1_000 tune and 5_000 draw iterations (4_000 + 20_000 draws total) took 237 seconds.


distances linear
ts ~  dim + lag + states + scheme


Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [eta, l_scheme[T.closest-heavy], l_dim, l_lag, l_states, sigma_n]


Sampling 4 chains for 1_000 tune and 5_000 draw iterations (4_000 + 20_000 draws total) took 243 seconds.
There was 1 divergence after tuning. Increase `target_accept` or reparameterize.
There were 4 divergences after tuning. Increase `target_accept` or reparameterize.
There were 4 divergences after tuning. Increase `target_accept` or reparameterize.
There were 5 divergences after tuning. Increase `target_accept` or reparameterize.


distances logistic
ts ~  dim + lag + states + scheme + steep + cent


Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [eta, l_scheme[T.closest-heavy], l_dim, l_lag, l_states, l_steep, l_cent, sigma_n]


Sampling 4 chains for 1_000 tune and 5_000 draw iterations (4_000 + 20_000 draws total) took 1388 seconds.
There was 1 divergence after tuning. Increase `target_accept` or reparameterize.
There were 2 divergences after tuning. Increase `target_accept` or reparameterize.


m52
dihedrals None
ts ~  dim + lag + states


Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [eta, l_dim, l_lag, l_states, sigma_n]


Sampling 4 chains for 1_000 tune and 5_000 draw iterations (4_000 + 20_000 draws total) took 331 seconds.


distances linear
ts ~  dim + lag + states + scheme


Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [eta, l_scheme[T.closest-heavy], l_dim, l_lag, l_states, sigma_n]


Sampling 4 chains for 1_000 tune and 5_000 draw iterations (4_000 + 20_000 draws total) took 283 seconds.
There were 9 divergences after tuning. Increase `target_accept` or reparameterize.
There were 4 divergences after tuning. Increase `target_accept` or reparameterize.
There were 18 divergences after tuning. Increase `target_accept` or reparameterize.
There were 4 divergences after tuning. Increase `target_accept` or reparameterize.


distances logistic
ts ~  dim + lag + states + scheme + steep + cent


Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [eta, l_scheme[T.closest-heavy], l_dim, l_lag, l_states, l_steep, l_cent, sigma_n]


Sampling 4 chains for 1_000 tune and 5_000 draw iterations (4_000 + 20_000 draws total) took 2165 seconds.
There were 3 divergences after tuning. Increase `target_accept` or reparameterize.
There was 1 divergence after tuning. Increase `target_accept` or reparameterize.
There were 4 divergences after tuning. Increase `target_accept` or reparameterize.
There were 3 divergences after tuning. Increase `target_accept` or reparameterize.
