## Example of Trans-C sampling with Pool Parameter Exposure: Regression

This notebook demonstrates the new `pool` parameter functionality for Bayesian sampling
using the Ensemble resampler with custom pool configurations (e.g., schwimmbad).

In [1]:
# general python utility packages
import time
from collections import Counter
from functools import partial
from multiprocessing import Pool

import corner
import matplotlib.pyplot as plt
import numpy as np
import scipy.stats as stats
from scipy.optimize import minimize

In [2]:
from pytransc.analysis.samples import resample_ensembles
from pytransc.analysis.visits import (
    count_state_changes,
    count_total_state_changes,
    get_acceptance_rate_between_states,
    get_autocorr_between_state_jumps,
    get_relative_marginal_likelihoods,
    get_visits_to_states,
)
from pytransc.samplers import run_ensemble_resampler, run_mcmc_per_state
from pytransc.utils.auto_pseudo import build_auto_pseudo_prior

In [3]:
# Try to import schwimmbad for demonstration
try:
    from schwimmbad import MPIPool, MultiPool
    SCHWIMMBAD_AVAILABLE = True
    print("✓ schwimmbad available - will demonstrate MPI and multiprocessing pools")
except ImportError:
    SCHWIMMBAD_AVAILABLE = False
    print("⚠ schwimmbad not available - will use standard multiprocessing only")

✓ schwimmbad available - will demonstrate MPI and multiprocessing pools


### Setup regression problem (same as original example)

In [4]:
def solveLS(d, G, Cdinv, mu=None, Cmpriorinv=None):
    """Solve the least squares problem given data, design matrix, and covariance matrices."""
    A = np.dot(np.transpose(G), Cdinv)
    GtG = np.dot(A, G)
    if (Cmpriorinv is not None) and (mu is not None):
        GtG += Cmpriorinv
    GtGinv = np.linalg.inv(GtG)
    B = np.dot(A, d)
    if (Cmpriorinv is not None) and (mu is not None):
        B += np.dot(Cmpriorinv, mu)
    mls = np.dot(GtGinv, B)
    return mls, GtGinv

def getG(x, order):
    """Build design matrix, G for polynomial data fitting."""
    return np.transpose([x ** (i) for i in range(order + 1)])

In [5]:
# Generate synthetic data
np.random.seed(61254557)
ndata, sigma = 20, 0.2
xobs = np.sort(np.random.rand(ndata))
mtrue = np.array([0.3, 0.6])
G = getG(xobs, 1)
dobs = np.dot(G, mtrue) + np.random.normal(0, sigma, size=len(xobs))
Cdinv = np.eye(ndata) / (sigma**2)

# Setup problem parameters
maxorder = 3
nstates = maxorder + 1
ndims = [1, 2, 3, 4]
Cmprior = [20 * np.eye(i + 1) for i in np.arange(maxorder + 1)]
muprior = [np.zeros(i + 1) for i in np.arange(maxorder + 1)]

In [6]:
# Define log-posterior functions
def _log_likelihood(x, state, dobs, G, Cdinv):
    dpred = np.dot(G[state], x)
    misfit = 0.5 * (np.dot((dobs - dpred), np.dot(Cdinv, (dobs - dpred))))
    f = np.sqrt(2 * np.pi) ** len(dpred)
    detCdinv = np.linalg.det(Cdinv)
    logL = -np.log(f) - misfit + np.log(detCdinv)
    return logL

def _log_prior(x, state, muprior, Cmprior):
    mu = muprior[state]
    cov = Cmprior[state]
    rv = stats.multivariate_normal(mean=mu, cov=cov)
    return rv.logpdf(x)

# Create partial functions
Gp = [getG(xobs, i) for i in range(maxorder + 1)]
log_likelihood = partial(_log_likelihood, dobs=dobs, G=Gp, Cdinv=Cdinv)
log_prior = partial(_log_prior, muprior=muprior, Cmprior=Cmprior)

