# Phase 3

In [None]:
# Import Libraries
import os
import soundfile as sf
import scipy.signal as ss
from scipy.io.wavfile import write
from IPython.display import Audio, display
import matplotlib.pyplot as plt
import numpy as np
import pickle

In [None]:
# Set up directories for datasets and designs
current_dir = os.getcwd()
base_dir = os.path.dirname(current_dir)
data_dir = os.path.join(base_dir, 'data')
results_dir = os.path.join(base_dir, 'results')
cleaned_data_dir = os.path.join(base_dir, 'data', 'cleaned_dataset')
designs_dir = os.path.join(base_dir, 'designs')

In [None]:
# Define sampling frequency and frequency range
FS = 16000  # Sampling frequency in Hz
F_MIN = 100  # Minimum frequency in Hz
F_MAX = 8000  # Maximum frequency in Hz

In [None]:
def hz_to_erb(f):
    """Convert frequency in Hz to ERB number."""
    return 21.4 * np.log10(4.37e-3 * f + 1)

def erb_to_hz(erb):
    """Convert ERB number to frequency in Hz."""
    return (10**(erb / 21.4) - 1) / 4.37e-3

def erb_width(f):
    """Calculate the ERB width for a given center frequency."""
    return 24.7 * (4.37e-3 * f + 1)

In [None]:
def get_center_frequencies(num_channels, lowcut, highcut, distr='linear', overlap=0.0):
    """
    Calculate center frequencies and bandwidths for a filter bank with specified overlap.

    Args:
        num_channels (int): Number of filters (channels) in the bank.
        lowcut (float): Lower frequency bound in Hz.
        highcut (float): Upper frequency bound in Hz.
        distr (str): Distribution type ('linear', 'log2', 'log10', 'erb').
        overlap (float): Desired overlap between adjacent filters as a fraction (e.g., 0.1 for 10%).

    Returns:
        tuple: Two numpy arrays containing center frequencies and bandwidths.
    """
    if distr == 'linear':
        edges = np.linspace(lowcut, highcut, num_channels + 1)
    elif distr == 'log10':
        edges = np.logspace(np.log10(lowcut), np.log10(highcut), num_channels + 1)
    elif distr == 'log2':
        edges = np.logspace(np.log2(lowcut), np.log2(highcut), num_channels + 1, base=2.0)
    elif distr == 'erb':
        # Convert frequency bounds to ERB scale
        erb_low = hz_to_erb(lowcut)
        erb_high = hz_to_erb(highcut)
        erb_edges = np.linspace(erb_low, erb_high, num_channels + 1)
        # Convert ERB edges back to Hz
        edges = erb_to_hz(erb_edges)
    else:
        raise ValueError(f"Distribution '{distr}' not implemented.")

    # Calculate center frequencies
    center_freqs = (edges[:-1] + edges[1:]) / 2.0

    # Calculate bandwidths
    bandwidths = np.diff(edges)

    # Adjust bandwidths for overlap
    adjusted_bandwidths = bandwidths * (1 + overlap)

    return center_freqs, adjusted_bandwidths


In [None]:
def save_fbank(fbank, dir_path, file_name):
    """
    Save the filter bank to a file.

    Args:
        fbank (list): The filter bank to save.
        dir_path (str): Directory path where the file will be saved.
        file_name (str): Name of the file to save.
    """
    os.makedirs(dir_path, exist_ok=True)
    file_path = os.path.join(dir_path, f"{file_name}.pkl")

    with open(file_path, 'wb') as file:
        pickle.dump(fbank, file)

In [None]:
def load_fbank(file_path):
    with open(file_path, 'rb') as file:
        bank = pickle.load(file)
    return bank

In [None]:
def load_sound_file(file_path):
    """
    Load an audio file and return its data and sample rate.

    Args:
        file_path (str): Path to the audio file.

    Returns:
        tuple: A tuple containing:
            - data (numpy.ndarray): Audio data.
            - sample_rate (int): Sample rate of the audio file.
    """
    try:
        data, sample_rate = sf.read(file_path)
        return data, sample_rate
    except Exception as e:
        print(f"Error loading sound file: {e}")
        return None, None

