# {TITLE}

- **Author:** {AUTHOR} 
- **Email:** {EMAIL}
- **Date:** {DATE}

{DESCRIPTION}

### Tips

1.
    If the distribution is highly concentrated, with most samples falling within a narrow range, 
    which indicates an uneven spread of the data.

    When generating a subset of the dataset, only a small number of samples with higher values 
    are included, resulting in a skewed distribution and causing outliers in the test set.
    Outliers are data points that differ significantly from other observations and can distort 
    statistical analyses, leading to a high Mean Squared Error (MSE) when their values are far 
    from the training set's maximum values.

    To address this issue, try normalizing the data before splitting the dataset into training 
    and testing subsets. Normalizing beforehand can help ensure a more balanced distribution, 
    reducing the impact of outliers.

    Ideally, creating a larger dataset with more diverse samples and then applying proper 
    preprocessing would help achieve a more representative distribution, minimizing outliers 
    and improving overall model performance.

2.
    Debug every variable and data before using it for training. Make sure they are correctly implemented.
    
    Otherwise, the model will not work properly.

## 1. Setup and Configuration

### 1.1. Imports

In [1]:
import os
import logging
import datetime
import numpy as np
import matplotlib.pyplot as plt
import scipy.io

import tensorflow as tf
from tensorflow.keras.utils import plot_model
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint, TensorBoard
from tensorflow.keras.optimizers import AdamW
from tensorflow.keras.backend import clear_session

from keras import layers, Input, Model
from sklearn.preprocessing import StandardScaler, MinMaxScaler
from sklearn.model_selection import train_test_split

# Disable TensorFlow warnings
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
tf.get_logger().setLevel(logging.ERROR)

### 1.2. GPU Configuration

In [5]:
def get_gpu_info():
    """
    Retrieves and prints detailed GPU information including TensorFlow,
    CUDA, cuDNN versions, number of GPUs, and memory details.
    """
    # Display TensorFlow version
    print(f"TensorFlow Version: {tf.__version__}")

    # Check if TensorFlow is built with CUDA support and retrieve build info
    if tf.test.is_built_with_cuda():
        build_info = tf.sysconfig.get_build_info()
        print(f"TensorFlow is built with CUDA support")
        print(f"CUDA Version: {build_info['cuda_version']}")
        print(f"cuDNN Version: {build_info['cudnn_version']}")
    else:
        print("Running on CPU (No CUDA support detected)")

    # Detect available GPUs
    gpus = tf.config.list_physical_devices("GPU")
    if gpus:
        print(f"Number of GPUs detected: {len(gpus)}")
        print(f"Available GPU(s): {[gpu.name for gpu in gpus]}")
        for i, gpu in enumerate(gpus):
            # Get GPU details using logical device configuration
            details = tf.config.experimental.get_device_details(gpu)

            memory_info = tf.config.experimental.get_memory_info(gpu.name)
            print(f"GPU {i} Details:")
            print(f"  Name: {details.get('device_name', 'Unknown')}")
            print(f"  Total Memory: {memory_info.get('total', 'Unknown')} bytes")
            print(f"  Free Memory: {memory_info.get('free', 'Unknown')} bytes")

        print(f"Using GPU: {tf.test.gpu_device_name()}")
    else:
        print("No GPUs found")
        print("Running on CPU")
get_gpu_info()

def enable_memory_growth():
    """
    Enables memory growth for all detected GPUs.
    """
    # Specify GPU to use (e.g., GPU 0)
    os.environ["CUDA_VISIBLE_DEVICES"] = "0"

    gpus = tf.config.list_physical_devices("GPU")
    if gpus:
        for i, gpu in enumerate(gpus):
            try:
                tf.config.experimental.set_memory_growth(gpu, True)
                print(f"Memory growth enabled for GPU {i}: {gpu.name}")
            except RuntimeError as e:
                print(f"Error enabling memory growth for GPU {i}: {gpu.name}, {e}")
    else:
        print("Error enabling memory growth: No GPUs found")
enable_memory_growth()

# Specify GPU to use (e.g., GPU 0)
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

TensorFlow Version: 2.13.0
Running on CPU (No CUDA support detected)
No GPUs found
Running on CPU
Error enabling memory growth: No GPUs found


## 2. Constants and Hyperparameters

In [None]:
# Training Parameters
TOTAL_NUM_PORTS = 144  # Total number of ports
BATCH_SIZE = 32  # Batch size for training
EPOCHS = 20  # Number of epochs for training

# Threshold in dB for outage probability
OUTAGE_THRESHOLD = 20  

