In [None]:
!pip install git+https://github.com/stefanradev93/bayesflow.git

In [None]:
# Model_1:
import numpy as np
import pandas as pd
import tensorflow as tf
from scipy.stats import uniform, entropy, kurtosis, skew
from bayesflow.workflows import BasicWorkflow
from bayesflow.approximators import ContinuousApproximator
from bayesflow.networks import CouplingFlow
from bayesflow.adapters import Adapter
import bayesflow.diagnostics as bf_diag
import bayesflow as bf
from scipy.stats import uniform
from tensorflow import keras
import matplotlib.pyplot as plt

In [None]:
file_path = '/content/drive/My Drive/Synthetic_Eye-Tracking_Data.csv'
df = pd.read_csv(file_path)
df.head()

In [None]:
df['fix_duration'] = df['Fixation_End'] - df['Fixation_Start']
df['word_idx'] = df['Word_ID']
df['word_length'] = df['Word_Length']
df['frequency'] = df['Word_Frequency']

In [None]:
records = df[['word_idx', 'fix_duration']].to_dict(orient='records')

In [None]:
# Model_1:
# ─────────────────────────────────────────────────────────────────────────────
# 1. Define Priors
priors = {
    'nu': uniform(loc=0.1, scale=0.9),
    'r': uniform(loc=5.0, scale=15.0),
    'mu_T': uniform(loc=150.0, scale=150.0)
}

def sample_prior():
    return {
        'nu': priors['nu'].rvs(),
        'r': priors['r'].rvs(),
        'mu_T': priors['mu_T'].rvs()
    }

# ─────────────────────────────────────────────────────────────────────────────
# 2. SWIFT-inspired Simulation

def compute_saliency(activations, eta=0.01):
    scaled = activations / eta
    exp_vals = np.exp(scaled - np.max(scaled))
    return exp_vals / np.sum(exp_vals)

def simulate(theta, num_words=30, alpha=9, eta=2.0, max_steps=100):
    nu, r, mu_T = theta['nu'], theta['r'], theta['mu_T']
    position = 0
    activations = np.zeros(num_words)
    fixations = []
    visited = set()

    for _ in range(max_steps):
        for i in range(num_words):
            dist = abs(i - position)
            activations[i] = r * np.exp(-dist**2 / (2 * nu**2))

        saliency_probs = compute_saliency(activations, eta)
        saliency_probs = saliency_probs / saliency_probs.sum()

        next_position = np.random.choice(num_words, p=saliency_probs)

        rate = alpha / mu_T
        duration = np.random.gamma(shape=alpha, scale=1 / rate)

        fixations.append({'word_idx': next_position + 1, 'fix_duration': duration})
        visited.add(next_position)

        if len(visited) == num_words:
            break

        position = next_position

    return fixations

# ─────────────────────────────────────────────────────────────────────────────
# 3. Summary Statistics Extractor

def extract_summary_stats(fixations):
    words = np.array([f['word_idx'] for f in fixations])
    durs = np.array([f['fix_duration'] for f in fixations])
    saccades = np.diff(words)

    return np.array([
        durs.mean(),
        durs.std(),
        np.mean(saccades > 1),
        np.mean(saccades < 0),
        np.mean(np.abs(saccades)),
        skew(durs),
        kurtosis(durs),
        entropy(np.bincount(words) + 1)
    ], dtype=np.float32)

# ─────────────────────────────────────────────────────────────────────────────
# 4. Generate Training Dataset

def simulate_and_extract():
    theta = sample_prior()
    fixations = simulate(theta)
    stats = extract_summary_stats(fixations)
    return np.array([theta['nu'], theta['r'], theta['mu_T']], dtype=np.float32), stats

N = 10000
theta_list, stats_list = [], []

for _ in range(N):
    theta, stats = simulate_and_extract()
    theta_list.append(theta)
    stats_list.append(stats)

theta_train = np.array(theta_list)
x_train = np.array(stats_list)

# Optional: validation split
theta_val = theta_train[:1000]
x_val = x_train[:1000]

# ─────────────────────────────────────────────────────────────────────────────
# 5. Inference Network

inference_network = CouplingFlow(num_params=3)

# ─────────────────────────────────────────────────────────────────────────────
# 6. BayesFlow Setup

