# NAS - Optuna

- **Authored by:** Matheus Ferreira Silva 
- **GitHub:**: https://github.com/MatheusFS-dev

## 1. Setup and Configuration

### 1.1. Environment Variables

In [None]:
import os

# Async CUDA allocator
os.environ['TF_GPU_ALLOCATOR'] = 'cuda_malloc_async'

# If cuDNN autotune fails, fall back to a safe (but slower) algorithm.
os.environ["XLA_FLAGS"] = "--xla_gpu_strict_conv_algorithm_picker=false"

# Allow TensorFlow to allocate GPU memory as needed
os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = 'true' 

### 1.2. Imports

In [None]:
from _imports import * # Centralized file containing all imports

### 1.3. GPU Management

In [None]:
# Specify GPU to use (e.g., GPU 0)
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

## 2. Run Parameters 

In [None]:
NUM_TRIALS = 200
EPOCHS = 50
TOP_K = 5  # Number of top trials to save

TOTAL_NUM_PORTS = 100
observed_ports_list = [3, 4, 5, 6, 7, 10, 15]

THRESHOLD = 1
SNR_LINEAR = 2.5

mixed_precision.set_global_policy("mixed_float16")

#? Set to an existing path to resume training
RESUME_TRAINING_PATH = "runs/nas_cnn1d_v0" # None or "runs/nas_1" 

In [None]:
RUN_DIR = RESUME_TRAINING_PATH
print(f"Run directory: {RUN_DIR}")

## 3. Data Loading and Preprocessing

In [None]:
# --------------------- Load the dataset in matlab format -------------------- #
rng = np.random.default_rng(42)

kappa0_mu1_m0 = scipy.io.loadmat("./data/w1_u1_n100/SNR_events_W1.0_U1_N100_kappa1.0e-16_mu1.0_m0.0.mat")["SNR_events"]
kappa0_mu1_m2 = scipy.io.loadmat("./data/w1_u1_n100/SNR_events_W1.0_U1_N100_kappa1.0e-16_mu1.0_m2.0.mat")["SNR_events"]
kappa0_mu1_m50 = scipy.io.loadmat("./data/w1_u1_n100/SNR_events_W1.0_U1_N100_kappa1.0e-16_mu1.0_m50.0.mat")["SNR_events"]
kappa0_mu2_m50 = scipy.io.loadmat("./data/w1_u1_n100/SNR_events_W1.0_U1_N100_kappa1.0e-16_mu2.0_m50.0.mat")["SNR_events"]
kappa0_mu5_m50 = scipy.io.loadmat("./data/w1_u1_n100/SNR_events_W1.0_U1_N100_kappa1.0e-16_mu5.0_m50.0.mat")["SNR_events"]
kappa5_mu1_m0 = scipy.io.loadmat("./data/w1_u1_n100/SNR_events_W1.0_U1_N100_kappa5.0e+00_mu1.0_m0.0.mat")["SNR_events"]
kappa5_mu1_m2 = scipy.io.loadmat("./data/w1_u1_n100/SNR_events_W1.0_U1_N100_kappa5.0e+00_mu1.0_m2.0.mat")["SNR_events"]
kappa5_mu1_m50 = scipy.io.loadmat("./data/w1_u1_n100/SNR_events_W1.0_U1_N100_kappa5.0e+00_mu1.0_m50.0.mat")["SNR_events"]
kappa5_mu2_m0 = scipy.io.loadmat("./data/w1_u1_n100/SNR_events_W1.0_U1_N100_kappa5.0e+00_mu2.0_m0.0.mat")["SNR_events"]
kappa5_mu2_m2 = scipy.io.loadmat("./data/w1_u1_n100/SNR_events_W1.0_U1_N100_kappa5.0e+00_mu2.0_m2.0.mat")["SNR_events"]
kappa5_mu2_m50 = scipy.io.loadmat("./data/w1_u1_n100/SNR_events_W1.0_U1_N100_kappa5.0e+00_mu2.0_m50.0.mat")["SNR_events"]
kappa5_mu5_m0 = scipy.io.loadmat("./data/w1_u1_n100/SNR_events_W1.0_U1_N100_kappa5.0e+00_mu5.0_m0.0.mat")["SNR_events"]
kappa5_mu5_m2 = scipy.io.loadmat("./data/w1_u1_n100/SNR_events_W1.0_U1_N100_kappa5.0e+00_mu5.0_m2.0.mat")["SNR_events"]
kappa5_mu5_m50 = scipy.io.loadmat("./data/w1_u1_n100/SNR_events_W1.0_U1_N100_kappa5.0e+00_mu5.0_m50.0.mat")["SNR_events"]

# ————————————— Split the data into 10% training and 90% testing ————————————— #

# kappa0_mu1_m0
perm = rng.permutation(kappa0_mu1_m0.shape[0])
n_test = int(0.9*kappa0_mu1_m0.shape[0])
kappa0_mu1_m0_test = kappa0_mu1_m0[perm[:n_test]]
kappa0_mu1_m0 = kappa0_mu1_m0[perm[n_test:]]

# kappa0_mu1_m2
perm = rng.permutation(kappa0_mu1_m2.shape[0])
n_test = int(0.9*kappa0_mu1_m2.shape[0])
kappa0_mu1_m2_test = kappa0_mu1_m2[perm[:n_test]]
kappa0_mu1_m2 = kappa0_mu1_m2[perm[n_test:]]

