In [None]:
# GemPy Gravity Response Tutorial - Setup and Helper Functions
# Trans-Conceptual Model Selection Example

import gempy as gp
import gempy_viewer as gpv
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from typing import Tuple

# Set random seed for reproducibility
np.random.seed(42)

# Import functions from module for cross-platform multiprocessing compatibility
from gravity_functions import create_gravity_grid, compute_and_plot_gravity

# Common parameters
grav_res = 20
extent = [0, 1000, 0, 1000, -1000, 0]

print("Setup complete!")

## Mathematical Framework

### Trans-Conceptual Posterior

The trans-conceptual posterior combines parameter uncertainty and model uncertainty:

```
p(state, θ | data) ∝ p(data | θ, state) × p(θ | state) × p(state)
```

where:
- `state` ∈ {0, 1, 2} is the model index
- `θ` is the parameter vector (amplitude in our case)
- `p(data | θ, state)` is the likelihood
- `p(θ | state)` is the within-state prior
- `p(state)` is the prior over models (uniform in our case)

### Marginal Model Probabilities

By integrating over parameters, we get the posterior model probability:

```
p(state | data) = ∫ p(state, θ | data) dθ
```

These probabilities are approximated by the visitation frequencies in the MCMC chains.

### Bayes Factors

The Bayes factor comparing models *i* and *j* is:

```
BF_ij = p(state=i | data) / p(state=j | data)
```

Interpretation:
- BF > 10: Strong evidence for model *i*
- BF > 3: Moderate evidence for model *i*  
- BF ≈ 1: Models are equally supported
- BF < 1: Evidence favors model *j*

# Trans-Conceptual Model Selection for Gravity Data

This notebook demonstrates **Trans-Conceptual Bayesian inference** for geological model selection using gravity data. We'll use three competing models of subsurface structure and let the data tell us which model is most probable.

## Problem Setup

We have three competing hypotheses about the geological structure:
- **Model 0**: Simple layered structure with no faults
- **Model 1**: Layered structure with one vertical fault
- **Model 2**: Layered structure with two opposing dipping faults

Each model produces a different gravity response at the surface. Given noisy gravity observations, we want to determine which geological model is most consistent with the data.

## Methodology

This example uses the **ensemble resampler** approach from pyTransC:

1. **Forward modeling**: Use GemPy to compute gravity responses for each geological model
2. **Within-state sampling**: Run MCMC independently for each model to explore its parameter space
3. **Pseudo-prior construction**: Build bridge distributions from the posterior ensembles
4. **Trans-Conceptual sampling**: Use ensemble resampler to jump between models and estimate posterior model probabilities

## References