# Scaling Method
SCALER = MinMaxScaler(clip=True)

# Checkpoint and Log Directories
CHECKPOINT_DIR = "dnn_checkpoints"
if not os.path.exists(CHECKPOINT_DIR):
    os.makedirs(CHECKPOINT_DIR)
    
LOG_DIR = os.path.join("dnn_logs", "fit", datetime.datetime.now().strftime("%Y%m%d-%H%M%S"))

# Set random seeds for reproducibility
SEED = 42
np.random.seed(SEED)
tf.random.set_seed(SEED)


## 3. Data Loading and Preprocessing

In [None]:
def load_sinr_data(filepath, key="gamma_k"):
    """
    Loads SINR data from a .mat file.

    The function reads a .mat file and extracts a specified key that contains the SINR matrix data. 
    This is useful for importing and accessing SINR values stored in MATLAB files.

    Args:
        filepath (str): Path to the .mat file containing the SINR data.
        key (str, optional): Key used to extract the SINR matrix from the loaded .mat file. 
                             Defaults to "gamma_k".

    Returns:
        sinr_matrix (numpy.ndarray): A 2D array containing the SINR values extracted from the .mat file.
    """
    # Load the .mat file
    mat = scipy.io.loadmat(filepath)

    # Extract the SINR matrix
    sinr_matrix = mat[key]

    return sinr_matrix

In [None]:
# Load the dataset
sinr_data = load_sinr_data('data/Rayleigh/SNR_ports.mat')
sinr_data = sinr_data[:int(1.0 * sinr_data.shape[0]), :] # Subsample data
sinr_data = 10 * np.log10(sinr_data) # Convert to dB
#! Warning: converting to dB makes the numbers much closer together, which may make training harder

print("Shape of the data: ", sinr_data.shape)

# Compute statistics
mean_sinr_per_sample = np.mean(sinr_data, axis=0)
mean_of_means = np.mean(mean_sinr_per_sample)
print(f"Mean of the means of SINR values: {mean_of_means}")


In [None]:
# Plot histogram of the dataset values
fig, axs = plt.subplots(1, 1, sharey=True, tight_layout=True)
axs.hist(sinr_data, bins=100, density= True)
plt.show()

## 4. Model Definition

In [None]:
def build_model(input_shape: int) -> Model:
    """
    Constructs a neural network model based on the best parameters from Optuna optimization.

    Args:
        input_shape (int): The number of observed ports (input size).
        num_ports (int): Total number of output ports.

    Returns:
        keras.Model: A compiled Keras Model instance ready for training.
    """
    # ----------------------------- Input Layer ----------------------------- #
    inputs = Input(shape=(input_shape,))  # Observed ports as input

    # ---------------------------- Dense Layers ---------------------------- #
    # First Dense Layer
    x = layers.Dense(
        units=491,  # Optimized unit count
        activation="relu"  # Optimized activation
    )(inputs)

    # Second Dense Layer
    x = layers.Dense(
        units=497,  # Optimized unit count
        activation="relu"  # Optimized activation
    )(x)

    # Third Dense Layer
    x = layers.Dense(
        units=476,  # Optimized unit count
        activation="tanh"  # Optimized activation
    )(x)

    # ------------------------------- Output Layer ------------------------------- #
    # Generate the predicted SINR value for each port
    outputs = layers.Dense(TOTAL_NUM_PORTS, activation="linear")(x)

    return Model(inputs=inputs, outputs=outputs)


## 5. Utility Functions

### 5.1. Outage Probability Functions

In [None]:
def compute_ideal_op(sinr_data, threshold):
    """
    Computes the outage probability based on the ideal case, which considers the best SINR
    value across all ports.

    The function finds the maximum SINR value from all available ports for each sample and
    checks if it falls below the given threshold. If the maximum SINR for a sample is below
    the threshold, the sample is considered to be in outage.

    Args:
        sinr_data (np.array): SINR data for all ports (shape: num_samples, num_ports).
        threshold (float): SINR threshold below which the outage is considered.

    Returns:
        outage_prob (float): The outage probability based on the ideal case.
    """
    # Maximum SINR across all ports for each sample
    max_sinr_all_ports = np.max(sinr_data, axis=1)

    # Calculate outage probability: percentage of samples where max SINR is below the threshold
    outage_prob = np.mean(max_sinr_all_ports < threshold)

    return outage_prob