In [None]:
def apply_fir_filter(input_signal, filter_coefficients):
    """
    Apply an FIR filter to the input signal.

    Args:
        input_signal (numpy.ndarray): The input audio signal.
        filter_coefficients (numpy.ndarray): The FIR filter coefficients.

    Returns:
        numpy.ndarray: The filtered signal.
    """
    # Ensure filter coefficients are one-dimensional
    b = filter_coefficients
    
    return ss.lfilter(b, [1.0], input_signal)

def apply_iir_filter(input_signal, b_coefficients, a_coefficients):
    """
    Apply an IIR filter to the input signal.

    Args:
        input_signal (numpy.ndarray): The input audio signal.
        b_coefficients (numpy.ndarray): The numerator (b) coefficients of the IIR filter.
        a_coefficients (numpy.ndarray): The denominator (a) coefficients of the IIR filter.

    Returns:
        numpy.ndarray: The filtered signal.
    """
    # Ensure the input signal is one-dimensional
    
    return ss.lfilter(b_coefficients, a_coefficients, input_signal)


In [None]:
def apply_filter(input_signal, filter_bank):
    """
    Apply a filter bank to an input signal.

    Args:
        input_signal (numpy.ndarray): The input audio signal.
        filter_bank (list): A list of dictionaries containing filter coefficients.

    Returns:
        list: A list of filtered signals corresponding to each filter in the filter bank.
    """
    filtered_signals = []
    
    for idx, filter in enumerate(filter_bank):
        if 'taps' in filter:
            # FIR filter
            #print(f"Filter {idx} is an FIR filter with taps shape: {filter['taps'].shape}")
            filtered_signal = apply_fir_filter(input_signal, filter['taps'])
        elif 'b' in filter and 'a' in filter:
            # IIR filter
            #print(f"Filter {idx} is an IIR filter with b shape: {filter['b'].shape} and a shape: {filter['a'].shape}")
            filtered_signal = apply_iir_filter(input_signal, filter['b'], filter['a'])
        else:
            raise ValueError(f"Filter {idx} does not contain recognized coefficients.")
        filtered_signals.append(filtered_signal)
    return filtered_signals


## Gammatone Filter Bank

### **Filter Type**: IIR and FIR

#### **Reasoning**:
Gammatone filters mimic the auditory processing of the human cochlea, making them a natural choice for applications involving speech and hearing. They are especially effective in cochlear implants as they align with how humans perceive sound frequencies.

### **Final Design Parameters**:
- **Number of Bands (N)**:
  - **32**: High-resolution frequency analysis for precise auditory processing.
- **Frequency Distribution (DISTR)**:
  - **ERB (Equivalent Rectangular Bandwidth)**: Matches the auditory filter shapes of the human cochlea for a perceptually accurate design.
- **Filter Type**:
  - **IIR**: Computationally efficient, suitable for real-time applications.
  - **FIR**: Ensures phase linearity for better sound fidelity.
- **Filter Order**:
  - **4**: Provides a balance between sharpness and computational complexity.
- **Overlap**:
  - **0.1**: Ensures minimal frequency gaps while keeping computational load manageable.
- **Frequency Range**:
  - **100 Hz to 8000 Hz**: Captures the full range of human speech and relevant environmental sounds.
- **Sampling Frequency (FS)**:
  - **16,000 Hz**: High fidelity for audio signal processing.

In [1]:
def design_gammatone_fbank(num_channels, lowcut, highcut, distr, overlap, fs, ftype='iir', order=4):
    """
    Design a gammatone filter bank with specified parameters.

    Args:
        num_channels (int): Number of filters (channels) in the bank.
        lowcut (float): Lower frequency bound in Hz.
        highcut (float): Upper frequency bound in Hz.
        distr (str): Distribution type ('linear', 'log2', 'log10', 'erb').
        overlap (float): Desired overlap between adjacent filters as a fraction (e.g., 0.1 for 10%).
        fs (float): Sampling frequency in Hz.
        order (int): Order of the gammatone filter.

    Returns:
        list: A list of filters in the filter bank, each represented by its coefficients.
    """
    center_freqs, bandwidths = get_center_frequencies(num_channels, lowcut, highcut, distr, overlap)
    fbank = []

    for center in center_freqs:
        # Design gammatone filter
        b, a = ss.gammatone(center, ftype=ftype, order=order, fs=fs)
        fbank.append({
            'center_freq': center,
            'b': b,
            'a': a
        })

    return fbank

