### _Setup_

In [None]:
# Reset memory
%reset -f

In [None]:
# Install correct package versions
!pip install "tensorflow[and-cuda]"
!pip uninstall numpy pandas -y
!pip install "numpy<2.0" pandas --upgrade --no-cache-dir

In [None]:
# Packages
from typing import Union, List, Tuple, Dict, Any
import time
import numpy as np
import pandas as pd
from sklearn.model_selection import KFold, StratifiedKFold, StratifiedShuffleSplit, train_test_split
from sklearn.metrics import mean_squared_error, r2_score
import tensorflow as tf
from tensorflow.keras import layers, models, regularizers, initializers, optimizers, callbacks
from tensorflow.keras.layers import Layer
import tensorflow_probability as tfp
import optuna
import matplotlib.pyplot as plt

tfpl = tfp.layers
tfd = tfp.distributions

In [None]:
# GPU check
print("Available GPUs:", tf.config.list_physical_devices('GPU'))

In [None]:
# Data
df = pd.read_csv('data.csv')

### _Functions_

In [None]:
def find_col_index_of_spectra(
    df: pd.DataFrame
) -> int:
    """
    Find the column index where spectral data starts.

    Assumes spectral column names can be converted to float (e.g., "730.5", "731.0").

    Parameters:
        df : Input DataFrame

    Returns:
        Index of the first spectral column, or -1 if not found.
    """
    for idx, col in enumerate(df.columns):
        try:
            float(col)
            return idx
        except (ValueError, TypeError):
            continue
    return -1

def split_train_test(
    df: pd.DataFrame,
    test_variety: str,
    test_season: int       
) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
    """
    Split a DataFrame into one training set and two test sets:

    - Variety test set: Variety == test_variety AND Year == 2024
    - Season test set : Year == test_season 

    The training set excludes all rows that belong to any of the test sets.
    The season test set only includes varieties that are present in the training set.

    Parameters:
        df           : Full pandas DataFrame
        test_variety : Variety used for the test set
        test_season  : Year used for the season test

    Returns:
        df_train        : Training set
        df_test_variety : Test set for specified variety and 2024
        df_test_season  : Test set for specified season (filtered by train varieties)
    """

    # Select test set for the specified variety in year 2024
    df_test_variety = df[
        (df["Variety"] == test_variety) &
        (df["Scan Date Year"] == 2024)
    ]

    # Select test set for the specified season (regardless of variety)
    df_test_season = df[
        df["Scan Date Year"] == test_season
    ]

    # Select training set (exclude test variety and test season)
    df_train = df[
        (df["Variety"] != test_variety) &
        (df["Scan Date Year"] != test_season)
    ]

    # Filter season test set to only include varieties present in training set
    train_varieties = df_train["Variety"].unique()
    df_test_season = df_test_season[
        df_test_season["Variety"].isin(train_varieties)
    ]

    return df_train, df_test_variety, df_test_season

def split_x_y(
    df: pd.DataFrame,
) -> Tuple[np.ndarray, np.ndarray]:
    """
    Split a DataFrame into X (spectral features) and y (target) arrays.
    Assumes find_col_index_of_spectra() is defined globally and returns the index
    where spectral data starts.

    Parameters:
        df : Input DataFrame containing both metadata and spectral data.

    Returns:
        x : NumPy array of shape (n_samples, n_spectral_features)
        y : NumPy array of shape (n_samples, 1) containing Brix values
    """
    # Identify spectral columns (those that can be cast to float, e.g. wavelengths)
    spectra_cols = list(df.columns[find_col_index_of_spectra(df):])

    # Define the target column
    target_cols = ['Brix (Position)']

    # Extract feature and target arrays
    x = df[spectra_cols].values
    y = df[target_cols].values

    return x, y