# kappa0_mu1_m50
perm = rng.permutation(kappa0_mu1_m50.shape[0])
n_test = int(0.9*kappa0_mu1_m50.shape[0])
kappa0_mu1_m50_test = kappa0_mu1_m50[perm[:n_test]]
kappa0_mu1_m50 = kappa0_mu1_m50[perm[n_test:]]

# kappa0_mu2_m50
perm = rng.permutation(kappa0_mu2_m50.shape[0])
n_test = int(0.9*kappa0_mu2_m50.shape[0])
kappa0_mu2_m50_test = kappa0_mu2_m50[perm[:n_test]]
kappa0_mu2_m50 = kappa0_mu2_m50[perm[n_test:]]

# kappa0_mu5_m50
perm = rng.permutation(kappa0_mu5_m50.shape[0])
n_test = int(0.9*kappa0_mu5_m50.shape[0])
kappa0_mu5_m50_test = kappa0_mu5_m50[perm[:n_test]]
kappa0_mu5_m50 = kappa0_mu5_m50[perm[n_test:]]

# kappa5_mu1_m0
perm = rng.permutation(kappa5_mu1_m0.shape[0])
n_test = int(0.9*kappa5_mu1_m0.shape[0])
kappa5_mu1_m0_test = kappa5_mu1_m0[perm[:n_test]]
kappa5_mu1_m0 = kappa5_mu1_m0[perm[n_test:]]

# kappa5_mu1_m2
perm = rng.permutation(kappa5_mu1_m2.shape[0])
n_test = int(0.9*kappa5_mu1_m2.shape[0])
kappa5_mu1_m2_test = kappa5_mu1_m2[perm[:n_test]]
kappa5_mu1_m2 = kappa5_mu1_m2[perm[n_test:]]

# kappa5_mu1_m50
perm = rng.permutation(kappa5_mu1_m50.shape[0])
n_test = int(0.9*kappa5_mu1_m50.shape[0])
kappa5_mu1_m50_test = kappa5_mu1_m50[perm[:n_test]]
kappa5_mu1_m50 = kappa5_mu1_m50[perm[n_test:]]

# kappa5_mu2_m0
perm = rng.permutation(kappa5_mu2_m0.shape[0])
n_test = int(0.9*kappa5_mu2_m0.shape[0])
kappa5_mu2_m0_test = kappa5_mu2_m0[perm[:n_test]]
kappa5_mu2_m0 = kappa5_mu2_m0[perm[n_test:]]

# kappa5_mu2_m2
perm = rng.permutation(kappa5_mu2_m2.shape[0])
n_test = int(0.9*kappa5_mu2_m2.shape[0])
kappa5_mu2_m2_test = kappa5_mu2_m2[perm[:n_test]]
kappa5_mu2_m2 = kappa5_mu2_m2[perm[n_test:]]

# kappa5_mu2_m50
perm = rng.permutation(kappa5_mu2_m50.shape[0])
n_test = int(0.9*kappa5_mu2_m50.shape[0])
kappa5_mu2_m50_test = kappa5_mu2_m50[perm[:n_test]]
kappa5_mu2_m50 = kappa5_mu2_m50[perm[n_test:]]

# kappa5_mu5_m0
perm = rng.permutation(kappa5_mu5_m0.shape[0])
n_test = int(0.9*kappa5_mu5_m0.shape[0])
kappa5_mu5_m0_test = kappa5_mu5_m0[perm[:n_test]]
kappa5_mu5_m0 = kappa5_mu5_m0[perm[n_test:]]

# kappa5_mu5_m2
perm = rng.permutation(kappa5_mu5_m2.shape[0])
n_test = int(0.9*kappa5_mu5_m2.shape[0])
kappa5_mu5_m2_test = kappa5_mu5_m2[perm[:n_test]]
kappa5_mu5_m2 = kappa5_mu5_m2[perm[n_test:]]

# kappa5_mu5_m50
perm = rng.permutation(kappa5_mu5_m50.shape[0])
n_test = int(0.9*kappa5_mu5_m50.shape[0])
kappa5_mu5_m50_test = kappa5_mu5_m50[perm[:n_test]]
kappa5_mu5_m50 = kappa5_mu5_m50[perm[n_test:]]

# ————————————— Concatenate all training subsamples along axis=0 ————————————— #
dataset = np.concatenate(
    [
        kappa0_mu1_m0,
        kappa0_mu1_m2,
        kappa0_mu1_m50,
        kappa0_mu2_m50,
        kappa0_mu5_m50,
        kappa5_mu1_m0,
        kappa5_mu1_m2,
        kappa5_mu1_m50,
        kappa5_mu2_m0,
        kappa5_mu2_m2,
        kappa5_mu2_m50,
        kappa5_mu5_m0,
        kappa5_mu5_m2,
        kappa5_mu5_m50,
    ],
    axis=0,
)

print(f"Original dataset shape: {dataset.shape}")

# Subsample data
# dataset = dataset[: int(0.01 * dataset.shape[0]), :]

print(f"Shape of the data after configuration: {dataset.shape}\n")

## 4. Getters

### 4.1. Regularizers