def compute_reference_op(observed_sinr, threshold):
    """
    Computes the outage probability based on the reference case, which considers the best SINR
    value from the observed ports.

    The function finds the maximum SINR value from the observed ports for each sample and
    checks if it falls below the given threshold. If the maximum observed SINR for a sample
    is below the threshold, the sample is considered to be in outage.

    Args:
        observed_sinr (np.array): Observed SINR data for the selected observed ports (shape: num_samples, num_observed_ports).
        threshold (float): SINR threshold below which the outage is considered.

    Returns:
        outage_prob (float): The outage probability based on the reference case.
    """
    # Find the maximum SINR value from the observed ports for each sample
    best_observed_sinr = np.max(observed_sinr, axis=1)

    # Calculate outage probability: percentage of samples where best observed SINR is below the threshold
    outage_prob = np.mean(best_observed_sinr < threshold)

    return outage_prob


def compute_model_op(predicted_sinr, original_data, threshold):
    """
    Computes the outage probability for the model-generated SINR values by selecting the best-predicted
    port for each sample and using the corresponding SINR value from the original SINR data.

    Args:
        predicted_sinr (np.array): SINR values generated by the model (shape: num_samples, num_ports).
        original_data (np.array): Original SINR values for the entire dataset (shape: num_samples, num_ports).
        threshold (float): SINR threshold below which the outage is considered.

    Returns:
        outage_prob (float): The outage probability based on the model-generated SINR values.
    """
    # Step 1: Find the port with the highest predicted SINR value for each sample
    best_predicted_ports = np.argmax(predicted_sinr, axis=1)

    # Step 2: Retrieve the original SINR values for these best-predicted ports
    best_sinr_values = original_data[np.arange(original_data.shape[0]), best_predicted_ports]

    # Step 3: Calculate outage probability: percentage of samples where the best SINR is below the threshold
    outage_prob = np.mean(best_sinr_values < threshold)

    return outage_prob

### 5.2. Data Preparation

In [None]:
def get_observed_ports(sinr_data, num_obs_ports, num_ports=144):
    """
    Extracts SINR values for the specified number of observed ports.

    The function selects a subset of SINR data by identifying equally spaced ports based on the
    number of observed ports specified. It returns the SINR values for these observed ports and
    their corresponding indices.

    Args:
        sinr_data (numpy.ndarray): A 2D array where each row represents an observation and each column
                                   represents a port with its corresponding SINR values.
        num_obs_ports (int): The number of observed ports to select from the SINR data.
        num_ports (int, optional): The total number of ports in the SINR data. Defaults to 144.

    Returns:
        observed_sinr (numpy.ndarray): A 2D array containing the SINR values for the observed ports.
        observed_indices (numpy.ndarray): A 1D array of the indices corresponding to the observed ports.
    """
    observed_indices = np.linspace(0, num_ports - 1, num_obs_ports, dtype=int)
    observed_sinr = sinr_data[:, observed_indices]

    return observed_sinr, observed_indices

### 5.3. Plotting

In [None]:
def plot_outage_probability(ideal_data, reference_data, model_data, ports_list):
    """
    Plots the outage probability for ideal, reference, and model-generated data.

    The function creates a plot with the number of observed ports on the x-axis and the
    outage probability (in logarithmic scale) on the y-axis. It plots the outage probabilities
    for the ideal, reference (observed ports), and model-generated data.

    Args:
        ideal_data (list): Outage probabilities for the ideal SINR data.
        reference_data (list): Outage probabilities for the reference (observed ports) SINR data.
        model_data (list): Outage probabilities for the model-generated SINR data.
        ports_list (list): List of observed port configurations used for the x-axis.
    """
    plt.figure(figsize=(10, 6))
    plt.plot(ports_list, ideal_data, label="Ideal", color="blue", linestyle="--", marker="o")
    plt.plot(
        ports_list,
        reference_data,
        label="Reference (Observed Ports)",
        color="green",
        linestyle="-",
        marker="x",
    )
    plt.ylim(0, 1)
    plt.plot(ports_list, model_data, label="Model", color="orange", linestyle=":", marker="s")
    # plt.yscale("log")  # Logarithmic y-axis
    plt.title("Outage Probability vs Number of Observed Ports")
    plt.xlabel("Number of Observed Ports")
    plt.ylabel("Outage Probability")
    plt.legend()
    plt.grid(True, which="both", ls="--")
    plt.tight_layout()
    plt.show()