class IdentityAdapter:
    def adapt(self, sim_out):
        return sim_out['parameters'], sim_out['summary_conditions']

approximator = ContinuousApproximator(
    adapter=IdentityAdapter(),
    inference_network=inference_network
)

workflow = BasicWorkflow(
    approximator=approximator,
    inference_variables=["theta"],
    inference_conditions=["x"]
)

# ─────────────────────────────────────────────────────────────────────────────
# 7. Training

workflow.fit_offline(
    data={"x": x_train, "theta": theta_train},
    validation_data={"x": x_val, "theta": theta_val},
    epochs=50,
    batch_size=64,
    verbose=1
)

# ─────────────────────────────────────────────────────────────────────────────
# 8. Real Observed Fixation Data

real_fixations = records
obs_summary = extract_summary_stats(real_fixations).reshape(1, -1)

# ─────────────────────────────────────────────────────────────────────────────
# 9. Posterior Sampling and Prediction

posterior_samples = workflow.sample(
    conditions={"x": obs_summary},
    num_samples=1000
)

samples_array = posterior_samples["theta"].squeeze(0)
posterior_mean = samples_array.mean(axis=0)
posterior_std = samples_array.std(axis=0)

 print("\nPosterior Mean Estimates:")
 print(f"ν (nu): {posterior_mean[0]:.4f} ± {posterior_std[0]:.4f}")
 print(f"r:      {posterior_mean[1]:.4f} ± {posterior_std[1]:.4f}")
 print(f"μ_T:    {posterior_mean[2]:.4f} ± {posterior_std[2]:.4f}")


In [None]:
# Model_1:
def parameter_recovery(workflow, num_cases=200):
    true_params, est_means = [], []

    for _ in range(num_cases):
        theta_true, x = simulate_and_extract()
        posterior = workflow.sample(conditions={"x": x.reshape(1, -1)}, num_samples=500)["theta"].squeeze(0)
        true_params.append(theta_true)
        est_means.append(posterior.mean(axis=0))

    true_params = np.array(true_params)
    est_means = np.array(est_means)

    for i, name in enumerate(['nu', 'r', 'mu_T']):
        corr = np.corrcoef(true_params[:, i], est_means[:, i])[0, 1]
        rmse = np.sqrt(np.mean((true_params[:, i] - est_means[:, i])**2))
        print(f"{name}: corr = {corr:.3f}, RMSE = {rmse:.3f}")

parameter_recovery(workflow, num_cases=200)


In [None]:
# Model_1:

def posterior_predictive_overlay(workflow, obs_summary, real_fixations, num_samples=100, show_individual_curves=True):
    posterior = workflow.sample(conditions={"x": obs_summary}, num_samples=num_samples)["theta"].squeeze(0)
    simulated_curves = []

    for theta_sample in posterior:
        fix = simulate({'nu': theta_sample[0], 'r': theta_sample[1], 'mu_T': theta_sample[2]})
        fix_durs = [f['fix_duration'] for f in fix]

        target_len = len(real_fixations)
        if len(fix_durs) >= target_len:
            simulated_curves.append(fix_durs[:target_len])
        else:
            padded = fix_durs + [np.nan] * (target_len - len(fix_durs))
            simulated_curves.append(padded)

    simulated_curves = np.array(simulated_curves)
    if simulated_curves.shape[0] == 0:
        print("No simulated sequences matched observed length.")
        return

    mean_curve = np.nanmean(simulated_curves, axis=0)
    observed = [f['fix_duration'] for f in real_fixations]
    x = list(range(1, len(observed) + 1))

    for i, curve in enumerate(simulated_curves):
      if show_individual_curves:
        if i == 0:
            plt.plot(x, curve, color='blue', alpha=0.1)  # Actual curve
            plt.plot([], [], color='blue', label='Posterior predictive samples')  # Legend only
        else:
            plt.plot(x, curve, color='blue', alpha=0.1)
    plt.plot(x, mean_curve, color='orange', linestyle='--', label='Posterior predictive mean')
    plt.plot(x, observed, color='black', label='Observed', linewidth=2)

    plt.xlabel("Fixation Index")
    plt.ylabel("Fixation Duration (ms)")
    plt.title("Posterior Predictive Overlay")
    plt.legend()
    plt.grid(True)
    plt.tight_layout()
    plt.show()