def log_posterior(x, state):
    return log_likelihood(x, state) + log_prior(x, state)

In [7]:
# Find ML estimates for starting points
rng = np.random.default_rng(42)
nll = lambda *args: -log_posterior(*args)
ml = []
for i in range(nstates):
    initial = 0.5 * np.ones(i + 1)
    soln = minimize(nll, initial, args=(i,))
    ml.append(soln.x)

# Setup MCMC parameters
nsamples_es = 4 * [20000]  # Reduced for demo
nwalkers_es = 4 * [16]     # Reduced for demo
pos = []
for i in range(nstates):
    pos.append(ml[i] + 1e-4 * rng.standard_normal((nwalkers_es[i], ndims[i])))

## Pool Parameter Demonstration

The new `pool` parameter allows you to provide custom parallelization strategies:

### Method 1: Standard multiprocessing.Pool (user-managed)

In [None]:
print("=== Method 1: User-managed multiprocessing.Pool ===")
start_time = time.time()

# User creates and manages the pool

user_pool=Pool(processes=4)
    
ensemble_per_state_1, log_posterior_ens_1 = run_mcmc_per_state(
    nstates,
    ndims,
    nwalkers_es,
    nsamples_es,
    pos,
    log_posterior,
    pool=user_pool,  # ← NEW: Pass user-managed pool
    discard=0,
    auto_thin=True,
    verbose=True)


elapsed_1 = time.time() - start_time
print(f"✓ Completed with user-managed pool in {elapsed_1:.1f}s")
print(f"Samples per state: {[len(ens) for ens in ensemble_per_state_1]}")

=== Method 1: User-managed multiprocessing.Pool ===

Running within-state sampler separately on each state

Number of walkers               :  [16, 16, 16, 16]

Number of states being sampled:  4
Dimensions of each state:  [1, 2, 3, 4]


 82%|██████████████████████████████████████████████████████████▎            | 16428/20000 [01:56<00:27, 127.88it/s]

### Method 2: schwimmbad MultiPool (if available)

In [9]:
if SCHWIMMBAD_AVAILABLE:
    print("=== Method 2: schwimmbad MultiPool ===")
    start_time = time.time()
    
    with MultiPool(processes=4) as schwimm_pool:
        ensemble_per_state_2, log_posterior_ens_2 = run_mcmc_per_state(
            nstates,
            ndims,
            nwalkers_es,
            nsamples_es,
            pos,
            log_posterior,
            pool=schwimm_pool,  # ← NEW: Pass schwimmbad pool
            discard=0,
            auto_thin=True,
            verbose=True,
        )
    
    elapsed_2 = time.time() - start_time
    print(f"✓ Completed with schwimmbad MultiPool in {elapsed_2:.1f}s")
    print(f"Samples per state: {[len(ens) for ens in ensemble_per_state_2]}")
else:
    print("⚠ Skipping schwimmbad example (not available)")
    ensemble_per_state_2 = ensemble_per_state_1
    log_posterior_ens_2 = log_posterior_ens_1

=== Method 2: schwimmbad MultiPool ===

Running within-state sampler separately on each state

Number of walkers               :  [16, 16, 16, 16]

Number of states being sampled:  4
Dimensions of each state:  [1, 2, 3, 4]


 29%|█████████████████████▏                                                   | 5791/20000 [01:33<03:49, 61.83it/s]


KeyboardInterrupt: 

### Method 3: Library-managed (backward compatibility)

In [None]:
print("=== Method 3: Library-managed (existing behavior) ===")
start_time = time.time()

ensemble_per_state_3, log_posterior_ens_3 = run_mcmc_per_state(
    nstates,
    ndims,
    nwalkers_es,
    nsamples_es,
    pos,
    log_posterior,
    # pool=None,  # ← Default behavior (pool not specified)
    parallel=True,      # ← Library creates pool internally
    n_processors=4,
    discard=0,
    auto_thin=True,
    verbose=True,
)

