In [2]:
import time
import pickle
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
from scipy.stats import levy_stable
import tensorflow as tf

import bayesflow as bf












  from tqdm.autonotebook import tqdm


In [3]:
# Suppress scientific notation for floats
np.set_printoptions(suppress=True)

In [4]:
print("Num GPUs Available: ", len(tf.config.list_physical_devices('GPU')))

Num GPUs Available:  0


In [5]:
def simulate_levy_trial(a=1, z=0.5, v=0, t=0.2, alpha=2.0, dt=0.001):
    """
    Simulates the response time for one trial from a Lévy-Flight Process
     * param a:     threshold separation (a>0)
     * param z:     relative Starting Point (0<z<1)
     * param v:     drift
     * param t:     non-decisio-time (t>0)
     * param alpha: stability parameter (0<a<2)
     * param dt:    step size for simulation (dt>0)
    
    The function returns a simulated response time. A negative sign indicates a response at the lower threshold
    """
    
    scale = np.power(dt,(1/alpha)) / np.sqrt(2)
    
    bin_size = np.int_(1/dt)  # The decision path is simulated in bins of 1 second
    start = [a*z]
    cnt = 0

    while(True):
        path = np.array(start + np.cumsum(v*dt + levy_stable.rvs(alpha, 0, scale=scale, size=bin_size)))
        if np.any(path < 0): 
            cnt = cnt + np.min(np.where(path < 0))
            return (-(t + cnt*dt)).astype(np.float32)
        
        if np.any(path > a): 
            cnt = cnt + np.min(np.where(path > a))
            return (t + cnt*dt).astype(np.float32)
        
        start = path[-1]
        cnt = cnt + bin_size

In [6]:
RNG = np.random.default_rng(2024)

In [10]:
def LF_prior():
    """
    Draws one set of marginally informative priors for the 7-parameter Levy-Flight-Model
    v1 nad v2 represent drift rates for two different stimulus types, 
    with positive and negative average slopes of evidence accumulation, respectively.
    
    The function returns a np array with all parameter values
    """
    a      = RNG.gamma(3, 1/3) + 0.1
    z      = RNG.beta(5, 5)
    v     = RNG.normal(3, 3) 
    t0     = RNG.gamma(1, 1/6) + 0.1
    st0    = RNG.beta(1,2) * 2*t0
    alpha  = RNG.beta (6, 2) * 2

    return np.array([a, z, v, t0, st0, alpha]).astype(np.float32)

In [11]:
PARAM_NAMES = ['a', 'z', 'v', 't', 'st', 'alpha']

In [12]:
prior = bf.simulation.Prior(prior_fun=LF_prior, param_names=PARAM_NAMES)

In [13]:
def LF_experiment(theta, n_obs=500):
    """
    Simulates response times for one participant of a lévy-flight experiment.
     * param theta: numpy array with the parameters of the lfm (a,z,v,t,st,alpha) 
     * param n_obs: (maxiimum) number of observations
    
    The function returns a 2 by n_obs array, where the first column give the stimulus type (0,1)
    and the second column gives teh response times (negative values indicate responese at the lower threshold)
    """
    sim_data = np.zeros([n_obs,2])
    cnd = np.zeros(n_obs)
    sim_data[:,0] = cnd;
    for i in range(n_obs):
        sim_data[i,1] = simulate_levy_trial(
                            a = theta[0], 
                            z = theta[1],
                            v = theta[2],
                            t = np.random.uniform(theta[3] - theta[4]/2, theta[3] + theta[4]/2), 
                            alpha = theta[5]
                        )   
    return sim_data.astype(np.float32)

In [10]:
simulator = bf.simulation.Simulator(simulator_fun = LF_experiment)
model = bf.simulation.GenerativeModel(prior=prior, simulator=simulator, name="lfm")

