# BIOSTAT 682 HW4 - Bayesian Neural Networks for Crime Data (5k draws/tune)

This notebook fits Bayesian neural networks with spike-and-slab priors to the UScrime dataset.

**Workflow:**
**Note:** This version uses fixed 5k draws/tune for all grid search combinations.
1. Load and preprocess data (standardize, train/test split)
2. **Parallelized** grid search over prior types, draws/tune, and hidden units
3. Compare BNN models with different hidden units using DIC
4. Evaluate test set performance
5. Compare with Bayesian linear regression

## 1. Setup and Imports

In [None]:
import time
import datetime
import os
import re
import warnings
import pickle
from pathlib import Path
from joblib import Parallel, delayed, dump, load
import multiprocessing as mp

import numpy as np
import pandas as pd
import pymc as pm
import arviz as az
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score

warnings.filterwarnings('ignore')

# Configuration
SEED = 2025
np.random.seed(SEED)
N_WORKERS = max(1, mp.cpu_count() - 1)  # Leave one core free

# Directories
MODELS_DIR = Path("../../data/models/Solution1_5k")
MODELS_DIR.mkdir(exist_ok=True, parents=True)

# Log files
GRIDSEARCH_LOG = "bnn_gridsearch_5k.log"
GENERAL_LOG = "crime_bnn_optimized_run_5k.log"

# Clear log files at start of execution (truncate, don't delete)
for log_file in [GRIDSEARCH_LOG, GENERAL_LOG]:
    if os.path.exists(log_file):
        with open(log_file, "w") as f:
            pass  # Truncate file to clear it
        print(f"Cleared log file: {log_file}")

EXEC_START = time.time()
print(f"Started: {datetime.datetime.now().isoformat(timespec='seconds')}")
print(f"Available workers for parallel grid search: {N_WORKERS}")
print(f"Models will be saved to: {MODELS_DIR}")
print(f"Log files cleared and ready for new run")

## 2. Data Loading and Preprocessing

Load the UScrime dataset, standardize features and target, and create a 50/50 train/test split.

In [None]:
# Load data
df = pd.read_csv('../../data/UScrime.csv')
X = df.drop('y', axis=1).values
y = df['y'].values

# Standardize
scaler_X, scaler_y = StandardScaler(), StandardScaler()
X_scaled = scaler_X.fit_transform(X)
y_scaled = scaler_y.fit_transform(y.reshape(-1, 1)).flatten()

# Train/test split
X_train, X_test, y_train, y_test = train_test_split(
    X_scaled, y_scaled, test_size=0.5, random_state=SEED
)

n_features = X_scaled.shape[1]
print(f"Data: {X_scaled.shape[0]} observations, {n_features} features")
print(f"Train: {len(y_train)}, Test: {len(y_test)}")

## 3. Model Definitions

Two BNN architectures with spike-and-slab priors:

1. **Current priors**: Supports centered/non-centered parameterization with Bernoulli selection
2. **HW3 attempt2 priors**: Precision-based spike-and-slab using inverse precision

Both use a single hidden layer with tanh activation.

In [None]:
def create_bnn_spike_slab(X_train, y_train, q, X_test=None, use_noncentered=True):
    """
    Create BNN with one hidden layer and spike-and-slab priors.
    
    Args:
        X_train (np.ndarray): Training features of shape (n, p)
        y_train (np.ndarray): Training targets of shape (n,)
        q (int): Number of hidden units
        X_test (np.ndarray, optional): Test features for prediction. Defaults to None.
        use_noncentered (bool): If True, use non-centered parameterization. Defaults to True.
    
    Returns:
        pm.Model: PyMC model with spike-and-slab BNN architecture
    
    Raises:
        ValueError: If X_train and y_train have incompatible shapes
    """
    n, p = X_train.shape
    pi1, pi2 = 0.5, 0.5
    spike_sd, slab_sd = 0.01, 1.0
    
    with pm.Model() as model:
        # Layer 1: Input -> Hidden
        gamma1 = pm.Bernoulli("gamma1", p=pi1, shape=(p, q))
        sd1 = spike_sd + gamma1 * (slab_sd - spike_sd)
        
        if use_noncentered:
            W1_raw = pm.Normal("W1_raw", mu=0, sigma=1, shape=(p, q))
            W1 = pm.Deterministic("W1", W1_raw * sd1)
        else:
            W1 = pm.Normal("W1", mu=0, sigma=sd1, shape=(p, q))
        
        b1 = pm.Normal("b1", mu=0, sigma=1, shape=q)
        hidden = pm.math.tanh(pm.math.dot(X_train, W1) + b1)
        
        # Layer 2: Hidden -> Output
        gamma2 = pm.Bernoulli("gamma2", p=pi2, shape=q)
        sd2 = spike_sd + gamma2 * (slab_sd - spike_sd)
        
        if use_noncentered:
            W2_raw = pm.Normal("W2_raw", mu=0, sigma=1, shape=q)
            W2 = pm.Deterministic("W2", W2_raw * sd2)
        else:
            W2 = pm.Normal("W2", mu=0, sigma=sd2, shape=q)
        
        b2 = pm.Normal("b2", mu=0, sigma=1)
        mu = pm.math.dot(hidden, W2) + b2
        
        # Likelihood
        sigma = pm.HalfNormal("sigma", sigma=1)
        pm.Normal("y_obs", mu=mu, sigma=sigma, observed=y_train)
        
        # Test predictions
        if X_test is not None:
            hidden_test = pm.math.tanh(pm.math.dot(X_test, W1) + b1)
            mu_test = pm.math.dot(hidden_test, W2) + b2
            pm.Normal("y_pred", mu=mu_test, sigma=sigma, shape=X_test.shape[0])
    
    return model


