# Information-Maximization Approach to Blind Separation and Blind Deconvolution

## Constants and Functions

### Imports / Libraries

In [None]:
import numpy as np
import cupy as cp
import pandas as pd
import matplotlib.pyplot as plt
import soundfile as sf
import csv
import os
import time as tick
from scipy.signal import firwin, freqz, lfilter

### Initializaze and Load

In [None]:
# Load your input signals (update if needed)
num_trials = 10
x_chirp = np.load("chirp.npy")
x_speech, fs_speech = sf.read("speech.WAV")
x_speech = x_speech.astype(np.float64)

# Example parameters
noise_variances = [0.1, 0.3, 1.0]
M_candidates = [64, 128, 256, 512]
L_candidates = [4, 8, 16, 32, 64]
eta_candidates = [3**(-i) for i in range(2,8)]
est_cdf_candidates = ['sigmoid', 'tanh']

print(fs_speech)
plt.plot(x_speech)

### Create Low-Pass Filter

In [None]:
fc = 1000   # Cutoff frequency (Hz)
numtaps = 501  # Filter order (should be odd)

# Design FIR filter
fir_coeff = firwin(numtaps, fc, fs=fs_speech, pass_zero=True)

# Frequency response
w, h = freqz(fir_coeff, worN=8000)

# Plot Frequency Response
plt.figure(figsize=(8, 4))
plt.plot(w * fs_speech / (2 * np.pi), 20 * np.log10(abs(h)), 'b')
plt.title('FIR Lowpass Filter Frequency Response')
plt.xlabel('Frequency (Hz)')
plt.ylabel('Gain (dB)')
plt.grid()
plt.show()

### Helper Memory-less Monotonic Mapping Functions

In [None]:
def sigmoid(x):
    return 1 / (1 + np.exp(-x))

def sigmoid_derivative(x):
    return sigmoid(x) * (1 - sigmoid(x))

def tanh(x):
    return np.tanh(x)

def tanh_derivative(x):
    return 1 - np.tanh(x)**2

def gaussian(x):
    return np.exp(-x**2)

def gaussian_derivative(x):
    return -2 * x * np.exp(-x**2)

### Signal Loading and Processing

In [None]:
def unknown_plant(x, fir_coeff=None):
    """
    H(z) = (1 - z^-10) / (1 - z^-1)
    y[n] = y[n-1] + x[n] - x[n-10].

    else,
    use Lowpass filters with FIR coefficients.s
    """
    N = len(x)
    y = np.zeros(N)

    if fir_coeff is not None:
        y = lfilter(fir_coeff, 1, x)
    else:
        for n in range(N):
            if n == 0:
                y[n] = x[n]
            elif n < 10:
                y[n] = y[n-1] + x[n]
            else:
                y[n] = y[n-1] + x[n] - x[n-10]
    return y

def add_noise(y, noise_var=0.1, seed=0):
    """
    Add white Gaussian noise of variance = noise_var to y.
    """
    np.random.seed(seed)
    noise = np.sqrt(noise_var) * np.random.randn(len(y))
    return y + noise

### Filter Measurements

In [None]:
def get_w_true(M, FIR_coeffs=None):
    if FIR_coeffs is None:
        w_true = np.ones(10)
    else:
        w_true = FIR_coeffs
    # Keeps the first 10 elements as 1 and pads the rest with zeros
    if M <= 10:
        return w_true
    else:
        w_true = np.concatenate((w_true, np.zeros(M-10)))
        return w_true

def compute_weighted_snr(w_est, FIR_coeffs=None):
    """
    Weighted SNR = 10 * log10( (w_true^T w_true) / ( (w_true - w_est)^T (w_true - w_est) ) ).
    """
    M = len(w_est)
    if FIR_coeffs is not None:
        w_true = get_w_true(M, FIR_coeffs)
    else:
        w_true = get_w_true(M)
    # Pad w_est with zeros if necessary
    if len(w_true) > len(w_est):
        w_est = np.concatenate((w_est, np.zeros(len(w_true)-len(w_est))))
    num = np.dot(w_true, w_true)
    den = np.dot(w_true - w_est, w_true - w_est)
    if den < 1e-15:
        return 999.0
    return 10.0 * np.log10(num / den)

