In [5]:
import duckdb
import pandas as pd
import numpy as np
import pymc3 as pm
import arviz as az
import patsy
import matplotlib.pyplot as plt

In [6]:
con = duckdb.connect(database='data/database/apixaban_data.duckdb')
df = pd.read_sql('select * from scaled_cleaned_data', con)
con.close()

df

Unnamed: 0,subjectids,age,is_male,weight_kg,dose_mg_twice_daily,hrs_post_dose,yobs_ng_ml,creatinine_micromol_l,indication_for_doac_use,history_arrhythmia_treatments_e_g_ablation_cardioversion,...,history_cancer,history_ckd,amiodarone_mg_day,carbamazepine_mg_day,diltiazem_mg_day,phenobarbital_mg_day,phenytoin_mg_day,primidone_mg_day,rifampin_mg_day,verapamil_mg_day
0,DP001,-0.504361,1.0,1.506693,5.0,8.17,265.4,3.583247,Atrial fibrillation,0.0,...,0.0,0.0,-0.269724,-0.074893,-0.391087,-0.049938,-0.105798,-0.070711,-0.049938,-0.070711
1,DP002,0.025641,1.0,0.605848,2.5,2.58,144.8,2.774356,Atrial fibrillation,0.0,...,0.0,0.0,3.058254,-0.074893,-0.391087,-0.049938,-0.105798,-0.070711,-0.049938,-0.070711
2,DP003,0.237642,0.0,0.016510,5.0,7.67,287.2,0.010646,Atrial fibrillation,0.0,...,0.0,0.0,1.394265,-0.074893,-0.391087,-0.049938,-0.105798,-0.070711,-0.049938,-0.070711
3,DP004,1.297647,0.0,-1.288453,5.0,7.50,363.9,-0.281453,Atrial fibrillation,0.0,...,0.0,0.0,-0.269724,-0.074893,-0.391087,-0.049938,-0.105798,-0.070711,-0.049938,-0.070711
4,DP005,0.449643,0.0,-0.505476,5.0,5.50,315.0,-0.708368,Atrial fibrillation,0.0,...,0.0,0.0,-0.269724,-0.074893,-0.391087,-0.049938,-0.105798,-0.070711,-0.049938,-0.070711
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
396,DP397,0.873645,1.0,-0.825402,2.5,4.50,125.5,1.134106,Atrial fibrillation,0.0,...,0.0,0.0,-0.269724,-0.074893,1.892785,-0.049938,-0.105798,4.655134,-0.049938,-0.070711
397,DP398,-0.080359,1.0,2.302300,5.0,5.75,85.1,4.751645,Atrial fibrillation,0.0,...,0.0,0.0,-0.269724,-0.074893,-0.391087,-0.049938,-0.105798,-0.070711,-0.049938,-0.070711
398,DP399,-0.716362,1.0,0.294341,5.0,3.83,174.6,1.268921,Atrial fibrillation,0.0,...,0.0,0.0,-0.269724,-0.074893,1.892785,-0.049938,-0.105798,-0.070711,-0.049938,-0.070711
399,DP400,-0.080359,1.0,0.795278,2.5,5.33,280.1,3.583247,Atrial fibrillation,0.0,...,0.0,0.0,-0.269724,-0.074893,1.892785,-0.049938,-0.105798,-0.070711,-0.049938,-0.070711


In [7]:
def make_design(v_formula, ke_formula, f_formula, a_formula, data):
    
    formulae = [v_formula, ke_formula, f_formula, a_formula]
    X_V, X_ke, X_f, X_a = list(map(lambda x: np.asarray(patsy.dmatrix(x, data=data)), formulae))
    
    return X_V, X_ke, X_f, X_a


def concentration(time, dose, v, f, ka, ke):
    return dose * f * ka / (v * (ke - ka)) * (pm.math.exp(-ka * time) - pm.math.exp(-ke * time))

In [8]:
v_formula = '~weight_kg + is_male -1'
ke_formula = '~weight_kg + age + is_male + creatinine_micromol_l -1'
f_formula = '~diltiazem_mg_day + amiodarone_mg_day-1'
a_formula = '~weight_kg + age + is_male + creatinine_micromol_l -1'


yobs = df.yobs_ng_ml.values * 0.001
time = df.hrs_post_dose.values
dose = df.dose_mg_twice_daily.values
X_v, X_ke, X_f, X_a = make_design(v_formula, ke_formula, f_formula, a_formula, data=df)


In [9]:
with pm.Model() as model:
    # Model 
    mu_v = pm.Normal('mu_v', np.log(21), 1)
    mu_ke = pm.Normal('mu_ke', -1.25, 0.25)
    mu_f = pm.Normal('mu_f', 0, 0.125)
    mu_a = pm.Normal('mu_a', -1, 0.25)
    
    sigma = pm.Lognormal('sigma', np.log(0.1), 0.2)
    
    beta_v = pm.Normal('b_v', 0, 0.25, shape=X_v.shape[1])
    beta_ke = pm.Normal('b_ke', 0, 0.25, shape=X_ke.shape[1])
    beta_f = pm.Normal('b_f', 0, 0.25, shape=X_f.shape[1])
    beta_a = pm.Normal('b_a', 0, 0.25, shape=X_a.shape[1])
    tmax = pm.Bound(pm.Normal, lower=0)('tmax', 3.5, 0.5)
    
    # Transformed parameters
    log_v = mu_v + pm.math.dot(X_v, beta_v)
    log_ke = mu_ke + pm.math.dot(X_ke, beta_ke)
    inv_logit_f = mu_f + pm.math.dot(X_f, beta_f)
    inv_logit_a = mu_a + pm.math.dot(X_a, beta_a)
    
    v = pm.Deterministic('V', pm.math.exp(log_v))
    ke = pm.Deterministic('ke', pm.math.exp(log_ke))
    f = pm.Deterministic('f', pm.math.invlogit(inv_logit_f))
    a = pm.Deterministic('a', pm.math.invlogit(inv_logit_a))
    ka = -pm.math.log(a)/(tmax*(1-a))

    
    C0 = 0
    for i in np.arange(14):
        C0+=concentration(12.0*i, dose, v, f, ka, ke)
        
    C = C0 + concentration(time, dose, v, f, ka, ke)
    
    Y = pm.Lognormal('Y', pm.math.log(C), sigma, observed=yobs)

In [10]:
with model:
    trace = pm.sample(chains=4, cores=4)

  trace = pm.sample(chains=4, cores=4)
Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...


KeyboardInterrupt: 

In [None]:
with model:
    ppc = pm.sample_posterior_predictive(trace, 1000)

In [None]:
plt.scatter(np.log(yobs), np.log(ppc['Y'].mean(0)))
plt.plot([-4, 0], [-4, 0], color = 'red')

In [None]:
with model:
    data = az.from_pymc3(trace)
rhat = az.rhat(data)

In [None]:
rhat.to_dataframe()