def create_bnn_hw3_attempt2(X_train, y_train, q, X_test=None):
    """
    Create BNN with precision-based spike-and-slab priors (HW3 attempt2 style).
    
    Uses inverse precision parameters for spike (high precision = small variance)
    and slab (low precision = large variance) components.
    
    Args:
        X_train (np.ndarray): Training features of shape (n, p)
        y_train (np.ndarray): Training targets of shape (n,)
        q (int): Number of hidden units
        X_test (np.ndarray, optional): Test features for prediction. Defaults to None.
    
    Returns:
        pm.Model: PyMC model with precision-based spike-and-slab BNN architecture
    
    Raises:
        ValueError: If X_train and y_train have incompatible shapes
    """
    n, p = X_train.shape
    inv_tau2_spike, inv_tau2_slab = 1000.0, 0.01
    
    with pm.Model() as model:
        alpha = pm.Normal("alpha", mu=0.0, sigma=100.0)
        
        # Layer 1
        gamma1 = pm.Bernoulli("gamma1", p=0.5, shape=(p, q))
        tau2_1 = 1.0 / ((1 - gamma1) * inv_tau2_spike + gamma1 * inv_tau2_slab)
        W1 = pm.Normal("W1", mu=0.0, sigma=pm.math.sqrt(tau2_1), shape=(p, q))
        b1 = pm.Normal("b1", mu=0, sigma=1, shape=q)
        hidden = pm.math.tanh(pm.math.dot(X_train, W1) + b1)
        
        # Layer 2
        gamma2 = pm.Bernoulli("gamma2", p=0.5, shape=q)
        tau2_2 = 1.0 / ((1 - gamma2) * inv_tau2_spike + gamma2 * inv_tau2_slab)
        W2 = pm.Normal("W2", mu=0.0, sigma=pm.math.sqrt(tau2_2), shape=q)
        b2 = pm.Normal("b2", mu=0, sigma=1)
        mu = pm.math.dot(hidden, W2) + b2
        
        # Likelihood with inverse-gamma prior on variance
        inv_sigma2 = pm.Gamma("inv_sigma2", alpha=0.0001, beta=0.0001)
        sigma = pm.math.sqrt(1.0 / inv_sigma2)
        pm.Normal("y_obs", mu=mu, sigma=sigma, observed=y_train)
        
        # Test predictions
        if X_test is not None:
            hidden_test = pm.math.tanh(pm.math.dot(X_test, W1) + b1)
            mu_test = pm.math.dot(hidden_test, W2) + b2
            pm.Normal("y_pred", mu=mu_test, sigma=sigma, shape=X_test.shape[0])
    
    return model

## 4. Helper Functions

Utilities for computing DIC, convergence diagnostics, and parsing grid search logs.

In [None]:
def compute_dic(idata):
    """
    Compute Deviance Information Criterion (DIC) from inference data.
    
    Args:
        idata (arviz.InferenceData): Inference data object containing log likelihood
    
    Returns:
        float: DIC value (lower is better)
    
    Raises:
        KeyError: If 'y_obs' log likelihood is not found in idata
        AttributeError: If idata does not have log_likelihood attribute
    """
    log_lik = idata.log_likelihood["y_obs"].values.reshape(-1, idata.log_likelihood["y_obs"].shape[-1])
    D_bar = -2 * np.mean(log_lik)
    D_theta_bar = -2 * np.sum(np.mean(log_lik, axis=0))
    p_D = D_bar - D_theta_bar
    return D_bar + p_D


def compute_diagnostics(idata, exclude_vars=('y_pred', 'gamma1', 'gamma2')):
    """
    Compute convergence diagnostics: max R-hat, min ESS, divergence count.
    
    Args:
        idata (arviz.InferenceData): Inference data object
        exclude_vars (tuple): Variable names to exclude from diagnostics. 
            Defaults to ('y_pred', 'gamma1', 'gamma2').
    
    Returns:
        tuple: (max_rhat, min_ess, n_divergences)
            - max_rhat (float): Maximum R-hat value across all variables
            - min_ess (float): Minimum effective sample size
            - n_divergences (int): Number of divergent transitions
    
    Raises:
        AttributeError: If idata does not have required attributes
    """
    rhat = az.rhat(idata)
    vars_to_check = [v for v in rhat.data_vars if v not in exclude_vars]
    
    if not vars_to_check:
        return np.nan, np.nan, 0
    
    max_rhat = max(float(rhat[v].max()) for v in vars_to_check)
    ess_bulk = az.ess(idata, method="bulk")
    min_ess = min(float(ess_bulk[v].min()) for v in vars_to_check)
    n_div = int(idata.sample_stats.diverging.values.sum()) if 'diverging' in idata.sample_stats else 0
    
    return max_rhat, min_ess, n_div