def take_subset(
    df: pd.DataFrame, 
    n_subset: int,
    random_state: int
) -> pd.DataFrame:
    """
    Return a stratified subset of the DataFrame based on 10 Brix bins.

    If n_subset >= len(df), the original DataFrame is returned.

    Parameters:
        df       : Input DataFrame with 'Brix (Position)' column
        n_subset : Desired subset size
        random_state : Random seed for reproducibility

    Returns:
        Subset of df with stratification over 10 quantile bins of Brix
    """
    # If requested subset size exceeds full dataset, return a copy of the full DataFrame
    if n_subset >= len(df):
        return df.copy()

    # Bin the Brix values into 10 quantile-based bins for stratification
    binned = pd.qcut(df["Brix (Position)"], q=10, labels=False, duplicates='drop')

    # Initialize stratified sampler
    splitter = StratifiedShuffleSplit(
        n_splits=1,
        train_size=n_subset,
        random_state=random_state
    )

    # Perform stratified split and extract subset indices
    idx_subset, _ = next(splitter.split(df, binned))

    # Return the stratified subset as a new DataFrame with reset index
    return df.iloc[idx_subset].reset_index(drop=True)

def create_train_val_split(
    df: pd.DataFrame,
    validation_size: float,
    random_state: int
) -> Tuple[pd.DataFrame, pd.DataFrame]:
    """
    Split a DataFrame into train and validation sets using stratified sampling
    based on 10 quantile bins of the 'Brix (Position)' column.

    Parameters:
        df              : Input DataFrame
        validation_size : Proportion of validation samples (0 < float < 1)
        random_state    : Seed for reproducibility

    Returns:
        df_train, df_val : Stratified training and validation DataFrames
    """
    # Bin the Brix values into 10 quantile-based bins for stratified splitting
    binned = pd.qcut(df["Brix (Position)"], q=10, labels=False, duplicates="drop")

    # Perform stratified train/validation split based on the binned Brix values
    df_train, df_val = train_test_split(
        df,
        test_size=validation_size,
        random_state=random_state,
        stratify=binned
    )

    # Return splits with reset indices
    return df_train.reset_index(drop=True), df_val.reset_index(drop=True)

def rmse_loss(
    y_true, 
    y_pred
):
    """
    Compute the Root Mean Squared Error (RMSE) as a loss function.

    Parameters:
        y_true : Tensor of true target values
        y_pred : Tensor of predicted values

    Returns:
        RMSE as a scalar Tensor
    """
    return tf.sqrt(tf.reduce_mean(tf.square(y_pred - y_true)))  

def rmse_metric(
    y_true, 
    y_pred
):
    """
    Compute the Root Mean Squared Error (RMSE) as a performance metric.

    Parameters:
        y_true : Tensor of true target values
        y_pred : Tensor of predicted values

    Returns:
        RMSE as a scalar Tensor
    """
    return tf.sqrt(tf.reduce_mean(tf.square(y_pred - y_true)))  