### Helper Plotting Functions

In [None]:
PLOT_DIR = "plots"

CSV_FILENAME = "all_results.csv"

if not os.path.exists(PLOT_DIR):
    os.makedirs(PLOT_DIR)

def plot_and_save(sig_true, sig_pred, figname, title="Comparison"):
    """
    Plots the true signal vs predicted, saves to figname.
    """
    plt.figure(figsize=(8, 4))
    plt.plot(sig_true, label="Actual Output", alpha=0.7)
    plt.plot(sig_pred, label="Predicted Output", alpha=0.7)
    plt.title(title)
    plt.xlabel("Sample index")
    plt.ylabel("Amplitude")
    plt.legend()
    plt.tight_layout()
    plt.savefig(os.path.join(PLOT_DIR, figname), dpi=150)

def plot_comparison(sig_input, sig_noisy, sig_pred, title="", max_pts=300):
    """
    Plots three signals in a single figure:
      1) Input signal
      2) Noisy/observed output
      3) Predicted output

    max_pts: number of samples to show (for clarity).
    """
    length = min(max_pts, len(sig_input), len(sig_noisy), len(sig_pred))
    plt.figure(figsize=(8,4))
    plt.plot(sig_input[:length], label="Input Signal", alpha=0.7)
    plt.plot(sig_noisy[:length], label="Noisy Output", alpha=0.7)
    plt.plot(sig_pred[:length], label="Predicted Output", alpha=0.7)
    plt.title(title)
    plt.xlabel("Sample index")
    plt.ylabel("Amplitude")
    plt.legend()
    plt.tight_layout()
    plt.show()

def plot_with_variance(time, true_signal, pred_mean, pred_std,
                       title, xlabel="Sample index", ylabel="Signal",
                       label_true="True Signal", label_pred="Predicted Signal", 
                       position = 'last', max_pts=500):
    '''
    A plotting function that plots the true signal alongside the mean
    prediction with a shaded region indicating ± one standard deviation.
    '''
    if position == 'first':
        plt.figure(figsize=(10, 6))
        length = min(max_pts, len(time), len(true_signal), len(pred_mean), len(pred_std))
        plt.plot(time[:length], true_signal[:length], label=label_true, color='blue', linewidth=2)
        plt.plot(time[:length], pred_mean[:length], label=label_pred, color='red', linestyle='--', linewidth=2)
        #plt.fill_between(time[:length], pred_mean[:length] - pred_std[:length], pred_mean[:length] + pred_std[:length],color='red', alpha=0.3, label="Prediction ±1 STD")
        plt.xlabel(xlabel)
        plt.ylabel(ylabel)
        plt.title(title)
        plt.legend()
        plt.show()
    elif position == 'last':
        plt.figure(figsize=(10, 6))
        length = min(max_pts, len(time), len(true_signal), len(pred_mean), len(pred_std))
        plt.plot(time[-length:], true_signal[-length:], label=label_true, color='blue', linewidth=2)
        plt.plot(time[-length:], pred_mean[-length:], label=label_pred, color='red', linestyle='--', linewidth=2)
        #plt.fill_between(time[:length], pred_mean[:length] - pred_std[:length], pred_mean[:length] + pred_std[:length],color='red', alpha=0.3, label="Prediction ±1 STD")
        plt.xlabel(xlabel)
        plt.ylabel(ylabel)
        plt.title(title)
        plt.legend()
        plt.show()
    else:
        raise ValueError("position must be either 'first' or 'last'")

## Blind Deconvolution Code

### Blind Deconvolution Algorithms