def get_model_creator(prior_type):
    """
    Return appropriate model creation function based on prior type.
    
    Args:
        prior_type (str): Type of prior, either "current" or "hw3"
    
    Returns:
        callable: Model creation function (create_bnn_spike_slab or create_bnn_hw3_attempt2)
    
    Raises:
        ValueError: If prior_type is not "current" or "hw3"
    """
    return create_bnn_spike_slab if prior_type == "current" else create_bnn_hw3_attempt2


def log_general(message, log_file=GENERAL_LOG):
    """
    Log general progress messages to file and print to console.
    
    Args:
        message (str): Message to log
        log_file (str or Path): Path to log file. Defaults to GENERAL_LOG.
    
    Returns:
        None
    
    Raises:
        IOError: If log file cannot be written
    """
    timestamp = datetime.datetime.now().isoformat()
    with open(log_file, "a") as f:
        f.write(f"[{timestamp}] {message}\n")
    print(message)


def parse_gridsearch_log(log_file):
    """
    Parse grid search log file and return results DataFrame.
    
    Args:
        log_file (str or Path): Path to grid search log file
    
    Returns:
        pd.DataFrame: DataFrame with columns: prior_type, use_noncentered, draws, 
            tune, q, DIC, Rhat, minESS, Divergences, Converged
    
    Raises:
        IOError: If log file cannot be read
    """
    if not os.path.exists(log_file):
        return pd.DataFrame()
    
    pattern = re.compile(
        r'prior=([^,]+), use_noncentered=([^,]+), draws=([^,]+), tune=([^,]+), '
        r'q=([^,]+), DIC=([^,]+), Rhat=([^,]+), minESS=([^,]+), Divergences=([^\s]+)'
    )
    
    results = []
    with open(log_file, 'r') as f:
        for line in f:
            if '[COMBO_END]' not in line or 'ERROR' in line:
                continue
            match = pattern.search(line)
            if not match:
                continue
            
            use_nc = {'True': True, 'False': False, 'None': None}.get(match.group(2))
            rhat, divs = float(match.group(7)), int(match.group(9))
            
            results.append({
                'prior_type': match.group(1),
                'use_noncentered': use_nc,
                'draws': int(match.group(3)),
                'tune': int(match.group(4)),
                'q': int(match.group(5)),
                'DIC': float(match.group(6)),
                'Rhat': rhat,
                'minESS': float(match.group(8)),
                'Divergences': divs,
                'Converged': rhat < 1.01 and divs == 0
            })
    
    return pd.DataFrame(results)


def get_model_filename(prior_type, use_noncentered, draws, tune, q, models_dir=MODELS_DIR):
    """
    Generate model filename following scikit-learn conventions.
    
    Args:
        prior_type (str): Type of prior ("current" or "hw3")
        use_noncentered (bool or None): Whether non-centered parameterization is used
        draws (int): Number of draws
        tune (int): Number of tuning steps
        q (int): Number of hidden units
        models_dir (Path): Directory to save models. Defaults to MODELS_DIR.
    
    Returns:
        Path: Path object for the model file
    
    Raises:
        ValueError: If prior_type is not recognized
    """
    if prior_type == "current":
        prior_str = f"current_{'nc' if use_noncentered else 'c'}"
    else:
        prior_str = "hw3"
    return models_dir / f"bnn_{prior_str}_d{draws}_t{tune}_q{q}.pkl"


def save_model(idata, prior_type, use_noncentered, draws, tune, q, models_dir=MODELS_DIR):
    """
    Save model using joblib (scikit-learn style).
    
    Args:
        idata (arviz.InferenceData): Inference data to save
        prior_type (str): Type of prior ("current" or "hw3")
        use_noncentered (bool or None): Whether non-centered parameterization is used
        draws (int): Number of draws
        tune (int): Number of tuning steps
        q (int): Number of hidden units
        models_dir (Path): Directory to save models. Defaults to MODELS_DIR.
    
    Returns:
        Path: Path to saved model file
    
    Raises:
        IOError: If model cannot be saved
        ValueError: If prior_type is not recognized
    """
    filename = get_model_filename(prior_type, use_noncentered, draws, tune, q, models_dir)
    dump(idata, filename, compress=3)  # compress=3 for good compression
    return filename


def load_model(prior_type, use_noncentered, draws, tune, q, models_dir=MODELS_DIR):
    """
    Load model using joblib (scikit-learn style).
    
    Args:
        prior_type (str): Type of prior ("current" or "hw3")
        use_noncentered (bool or None): Whether non-centered parameterization is used
        draws (int): Number of draws
        tune (int): Number of tuning steps
        q (int): Number of hidden units
        models_dir (Path): Directory containing models. Defaults to MODELS_DIR.
    
    Returns:
        arviz.InferenceData: Loaded inference data
    
    Raises:
        FileNotFoundError: If model file does not exist
        ValueError: If prior_type is not recognized
    """
    filename = get_model_filename(prior_type, use_noncentered, draws, tune, q, models_dir)
    if filename.exists():
        return load(filename)
    else:
        raise FileNotFoundError(f"Model not found: {filename}")

## 5. Parallelized Grid Search

Search over:
- Prior types: current (non-centered), current (centered), hw3 (attempt2)
- Draws/tune: [5000] (fixed)
- Hidden units q: [2, 3, 4, 5, 6]

- Hidden units q: [2, 3, 4, 5, 6]