In [None]:
def get_regularizer(trial: optuna.Trial, name: str) -> Optional[tf.keras.regularizers.Regularizer]:
    """
    Suggests a regularization strategy using Optuna and returns the corresponding Keras regularizer.
    
    Args:
        trial (optuna.Trial): Optuna trial object used to sample the regularizer.
        name (str): Unique identifier for this regularizer parameter (used as key).

    Returns:
        Optional[tf.keras.regularizers.Regularizer]: The selected Keras regularizer instance,
        or `None` if "none" was selected.
    """
    # Suggest a regularizer type
    reg_type: str = trial.suggest_categorical(
        name,
        [
            "none",
            "l1",
            "l2",
            "l1l2",
            # "orthogonal",  #! only works for rank-2 tensors
        ],
    )

    # Map each regularizer name to a corresponding Keras regularizer instance
    regularizer_map: Dict[str, Optional[tf.keras.regularizers.Regularizer]] = {
        "none": None,
        "l1": regularizers.L1(l1=0.01),
        "l2": regularizers.L2(l2=0.01),
        "l1l2": regularizers.L1L2(l1=0.01, l2=0.01),
        "orthogonal": regularizers.OrthogonalRegularizer(factor=0.01, mode="rows"),
    }

    # Return the appropriate regularizer, or None if not found
    return regularizer_map.get(reg_type, None)

### 4.2. Activation Functions

In [None]:
def get_activation(trial: Any, name: str) -> Union[str, Callable[..., layers.Layer]]:
    """
    Suggests an activation function from a predefined list using Optuna.

    Args:
        trial (Any): The Optuna trial instance used to suggest a value.
        name (str): A unique name for this hyperparameter (e.g., "layer_1_activation").

    Returns:
        Union[str, Callable[..., layers.Layer]]: A string representing the activation function.
        This can be passed directly into a Keras layer's `activation=` argument.
    """
    return trial.suggest_categorical(
        name,
        [
            "relu",
            "tanh",
            "sigmoid",  # Logistic
            "elu", 
            "swish",  # x * sigmoid(x)
            "leaky_relu",
        ],
    )

### 4.3. Optimizers

In [None]:
def get_optimizer(trial: optuna.Trial) -> tf.keras.optimizers.Optimizer:
    """
    Suggests and returns a TensorFlow optimizer with a trial-based learning rate.

    Args:
        trial (optuna.Trial): Optuna trial object used for hyperparameter suggestion.

    Returns:
        tf.keras.optimizers.Optimizer: An instance of the selected optimizer.
    """
    # Suggest optimizer name from a predefined categorical set
    optimizer_name = trial.suggest_categorical(
        "optimizer",
        [
            "AdamW",
            # "SGD",
            # "Adam",
            # "RMSprop",
            # "Nadam",
            # "Lion",
        ],
    )

    # Suggest learning rate on a logarithmic scale between 1e-5 and 1e-2
    learning_rate = trial.suggest_float("learning_rate", 1e-5, 1e-2, log=True)

    # Mapping of optimizer names to their TensorFlow classes
    optimizer_map: Dict[str, Type[tf.keras.optimizers.Optimizer]] = {
        "Adam": optimizers.Adam,
        "AdamW": optimizers.AdamW,
        "SGD": optimizers.SGD,
        "RMSprop": optimizers.RMSprop,
        "Nadam": optimizers.Nadam,
        "Lion": optimizers.Lion,
    }

    # Raise error if selected optimizer is not supported in the current context
    if optimizer_name not in optimizer_map:
        raise ValueError(
            f"Optimizer '{optimizer_name}' is not supported. "
            f"Supported optimizers are: {list(optimizer_map.keys())}."
        )

    # Instantiate and return the selected optimizer with suggested learning rate
    return optimizer_map[optimizer_name](learning_rate=learning_rate)

### 4.4. Callbacks

In [None]:
def get_callbacks(trial: optuna.Trial, checkpoint_dir: str) -> List[tf.keras.callbacks.Callback]:
    """
    Constructs and returns a list of Keras callbacks tailored for Optuna trials.

    Args:
        trial (optuna.Trial): The current Optuna trial object.
        checkpoint_dir (str): Directory where model weights will be saved.

    Returns:
        List[tf.keras.callbacks.Callback]: A list of callbacks to pass into `model.fit()`.
    """
    # Construct path for saving weights for this specific trial
    checkpoint_path: str = os.path.join(checkpoint_dir, f"trial_{trial.number}.weights.h5")

    # Metric to monitor for early stopping and checkpointing
    monitor: str = "val_loss"

    # Stop training early if no improvement in validation loss for N epochs
    early_stopping = callbacks.EarlyStopping(
        monitor=monitor,
        patience=6,  # number of epochs to wait
        restore_best_weights=True,
        verbose=1,
    )

    # Reduce learning rate if validation loss plateaus
    reduce_lr = callbacks.ReduceLROnPlateau(
        monitor=monitor,
        patience=3,  # how many epochs to wait before reducing LR
        factor=0.2,  # reduce LR by this factor
        min_lr=1e-6,  # don't reduce below this
        verbose=1,
    )

    # Save only the best model weights based on monitored metric
    model_checkpoint = callbacks.ModelCheckpoint(
        filepath=checkpoint_path,
        monitor=monitor,
        save_best_only=True,  # only save weights if val_loss improves
        save_weights_only=True,  # save only the weights (not full model)
        verbose=0,
    )

    #! ——————— WARNING: the callbacks below do not work with multi-objective —————— !#
    # Custom callback to prune trial if NaN loss is encountered
    nan_pruner_callback = NanLossPrunerCallback(trial)

    # Optuna's built-in pruning callback for early trial termination
    pruning_callback = KerasPruningCallback(trial, monitor)
    #! ———————————————————————————————————————————————————————————————————————————— !#

    # Return the complete list of callbacks
    return [early_stopping, reduce_lr, model_checkpoint, nan_pruner_callback, pruning_callback]