In [None]:
def info_max_deconvolution_gaussian_pdf(x, w, filter_length=50, learning_rate=0.01, non_linearity='sigmoid'):
    """
    Performs a simple InfoMax-based deconvolution on a 1D time-series signal.
    
    Parameters:
        x              : 1D numpy array
                       The observed time-series signal (assumed to be a convolution of
                       an unknown source with a mixing kernel).
        w              : 1D numpy array of shape (filter_length,)
                       The initial deconvolution filter.
        filter_length  : int, optional (default=50)
                       The number of taps (length) of the deconvolution FIR filter.
        num_iterations : int, optional (default=1000)
                       Number of gradient ascent iterations.
        learning_rate  : float, optional (default=0.01)
                       Step size (learning rate) for gradient updates.
                       
    Returns:
        w              : 1D numpy array of shape (filter_length,)
                       The learned deconvolution filter.
        y_final        : 1D numpy array
                       The deconvolved output signal, computed as the sliding dot–product
                       of the final filter with the signal.
    """

    if len(w) != filter_length:
        raise ValueError("Filter length must match the length of the initial filter w.")
    
    if len(x) != filter_length:
        raise ValueError("Signal x must match the length of the initial filter w")
    
    if w is None or w == np.zeros(filter_length):
        w = np.random.randn(filter_length)
        w = w / np.linalg.norm(w)

    v = x.dot(w)
    
    if non_linearity == 'sigmoid':
        activation = sigmoid
        activation_derivative = sigmoid_derivative
        
        y = activation(v)
        grad = 1/w + x * (1.0 - 2.0 * y)

        w_update = w + learning_rate * grad

    elif non_linearity == 'tanh':
        activation = tanh
        activation_derivative = tanh_derivative

        y = activation(v)
        grad = 1/w - 2 * x * y

        w_update = w + learning_rate * grad
    else:
        raise ValueError("Non-linearity is either not implemented, or invalid. Try 'sigmoid' or 'tanh'.")

    return w_update, y