def bcnn_model(
    input_shape: int,
    kernel_size: int,
    dropout_rate: float,
    l2_strength: float,
    learning_rate: float,
    random_state: int,
    kl_scale: float
) -> tf.keras.Model:
    """
    Build and compile a Bayesian Convolutional Neural Network (BCNN) model.

    The model includes:
    - A deterministic 1D convolutional layer with L2 regularization.
    - Bayesian dense layers using variational inference (DenseReparameterization).
    - Dropout for regularization.
    - KL divergence scaled by `kl_scale` for regularization of Bayesian layers.

    Parameters:
        input_shape    : Number of input features.
        kernel_size    : Convolutional kernel size.
        dropout_rate   : Dropout rate applied after key layers.
        l2_strength    : L2 regularization strength for the Conv1D layer.
        learning_rate  : Learning rate for the Adam optimizer.
        random_state   : Random seed for weight initialization.
        kl_scale       : Scaling factor for KL divergence in Bayesian layers.

    Returns:
        model : Compiled Keras model ready for training.
    """
    # Define L2 regularizer and HeNormal initializer for Conv1D
    kernel_reg  = regularizers.l2(l2_strength)
    kernel_init = initializers.HeNormal(seed=random_state)

    model = models.Sequential([
        # Input layer reshaping the flat input into 1D format for Conv1D
        tf.keras.Input(shape=(input_shape,)),
        layers.Reshape((input_shape, 1)),

        # Deterministic 1D convolutional layer
        layers.Conv1D(
            filters=1,
            kernel_size=kernel_size,
            padding="same",
            activation="elu",
            kernel_initializer=kernel_init,
            kernel_regularizer=kernel_reg
        ),

        # Dropout regularization
        layers.Dropout(dropout_rate),
        layers.Flatten(),

        # First Bayesian dense layer (variational)
        tfpl.DenseReparameterization(
            units=36,
            activation="elu",
            kernel_prior_fn=tfpl.default_multivariate_normal_fn,
            kernel_posterior_fn=tfpl.default_mean_field_normal_fn(),
            kernel_divergence_fn=lambda q, p, _: tfd.kl_divergence(q, p) * kl_scale
        ),
        layers.Dropout(dropout_rate),

        # Second Bayesian dense layer
        tfpl.DenseReparameterization(
            units=18,
            activation="elu",
            kernel_prior_fn=tfpl.default_multivariate_normal_fn,
            kernel_posterior_fn=tfpl.default_mean_field_normal_fn(),
            kernel_divergence_fn=lambda q, p, _: tfd.kl_divergence(q, p) * kl_scale
        ),
        layers.Dropout(dropout_rate),

        # Third Bayesian dense layer
        tfpl.DenseReparameterization(
            units=12,
            activation="elu",
            kernel_prior_fn=tfpl.default_multivariate_normal_fn,
            kernel_posterior_fn=tfpl.default_mean_field_normal_fn(),
            kernel_divergence_fn=lambda q, p, _: tfd.kl_divergence(q, p) * kl_scale
        ),

        # Output Bayesian dense layer (linear activation for regression)
        tfpl.DenseReparameterization(
            units=1,
            activation="linear",
            kernel_prior_fn=tfpl.default_multivariate_normal_fn,
            kernel_posterior_fn=tfpl.default_mean_field_normal_fn(),
            kernel_divergence_fn=lambda q, p, _: tfd.kl_divergence(q, p) * kl_scale
        )
    ])

    # Compile model with RMSE loss and metric
    model.compile(
        optimizer=tf.keras.optimizers.Adam(learning_rate=learning_rate),
        loss=lambda y_true, y_pred: tf.sqrt(tf.reduce_mean(tf.square(y_pred - y_true))),
        metrics=[lambda y_true, y_pred: tf.sqrt(tf.reduce_mean(tf.square(y_pred - y_true)))]
    )

    return model