INFO:root:Performing 2 pilot runs with the lfm model...
INFO:root:Shape of parameter batch after 2 pilot simulations: (batch_size = 2, 7)
INFO:root:Shape of simulation batch after 2 pilot simulations: (batch_size = 2, 500, 2)
INFO:root:No optional prior non-batchable context provided.
INFO:root:No optional prior batchable context provided.
INFO:root:No optional simulation non-batchable context provided.
INFO:root:No optional simulation batchable context provided.


In [11]:
summary_net = bf.networks.SetTransformer(input_dim=2, summary_dim=14, name="lfm_summary")







In [12]:
inference_net = bf.networks.InvertibleNetwork(
    num_params=len(prior.param_names),
    coupling_settings={"dense_args": dict(kernel_regularizer=None), "dropout": False},
    name="lfm_inference",
)

In [13]:
amortizer = bf.amortizers.AmortizedPosterior(inference_net, summary_net, name="lfm_amortizer")

In [14]:
#prior_means, prior_stds = prior.estimate_means_and_stds(n_draws=100000)
#prior_means = np.round(prior_means, decimals=1)
#prior_stds = np.round(prior_stds, decimals=1)
#print(prior_means, prior_stds)

In [15]:
prior_means = [ 1.1,  0.5,  3.,  -3.,   0.3,  0.2,  1.5]
prior_stds = [0.6, 0.2, 3.,  3.,  0.2, 0.2, 0.3]

In [16]:
def configurator(forward_dict, min_trials=200, max_trials=500):
    """Configure the output of the GenerativeModel for a BayesFlow setup."""

    # Prepare placeholder dict
    out_dict = {}

    # Extract simulated response times
    data = forward_dict["sim_data"]
    num_trials = np.random.randint(min_trials, max_trials + 1)
    idx = np.random.choice(range(data.shape[1]), size=num_trials, replace=False)
    data = data[:, idx, :]

    out_dict["summary_conditions"] = data.astype(np.float32)


    # Make inference network aware of varying numbers of trials
    # We create a vector of shape (batch_size, 1) by repeating the sqrt(num_obs)
    vec_num_obs = np.sqrt(num_trials) * np.ones((data.shape[0], 1))
    out_dict["direct_conditions"] = np.sqrt(vec_num_obs).astype(np.float32)

    # Get data generating parameters
    params = forward_dict["prior_draws"].astype(np.float32)

    # Standardize parameters
    out_dict["parameters"] = ((params - prior_means) / prior_stds).astype(np.float32)

    return out_dict

In [17]:
trainer = bf.trainers.Trainer(
    generative_model=model, 
    amortizer=amortizer, 
    configurator=configurator,
    default_lr=0.0001,
    checkpoint_path="checkpoints//lfm_7p_new"
)

INFO:root:Initialized empty loss history.
INFO:root:Initialized networks from scratch.
INFO:root:Performing a consistency check with provided components...
INFO:root:Done.


In [18]:
%%time
## Attention: Takes about 5 minutes per 1,000 data sets!
## For training, at least 100,000 data sets are recommended (~16 hours)

train_data = model(2000)

f = open("training_data_lfm_7p_short.obj","wb")
pickle.dump(train_data,f,protocol=4)
f.close()

CPU times: total: 8min 29s
Wall time: 8min 29s


In [19]:
f = open("training_data_lfm_7p_short.obj","rb")
train_data = pickle.load(f)
f.close()


In [20]:
print("Num GPUs Available: ", len(tf.config.list_physical_devices('GPU')))

Num GPUs Available:  0


In [None]:
history = trainer.train_offline(
    simulations_dict = train_data, 
    epochs = 1000, 
    batch_size = 20
)  

Training epoch 1:   0%|          | 0/100 [00:00<?, ?it/s]

Training epoch 2:   0%|          | 0/100 [00:00<?, ?it/s]

Training epoch 3:   0%|          | 0/100 [00:00<?, ?it/s]

Training epoch 4:   0%|          | 0/100 [00:00<?, ?it/s]

Training epoch 5:   0%|          | 0/100 [00:00<?, ?it/s]

Training epoch 6:   0%|          | 0/100 [00:00<?, ?it/s]

