In [None]:
""" Hierarchical Bayesian Model """

import pymc as pm
import numpy as np
import matplotlib.pyplot as plt

# --- Simulation settings ---
n_simulations = 1           # Number of seasons to simulate
n_nests = 1000                # Nests per season
min_eggs = 3
max_eggs = 7
min_fragments = 10
max_fragments = 30
sample_size = 110             # Fragments sampled per season

# --- Store all sampled δ18O values across all simulations ---
all_sampled_fragments = []

for sim in range(n_simulations):
    with pm.Model() as model:
        # Phenology prior: laying day ~ N(mean, std)
        mean_day = pm.Uniform("mean_day", lower=3, upper=7)
        std_day = pm.HalfNormal("std_day", sigma=2)
        laying_days = pm.Normal("laying_days", mu=mean_day, sigma=std_day, shape=n_nests)

        # δ18O of water over time: linear change
        d18O_water = -12.5 + 0.22 * laying_days
        d18O_eggshell = pm.Deterministic("d18O_eggshell", d18O_water + 32)

        # Prior predictive sampling
        prior = pm.sample_prior_predictive(samples=1, return_inferencedata=False)

    # Extract eggshell values for each nest
    d18O_per_nest = prior["d18O_eggshell"][0]  # Shape: (n_nests,)

    # Assign random egg count and simulate fragments
    eggs_per_nest = np.random.randint(min_eggs, max_eggs + 1, size=n_nests)
    fragments = []

    for d18O, n_eggs in zip(d18O_per_nest, eggs_per_nest):
        for _ in range(n_eggs):
            n_frags = np.random.randint(min_fragments, max_fragments + 1)
            fragments.extend([d18O] * n_frags)

    # Sample 110 fragments from all
    sampled = np.random.choice(fragments, size=sample_size, replace=False)
    all_sampled_fragments.extend(sampled)

# --- Plot aggregated histogram ---
plt.figure(figsize=(8, 5))
plt.hist(all_sampled_fragments, bins=30, edgecolor='black')
plt.title(f"Aggregated δ$^{{18}}$O Histogram\n({n_simulations} seasons × {sample_size} fragments = {len(all_sampled_fragments)})")
plt.xlabel("δ$^{18}$O (‰, eggshell)")
plt.ylabel("Frequency")
plt.grid(True, linestyle='--', alpha=0.6)
plt.tight_layout()
plt.show()