def train_bcnn(
    x_train: np.ndarray,
    y_train: np.ndarray,
    input_shape: int,
    kernel_size: int,
    dropout_rate: float,
    l2_strength: float,
    random_state: int,
    kl_scale: float,
    batch_size: int,
    epochs: int,
    patience_reduce_lr: int,
    patience_early_stop: int,
    min_lr: float,
    x_val: np.ndarray,
    y_val: np.ndarray,
    verbose: int = 0
) -> Tuple[tf.keras.Model, tf.keras.callbacks.History]:
    """
    Train a Bayesian Convolutional Neural Network (BCNN) model with optional validation.

    Uses KL-divergence regularized variational layers, with learning rate scheduling
    and early stopping support.

    Parameters:
        x_train, y_train       : Training data and labels
        input_shape            : Number of spectral input features
        kernel_size            : Convolutional kernel size
        dropout_rate           : Dropout rate for regularization
        l2_strength            : L2 penalty for deterministic Conv1D layer
        random_state           : Seed for reproducibility
        kl_scale               : Weighting for KL divergence in variational layers
        batch_size             : Mini-batch size
        epochs                 : Maximum number of training epochs
        patience_reduce_lr     : Patience for learning rate reduction callback
        patience_early_stop    : Patience for early stopping callback
        min_lr                 : Minimum learning rate allowed by scheduler
        x_val, y_val           : Optional validation set
        verbose                : Verbosity level (0 = silent, 1 = progress)

    Returns:
        model   : Trained BCNN model
        history : Keras training history object
    """

    # Build the Bayesian CNN model
    model = bcnn_model(
        input_shape=input_shape,
        kernel_size=kernel_size,
        dropout_rate=dropout_rate,
        l2_strength=l2_strength,
        learning_rate=0.01 * batch_size / 256,
        random_state=random_state,
        kl_scale=kl_scale
    )

    # Determine whether to monitor training loss or validation loss
    if x_val is not None and y_val is not None:
        monitor_metric = "val_loss"
        validation_data = (x_val, y_val)
    else:
        monitor_metric = "loss"
        validation_data = None

    # Configure training callbacks for learning rate scheduling and early stopping
    cb = [
        callbacks.ReduceLROnPlateau(
            monitor=monitor_metric,
            factor=0.5,
            patience=patience_reduce_lr,
            min_lr=min_lr,
            verbose=0
        ),
        callbacks.EarlyStopping(
            monitor=monitor_metric,
            patience=patience_early_stop,
            restore_best_weights=True,
            verbose=0
        )
    ]

    # Train the model with optional validation
    history = model.fit(
        x_train,
        y_train,
        validation_data=validation_data,
        epochs=epochs,
        batch_size=batch_size,
        callbacks=cb,
        verbose=verbose
    )

    return model, history

def perform_optuna_hyperparameter_optimization(
    x_train: np.ndarray,
    y_train: np.ndarray,
    x_val: np.ndarray,
    y_val: np.ndarray,
    input_shape: int,
    random_state: int,
    epochs: int,
    patience_reduce_lr: int,
    patience_early_stop: int,
    min_lr: float,
    kernel_size_range: Tuple[int, int],
    batch_size_list: list,
    dropout_range: Tuple[float, float],
    l2_range: Tuple[float, float],
    kl_range: Tuple[float, float],
    timeout_time: float
) -> Tuple['optuna.study.Study', float, dict]:
    """
    Perform Optuna-based hyperparameter tuning for a Bayesian CNN (BCNN) model.

    Parameters:
        x_train, y_train       : Training features and labels
        x_val, y_val           : Validation features and labels
        input_shape            : Number of spectral input features
        random_state           : Seed for reproducibility
        epochs                 : Maximum number of training epochs
        patience_reduce_lr     : ReduceLROnPlateau callback patience
        patience_early_stop    : EarlyStopping callback patience
        min_lr                 : Minimum learning rate for LR scheduler
        kernel_size_range      : Tuple (min, max) kernel sizes to test
        batch_size_list        : List of candidate batch sizes
        dropout_range          : Tuple (min, max) dropout values
        l2_range               : Tuple (min, max) L2 regularization strengths
        kl_range               : Tuple (min, max) KL divergence weights
        timeout_time           : Max tuning time allowed (in seconds)

    Returns:
        study          : Optuna study object
        best_val_rmse  : Best RMSE on validation set
        best_params    : Dictionary of best trial hyperparameters
    """

    def objective(trial):
        # Suggest hyperparameters using Optuna's search space
        kernel_size   = trial.suggest_int("kernel_size", kernel_size_range[0], kernel_size_range[1])
        batch_size    = trial.suggest_categorical("batch_size", batch_size_list)
        dropout_rate  = trial.suggest_float("dropout_rate", dropout_range[0], dropout_range[1])
        l2_strength   = trial.suggest_float("l2_strength", l2_range[0], l2_range[1], log=True)
        kl_scale      = trial.suggest_float("kl_scale", kl_range[0], kl_range[1], log=True)

        # Learning rate is scaled based on batch size
        learning_rate = 0.01 * batch_size / 256

        # Print trial configuration
        print(f"\n[Optuna Trial {trial.number}] Hyperparameters:")
        print(f"  kernel_size   = {kernel_size}")
        print(f"  batch_size    = {batch_size}")
        print(f"  dropout_rate  = {dropout_rate:.4f}")
        print(f"  l2_strength   = {l2_strength:.2e}")
        print(f"  learning_rate = {learning_rate:.2e}")
        print(f"  kl_scale      = {kl_scale:.2e}")

        # Train the model using the suggested parameters
        model, history = train_bcnn(
            x_train=x_train,
            y_train=y_train,
            x_val=x_val,
            y_val=y_val,
            input_shape=input_shape,
            kernel_size=kernel_size,
            dropout_rate=dropout_rate,
            l2_strength=l2_strength,
            random_state=random_state,
            kl_scale=kl_scale,
            batch_size=batch_size,
            epochs=epochs,
            patience_reduce_lr=patience_reduce_lr,
            patience_early_stop=patience_early_stop,
            min_lr=min_lr,
            verbose=0
        )

        # Evaluate model on validation set
        y_pred = model.predict(x_val, batch_size=batch_size, verbose=0).flatten()
        y_true = y_val.flatten()
        val_rmse = float(np.sqrt(mean_squared_error(y_true, y_pred)))
        print(f"  Validation RMSE: {val_rmse:.5f}")

        return val_rmse

    # Create and run Optuna study
    study = optuna.create_study(direction="minimize")
    study.optimize(objective, timeout=timeout_time)

    # Extract best result
    best_val_rmse = study.best_value
    best_params = study.best_trial.params

    return study, best_val_rmse, best_params