### 4.5. Scalers

In [None]:
def get_scaler(
    trial: optuna.Trial,
) -> Union[StandardScaler, MinMaxScaler, RobustScaler, QuantileTransformer, PowerTransformer]:
    """
    Suggests and returns a scikit-learn scaler based on Optuna hyperparameter selection.

    Args:
        trial (optuna.Trial): Optuna trial object used to suggest hyperparameters.

    Returns:
        Union[StandardScaler, MinMaxScaler, RobustScaler, QuantileTransformer, PowerTransformer]:
            Instantiated scaler object from scikit-learn.
    """
    # Suggest a scaler name from the list of supported options
    scaler_name = trial.suggest_categorical(
        "scaler",
        [
            "StandardScaler",  # For normally-distributed data
            "MinMaxScaler_-1_1",  # Normalize to [-1, 1] range
            "MinMaxScaler_0_1",  # Normalize to [0, 1] range
            # "RobustScaler",  # For data with outliers
            # "QuantileTransformer",  # For non-normal or skewed data
            # "PowerTransformer",  # For heavy-tailed or skewed data
        ],
    )

    # Return the appropriate scaler instance based on selection
    if scaler_name == "StandardScaler":
        return StandardScaler()
    elif scaler_name == "RobustScaler":
        return RobustScaler()
    elif scaler_name == "QuantileTransformer":
        return QuantileTransformer(output_distribution="normal")
    elif scaler_name == "PowerTransformer":
        return PowerTransformer(method="yeo-johnson")
    elif scaler_name == "MinMaxScaler_0_1":
        return MinMaxScaler(feature_range=(0, 1))
    elif scaler_name == "MinMaxScaler_-1_1":
        return MinMaxScaler(feature_range=(-1, 1))

    # Catch invalid or unknown choices
    else:
        raise ValueError(f"Unknown scaler selected: {scaler_name}")

### 4.6. Implementation getters

In [None]:
def get_observed_ports(sinr_data, num_observed_ports, total_ports):
    """
    Extracts SINR values for the specified number of observed ports.

    The function selects a subset of SINR data by identifying equally spaced ports based on the
    number of observed ports specified. It returns the SINR values for these observed ports and
    their corresponding indices.

    Args:
        sinr_data (numpy.ndarray): A 2D array where each row represents an observation and each column
                                   represents a port with its corresponding SINR values.
        num_observed_ports (int): The number of observed ports to select from the SINR data.
        total_ports (int): The total number of ports in the SINR data.

    Returns:
        observed_sinr (numpy.ndarray): A 2D array containing the SINR values for the observed ports.
        observed_indices (numpy.ndarray): A 1D array of the indices corresponding to the observed ports.
    """
    observed_indices = np.linspace(0, total_ports - 1, num_observed_ports, dtype=int)
    observed_sinr = sinr_data[:, observed_indices]

    return observed_sinr, observed_indices


def getOP(
    observed_indices: np.ndarray, 
    predicted_values: np.ndarray, 
    true_values: np.ndarray, 
    threshold: float, 
    snr_linear: float,
    total_ports: int
) -> float:
    """Estimate the outage probability for regression models.

    This function compares the predicted and observed signal values at different 
    channels (ports) and determines whether the chosen signal is above a given threshold. 
    The outage probability is then computed as the proportion of times the signal falls 
    below this threshold.

    Args:
        observed_indices (np.ndarray): Indices of the observed ports (channels).
        predicted_values (np.ndarray): Matrix of predicted values for each sample.
        true_values (np.ndarray): Ground-truth values for each port.
        threshold (float): Threshold value for determining outage.
        snr_linear (float): Signal-to-noise ratio in linear scale.

    Returns:
        float: Estimated outage probability.
    """
    
    # Initialize an array with negative infinity to store the observed values
    observed_values_matrix = np.full((true_values.shape[0], total_ports), -np.inf, dtype=np.float64)

    # Assign the true values of the observed ports (channels) to the matrix
    observed_values_matrix[:, observed_indices] = true_values[:, observed_indices]

    # Find the index of the highest predicted value for each sample
    best_predicted_indices = np.argmax(predicted_values, axis=1)

    # Initialize an array with negative infinity to store the predicted values
    predicted_values_matrix = np.full((true_values.shape[0], total_ports), -np.inf, dtype=np.float64)

    # Assign the true value corresponding to the predicted best port
    predicted_values_matrix[np.arange(len(best_predicted_indices)), best_predicted_indices] = (
        true_values[np.arange(len(best_predicted_indices)), best_predicted_indices]
    )

    # Take the element-wise maximum between the observed and predicted value matrices
    best_value_matrix = np.maximum(observed_values_matrix, predicted_values_matrix)

    # print("Shape of Best Value Matrix:", best_value_matrix.shape)

    # Find the index of the best predicted or observed port (channel) for each sample
    best_predicted_or_observed_ports = np.argmax(best_value_matrix, axis=1)

    # print("Shape of Best Predicted/Observed Ports:", best_predicted_or_observed_ports.shape)
    # print("Number of Selected Ports:", len(best_predicted_or_observed_ports))

    # Retrieve the actual values corresponding to the best selected ports
    selected_values = best_value_matrix[np.arange(len(true_values)), best_predicted_or_observed_ports]

    # print("Shape of Selected Values:", selected_values.shape)

    # Determine which selected values are above the given threshold
    above_threshold = selected_values > (threshold / snr_linear)

    # print("Shape of Above Threshold Array:", above_threshold.shape)

    # Compute the outage probability: probability that the selected value is below the threshold
    outage_probability = 1.0 - (np.sum(above_threshold) / len(true_values))

    return outage_probability


