In [None]:
import numpy as np
import torch
import einops
import math
import torch.nn as nn
import matplotlib.pyplot as plt
import numpy as np
import numba as nb
import matplotlib.pyplot as plt
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F

from src.mamba import Mamba, MambaConfig 
from sbi import utils as utils
from sbi import analysis as analysis
from sbi.inference.base import infer
from sbi.inference import SNPE, prepare_for_sbi, simulate_for_sbi
from sbi import analysis, utils
from sbi.inference import SNPE, simulate_for_sbi
from sbi.utils.user_input_checks import (
    check_sbi_inputs,
    process_prior,
    process_simulator,
)
# import required modules
from sbi.utils.get_nn_models import posterior_nn
seed = 0 
torch.manual_seed(seed) 
from src.temporal_encoders import ResidualTemporalBlock, Residual, PreNorm, LinearAttention, Downsample1d, Conv1dBlock

In [None]:
def Brownian_Motion_simulator():
    return 0 #take the BM or Langevin simulator of the given inference procedure

In [None]:
with open('your_path.pkl', 'rb') as f:
    posterior = torch.load(f)

In [None]:
#plot the posteriors for a given set of true D values

params_array = np.array([-0.5,0,0.5,1,1.5])
posterior_samples = []
true_params_list = []
for i in params_array:
    true_params = torch.tensor([i], dtype=torch.float32).reshape(1, -1)
    observation = Brownian_Motion_simulator(true_params).to('cuda')
    samples = posterior.sample((100000,), x=observation).cpu()
    #samples = samples
    posterior_samples.append(samples)
    true_params_list.append(true_params)

# Create a list of colors
points_colors = ['r']*len(true_params_list)

# Create a list of labels for the legend
labels = [f'$log(D_{i+1})$: {true_params_list[i].item()}' for i in range(len(true_params_list))]+ ['$\\theta_0$']

_ = analysis.pairplot(
    posterior_samples, 
    limits=[[-1, 2]], 
    figsize=(10, 10), 
    points=true_params_list, 
    points_offdiag={'markersize': 6}, 
    points_colors=points_colors,
    labels=labels
)

# Add a title to the x-axis
plt.xlabel('$log(D)$')

# Add a legend
plt.legend(labels, 
              title='Parameters',
              title_fontsize='large',
                loc='upper left',
           )

In [None]:
params_array = np.linspace(-1, 2, 1000) #plot the metrics of 1000 BM posteriors for the true value D
posterior_samples = []
for i in params_array:
    true_params = torch.tensor([i], dtype=torch.float32).reshape(1, -1)
    observation = Brownian_Motion_simulator(true_params).to('cuda')
    samples = posterior.sample((100000,), x=observation, show_progress_bars=False).cpu()
    samples = samples
    posterior_samples.append(samples)

mean_posterior = []
std_posterior = []
for i in range(len(posterior_samples)):
    mean_posterior.append(posterior_samples[i].mean())
    std_posterior.append(posterior_samples[i].std())

fig, axs = plt.subplots(1, 2, figsize=(10, 5))
axs[0].plot(params_array, mean_posterior)
axs[0].set_xlabel('$D$')
axs[0].set_ylabel('Posterior Mean')
axs[1].plot(params_array, std_posterior)
axs[1].set_xlabel('$D$')
axs[1].set_ylabel('Posterior STD')

In [None]:
mass_array = np.array([2,-1]) #plot the langevin posteriors
friction_array = np.array([-1,2])
posterior_samples = []
true_params_list = []
for i, j in zip(mass_array, friction_array):
    true_params = torch.tensor([i,j], dtype=torch.float32).reshape(1, -1)
    observation = Brownian_Motion_simulator(true_params)
    samples = posterior.sample((100000,), x=observation)
    #samples = samples
    posterior_samples.append(samples)
    true_params_list.append(true_params)

# Create a list of colors
points_colors = ['r']*len(true_params_list)