posterior_predictive_overlay(workflow, obs_summary, real_fixations)

In [None]:
# Model_1:
# ─────────────────────────────────────────────────────────────────────────────
# 7b. ECDF Calibration Plot using Existing Validation Data

import bayesflow.diagnostics as bf_diag
import matplotlib.pyplot as plt

# Run inference on validation summaries
posterior_val = workflow.sample(
    conditions={"x": x_val},
    num_samples=1000
)["theta"]  # shape: (num_samples, num_val, 3)

# BayesFlow expects shape: (num_val, num_samples, num_params)
samples = {"parameters": posterior_val.transpose(1, 0, 2)}
validation_data = {"parameters": theta_val}

# Parameter names for labeling
param_names = ['nu', 'r', 'mu_T']

# Plot ECDF with 95% confidence band
bf_diag.plots.calibration_ecdf(
    samples,
    validation_data,
    variable_names=param_names,
    difference=True,
    alpha=0.05
)
plt.show()


In [None]:
# Model_1:

def parameter_recovery_plot(workflow, num_cases=200):
    true_params, est_means = [], []

    for _ in range(num_cases):
        theta_true, x = simulate_and_extract()
        posterior = workflow.sample(conditions={"x": x.reshape(1, -1)}, num_samples=500)["theta"].squeeze(0)
        true_params.append(theta_true)
        est_means.append(posterior.mean(axis=0))

    true_params = np.array(true_params)
    est_means = np.array(est_means)
    param_names = ['nu', 'r', 'mu_T']

    for i in range(3):
        # Plot scatter
        plt.figure(figsize=(5, 4))
        plt.scatter(true_params[:, i], est_means[:, i], alpha=0.6)
        plt.plot(
            [true_params[:, i].min(), true_params[:, i].max()],
            [true_params[:, i].min(), true_params[:, i].max()],
            'r--', label='Perfect Recovery'
        )
        plt.xlabel(f"True {param_names[i]}")
        plt.ylabel(f"Estimated {param_names[i]}")
        plt.title(f"Parameter Recovery: {param_names[i]}")

        # Compute RMSE and display on plot
        errors = est_means[:, i] - true_params[:, i]
        rmse = np.sqrt(np.mean(errors ** 2))
        corr = np.corrcoef(true_params[:, i], est_means[:, i])[0, 1]
        plt.text(0.05, 0.95, f"RMSE = {rmse:.4f}\nCorr = {corr:.4f}",
                 transform=plt.gca().transAxes, fontsize=10,
                 verticalalignment='top', bbox=dict(boxstyle="round", facecolor="white", alpha=0.8))

        plt.legend()
        plt.grid(True)
        plt.tight_layout()
        plt.show()

        print(f"{param_names[i]}: RMSE = {rmse:.4f}, Correlation = {corr:.4f}")

parameter_recovery_plot(workflow)

In [None]:
# Model_2:
# ─────────────────────────────────────────────────────────────────────────────
# 1. Define Priors
priors = {
    'nu': uniform(loc=0.1, scale=0.9),
    'r': uniform(loc=5.0, scale=15.0),
    'mu_T': uniform(loc=150.0, scale=150.0)
}

def sample_prior():
    return {
        'nu': priors['nu'].rvs(),
        'r': priors['r'].rvs(),
        'mu_T': priors['mu_T'].rvs()
    }

# ─────────────────────────────────────────────────────────────────────────────
# 2. SWIFT-inspired Simulation

def compute_saliency(activations, eta=0.01):
    scaled = activations / eta
    exp_vals = np.exp(scaled - np.max(scaled))
    return exp_vals / np.sum(exp_vals)