Total: 15 combinations (3 prior types Ã— 5 q values), run in parallel across available CPU cores.


In [None]:
def fit_single_config(config):
    """
    Fit a single BNN configuration. Designed to run in a separate process.
    Saves model to disk following scikit-learn conventions.
    
    Args:
        config (dict): Configuration dictionary with keys:
            - prior_type (str): Type of prior ("current" or "hw3")
            - use_noncentered (bool or None): Whether to use non-centered parameterization
            - draws_tune (int): Number of draws and tuning steps
            - q (int): Number of hidden units
            - X_train (np.ndarray): Training features
            - y_train (np.ndarray): Training targets
            - seed (int): Random seed
            - models_dir (Path, optional): Directory to save models. Defaults to 'models'.
    
    Returns:
        dict: Results dictionary with keys:
            - prior_type (str): Prior type used
            - use_noncentered (bool or None): Parameterization used
            - draws (int): Number of draws
            - tune (int): Number of tuning steps
            - q (int): Number of hidden units
            - DIC (float): Deviance Information Criterion
            - Rhat (float): Maximum R-hat value
            - minESS (float): Minimum effective sample size
            - Divergences (int): Number of divergent transitions
            - Converged (bool): Whether model converged (R-hat < 1.01 and no divergences)
            - Duration (float): Fitting duration in seconds
            - model_path (str or None): Path to saved model file
            - Error (str or None): Error message if fitting failed
    
    Raises:
        KeyError: If required keys are missing from config
        ValueError: If prior_type is not recognized
    """
    import warnings
    from pathlib import Path
    from joblib import dump
    warnings.filterwarnings('ignore')
    
    prior_type = config['prior_type']
    use_nc = config['use_noncentered']
    draws_tune = config['draws_tune']
    q = config['q']
    X_train = config['X_train']
    y_train = config['y_train']
    seed = config['seed']
    models_dir = Path(config.get('models_dir', 'models'))
    models_dir.mkdir(exist_ok=True)
    
    start = datetime.datetime.now()
    model_path = None
    
    try:
        # Create model
        if prior_type == "current":
            model = create_bnn_spike_slab(X_train, y_train, q, use_noncentered=use_nc)
        else:
            model = create_bnn_hw3_attempt2(X_train, y_train, q)
        
        # Sample
        with model:
            idata = pm.sample(
                draws=draws_tune, tune=draws_tune, chains=4, cores=1,
                target_accept=0.90, random_seed=seed, init="adapt_diag",
                return_inferencedata=True, progressbar=False
            )
            pm.compute_log_likelihood(idata)
        
        # Save model (scikit-learn style)
        if prior_type == "current":
            prior_str = f"current_{'nc' if use_nc else 'c'}"
        else:
            prior_str = "hw3"
        model_path = models_dir / f"bnn_{prior_str}_d{draws_tune}_t{draws_tune}_q{q}.pkl"
        dump(idata, model_path, compress=3)
        
        # Compute metrics
        dic = compute_dic(idata)
        max_rhat, min_ess, n_div = compute_diagnostics(idata)
        converged = max_rhat < 1.01 and n_div == 0
        duration = (datetime.datetime.now() - start).total_seconds()
        
        return {
            'prior_type': prior_type,
            'use_noncentered': use_nc,
            'draws': draws_tune,
            'tune': draws_tune,
            'q': q,
            'DIC': dic,
            'Rhat': max_rhat,
            'minESS': min_ess,
            'Divergences': n_div,
            'Converged': converged,
            'Duration': duration,
            'model_path': str(model_path),
            'Error': None
        }
        
    except Exception as e:
        duration = (datetime.datetime.now() - start).total_seconds()
        return {
            'prior_type': prior_type,
            'use_noncentered': use_nc,
            'draws': draws_tune,
            'tune': draws_tune,
            'q': q,
            'DIC': np.nan,
            'Rhat': np.nan,
            'minESS': np.nan,
            'Divergences': np.nan,
            'Converged': False,
            'Duration': duration,
            'model_path': None,
            'Error': str(e)[:100]
        }