def plot_mse(mse_history, title=""):
    """
    Plots the MSE over epochs for a single model trained with a specific number of observed ports.
    The y-axis is in logarithmic scale.

    Args:
        mse_history (list): The MSE history (list of MSE values over epochs) for the model.
        observed_ports (int): The number of observed ports used in the model training.
    """
    plt.figure(figsize=(10, 6))  # Create a new figure with a specific size

    # Plot the MSE history for this model
    epochs_range = range(1, len(mse_history) + 1)  # Create an epoch range starting from 1
    plt.plot(epochs_range, mse_history, marker="o")

    # Set the y-axis to logarithmic scale
    # plt.yscale("log")
    plt.ylim(0, 1)

    # Set the plot title and axis labels
    plt.title(title)
    plt.xlabel("Epoch")  # X-axis label: epoch number
    plt.ylabel("Mean Squared Error (MSE)")  # Y-axis label: MSE value

    # Ensure x-axis only shows integers (epochs)
    plt.xticks(ticks=epochs_range)  # Set x-axis ticks to the integer values of epochs

    # Display the legend
    plt.legend()

    # Display a grid for better readability
    plt.grid(True, which="both", ls="--")  # Grid on both major and minor ticks

    # Adjust the layout for better display
    plt.tight_layout()

    # Show the plot
    plt.show()


def plot_all_mse_over_epochs(mse_history_list, observed_ports_list, title=""):
    """
    Plots the MSE over epochs for all models trained with different numbers of observed ports on the same plot.
    Each model is represented with a different color, and the y-axis is in logarithmic scale.

    Args:
        mse_history_list (list of lists): A list where each entry is the MSE history (list of MSE values over epochs) for a specific model.
        observed_ports_list (list of int): A list of the number of observed ports for each model.
    """
    plt.figure(figsize=(10, 6))  # Create a new figure with a specific size

    for i, mse_history in enumerate(mse_history_list):
        epochs_range = range(1, len(mse_history) + 1)
        plt.plot(epochs_range, mse_history, label=f"{observed_ports_list[i]} observed ports", marker="o")

    plt.title(title)
    plt.xlabel("Epoch")  
    plt.ylabel("Mean Squared Error (MSE)")  
    plt.xticks(ticks=range(1, max(len(mse_history) for mse_history in mse_history_list) + 1))
    plt.ylim(0, 1)

    # Move legend outside of the plot
    plt.legend(title="Observed Ports", loc="center left", bbox_to_anchor=(1, 0.5))  # Move legend outside
    plt.grid(True, which="both", ls="--")
    plt.tight_layout()

    plt.show()

## 6. Training

### 6.1. Training Initialization

In [None]:
# Create a list of observed ports up to num_ports with step of 1
STEP = 1
observed_ports_list = list(range(1, 20 + STEP, STEP))

# Prepare the data
sinr_train, sinr_test = train_test_split(sinr_data, test_size=0.2, random_state=SEED)
SCALER.fit(sinr_data)
sinr_train_scaled, sinr_test_scaled = SCALER.transform(sinr_train), SCALER.transform(sinr_test)

# Initialize lists to store the results
train_mse_history_list = []
val_mse_history_list = []
outage_prob_ideal_list = []
outage_prob_reference_list = []
outage_prob_model_list = []
best_model_paths = []

### 6.2. Training Models for Various Observed Ports

In [None]:
for n in observed_ports_list:
    print(f"Training model with {n} observed ports...")

    # ----------------------------- Data Preparation ----------------------------- #
    observed_ports_train, _ = get_observed_ports(sinr_train, n)
    observed_ports_test, observed_indices_test = get_observed_ports(sinr_test, n)
    
    X_train = observed_ports_train
    X_test = observed_ports_test
    y_train = sinr_train # Real SINR values
    y_test = sinr_test # Real SINR values
    
    # --------------------------------- Callbacks -------------------------------- #
    tensorboard_callback = TensorBoard(log_dir=LOG_DIR, histogram_freq=1)
    # Run TensorBoard in the terminal using this command: tensorboard --logdir logs/fit
    # Run the command at the folder containing the logs directory

    checkpoint_callback = ModelCheckpoint(
        filepath=os.path.join(CHECKPOINT_DIR, f"best_model_observed_ports_{n}.weights.h5"),
        save_weights_only=True,  # Save only the weights
        monitor="val_mse",  # Use validation MSE to monitor the best model
        mode="min",
        save_best_only=True,  # Save only the best model
        verbose=1,  # Verbosity level for saving process
    )

    early_stopping = EarlyStopping(
        monitor="val_mse",  # The metric to monitor
        mode="min",  # Maximize or minimize the metric
        patience=3,  # Number of epochs with no improvement before stopping
        verbose=1,  # Verbosity level (0 = silent, 1 = report stopping)
        restore_best_weights=True,  # Restore the best weights at the end of training
    )

    # ------------------------- Build and Train the model ------------------------ #
    optimizer = AdamW(learning_rate=8.050158722378705e-05)  # Optimized AdamW settings

    model = build_model(n)
    model.compile(optimizer=optimizer, loss="mse", metrics=["mse"])

    history = model.fit(
        X_train, y_train,
        epochs=EPOCHS,
        batch_size=BATCH_SIZE,
        validation_data=(X_test, y_test),
        callbacks=[
            # early_stopping,
            checkpoint_callback,
            tensorboard_callback,
        ],
        verbose=1,
    )
    
    train_mse_history_list.append(history.history["loss"])  # Store loss values
    val_mse_history_list.append(history.history["val_loss"])  # Store loss values
    plot_mse(history.history["loss"], title=f"Training loss for model with {n} observed ports")
    plot_mse(history.history["val_loss"], title=f"Validation loss for model with {n} observed ports") 
    
    # Store the path of the best model for later loading
    best_model_path = os.path.join(CHECKPOINT_DIR, f"best_model_observed_ports_{n}.weights.h5")
    best_model_paths.append(best_model_path)
    
    clear_session() # Clear up internal variables of tensorflow, a must for loop training