## 5. Layers Builders

### 5.1. CNN

In [None]:
def build_cnn1d(
    trial: optuna.Trial,
    x: layers.Layer,
    num_layers: int = 5,
    max_filters: int = 256,
    min_filters: int = 32,
    filter_step: int = 32,
    max_kernel_size: int = 10,
    min_pool_size: int = 2,
    max_pool_size: int = 2,
    use_batch_norm: bool = False,
    use_regularization: bool = False,
    residual_method: Optional[str] = None,
    custom_name: str = "cnn1d",
) -> layers.Layer:
    """
    Builds a 1D CNN where Optuna picks filters, kernel sizes, and pooling window per layer.

    Args:
        trial (optuna.Trial): Optuna trial object.
        x (Layer): Input Keras tensor.
        num_layers (int): Number of Conv1D + Pool1D blocks.
        max_filters (int): Upper bound on number of filters.
        min_filters (int): Lower bound on number of filters.
        filter_step (int): Step size when sampling number of filters.
        max_kernel_size (int): Maximum size of the 1D convolution kernel.
        min_pool_size (int): Minimum size of the 1D pooling window.
        max_pool_size (int): Maximum size of the 1D pooling window.
        use_batch_norm (bool): If True, trial may enable BatchNormalization per layer.
        use_regularization (bool): If True, trial picks kernel, bias, and activity regularizers.
        residual_method (Optional[str]): One of {None, "beside", "all"} to control skip connections.
        custom_name (str): Prefix for naming all layers.

    Returns:
        Layer: Output tensor after applying all CNN blocks.
    """

    # Placeholders for residual connection strategies
    beside_residual: Optional[layers.Layer] = None
    all_skip_connections: List[layers.Layer] = []

    for layer_idx in range(num_layers):
        # 1) Sample pooling window size
        pool_size = trial.suggest_int(
            f"{custom_name}_pool_size_{layer_idx}", min_pool_size, max_pool_size
        )

        # 2) Sample number of filters
        num_filters = trial.suggest_int(
            f"{custom_name}_filters_{layer_idx}",
            min_filters,
            max_filters,
            step=filter_step,
        )

        # 3) Sample convolution kernel size
        kernel_size = trial.suggest_int(
            f"{custom_name}_kernel_size_{layer_idx}", 1, max_kernel_size
        )

        # 4) Activation function
        activation_fn = get_activation(trial, f"{custom_name}_activation_{layer_idx}")

        # 5) Regularizers (if enabled)
        kernel_reg = (
            get_regularizer(trial, f"{custom_name}_kernel_regularizer_{layer_idx}")
            if use_regularization
            else None
        )
        bias_reg = (
            get_regularizer(trial, f"{custom_name}_bias_regularizer_{layer_idx}")
            if use_regularization
            else None
        )
        activity_reg = (
            get_regularizer(trial, f"{custom_name}_activity_regularizer_{layer_idx}")
            if use_regularization
            else None
        )

        # 6) 1D convolution
        x = layers.Conv1D(
            filters=num_filters,
            kernel_size=kernel_size,
            activation=activation_fn,
            padding="same",
            name=f"{custom_name}_conv1d_{layer_idx}",
            kernel_regularizer=kernel_reg,
            bias_regularizer=bias_reg,
            activity_regularizer=activity_reg,
        )(x)

        # 7) Optional BatchNormalization
        if use_batch_norm and trial.suggest_categorical(
            f"{custom_name}_use_batch_norm_{layer_idx}", [True, False]
        ):
            x = layers.BatchNormalization(name=f"{custom_name}_batch_norm_{layer_idx}")(x)

        # 8) Residual connections
        if residual_method == "beside":
            if layer_idx == 0:
                beside_residual = x
            else:
                if trial.suggest_categorical(
                    f"{custom_name}_use_residual_{layer_idx}", [True, False]
                ):
                    prev = beside_residual
                    target_ch = x.shape[-1]
                    # Align channel dimensions if needed
                    if prev.shape[-1] != target_ch:
                        prev = layers.Conv1D(
                            filters=target_ch,
                            kernel_size=1,
                            padding="same",
                            name=f"{custom_name}_res_align_{layer_idx}",
                        )(prev)
                    x = layers.Add(name=f"{custom_name}_res_add_{layer_idx}")([x, prev])
                    beside_residual = x
                else:
                    beside_residual = x

        elif residual_method == "all":
            if layer_idx == 0:
                all_skip_connections = [x]
            else:
                to_add: List[layers.Layer] = []
                for prev_idx, prev_layer in enumerate(all_skip_connections):
                    if trial.suggest_categorical(
                        f"{custom_name}_use_residual_{layer_idx}_{prev_idx}", [True, False]
                    ):
                        prev = prev_layer
                        target_ch = x.shape[-1]
                        if prev.shape[-1] != target_ch:
                            prev = layers.Conv1D(
                                filters=target_ch,
                                kernel_size=1,
                                padding="same",
                                name=f"{custom_name}_skip_res_conv1d_{layer_idx}_{prev_idx}",
                            )(prev)
                        to_add.append(prev)
                if to_add:
                    x = layers.Add(
                        name=f"{custom_name}_res_all_add_{layer_idx}"
                    )([x] + to_add)
                all_skip_connections.append(x)

        # 9) 1D MaxPooling
        x = layers.MaxPooling1D(
            pool_size=pool_size, name=f"{custom_name}_maxpool_{layer_idx}"
        )(x)

    return x