def run_parallel_grid_search(X_train, y_train, n_workers=None, 
                             gridsearch_log=GRIDSEARCH_LOG, 
                             general_log=GENERAL_LOG,
                             models_dir=MODELS_DIR):
    """
    Run parallelized grid search over BNN hyperparameters using joblib.
    Saves all models to disk and logs progress appropriately.
    
    Searches over:
    - Prior types: current (non-centered), current (centered), hw3 (attempt2)
    - Draws/tune values: [5000] (fixed)
    - Hidden units q: [2, 3, 4, 5, 6]
    Total: 15 combinations (3 prior types Ã— 5 q values)
    
    Args:
        X_train (np.ndarray): Training features of shape (n, p)
        y_train (np.ndarray): Training targets of shape (n,)
        n_workers (int, optional): Number of parallel workers. 
            Defaults to cpu_count() - 1.
        gridsearch_log (str or Path): Path to grid search log file. 
            Defaults to GRIDSEARCH_LOG.
        general_log (str or Path): Path to general progress log file. 
            Defaults to GENERAL_LOG.
        models_dir (Path): Directory to save models. Defaults to MODELS_DIR.
    
    Returns:
        pd.DataFrame: DataFrame with results for all configurations, including:
            - prior_type, use_noncentered, draws, tune, q
            - DIC, Rhat, minESS, Divergences
            - Converged, Duration, model_path, Error
    
    Raises:
        ValueError: If X_train and y_train have incompatible shapes
        IOError: If log files cannot be written
        RuntimeError: If parallel execution fails
    """
    if n_workers is None:
        n_workers = max(1, mp.cpu_count() - 1)
    
    models_dir = Path(models_dir)
    models_dir.mkdir(exist_ok=True)
    
    draws_tune_values = [5000]
    q_values = [2, 3, 4, 5, 6]
    prior_configs = [
        ("current", True),   # non-centered
        ("current", False),  # centered
        ("hw3", None),       # hw3 attempt2
    ]
    
    # Build all configurations
    configs = []
    for prior_type, use_nc in prior_configs:
        for draws_tune in draws_tune_values:
            for q in q_values:
                configs.append({
                    'prior_type': prior_type,
                    'use_noncentered': use_nc,
                    'draws_tune': draws_tune,
                    'q': q,
                    'X_train': X_train,
                    'y_train': y_train,
                    'seed': SEED,
                    'models_dir': models_dir
                })
    
    total = len(configs)
    log_general(f"Starting parallel grid search: {total} combinations using {n_workers} workers", general_log)
    print(f"Parallel Grid Search: {total} combinations using {n_workers} workers")
    print(f"Grid search log: {gridsearch_log}")
    print(f"General log: {general_log}")
    print(f"Models directory: {models_dir}\n")
    
    # Initialize grid search log file
    search_start = datetime.datetime.now()
    with open(gridsearch_log, "w") as f:
        f.write(f"[GRIDSEARCH_START] {search_start.isoformat()}\n")
        f.write(f"[PARALLEL] workers={n_workers}, total_configs={total}\n")
        f.write(f"[MODELS_DIR] {models_dir}\n")
    
    # Define callback function for progress updates
    def log_result(result, completed, total):
        """
        Log a single result to both grid search log and general log.
        
        Args:
            result (dict): Result dictionary from fit_single_config
            completed (int): Number of completed configurations
            total (int): Total number of configurations
        
        Returns:
            None
        
        Raises:
            IOError: If log files cannot be written
        """
        # Format prior label
        if result['prior_type'] == "current":
            prior_label = f"current-{'nc' if result['use_noncentered'] else 'c'}"
        else:
            prior_label = "hw3"
        
        # Log to grid search log
        with open(gridsearch_log, "a") as f:
            if result['Error']:
                f.write(f"[COMBO_END] {datetime.datetime.now().isoformat()} - "
                        f"prior={result['prior_type']}, use_noncentered={result['use_noncentered']}, "
                        f"draws={result['draws']}, tune={result['tune']}, q={result['q']}, "
                        f"ERROR: {result['Error']}\n")
                status = "ERROR"
            else:
                model_info = f", model={result.get('model_path', 'N/A')}" if result.get('model_path') else ""
                f.write(f"[COMBO_END] {datetime.datetime.now().isoformat()} - "
                        f"prior={result['prior_type']}, use_noncentered={result['use_noncentered']}, "
                        f"draws={result['draws']}, tune={result['tune']}, q={result['q']}, "
                        f"DIC={result['DIC']:.2f}, Rhat={result['Rhat']:.4f}, "
                        f"minESS={result['minESS']:.0f}, Divergences={result['Divergences']}"
                        f"{model_info}\n")
                status = "âœ“" if result['Converged'] else "âš "
        
        # Log to general log
        if result['Error']:
            log_general(f"Grid search [{completed}/{total}]: {prior_label} d={result['draws']} q={result['q']} - ERROR: {result['Error'][:40]}", general_log)
        else:
            model_saved = "âœ“" if result.get('model_path') else "âœ—"
            log_general(f"Grid search [{completed}/{total}]: {prior_label} d={result['draws']} q={result['q']} {status} DIC={result['DIC']:.1f} RÌ‚={result['Rhat']:.3f} Model saved: {model_saved}", general_log)
        
        # Progress update to console
        if result['Error']:
            print(f"[{completed:2d}/{total}] {prior_label} d={result['draws']:5d} q={result['q']} âœ— {result['Error'][:40]}")
        else:
            model_indicator = "ðŸ’¾" if result.get('model_path') else "âš "
            print(f"[{completed:2d}/{total}] {prior_label} d={result['draws']:5d} q={result['q']} "
                  f"{status} DIC={result['DIC']:.1f} RÌ‚={result['Rhat']:.3f} ({result['Duration']:.0f}s) {model_indicator}")
    
    # Run in parallel using joblib (handles pickling on macOS gracefully)
    log_general("Starting parallel execution...", general_log)
    results = Parallel(n_jobs=n_workers, backend='loky', verbose=0)(
        delayed(fit_single_config)(cfg) for cfg in configs
    )
    
    # Log all results and print progress
    for i, result in enumerate(results, 1):
        log_result(result, i, total)
    
    results_df = pd.DataFrame(results)
    
    # Report best result
    print(f"\n{'='*60}")
    converged_df = results_df[results_df['Converged']]
    if len(converged_df) > 0:
        best = converged_df.loc[converged_df['DIC'].idxmin()]
        best_msg = f"âœ“ Best converged: DIC={best['DIC']:.2f}, q={int(best['q'])}, draws={int(best['draws'])}"
        print(best_msg)
        log_general(best_msg, general_log)
    else:
        best = results_df.loc[results_df['Rhat'].idxmin()]
        best_msg = f"âš  No converged models. Best R-hat: {best['Rhat']:.4f}"
        print(best_msg)
        log_general(best_msg, general_log)
    
    total_time = results_df['Duration'].sum()
    wall_time = (datetime.datetime.now() - search_start).total_seconds()
    speedup_msg = f"Total CPU time: {total_time/60:.1f} min | Wall time: {wall_time/60:.1f} min | Speedup: {total_time/wall_time:.1f}x"
    print(speedup_msg)
    log_general(speedup_msg, general_log)
    
    # Log model count
    saved_models = results_df['model_path'].notna().sum()
    model_msg = f"Saved {saved_models}/{total} models to {models_dir}"
    print(model_msg)
    log_general(model_msg, general_log)
    
    return results_df

