In [None]:
import os
import numpy as np
from scipy.signal import butter, filtfilt
import gc 


FS = 500e6
CUTOFF = 30e6
ORDER = 6
CHUNK_SIZE = 1000  

TRACE_FILES = [
    # ("./traces/trace_A_common_20000.npy", False),
    # ("./traces/trace_B_common_20000.npy", False),
    # ("./traces/trace_A_a_20000.npy", False),
    # ("./traces/trace_B_b_20000.npy", False),
    # ("./traces/trace_C_c_20000.npy", True),       # Invert
    ("./traces/trace_D_d_20000.npy", False),
]

def get_filter_coeffs(cutoff, fs, order):
    nyq = 0.5 * fs
    normal_cutoff = cutoff / nyq
    b, a = butter(order, normal_cutoff, btype='low', analog=False)
    return b, a

b_coeff, a_coeff = get_filter_coeffs(CUTOFF, FS, ORDER)


for path, is_c in TRACE_FILES:
    if not os.path.exists(path):
        print(f"File not found: {path}")
        continue
        
    dir_name, file_name = os.path.split(path)
    save_path = os.path.join(dir_name, "filtered_" + file_name)
    
    if os.path.exists(save_path):
        print(f"Already exists: {save_path}")
        continue

    
    try:
        raw_data = np.load(path, mmap_mode='r')
    except Exception as e:
        print(f"File load failed: {e}")
        continue
        
    total_traces = raw_data.shape[0]
    trace_len = raw_data.shape[1]
    
    try:
        output_mmap = np.lib.format.open_memmap(
            save_path, 
            mode='w+', 
            dtype=np.float32, 
            shape=(total_traces, trace_len)
        )
    except Exception as e:
        print(f"File save failed: {e}")
        continue
    
    num_chunks = int(np.ceil(total_traces / CHUNK_SIZE))
    
    for i in range(num_chunks):
        start_idx = i * CHUNK_SIZE
        end_idx = min((i + 1) * CHUNK_SIZE, total_traces)
        
        batch = raw_data[start_idx:end_idx]
        
        filtered_batch = filtfilt(b_coeff, a_coeff, batch, axis=-1)
        
        if is_c:
            filtered_batch = filtered_batch * -1 + 10
            
        output_mmap[start_idx:end_idx] = filtered_batch.astype(np.float32)
        
        if (i + 1) % 5 == 0 or (i + 1) == num_chunks:
            print(f"   -> Processing...: {end_idx}/{total_traces}")

    del output_mmap
    del raw_data
    gc.collect() 
    