def test_bcnn(
    model: tf.keras.Model,
    x_test_data: np.ndarray,
    y_test_data: np.ndarray,
    batch_size: int,
    num_monte_carlo: int
) -> tuple:
    """
    Evaluate a trained Bayesian CNN model on a hold-out test set using Monte Carlo sampling.

    Parameters:
        model            : Trained BCNN Keras model
        x_test_data      : Test input features
        y_test_data      : True target values
        batch_size       : Batch size for model prediction
        num_monte_carlo  : Number of MC passes to sample model uncertainty

    Returns:
        test_rmsep              : Root mean square error of prediction
        test_r2                 : R² score
        test_practical_accuracy: % predictions within ±20% of true value
        df_predictions          : DataFrame with MC samples, predicted mean, and ground truth
    """
    # Perform MC passes to estimate predictive uncertainty
    mc_preds = []
    for i in range(num_monte_carlo):
        preds = model.predict(x_test_data, batch_size=batch_size, verbose=0)
        mc_preds.append(preds.flatten())

    # Stack MC predictions into [n_samples, num_monte_carlo] shape
    mc_preds = np.stack(mc_preds, axis=1)

    # Create DataFrame with all MC samples
    df_predictions = pd.DataFrame(
        mc_preds, 
        columns=[f"mc_pass_{i+1}" for i in range(num_monte_carlo)]
    )

    # Compute mean prediction and add to DataFrame with observed values
    df_predictions["predicted"] = df_predictions.mean(axis=1)
    df_predictions["observed"] = y_test_data.flatten()

    # Evaluation metrics
    y_pred = df_predictions["predicted"].values
    y_true = df_predictions["observed"].values

    test_rmsep = float(np.sqrt(mean_squared_error(y_true, y_pred)))
    test_r2 = float(r2_score(y_true, y_pred))
    pct_error = np.abs(y_pred - y_true) / np.abs(y_true)
    test_practical_accuracy = float((pct_error <= 0.2).mean() * 100.0)

    # Print evaluation results
    print(f"Test RMSEP: {test_rmsep:.4f}")
    print(f"Test R²: {test_r2:.4f}")
    print(f"Practical accuracy (±20%): {test_practical_accuracy:.1f}%")

    # Parity plot: predicted vs. observed
    plt.figure(figsize=(8, 6))
    plt.scatter(y_true, y_pred, alpha=0.7, label="Test Data")
    plt.plot([y_true.min(), y_true.max()], [y_true.min(), y_true.max()], "k--", lw=2, label="Ideal")
    plt.xlabel("Observed")
    plt.ylabel("Predicted")
    plt.title("Observed vs. Predicted on Test Set (BCNN)")
    plt.legend()
    plt.grid(True)
    plt.show()

    return test_rmsep, test_r2, test_practical_accuracy, df_predictions