Training epoch 7:   0%|          | 0/100 [00:00<?, ?it/s]

Training epoch 8:   0%|          | 0/100 [00:00<?, ?it/s]

Training epoch 9:   0%|          | 0/100 [00:00<?, ?it/s]

Training epoch 10:   0%|          | 0/100 [00:00<?, ?it/s]

Training epoch 11:   0%|          | 0/100 [00:00<?, ?it/s]

Training epoch 12:   0%|          | 0/100 [00:00<?, ?it/s]

Training epoch 13:   0%|          | 0/100 [00:00<?, ?it/s]

Training epoch 14:   0%|          | 0/100 [00:00<?, ?it/s]

Training epoch 15:   0%|          | 0/100 [00:00<?, ?it/s]

Training epoch 16:   0%|          | 0/100 [00:00<?, ?it/s]

Training epoch 17:   0%|          | 0/100 [00:00<?, ?it/s]

Training epoch 18:   0%|          | 0/100 [00:00<?, ?it/s]

Training epoch 19:   0%|          | 0/100 [00:00<?, ?it/s]

Training epoch 20:   0%|          | 0/100 [00:00<?, ?it/s]

Training epoch 21:   0%|          | 0/100 [00:00<?, ?it/s]

Training epoch 22:   0%|          | 0/100 [00:00<?, ?it/s]

Training epoch 23:   0%|          | 0/100 [00:00<?, ?it/s]

Training epoch 24:   0%|          | 0/100 [00:00<?, ?it/s]

Training epoch 25:   0%|          | 0/100 [00:00<?, ?it/s]

Training epoch 26:   0%|          | 0/100 [00:00<?, ?it/s]

Training epoch 27:   0%|          | 0/100 [00:00<?, ?it/s]

Training epoch 28:   0%|          | 0/100 [00:00<?, ?it/s]

Training epoch 29:   0%|          | 0/100 [00:00<?, ?it/s]

Training epoch 30:   0%|          | 0/100 [00:00<?, ?it/s]

Training epoch 31:   0%|          | 0/100 [00:00<?, ?it/s]

Training epoch 32:   0%|          | 0/100 [00:00<?, ?it/s]

Training epoch 33:   0%|          | 0/100 [00:00<?, ?it/s]

Training epoch 34:   0%|          | 0/100 [00:00<?, ?it/s]

Training epoch 35:   0%|          | 0/100 [00:00<?, ?it/s]

Training epoch 36:   0%|          | 0/100 [00:00<?, ?it/s]

Training epoch 37:   0%|          | 0/100 [00:00<?, ?it/s]

Training epoch 38:   0%|          | 0/100 [00:00<?, ?it/s]

Training epoch 39:   0%|          | 0/100 [00:00<?, ?it/s]

Training epoch 40:   0%|          | 0/100 [00:00<?, ?it/s]

Training epoch 41:   0%|          | 0/100 [00:00<?, ?it/s]

Training epoch 42:   0%|          | 0/100 [00:00<?, ?it/s]

Training epoch 43:   0%|          | 0/100 [00:00<?, ?it/s]

Training epoch 44:   0%|          | 0/100 [00:00<?, ?it/s]

Training epoch 45:   0%|          | 0/100 [00:00<?, ?it/s]

Training epoch 46:   0%|          | 0/100 [00:00<?, ?it/s]

Training epoch 47:   0%|          | 0/100 [00:00<?, ?it/s]

Training epoch 48:   0%|          | 0/100 [00:00<?, ?it/s]

Training epoch 49:   0%|          | 0/100 [00:00<?, ?it/s]

Training epoch 50:   0%|          | 0/100 [00:00<?, ?it/s]

Training epoch 51:   0%|          | 0/100 [00:00<?, ?it/s]

Training epoch 52:   0%|          | 0/100 [00:00<?, ?it/s]

Training epoch 53:   0%|          | 0/100 [00:00<?, ?it/s]

Training epoch 54:   0%|          | 0/100 [00:00<?, ?it/s]