In [None]:
# Initialize general log (already cleared at start, now write header)
with open(GENERAL_LOG, "a") as f:
    f.write(f"[START] {datetime.datetime.now().isoformat()}\n")
    f.write(f"Grid search log: {GRIDSEARCH_LOG}\n")
    f.write(f"Models directory: {MODELS_DIR}\n\n")

log_general("Starting notebook execution", GENERAL_LOG)

# Load existing results or run parallel grid search
grid_results_df = parse_gridsearch_log(GRIDSEARCH_LOG)

if len(grid_results_df) >= 10:
    msg = f"Loaded {len(grid_results_df)} existing results from {GRIDSEARCH_LOG}"
    print(msg)
    log_general(msg, GENERAL_LOG)
else:
    print(f"Running parallel grid search...")
    log_general("Starting parallel grid search", GENERAL_LOG)
    grid_results_df = run_parallel_grid_search(
        X_train, y_train, 
        n_workers=N_WORKERS,
        gridsearch_log=GRIDSEARCH_LOG,
        general_log=GENERAL_LOG,
        models_dir=MODELS_DIR
    )

# Extract best configuration
converged = grid_results_df[grid_results_df['Converged']]
best_config = (converged.loc[converged['DIC'].idxmin()] if len(converged) > 0 
               else grid_results_df.loc[grid_results_df['Rhat'].idxmin()])

BEST_PRIOR_TYPE = best_config['prior_type']
BEST_USE_NONCENTERED = best_config['use_noncentered'] if pd.notna(best_config['use_noncentered']) else True
BEST_DRAWS = int(best_config['draws'])
BEST_TUNE = int(best_config['tune'])

best_msg = f"\nâœ“ Best configuration:\n  Prior: {BEST_PRIOR_TYPE}, Non-centered: {BEST_USE_NONCENTERED}\n  Draws: {BEST_DRAWS}, Tune: {BEST_TUNE}\n  DIC: {best_config['DIC']:.2f}, R-hat: {best_config['Rhat']:.4f}"
if 'model_path' in best_config and pd.notna(best_config['model_path']):
    best_msg += f"\n  Model: {best_config['model_path']}"
print(best_msg)
log_general(best_msg, GENERAL_LOG)

## 6. Problem 1(a): Model Comparison Using DIC

Fit BNN models on the full dataset for q âˆˆ {2, 3, 4, 5, 6} using the best hyperparameters.
Compare models using DIC and convergence diagnostics.

In [None]:
# Sampling configuration
CHAINS = 6
TARGET_ACCEPT = 0.98
Q_VALUES = [2, 3, 4, 5, 6]

model_creator = get_model_creator(BEST_PRIOR_TYPE)
full_results = {}
dic_scores = []

log_general("Starting Problem 1(a): Model comparison using DIC", GENERAL_LOG)

for q in Q_VALUES:
    print(f"Fitting q={q}...", end=" ")
    log_general(f"Fitting model with q={q} for DIC comparison", GENERAL_LOG)
    
    if BEST_PRIOR_TYPE == "current":
        model = model_creator(X_scaled, y_scaled, q, use_noncentered=BEST_USE_NONCENTERED)
    else:
        model = model_creator(X_scaled, y_scaled, q)
    
    with model:
        idata = pm.sample(
            draws=BEST_DRAWS, tune=BEST_TUNE, chains=CHAINS, cores=1,
            target_accept=TARGET_ACCEPT, random_seed=SEED, init="adapt_diag",
            return_inferencedata=True, progressbar=True
        )

    # Save test fit to pkl file
    test_filename = get_model_filename(
        BEST_PRIOR_TYPE, BEST_USE_NONCENTERED, BEST_DRAWS, BEST_TUNE, q
    )
    test_filename = test_filename.parent / f"{test_filename.stem}_test.pkl"
    dump(idata, test_filename, compress=3)
    print(f"Saved test fit to {test_filename.name}")


## 7. Problem 1(b): Test Set Prediction

Evaluate BNN models on the held-out test set using RMSE, MAE, RÂ², and correlation.

In [None]:
log_general("Starting Problem 1(b): Test set prediction", GENERAL_LOG)
test_results = []