## 6. Objective Function

In [None]:
def objective(
    trial: optuna.Trial,
    X: List[np.ndarray],
    y: List[np.ndarray],
    checkpoint_dir: str,
    model_dir: str,
    fig_dir: str,
    logs_dir: str,
    epochs: int = 50,
    size_penalizer: Optional[str] = None,
    use_regularization: bool = False,
    residual_method: Optional[str] = None,
    show_summary: bool = False,
    plot_model: bool = False,
) -> float:
    """
    Objective function for Optuna to optimize a Neural Network NN on any-input data.

    Args:
        trial (optuna.Trial): Current trial for hyperparameter suggestions.
        X (List[np.ndarray]): List of input arrays.
        y (List[np.ndarray]): List of label arrays.
        checkpoint_dir (str): Path to store checkpoint files.
        model_dir (str): Path to store full models.
        fig_dir (str): Path to store plots.
        logs_dir (str): Path to store logs.
        epochs (int): Number of training epochs.
        size_penalizer (Optional[str]): type of penalizer to use:
            - "params": Penalizes based on the number of parameters.
            - "flops": Penalizes based on the number of FLOPs.
            - None: No penalization is applied.
        use_regularization (bool): If True, adds regularization (e.g., L1/L2) to layers to prevent overfitting.
        residual_method (Optional[str]): tyoe of residual connection to use:
            - "beside": Adds residual connections between consecutive layers.
            - "all": test residual connections between all layers.
            - None: No residual connections are applied.
        show_summary (bool): If True, display the model summary.
        plot_model (bool): If True, display a plot of the model architecture.

    Returns:
        float: Final validation loss (optionally penalized) used for optimization.
    """
    
    # ————————————————————————————— Prepare the Data ————————————————————————————— #
    X_train, X_val = X[0], X[1]
    y_train, y_val = y[0], y[1]

    # ———————————————————————————————————————————————————————————————————————————— #

    model = None
    try:

        # ———————————————————————————————————————————————————————————————————————————— #
        #                              Model Construction                              #
        # ———————————————————————————————————————————————————————————————————————————— #

        # —————————————————————————————————— Scaler —————————————————————————————————— #
        scaler = get_scaler(trial)
        X_train = scaler.fit_transform(X_train)
        X_val = scaler.transform(X_val)
        
        # ———————————————————————————— Reshape for 1D CNN ———————————————————————————— #
        # (num_samples, observed_ports) -> (num_samples, observed_ports, 1)
        X_train = X_train.reshape(X_train.shape[0], X_train.shape[1], 1)
        X_val = X_val.reshape(X_val.shape[0], X_val.shape[1], 1)

        # ——————————————————————————— Observed ports input ——————————————————————————— #
        inputs = layers.Input(shape=(X_train.shape[1], 1))  # Observed ports as input

        max_layers = trial.suggest_int("num_layers", 1, 4)

        # Calculate max pool size
        max_pool_dim1 = math.floor(X_train.shape[1] ** (1.0 / max_layers))

        x = build_cnn1d(
            trial,
            inputs,
            num_layers=max_layers,
            max_filters=512,
            min_filters=32,
            filter_step=32,
            max_kernel_size=5,
            min_pool_size=1,
            max_pool_size=max_pool_dim1,
            use_batch_norm=True,
            use_regularization=use_regularization,
            residual_method=residual_method,
        )

        # ———————————————————————————— Flatten the Output ———————————————————————————— #
        x = layers.Flatten(name="flatten")(x)

        # ——————————————————————————————— Dense Layers ——————————————————————————————— #
        #? This was the best performing option in the previous trials
        num_dense_layers = trial.suggest_int("num_dense_layers", 0, 3)
        for i in range(num_dense_layers):
            # Suggest the number of units for each dense layer
            units = trial.suggest_int(f"dense_{i+1}_units", 64, 512, step=64)
            x = layers.Dense(
                units=units,
                activation=get_activation(trial, f"dense_{i+1}_activation"),
                name=f"dense_{i+1}",
            )(x)
            rate = trial.suggest_float(f"dense_{i+1}_dropout", 0.0, 0.5, step=0.1)
            x = layers.Dropout(rate=rate)(x)

        # —————————————————————————————————— Output —————————————————————————————————— #
        outputs = layers.Dense(TOTAL_NUM_PORTS, activation="linear")(x)

        # —————————————————————————— Set Inputs and Outputs —————————————————————————— #
        model = Model(inputs=inputs, outputs=(outputs,))

        # ———————————————————————————— Vizualize the Model ——————————————————————————— #
        if show_summary:
            model.summary()

        if plot_model:
            # Display the model architecture image
            tf.keras.utils.plot_model(
                model,
                to_file=os.path.join(fig_dir, f"model_plot_{trial.number}.png"),
                show_shapes=True,
                show_layer_names=True,
            )
            display(Image(filename=os.path.join(fig_dir, f"model_plot_{trial.number}.png")))

        # ————————————————————————————— Compile the Model ———————————————————————————— #
        optimizer = get_optimizer(trial)
        model.compile(
            optimizer=optimizer,
            loss="mse",
        )

        # ———————————————————————————————— Train Model ——————————————————————————————— #
        batch_size = trial.suggest_categorical("batch_size", [64, 128, 256])
        history = model.fit(
            X_train,
            y_train,
            validation_data=(X_val, y_val),
            epochs=epochs,
            batch_size=batch_size,
            callbacks=get_callbacks(trial, checkpoint_dir),
            verbose=2,
        )

        model.save(os.path.join(model_dir, f"trial_{trial.number}.keras"))
        loss = min(history.history["val_loss"])

        # ———————————————————————————————————————————————————————————————————————————— #
        #                                 Trial Results                                #
        # ———————————————————————————————————————————————————————————————————————————— #
        clear_output(wait=True)

        epochs = list(range(1, len(history.history["loss"]) + 1))
        train_loss = history.history["loss"]
        val_loss = history.history["val_loss"]
    
        # Create figure
        fig, ax_loss = plt.subplots(figsize=(8, 6))

        # Plot Loss
        ax_loss.plot(epochs, train_loss, marker="o", linestyle="-", label="Training Loss")
        ax_loss.plot(epochs, val_loss, marker="x", linestyle="--", label="Validation Loss")
        ax_loss.set_title("Training & Validation Loss")
        ax_loss.set_xlabel("Epoch")
        ax_loss.set_ylabel("Loss")
        ax_loss.set_xticks(epochs)
        ax_loss.set_ylim(0, max(max(train_loss), max(val_loss)) * 1.05)
        ax_loss.grid(True)
        ax_loss.legend()

        fig.tight_layout()
        fig.savefig(os.path.join(fig_dir, f"trial_{trial.number}.png"), dpi=300)
        plt.close(fig)

        # ————————————————————————————————— Evaluate ————————————————————————————————— #
        datasets = [
            kappa0_mu1_m0_test,
            kappa0_mu1_m2_test,
            kappa0_mu1_m50_test,
            kappa0_mu2_m50_test,
            kappa0_mu5_m50_test,
            kappa5_mu1_m0_test,
            kappa5_mu1_m2_test,
            kappa5_mu1_m50_test,
            kappa5_mu2_m0_test,
            kappa5_mu2_m2_test,
            kappa5_mu2_m50_test,
            kappa5_mu5_m0_test,
            kappa5_mu5_m2_test,
            kappa5_mu5_m50_test,
        ]
        test_losses = []
        ops = []

        n_ports = X_train.shape[1]

        for i, dataset in enumerate(datasets, start=1):
            observed_ports, observed_indices = get_observed_ports(
                dataset, num_observed_ports=n_ports, total_ports=TOTAL_NUM_PORTS
            )
            
            X_test, y_test = observed_ports, dataset
            X_test = scaler.transform(X_test)
            
            # Reshape for 1D CNN
            X_test = X_test.reshape(X_test.shape[0], X_test.shape[1], 1)
            
            test_loss = model.evaluate(X_test, y_test, batch_size=batch_size, verbose=0)
            test_losses.append(test_loss)
            
            # Make predictions
            y_pred = model.predict(X_test)

            # Calculate Outage Probability (OP)
            op_value = getOP(
                observed_indices=observed_indices,
                predicted_values=y_pred,
                true_values=y_test,
                threshold=THRESHOLD,
                snr_linear=SNR_LINEAR,
                total_ports=TOTAL_NUM_PORTS,
            )
            ops.append(op_value)

        dataset_names: list[str] = [
            "kappa0_mu1_m0_test",
            "kappa0_mu1_m2_test",
            "kappa0_mu1_m50_test",
            "kappa0_mu2_m50_test",
            "kappa0_mu5_m50_test",
            "kappa5_mu1_m0_test",
            "kappa5_mu1_m2_test",
            "kappa5_mu1_m50_test",
            "kappa5_mu2_m0_test",
            "kappa5_mu2_m2_test",
            "kappa5_mu2_m50_test",
            "kappa5_mu5_m0_test",
            "kappa5_mu5_m2_test",
            "kappa5_mu5_m50_test",
        ]

        dataset_test_losses = {
            f"{name}_loss": loss
            for name, loss in zip(dataset_names, test_losses)
        }
        dataset_test_ops = {
            f"{name}_op": op
            for name, op in zip(dataset_names, ops)
        }

        # Print test losses in a formatted manner
        print("\n" + "=" * 15)
        for name in dataset_names:
            loss = dataset_test_losses[f"{name}_loss"]
            op   = dataset_test_ops[f"{name}_op"]
            print(
                f"{name.replace('_', ' ').capitalize()}: "
                f"Loss = {loss:.12f}, OP = {op:.12f}"
            )
            

        params = model.count_params()
        print(f"\nNumber of parameters: {params}")
        print(f"Model size: {params * 4 / (1024 ** 2):.2f} MB")
        print("=" * 15 + "\n")
        
        trial.set_user_attr("num_params", params)
        trial.set_user_attr("model_size_mb", params * 4 / (1024 ** 2))

        return loss

    except optuna.exceptions.TrialPruned:
        raise  # simply propagate pruning
    except tf.errors.ResourceExhaustedError as oom_err:
        # Catch OOM / resource exhausted
        print(f"❌ Trial {trial.number} hit OOM (ResourceExhaustedError): {oom_err}")

        # Log the error to a file in the logs directory
        error_log_path = os.path.join(logs_dir, f"trial_{trial.number}_error.log")
        with open(error_log_path, "w") as log_file:
            log_file.write(f"Trial {trial.number} encountered an error:\n")
            log_file.write(str(oom_err) + "\n\n")
            log_file.write("Traceback:\n")
            traceback.print_exc(file=log_file)

        return float("inf")  # Return bad loss
    except Exception as e:
        print(f"An error occurred during the trial execution: {e}")
        traceback.print_exc()

        # Log the error to a file in the logs directory
        error_log_path = os.path.join(logs_dir, f"trial_{trial.number}_error.log")
        with open(error_log_path, "w") as log_file:
            log_file.write(f"Trial {trial.number} encountered an error:\n")
            log_file.write(str(e) + "\n\n")
            log_file.write("Traceback:\n")
            traceback.print_exc(file=log_file)

        return float("inf")  # Return bad loss
    finally:
        if model is not None:
            clear_session()
            del model