In [None]:
# Parameters
num_channels = 32
lowcut = 100.0  # Hz
highcut = 8000.0  # Hz
distr = 'erb'
overlap = 0.1
fs = 16000.0  # Sampling frequency in Hz
order = 4

# Design the gammatone filter bank
gammatone_iir_fbank = design_gammatone_fbank(num_channels, lowcut, highcut, distr, overlap, fs, ftype='iir', order=order)
gammatone_fir_fbank = design_gammatone_fbank(num_channels, lowcut, highcut, distr, overlap, fs, ftype='fir', order=order)

# Save the filter banks to a file
gamma_iirbank_name = f'gamma_N_{num_channels}_distr_{distr}_order_{order}_filter_type_iir'
save_fbank(gammatone_iir_fbank, dir_path='filter_banks', file_name= gamma_iirbank_name)

gamma_firbank_name = f'gamma_N_{num_channels}_distr_{distr}_order_{order}_filter_type_fir'
save_fbank(gammatone_fir_fbank, dir_path='filter_banks', file_name=gamma_firbank_name)


In [None]:
rectified_signals_dict = {}

# Iterate over each filtered signal
for name, signal in filtered_signals_dict.items():
    
    # Rectify the signal by taking its absolute value
    rectified_signal = np.abs(signal['signal'])
    center_frequency = signal['center_freq']
    bandwidths = signal['bandwidth']
    
    new_signal = {
        'signal': rectified_signal,
        'center_freq': center_frequency,
        'bandwidth': bandwidths
    }
    # Store the rectified signal in the dictionary
    rectified_signals_dict[name] = new_signal

## Design 2: Butterworth Low-Pass Filter (LPF)

### **Objective**: Maximizing Computational Efficiency

#### **Filter Type**: IIR Filter

#### **Reasoning**:
The Butterworth filter is renowned for its maximally flat frequency response in the passband, ensuring minimal signal distortion. Implementing a second-order Butterworth IIR filter provides an efficient solution with a gentle roll-off, effectively balancing performance and computational load. This design is particularly advantageous for real-time applications, such as cochlear implants, where processing resources are limited.

---

### **Final Design Parameters**:
- **Sampling Frequency (fs)**: 
  - **16,000 Hz**: A standard rate suitable for speech processing applications.
- **Cutoff Frequency (fc)**: 
  - **400 Hz**: Targets the essential frequency range for speech envelope extraction.
- **Filter Order**: 
  - **2**: Provides a balance between adequate frequency separation and low computational complexity.

### **Sources**:

[Mathworks](https://www.mathworks.com/help/dsp/ug/lowpass-filter-design.html?utm_source=chatgpt.com)

In [None]:
# Design parameters
fs = 16000  # Sampling frequency
fc = 400    # Cutoff frequency
order = 8   # Filter order

In [None]:
# Design the Butterworth filter
b, a = ss.butter(order, fc / (0.5 * fs), btype='low', analog=False)
butter_lpf_8 = {
    'b': b,
    'a': a
}

In [None]:
# Initialize a dictionary to store the filtered signals
filtered_signals_lpf_dict = {}

# Iterate over each signal in rectified_signals_dict
for signal_name, signal_data in rectified_signals_dict.items():
    # Retrieve the signal, center frequencies, and bandwidths
    signal = signal_data['signal']
    center_frequencies = signal_data['center_freq']
    bandwidths = signal_data['bandwidth']

    # Ensure the signal is a numpy array
    signal = np.array(signal)

    # Apply each low-pass filter in lpf_banks
    for filter_name, filter_params in lpf_banks.items():
        # Initialize a list to hold filtered signals for each channel
        filtered_signal = []

        # Apply the LPF to each channel individually
        for ch_signal in signal:
            # Apply the LPF using the apply_filter function
            # Since apply_filter returns a list, we take the first element
            ch_filtered_signal = apply_filter(ch_signal, [filter_params])[0]
            filtered_signal.append(ch_filtered_signal)

        # Convert the list of filtered signals to a numpy array
        filtered_signal = np.array(filtered_signal)

        # Store the filtered signal with its metadata
        combined_name = f"{signal_name}_{filter_name}"
        filtered_signals_lpf_dict[combined_name] = {
            'signal': filtered_signal,
            'center_freq': center_frequencies,
            'bandwidth': bandwidths
        }