Training epoch 55:   0%|          | 0/100 [00:00<?, ?it/s]

Training epoch 56:   0%|          | 0/100 [00:00<?, ?it/s]

Training epoch 57:   0%|          | 0/100 [00:00<?, ?it/s]

Training epoch 58:   0%|          | 0/100 [00:00<?, ?it/s]

Training epoch 59:   0%|          | 0/100 [00:00<?, ?it/s]

Training epoch 60:   0%|          | 0/100 [00:00<?, ?it/s]

Training epoch 61:   0%|          | 0/100 [00:00<?, ?it/s]

Training epoch 62:   0%|          | 0/100 [00:00<?, ?it/s]

Training epoch 63:   0%|          | 0/100 [00:00<?, ?it/s]

Training epoch 64:   0%|          | 0/100 [00:00<?, ?it/s]

Training epoch 65:   0%|          | 0/100 [00:00<?, ?it/s]

Training epoch 66:   0%|          | 0/100 [00:00<?, ?it/s]

Training epoch 67:   0%|          | 0/100 [00:00<?, ?it/s]

Training epoch 68:   0%|          | 0/100 [00:00<?, ?it/s]

Training epoch 69:   0%|          | 0/100 [00:00<?, ?it/s]

Training epoch 70:   0%|          | 0/100 [00:00<?, ?it/s]

Training epoch 71:   0%|          | 0/100 [00:00<?, ?it/s]

Training epoch 72:   0%|          | 0/100 [00:00<?, ?it/s]

Training epoch 73:   0%|          | 0/100 [00:00<?, ?it/s]

Training epoch 74:   0%|          | 0/100 [00:00<?, ?it/s]

Training epoch 75:   0%|          | 0/100 [00:00<?, ?it/s]

Training epoch 76:   0%|          | 0/100 [00:00<?, ?it/s]

Training epoch 77:   0%|          | 0/100 [00:00<?, ?it/s]

Training epoch 78:   0%|          | 0/100 [00:00<?, ?it/s]

Training epoch 79:   0%|          | 0/100 [00:00<?, ?it/s]

Training epoch 80:   0%|          | 0/100 [00:00<?, ?it/s]

Training epoch 81:   0%|          | 0/100 [00:00<?, ?it/s]

Training epoch 82:   0%|          | 0/100 [00:00<?, ?it/s]

Training epoch 83:   0%|          | 0/100 [00:00<?, ?it/s]

Training epoch 84:   0%|          | 0/100 [00:00<?, ?it/s]

Training epoch 85:   0%|          | 0/100 [00:00<?, ?it/s]

Training epoch 86:   0%|          | 0/100 [00:00<?, ?it/s]

Training epoch 87:   0%|          | 0/100 [00:00<?, ?it/s]

Training epoch 88:   0%|          | 0/100 [00:00<?, ?it/s]

Training epoch 89:   0%|          | 0/100 [00:00<?, ?it/s]

Training epoch 90:   0%|          | 0/100 [00:00<?, ?it/s]

Training epoch 91:   0%|          | 0/100 [00:00<?, ?it/s]

Training epoch 92:   0%|          | 0/100 [00:00<?, ?it/s]

Training epoch 93:   0%|          | 0/100 [00:00<?, ?it/s]

Training epoch 94:   0%|          | 0/100 [00:00<?, ?it/s]

Training epoch 95:   0%|          | 0/100 [00:00<?, ?it/s]

Training epoch 96:   0%|          | 0/100 [00:00<?, ?it/s]

Training epoch 97:   0%|          | 0/100 [00:00<?, ?it/s]

Training epoch 98:   0%|          | 0/100 [00:00<?, ?it/s]

Training epoch 99:   0%|          | 0/100 [00:00<?, ?it/s]

Training epoch 100:   0%|          | 0/100 [00:00<?, ?it/s]

Training epoch 101:   0%|          | 0/100 [00:00<?, ?it/s]

Training epoch 102:   0%|          | 0/100 [00:00<?, ?it/s]