def simulate(theta, num_words=30, alpha=9, eta=2.0, max_steps=100):
    nu, r, mu_T = theta['nu'], theta['r'], theta['mu_T']
    position = 0
    activations = np.zeros(num_words)
    fixations = []
    visited = set()

    for _ in range(max_steps):
        for i in range(num_words):
            dist = abs(i - position)
            activations[i] = r * np.exp(-dist**2 / (2 * nu**2))

        saliency_probs = compute_saliency(activations, eta)
        saliency_probs = saliency_probs / saliency_probs.sum()
        next_position = np.random.choice(num_words, p=saliency_probs)

        rate = alpha / mu_T
        duration = np.random.gamma(shape=alpha, scale=1 / rate)

        fixations.append({'word_idx': next_position + 1, 'fix_duration': duration})
        visited.add(next_position)

        if len(visited) == num_words:
            break

        position = next_position

    return fixations

# ─────────────────────────────────────────────────────────────────────────────
# 3. Learned Summary Network

class SummaryNetwork(bf.networks.SummaryNetwork):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.network = keras.Sequential(
            [
                keras.layers.Dense(400, activation="relu"),
                keras.layers.Dense(200, activation="relu"),
                keras.layers.Dense(100, activation="relu"),
                keras.layers.Dense(50, activation="relu"),
            ]
        )

    def call(self, x, **kwargs):
        summary = self.network(x, training=kwargs.get("stage") == "training")
        return summary

# ─────────────────────────────────────────────────────────────────────────────
# 4. Generate Training Dataset

MAX_FIXATIONS = 100
FEATURES_PER_FIXATION = 2
PADDED_LENGTH = MAX_FIXATIONS * FEATURES_PER_FIXATION

def pad_fixations(arr):
    if len(arr) >= PADDED_LENGTH:
        return arr[:PADDED_LENGTH]
    padded = np.zeros(PADDED_LENGTH, dtype=np.float32)
    padded[:len(arr)] = arr
    return padded

def simulate_and_extract():
    theta = sample_prior()
    fixations = simulate(theta)
    flat_fix = []
    for f in fixations:
        flat_fix.extend([f['word_idx'], f['fix_duration']])
    return np.array([theta['nu'], theta['r'], theta['mu_T']], dtype=np.float32), pad_fixations(np.array(flat_fix, dtype=np.float32))

N = 10000
theta_list, fix_list = [], []

for _ in range(N):
    theta, fix = simulate_and_extract()
    theta_list.append(theta)
    fix_list.append(fix)

theta_train = np.array(theta_list)
x_train = np.array(fix_list)

theta_val = theta_train[:1000]
x_val = x_train[:1000]

# ─────────────────────────────────────────────────────────────────────────────
# 5. Adapter

class FixationAdapter:
    def adapt(self, sim_out):
        return sim_out["parameters"], sim_out["sim_data"]

# ─────────────────────────────────────────────────────────────────────────────
# 6. Setup BayesFlow Workflow

summary_network = SummaryNetwork()
inference_network = CouplingFlow(num_params=3)

approximator = ContinuousApproximator(
    adapter=FixationAdapter(),
    summary_network=summary_network,
    inference_network=inference_network
)

workflow = BasicWorkflow(
    approximator=approximator,
    inference_variables=["theta"],
    inference_conditions=["x"]
)

# ─────────────────────────────────────────────────────────────────────────────
# 7. Training

workflow.fit_offline(
    data={"x": x_train, "theta": theta_train},
    validation_data={"x": x_val, "theta": theta_val},
    epochs=50,
    batch_size=64,
    verbose=1
)

# ─────────────────────────────────────────────────────────────────────────────
# 8. Real Observed Fixation Data

real_fixations = records

flat_obs = []
for f in real_fixations:
    flat_obs.extend([f['word_idx'], f['fix_duration']])
flat_obs = pad_fixations(np.array(flat_obs, dtype=np.float32)).reshape(1, -1)

# ─────────────────────────────────────────────────────────────────────────────
# 9. Posterior Sampling and Prediction

posterior_samples = workflow.sample(
    conditions={"x": flat_obs},
    num_samples=1000
)

samples_array = posterior_samples["theta"].squeeze(0)
posterior_mean = samples_array.mean(axis=0)
posterior_std = samples_array.std(axis=0)

print("\nPosterior Mean Estimates:")
print(f"ν (nu): {posterior_mean[0]:.4f} ± {posterior_std[0]:.4f}")
print(f"r:      {posterior_mean[1]:.4f} ± {posterior_std[1]:.4f}")
print(f"μ_T:    {posterior_mean[2]:.4f} ± {posterior_std[2]:.4f}")