### _Parameters_

In [None]:
DF                              = df
N_SUBSET                        = 23690
VALIDATION_SIZE                 = 0.1
TEST_VARIETY                    = "TestVariety"
TEST_SEASON                     = 2025
RANDOM_STATE                    = 27

PATIENCE_CALLBACK_REDUCE_LR     = 25
PATIENCE_CALLBACK_EARLY_STOP    = 50
MIN_LR                          = 1e-6

KERNEL_SIZE_RANGE               = (3, 1025)
BATCH_SIZE_OPTIONS              = [32, 64, 128, 256, 512, 1024]     
DROPOUT_RANGE                   = (0.01, 0.4)
L2_RANGE                        = (1e-6, 1e-2)
KL_SCALE_RANGE                  = (1e-6, 1e-2)
TIMEOUT_TIME                    = 60 * 60 * 72

TRAIN_EPOCHS                    = 250
TEST_EPOCHS                     = 1000
NUM_MONTE_CARLO                 = 100

In [None]:
# Reduce warnings
import warnings
warnings.filterwarnings("ignore", category=UserWarning, module="tensorflow_probability")

### _Run_

In [None]:
# === Split into train and test sets ===
df_train_all, df_test_variety, df_test_season = split_train_test(
    df,
    test_variety=TEST_VARIETY,
    test_season=TEST_SEASON
)

# === Take subset ===
df_subset = take_subset(
    df_train_all, 
    n_subset=N_SUBSET, 
    random_state=RANDOM_STATE
)

# === Make train/validation split ===
df_train, df_val = create_train_val_split(
    df=df_subset,
    validation_size=VALIDATION_SIZE,
    random_state=RANDOM_STATE
)

# === Convert to x and y arrays ===
x_train_all, y_train_all = split_x_y(
    df_train_all,
)
x_train, y_train = split_x_y(
    df_train,
)
x_val, y_val = split_x_y(
    df_val,
)
x_test_variety, y_test_variety = split_x_y(
    df_test_variety,
)
x_test_season, y_test_season = split_x_y(
    df_test_season,
)

study, best_val_rmse, best_params = perform_optuna_hyperparameter_optimization(
    x_train=x_train,
    y_train=y_train,
    x_val=x_val,
    y_val=y_val,
    input_shape=x_train.shape[1],
    random_state=RANDOM_STATE,
    epochs=TRAIN_EPOCHS,
    patience_reduce_lr=PATIENCE_CALLBACK_REDUCE_LR,
    patience_early_stop=PATIENCE_CALLBACK_EARLY_STOP,
    min_lr=MIN_LR,
    kernel_size_range=KERNEL_SIZE_RANGE,
    batch_size_list=BATCH_SIZE_OPTIONS,
    dropout_range=DROPOUT_RANGE,
    l2_range=L2_RANGE,
    kl_range=KL_SCALE_RANGE,
    timeout_time=TIMEOUT_TIME
)

print("\nBest params:", best_params)
print("Best validation RMSE:", best_val_rmse)

# === Retrain the BCNN on all training data with the optimal hyperparameter setting ===
bcnn_trained, _ = train_bcnn(
    x_train=x_train_all,
    y_train=y_train_all,
    x_val=None,
    y_val=None,
    input_shape=x_train_all.shape[1],
    kernel_size=best_params["kernel_size"],
    dropout_rate=best_params["dropout_rate"],
    l2_strength=best_params["l2_strength"],
    random_state=RANDOM_STATE,
    kl_scale=best_params["kl_scale"],
    batch_size=best_params["batch_size"],
    epochs=TEST_EPOCHS,
    patience_reduce_lr=PATIENCE_CALLBACK_REDUCE_LR,
    patience_early_stop=PATIENCE_CALLBACK_EARLY_STOP,
    min_lr=MIN_LR
)