# Create a list of labels for the legend
labels = [f'$\\gamma_{i+1}$: {true_params_list[i].tolist()[0][0]}, $m_{i+1}$: {true_params_list[i].tolist()[0][1]}' for i in range(len(true_params_list))] + ['$\\theta_0$']

_ = analysis.pairplot(
    posterior_samples, 
    limits=[[-1,2], [-1,2]], 
    figsize=(10, 10), 
    points=true_params_list, 
    points_offdiag={'markersize': 6}, 
    points_colors=points_colors,
    labels=['$mass$', '$\\gamma$']
)

# Add a legend
plt.legend(labels, 
           title='Parameters',
           title_fontsize='large',
           loc=(-1, 0.5),  # Places the legend at the center of the plot
          )

In [None]:
from tqdm import tqdm

mass_array = np.linspace(-1, 2, 100) #infer 10000 posteriors on a grid of size 10^2 x 10^2
friction_array = np.linspace(-1, 2, 100)

mass_grid, friction_grid = np.meshgrid(mass_array, friction_array)
params_grid = np.stack([mass_grid.flatten(), friction_grid.flatten()], axis=-1)

true_params = torch.tensor(params_grid, dtype=torch.float32)

observations = Brownian_Motion_simulator(true_params).to('cuda')

posterior_samples = []
for observation in tqdm(observations):
    sample = posterior.sample((100000,), x=observation.unsqueeze(0), show_progress_bars=False).cpu()
    posterior_samples.append(sample)
posterior_samples = torch.stack(posterior_samples)

In [None]:
mass_posterior_samples = [sample[:, 0] for sample in posterior_samples] #calculate the metrics
friction_posterior_samples = [sample[:, 1] for sample in posterior_samples]

mean_posterior_mass = [sample.mean() for sample in mass_posterior_samples]
std_posterior_mass = [sample.std() for sample in mass_posterior_samples]

mean_posterior_fric = [sample.mean() for sample in friction_posterior_samples]
std_posterior_fric = [sample.std() for sample in friction_posterior_samples]

In [None]:
m, f = np.meshgrid(mass_array, friction_array) #visualize as a heatmap
m = m.flatten()
f = f.flatten()

fig, axs = plt.subplots(2, 2, figsize=(15, 15))

mean_m = np.abs(np.array(mean_posterior_mass)-m)
std_m = np.array(std_posterior_mass)

mean_f = np.abs(np.array(mean_posterior_fric)-f)
std_f = np.array(std_posterior_fric)

scatter0 = axs[0,0].scatter(m, f, c=mean_m.flatten())
axs[0,0].set_xlabel('$\\log(m)$')
axs[0,0].set_ylabel('$\\log(\\gamma)$')
axs[0,0].set_title('Accuracy mass')
fig.colorbar(scatter0, ax=axs[0,0], label='Mean-True mass')

scatter1 = axs[0,1].scatter(m, f, c=std_m.flatten())
axs[0,1].set_xlabel('$\\log(m)$')
axs[0,1].set_ylabel('$\\log(\\gamma)$')
axs[0,1].set_title('Precision mass')
fig.colorbar(scatter1, ax=axs[0,1], label='STD mass')

scatter2 = axs[1,0].scatter(m, f, c=mean_f.flatten())
axs[1,0].set_xlabel('$\\log(m)$')
axs[1,0].set_ylabel('$\\log(\\gamma)$')
axs[1,0].set_title('Accuracy friction')
fig.colorbar(scatter2, ax=axs[1,0], label='Mean-True friction')

scatter3 = axs[1,1].scatter(m, f, c=std_f.flatten())
axs[1,1].set_xlabel('$\\log(m)$')
axs[1,1].set_ylabel('$\\log(\\gamma)$')
axs[1,1].set_title('Precision friction')
fig.colorbar(scatter3, ax=axs[1,1], label='STD friction')

plt.show()