# FIR filter logic notebook - anthony

## real or synthetic data

if USE_REAL_DATA == True, it will parse a nvbx .bin file from DATA_FILE_PATH

else it will generate square wave (60s) with 10% gaussian noise added

In [88]:
import numpy as np
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import struct
import time

# Data 
USE_REAL_DATA = True  # Set to True to use real binary data, False for synthetic data

# Real data path (only used if USE_REAL_DATA = True)
DATA_FILE_PATH = r"path_to_channel.bin"

# Signal parameters
SAMPLE_RATE = 51200  # Hz
BASE_FREQ = 1000.7   # Hz

## Helper Functions to parse .bin & find zero crossings of signal

In [89]:
def read_binary_file(file_path):
    """parse .bin file """
    with open(file_path, 'rb') as file:
        buffer = file.read()
        num_samples = len(buffer) // 4
        data = np.array(struct.unpack(f'<{num_samples}f', buffer))
        print(f"Total samples read: {len(data):,}")
        return data

def find_zero_crossings(data, falling=True):
    """Find zero crossings with interpolation."""
    if falling:
        crossings = np.where((data[:-1] > 0) & (data[1:] < 0))[0]
    else:
        crossings = np.where((data[:-1] < 0) & (data[1:] > 0))[0]
    
    refined_crossings = []
    for idx in crossings:
        y1, y2 = data[idx], data[idx + 1]
        frac = -y1 / (y2 - y1)  # Linear interpolation
        refined_crossings.append(idx + frac)
    
    return np.array(refined_crossings)

## Load or Generate Data

In [90]:
# Load or generate data based on configuration
if USE_REAL_DATA:
    print("Loading real data...")
    data = read_binary_file(DATA_FILE_PATH)
else:
    print("Generating synthetic data...")
    duration = 60.0  # seconds
    noise_std = 0.1
    t = np.arange(int(duration * SAMPLE_RATE)) / SAMPLE_RATE
    clean_signal = np.sign(np.sin(2 * np.pi * BASE_FREQ * t))
    data = clean_signal + np.random.normal(0, noise_std, len(clean_signal))

Loading real data...
Total samples read: 15,677,440


## Plot a few cycles of real or synth waveform for qaqc

In [91]:
def plot_raw_signal(data, sample_rate, base_freq):
    """Plot first 5 cycles of the raw signal using Plotly."""
    samples_per_cycle = int(sample_rate / base_freq)
    plot_samples = samples_per_cycle * 5
    
    fig = go.Figure()
    fig.add_trace(go.Scatter(
        x=np.arange(plot_samples),
        y=data[:plot_samples],
        mode='lines',
        name='Raw Signal',
        line=dict(color='blue')
    ))
    
    fig.update_layout(
        title='First 5 Cycles of Raw Signal',
        xaxis_title='Sample Index',
        yaxis_title='Amplitude',
        showlegend=True,
        width=1000,
        height=500,
        hovermode='x unified'
    )
    
    fig.show()

# Plot raw signal
plot_raw_signal(data - np.mean(data), SAMPLE_RATE, BASE_FREQ)  # Remove DC offset for better visualization

## Identifying cycles and stacking 

In [92]:
def extract_cycle(data, zero_crossing_idx, expected_samples):
    """Extract one complete cycle using zero crossings."""
    search_range = int(expected_samples * 0.6)
    
    # Find previous crossing
    for i in range(int(zero_crossing_idx) - 1, max(0, int(zero_crossing_idx) - search_range), -1):
        if data[i] * data[i-1] <= 0:
            prev_crossing = i
            break
    else:
        return None
    
    # Find next crossing
    for i in range(int(zero_crossing_idx) + 1, min(len(data)-1, int(zero_crossing_idx) + search_range)):
        if data[i] * data[i+1] <= 0:
            next_crossing = i
            break
    else:
        return None
    
    # Extract cycle data between crossings
    start_idx = prev_crossing + 1
    end_idx = next_crossing - 1
    
    if start_idx < end_idx and end_idx < len(data):
        return data[start_idx:end_idx+1]
    
    return None