- GemPy: 3D geological modeling (https://www.gempy.org/)
- pyTransC: Trans-Conceptual MCMC sampling
- Methodology based on Sambridge et al. (2013) and Bodin et al. (2012)

In [None]:
# ============================================================================
# MODEL 1: NO FAULT - Simple Layered Model
# ============================================================================
print("=" * 70)
print("MODEL 1: SIMPLE LAYERED MODEL (NO FAULT)")
print("=" * 70)

# Create surface points (layer interfaces)
surface_points_1 = pd.DataFrame({
    'X': [250, 750, 250, 750, 250, 750],
    'Y': [250, 250, 750, 750, 500, 500],
    'Z': [-200, -200, -200, -200, -500, -500],
    'surface': ['Layer2', 'Layer2', 'Layer2', 'Layer2', 'Layer3', 'Layer3']
})

# Create orientation data (dip direction and dip angle)
orientations_1 = pd.DataFrame({
    'X': [500, 500],
    'Y': [500, 500],
    'Z': [-200, -500],
    'surface': ['Layer2', 'Layer3'],
    'dip': [0, 0],  # horizontal layers
    'azimuth': [0, 0]
})

# Create GeoModel with default structural frame
geo_model_1 = gp.create_geomodel(
    project_name='Model1_NoFault',
    extent=extent,
    resolution=[50, 50, 50],
    refinement=1,
    structural_frame=gp.data.StructuralFrame.initialize_default_structure()
)

# Add surfaces
gp.add_surface_points(
    geo_model=geo_model_1,
    x=surface_points_1['X'].values,
    y=surface_points_1['Y'].values,
    z=surface_points_1['Z'].values,
    elements_names=surface_points_1['surface'].values
)

gp.add_orientations(
    geo_model=geo_model_1,
    x=orientations_1['X'].values,
    y=orientations_1['Y'].values,
    z=orientations_1['Z'].values,
    elements_names=orientations_1['surface'].values,
    pole_vector=np.array([[0, 0, 1], [0, 0, 1]])  # Horizontal layers
)

# Map geological series to surfaces
gp.map_stack_to_surfaces(
    gempy_model=geo_model_1,
    mapping_object={'Strata': ['Layer2', 'Layer3']}
)

# Set up gravity grid
xy_grid_1 = create_gravity_grid(extent, grav_res)

# Configure centered grid for gravity
gp.set_centered_grid(
    grid=geo_model_1.grid,
    centers=xy_grid_1,
    resolution=np.array([10, 10, 15]),
    radius=np.array([2000, 2000, 2000])
)

# Calculate gravity gradient
gravity_gradient_1 = gp.calculate_gravity_gradient(geo_model_1.grid.centered_grid)

# Assign densities (g/cm³): Layer1 (top), Layer2, Layer3
geo_model_1.geophysics_input = gp.data.GeophysicsInput(
    tz=gravity_gradient_1,
    densities=np.array([2.3, 2.6, 2.8])
)

# Compute and visualize
grav_1, sol_1 = compute_and_plot_gravity(geo_model_1, "Model 1: No Fault", grav_res)

## Part 1: Forward Modeling - Computing Gravity Responses

In this section, we define three competing geological models and compute their gravity responses using GemPy.

### Model 0: No Fault
A simple three-layer model with horizontal interfaces. This represents the null hypothesis - no structural complexity beyond layering.

In [None]:
# ============================================================================
# MODEL 2: SINGLE FAULT
# ============================================================================
print("\n" + "=" * 70)
print("MODEL 2: MODEL WITH ONE FAULT")
print("=" * 70)

# Stratigraphic layers
surface_points_2 = pd.DataFrame({
    'X': [200, 800, 200, 800, 200, 800],
    'Y': [250, 250, 750, 750, 500, 500],
    'Z': [-200, -200, -200, -200, -500, -500],
    'surface': ['Layer2', 'Layer2', 'Layer2', 'Layer2', 'Layer3', 'Layer3']
})

orientations_2 = pd.DataFrame({
    'X': [500, 500],
    'Y': [500, 500],
    'Z': [-200, -500],
    'surface': ['Layer2', 'Layer3'],
    'dip': [0, 0],
    'azimuth': [0, 0]
})

# Fault plane (vertical fault at X=500)
fault_points_2 = pd.DataFrame({
    'X': [500, 500, 500, 500],
    'Y': [200, 800, 200, 800],
    'Z': [-100, -100, -900, -900],
    'surface': ['Fault1', 'Fault1', 'Fault1', 'Fault1']
})

fault_orientations_2 = pd.DataFrame({
    'X': [500],
    'Y': [500],
    'Z': [-400],
    'surface': ['Fault1'],
    'dip': [90],  # vertical fault
    'azimuth': [90]  # striking N-S
})

# Create GeoModel with default structural frame
geo_model_2 = gp.create_geomodel(
    project_name='Model2_OneFault',
    extent=extent,
    resolution=[50, 50, 50],
    refinement=1,
    structural_frame=gp.data.StructuralFrame.initialize_default_structure()
)

# Add fault points
gp.add_surface_points(
    geo_model=geo_model_2,
    x=fault_points_2['X'].values,
    y=fault_points_2['Y'].values,
    z=fault_points_2['Z'].values,
    elements_names=fault_points_2['surface'].values
)

gp.add_orientations(
    geo_model=geo_model_2,
    x=fault_orientations_2['X'].values,
    y=fault_orientations_2['Y'].values,
    z=fault_orientations_2['Z'].values,
    elements_names=fault_orientations_2['surface'].values,
    pole_vector=np.array([[1, 0, 0]])  # Vertical fault
)

# Add stratigraphic layers
gp.add_surface_points(
    geo_model=geo_model_2,
    x=surface_points_2['X'].values,
    y=surface_points_2['Y'].values,
    z=surface_points_2['Z'].values,
    elements_names=surface_points_2['surface'].values
)

gp.add_orientations(
    geo_model=geo_model_2,
    x=orientations_2['X'].values,
    y=orientations_2['Y'].values,
    z=orientations_2['Z'].values,
    elements_names=orientations_2['surface'].values,
    pole_vector=np.array([[0, 0, 1], [0, 0, 1]])  # Horizontal layers
)

# Map geological series to surfaces
gp.map_stack_to_surfaces(
    gempy_model=geo_model_2,
    mapping_object={
        'Fault1_series': 'Fault1',
        'Strata': ['Layer2', 'Layer3']
    }
)

# Set fault relations
geo_model_2.structural_frame.structural_groups[0].structural_relation = gp.data.StackRelationType.FAULT
geo_model_2.structural_frame.fault_relations = np.array([[0, 1], [0, 0]])

# Gravity setup
xy_grid_2 = create_gravity_grid(extent, grav_res)
gp.set_centered_grid(
    grid=geo_model_2.grid,
    centers=xy_grid_2,
    resolution=np.array([10, 10, 15]),
    radius=np.array([2000, 2000, 2000])
)

gravity_gradient_2 = gp.calculate_gravity_gradient(geo_model_2.grid.centered_grid)
geo_model_2.geophysics_input = gp.data.GeophysicsInput(
    tz=gravity_gradient_2,
    densities=np.array([2.3, 2.6, 2.8])
)

# Compute and visualize
grav_2, sol_2 = compute_and_plot_gravity(geo_model_2, "Model 2: One Fault", grav_res)

### Model 1: Single Vertical Fault
A more complex model with a vertical fault cutting through the layers at X=500m. This introduces structural discontinuity and produces a distinct gravity signature.

In [None]:
# ============================================================================
# MODEL 3: TWO OPPOSING FAULTS
# ============================================================================
print("\n" + "=" * 70)
print("MODEL 3: MODEL WITH TWO OPPOSING FAULTS")
print("=" * 70)

# Stratigraphic layers
surface_points_3 = pd.DataFrame({
    'X': [200, 800, 200, 800, 200, 800],
    'Y': [250, 250, 750, 750, 500, 500],
    'Z': [-200, -200, -200, -200, -500, -500],
    'surface': ['Layer2', 'Layer2', 'Layer2', 'Layer2', 'Layer3', 'Layer3']
})

orientations_3 = pd.DataFrame({
    'X': [500, 500],
    'Y': [500, 500],
    'Z': [-200, -500],
    'surface': ['Layer2', 'Layer3'],
    'dip': [0, 0],
    'azimuth': [0, 0]
})

# Fault 1: Dipping to the right
fault1_points_3 = pd.DataFrame({
    'X': [350, 350, 350, 350],
    'Y': [200, 800, 200, 800],
    'Z': [-100, -100, -900, -900],
    'surface': ['Fault1', 'Fault1', 'Fault1', 'Fault1']
})

fault1_orientations_3 = pd.DataFrame({
    'X': [350],
    'Y': [500],
    'Z': [-400],
    'surface': ['Fault1'],
    'dip': [75],  # steep fault dipping right
    'azimuth': [90]
})

# Fault 2: Dipping to the left (opposite direction)
fault2_points_3 = pd.DataFrame({
    'X': [650, 650, 650, 650],
    'Y': [200, 800, 200, 800],
    'Z': [-100, -100, -900, -900],
    'surface': ['Fault2', 'Fault2', 'Fault2', 'Fault2']
})

fault2_orientations_3 = pd.DataFrame({
    'X': [650],
    'Y': [500],
    'Z': [-400],
    'surface': ['Fault2'],
    'dip': [75],  # steep fault dipping left
    'azimuth': [270]  # opposite direction
})

# Create GeoModel with default structural frame
geo_model_3 = gp.create_geomodel(
    project_name='Model3_TwoFaults',
    extent=extent,
    resolution=[50, 50, 50],
    refinement=1,
    structural_frame=gp.data.StructuralFrame.initialize_default_structure()
)

# Add fault 1
gp.add_surface_points(
    geo_model=geo_model_3,
    x=fault1_points_3['X'].values,
    y=fault1_points_3['Y'].values,
    z=fault1_points_3['Z'].values,
    elements_names=fault1_points_3['surface'].values
)

gp.add_orientations(
    geo_model=geo_model_3,
    x=fault1_orientations_3['X'].values,
    y=fault1_orientations_3['Y'].values,
    z=fault1_orientations_3['Z'].values,
    elements_names=fault1_orientations_3['surface'].values,
    pole_vector=np.array([[0.966, 0, 0.259]])  # Dipping fault
)

# Add fault 2
gp.add_surface_points(
    geo_model=geo_model_3,
    x=fault2_points_3['X'].values,
    y=fault2_points_3['Y'].values,
    z=fault2_points_3['Z'].values,
    elements_names=fault2_points_3['surface'].values
)

gp.add_orientations(
    geo_model=geo_model_3,
    x=fault2_orientations_3['X'].values,
    y=fault2_orientations_3['Y'].values,
    z=fault2_orientations_3['Z'].values,
    elements_names=fault2_orientations_3['surface'].values,
    pole_vector=np.array([[-0.966, 0, 0.259]])  # Opposite dipping fault
)

# Add stratigraphic layers
gp.add_surface_points(
    geo_model=geo_model_3,
    x=surface_points_3['X'].values,
    y=surface_points_3['Y'].values,
    z=surface_points_3['Z'].values,
    elements_names=surface_points_3['surface'].values
)

gp.add_orientations(
    geo_model=geo_model_3,
    x=orientations_3['X'].values,
    y=orientations_3['Y'].values,
    z=orientations_3['Z'].values,
    elements_names=orientations_3['surface'].values,
    pole_vector=np.array([[0, 0, 1], [0, 0, 1]])  # Horizontal layers
)

# Map geological series to surfaces
gp.map_stack_to_surfaces(
    gempy_model=geo_model_3,
    mapping_object={
        'Fault1_series': 'Fault1',
        'Fault2_series': 'Fault2',
        'Strata': ['Layer2', 'Layer3']
    }
)

# Set fault relations
geo_model_3.structural_frame.structural_groups[0].structural_relation = gp.data.StackRelationType.FAULT
geo_model_3.structural_frame.structural_groups[1].structural_relation = gp.data.StackRelationType.FAULT
geo_model_3.structural_frame.fault_relations = np.array([[0, 1, 1], [0, 0, 1], [0, 0, 0]])

# Gravity setup
xy_grid_3 = create_gravity_grid(extent, grav_res)
gp.set_centered_grid(
    grid=geo_model_3.grid,
    centers=xy_grid_3,
    resolution=np.array([10, 10, 15]),
    radius=np.array([2000, 2000, 2000])
)

gravity_gradient_3 = gp.calculate_gravity_gradient(geo_model_3.grid.centered_grid)
geo_model_3.geophysics_input = gp.data.GeophysicsInput(
    tz=gravity_gradient_3,
    densities=np.array([2.3, 2.6, 2.8])
)

# Compute and visualize
grav_3, sol_3 = compute_and_plot_gravity(geo_model_3, "Model 3: Two Opposing Faults", grav_res)

### Model 2: Two Opposing Dipping Faults
The most complex model with two faults dipping in opposite directions (forming a graben or horst structure). This produces an even more distinctive gravity response.

In [None]:
# ============================================================================
# COMPARISON PLOT
# ============================================================================
print("\n" + "=" * 70)
print("GRAVITY COMPARISON")
print("=" * 70)

fig, axes = plt.subplots(1, 3, figsize=(18, 5))

extent_vals = geo_model_1.grid.regular_grid.extent

im1 = axes[0].imshow(grav_1.reshape(grav_res, grav_res), 
                     extent=(extent_vals[0], extent_vals[1], extent_vals[2], extent_vals[3]),
                     cmap='RdBu_r', origin='lower', aspect='auto')
axes[0].set_title('Model 1: No Fault')
axes[0].set_xlabel('X')
axes[0].set_ylabel('Y')
plt.colorbar(im1, ax=axes[0], label='μgal')

im2 = axes[1].imshow(grav_2.reshape(grav_res, grav_res),
                     extent=(extent_vals[0], extent_vals[1], extent_vals[2], extent_vals[3]),
                     cmap='RdBu_r', origin='lower', aspect='auto')
axes[1].set_title('Model 2: One Fault')
axes[1].set_xlabel('X')
axes[1].set_ylabel('Y')
plt.colorbar(im2, ax=axes[1], label='μgal')

im3 = axes[2].imshow(grav_3.reshape(grav_res, grav_res),
                     extent=(extent_vals[0], extent_vals[1], extent_vals[2], extent_vals[3]),
                     cmap='RdBu_r', origin='lower', aspect='auto')
axes[2].set_title('Model 3: Two Opposing Faults')
axes[2].set_xlabel('X')
axes[2].set_ylabel('Y')
plt.colorbar(im3, ax=axes[2], label='μgal')

plt.tight_layout()
plt.show()

print("\nGravity Statistics:")
print(f"Model 1 (No Fault):     mean={grav_1.mean():.2f}, std={grav_1.std():.2f}, range=[{grav_1.min():.2f}, {grav_1.max():.2f}]")
print(f"Model 2 (One Fault):    mean={grav_2.mean():.2f}, std={grav_2.std():.2f}, range=[{grav_2.min():.2f}, {grav_2.max():.2f}]")
print(f"Model 3 (Two Faults):   mean={grav_3.mean():.2f}, std={grav_3.std():.2f}, range=[{grav_3.min():.2f}, {grav_3.max():.2f}]")

### Comparison of Gravity Responses

Let's compare the gravity responses from all three models. Notice how each model produces a unique gravity signature:
- **No fault**: Smooth, symmetric gravity field
- **One fault**: Clear linear discontinuity in the gravity field
- **Two faults**: Two discontinuities creating a more complex pattern

These differences will allow the Trans-Conceptual sampler to distinguish between the models.

# Trans-Conceptual Model Selection

Now we'll set up a Trans-Conceptual MCMC sampling example to determine which model best explains synthetic gravity observations.

**Goal**: Given noisy gravity observations, use pyTransC to identify the correct geological model (state) among the three competing hypotheses:
- **State 0**: No fault
- **State 1**: One fault  
- **State 2**: Two opposing faults

In [None]:
# ============================================================================
# GENERATE SYNTHETIC OBSERVATIONS
# ============================================================================
# Use Model 2 (One Fault) as the "true" model and add noise

print("=" * 70)
print("GENERATING SYNTHETIC OBSERVATIONS")
print("=" * 70)

# Choose Model 2 (one fault) as the true model
true_model = 2  # State 1 in 0-indexed
true_gravity = grav_2.copy()

# Add Gaussian noise to create synthetic observations
noise_std = 0.5  # Standard deviation of measurement noise (μgal)
np.random.seed(123)  # For reproducibility
noise = np.random.normal(0, noise_std, size=true_gravity.shape)
observed_gravity = true_gravity + noise

# Store for later use
n_data = len(observed_gravity)

# Visualize the observed data
fig, axes = plt.subplots(1, 3, figsize=(18, 5))

# Plot true gravity
im1 = axes[0].imshow(true_gravity.reshape(grav_res, grav_res),
                     extent=(extent_vals[0], extent_vals[1], extent_vals[2], extent_vals[3]),
                     cmap='RdBu_r', origin='lower', aspect='auto')
axes[0].set_title('True Gravity (Model 2: One Fault)')
axes[0].set_xlabel('X')
axes[0].set_ylabel('Y')
plt.colorbar(im1, ax=axes[0], label='μgal')

# Plot noise
im2 = axes[1].imshow(noise.reshape(grav_res, grav_res),
                     extent=(extent_vals[0], extent_vals[1], extent_vals[2], extent_vals[3]),
                     cmap='seismic', origin='lower', aspect='auto', vmin=-2*noise_std, vmax=2*noise_std)
axes[1].set_title(f'Measurement Noise (σ={noise_std} μgal)')
axes[1].set_xlabel('X')
axes[1].set_ylabel('Y')
plt.colorbar(im2, ax=axes[1], label='μgal')

# Plot observed data
im3 = axes[2].imshow(observed_gravity.reshape(grav_res, grav_res),
                     extent=(extent_vals[0], extent_vals[1], extent_vals[2], extent_vals[3]),
                     cmap='RdBu_r', origin='lower', aspect='auto')
axes[2].set_title('Observed Gravity (True + Noise)')
axes[2].set_xlabel('X')
axes[2].set_ylabel('Y')
plt.colorbar(im3, ax=axes[2], label='μgal')

plt.tight_layout()
plt.show()

print(f"\nTrue model: State {true_model-1} (Model {true_model}: One Fault)")
print(f"Number of observations: {n_data}")
print(f"Noise level: σ = {noise_std} μgal")
print(f"SNR: {true_gravity.std() / noise_std:.2f}")

---

## Part 2: Creating Synthetic Observations

To test the Trans-Conceptual sampling approach, we'll create synthetic observations by:
1. Selecting one model as the "true" model (we'll use Model 1: One fault)
2. Adding realistic measurement noise to its gravity response
3. Treating this as our "observed" data

The goal is then to see if the Trans-Conceptual sampler can correctly identify which model generated the data, even in the presence of noise.

In [None]:
# ============================================================================
# DEFINE LOG-LIKELIHOOD AND LOG-PRIOR FUNCTIONS
# ============================================================================
print("=" * 70)
print("SETTING UP TRANS-CONCEPTUAL SAMPLING")
print("=" * 70)

# Import probability functions from module (for macOS/Windows compatibility)
from functools import partial
from gravity_functions import (
    compute_gravity_for_state as _compute_gravity_for_state,
    log_likelihood as _log_likelihood,
    log_prior,
    log_posterior as _log_posterior
)

# Create partial functions binding the gravity data
compute_gravity_for_state = partial(_compute_gravity_for_state, grav_1=grav_1, grav_2=grav_2, grav_3=grav_3)
log_likelihood = partial(_log_likelihood, observed_gravity=observed_gravity, noise_std=noise_std, grav_1=grav_1, grav_2=grav_2, grav_3=grav_3)
log_posterior = partial(_log_posterior, observed_gravity=observed_gravity, noise_std=noise_std, grav_1=grav_1, grav_2=grav_2, grav_3=grav_3)

print("✓ Log-likelihood and log-prior functions defined")
print(f"  - Each state has 1 parameter: amplitude (scaling factor)")
print(f"  - Prior: amplitude ~ Uniform(0.5, 1.5)")
print(f"  - Likelihood: Gaussian with fixed σ = {noise_std} μgal")

---

## Part 3: Bayesian Inference Setup

### Parameterization

For this example, we keep the parameterization simple:
- **Each model has 1 parameter**: An amplitude scaling factor that represents uncertainty in the density contrast between layers
- **Physical interpretation**: Real density contrasts are uncertain, so this parameter allows each model to optimize its fit to the data

### Likelihood Function

We use a Gaussian likelihood:
```
p(data | model, params) ∝ exp(-0.5 * ||observed - predicted||² / σ²)
```

where σ is the known measurement noise standard deviation.

### Prior Distribution

We use a uniform prior on the amplitude: `amplitude ~ Uniform(0.5, 1.5)`, allowing for ±50% uncertainty in density contrast.

In [None]:
# ============================================================================
# SETUP TRANS-CONCEPTUAL SAMPLING PARAMETERS
# ============================================================================

# Import pyTransC samplers
from pytransc.samplers import run_mcmc_per_state, run_ensemble_resampler
from pytransc.utils.auto_pseudo import build_auto_pseudo_prior
from pytransc.analysis.visits import get_visits_to_states

# Sampling parameters
nstates = 3  # Number of competing models
ndims = [1, 1, 1]  # Each state has 1 parameter (amplitude)
nwalkers = 32  # Number of MCMC walkers per state
nsteps_per_state = 1000  # Steps for within-state sampling
nsteps_ensemble = 10000  # Steps for ensemble resampler

# Initialize walker starting positions for each state
# Start walkers around amplitude=1.0 with small perturbations
np.random.seed(42)
pos = []
for state in range(nstates):
    # Create starting positions: amplitude ~ Uniform(0.9, 1.1)
    pos_state = np.random.uniform(0.9, 1.1, size=(nwalkers, 1))
    pos.append(pos_state)
    
    # Test the log-posterior at the first walker position
    test_params = pos_state[0]
    lp = log_prior(test_params, state)
    ll = log_likelihood(test_params, state)
    lpost = lp + ll
    print(f"State {state}: log-prior={lp:.2f}, log-likelihood={ll:.2f}, log-posterior={lpost:.2f}")

print(f"\n✓ Sampling configuration complete")
print(f"  - Number of states: {nstates}")
print(f"  - Dimensions per state: {ndims}")
print(f"  - Walkers per state: {nwalkers}")
print(f"  - Steps for within-state sampling: {nsteps_per_state}")
print(f"  - Steps for ensemble resampler: {nsteps_ensemble}")

---

## Part 4: Trans-Conceptual MCMC Sampling

### Two-Stage Approach

We use the **ensemble resampler** methodology:

**Stage 1: Within-State Sampling**
- Run standard MCMC independently for each model
- This builds a posterior ensemble for each model's parameter space
- Goal: Learn about each model's parameter distribution given the data

**Stage 2: Trans-Conceptual Sampling**
- Use the posterior ensembles to construct "pseudo-priors" (bridge distributions)
- Run ensemble resampler to jump between models
- Goal: Estimate posterior probability of each model

### Why This Approach?

The two-stage approach avoids expensive forward model evaluations during trans-conceptual jumps by:
1. Pre-computing parameter samples for each model
2. Resampling from these ensembles during model jumps
3. Using pseudo-priors to ensure detailed balance and correct posterior probabilities

In [None]:
# ============================================================================
# STAGE 1: RUN MCMC PER STATE TO BUILD POSTERIOR ENSEMBLES
# ============================================================================
print("=" * 70)
print("STAGE 1: WITHIN-STATE MCMC SAMPLING")
print("=" * 70)

import time

start_time = time.time()

# Run MCMC independently for each state
ensembles, log_probs = run_mcmc_per_state(
    n_states=nstates,
    n_dims=ndims,
    n_walkers=nwalkers,
    n_steps=nsteps_per_state,
    pos=pos,
    log_posterior=log_posterior,
    verbose=True,
    skip_initial_state_check=True,
)

elapsed_time = time.time() - start_time

print(f"\n✓ Within-state sampling completed in {elapsed_time:.2f} seconds")
print(f"\nEnsemble shapes:")
for i, ens in enumerate(ensembles):
    print(f"  State {i}: {ens.shape} (walkers × steps, dims)")
print(f"\nLog-probability shapes:")
for i, lp in enumerate(log_probs):
    print(f"  State {i}: {lp.shape}")
    
# Visualize posterior distributions for each state
fig, axes = plt.subplots(1, 3, figsize=(18, 4))

for state in range(nstates):
    ax = axes[state]
    
    # Extract amplitude samples
    amplitude_samples = ensembles[state][:, 0]
    
    # Plot histogram
    ax.hist(amplitude_samples, bins=50, density=True, alpha=0.7, edgecolor='black')
    ax.axvline(1.0, color='r', linestyle='--', linewidth=2, label='True value')
    ax.set_xlabel('Amplitude')
    ax.set_ylabel('Density')
    ax.set_title(f'State {state} Posterior\n({"No fault" if state==0 else "One fault" if state==1 else "Two faults"})')
    ax.legend()
    ax.grid(True, alpha=0.3)
    
plt.tight_layout()
plt.show()

# Print summary statistics
print("\nPosterior summary statistics:")
for state in range(nstates):
    amp_samples = ensembles[state][:, 0]
    print(f"  State {state}: mean={amp_samples.mean():.4f}, std={amp_samples.std():.4f}, "
          f"95% CI=[{np.percentile(amp_samples, 2.5):.4f}, {np.percentile(amp_samples, 97.5):.4f}]")

### Stage 1: Within-State Sampling

We now run MCMC independently for each of the three models using `run_mcmc_per_state()`. This function uses emcee (the affine-invariant ensemble sampler) to explore the parameter space of each model.

**What to expect:**
- Each model will find an optimal amplitude parameter that best fits the data
- Model 1 (one fault) should achieve the best fit since it generated the data
- Models 0 and 2 will try to compensate with different amplitude values but won't fit as well

In [None]:
# ============================================================================
# STAGE 2: BUILD PSEUDO-PRIORS FROM POSTERIOR ENSEMBLES
# ============================================================================
print("=" * 70)
print("STAGE 2: BUILDING PSEUDO-PRIORS")
print("=" * 70)

from scipy import stats

# Build pseudo-priors using Gaussian approximation
# For each state, fit a Gaussian to the posterior ensemble

log_pseudo_prior_ens = []

for state in range(nstates):
    # Get ensemble for this state
    ens = ensembles[state]
    
    # Fit Gaussian: mean and covariance
    mean = np.mean(ens, axis=0)
    cov = np.diag(np.var(ens, axis=0))  # Diagonal covariance
    
    # Create multivariate normal distribution
    rv = stats.multivariate_normal(mean=mean, cov=cov)
    
    # Evaluate pseudo-prior log-probability for all ensemble members
    log_pseudo_prior_ens.append(rv.logpdf(ens))
    
    print(f"State {state}: mean={mean[0]:.4f}, std={np.sqrt(cov[0,0]):.4f}")

print(f"\n✓ Pseudo-priors built for all {nstates} states")

# Prepare data for ensemble resampler
ensemble_per_state = ensembles
log_posterior_ens = log_probs

print(f"\nReady for ensemble resampler:")
for i in range(nstates):
    print(f"  State {i}: {len(ensemble_per_state[i])} samples")

### Stage 2: Building Pseudo-Priors

The pseudo-priors are bridge distributions that enable valid transitions between models. We construct them by fitting Gaussian distributions to each posterior ensemble.

**Key concept:** The pseudo-prior for state *i* approximates the posterior distribution p(θ_i | data, state i). This allows us to:
1. Propose parameters when jumping TO state *i*
2. Calculate acceptance ratios that maintain detailed balance
3. Correctly estimate posterior model probabilities

For this simple 1D problem, we use diagonal Gaussian approximations.

In [None]:
# ============================================================================
# STAGE 3: RUN ENSEMBLE RESAMPLER FOR TRANS-CONCEPTUAL MODEL SELECTION
# ============================================================================
print("=" * 70)
print("STAGE 3: TRANS-CONCEPTUAL ENSEMBLE RESAMPLER")
print("=" * 70)

start_time = time.time()

# Run ensemble resampler
er_results = run_ensemble_resampler(
    n_walkers=16,  # Number of walkers for ensemble resampler
    n_steps=nsteps_ensemble,
    n_states=nstates,
    n_dims=ndims,
    log_posterior_ens=log_posterior_ens,
    log_pseudo_prior_ens=log_pseudo_prior_ens,
    parallel=False,
    progress=True
)

elapsed_time = time.time() - start_time

print(f"\n✓ Ensemble resampler completed in {elapsed_time:.2f} seconds")

# Extract results
state_chains = er_results.state_chain  # (n_walkers, n_steps)
n_accepted = er_results.n_accepted
n_proposed = er_results.n_proposed
acceptance_rates = n_accepted / n_proposed * 100

print(f"\nState chain shape: {state_chains.shape}")
print(f"Average acceptance rate: {np.mean(acceptance_rates):.2f}%")
print(f"Acceptance rate range: {np.min(acceptance_rates):.2f}% - {np.max(acceptance_rates):.2f}%")

# Calculate state visitation frequencies
state_visits = np.bincount(state_chains.flatten(), minlength=nstates)
state_frequencies = state_visits / state_visits.sum()

print(f"\nState visitation frequencies:")
state_names = ["No fault", "One fault", "Two faults"]
for i in range(nstates):
    print(f"  State {i} ({state_names[i]:12s}): {state_frequencies[i]:.4f} ({state_visits[i]:,} visits)")

print(f"\nTrue model was State 1 (One fault)")

### Stage 3: Trans-Conceptual Sampling with Ensemble Resampler

Now we run the ensemble resampler using `run_ensemble_resampler()`. This sampler:

1. **Proposes model changes**: Randomly selects a new state (model) to jump to
2. **Resamples parameters**: Draws a parameter value from the target state's ensemble
3. **Calculates acceptance probability**: Using the Metropolis-Hastings ratio with pseudo-priors
4. **Accepts/rejects**: Maintains detailed balance to correctly sample the trans-conceptual posterior

**What to expect:**
- The chain will visit all three states
- Model 1 (one fault) should be visited most frequently since it's the true model
- The visitation frequencies approximate the posterior model probabilities

In [None]:
# ============================================================================
# VISUALIZATION AND ANALYSIS OF TRANS-CONCEPTUAL RESULTS
# ============================================================================
print("=" * 70)
print("RESULTS ANALYSIS")
print("=" * 70)

# Create comprehensive visualization
fig = plt.figure(figsize=(18, 10))
gs = fig.add_gridspec(3, 3, hspace=0.3, wspace=0.3)

# 1. State trace plot (top row, spanning all columns)
ax1 = fig.add_subplot(gs[0, :])
for walker_idx in range(state_chains.shape[0]):
    ax1.plot(state_chains[walker_idx, :], alpha=0.3, linewidth=0.5)
ax1.axhline(y=1, color='r', linestyle='--', linewidth=2, label='True state (One fault)')
ax1.set_xlabel('Iteration')
ax1.set_ylabel('State Index')
ax1.set_title('State Trace: Trans-Conceptual MCMC Chains')
ax1.set_yticks([0, 1, 2])
ax1.set_yticklabels(['No fault', 'One fault', 'Two faults'])
ax1.grid(True, alpha=0.3)
ax1.legend()

# 2. State histogram (middle left)
ax2 = fig.add_subplot(gs[1, 0])
# Use all samples after burn-in (skip first 10%)
burnin = int(0.1 * state_chains.shape[1])
state_samples_burnin = state_chains[:, burnin:].flatten()
counts, bins, patches = ax2.hist(state_samples_burnin, bins=np.arange(nstates+1)-0.5, 
                                   density=True, alpha=0.7, edgecolor='black')
# Color the true state bar in red
patches[1].set_facecolor('red')
patches[1].set_alpha(0.8)
ax2.set_xlabel('State Index')
ax2.set_ylabel('Probability Density')
ax2.set_title(f'State Posterior Distribution\n(after {burnin} step burn-in)')
ax2.set_xticks([0, 1, 2])
ax2.set_xticklabels(['No\nfault', 'One\nfault', 'Two\nfaults'])
ax2.grid(True, alpha=0.3, axis='y')

# Add probability text on bars
state_probs_burnin = np.bincount(state_samples_burnin.astype(int), minlength=nstates) / len(state_samples_burnin)
for i, prob in enumerate(state_probs_burnin):
    ax2.text(i, prob + 0.05, f'{prob:.3f}', ha='center', fontsize=12, fontweight='bold')

# 3. Cumulative state probability (middle center)
ax3 = fig.add_subplot(gs[1, 1])
cumulative_visits = np.zeros((nstates, state_chains.shape[1]))
for step in range(state_chains.shape[1]):
    visits = np.bincount(state_chains[:, :step+1].flatten(), minlength=nstates)
    cumulative_visits[:, step] = visits / visits.sum()

colors = ['blue', 'red', 'green']
for state in range(nstates):
    ax3.plot(cumulative_visits[state, :], label=state_names[state], 
             color=colors[state], linewidth=2, alpha=0.8)
ax3.set_xlabel('Iteration')
ax3.set_ylabel('Cumulative Probability')
ax3.set_title('Convergence of State Probabilities')
ax3.legend()
ax3.grid(True, alpha=0.3)
ax3.set_ylim([0, 1])

# 4. Model comparison via Bayes factors (middle right)
ax4 = fig.add_subplot(gs[1, 2])
# Calculate Bayes factors relative to State 1 (true model)
# BF_ij = P(State i) / P(State j)
bayes_factors = state_probs_burnin / state_probs_burnin[1]  # Relative to State 1
bar_colors = ['blue' if i != 1 else 'red' for i in range(nstates)]
bars = ax4.bar(range(nstates), bayes_factors, color=bar_colors, alpha=0.7, edgecolor='black')
ax4.axhline(y=1, color='gray', linestyle='--', linewidth=1)
ax4.set_xlabel('State Index')
ax4.set_ylabel('Bayes Factor (relative to State 1)')
ax4.set_title('Bayes Factors\n(Evidence Ratios)')
ax4.set_xticks([0, 1, 2])
ax4.set_xticklabels(['No\nfault', 'One\nfault', 'Two\nfaults'])
ax4.set_yscale('log')
ax4.grid(True, alpha=0.3, axis='y')

# Add values on bars
for i, bf in enumerate(bayes_factors):
    ax4.text(i, bf * 1.2, f'{bf:.2f}', ha='center', fontsize=10, fontweight='bold')

# 5-7. Gravity residuals for each model (bottom row)
for state in range(nstates):
    ax = fig.add_subplot(gs[2, state])
    
    # Compute predicted gravity for this state using mean amplitude from posterior
    mean_amplitude = ensembles[state][:, 0].mean()
    predicted = compute_gravity_for_state(state, mean_amplitude)
    residual = observed_gravity - predicted
    
    # Plot residual map
    im = ax.imshow(residual.reshape(grav_res, grav_res),
                   extent=(extent_vals[0], extent_vals[1], extent_vals[2], extent_vals[3]),
                   cmap='seismic', origin='lower', aspect='auto',
                   vmin=-2*noise_std, vmax=2*noise_std)
    ax.set_title(f'{state_names[state]}\nRMS={np.sqrt(np.mean(residual**2)):.3f} μgal')
    ax.set_xlabel('X')
    ax.set_ylabel('Y')
    plt.colorbar(im, ax=ax, label='Residual (μgal)')

plt.suptitle('Trans-Conceptual MCMC: Gravity Model Selection', fontsize=16, fontweight='bold', y=0.995)
plt.show()

# Print final summary
print("\n" + "=" * 70)
print("FINAL SUMMARY")
print("=" * 70)
print(f"\nTrue model: State 1 (One fault)")
print(f"\nPosterior probabilities (after burn-in):")
for i in range(nstates):
    print(f"  State {i} ({state_names[i]:12s}): {state_probs_burnin[i]:.4f}")
print(f"\nBayes factors (relative to State 1):")
for i in range(nstates):
    print(f"  State {i} ({state_names[i]:12s}): {bayes_factors[i]:.2f}")
    
# Interpretation
if state_probs_burnin[1] > 0.5:
    print(f"\n✓ SUCCESS: Trans-Conceptual sampling correctly identified State 1 (One fault)")
    print(f"  with posterior probability {state_probs_burnin[1]:.4f}")
else:
    print(f"\n⚠ CAUTION: Trans-Conceptual sampling assigned highest probability to State {np.argmax(state_probs_burnin)}")
    print(f"  but true model was State 1")

---

## Conclusions and Next Steps

### What We Demonstrated

This notebook showed how Trans-Conceptual MCMC can:
1. **Simultaneously infer parameters and select models** - No need to run separate analyses for each model
2. **Quantify model uncertainty** - Get posterior probabilities, not just point estimates
3. **Handle model complexity** - Automatically account for different model complexities through the evidence

### Extensions and Variations

You can modify this example to explore:

**1. More complex parameterizations:**
- Add more parameters (fault location, dip angle, layer depths, density values)
- Use more realistic priors based on geological knowledge

**2. Different noise levels:**
- Increase noise to see when models become indistinguishable
- Add correlated noise to simulate systematic errors

**3. More models:**
- Add models with different numbers of layers
- Include models with intrusions or other geological features

**4. Alternative samplers:**
- Try `run_product_space_sampler()` for simultaneous parameter space exploration
- Try `run_state_jump_sampler()` for RJ-MCMC style trans-dimensional sampling

**5. Real data:**
- Replace synthetic observations with real gravity survey data
- Incorporate data uncertainties and measurement locations

### Key Takeaways

✓ Trans-Conceptual sampling provides a principled way to compare models with different structures

✓ The ensemble resampler is efficient for problems where forward model evaluation is expensive

✓ Posterior model probabilities account for both fit quality and model complexity (via Occam's razor)

✓ The methodology naturally handles model uncertainty in geophysical inverse problems

---

## Part 5: Results and Model Selection

### Interpreting the Results

The ensemble resampler output provides:

1. **State chains**: Trace of which model each walker visited at each iteration
2. **Posterior model probabilities**: Fraction of samples in each state (after burn-in)
3. **Bayes factors**: Relative evidence ratios between models
4. **Acceptance rates**: Diagnostic for sampler efficiency

**Success criteria:**
- Model 1 (one fault) should have the highest posterior probability
- Bayes factors should strongly favor Model 1
- State chains should show good mixing between models
- Gravity residuals should be smallest for Model 1