In [1]:
"""
This file is used to run simulation-based bayesian inference on the brownian motion simulators found in
Motion.py
"""

import os

if "KERAS_BACKEND" not in os.environ:
    # set this to "torch", "tensorflow", or "jax"
    os.environ["KERAS_BACKEND"] = "jax"

import datetime
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import bayesflow as bf
import keras
import os
import jax
from Motion import Motion

INFO:2025-12-04 11:31:38,307:jax._src.xla_bridge:822: Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: dlopen(libtpu.so, 0x0001): tried: 'libtpu.so' (no such file), '/System/Volumes/Preboot/Cryptexes/OSlibtpu.so' (no such file), '/usr/lib/libtpu.so' (no such file, not in dyld cache), 'libtpu.so' (no such file)
INFO:jax._src.xla_bridge:Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: dlopen(libtpu.so, 0x0001): tried: 'libtpu.so' (no such file), '/System/Volumes/Preboot/Cryptexes/OSlibtpu.so' (no such file), '/usr/lib/libtpu.so' (no such file, not in dyld cache), 'libtpu.so' (no such file)
INFO:bayesflow:Using backend 'jax'


In [2]:

# GRU network to make our input usable for the inference network
class GRU(bf.networks.SummaryNetwork):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)

        self.gru = keras.layers.GRU(128, dropout=0.1)
        self.summary_stats = keras.layers.Dense(64)

    def call(self, time_series, **kwargs):
        summary = self.gru(time_series, training=kwargs.get("stage") == "training")
        summary = self.summary_stats(summary)
        return summary

In [3]:
motion = Motion(simulator="geom", parameters="fc", with_prior=True)

In [12]:

simulator = bf.simulators.make_simulator([lambda : motion()])


{'motion': array([[[100.        , 100.        , 100.        ],
         [ 91.41938393, 102.77161044,  97.34266675],
         [101.20235985, 118.80527202,  99.10638145],
         [105.14865999, 126.55714949, 100.4767346 ],
         [ 99.66112319, 109.14884157,  99.62480082],
         [108.09556724, 124.13644334, 102.62797605],
         [107.27335947, 127.50419229, 101.64578111],
         [ 97.62310002, 103.96102957, 100.53567932],
         [110.09841347, 115.44025434, 103.30032798],
         [115.13857605, 115.3095913 , 104.88686396],
         [111.63895899, 110.34193041, 106.31176636],
         [118.21003001, 117.74038933, 107.72154496],
         [111.93053507, 117.19784467, 107.15030933],
         [111.93642118, 120.10985659, 108.50845238],
         [121.27742197, 122.6768765 , 112.76539916],
         [113.3964065 , 111.4146503 , 111.92068536],
         [118.26000695, 118.68643327, 115.23183225],
         [110.47710453, 128.3565287 , 112.57088564],
         [111.48139218, 148.43801126

In [15]:

adapter = (
    bf.adapters.Adapter()
    .convert_dtype("float64", "float32")
    .as_time_series("motion")
    .concatenate(["b1", "b2", "b3"], into="inference_variables")
    .rename("motion", "summary_variables")
    # .log(["summary_variables"], p1=True)
)

# summary_net = bf.networks.TimeSeriesTransformer(dropout=0.1)
summary_net = GRU(dropout=0.1)

# inference_net = bf.networks.CouplingFlow(transform="spline", depth=2, dropout=0.1)
inference_net = bf.networks.FlowMatching(dropout=0.1)

workflow = bf.BasicWorkflow(
    simulator=simulator,
    adapter=adapter,
    summary_network=summary_net,
    inference_network=inference_net,
    checkpoint_path="motion_workflow/"
)

In [16]:

train = workflow.simulate(8000)
validation = workflow.simulate(300)

history = workflow.fit_offline(data=train,
                               epochs=100,
                               batch_size=32,
                               validation_data=validation)

INFO:bayesflow:Fitting on dataset instance of OfflineDataset.
INFO:bayesflow:Building on a test batch.


Epoch 1/100
[1m 70/250[0m [32m━━━━━[0m[37m━━━━━━━━━━━━━━━[0m [1m3s[0m 18ms/step - loss: 7.6047

KeyboardInterrupt: 

In [None]:

f = bf.diagnostics.plots.loss(history)

# Save the workflow
#workflow.approximator.save("gbm_drift_workflow/")

plt.plot()

In [None]:

num_datasets = 300
num_samples = 1000

# Simulate 300 scenarios
print("Running simulations")
test_sims = workflow.simulate(num_datasets)

# Obtain num_samples posterior samples per scenario
print("Sampling")
samples = workflow.sample(conditions=test_sims, num_samples=num_samples)

print("Making plots")
f = bf.diagnostics.plots.recovery(samples, test_sims)

plt.plot()

In [None]:

labels = ["v1", "v2", "v3"]

truths = np.asarray([test_sims[labels[0]][0].item(),
                     test_sims[labels[1]][0].item(),
                     test_sims[labels[2]][0].item()])

out_samples = np.asarray([samples[labels[0]][0].flatten(),
                          samples[labels[1]][0].flatten(),
                          samples[labels[2]][0].flatten()]).T

d = out_samples.shape[1]
fig, axes = plt.subplots(d, d, figsize=(8, 8))

for i in range(d):
    for j in range(d):
        ax = axes[i, j]
        if i == j:
            ax.set_facecolor("white")  # set background blue
            ax.hist(out_samples[:, i], bins=40, histtype="step", color="lightblue")
            ax.axvline(truths[i], color="red")
            ax.set_xlabel(labels[i])
        elif i < j:
            ax.set_facecolor("midnightblue")  # set background blue
            h = ax.hist2d(out_samples[:, j], out_samples[:, i],
                          bins=50, cmap="viridis")
            ax.plot(truths[j], truths[i], "o", color="red")
        else:
            ax.axis("off")

plt.tight_layout()

plt.plot()