# === Test on the test sets ===
rmsep_variety, r2_variety, acc_variety, df_predictions_variety = test_bcnn(
    model=bcnn_trained,
    x_test_data=x_test_variety,
    y_test_data=y_test_variety,
    batch_size=best_params["batch_size"],
    num_monte_carlo=NUM_MONTE_CARLO
)
print(f"VARIETY: RMSEP={rmsep_variety:.3f}, R2={r2_variety:.3f}, ACC(±20%)={acc_variety:.1f}%s")

rmsep_season, r2_season, acc_season, df_predictions_season = test_bcnn(
    model=bcnn_trained,
    x_test_data=x_test_season,
    y_test_data=y_test_season,
    batch_size=best_params["batch_size"],
    num_monte_carlo=NUM_MONTE_CARLO
)
print(f"SEASON:  RMSEP={rmsep_season:.3f}, R2={r2_season:.3f}, ACC(±20%)={acc_season:.1f}%s")

# === Gather results ===
results = {
    "Test Set": ["VARIETY", "SEASON"],
    "RMSEP": [rmsep_variety, rmsep_season],
    "R2": [r2_variety, r2_season],
    "Accuracy (±20%)": [acc_variety, acc_season]
}

# Create a DataFrame
df_results = pd.DataFrame(results)

### _Inference Time Analysis_

In [None]:
def get_inference_sample_set(
    df_variety: pd.DataFrame,
    df_season: pd.DataFrame,
    random_state: int,
    sample_size: int = 1000
) -> Tuple[np.ndarray, np.ndarray]:
    """
    Combine two test sets (variety and season), sample rows randomly, and return X and y arrays.

    Parameters:
        df_variety   : DataFrame for variety-based test set
        df_season    : DataFrame for season-based test set
        random_state : Random seed for reproducibility
        sample_size  : Number of rows to sample from combined test set

    Returns:
        x_sample : NumPy array of shape (sample_size, n_features) with spectral features
        y_sample : NumPy array of shape (sample_size,) with corresponding Brix values
    """
    # Combine the two test sets
    df_combined = pd.concat([df_variety, df_season], axis=0)

    # Randomly sample rows from the combined test set
    df_sample = df_combined.sample(
        n=sample_size,
        random_state=random_state
    )

    # Split into X and y arrays
    x_sample, y_sample = split_x_y(df_sample)

    return x_sample, y_sample

def test_inference_time_bcnn(
    model: tf.keras.Model,
    x_test_data: np.ndarray,
    num_monte_carlo: int
) -> float:
    """
    Measure average inference time per test sample for a Bayesian CNN using Monte Carlo sampling.

    Parameters:
        model            : Trained BCNN model
        x_test_data      : Test feature matrix [n_samples, n_features]
        num_monte_carlo  : Number of Monte Carlo forward passes per sample

    Returns:
        avg_inference_time_ms : Average inference time per sample in milliseconds
    """
    times = []

    # Warm-up forward pass to avoid first-run latency
    _ = model(x_test_data[:1], training=True)

    for x in x_test_data:
        x_input = np.expand_dims(x, axis=0)

        start = time.time()
        for _ in range(num_monte_carlo):
            _ = model(x_input, training=True).numpy()
        end = time.time()

        times.append(end - start)

    avg_inference_time_ms = np.mean(times) * 1e3
    print(f"Average inference time: {avg_inference_time_ms:.3f} ms/sample")

    return avg_inference_time_ms


In [None]:
# === Create sample set for inference time measurement ===
x_inference_time, y_inference_time = get_inference_sample_set(
    df_test_variety,
    df_test_season,
    random_state=RANDOM_STATE
)

# === Compute the average inference time ===
bcnn_time_ms = test_inference_time_bcnn(
    model=bcnn_trained,
    x_test_data=x_inference_time,
    num_monte_carlo=NUM_MONTE_CARLO
)