## 7. Results

### 7.1. Plot MSE over epochs

In [None]:
for idx, n in enumerate(observed_ports_list):
    plot_mse(train_mse_history_list[idx], title=f"Training MSE for model with {n} observed ports")
    plot_mse(val_mse_history_list[idx], title=f"Validation MSE for model with {n} observed ports")

### 7.2. Plot the Merged MSE over epochs

In [None]:
# Select a subset of observed ports (e.g., every 5 ports)
STEP = 1
subset_observed_ports_list = observed_ports_list[::STEP]  # This will select every Xth port

# Plot the MSE over epochs for the subset of observed ports
subset_mse_history_list = [train_mse_history_list[i] for i in range(0, len(observed_ports_list), STEP)]
plot_all_mse_over_epochs(
    subset_mse_history_list, subset_observed_ports_list, title="Training MSE over Epochs for All Models"
)

subset_mse_history_list = [val_mse_history_list[i] for i in range(0, len(observed_ports_list), STEP)]
plot_all_mse_over_epochs(
    subset_mse_history_list, subset_observed_ports_list, title="Validation MSE over Epochs for All Models"
)

### 7.3. Load Best Models and Compute Outage

In [None]:
outage_prob_ideal_list = []
outage_prob_reference_list = []
outage_prob_model_list = []

best_model_paths = []
for i, n in enumerate(observed_ports_list):
    best_model_path = os.path.join("weights/dnn", f"best_model_observed_ports_{n}.weights.h5")
    best_model_paths.append(best_model_path)

sinr_train, sinr_test = train_test_split(sinr_data, test_size=0.2, random_state=SEED)

OUTAGE_THRESHOLD = 20

# Load Models and Compute Outage
for i, n in enumerate(observed_ports_list):
    # if i > 40: break
    
    print("Loading: ", best_model_paths[i])
    
    # ----------------------------- Rebuild the model ---------------------------- #
    model = build_model(n)
    
    # Load the best model saved for the current n configuration
    model.load_weights(best_model_paths[i])
    
    # ------------------------------ Preparing data ------------------------------ #
    observed_ports_test, observed_indices_test = get_observed_ports(sinr_test, n)
    SCALER.fit(observed_ports_test)
    X_test = SCALER.transform(observed_ports_test)
    
    # ------------------------ Compute Outage Probability ------------------------ #
    # Be careful with the number of samples used for the calculation, too little
    # and the result will be inaccurate. min = (ideal_op)^-1 * 100

    # Compute ideal outage probability (using all ports)
    outage_prob_ideal = compute_ideal_op(sinr_test, threshold=OUTAGE_THRESHOLD)
    outage_prob_ideal_list.append(outage_prob_ideal)

    # Compute reference outage probability (using observed ports)
    outage_prob_reference = compute_reference_op(observed_ports_test, threshold=OUTAGE_THRESHOLD)
    outage_prob_reference_list.append(outage_prob_reference)
    
    # Calculate the predicted SINR using the model
    predicted_sinr = model.predict(X_test)
    
    # Compute model outage probability using the predictions
    outage_prob_model = compute_model_op(
        predicted_sinr, sinr_test, threshold=OUTAGE_THRESHOLD
    )
    outage_prob_model_list.append(outage_prob_model)

# Plot the outage probabilities
plot_outage_probability(
    outage_prob_ideal_list, outage_prob_reference_list, outage_prob_model_list, observed_ports_list
)