Training epoch 103:   0%|          | 0/100 [00:00<?, ?it/s]

Training epoch 104:   0%|          | 0/100 [00:00<?, ?it/s]

Training epoch 105:   0%|          | 0/100 [00:00<?, ?it/s]

Training epoch 106:   0%|          | 0/100 [00:00<?, ?it/s]

Training epoch 107:   0%|          | 0/100 [00:00<?, ?it/s]

Training epoch 108:   0%|          | 0/100 [00:00<?, ?it/s]

Training epoch 109:   0%|          | 0/100 [00:00<?, ?it/s]

Training epoch 110:   0%|          | 0/100 [00:00<?, ?it/s]

Training epoch 111:   0%|          | 0/100 [00:00<?, ?it/s]

Training epoch 112:   0%|          | 0/100 [00:00<?, ?it/s]

Training epoch 113:   0%|          | 0/100 [00:00<?, ?it/s]

Training epoch 114:   0%|          | 0/100 [00:00<?, ?it/s]

Training epoch 115:   0%|          | 0/100 [00:00<?, ?it/s]

Training epoch 116:   0%|          | 0/100 [00:00<?, ?it/s]

Training epoch 117:   0%|          | 0/100 [00:00<?, ?it/s]

Training epoch 118:   0%|          | 0/100 [00:00<?, ?it/s]

Training epoch 119:   0%|          | 0/100 [00:00<?, ?it/s]

Training epoch 120:   0%|          | 0/100 [00:00<?, ?it/s]

Training epoch 121:   0%|          | 0/100 [00:00<?, ?it/s]

Training epoch 122:   0%|          | 0/100 [00:00<?, ?it/s]

Training epoch 123:   0%|          | 0/100 [00:00<?, ?it/s]

Training epoch 124:   0%|          | 0/100 [00:00<?, ?it/s]

Training epoch 125:   0%|          | 0/100 [00:00<?, ?it/s]

Training epoch 126:   0%|          | 0/100 [00:00<?, ?it/s]

Training epoch 127:   0%|          | 0/100 [00:00<?, ?it/s]

Training epoch 128:   0%|          | 0/100 [00:00<?, ?it/s]

Training epoch 129:   0%|          | 0/100 [00:00<?, ?it/s]

Training epoch 130:   0%|          | 0/100 [00:00<?, ?it/s]

Training epoch 131:   0%|          | 0/100 [00:00<?, ?it/s]

Training epoch 132:   0%|          | 0/100 [00:00<?, ?it/s]

Training epoch 133:   0%|          | 0/100 [00:00<?, ?it/s]

Training epoch 134:   0%|          | 0/100 [00:00<?, ?it/s]

Training epoch 135:   0%|          | 0/100 [00:00<?, ?it/s]

Training epoch 136:   0%|          | 0/100 [00:00<?, ?it/s]

Training epoch 137:   0%|          | 0/100 [00:00<?, ?it/s]

Training epoch 138:   0%|          | 0/100 [00:00<?, ?it/s]

Training epoch 139:   0%|          | 0/100 [00:00<?, ?it/s]

Training epoch 140:   0%|          | 0/100 [00:00<?, ?it/s]

Training epoch 141:   0%|          | 0/100 [00:00<?, ?it/s]

Training epoch 142:   0%|          | 0/100 [00:00<?, ?it/s]

Training epoch 143:   0%|          | 0/100 [00:00<?, ?it/s]

Training epoch 144:   0%|          | 0/100 [00:00<?, ?it/s]

Training epoch 145:   0%|          | 0/100 [00:00<?, ?it/s]

Training epoch 146:   0%|          | 0/100 [00:00<?, ?it/s]

Training epoch 147:   0%|          | 0/100 [00:00<?, ?it/s]

Training epoch 148:   0%|          | 0/100 [00:00<?, ?it/s]

Training epoch 149:   0%|          | 0/100 [00:00<?, ?it/s]

Training epoch 150:   0%|          | 0/100 [00:00<?, ?it/s]