def stack_cycles(data, sample_rate=51200, base_freq=1000.7):
    """Stack multiple cycles to reduce noise."""
    data = data - np.mean(data)  # Remove DC offset
    expected_samples = int(sample_rate / base_freq)
    print(f"Expected samples per cycle: {expected_samples}")
    
    # Find valid cycles
    falling_edges = find_zero_crossings(data, falling=True)
    print(f"Found {len(falling_edges):,} falling zero crossings")
    
    # Find cycles with length close to expected
    valid_edges = []
    for i in range(len(falling_edges)-1):
        spacing = falling_edges[i+1] - falling_edges[i]
        if 0.95 * expected_samples <= spacing <= 1.05 * expected_samples:
            valid_edges.append(falling_edges[i])
    
    print(f"Number of valid cycles found: {len(valid_edges):,}")
    
    # Stack cycles
    stacked = None
    num_cycles = 0
    intermediate = []
    process_start = time.time()
    
    for edge in valid_edges:
        cycle = extract_cycle(data, edge, expected_samples)
        if cycle is not None:
            # Resample to expected length before stacking
            x = np.linspace(0, 1, len(cycle))
            x_resampled = np.linspace(0, 1, expected_samples)
            cycle_resampled = np.interp(x_resampled, x, cycle)
            
            if stacked is None:
                stacked = cycle_resampled
            else:
                stacked = (stacked * num_cycles + cycle_resampled) / (num_cycles + 1)
            num_cycles += 1
            
            # Save intermediate results for visualization
            if num_cycles in [5, 10, 50, 100, 1000, 10000, 50000]:
                intermediate.append((num_cycles, stacked.copy()))
            
            if num_cycles % 10000 == 0:
                elapsed = time.time() - process_start
                print(f"Processed {num_cycles:,} cycles ({elapsed:.1f}s)")
    
    print(f"\nProcessing complete:")
    print(f"Total cycles processed: {num_cycles:,}")
    print(f"Processing time: {time.time() - process_start:.2f}s")
    
    return stacked, intermediate

# Stack cycles and get intermediate results
stacked_cycle, intermediate_results = stack_cycles(data, SAMPLE_RATE, BASE_FREQ)

Expected samples per cycle: 51
Found 308,348 falling zero crossings
Number of valid cycles found: 308,345
Processed 10,000 cycles (0.3s)
Processed 20,000 cycles (0.6s)
Processed 30,000 cycles (1.0s)
Processed 40,000 cycles (1.3s)
Processed 50,000 cycles (1.6s)
Processed 60,000 cycles (1.9s)
Processed 70,000 cycles (2.3s)
Processed 80,000 cycles (2.7s)
Processed 90,000 cycles (3.1s)
Processed 100,000 cycles (3.5s)
Processed 110,000 cycles (3.8s)
Processed 120,000 cycles (4.1s)
Processed 130,000 cycles (4.4s)
Processed 140,000 cycles (4.7s)
Processed 150,000 cycles (5.1s)
Processed 160,000 cycles (5.4s)
Processed 170,000 cycles (5.7s)
Processed 180,000 cycles (6.0s)
Processed 190,000 cycles (6.4s)
Processed 200,000 cycles (6.7s)
Processed 210,000 cycles (7.0s)
Processed 220,000 cycles (7.3s)
Processed 230,000 cycles (7.7s)
Processed 240,000 cycles (8.0s)
Processed 250,000 cycles (8.3s)
Processed 260,000 cycles (8.6s)
Processed 270,000 cycles (8.9s)
Processed 280,000 cycles (9.3s)
Process

### Stacking results for diff # of cycles

In [93]:
def plot_stacking_evolution(intermediate_results, is_real_data=True):
    """Plot intermediate stacking products for diff number of cycles."""
    fig = go.Figure()
    
    for num_cycles, result in intermediate_results:
        fig.add_trace(go.Scatter(
            x=np.arange(len(result)),
            y=result,
            name=f'{num_cycles:,} cycles',
            opacity=0.7
        ))
    
    data_type = "Real" if is_real_data else "Synthetic"
    fig.update_layout(
        title=f'Stacked Signal ({data_type} Data)',
        xaxis_title='Sample Index',
        yaxis_title='Amplitude',
        showlegend=True,
        width=1000,
        height=600,
        hovermode='x unified',
        legend=dict(
            yanchor="top",
            y=0.99,
            xanchor="left",
            x=1.05
        )
    )
    
    fig.show()

# Plot stacking evolution
plot_stacking_evolution(intermediate_results, USE_REAL_DATA)

## Ideal square wave and filter generation

In [86]:
# Resample stacked cycle to 2048 points
x_orig = np.linspace(0, 1, len(stacked_cycle))
x_resampled = np.linspace(0, 1, 2048)
resampled_cycle = np.interp(x_resampled, x_orig, stacked_cycle)