In [None]:
# Model_2:
# ─────────────────────────────────────────────────────────────────────────────
# 10. Posterior Predictive Check Plot


def posterior_predictive_overlay(workflow, obs_fixations, real_fixations, num_samples=100, show_individual_curves=True):
    # Handle padding and reshaping internally
    if isinstance(obs_fixations, list) and isinstance(obs_fixations[0], dict):
        flat_obs = []
        for f in obs_fixations:
            flat_obs.extend([f['word_idx'], f['fix_duration']])
        obs_summary = pad_fixations(np.array(flat_obs, dtype=np.float32)).reshape(1, -1)
    elif isinstance(obs_fixations, np.ndarray):
        obs_summary = obs_fixations.reshape(1, -1)
    else:
        raise ValueError("Unsupported format for obs_fixations.")

    posterior = workflow.sample(conditions={"x": obs_summary}, num_samples=num_samples)["theta"].squeeze(0)
    simulated_curves = []

    for theta_sample in posterior:
        fix = simulate({'nu': theta_sample[0], 'r': theta_sample[1], 'mu_T': theta_sample[2]})
        fix_durs = [f['fix_duration'] for f in fix]

        target_len = len(real_fixations)
        if len(fix_durs) >= target_len:
            simulated_curves.append(fix_durs[:target_len])
        else:
            padded = fix_durs + [np.nan] * (target_len - len(fix_durs))
            simulated_curves.append(padded)

    simulated_curves = np.array(simulated_curves)
    if simulated_curves.shape[0] == 0:
        print("No simulated sequences matched observed length.")
        return

    mean_curve = np.nanmean(simulated_curves, axis=0)
    observed = [f['fix_duration'] for f in real_fixations]
    x = list(range(1, len(observed) + 1))

    for i, curve in enumerate(simulated_curves):
        if show_individual_curves:
            if i == 0:
                plt.plot(x, curve, color='blue', alpha=0.1)
                plt.plot([], [], color='blue', label='Posterior predictive samples')
            else:
                plt.plot(x, curve, color='blue', alpha=0.1)

    plt.plot(x, mean_curve, color='orange', linestyle='--', label='Posterior predictive mean')
    plt.plot(x, observed, color='black', label='Observed', linewidth=2)

    plt.xlabel("Fixation Index")
    plt.ylabel("Fixation Duration (ms)")
    plt.title("Posterior Predictive Overlay")
    plt.legend()
    plt.grid(True)
    plt.tight_layout()
    plt.show()

# Run the overlay using real fixation data (automatic padding)
posterior_predictive_overlay(workflow, real_fixations, real_fixations)


In [None]:
# Model_2:
# ─────────────────────────────────────────────────────────────────────────────

import bayesflow.diagnostics as bf_diag
import matplotlib.pyplot as plt

# Run inference on validation data
posterior_val = workflow.sample(
    conditions={"x": x_val},
    num_samples=1000
)["theta"]  # shape: (num_samples, num_val, 3)

# Transpose posterior_val to (num_val, num_samples, num_parameters)
posterior_val_transposed = posterior_val.transpose(1, 0, 2)

# Use theta_val directly for targets
estimates = posterior_val_transposed
targets = theta_val

# Plot ECDF
bf_diag.plots.calibration_ecdf(
    estimates=estimates,
    targets=targets,
    variable_names=param_names,
    difference=True,
    alpha=0.05
)
plt.tight_layout()
plt.show()


In [None]:
# Model_2:
from tqdm import tqdm  # progress bar