Training epoch 151:   0%|          | 0/100 [00:00<?, ?it/s]

Training epoch 152:   0%|          | 0/100 [00:00<?, ?it/s]

Training epoch 153:   0%|          | 0/100 [00:00<?, ?it/s]

Training epoch 154:   0%|          | 0/100 [00:00<?, ?it/s]

Training epoch 155:   0%|          | 0/100 [00:00<?, ?it/s]

Training epoch 156:   0%|          | 0/100 [00:00<?, ?it/s]

Training epoch 157:   0%|          | 0/100 [00:00<?, ?it/s]

Training epoch 158:   0%|          | 0/100 [00:00<?, ?it/s]

Training epoch 159:   0%|          | 0/100 [00:00<?, ?it/s]

Training epoch 160:   0%|          | 0/100 [00:00<?, ?it/s]

Training epoch 161:   0%|          | 0/100 [00:00<?, ?it/s]

Training epoch 162:   0%|          | 0/100 [00:00<?, ?it/s]

Training epoch 163:   0%|          | 0/100 [00:00<?, ?it/s]

Training epoch 164:   0%|          | 0/100 [00:00<?, ?it/s]

Training epoch 165:   0%|          | 0/100 [00:00<?, ?it/s]

Training epoch 166:   0%|          | 0/100 [00:00<?, ?it/s]

Training epoch 167:   0%|          | 0/100 [00:00<?, ?it/s]

Training epoch 168:   0%|          | 0/100 [00:00<?, ?it/s]

Training epoch 169:   0%|          | 0/100 [00:00<?, ?it/s]

Training epoch 170:   0%|          | 0/100 [00:00<?, ?it/s]

Training epoch 171:   0%|          | 0/100 [00:00<?, ?it/s]

Training epoch 172:   0%|          | 0/100 [00:00<?, ?it/s]

Training epoch 173:   0%|          | 0/100 [00:00<?, ?it/s]

Training epoch 174:   0%|          | 0/100 [00:00<?, ?it/s]

Training epoch 175:   0%|          | 0/100 [00:00<?, ?it/s]

Training epoch 176:   0%|          | 0/100 [00:00<?, ?it/s]

Training epoch 177:   0%|          | 0/100 [00:00<?, ?it/s]

Training epoch 178:   0%|          | 0/100 [00:00<?, ?it/s]

Training epoch 179:   0%|          | 0/100 [00:00<?, ?it/s]

Training epoch 180:   0%|          | 0/100 [00:00<?, ?it/s]

In [None]:
history = trainer.train_online(
    epochs = 100, 
    iterations_per_epoch = 100,
    batch_size = 16
)  

In [None]:
f = bf.diagnostics.plot_losses(history)

In [24]:
# Generate some validation data
validation_sims = configurator(model(batch_size=100))

# Extract unstandardized prior draws and transform to original scale
prior_samples = validation_sims["parameters"]


In [25]:
print(
    f"Estimation will be performed on data sets with {validation_sims['summary_conditions'].shape[1]} simulated trials."
)


In [26]:
# Generate 100 posterior draws for each of the 1000 simulated data sets
post_samples = amortizer.sample(validation_sims, n_samples=100)

# Unstandardize posterior draws into original scale
post_samples = post_samples 

In [27]:
f = bf.diagnostics.plot_sbc_histograms(post_samples, prior_samples, num_bins=10, param_names=PARAM_NAMES)

In [28]:
f = bf.diagnostics.plot_sbc_ecdf(post_samples, prior_samples, stacked=False, difference=True, param_names=PARAM_NAMES)

In [29]:
post_samples = amortizer.sample(validation_sims, n_samples=1000)


In [30]:
f = bf.diagnostics.plot_recovery(
    post_samples, prior_samples, param_names=prior.param_names, point_agg=np.mean, uncertainty_agg=np.std
)

In [31]:
f = bf.diagnostics.plot_z_score_contraction(post_samples, prior_samples, param_names=prior.param_names)