elapsed_3 = time.time() - start_time
print(f"✓ Completed with library-managed pool in {elapsed_3:.1f}s")
print(f"Samples per state: {[len(ens) for ens in ensemble_per_state_3]}")

### Run ensemble resampler with pool parameter

In [None]:
# Use the first method's results for ensemble resampling
ensemble_per_state = ensemble_per_state_1
log_posterior_ens = log_posterior_ens_1

# Build pseudo priors
log_pseudo_prior = build_auto_pseudo_prior(ensemble_per_state=ensemble_per_state)
log_pseudo_prior_ens = []
for i, ens in enumerate(ensemble_per_state):
    log_pseudo_prior_ens.append(np.array([log_pseudo_prior(x, i) for x in ens]))

In [None]:
print("=== Ensemble Resampler with Pool Parameter ===")

nwalkers_er = 16
nsteps_er = 50000
start_time = time.time()

# Demonstrate pool parameter in ensemble resampler
with Pool(processes=4) as resampler_pool:
    resampler_chains = run_ensemble_resampler(
        nwalkers_er,
        nsteps_er,
        nstates,
        ndims,
        log_posterior_ens=log_posterior_ens,
        log_pseudo_prior_ens=log_pseudo_prior_ens,
        pool=resampler_pool,  # ← NEW: Pass pool to ensemble resampler
        progress=True,
    )

elapsed_ensemble = time.time() - start_time
print(f"✓ Ensemble resampler completed in {elapsed_ensemble:.1f}s")

### Results and comparison

In [None]:
# Extract results
from pytransc.analysis.visits import get_relative_marginal_likelihoods

relative_marginal_likelihoods = get_relative_marginal_likelihoods(
    resampler_chains.state_chain_tot[:, -1, :]
)

print("\n=== Results Summary ===")
print(f"Estimated relative evidences: {np.round(relative_marginal_likelihoods, 4)}")
print("\n=== Pool Parameter Benefits ===")
print("✓ User controls pool lifecycle (creation/cleanup)")
print("✓ Can use schwimmbad for MPI, cluster computing")
print("✓ Backward compatible - existing code unchanged")
print("✓ Avoids memory leaks from unclosed pools")

## MPI Example (conceptual)

For cluster/MPI usage with schwimmbad:

In [10]:
# Example of MPI usage (requires mpirun to actually work)
print("\n=== MPI Example (conceptual) ===")
print("""
# For MPI execution:
from schwimmbad import MPIPool

# Run with: mpirun -n 8 python your_script.py
with MPIPool() as mpi_pool:
    ensemble_per_state, log_posterior_ens = run_mcmc_per_state(
        nstates, ndims, nwalkers_es, nsamples_es, pos, log_posterior,
        pool=mpi_pool,  # Uses MPI across cluster nodes
        discard=0, auto_thin=True
    )
""")

if SCHWIMMBAD_AVAILABLE:
    print("✓ schwimmbad available - MPI example would work")
else:
    print("⚠ Install schwimmbad for MPI support: pip install schwimmbad")


=== MPI Example (conceptual) ===

# For MPI execution:
from schwimmbad import MPIPool

# Run with: mpirun -n 8 python your_script.py
with MPIPool() as mpi_pool:
    ensemble_per_state, log_posterior_ens = run_mcmc_per_state(
        nstates, ndims, nwalkers_es, nsamples_es, pos, log_posterior,
        pool=mpi_pool,  # Uses MPI across cluster nodes
        discard=0, auto_thin=True
    )

✓ schwimmbad available - MPI example would work


## Key Implementation Notes

1. **Pool Parameter Precedence**: When `pool` is provided, it takes precedence over `parallel` and `n_processors`
2. **Memory Management**: User-provided pools are not closed by the library - user manages lifecycle
3. **Backward Compatibility**: All existing code continues to work unchanged
4. **Type Flexibility**: `pool: Any | None = None` accepts any pool-like object with a `map()` method