## 7. Code Health Check

In [None]:
# resources_dir = os.path.join(RUN_DIR, "resources")
# os.makedirs(resources_dir, exist_ok=True)
# troo.log_resources(log_dir=resources_dir)

In [None]:
try:
    pid = os.getpid()
    cmd = (
        f'python3 "{os.path.abspath("_monitor_kernel_life.py")}" '
        f"--pid {pid} --custom-title {RUN_DIR}; exec bash"
    )
    terminals = [
        ["xfce4-terminal", "--disable-server", "--hold", "-e", f'bash -c "{cmd}"'],
        ["gnome-terminal", "--disable-factory", "--", "bash", "-i", "-c", cmd],
        ["xterm", "-hold", "-e", cmd],
        ["konsole", "--hold", "-e", f'bash -c "{cmd}"'],
    ]
    term = next((t for t in terminals if shutil.which(t[0])), None)
    if not term:
        raise RuntimeError(
            "No supported terminal emulator found; install gnome-terminal, "
            "xfce4-terminal, konsole, or xterm."
        )
    _monitor_proc = subprocess.Popen(term, preexec_fn=os.setpgrp)
    print(f"[INFO] Launched monitor in {term[0]} (PID={pid})")
except Exception as e:
    print(f"[ERROR] Auto launching kernel monitoring failed! {e}\n")
    display(
        HTML(
            f"Call the monitor script manually: "
            f'<span style="color: orange;">'
            f"python _monitor_kernel_life.py --pid {pid} --custom-title {RUN_DIR}"
            f"</span>"
        )
    )
    pass