def parameter_recovery_plot(workflow, num_cases=200):
    true_params, est_means = [], []

    for _ in tqdm(range(num_cases), desc="Simulating + inferring"):
        theta_true, x = simulate_and_extract()
        posterior = workflow.sample(
            conditions={"x": x.reshape(1, -1)},  # x must be shape (1, 200)
            num_samples=500
        )["theta"].squeeze(0)
        true_params.append(theta_true)
        est_means.append(posterior.mean(axis=0))

    true_params = np.array(true_params)
    est_means = np.array(est_means)
    param_names = ['nu', 'r', 'mu_T']

    for i in range(3):
        plt.figure(figsize=(5, 4))
        plt.scatter(true_params[:, i], est_means[:, i], alpha=0.6)
        plt.plot(
            [true_params[:, i].min(), true_params[:, i].max()],
            [true_params[:, i].min(), true_params[:, i].max()],
            'r--', label='Perfect Recovery'
        )
        plt.xlabel(f"True {param_names[i]}")
        plt.ylabel(f"Estimated {param_names[i]}")
        plt.title(f"Parameter Recovery: {param_names[i]}")

        # RMSE and Correlation
        errors = est_means[:, i] - true_params[:, i]
        rmse = np.sqrt(np.mean(errors ** 2))
        corr = np.corrcoef(true_params[:, i], est_means[:, i])[0, 1]
        plt.text(0.05, 0.95, f"RMSE = {rmse:.4f}\nCorr = {corr:.4f}",
                 transform=plt.gca().transAxes, fontsize=10,
                 verticalalignment='top', bbox=dict(boxstyle="round", facecolor="white", alpha=0.8))

        plt.legend()
        plt.grid(True)
        plt.tight_layout()
        plt.show()

        print(f"{param_names[i]}: RMSE = {rmse:.4f}, Correlation = {corr:.4f}")

# Run the parameter recovery
parameter_recovery_plot(workflow)

In [None]:
# Comparison_Model:
# ─────────────────────────────────────────────────────────────────────────────
# 0. Setup
np.random.seed(42)
tf.random.set_seed(42)

# ─────────────────────────────────────────────────────────────────────────────
# 1. Priors

priors = {
    'nu': uniform(loc=0.1, scale=0.9),
    'r': uniform(loc=5.0, scale=15.0),
    'mu_T': uniform(loc=150.0, scale=150.0)
}

def sample_prior():
    return {
        'nu': priors['nu'].rvs(),
        'r': priors['r'].rvs(),
        'mu_T': priors['mu_T'].rvs()
    }

# ─────────────────────────────────────────────────────────────────────────────
# 2. Simulation Variants

def simulate(theta, num_words=30, alpha=9, eta=1e-3, max_steps=100):
    nu, r, mu_T = theta['nu'], theta['r'], theta['mu_T']
    position = 0
    activations = np.zeros(num_words)
    fixations = []
    for _ in range(max_steps):
        for i in range(num_words):
            d = abs(i - position)
            activations[i] += r * np.exp(-d**2 / (2 * nu**2))
        norm = activations / activations.max() if activations.max() > 0 else activations
        sal = np.sin(np.pi * norm)**2
        sal = np.where(norm < 1, sal, 0) + eta
        probs = sal / sal.sum()
        position = np.random.choice(num_words, p=probs)
        duration = np.random.gamma(shape=alpha, scale=mu_T / alpha)
        fixations.append({'word': position + 1, 'duration': duration})
    return fixations

def simulate_no_saliency(theta, num_words=10, alpha=9, max_steps=50):
    mu_T = theta['mu_T']
    fixations = [
        {'word': np.random.randint(1, num_words + 1),
         'duration': np.random.gamma(alpha, mu_T / alpha)}
        for _ in range(max_steps)
    ]
    return fixations

def simulate_random(theta, num_words=10, alpha=9, max_steps=50):
    return simulate_no_saliency(theta, num_words, alpha, max_steps)

# ─────────────────────────────────────────────────────────────────────────────
# 3. Data Processing

MAX_FIXATIONS = 100
FEATURES_PER_FIX = 2
PADDED_LEN = MAX_FIXATIONS * FEATURES_PER_FIX

def flatten_fixations(fix):
    flat = []
    for f in fix:
        flat.extend([f['word'], f['duration']])
    return np.array(flat, dtype=np.float32)

def pad_fixations(fix_array):
    if len(fix_array) >= PADDED_LEN:
        return fix_array[:PADDED_LEN]
    padded = np.zeros(PADDED_LEN, dtype=np.float32)
    padded[:len(fix_array)] = fix_array
    return padded

def simulate_and_prepare(sim_func):
    theta = sample_prior()
    fix = sim_func(theta)
    x = pad_fixations(flatten_fixations(fix))
    y = np.array([theta['nu'], theta['r'], theta['mu_T']], dtype=np.float32)
    return x, y