def create_perfect_square(resampled_cycle):
    """Create perfect square wave aligned with stacked signal of len 2048."""
    # Find zero crossing (pos to neg transition)
    for i in range(1, len(resampled_cycle)):
        if resampled_cycle[i-1] > 0 and resampled_cycle[i] < 0:
            zero_cross_idx = i
            break
    else:
        zero_cross_idx = len(resampled_cycle) // 2
    
    # get amplitude of top & bottom of square wave
    high_value = np.mean(resampled_cycle[:zero_cross_idx])
    low_value = np.mean(resampled_cycle[zero_cross_idx:])
    
    # Create square wave
    square = np.full(len(resampled_cycle), low_value)
    square[:zero_cross_idx] = high_value
    
    print(f"  top of sqr wave: {high_value:.6f}")
    print(f"  bottom of sqr wave: {low_value:.6f}")
    
    return square

def generate_fir_filter(signal, target, reg_param=0.00001):
    """Generate FIR filter using regularized least squares solution."""
    n = len(signal)
    A = np.array([np.roll(signal, i) for i in range(n)])
    ATA = A.T @ A
    ATb = A.T @ target
    ATA += np.eye(n) * np.mean(np.diag(ATA)) * reg_param
    return np.linalg.solve(ATA, ATb)

def apply_filter(signal, coeffs):
    """Apply FIR filter to stacked signal."""
    filtered = np.zeros(len(signal))
    for i in range(len(signal)):
        filtered[i] = np.dot(coeffs, np.roll(signal, i))
    return filtered

# Generate perfect square reference
perfect_square = create_perfect_square(resampled_cycle)

  top of sqr wave: 1.264122
  bottom of sqr wave: -1.265880


### Compare Different Filter Parameters

In [94]:
def plot_filter_comparison(resampled_cycle, perfect_square, is_real_data=True):
    """Compare filtering results for different regularization / stabalization parameters ."""
    data_type = "Real" if is_real_data else "Synthetic"
    reg_params = [0.1, 0.01, 0.001, 0.00001, 0.000001]
    filtered_signals = []
    mse_values = []
    
    # Plot filter coefficients
    fig1 = go.Figure()
    for reg_param in reg_params:
        fir_coeffs = generate_fir_filter(resampled_cycle, perfect_square, reg_param)
        fig1.add_trace(go.Scatter(
            x=np.arange(len(fir_coeffs)),
            y=fir_coeffs,
            name=f'λ={reg_param}',
            opacity=0.7
        ))
        
        filtered = apply_filter(resampled_cycle, fir_coeffs)
        filtered_signals.append(filtered)
        mse = np.mean((filtered - perfect_square) ** 2)
        mse_values.append(mse)
    
    fig1.update_layout(
        title=f'FIR Filter Coefficients for Different Regularization Parameters ({data_type} Data)',
        xaxis_title='Coefficient Index',
        yaxis_title='Amplitude',
        width=1000,
        height=400,
        hovermode='x unified'
    )
    fig1.show()
    
    # Plot filtering results
    fig2 = make_subplots(rows=2, cols=1, vertical_spacing=0.1,
                         subplot_titles=[f'Original Signal vs Target ({data_type} Data)',
                                        f'Filtered Results ({data_type} Data)'])
    
    # Original vs Target
    fig2.add_trace(go.Scatter(x=np.arange(len(resampled_cycle)), y=resampled_cycle,
                             name='Original Signal', opacity=0.7, line=dict(color='blue')), row=1, col=1)
    fig2.add_trace(go.Scatter(x=np.arange(len(perfect_square)), y=perfect_square,
                             name='Target Square Wave', opacity=0.7, line=dict(color='black', dash='dash')), row=1, col=1)
    
    # Filtered results
    for i, (filtered, reg_param) in enumerate(zip(filtered_signals, reg_params)):
        fig2.add_trace(go.Scatter(x=np.arange(len(filtered)), y=filtered,
                                 name=f'λ={reg_param} (MSE={mse_values[i]:.6f})',
                                 opacity=0.7), row=2, col=1)
    
    fig2.update_layout(
        height=800,
        width=1000,
        showlegend=True,
        hovermode='x unified'
    )
    fig2.show()
    
    # Calculate and print error metrics
    best_idx = np.argmin(mse_values)
    print(f"len of fir filter coefficients: {len(fir_coeffs)}")
 
    
    # Save filter coefficients for best stabalization param
    best_coeffs = generate_fir_filter(resampled_cycle, perfect_square, reg_params[best_idx])
    np.savetxt('fir_coefficients.txt', best_coeffs)
    print("\nBest FIR filter coefficients saved to 'fir_coefficients.txt'")

# Plot filter comparison
plot_filter_comparison(resampled_cycle, perfect_square, USE_REAL_DATA)

len of fir filter coefficients: 2048

Best FIR filter coefficients saved to 'fir_coefficients.txt'