## Main

In [None]:
for n in observed_ports_list:
    try:
        # ——————————————————————————————— Storage paths —————————————————————————————— #
        study_dir = os.path.join(RUN_DIR, f"optuna_study_{n}_ports")
        os.makedirs(study_dir, exist_ok=True)

        dirs = {
            "args": os.path.join(study_dir, "args"),
            "figures": os.path.join(study_dir, "figures"),
            "weights": os.path.join(study_dir, "weights"),
            "models": os.path.join(study_dir, "models"),
            "logs": os.path.join(study_dir, "logs"),
        }
        for path in dirs.values():
            os.makedirs(path, exist_ok=True)

        storage_path = f"sqlite:///{os.path.join(study_dir, 'optuna_study.db')}"
        checkpoint_dir, model_dir, fig_dir, args_dir, logs_dir = (
            dirs["weights"],
            dirs["models"],
            dirs["figures"],
            dirs["args"],
            dirs["logs"],
        )

        print(f"Initializing study at '{study_dir}'...")
        
        # ——————————————————————————————————— Data ——————————————————————————————————— #
        observed_ports, observed_indices = get_observed_ports(
            dataset, num_observed_ports=n, total_ports=TOTAL_NUM_PORTS
        )
        
        # split the dataset into training and validation sets
        X_train, X_val, y_train, y_val = train_test_split(
            observed_ports,
            dataset,
            test_size=0.2,
            random_state=0,
            shuffle=True,
        )
            
        # —————————————————————————————————— Pruners ————————————————————————————————— #
        pruner = optuna.pruners.HyperbandPruner()

        # ——————————————————————————————————— Study —————————————————————————————————— #
        study = optuna.create_study(
            study_name=os.path.basename(study_dir),
            storage=storage_path,
            direction="minimize",
            pruner=pruner,
            load_if_exists=True,
        )

        # Count trials done, then determine the remaining trials
        done_trials = len(
            study.get_trials(
                deepcopy=False,
                states=(
                    optuna.trial.TrialState.COMPLETE,
                    optuna.trial.TrialState.PRUNED,
                    optuna.trial.TrialState.FAIL,
                ),
            )
        )
        n_remaining_trials = max(0, NUM_TRIALS - done_trials)

        study.optimize(
            lambda trial: objective(
                trial,
                X=[X_train, X_val],
                y=[y_train, y_val],
                checkpoint_dir=checkpoint_dir,
                model_dir=model_dir,
                fig_dir=fig_dir,
                logs_dir=logs_dir,
                epochs=EPOCHS,
                size_penalizer=None,
                use_regularization=False,
                residual_method=None,  #! Find your backbone first
                show_summary=False,
            ),
            n_trials=n_remaining_trials,
            catch=(ValueError, RuntimeError),
            gc_after_trial=True,
            n_jobs=1,  # If you have multiple GPUs/Cores
            show_progress_bar=False,
        )
    except Exception as e:
        print(f"An error occurred: {e}")
        traceback.print_exc()

In [None]:
# Kill the monitor kernel life process
if _monitor_proc is not None and _monitor_proc.poll() is None:
    os.killpg(_monitor_proc.pid, signal.SIGINT)