In [1]:
import time
import arviz
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import stan
import nest_asyncio
nest_asyncio.apply()

In [2]:
np.random.seed(1729)

In [3]:
def simulate_data(N, P):
        
        # Simulating transmission and occurence rates
        transmission_rate = np.random.beta(2, 10, P)
        occurrence_rate = np.random.beta(2, 10, P)
        base_rate = np.random.beta(2, 10, 1)
        
        data = {}
        for p in range(P):
            occurrence = np.random.binomial(1, occurrence_rate[p], N)
            transmission = occurrence * np.random.binomial(1, transmission_rate[p], N)
            data[f'O{p+1}'] = occurrence
            data[f'T{p+1}'] = transmission
        
        data['T0'] = np.random.binomial(1, base_rate, N)
        X = pd.DataFrame(data)
        z = X.loc[:, X.columns.str.startswith('T')].sum(axis=1)
        y = (z > 0).astype(int)
        X = X.loc[:, X.columns.str.startswith('O')]
        return {'N': N, 'P': P, 'X': X.to_numpy(), 'y': y.to_numpy()}

In [4]:
# Stan model code
model_code = """
data {
  int<lower=0> N;                            // number of observations
  int<lower=0> P;                            // number of places
  int<lower=0, upper=1> X[N,P];              // activity occurrences
  int<lower=0, upper=1> y[N];                // transmission (tested positive)
  
}
parameters {
  real<lower=0, upper=1> theta[P];           // transmission rates
  real<lower=0, upper=1> rho;                // underlying risk
}
transformed parameters {
  // Precomputation
  real log1m_theta[P];
  real log1m_rho;
  
  for (p in 1:P) {
    log1m_theta[p] = log1m(theta[p]);
  }

  log1m_rho = log1m(rho);
}
model {
  // Priors
  theta ~ uniform(0, 1);
  rho ~ uniform(0, 1);
  
  // Likelihood
  for (n in 1:N) {
    real s = 0.0;
    for (p in 1:P) {
      if (X[n,p] == 1) {
        s += log1m_theta[p];
      }
    }
    s += log1m_rho;
    
    if (y[n] == 1) {
      target += log1m_exp(s);
    } 
    else {
      target += s;
    }
  }
}
""" 

In [5]:
def runtime_lineplot_N(N_space, P=4):
    """
    Plots model runtime as a function of the number of fitting samples
    Saves plot as .png file
    
    :param model: A Test and Trace model object
    :param N_space: One-dimensional iterable of integers for each number of fitting samples
    
    """ 
    N_space = np.sort(np.array(N_space))
    runtimes = []
    for i in N_space:
        N = int(i)
        print('Running model...')
        start = time.time()
        model_data=simulate_data(N,P)
        posterior = stan.build(model_code, data=model_data, random_seed=1)
        fit = posterior.sample(num_samples=1000, num_warmup=500, num_chains=4)
        end = time.time()
        print('Finished running')
        runtimes.append((end - start))
    
    fig, ax = plt.subplots()
    ax.plot(N_space, runtimes, color='tab:orange')
    ax.set_ylabel('Runtime')
    ax.set_xlabel('Number of fitting samples')
    ax.set_title('Model Runtime')
    fig.tight_layout()
    plt.savefig('runtime_N_plot.png')
    plt.show()

In [None]:
a = 10**4
Ns = [5*(10**5), a, 5*a, 10*a,50*a]
runtime_lineplot_N(Ns, P=4)

Building...

Running model...



Found model in cache. Done.
Sampling...
    0/6000 [>---------------------------]   0%  1 sec/0     
    1/6000 [>---------------------------]   0%  1 sec/6453  
  100/6000 [>---------------------------]   1% 3 secs/129   
  200/6000 [>---------------------------]   3% 4 secs/97    
  300/6000 [=>--------------------------]   5% 5 secs/85    
  400/6000 [=>--------------------------]   6% 6 secs/79    
  500/6000 [==>-------------------------]   8% 7 secs/76    
  501/6000 [==>-------------------------]   8% 8 secs/88    
  600/6000 [==>-------------------------]  10% 8 secs/77    
  600/6000 [==>-------------------------]  10% 12 secs/110   
  900/6000 [====>-----------------------]  15% 20 secs/128   
 1100/6000 [=====>----------------------]  18% 21 secs/113   
 1200/6000 [=====>----------------------]  20% 22 secs/106   
 1200/6000 [=====>----------------------]  20% 23 secs/112   