for q in Q_VALUES:
    print(f"Evaluating q={q}...", end=" ")
    log_general(f"Evaluating model with q={q} on test set", GENERAL_LOG)
    
    if BEST_PRIOR_TYPE == "current":
        model = model_creator(X_train, y_train, q, X_test, use_noncentered=BEST_USE_NONCENTERED)
    else:
        model = model_creator(X_train, y_train, q, X_test)
    
    with model:
        idata = pm.sample(
            draws=BEST_DRAWS, tune=BEST_TUNE, chains=CHAINS, cores=1,
            target_accept=TARGET_ACCEPT, random_seed=SEED, init="adapt_diag",
            return_inferencedata=True, progressbar=True
        )
    
    # Save test fit to pkl file
    test_filename = get_model_filename(
        BEST_PRIOR_TYPE, BEST_USE_NONCENTERED, BEST_DRAWS, BEST_TUNE, q
    )
    test_filename = test_filename.parent / f"{test_filename.stem}_test.pkl"
    dump(idata, test_filename, compress=3)
    print(f"Saved test fit to {test_filename.name}")
    
    # Compute predictions (median of posterior predictive)
    y_pred_samples = idata.posterior["y_pred"].values
    y_pred_median = np.median(y_pred_samples.reshape(-1, len(y_test)), axis=0)
    
    # Metrics
    rmse = np.sqrt(mean_squared_error(y_test, y_pred_median))
    mae = mean_absolute_error(y_test, y_pred_median)
    r2 = r2_score(y_test, y_pred_median)
    corr = np.corrcoef(y_test, y_pred_median)[0, 1]
    max_rhat, _, _ = compute_diagnostics(idata)
    
    test_results.append({
        'q': q, 'RMSE': rmse, 'MAE': mae, 'R2': r2, 'Correlation': corr,
        'R-hat': max_rhat, 'y_pred_median': y_pred_median
    })
    result_msg = f"q={q}: RMSE={rmse:.4f}, RÂ²={r2:.4f}"
    print(f"RMSE={rmse:.4f}, RÂ²={r2:.4f}")
    log_general(result_msg, GENERAL_LOG)

test_df = pd.DataFrame(test_results)
print("\n" + test_df[['q', 'RMSE', 'MAE', 'R2', 'Correlation']].to_string(index=False))

best_test_q = int(test_df.loc[test_df['RMSE'].idxmin(), 'q'])
best_test_msg = f"Best test performance: q={best_test_q} (RMSE={test_df['RMSE'].min():.4f})"
print(f"\nâœ“ {best_test_msg}")
log_general(f"Problem 1(b) complete: {best_test_msg}", GENERAL_LOG)

## 8. Problem 1(c): Comparison with Bayesian Linear Regression

Fit Bayesian linear regression with spike-and-slab priors using the same hyperparameters.
Compare test set performance with BNN models.

In [None]:
log_general("Starting Problem 1(c): Comparison with Bayesian linear regression", GENERAL_LOG)
print("Fitting Bayesian linear regression...")

with pm.Model() as linear_model:
    # Spike-and-slab priors (non-centered)
    pi = 0.5
    spike_sd, slab_sd = 0.01, 1.0
    
    gamma_lin = pm.Bernoulli("gamma_lin", p=pi, shape=X_train.shape[1])
    beta_raw = pm.Normal("beta_raw", mu=0, sigma=1, shape=X_train.shape[1])
    sd_beta = spike_sd + gamma_lin * (slab_sd - spike_sd)
    beta = pm.Deterministic("beta", beta_raw * sd_beta)
    
    alpha = pm.Normal("alpha", mu=0, sigma=1)
    mu = alpha + pm.math.dot(X_train, beta)
    sigma = pm.HalfNormal("sigma", sigma=1)
    pm.Normal("y_obs", mu=mu, sigma=sigma, observed=y_train)
    
    # Test predictions
    mu_test = alpha + pm.math.dot(X_test, beta)
    pm.Normal("y_pred", mu=mu_test, sigma=sigma, shape=len(y_test))

with linear_model:
    trace_linear = pm.sample(
        BEST_DRAWS, tune=BEST_TUNE, chains=CHAINS, cores=1,
        target_accept=TARGET_ACCEPT, random_seed=SEED,
        return_inferencedata=True, progressbar=True
    )

# Save linear regression test fit to pkl file
linear_test_filename = MODELS_DIR / f"linear_regression_test_d{BEST_DRAWS}_t{BEST_TUNE}.pkl"
dump(trace_linear, linear_test_filename, compress=3)
print(f"Saved linear regression test fit to {linear_test_filename.name}")

max_rhat_lin, min_ess_lin, n_div_lin = compute_diagnostics(trace_linear)
print(f"R-hat: {max_rhat_lin:.4f}, ESS: {min_ess_lin:.0f}, Divergences: {n_div_lin}")

# Linear model predictions
y_pred_linear = trace_linear.posterior["y_pred"].values
y_pred_linear_median = np.median(y_pred_linear.reshape(-1, len(y_test)), axis=0)

rmse_linear = np.sqrt(mean_squared_error(y_test, y_pred_linear_median))
mae_linear = mean_absolute_error(y_test, y_pred_linear_median)
r2_linear = r2_score(y_test, y_pred_linear_median)
corr_linear = np.corrcoef(y_test, y_pred_linear_median)[0, 1]

linear_msg = f"Linear Regression: RMSE={rmse_linear:.4f}, MAE={mae_linear:.4f}, RÂ²={r2_linear:.4f}, Corr={corr_linear:.4f}"
print(f"\n{linear_msg}")
log_general(linear_msg, GENERAL_LOG)
log_general("Problem 1(c) complete", GENERAL_LOG)