# ─────────────────────────────────────────────────────────────────────────────
# 4. Summary Network

class SummaryNetwork(bf.networks.SummaryNetwork):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.network = keras.Sequential([
            keras.layers.Input(shape=(PADDED_LEN,)),
            keras.layers.Dense(400, activation="relu"),
            keras.layers.Dense(200, activation="relu"),
            keras.layers.Dense(100, activation="relu"),
            keras.layers.Dense(50, activation="relu"),
        ])
    def call(self, x, **kwargs):
        return self.network(x, training=kwargs.get("stage") == "training")

# ─────────────────────────────────────────────────────────────────────────────
# 5. Adapter

class IdentityAdapter:
    def adapt(self, sim_out):
        return sim_out["parameters"], sim_out["sim_data"]

# ─────────────────────────────────────────────────────────────────────────────
# 6. Training Helper

def train_bayesflow_model(sim_func, epochs=50, N=10000):
    x_list, theta_list = [], []
    for _ in range(N):
        x, y = simulate_and_prepare(sim_func)
        x_list.append(x)
        theta_list.append(y)
    x_train = np.array(x_list)
    theta_train = np.array(theta_list)
    x_val = x_train[:1000]
    theta_val = theta_train[:1000]

    summary_net = SummaryNetwork()
    inference_net = CouplingFlow(num_params=3)

    approximator = ContinuousApproximator(
        adapter=IdentityAdapter(),
        summary_network=summary_net,
        inference_network=inference_net
    )

    workflow = BasicWorkflow(
        approximator=approximator,
        inference_variables=["theta"],
        inference_conditions=["x"]
    )

    workflow.fit_offline(
        data={"x": x_train, "theta": theta_train},
        validation_data={"x": x_val, "theta": theta_val},
        epochs=epochs,
        batch_size=64,
        verbose=0
    )

    return workflow

# ─────────────────────────────────────────────────────────────────────────────
# 7. Train 3 Models

print("Training full model...")
workflow_full = train_bayesflow_model(simulate)

print("Training no saliency model...")
workflow_no_sal = train_bayesflow_model(simulate_no_saliency)

print("Training random model...")
workflow_rand = train_bayesflow_model(simulate_random)

# ─────────────────────────────────────────────────────────────────────────────
# 8. Evaluate on Ground Truth

true_params = {'nu': 0.5, 'r': 10.0, 'mu_T': 200.0}
obs_vec = pad_fixations(flatten_fixations(simulate(true_params))).reshape(1, -1)
true_vals = np.array([true_params['nu'], true_params['r'], true_params['mu_T']])

def get_abs_error(workflow):
    posterior = workflow.sample(conditions={"x": obs_vec}, num_samples=1000)["theta"].squeeze(0)
    mean = posterior.mean(axis=0)
    return np.abs(mean - true_vals)

err_full = get_abs_error(workflow_full)
err_no_sal = get_abs_error(workflow_no_sal)
err_rand = get_abs_error(workflow_rand)

# ─────────────────────────────────────────────────────────────────────────────
# 9. Visualization: Separate Bar Plots for Each Parameter (with value labels)

import matplotlib.pyplot as plt

labels = ['Full', 'No Saliency', 'Random']
errors = {
    'ν (nu)': [err_full[0], err_no_sal[0], err_rand[0]],
    'r':      [err_full[1], err_no_sal[1], err_rand[1]],
    'μ_T':    [err_full[2], err_no_sal[2], err_rand[2]]
}

for param, values in errors.items():
    plt.figure(figsize=(5, 4))
    bars = plt.bar(labels, values, color=['C0', 'C1', 'C2'])

    # Add value labels on top of each bar
    for bar in bars:
        height = bar.get_height()
        plt.text(
            bar.get_x() + bar.get_width() / 2,
            height + 0.01 * max(values),  # small offset above bar
            f"{height:.3f}",
            ha='center',
            va='bottom',
            fontsize=10
        )

    plt.ylabel("Absolute Error")
    plt.title(f"Parameter Recovery Error: {param}")
    plt.ylim(0, max(values) * 1.2)
    plt.grid(axis='y', linestyle='--', alpha=0.6)
    plt.tight_layout()
    plt.show()