def info_max_deconvolution_gaussian_pdf_simulation(x, filter_length=50, num_iterations=1000, learning_rate=0.01):
    """
    Performs a simple InfoMax-based deconvolution on a 1D time-series signal.
    
    Parameters:
      x              : 1D numpy array
                       The observed time-series signal (assumed to be a convolution of
                       an unknown source with a mixing kernel).
      filter_length  : int, optional (default=50)
                       The number of taps (length) of the deconvolution FIR filter.
      num_iterations : int, optional (default=1000)
                       Number of gradient ascent iterations.
      learning_rate  : float, optional (default=0.01)
                       Step size (learning rate) for gradient updates.
                       
    Returns:
      w              : 1D numpy array of shape (filter_length,)
                       The learned deconvolution filter.
      y_final        : 1D numpy array
                       The deconvolved output signal, computed as the sliding dot–product
                       of the final filter with the signal.
    """
    # Number of sliding windows (each of length filter_length)
    T = len(x) - filter_length + 1
    if T <= 0:
        raise ValueError("Signal x must be longer than the filter length.")
        
    # Create a matrix whose rows are sliding windows of x.
    # This uses numpy's sliding_window_view (available in NumPy 1.20+).
    try:
        from numpy.lib.stride_tricks import sliding_window_view
        X = sliding_window_view(x, window_shape=filter_length)
    except ImportError:
        # If sliding_window_view is not available, build X manually:
        X = np.array([x[i:i+filter_length] for i in range(T)])
    
    # X has shape (T, filter_length)
    
    # Initialize filter weights (small random numbers) and normalize
    w = np.random.randn(filter_length)
    w = w / np.linalg.norm(w)
    
    # Iteratively update the filter by ascending the (approximate) output entropy
    for it in range(num_iterations):
        # Compute the deconvolved output using the current filter.
        # Each element is a dot–product between w and one row (window) of X.
        y = X.dot(w)  # y.shape is (T,)
        
        # Compute the sigmoid nonlinearity applied to the output
        g_y = sigmoid(y)
        
        # Compute the "score" function, here given by (1-2*g(y))
        error = 1.0 - 2.0 * g_y   # shape (T,)
        
        # Compute the gradient with respect to the filter taps.
        # This is X^T dot error (i.e. each tap k gets gradient = sum_i error[i]*X[i,k])
        grad = X.T.dot(error)
        
        # Update the filter using gradient ascent.
        w += learning_rate * grad
        
        # Normalize the filter to avoid scale divergence.
        w = w / np.linalg.norm(w)
        
        # (Optional) Monitor progress by computing a rough measure of output "entropy"
        # Here we use -mean(log(g*(1-g))) as a proxy.
        if (it % (num_iterations // 10) == 0) or (it == num_iterations - 1):
            # Add a tiny constant to avoid log(0)
            entropy_est = -np.mean(np.log(g_y * (1 - g_y) + 1e-8))
            print(f"Iteration {it}/{num_iterations}, entropy estimate: {entropy_est:.4f}")
    
    # Compute the final deconvolved output.
    y_final = X.dot(w)
    return w, y_final

def info_max_deconvolution_parzen_window_sampler(x, w, M=50, L=10, eta=0.01, est_cdf='sigmoid', kernel='gaussian'):

    '''
    Performs 1 iteration of the InfoMax-based deconvolution on a 1D time-series signal.

    Parameters:
        x              : 1D numpy array (filter_length+1,)
                         The observed time-series signal (assumed to be a convolution of
                         an unknown source with a mixing kernel).
        w              : 1D numpy array of shape (filter_length,)
                         The initial deconvolution filter.
        M              : int, optional (default=50)
                         The number of taps (length) of the deconvolution FIR filter.
        L              : int, optional (default=10)
                         The number of windows to use in the Parzen window entropy estimate.
        eta            : float, optional (default=0.01)
                         Step size (learning rate) for gradient updates.
        non_linearity  : string, optional (default='sigmoid')
                         The non-linearity applied to the output. Currently supports 'sigmoid' and 'tanh'.
        kernel         : string, optional (default='gaussian')
                         The kernel function applied to the non-linear activation. Currently supports 'gaussian'.

    Returns:
        w_new          : 1D numpy array of shape (filter_length,)
                         The updated deconvolution filter.
    '''

    if len(w) != M:
        raise ValueError("M must match the length of the initial filter w.")
    
    if len(x) != M + L:
        # There should be L+1 windows of length M in x
        raise ValueError("Signal x must match the length of the initial filter M + L previous windows")
    
    if w is None or np.allclose(w, 0):
        w = np.random.randn(M)
        w = w / np.linalg.norm(w)

    # Checks for cdf mappers
    if est_cdf == 'sigmoid':
        activation = sigmoid
        activation_derivative = sigmoid_derivative
    elif est_cdf == 'tanh':
        activation = tanh
        activation_derivative = tanh_derivative
    else:
        raise ValueError("Non-linearity is either not implemented, or invalid. Try 'sigmoid' or 'tanh'.")
    # Checks for kernel functions
    if kernel == 'gaussian':
        kernel = gaussian
        kernel_derivative = gaussian_derivative
    else:
        raise ValueError("Window is either not implemented, or invalid. Try 'gaussian'.")
    
    # Get the last window
    x_fin = x[L:L+M]
    # Get the previous L windows (not including last, so L+1 windows total)
    X_windows = np.lib.stride_tricks.sliding_window_view(x, M)[:L] # L x M matrix
    # Get the prediction for the current window
    y_fin = np.dot(x_fin,w) # 1 x 1 matrix
    # Parallelized prediction for accelerated compute
    Y_fin = np.full((L, 1), y_fin)  # L x 1 matrix
    # Get the predictions for the previous L windows
    Y_windows = np.reshape(X_windows@w, (L, 1)) # L x 1 matrix
    # Get the cumulative density approximations for the current and previous windows
    Z_fin = activation(Y_fin) # L x 1 matrix
    Z_windows = activation(Y_windows) # L x 1 matrix
    # Compute the probabilistic mass approximations for the current and previous windows
    xt_fin = x_fin * (activation_derivative(y_fin)) # 1 x M matrix
    Xt_fin = np.repeat(xt_fin[np.newaxis, :], L, axis=0) # L x M matrix
    Xt_windows = X_windows * activation_derivative(Y_windows)  # L x M matrix
    # Compute the kernel weights for the current and previous windows
    kernel_weight = kernel_derivative(Z_fin-Z_windows) # L x 1 matrix
    chain_rule_kernel_weight = Xt_fin - Xt_windows # L x M matrix
    entropy_grad = np.mean(kernel_weight * chain_rule_kernel_weight, axis=0) # 1 x M matrix
    # Update the filter weights
    w_new = w - eta * entropy_grad

    return w_new

def info_max_deconvolution_parzen_window_sampler_simulation(x, M=50, L=10, eta=0.01, est_cdf='sigmoid', kernel='gaussian'):
    """
    Iterates the InfoMax-based deconvolution over the entire input signal.
    At each step a segment of length M+L is extracted and processed by
    info_max_deconvolution_parzen_window_sampler.
    
    The function returns three arrays (of the same length as the input signal):
      - weights_full: each row contains the deconvolution filter (of length M) at that iteration.
      - predictions_full: the prediction (i.e. dot-product of the current window with the filter).
      - errors_full: the norm difference between successive filter estimates.
      
    Parameters:
        x         : 1D numpy array
                    The observed time-series signal.
        M         : int, optional (default=50)
                    Number of taps (length) of the deconvolution FIR filter.
        L         : int, optional (default=10)
                    Number of past windows to use in the Parzen-window entropy estimate.
        eta       : float, optional (default=0.01)
                    Learning rate (step size).
        est_cdf   : string, optional (default='sigmoid')
                    Nonlinearity to use ('sigmoid' or 'tanh').
        kernel    : string, optional (default='gaussian')
                    Kernel function to use (currently only 'gaussian' is supported).
                    
    Returns:
        weights_full    : np.ndarray of shape (N, M)
                          Filter weight history (embedded into an array of length N).
        predictions_full: np.ndarray of shape (N,)
                          Predictions (the deconvolved outputs) at the embedded indices.
        errors_full     : np.ndarray of shape (N,)
                          Norm difference between successive filters at the embedded indices.
    """
    x = np.asarray(x)
    N = len(x)
    required_length = M + L
    if N < required_length:
        # Pad the input if it is too short for even one iteration.
        x = np.pad(x, (0, required_length - N), mode='constant')
        N = len(x)
    
    # Initialize the deconvolution filter randomly (normalized).
    w = np.random.randn(M)
    w = w / np.linalg.norm(w)
    
    # Determine the number of iterations (each iteration uses M+L samples).
    n_iters = N - (M + L) + 1
    
    # Lists to store the per-iteration results.
    weight_tracks = []
    predictions = []
    errors = []
    
    # Iterate over all valid segments of the signal.
    for i in range(n_iters):
        # Extract the segment of length M+L.
        segment = x[i: i + M + L]
        
        # Compute the current prediction using the last M samples of the segment.
        x_fin = segment[L:L+M]
        current_prediction = np.dot(x_fin, w)
        
        # Call your per-iteration function to get the updated filter.
        new_w = info_max_deconvolution_parzen_window_sampler(segment, w, M, L, eta, est_cdf, kernel)
        
        # Compute an error measure (here the change in the filter).
        update_error = np.linalg.norm(new_w - w)
        
        # Save the current results.
        predictions.append(current_prediction)
        weight_tracks.append(new_w.copy())
        errors.append(update_error)
        
        # Update the filter for the next iteration.
        w = new_w
    
    # --- Embed the iteration results into full-length arrays ---
    weights_full = np.full((N, M), np.nan)
    predictions_full = np.full(N, np.nan)
    errors_full = np.full(N, np.nan)
    
    # We assign the result from iteration i to the index corresponding to the "center"
    # of the segment, i.e., i + L + M//2.
    for i in range(n_iters):
        center_index = i + L + M // 2
        if center_index < N:
            weights_full[center_index] = weight_tracks[i]
            predictions_full[center_index] = predictions[i]
            errors_full[center_index] = errors[i]
    
    return weights_full, predictions_full, errors_full


### Cross-Validation Grid Search for Blind Deconvolution

In [None]:
def grid_search_deconvolution(x, d, M=50, L=10, candidate_etas=None, candidate_est_cdfs=None, num_trials=10):
    """
    Grid-search for the optimal learning rate (eta) and/or non-linearity (est_cdf) for the
    InfoMax-based deconvolution, using total normalized signal error (NMSE) as the evaluation metric.

    Parameters:
      x                 : 1D numpy array of input samples.
      d                 : 1D numpy array of true output signals.
      M                 : int, filter order (length of the deconvolution filter).
      L                 : int, number of windows used in the Parzen-window entropy estimate.
      candidate_etas    : List (or iterable) of candidate learning rates (η).  
                          If None, a default list is used.
      candidate_est_cdfs: List (or iterable) of candidate non-linearities (e.g., ['sigmoid', 'tanh']).
                          If None, a default list is used.
      num_trials        : int, number of trials per candidate configuration.
      
    Returns:
      best_params       : Tuple (best_eta, best_est_cdf) that yields the lowest average NMSE.
      best_trial_outputs: List of simulation outputs (a dict for each trial) for the best candidate.
      performance_dict  : Dictionary mapping each candidate (eta, est_cdf) to a tuple 
                         (avg_nmse, std_nmse) computed over num_trials.
    """
    # Set default candidate lists if not provided.
    if candidate_etas is None:
        candidate_etas = [0.001, 0.005, 0.01, 0.05, 0.1]
    if candidate_est_cdfs is None:
        candidate_est_cdfs = ['sigmoid', 'tanh']
    
    performance_dict = {}
    best_avg_nmse = np.inf
    best_params = None
    best_trial_outputs = None

    for eta_candidate in candidate_etas:
        for est_candidate in candidate_est_cdfs:
            trial_nmse_errors = []
            trial_outputs = []  # To store simulation outputs for each trial.
            
            for trial in range(num_trials):
                weights_full, predictions_full, errors_full = info_max_deconvolution_parzen_window_sampler_simulation(
                    x, M=M, L=L, eta=eta_candidate, est_cdf=est_candidate, kernel='gaussian')
                
                # Compute the NMSE over the entire prediction sequence
                valid_mask = ~np.isnan(predictions_full)
                
                x_valid = x[valid_mask]
                y_pred = predictions_full[valid_mask]
                d_valid = d[valid_mask]
                #d_valid_amp = np.max(np.abs(d_valid))
                x_valid_norm = np.linalg.norm(x_valid)
                y_pred_norm = np.linalg.norm(y_pred)
                d_valid_norm = np.linalg.norm(d_valid)

                if len(y_pred) > 0:
                    mse = np.mean((x_valid - y_pred*(x_valid_norm/y_pred_norm)) ** 2)
                    norm_factor = np.mean(x_valid ** 2) + 1e-15  # Avoid division by zero
                    nmse = mse / norm_factor
                else:
                    nmse = np.nan
                
                trial_nmse_errors.append(nmse)
                trial_outputs.append({
                    'weights_full': weights_full,
                    'predictions_full': predictions_full,
                    'errors_full': errors_full
                })

            avg_nmse = np.nanmean(trial_nmse_errors)  # Average NMSE across trials
            std_nmse = np.nanstd(trial_nmse_errors)   # Standard deviation of NMSE
            performance_dict[(eta_candidate, est_candidate)] = (avg_nmse, std_nmse)

            print(f"Candidate eta={eta_candidate}, est_cdf={est_candidate} --> "
                  f"Avg NMSE: {avg_nmse:.6f} +/- {std_nmse:.6f}")

            if avg_nmse < best_avg_nmse:
                best_avg_nmse = avg_nmse
                best_params = (eta_candidate, est_candidate)
                best_trial_outputs = trial_outputs

    print(f"\nBest parameters: eta={best_params[0]}, est_cdf={best_params[1]} "
          f"with average NMSE: {best_avg_nmse:.6f}")
    
    return best_avg_nmse, best_params, best_trial_outputs, performance_dict


## Results for different M and L values

### Testing Script

In [None]:
# I was using this block to test stuff out, ignore this

### True Search

In [None]:
# --- Grid-search adapted script ---
best_params_list = []  # will store best (eta, est_cdf) tuples for each candidate run

# Prepare a DataFrame to collect performance summary.
p1_df_deconv = pd.DataFrame(columns=['Input', 'Noise Variance', 'Model Order',
                                       'Best Eta', 'Best est_cdf', 'W-SNR', 'NMSE'])

for input_name, x_in in zip(["speech"], [x_speech]):
    print(f"\n=== {input_name.upper()} Input ===")
    # Get the noise-free output from the unknown plant.
    N = len(x_in)
    best_y_pred = None
    best_nmse = np.inf
    
    for nv in noise_variances:
        # Add noise to the plant input.
        x_noisy = add_noise(x_in, noise_var=nv)
        y_noisy = unknown_plant(x_noisy, fir_coeff=fir_coeff)

        y_noisy_mag = np.abs(np.max(y_noisy))
        print(f"  Noise variance={nv}")
        
        for M in M_candidates:

            for L in L_candidates:

                # Use the full available data for this candidate.
                x_w = x_in[:N]
                y_w = y_noisy[:N]
                
                t_start = tick.perf_counter()
                # Run grid search for deconvolution over the candidate learning rates and non-linearities.
                best_avg_nmse, best_params, best_trial_outputs, performance = grid_search_deconvolution(
                    x_w, y_w, M=M, L=L,
                    candidate_etas=eta_candidates,
                    candidate_est_cdfs=est_cdf_candidates,
                    num_trials=num_trials
                )
                t_stop = tick.perf_counter()
                
                best_eta, best_est_cdf = best_params
                best_params_list.append(best_params)
                
                # Each element in best_trial_outputs is a dict with keys:
                # 'weights_full', 'predictions_full', 'errors_full'
                trial_predictions = np.array([trial_out['predictions_full'] for trial_out in best_trial_outputs])
                trial_weights     = np.array([trial_out['weights_full'] for trial_out in best_trial_outputs])
                
                # Compute the overall (across-trial) mean and std of the predictions.
                avg_preds = np.mean(trial_predictions, axis=0)
                std_preds = np.std(trial_predictions, axis=0)
                
                # Split predictions into training and validation segments (50% each).
                train_len = int(0.5 * len(y_w))
                y_train_true = y_w[:train_len]
                #y_train_max = np.max(np.abs(y_train_true))
                x_val = x_w[train_len:]
                y_val_true   = y_w[train_len:]
                #y_val_max = np.max(np.abs(y_val_true))
                x_val_norm = np.linalg.norm(x_val)
                y_val_norm = np.linalg.norm(y_val_true)
                
                train_preds_trials = trial_predictions[:, :train_len]
                val_preds_trials   = trial_predictions[:, train_len:]
                
                avg_train_pred = np.mean(train_preds_trials, axis=0)
                std_train_pred = np.std(train_preds_trials, axis=0)
                avg_val_pred   = np.mean(val_preds_trials, axis=0)
                std_val_pred   = np.std(val_preds_trials, axis=0)
                #avg_val_pred_max = np.max(np.abs(avg_val_pred))
                nonnan_train_pred = avg_train_pred[~np.isnan(avg_train_pred)]
                nonnan_val_pred = avg_val_pred[~np.isnan(avg_val_pred)]
                avg_train_pred_norm = np.linalg.norm(nonnan_train_pred)
                avg_val_pred_norm = np.linalg.norm(nonnan_val_pred)
                
                print(f"  NMSE = {best_avg_nmse:.4f}")
                
                # Compute the Weighted SNR using the true impulse response and the final estimated filter.
                final_weights = []
                for trial_out in best_trial_outputs:
                    weights_full = trial_out['weights_full']
                    # Find the last (non-NaN) row of weights.
                    valid_indices = np.where(~np.isnan(weights_full[:, 0]))[0]
                    if len(valid_indices) > 0:
                        final_w = weights_full[valid_indices[-1]]
                        final_weights.append(final_w)
                if final_weights:
                    final_w_est = np.mean(np.array(final_weights), axis=0)
                    w_snr = compute_weighted_snr(final_w_est, fir_coeff)
                else:
                    w_snr = np.nan
                
                print(f"M={M}, best_eta={best_eta}, best_est_cdf={best_est_cdf}, W-SNR={w_snr:.2f} dB, "
                    f"NMSE={best_avg_nmse:.6f}, Time to find best params={t_stop-t_start:.2f} s")
                
                title_str_val = (f"{input_name.upper()}, noise={nv}, M={M}, eta={best_eta}, est_cdf={best_est_cdf}\n"
                                f"W-SNR_val={w_snr:.2f} dB, NMSE={best_avg_nmse:.6f}")
                
                # Plot the validation portion: true vs. predicted (with variance).
                plot_with_variance(
                    time=np.arange(len(y_val_true)),
                    true_signal=x_val/ x_val_norm,
                    pred_mean=avg_val_pred / avg_val_pred_norm,
                    pred_std=std_val_pred / np.sqrt(avg_val_pred_norm),
                    title="Validation Signal Comparison: " + title_str_val,
                    xlabel="Validation Sample Index",
                    ylabel="Signal",
                    position='first'
                )
                
                # Plot the evolution of the weights over time.
                trial_w_tracks = np.array([trial_out['weights_full'] for trial_out in best_trial_outputs])
                avg_w_track = np.nanmean(trial_w_tracks, axis=0)
                std_w_track = np.nanstd(trial_w_tracks, axis=0)
                
                plt.figure(figsize=(10, 5))
                time_steps = np.arange(avg_w_track.shape[0])
                for m in range(M):
                    color = None
                    plt.plot(time_steps, avg_w_track[:, m], label=f"w[{m}]", color=color)
                    plt.fill_between(time_steps,
                                    avg_w_track[:, m] - std_w_track[:, m],
                                    avg_w_track[:, m] + std_w_track[:, m],
                                    alpha=0.2)
                plt.title(f"Mean Weight Evolution Over Time (Validation) (M={M}, eta={best_eta}, est_cdf={best_est_cdf})")
                plt.xlabel("Time Steps")
                plt.ylabel("Weight Value")
                plt.legend()
                plt.grid()
                plt.tight_layout()
                fig_wtrack = f"chirp_deconv_wtrack_M_{M}_eta_{best_eta}_estcdf_{best_est_cdf}.png"
                plt.savefig(os.path.join(PLOT_DIR, fig_wtrack), dpi=150)
                plt.show()

                if best_avg_nmse < best_nmse:
                    print(f"  ** New best NMSE found: {best_avg_nmse:.6f} **")
                    best_nmse = best_avg_nmse
                    best_y_pred = np.concatenate([avg_train_pred, avg_val_pred])
                
                # Append the results to the DataFrame.
                new_row = pd.DataFrame([[input_name, nv, M, best_eta, best_est_cdf, w_snr, best_avg_nmse]],
                                    columns=p1_df_deconv.columns)
                p1_df_deconv = pd.concat([p1_df_deconv, new_row], ignore_index=True)

    # Save the best prediction for this signal.
    out_fname = f"speech_imax_pred.wav"
    nonnan_best_y_pred = best_y_pred[~np.isnan(best_y_pred)]
    sf.write(out_fname, nonnan_best_y_pred.reshape(-1, 1), fs_speech)

# Report the average best eta over all candidate runs.
avg_best_eta = np.mean([params[0] for params in best_params_list])
print(f"\nAverage best eta: {avg_best_eta:.6f}")

# Optionally, save the DataFrame of results.
p1_df_deconv.to_csv(os.path.join(PLOT_DIR, "deconv_grid_search_results.csv"), index=False)


In [None]:
# Write the noisy signal to a file for reference.
sf.write("speech_noisy.wav", y_noisy.reshape(-1, 1), fs_speech)

In [None]:
p1_df_deconv