## 9. Final Comparison and Summary

In [None]:
print("=" * 60)
print("FINAL RESULTS")
print("=" * 60)

print(f"\nBayesian Linear Regression: RMSE={rmse_linear:.4f}, RÂ²={r2_linear:.4f}")
print("\nBayesian Neural Networks:")
for _, row in test_df.iterrows():
    print(f"  q={int(row['q'])}: RMSE={row['RMSE']:.4f}, RÂ²={row['R2']:.4f}")

best_bnn_rmse = test_df['RMSE'].min()
best_bnn_q = int(test_df.loc[test_df['RMSE'].idxmin(), 'q'])

print("\n" + "-" * 60)
if rmse_linear < best_bnn_rmse:
    print(f"âœ“ Winner: Linear Regression (RMSE={rmse_linear:.4f})")
else:
    print(f"âœ“ Winner: BNN with q={best_bnn_q} (RMSE={best_bnn_rmse:.4f})")

## 10. Visualization

Create a comprehensive 6-panel figure showing:
1. DIC by model size
2. Test RMSE comparison (BNN vs Linear)
3. Test RÂ² comparison
4. Convergence diagnostics (R-hat)
5. Convergence diagnostics (ESS)
6. Predicted vs Actual for best BNN model

In [None]:
fig, axes = plt.subplots(2, 3, figsize=(16, 10))

# Panel 1: DIC by model size
ax = axes[0, 0]
ax.bar(dic_df['q'], dic_df['DIC'], alpha=0.7, edgecolor='black', color='steelblue')
ax.set_xlabel('Hidden units (q)')
ax.set_ylabel('DIC')
ax.set_title('DIC by Model Size', fontweight='bold')
ax.set_xticks(Q_VALUES)
ax.grid(True, alpha=0.3, axis='y')

# Panel 2: Test RMSE comparison
ax = axes[0, 1]
ax.plot(test_df['q'], test_df['RMSE'], 'o-', lw=2, ms=8, label='BNN')
ax.axhline(rmse_linear, color='red', ls='--', lw=2, label='Linear')
ax.set_xlabel('Hidden units (q)')
ax.set_ylabel('RMSE')
ax.set_title('Test RMSE Comparison', fontweight='bold')
ax.set_xticks(Q_VALUES)
ax.legend()
ax.grid(True, alpha=0.3)

# Panel 3: Test RÂ² comparison
ax = axes[0, 2]
ax.plot(test_df['q'], test_df['R2'], 'o-', lw=2, ms=8, label='BNN', color='green')
ax.axhline(r2_linear, color='red', ls='--', lw=2, label='Linear')
ax.set_xlabel('Hidden units (q)')
ax.set_ylabel('RÂ²')
ax.set_title('Test RÂ² Comparison', fontweight='bold')
ax.set_xticks(Q_VALUES)
ax.legend()
ax.grid(True, alpha=0.3)

# Panel 4: R-hat convergence
ax = axes[1, 0]
ax.plot(dic_df['q'], dic_df['R-hat'], 'o-', lw=2, ms=8, color='purple')
ax.axhline(1.01, color='red', ls='--', lw=1, label='Target (1.01)')
ax.set_xlabel('Hidden units (q)')
ax.set_ylabel('Max R-hat')
ax.set_title('Convergence: R-hat', fontweight='bold')
ax.set_xticks(Q_VALUES)
ax.legend()
ax.grid(True, alpha=0.3)

# Panel 5: ESS convergence
ax = axes[1, 1]
ax.plot(dic_df['q'], dic_df['min_ESS'], 'o-', lw=2, ms=8, color='orange')
ax.axhline(400, color='red', ls='--', lw=1, label='Target (400)')
ax.set_xlabel('Hidden units (q)')
ax.set_ylabel('Min ESS')
ax.set_title('Convergence: ESS', fontweight='bold')
ax.set_xticks(Q_VALUES)
ax.legend()
ax.grid(True, alpha=0.3)

# Panel 6: Predicted vs Actual
ax = axes[1, 2]
best_result = test_df.loc[test_df['RMSE'].idxmin()]
y_pred_best = best_result['y_pred_median']
ax.scatter(y_test, y_pred_best, alpha=0.6, edgecolors='black')
ax.plot([y_test.min(), y_test.max()], [y_test.min(), y_test.max()], 'r--', lw=2)
ax.set_xlabel('Actual Crime Rate')
ax.set_ylabel('Predicted Crime Rate')
ax.set_title(f'Best BNN (q={int(best_result["q"])}): Predicted vs Actual', fontweight='bold')
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('crime_bnn_results.png', dpi=150, bbox_inches='tight')
plt.show()

print("\nFigure saved: crime_bnn_results.png")

## 11. Execution Summary

In [None]:
elapsed = (time.time() - EXEC_START) / 60.0
finish_msg = f"Total runtime: {elapsed:.1f} minutes"
print(finish_msg)
print(f"Finished: {datetime.datetime.now().isoformat(timespec='seconds')}")
log_general(finish_msg, GENERAL_LOG)
log_general(f"Notebook execution completed at {datetime.datetime.now().isoformat(timespec='seconds')}", GENERAL_LOG)